前提
model.pyのforwardに@staticmethodを追加しています
実現したいこと
sepconvの学習をこちらのgithubを参考にして行いたいです。
発生している問題・エラーメッセージ
学習を行おうとして、train.pyを実行しても以下のようにエラーが出てしまいます。Testmodule.pyのTestでframe_out = model(one, two)でmodel.pyのforwardを呼び出すときに発生しているようです。このときtwoには画像のデータが入っていました。
機械学習初学者なのでどうすればよいか、わかりません。何かアドバイスお願いします。
質問の仕方が悪かったら申し訳ないです
line 356, in compat_exec exec(code, globals, locals) line 111, in <module> main() line 84, in main TestDB.Test(model, test_output_dir, logfile, str(model.epoch.item()).zfill(3) + '.png') line 48, in Test frame_out = model(self.input0_list[idx], self.input1_list[idx]) line 1190, in _call_impl return forward_call(*input, **kwargs) line 105, in decorate_fwd return fwd(*args, **kwargs) TypeError: forward() missing 1 required positional argument: 'frame2'
該当のソースコード
python
1#TestModule.py 2from PIL import Image 3import torch 4from torchvision import transforms 5from math import log10 6from torchvision.utils import save_image as imwrite 7from torch.autograd import Variable 8import os 9import matplotlib.pyplot as plt 10 11def to_variable(x): 12 if torch.cuda.is_available(): 13 x = x.cuda() 14 return Variable(x) 15 16 17class Middlebury_eval: 18 def __init__(self, input_dir): 19 self.im_list = ['Army', 'Backyard', 'Basketball', 'Dumptruck', 'Evergreen', 'Grove', 'Mequon', 'Schefflera', 'Teddy', 'Urban', 'Wooden', 'Yosemite'] 20 21 22class Middlebury_other: 23 def __init__(self, input_dir, gt_dir): 24 self.im_list = ['Beanbags', 'Dimetrodon', 'DogDance', 'Grove2', 'Grove3', 'Hydrangea', 'MiniCooper', 'RubberWhale', 'Urban2', 'Urban3', 'Venus', 'Walking'] 25 self.transform = transforms.Compose([transforms.ToTensor()]) 26 27 self.input0_list = [] 28 self.input1_list = [] 29 self.gt_list = [] 30 for item in self.im_list: 31 self.input0_list.append(to_variable(self.transform(Image.open(input_dir + '/' + item + '/frame10.png')).unsqueeze(0))) 32 self.input1_list.append(to_variable(self.transform(Image.open(input_dir + '/' + item + '/frame11.png')).unsqueeze(0))) 33 self.gt_list.append(to_variable(self.transform(Image.open(gt_dir + '/' + item + '/frame10i11.png')).unsqueeze(0))) 34 35 def Test(self, model, output_dir, logfile=None, output_name='output.png'): 36 av_psnr = 0 37 if logfile is not None: 38 logfile.write('{:<7s}{:<3d}'.format('Epoch: ', model.epoch.item()) + '\n') 39 for idx in range(len(self.im_list)): 40 41 if not os.path.exists(output_dir + '/' + self.im_list[idx]): 42 os.makedirs(output_dir + '/' + self.im_list[idx]) 43 44 one = self.input0_list[idx] 45 two = self.input1_list[idx] 46 47 frame_out = model(one, two) 48 gt = self.gt_list[idx] 49 psnr = -10 * log10(torch.mean((gt - frame_out) * (gt - frame_out)).item()) 50 av_psnr += psnr 51 imwrite(frame_out, output_dir + '/' + self.im_list[idx] + '/' + output_name, range=(0, 1)) 52 msg = '{:<15s}{:<20.16f}'.format(self.im_list[idx] + ': ', psnr) + '\n' 53 print(msg, end='') 54 if logfile is not None: 55 logfile.write(msg) 56 av_psnr /= len(self.im_list) 57 msg = '{:<15s}{:<20.16f}'.format('Average: ', av_psnr) + '\n' 58 print(msg, end='') 59 if logfile is not None: 60 logfile.write(msg)
python
1#model.py 2class SepConvNet(torch.nn.Module): 3 def __init__(self, kernel_size): 4 super(SepConvNet, self).__init__() 5 self.kernel_size = kernel_size 6 self.kernel_pad = int(math.floor(kernel_size / 2.0)) 7 8 self.epoch = Variable(torch.tensor(0, requires_grad=False)) 9 self.get_kernel = KernelEstimation(self.kernel_size) 10 self.optimizer = optim.Adam(self.parameters(), lr=0.001) 11 self.criterion = torch.nn.MSELoss() 12 13 self.modulePad = torch.nn.ReplicationPad2d([self.kernel_pad, self.kernel_pad, self.kernel_pad, self.kernel_pad]) 14 15 @staticmethod 16 @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) 17 def forward(self, frame0, frame2): 18 print('1') 19 20 h0 = int(list(frame0.size())[2]) 21 w0 = int(list(frame0.size())[3]) 22 h2 = int(list(frame2.size())[2]) 23 w2 = int(list(frame2.size())[3]) 24 25 if h0 != h2 or w0 != w2: 26 sys.exit('Frame sizes do not match') 27 28 h_padded = False 29 w_padded = False 30 if h0 % 32 != 0: 31 pad_h = 32 - (h0 % 32) 32 frame0 = F.pad(frame0, (0, 0, 0, pad_h)) 33 frame2 = F.pad(frame2, (0, 0, 0, pad_h)) 34 h_padded = True 35 36 if w0 % 32 != 0: 37 pad_w = 32 - (w0 % 32) 38 frame0 = F.pad(frame0, (0, pad_w, 0, 0)) 39 frame2 = F.pad(frame2, (0, pad_w, 0, 0)) 40 w_padded = True 41 42 Vertical1, Horizontal1, Vertical2, Horizontal2 = self.get_kernel(frame0, frame2) 43 44 tensorDot1 = sepconv.FunctionSepconv()(self.modulePad(frame0), Vertical1, Horizontal1) 45 tensorDot2 = sepconv.FunctionSepconv()(self.modulePad(frame2), Vertical2, Horizontal2) 46 47 frame1 = tensorDot1 + tensorDot2 48 49 if h_padded: 50 frame1 = frame1[:, :, 0:h0, :] 51 if w_padded: 52 frame1 = frame1[:, :, :, 0:w0] 53 54 return frame1 55 56 def train_model(self, frame0, frame2, frame1): 57 self.optimizer.zero_grad() 58 output = self.forward(frame0, frame2) 59 loss = self.criterion(output, frame1) 60 loss.backward() 61 self.optimizer.step() 62 return loss 63 64 def increase_epoch(self): 65 self.epoch += 1 66 67
試したこと
frame_out = model(one, two)no
twoには画像が入っていました。
補足情報(FW/ツールのバージョンなど)
ここにより詳細な情報を記載してください。
回答1件
あなたの回答
tips
プレビュー