pytorch初心者です。
CNNの複数入力モデルを構築したいのですが、関数内の定義がうまくいかず
local variable 'x_bs' referenced before assignment
のエラーが出ます。
入力は画像とほかmetadataが2つ(1次元、26次元)の計3変数、目的変数は連続データ、ResNet50のファインチューニング回帰モデルを目指しています。
コードは少し長くなり恐縮ですが、datasetから説明させていただきます。
pytorch
1 2#dataloader作成 3df_train = torch.utils.data.TensorDataset(x_train,x_bs_train, x_cate_train, y_train) 4df_val = torch.utils.data.TensorDataset(x_val,x_bs_val, x_cate_val, y_val) 5 6batch_size=32 7 8train_dataloader = torch.utils.data.DataLoader(df_train, batch_size=batch_size, shuffle=True) 9val_dataloader = torch.utils.data.DataLoader(df_val, batch_size=batch_size, shuffle=False) 10 11#モデル構築 12class MyModel(nn.Module): 13 def __init__(self): 14 super(MyModel, self).__init__() 15 # 画像 16 self.resnet = nn.Sequential(*list(models.resnet50(pretrained=True).children())[:-1]) 17 self.image_fc = nn.Linear(2048, 256) 18 # metadata:bs1次元,cate26次元 19 self.bs_fc = nn.Linear(1, 256) 20 self.cate_fc = nn.Linear(26, 256) 21 self.all_fc = nn.Linear(256*3, 256*3) 22 self.dropout = nn.Dropout(0.25) 23 # 最終FC層 24 self.last_fc = nn.Linear(256, 1) 25 26 def forward(self, image, bs, cate): 27 # 画像features 28 x_img = self.resnet(image).view(-1, 2048) 29 x_img = F.relu(self.image_fc(x_img)) 30 # metadata_features 31 x_bs = F.relu(self.bs_fc(x_bs)) 32 x_cate = F.relu(self.bs_fc(x_cate)) 33 # concat,relu, dropout 34 x = torch.cat([x_img, x_bs, x_cate], 1) 35 x = F.relu(self.all_fc(x)) 36 x = self.dropout(x) 37 # 最終FC 38 y = self.last_fc(x) 39 return y 40 41model = MyModel() 42 43#モデルが挙動するか確認 44for x_train,x_bs_train, x_cate_train, y_train in train_dataloader: 45 y_pred = model(x_train,x_bs_train,x_cate_train) 46 47#エラー: 48UnboundLocalError Traceback (most recent call last) 49<ipython-input-15-181704a40564> in <module> 50 1 for x_train,x_bs_train, x_cate_train, y_train in train_dataloader: 51----> 2 y_pred = model(x_train,x_bs_train,x_cate_train) 52 53~\AppData\Local\Continuum\anaconda3\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs) 54 530 result = self._slow_forward(*input, **kwargs) 55 531 else: 56--> 532 result = self.forward(*input, **kwargs) 57 533 for hook in self._forward_hooks.values(): 58 534 hook_result = hook(self, input, result) 59 60<ipython-input-13-3a3ca0de36f3> in forward(self, image, bs, cate) 61 20 x_img = F.relu(self.image_fc(x_img)) 62 21 # メタデータ特徴 63---> 22 x_bs = F.relu(self.bs_fc(x_bs)) 64 23 x_cate = F.relu(self.bs_fc(x_cate)) 65 24 # concat,relu, dropout 66 67UnboundLocalError: local variable 'x_bs' referenced before assignment 68
グローバル変数とローカル変数の不整合にあると思うのですが修正方法がわかりません。
ご教授のほど、よろしくお願いします。
回答1件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。