質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.50%
コマンド

コマンドとは特定のタスクを行う為に、コンピュータープログラムへ提示する指示文です。多くの場合、コマンドはShellやcmdようなコマンドラインインターフェイスに対する指示文を指します。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

解決済

1回答

964閲覧

複数の出力ファイルを連番で保存させたい 機械学習

tiroha

総合スコア109

コマンド

コマンドとは特定のタスクを行う為に、コンピュータープログラムへ提示する指示文です。多くの場合、コマンドはShellやcmdようなコマンドラインインターフェイスに対する指示文を指します。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

0クリップ

投稿2021/10/08 07:19

編集2021/10/09 05:54

機械学習させたモデルを使って予測画像を出力させたいです。
出力画像は連番がいいです。
通常、1枚ずつ出力させるなら

python predict.py -i input.jpg -o output.jpg -m checkpoints/CP10.pth

となり、うまく実行できます。
これをディレクトリにある全ての画像ファイルに対応させたいです。
実際に試したコードは

python predict.py -i test/*.jpg -o test/output%05d.jpg -m checkpoints/CP10.pth

エラーが
Error : Input files and output files are not of the same length
と出てきます。

python predict.py -i *.jpg -o *.jpg -m checkpoints/CP5.pth

とすれば、今いるディレクトリに出力画像が保存されるんですが、上書きされるので、別のディレクトリに保存したいです。

以下がpredict.pyです。

python

1import argparse 2import os 3 4import numpy as np 5import torch 6import torch.nn.functional as F 7 8from PIL import Image 9 10from unet import UNet 11from utils import resize_and_crop, normalize, split_img_into_squares, hwc_to_chw, merge_masks, dense_crf 12from utils import plot_img_and_mask 13 14from torchvision import transforms 15 16def predict_img(net, 17 full_img, 18 scale_factor=0.5, 19 out_threshold=0.5, 20 use_dense_crf=True, 21 use_gpu=False): 22 23 net.eval() 24 img_height = full_img.size[1] 25 img_width = full_img.size[0] 26 27 img = resize_and_crop(full_img, scale=scale_factor) 28 img = normalize(img) 29 30 left_square, right_square = split_img_into_squares(img) 31 32 left_square = hwc_to_chw(left_square) 33 right_square = hwc_to_chw(right_square) 34 35 X_left = torch.from_numpy(left_square).unsqueeze(0) 36 X_right = torch.from_numpy(right_square).unsqueeze(0) 37 38 if use_gpu: 39 X_left = X_left.cuda() 40 X_right = X_right.cuda() 41 42 with torch.no_grad(): 43 output_left = net(X_left) 44 output_right = net(X_right) 45 46 left_probs = output_left.squeeze(0) 47 right_probs = output_right.squeeze(0) 48 49 tf = transforms.Compose( 50 [ 51 transforms.ToPILImage(), 52 transforms.Resize(img_height), 53 transforms.ToTensor() 54 ] 55 ) 56 57 left_probs = tf(left_probs.cpu()) 58 right_probs = tf(right_probs.cpu()) 59 60 left_mask_np = left_probs.squeeze().cpu().numpy() 61 right_mask_np = right_probs.squeeze().cpu().numpy() 62 63 full_mask = merge_masks(left_mask_np, right_mask_np, img_width) 64 65 if use_dense_crf: 66 full_mask = dense_crf(np.array(full_img).astype(np.uint8), full_mask) 67 68 return full_mask > out_threshold 69 70 71 72def get_args(): 73 parser = argparse.ArgumentParser() 74 parser.add_argument('--model', '-m', default='MODEL.pth', 75 metavar='FILE', 76 help="Specify the file in which is stored the model" 77 " (default : 'MODEL.pth')") 78 parser.add_argument('--input', '-i', metavar='INPUT', nargs='+', 79 help='filenames of input images', required=True) 80 81 parser.add_argument('--output', '-o', metavar='INPUT', nargs='+', 82 help='filenames of ouput images') 83 parser.add_argument('--cpu', '-c', action='store_true', 84 help="Do not use the cuda version of the net", 85 default=False) 86 parser.add_argument('--viz', '-v', action='store_true', 87 help="Visualize the images as they are processed", 88 default=False) 89 parser.add_argument('--no-save', '-n', action='store_true', 90 help="Do not save the output masks", 91 default=False) 92 parser.add_argument('--no-crf', '-r', action='store_true', 93 help="Do not use dense CRF postprocessing", 94 default=False) 95 parser.add_argument('--mask-threshold', '-t', type=float, 96 help="Minimum probability value to consider a mask pixel white", 97 default=0.5) 98 parser.add_argument('--scale', '-s', type=float, 99 help="Scale factor for the input images", 100 default=0.5) 101 102 return parser.parse_args() 103 104def get_output_filenames(args): 105 in_files = args.input 106 out_files = [] 107 108 if not args.output: 109 for f in in_files: 110 pathsplit = os.path.splitext(f) 111 out_files.append("{}_OUT{}".format(pathsplit[0], pathsplit[1])) 112 elif len(in_files) != len(args.output): 113 print("Error : Input files and output files are not of the same length") 114 raise SystemExit() 115 else: 116 out_files = args.output 117 118 return out_files 119 120def mask_to_image(mask): 121 return Image.fromarray((mask * 255).astype(np.uint8)) 122 123if __name__ == "__main__": 124 args = get_args() 125 in_files = args.input 126 out_files = get_output_filenames(args) 127 128 net = UNet(n_channels=3, n_classes=1) 129 130 print("Loading model {}".format(args.model)) 131 132 if not args.cpu: 133 print("Using CUDA version of the net, prepare your GPU !") 134 net.cuda() 135 net.load_state_dict(torch.load(args.model)) 136 else: 137 net.cpu() 138 net.load_state_dict(torch.load(args.model, map_location='cpu')) 139 print("Using CPU version of the net, this may be very slow") 140 141 print("Model loaded !") 142 143 for i, fn in enumerate(in_files): 144 print("\nPredicting image {} ...".format(fn)) 145 146 img = Image.open(fn) 147 if img.size[0] < img.size[1]: 148 print("Error: image height larger than the width") 149 150 mask = predict_img(net=net, 151 full_img=img, 152 scale_factor=args.scale, 153 out_threshold=args.mask_threshold, 154 use_dense_crf= not args.no_crf, 155 use_gpu=not args.cpu) 156 157 if args.viz: 158 print("Visualizing results for image {}, close to continue ...".format(fn)) 159 plot_img_and_mask(img, mask) 160 161 if not args.no_save: 162 out_fn = out_files[i] 163 result = mask_to_image(mask) 164 result.save(out_files[i]) 165 166 print("Mask saved to {}".format(out_files[i])) 167

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

meg_

2021/10/08 10:54

predict.pyはあなたが作成したものですか?
tiroha

2021/10/09 05:52

いえ。違います。 github上で公開されているものを少し書き換えただけです。
guest

回答1

0

自己解決

python predict.py -i test/.jpg -o output_epoch25_kaggle/.jpg -m checkpoints/CP25.pth

とすればディレクトリ内に保存できました。

投稿2021/10/09 08:31

tiroha

総合スコア109

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.50%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問