質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%
PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

GitHub

GitHubは、Gitバージョン管理システムを利用したソフトウェア開発向けの共有ウェブサービスです。GitHub商用プランおよびオープンソースプロジェクト向けの無料アカウントを提供しています。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

解決済

1回答

809閲覧

sepconvの学習の際に’’ forward() missing 1 required positional argument: 'frame2'”と出てきます

ttt1111

総合スコア1

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

GitHub

GitHubは、Gitバージョン管理システムを利用したソフトウェア開発向けの共有ウェブサービスです。GitHub商用プランおよびオープンソースプロジェクト向けの無料アカウントを提供しています。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

0クリップ

投稿2022/11/07 09:04

前提

model.pyのforwardに@staticmethodを追加しています

実現したいこと

sepconvの学習をこちらのgithubを参考にして行いたいです。

発生している問題・エラーメッセージ

学習を行おうとして、train.pyを実行しても以下のようにエラーが出てしまいます。Testmodule.pyのTestでframe_out = model(one, two)でmodel.pyのforwardを呼び出すときに発生しているようです。このときtwoには画像のデータが入っていました。
機械学習初学者なのでどうすればよいか、わかりません。何かアドバイスお願いします。
質問の仕方が悪かったら申し訳ないです

line 356, in compat_exec exec(code, globals, locals) line 111, in <module> main() line 84, in main TestDB.Test(model, test_output_dir, logfile, str(model.epoch.item()).zfill(3) + '.png') line 48, in Test frame_out = model(self.input0_list[idx], self.input1_list[idx]) line 1190, in _call_impl return forward_call(*input, **kwargs) line 105, in decorate_fwd return fwd(*args, **kwargs) TypeError: forward() missing 1 required positional argument: 'frame2'

該当のソースコード

python

1#TestModule.py 2from PIL import Image 3import torch 4from torchvision import transforms 5from math import log10 6from torchvision.utils import save_image as imwrite 7from torch.autograd import Variable 8import os 9import matplotlib.pyplot as plt 10 11def to_variable(x): 12 if torch.cuda.is_available(): 13 x = x.cuda() 14 return Variable(x) 15 16 17class Middlebury_eval: 18 def __init__(self, input_dir): 19 self.im_list = ['Army', 'Backyard', 'Basketball', 'Dumptruck', 'Evergreen', 'Grove', 'Mequon', 'Schefflera', 'Teddy', 'Urban', 'Wooden', 'Yosemite'] 20 21 22class Middlebury_other: 23 def __init__(self, input_dir, gt_dir): 24 self.im_list = ['Beanbags', 'Dimetrodon', 'DogDance', 'Grove2', 'Grove3', 'Hydrangea', 'MiniCooper', 'RubberWhale', 'Urban2', 'Urban3', 'Venus', 'Walking'] 25 self.transform = transforms.Compose([transforms.ToTensor()]) 26 27 self.input0_list = [] 28 self.input1_list = [] 29 self.gt_list = [] 30 for item in self.im_list: 31 self.input0_list.append(to_variable(self.transform(Image.open(input_dir + '/' + item + '/frame10.png')).unsqueeze(0))) 32 self.input1_list.append(to_variable(self.transform(Image.open(input_dir + '/' + item + '/frame11.png')).unsqueeze(0))) 33 self.gt_list.append(to_variable(self.transform(Image.open(gt_dir + '/' + item + '/frame10i11.png')).unsqueeze(0))) 34 35 def Test(self, model, output_dir, logfile=None, output_name='output.png'): 36 av_psnr = 0 37 if logfile is not None: 38 logfile.write('{:<7s}{:<3d}'.format('Epoch: ', model.epoch.item()) + '\n') 39 for idx in range(len(self.im_list)): 40 41 if not os.path.exists(output_dir + '/' + self.im_list[idx]): 42 os.makedirs(output_dir + '/' + self.im_list[idx]) 43 44 one = self.input0_list[idx] 45 two = self.input1_list[idx] 46 47 frame_out = model(one, two) 48 gt = self.gt_list[idx] 49 psnr = -10 * log10(torch.mean((gt - frame_out) * (gt - frame_out)).item()) 50 av_psnr += psnr 51 imwrite(frame_out, output_dir + '/' + self.im_list[idx] + '/' + output_name, range=(0, 1)) 52 msg = '{:<15s}{:<20.16f}'.format(self.im_list[idx] + ': ', psnr) + '\n' 53 print(msg, end='') 54 if logfile is not None: 55 logfile.write(msg) 56 av_psnr /= len(self.im_list) 57 msg = '{:<15s}{:<20.16f}'.format('Average: ', av_psnr) + '\n' 58 print(msg, end='') 59 if logfile is not None: 60 logfile.write(msg)

