質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.35%
PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

Q&A

解決済

1回答

3353閲覧

Pytorch/複数入力の深層学習モデルの構築がうまくできません

hidemomo

総合スコア31

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

0グッド

0クリップ

投稿2020/02/26 15:10

pytorch初心者です。
CNNの複数入力モデルを構築したいのですが、関数内の定義がうまくいかず
local variable 'x_bs' referenced before assignmentのエラーが出ます。

入力は画像とほかmetadataが2つ(1次元、26次元)の計3変数、目的変数は連続データ、ResNet50のファインチューニング回帰モデルを目指しています。
コードは少し長くなり恐縮ですが、datasetから説明させていただきます。

pytorch

1 2#dataloader作成 3df_train = torch.utils.data.TensorDataset(x_train,x_bs_train, x_cate_train, y_train) 4df_val = torch.utils.data.TensorDataset(x_val,x_bs_val, x_cate_val, y_val) 5 6batch_size=32 7 8train_dataloader = torch.utils.data.DataLoader(df_train, batch_size=batch_size, shuffle=True) 9val_dataloader = torch.utils.data.DataLoader(df_val, batch_size=batch_size, shuffle=False) 10 11#モデル構築 12class MyModel(nn.Module): 13 def __init__(self): 14 super(MyModel, self).__init__() 15 # 画像 16 self.resnet = nn.Sequential(*list(models.resnet50(pretrained=True).children())[:-1]) 17 self.image_fc = nn.Linear(2048, 256) 18 # metadata:bs1次元,cate26次元 19 self.bs_fc = nn.Linear(1, 256) 20 self.cate_fc = nn.Linear(26, 256) 21 self.all_fc = nn.Linear(256*3, 256*3) 22 self.dropout = nn.Dropout(0.25) 23 # 最終FC層 24 self.last_fc = nn.Linear(256, 1) 25 26 def forward(self, image, bs, cate): 27 # 画像features 28 x_img = self.resnet(image).view(-1, 2048) 29 x_img = F.relu(self.image_fc(x_img)) 30 # metadata_features 31 x_bs = F.relu(self.bs_fc(x_bs)) 32 x_cate = F.relu(self.bs_fc(x_cate)) 33 # concat,relu, dropout 34 x = torch.cat([x_img, x_bs, x_cate], 1) 35 x = F.relu(self.all_fc(x)) 36 x = self.dropout(x) 37 # 最終FC 38 y = self.last_fc(x) 39 return y 40 41model = MyModel() 42 43#モデルが挙動するか確認 44for x_train,x_bs_train, x_cate_train, y_train in train_dataloader: 45 y_pred = model(x_train,x_bs_train,x_cate_train) 46 47#エラー: 48UnboundLocalError Traceback (most recent call last) 49<ipython-input-15-181704a40564> in <module> 50 1 for x_train,x_bs_train, x_cate_train, y_train in train_dataloader: 51----> 2 y_pred = model(x_train,x_bs_train,x_cate_train) 52 53~\AppData\Local\Continuum\anaconda3\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs) 54 530 result = self._slow_forward(*input, **kwargs) 55 531 else: 56--> 532 result = self.forward(*input, **kwargs) 57 533 for hook in self._forward_hooks.values(): 58 534 hook_result = hook(self, input, result) 59 60<ipython-input-13-3a3ca0de36f3> in forward(self, image, bs, cate) 61 20 x_img = F.relu(self.image_fc(x_img)) 62 21 # メタデータ特徴 63---> 22 x_bs = F.relu(self.bs_fc(x_bs)) 64 23 x_cate = F.relu(self.bs_fc(x_cate)) 65 24 # concat,relu, dropout 66 67UnboundLocalError: local variable 'x_bs' referenced before assignment 68

グローバル変数とローカル変数の不整合にあると思うのですが修正方法がわかりません。

ご教授のほど、よろしくお願いします。

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

0

自己解決

単純な記載ミスでした。

def forward(self, image, bs, cate):
の部分を
def forward(self, image, x_bs, x_cate):
に変える必要がありますね。

投稿2020/02/27 06:44

hidemomo

総合スコア31

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.35%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問