質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.35%
深層学習

深層学習は、多数のレイヤのニューラルネットワークによる機械学習手法。人工知能研究の一つでディープラーニングとも呼ばれています。コンピューター自体がデータの潜在的な特徴を汲み取り、効率的で的確な判断を実現することができます。

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

解決済

1回答

2144閲覧

【MNIST】PyTorchのsize mismatchエラー

jamboc

総合スコア16

深層学習

深層学習は、多数のレイヤのニューラルネットワークによる機械学習手法。人工知能研究の一つでディープラーニングとも呼ばれています。コンピューター自体がデータの潜在的な特徴を汲み取り、効率的で的確な判断を実現することができます。

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

0クリップ

投稿2021/01/04 05:23

こちらのサイトを元に以下の構成のネットワークをPyTorchで実装しようと思ったのですが、
size mismatch, m1: [128 x 256], m2: [128 x 256]axというエラーが出てきてしまいました。

当方まだ初学者で、ディープラーニングのお作法やPyTorchの記述方法があまりわかっていないので解決方法を教えていただきたいです。


実装したいネットワーク構成

conv - relu - conv- relu - pool -
conv - relu - conv- relu - pool -
conv - relu - conv- relu - pool -
affine - relu - dropout - affine - dropout - softmax


ソースコード(ネットワーク部分)

class CNNModel (nn.Module): def __init__(self): super(CNNModel, self).__init__() self.conv1=nn.Conv2d(1,16,3,1) self.conv2=nn.Conv2d(16,16,3,1) self.conv3=nn.Conv2d(16,32,3,1) self.conv4=nn.Conv2d(32,32,3,1) self.conv5=nn.Conv2d(32,64,3,1) self.conv6=nn.Conv2d(64,64,3,1) self.pool=nn.MaxPool2d(2,2) self.dropout1=nn.Dropout2d(0.25) self.dropout2=nn.Dropout2d(0.5) self.fc1=nn.Linear(128,256) # ,256 self.fc2=nn.Linear(256,10) def forward(self, x): x=self.conv1(x) x=F.relu(x) x=self.conv2(x) x=F.relu(x) x=self.pool(x) x=self.conv3(x) x=F.relu(x) x=self.conv4(x) x=F.relu(x) x=self.conv5(x) x=F.relu(x) x=self.conv6(x) x=F.relu(x) x=self.pool(x) x=self.dropout1(x) x=x.view(-1,self.num_flat_features(x)) # 128*256 print(x.size()[0]) x=self.fc1(x) x=F.relu(x) x=self.dropout2(x) x=self.fc2(x) x=F.log_softmax(x,dim=1) return x

エラー詳細

--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) <ipython-input-5-7a8a290ae702> in <module> 144 optimizer.zero_grad() 145 --> 146 output = cnn.forward(x) # WRITE ME (予測の計算) 147 loss = loss_fn(output,t) # WRITE ME (損失関数の計算) 148 loss.backward() # WRITE ME (勾配の計算) <ipython-input-5-7a8a290ae702> in forward(self, x) 97 print(x.size()[0]) 98 ---> 99 x=self.fc1(x) 100 x=F.relu(x) 101 x=self.dropout2(x) /usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs) 548 result = self._slow_forward(*input, **kwargs) 549 else: --> 550 result = self.forward(*input, **kwargs) 551 for hook in self._forward_hooks.values(): 552 hook_result = hook(self, input, result) /usr/local/lib/python3.7/dist-packages/torch/nn/modules/linear.py in forward(self, input) 85 86 def forward(self, input): ---> 87 return F.linear(input, self.weight, self.bias) 88 89 def extra_repr(self): /usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in linear(input, weight, bias) 1608 if input.dim() == 2 and bias is not None: 1609 # fused op is marginally faster -> 1610 ret = torch.addmm(bias, input, weight.t()) 1611 else: 1612 output = input.matmul(weight.t()) RuntimeError: size mismatch, m1: [128 x 256], m2: [128 x 256] at /pytorch/aten/src/THC/generic/THCTensorMathBlas.cu:283

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

0

ベストアンサー

m1: [128 x 256] の2番目の「256」と、
m2: [128 x 256] の1番目の「128」が違うというエラーです

m2の形状は、self.fc1=nn.Linear(128,256) の定義より

m1はfc1の入力なので、x=x.view(-1,self.num_flat_features(x)) の結果の形状

投稿2021/01/04 07:54

jbpb0

総合スコア7653

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.35%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問