🎄teratailクリスマスプレゼントキャンペーン2024🎄』開催中!

\teratail特別グッズやAmazonギフトカード最大2,000円分が当たる!/

詳細はこちら
PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

解決済

1回答

14740閲覧

pytorchのRuntimeError: mat1 and mat2 shapes cannot be multiplied (16896x256 and 65536x100)を解決したい

hamusuke

総合スコア4

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

0クリップ

投稿2021/01/07 06:07

編集2021/01/07 06:09

###状況
pytorchを用いて文字認識を行うために、次のサイトに従ってコードを作成したのですが、タイトルのエラーが表示されてしまい、解決方法が分からないため質問しました。データセットは自作のものを使っていて、サイズが256×256のグレースケール画像を使用しています。エラーを調べたところ、エラー文中の(16896x256 and 65536x100)太字の数字を合わせる必要があるというのは分かったのですが、バッチサイズや入力層のパラメータを変えてもエラー文を解決できませんでした。

https://book.mynavi.jp/manatee/detail/id=89498

###該当ソースコード

python

1pip install opencv-python 2import matplotlib.pyplot as plt 3import os 4import cv2 5import random 6import numpy as np 7import torch 8from torch.utils.data import TensorDataset, DataLoader 9 10# 1. 自作データを保存 11x = np.array(x) #画像データ 12y = np.array(y) #ラベル 13 14# 2.1 データを訓練とテストに分割(6:1) 15from sklearn.model_selection import train_test_split 16x_train, x_test, y_train, y_test = train_test_split( 17 x, y, test_size=1/7, random_state=0) 18 19# 2.2 データをPyTorchのTensorに変換 20x_train = torch.Tensor(x_train) 21x_test = torch.Tensor(x_test) 22y_train = torch.LongTensor(y_train) 23y_test = torch.LongTensor(y_test) 24 25# 2.3 データとラベルをセットにしたDatasetを作成 26ds_train = TensorDataset(x_train, y_train) 27ds_test = TensorDataset(x_test, y_test) 28 29# 2.4 データセットのミニバッチサイズを指定した、Dataloaderを作成 30loader_train = DataLoader(ds_train, batch_size=64, shuffle=True) 31loader_test = DataLoader(ds_test, batch_size=64, shuffle=False) 32 33# 3. ネットワークの構築 34 35from torch import nn 36 37model = nn.Sequential() 38model.add_module('fc1', nn.Linear(256*256, 100)) 39model.add_module('relu1', nn.ReLU()) 40model.add_module('fc2', nn.Linear(100, 100)) 41model.add_module('relu2', nn.ReLU()) 42model.add_module('fc3', nn.Linear(100, 92)) #92 = ラベルの数 43 44print(model) 45 46# 4. 誤差関数と最適化手法の設定 47 48from torch import optim 49 50# 誤差関数の設定 51loss_fn = nn.CrossEntropyLoss() # 変数名にはcriterionも使われる 52 53# 重みを学習する際の最適化手法の選択 54optimizer = optim.Adam(model.parameters(), lr=0.01) 55 56# 5-1. 学習と推論の設定 57 58from torch.autograd import Variable 59 60 61def train(epoch): 62 model.train() # ネットワークを学習モードに切り替える 63 64 # データローダーから1ミニバッチずつ取り出して計算する 65 for data, target in loader_train: 66 data, target = Variable(data), Variable(target) # 微分可能に変換 67 optimizer.zero_grad() # 一度計算された勾配結果を0にリセット 68 69 output = model(data) # 入力dataをinputし、出力を求める 70 loss = loss_fn(output, target) # 出力と訓練データの正解との誤差を求める 71 loss.backward() # 誤差のバックプロパゲーションを求める 72 optimizer.step() # バックプロパゲーションの値で重みを更新する 73 74 print("epoch{}:終了\n".format(epoch)) 75 76# 5-2. 学習と推論の設定 77 78def test(): 79 model.eval() # ネットワークを推論モードに切り替える 80 correct = 0 81 82 # データローダーから1ミニバッチずつ取り出して計算する 83 for data, target in loader_test: 84 data, target = Variable(data), Variable(target) # 微分可能に変換 85 output = model(data) # 入力dataをinputし、出力を求める 86 87 # 推論する 88 pred = output.data.max(1, keepdim=True)[1] # 出力ラベルを求める 89 correct += pred.eq(target.data.view_as(pred)).sum() # 正解と一緒だったらカウントアップ 90 91 # 正解率を出力 92 data_num = len(loader_test.dataset) # データの総数 93 print('\nテストデータの正解率: {}/{} ({:.0f}%)\n'.format(correct, 94 data_num, 100. * correct / data_num)) 95 96# 6. 学習と推論の実行 97test() #学習をする前に試しにテストを実行