python

1#model.py 2class SepConvNet(torch.nn.Module): 3 def __init__(self, kernel_size): 4 super(SepConvNet, self).__init__() 5 self.kernel_size = kernel_size 6 self.kernel_pad = int(math.floor(kernel_size / 2.0)) 7 8 self.epoch = Variable(torch.tensor(0, requires_grad=False)) 9 self.get_kernel = KernelEstimation(self.kernel_size) 10 self.optimizer = optim.Adam(self.parameters(), lr=0.001) 11 self.criterion = torch.nn.MSELoss() 12 13 self.modulePad = torch.nn.ReplicationPad2d([self.kernel_pad, self.kernel_pad, self.kernel_pad, self.kernel_pad]) 14 15 @staticmethod 16 @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) 17 def forward(self, frame0, frame2): 18 print('1') 19 20 h0 = int(list(frame0.size())[2]) 21 w0 = int(list(frame0.size())[3]) 22 h2 = int(list(frame2.size())[2]) 23 w2 = int(list(frame2.size())[3]) 24 25 if h0 != h2 or w0 != w2: 26 sys.exit('Frame sizes do not match') 27 28 h_padded = False 29 w_padded = False 30 if h0 % 32 != 0: 31 pad_h = 32 - (h0 % 32) 32 frame0 = F.pad(frame0, (0, 0, 0, pad_h)) 33 frame2 = F.pad(frame2, (0, 0, 0, pad_h)) 34 h_padded = True 35 36 if w0 % 32 != 0: 37 pad_w = 32 - (w0 % 32) 38 frame0 = F.pad(frame0, (0, pad_w, 0, 0)) 39 frame2 = F.pad(frame2, (0, pad_w, 0, 0)) 40 w_padded = True 41 42 Vertical1, Horizontal1, Vertical2, Horizontal2 = self.get_kernel(frame0, frame2) 43 44 tensorDot1 = sepconv.FunctionSepconv()(self.modulePad(frame0), Vertical1, Horizontal1) 45 tensorDot2 = sepconv.FunctionSepconv()(self.modulePad(frame2), Vertical2, Horizontal2) 46 47 frame1 = tensorDot1 + tensorDot2 48 49 if h_padded: 50 frame1 = frame1[:, :, 0:h0, :] 51 if w_padded: 52 frame1 = frame1[:, :, :, 0:w0] 53 54 return frame1 55 56 def train_model(self, frame0, frame2, frame1): 57 self.optimizer.zero_grad() 58 output = self.forward(frame0, frame2) 59 loss = self.criterion(output, frame1) 60 loss.backward() 61 self.optimizer.step() 62 return loss 63 64 def increase_epoch(self): 65 self.epoch += 1 66 67

試したこと

frame_out = model(one, two)no
twoには画像が入っていました。

補足情報(FW/ツールのバージョンなど)

ここにより詳細な情報を記載してください。

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

ttt1111

2022/11/07 09:32

staticmethodをつけ足したことで、forward(self,frame0, frame2)のselfの部分が反応してしまったようです。しかし、staticmethodを消すと以下のエラーが出てきてしまいます /// RuntimeError: Legacy autograd function with non-static forward method is deprecated. Please use new-style autograd function with static forward method. (Example: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function) /// pytorchのバージョンを下げても解決しませんでした。どうすればよいかアドバイスお願いします
guest

回答1

0

自己解決

以下の二つのことを行ったら一応動かすことができました
(1)model.pyの以下のコードにapplyを付け足した
変更前
tensorDot1 = sepconv.FunctionSepconv()(self.modulePad(frame0), Vertical1, Horizontal1)
tensorDot2 = sepconv.FunctionSepconv()(self.modulePad(frame2), Vertical2, Horizontal2)
変更後
tensorDot1 = sepconv.FunctionSepconv().apply(self.modulePad(frame0), Vertical1, Horizontal1)
tensorDot2 = sepconv.FunctionSepconv().apply(self.modulePad(frame2), Vertical2, Horizontal2)

(2)sepconvのforwardとbackwardに@staticmethodを追加しました

以上の二つを行ったところ一応動かすことができたのですが、なぜapplyを付け足したらできるようになったのか理解できていません。新しく質問しなおそうと思います

投稿2022/11/09 13:37

ttt1111

総合スコア1

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問