以下のshapeの入力用のtensorがあります
torch.Size([4, 15232]) (preds_)
これをnn.Linearに通して次元を894まで減らしたいです。
python
1# このshapeに 2torch.Size([4, 896])
以下のコードでやったのですが、エラーがでてうまく行きません。
解決策を教えていただけないでしょうか?
ご教授お願いします。
python
1class Net(nn.Module): 2 def __init__(self): 3 super().__init__() 4 INPUT_FEATURES = 4*15232 5 OUTPUT = 896 6 self.fc1 = nn.Linear(INPUT_FEATURES, OUTPUT) 7 8 def forward(self, x): 9 x = self.fc1(x) 10 return x 11 12 13fc = Net() 14fc.to(device) 15 16print(preds_.shape) 17>>> torch.Size([4, 15232]) 18preds_ = fc(preds_)
# エラー File "/usr/local/lib/python3.7/site-packages/torch/nn/functional.py", line 1690, in linear ret = torch.addmm(bias, input, weight.t()) RuntimeError: mat1 dim 1 must match mat2 dim 0
回答1件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。