質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%
PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

Q&A

解決済

1回答

523閲覧

Pytorchの shape '[51, 44944]' is invalid for input of size 20400のエラーを修正したい

wataske

総合スコア11

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

0グッド

0クリップ

投稿2022/01/26 10:09

前提・実現したいこと

pytorchにてCNNを用いた距離学習による画像判別を実装しようと試みているのですがCNNを組む段階で詰まってしまいました。

発生している問題・エラーメッセージ

shape '[51, 44944]' is invalid for input of size 20400 上記のエラーコードが発生します。入力サイズが割り切れない、という趣旨のエラーであると調べた結果わかったのですが

該当のソースコード

pytorch

1import torch 2import torch.nn as nn 3import torch.nn.functional as F 4 5 6class Net(nn.Module): 7 def __init__(self): 8 super(Net, self).__init__() 9 self.conv1 = nn.Conv2d(3, 6, 5) 10 self.pool = nn.MaxPool2d(2, 2) 11 self.conv2 = nn.Conv2d(6, 16, 5) 12 self.fc1 = nn.Linear(16 * 53 *53, 120) 13 self.fc2 = nn.Linear(120, 84) 14 self.fc3 = nn.Linear(84, 10) 15 16 def forward(self, x): 17 x = self.pool(F.relu(self.conv1(x))) 18 x = self.pool(F.relu(self.conv2(x))) 19 x = x.view(x.size(0), 16 * 53 *53) 20 x = F.relu(self.fc1(x)) 21 x = F.relu(self.fc2(x)) 22 x = self.fc3(x) 23 return x 24 25 26net = Net() 27///////////////////// 28batch_size = 64 29device = torch.device("cuda") 30 31model = Net().to(device) 32optimizer = optim.Adam(model.parameters(), lr=0.01) 33num_epochs = 300 34n_splits = 3 35x_epoch_data = [] 36y_train_loss_data = [] 37y_valid_loss_data = [] 38acc_test = [] 39acc_valid = [] 40 41summary(model,(3,224,224)) # summary(model,(channels,H,W)) 42///////////////////// 43acc_save = 0.0 44valid_loss_save = 1.0 45 46for epoch in range(1, num_epochs+1): 47 model, train_loss = train(model, loss_func, mining_func, device, train_loader, optimizer, epoch) 48 valid_loss = valid(model, valid_loader) 49 acc,train_embeddings, train_labels,valid_embeddings, valid_labels = accuracy(train_dataset, 50 valid_dataset, 51 model, 52 accuracy_calculator) 53 54 if(acc >= acc_save): 55 if(valid_loss_save >= valid_loss): 56 valid_loss_save = valid_loss 57 acc_save = acc 58 epoch_save = epoch 59 60 torch.save(model.state_dict(), "model.pth") 61 print("_______model更新_________") 62 63 64 acc_valid.append(acc) 65 print('epoch:{} ,train_loss {:.3f} ,valid loss {:.3f}'.format(epoch, train_loss, valid_loss)) 66 x_epoch_data.append(epoch) 67 y_train_loss_data.append(train_loss) 68 y_valid_loss_data.append(valid_loss) 69 70print("save model: epoch{}, valid_loss{}, valid_acc{}".format(epoch_save, valid_loss_save, acc_save))

試したこと

当初は”/////////////////////”に挟まれているところで同様のエラー(input of size 89888)が発生していたためそれで割り切ることができるように調整したところ現状のエラーが出てしまった状態です。入力されているデータのtorchサイズは(64,3,32,32)です。どのようにして解決すればよいかどなたかお願いしますm(__)m
また、入力サイズの89888と現在で出ている20400ではどういった違いがあるのでしょうか?

補足情報(FW/ツールのバージョンなど)

ここにより詳細な情報を記載してください。

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

0

自己解決

”/////////////////////”に挟まれているsummary()の大きさが間違っていたためでした。

投稿2022/01/30 07:47

wataske

総合スコア11

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問