機械学習させたモデルを使って予測画像を出力させたいです。
出力画像は連番がいいです。
通常、1枚ずつ出力させるなら
python predict.py -i input.jpg -o output.jpg -m checkpoints/CP10.pth
となり、うまく実行できます。
これをディレクトリにある全ての画像ファイルに対応させたいです。
実際に試したコードは
python predict.py -i test/*.jpg -o test/output%05d.jpg -m checkpoints/CP10.pth
エラーが
Error : Input files and output files are not of the same length
と出てきます。
python predict.py -i *.jpg -o *.jpg -m checkpoints/CP5.pth
とすれば、今いるディレクトリに出力画像が保存されるんですが、上書きされるので、別のディレクトリに保存したいです。
以下がpredict.pyです。
python
1import argparse 2import os 3 4import numpy as np 5import torch 6import torch.nn.functional as F 7 8from PIL import Image 9 10from unet import UNet 11from utils import resize_and_crop, normalize, split_img_into_squares, hwc_to_chw, merge_masks, dense_crf 12from utils import plot_img_and_mask 13 14from torchvision import transforms 15 16def predict_img(net, 17 full_img, 18 scale_factor=0.5, 19 out_threshold=0.5, 20 use_dense_crf=True, 21 use_gpu=False): 22 23 net.eval() 24 img_height = full_img.size[1] 25 img_width = full_img.size[0] 26 27 img = resize_and_crop(full_img, scale=scale_factor) 28 img = normalize(img) 29 30 left_square, right_square = split_img_into_squares(img) 31 32 left_square = hwc_to_chw(left_square) 33 right_square = hwc_to_chw(right_square) 34 35 X_left = torch.from_numpy(left_square).unsqueeze(0) 36 X_right = torch.from_numpy(right_square).unsqueeze(0) 37 38 if use_gpu: 39 X_left = X_left.cuda() 40 X_right = X_right.cuda() 41 42 with torch.no_grad(): 43 output_left = net(X_left) 44 output_right = net(X_right) 45 46 left_probs = output_left.squeeze(0) 47 right_probs = output_right.squeeze(0) 48 49 tf = transforms.Compose( 50 [ 51 transforms.ToPILImage(), 52 transforms.Resize(img_height), 53 transforms.ToTensor() 54 ] 55 ) 56 57 left_probs = tf(left_probs.cpu()) 58 right_probs = tf(right_probs.cpu()) 59 60 left_mask_np = left_probs.squeeze().cpu().numpy() 61 right_mask_np = right_probs.squeeze().cpu().numpy() 62 63 full_mask = merge_masks(left_mask_np, right_mask_np, img_width) 64 65 if use_dense_crf: 66 full_mask = dense_crf(np.array(full_img).astype(np.uint8), full_mask) 67 68 return full_mask > out_threshold 69 70 71 72def get_args(): 73 parser = argparse.ArgumentParser() 74 parser.add_argument('--model', '-m', default='MODEL.pth', 75 metavar='FILE', 76 help="Specify the file in which is stored the model" 77 " (default : 'MODEL.pth')") 78 parser.add_argument('--input', '-i', metavar='INPUT', nargs='+', 79 help='filenames of input images', required=True) 80 81 parser.add_argument('--output', '-o', metavar='INPUT', nargs='+', 82 help='filenames of ouput images') 83 parser.add_argument('--cpu', '-c', action='store_true', 84 help="Do not use the cuda version of the net", 85 default=False) 86 parser.add_argument('--viz', '-v', action='store_true', 87 help="Visualize the images as they are processed", 88 default=False) 89 parser.add_argument('--no-save', '-n', action='store_true', 90 help="Do not save the output masks", 91 default=False) 92 parser.add_argument('--no-crf', '-r', action='store_true', 93 help="Do not use dense CRF postprocessing", 94 default=False) 95 parser.add_argument('--mask-threshold', '-t', type=float, 96 help="Minimum probability value to consider a mask pixel white", 97 default=0.5) 98 parser.add_argument('--scale', '-s', type=float, 99 help="Scale factor for the input images", 100 default=0.5) 101 102 return parser.parse_args() 103 104def get_output_filenames(args): 105 in_files = args.input 106 out_files = [] 107 108 if not args.output: 109 for f in in_files: 110 pathsplit = os.path.splitext(f) 111 out_files.append("{}_OUT{}".format(pathsplit[0], pathsplit[1])) 112 elif len(in_files) != len(args.output): 113 print("Error : Input files and output files are not of the same length") 114 raise SystemExit() 115 else: 116 out_files = args.output 117 118 return out_files 119 120def mask_to_image(mask): 121 return Image.fromarray((mask * 255).astype(np.uint8)) 122 123if __name__ == "__main__": 124 args = get_args() 125 in_files = args.input 126 out_files = get_output_filenames(args) 127 128 net = UNet(n_channels=3, n_classes=1) 129 130 print("Loading model {}".format(args.model)) 131 132 if not args.cpu: 133 print("Using CUDA version of the net, prepare your GPU !") 134 net.cuda() 135 net.load_state_dict(torch.load(args.model)) 136 else: 137 net.cpu() 138 net.load_state_dict(torch.load(args.model, map_location='cpu')) 139 print("Using CPU version of the net, this may be very slow") 140 141 print("Model loaded !") 142 143 for i, fn in enumerate(in_files): 144 print("\nPredicting image {} ...".format(fn)) 145 146 img = Image.open(fn) 147 if img.size[0] < img.size[1]: 148 print("Error: image height larger than the width") 149 150 mask = predict_img(net=net, 151 full_img=img, 152 scale_factor=args.scale, 153 out_threshold=args.mask_threshold, 154 use_dense_crf= not args.no_crf, 155 use_gpu=not args.cpu) 156 157 if args.viz: 158 print("Visualizing results for image {}, close to continue ...".format(fn)) 159 plot_img_and_mask(img, mask) 160 161 if not args.no_save: 162 out_fn = out_files[i] 163 result = mask_to_image(mask) 164 result.save(out_files[i]) 165 166 print("Mask saved to {}".format(out_files[i])) 167
回答1件
あなたの回答
tips
プレビュー