###エラーコード

python

1RuntimeError Traceback (most recent call last) 2<ipython-input-103-4fda56ceff28> in <module>() 3 1 # 6. 学習と推論の実行 4 2 5----> 3 test() 6 75 frames 8/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in linear(input, weight, bias) 9 1690 ret = torch.addmm(bias, input, weight.t()) 10 1691 else: 11-> 1692 output = input.matmul(weight.t()) 12 1693 if bias is not None: 13 1694 output += bias 14 15RuntimeError: mat1 and mat2 shapes cannot be multiplied (16896x256 and 65536x100)

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

jbpb0

2021/01/07 10:08 編集

x = np.array(x)の次元が分からないのですが、x_trainが256x256の2次元のままなのではないですか? ネットワークの入力は、256*256=65536個の数値が1次元で並んでいるのを入力するようになってますけど
hamusuke

2021/01/07 16:23

データセットの作成には次のサイトのコードを使用しました。 x = np.array(x) のあとに一次元配列化するために x = x.ravel() を追加してみましたが、データセットを学習用とテスト用に分割するところで新たなエラーが発生しました。 ###新たなエラー ValueError: Found input variables with inconsistent numbers of samples: [30146560, 460] ###サイト https://intellectual-curiosity.tokyo/2019/07/02/%E3%82%AA%E3%83%AA%E3%82%B8%E3%83%8A%E3%83%AB%E3%81%AE%E7%94%BB%E5%83%8F%E3%81%8B%E3%82%89%E3%83%87%E3%83%BC%E3%82%BF%E3%82%BB%E3%83%83%E3%83%88%E3%82%92%E4%BD%9C%E6%88%90%E3%81%99%E3%82%8B%E6%96%B9/
jbpb0

2021/01/07 22:11

私の書き方が悪かったです 上で2次元、1次元と書いたのは、一つのサンプルに付いてです サンプルの次元は別にあるので、全サンプルでは、それぞれ3次元、2次元です ravel()だと、サンプルの次元も無くなってしまいます
jbpb0

2021/01/07 23:45 編集

https://book.mynavi.jp/manatee/detail/id=89498 では、mnistデータをsklearn.datasets.fetch_mldata()で入手してます このデータに合わせて以降で解説されてるネットワークが設計されてますので、そのネットワークを(画素数だけ変えて)流用するのなら、自作データの形式を上記mnistデータに合わせる必要があります https://note.nkmk.me/python-scikit-learn-svm-mnist/ の「データ読み込み」の「データとラベルをそれぞれ取り出す。」と書かれてるところの下にあるように、そのmnistデータのmnist_data.shapeは(70000, 784)です 70000はサンプル数、784は画像の画素数(28x28)です それに合わせて、自作データのx.shapeが(サンプル数, 65536)となるようにしてください
hamusuke

2021/01/08 05:10

詳しく教えてくださり、ありがとうございます。 xをreshapeすることで、mnistと同様の形式にすることができました。 また、学習とテストも問題なく実行することができました、ありがとうございます。 ぜひ、jbpb0さんをベストアンサーとして選択させていただきたいのですが、改めて回答をしていただけますでしょうか。
guest

回答1

0

ベストアンサー

第12回 PyTorchによるディープラーニング実装入門(1)
では、mnistデータをsklearn.datasets.fetch_mldata()で入手してます
このデータに合わせて以降で解説されてるネットワークが設計されてますので、そのネットワークを(画素数だけ変えて)流用するのなら、自作データの形式を上記mnistデータに合わせる必要があります

scikit-learnのSVMでMNISTの手書き数字データを分類
の「データ読み込み」の「データとラベルをそれぞれ取り出す。」と書かれてるところの下にあるように、そのmnistデータのmnist_data.shapeは(70000, 784)です
70000はサンプル数、784は画像の画素数(28x28)です
それに合わせて、自作データのx.shapeが(サンプル数, 65536)となるようにしてください
65536は自作データの画像の画素数(256x256)です

投稿2021/01/09 11:12

編集2021/01/09 11:14
jbpb0

総合スコア7653

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

hamusuke

2021/01/11 01:18

ご回答ありがとうございます。 問題なく学習を行うことができました。
guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.36%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問