質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
87.20%
PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

解決済

PytorchのMNISTでエラーが出ます

SaitoHiroaki
SaitoHiroaki

総合スコア15

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

1回答

0評価

0クリップ

4761閲覧

投稿2019/11/12 22:56

編集2019/11/12 22:59

前提・実現したいこと

Pytorchで機械学習をしたところエラーが出ます。
Pytorchはhttps://www.tensorflow.org/tutorials/quickstart/advanced?hl=ja
こちらのTensorflowの実装と同じ実装をしたいと思っています。

発生している問題・エラーメッセージ

ValueError: Expected input batch_size (784) to match target batch_size (32).

該当のソースコード

python

import torch import torchvision import torchvision.transforms as transforms import numpy as np import torch.optim as optim import torch.nn as nn import torch.nn.functional as F transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, ))]) trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2) testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=2) print("trainset.shape:",trainset) print('train_dataset = ', len(trainset)) print('test_dataset = ', len(testset)) classes = tuple(np.linspace(0, 9, 10, dtype=np.uint8)) class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3) #(入力のchannel数, 出力のチャンネル数, カーネル) self.flatten = nn.Flatten() self.d1 = nn.Linear(32, 128) self.d2 = nn.Linear(128, 10) def forward(self, x): print(x.shape) #torch.reshape(x, (32, 32)) x = x.view(-1, 32) x = self.flatten(x) x = F.relu(self.d1(x)) return F.softmax(self.d2(x)) # select device device = 'cuda' if torch.cuda.is_available() else 'cpu' net = Net().to(device) # optimizing criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(net.parameters()) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(net.parameters(), lr=0.001) epochs = 5 for epoch in range(epochs): running_loss = 0.0 for i, (inputs, labels) in enumerate(trainloader, 0): # zero the parameter gradients #inputs, labels = inputs.view(-1, 28*28*1).to(device), labels.to(device) optimizer.zero_grad() #print(input) # forward + backward + optimize outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() # print statistics running_loss += loss.item() if i % 100 == 99: print('[{:d}, {:5d}] loss: {.3f}' .format(epoch + 1, i + 1, running_loss / 100)) running_loss = 0.0 print('Finished Training')

試したこと

x = x.view(-1, 32)を付け加えました。

補足情報(FW/ツールのバージョンなど)

batchサイズをtensorflowの方と同じにしたいです。

良い質問の評価を上げる

以下のような質問は評価を上げましょう

  • 質問内容が明確
  • 自分も答えを知りたい
  • 質問者以外のユーザにも役立つ

評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

気になる質問をクリップする

クリップした質問は、後からいつでもマイページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

  • プログラミングに関係のない質問
  • やってほしいことだけを記載した丸投げの質問
  • 問題・課題が含まれていない質問
  • 意図的に内容が抹消された質問
  • 過去に投稿した質問と同じ内容の質問
  • 広告と受け取られるような投稿

評価を下げると、トップページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

まだ回答がついていません

会員登録して回答してみよう

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
87.20%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問

同じタグがついた質問を見る

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。