###困っていること
pytorch初心者です。Resnet50のファインチューニングを行っています。
入力変数は画像3次元、スカラー1次元で、出力の予測は数値なのでスカラになります。
いまスタックしていてのは入力の次元数が合わないことが原因と考えています。kerasでは次元数の違いは気にしなくてよかったのですが、pytorchは次元を合わせる必要があると知りました。しかし、合わせ方がわかりません。
torch.catに問題があるのかなと思っていますが、不確かです。
###エラー表示
RuntimeError: size mismatch, m1: [1 x 64], m2: [1 x 256]
###試行したこと
1.下記コード内でコメアウトしていますが、
torch型にする前後で[np.newaxis,np.newaxis,]を試行しました。しかし、
ndexError: too many indices for array
のエラーが出ます。
2.viewでtorch型の後でスカラを(-1,1)にしましたが、これも2次元なのでエラーが出ます。
RuntimeError: Tensors must have same number of dimensions: got 3 and 2
###コード(該当箇所だけ抽出しました)
model
1class ScorePredictionImageMeta1(nn.Module): 2 def __init__(self, meta_len=1): 3 super(ScorePredictionImageMeta1, self).__init__() 4 self.meta_len = meta_len 5 6 my_resnet = models.resnet50(pretrained=True) 7 self.resnet = nn.Sequential(*list(my_resnet.children())[:-1]) 8 self.image_fc = nn.Linear(2048, 256) #画像 9 self.meta_fc = nn.Linear(1, 256) #スカラー 10 self.relu = nn.ReLU() 11 self.dropout = nn.Dropout(p=0.25) 12 self.last_fc = nn.Linear(256,1) 13 14 def forward(self, image, meta): 15 16 x = self.resnet(image) 17 x = self.relu(self.image_fc(x.view(-1, 2048))) 18 y = self.relu(self.meta_fc(meta)) 19 z = torch.cat([x,y]) #怪しい?? 20 z = self.dropout(z) 21 z = self.last_fc(z) 22 23 return z
dataset
1class Dataset(): 2 #def __init__(self, mode='train', transform=None, args=None): 3 def __init__(self, file_list, mode='train', transform=None): 4 self.file_list = file_list 5 self.mode = mode 6 self.transform = ImageTransform() 7 self.meta = meta 8 self.label = label 9 10 #self.ft_type = args.ft_type 11 12 def __len__(self): 13 14 return len(self.file_list) 15 16 def __getitem__(self, index): 17 18 #画像 19 img_path = self.file_list[index] 20 img = Image.open(img_path) 21 22 img_transformed = self.transform(img, mode='train') 23 24 #meta 25 mate = self.meta[index] 26 #meta = np.array(meta)[np.newaxis,np.newaxis,] 27 meta = torch.tensor(meta).float() 28 #meta = meta[np.newaxis,np.newaxis,] 29 #meta = meta.view(-1, 1) 30 31 #label 32 label = self.label[index] 33 label = torch.tensor(label).float() 34 35 36 return img_transformed, brandscore, score
よろしくお願いします。
回答1件
あなたの回答
tips
プレビュー