質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

ただいまの
回答率

88.92%

skorchでのエラーについて

解決済

回答 1

投稿

  • 評価
  • クリップ 0
  • VIEW 67

rikubon_

score 31

skorchを使ってMNISTを使いたいと思い、下記のようなコードを実行しましたがエラーが出てしまいました。解決策はありますでしょうか。試したこととしてnp.expand_dims(x, 1)をやってみましたが結果は変わりませんでした。

%matplotlib inline
from skorch import NeuralNet
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
torch.manual_seed(123)

data_x, data_y = fetch_openml('mnist_784', data_home = './data', cache=True, return_X_y=True)
x_train, x_test, y_train, y_test = train_test_split(data_x, data_y, test_size=0.25, random_state=42)

x_train = np.expand_dims(x_train / 255, 1).astype(np.float32)
x_test = np.expand_dims(x_test / 255, 1).astype(np.float32)

y_train = y_train.astype(np.int64)
y_test = y_train.astype(np.int64)

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(1600, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))    # (32, 13, 13)
        x = F.relu(F.max_pool2d(self.conv2_drop(
            self.conv2(x)), 2))   # (64, 5, 5)
        x = x.view(-1, x.size(1) * x.size(2) * x.size(3))
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        x = F.log_softmax(x, dim=-1)
        return x

device = torch.device('cuda')

model = NeuralNet(
  CNN,            # ここで、先程定義したNetクラスを引数として与える
  max_epochs=10,
  optimizer=torch.optim.Adam,
  lr=0.001,
  device=device,
  batch_size=128,
  train_split=None,
  criterion=nn.NLLLoss    # CNNの最後のactivationがlog_softmaxなので、lossはNLLoss。
)


model.fit(X_train, y_train)

pred = model.prefict(x_test)  

pred = pred.argmax(axis=1)
print("acc:{}".format(accuracy_score(y_test, pred)))
RuntimeError: Expected 4-dimensional input for 4-dimensional weight [32, 1, 3, 3], but got 2-dimensional input of size [128, 784] instead
  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 過去に投稿した質問と同じ内容の質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

回答 1

check解決した方法

0

そもそも配列の形状がおかしかった

投稿

  • 回答の評価を上げる

    以下のような回答は評価を上げましょう

    • 正しい回答
    • わかりやすい回答
    • ためになる回答

    評価が高い回答ほどページの上位に表示されます。

  • 回答の評価を下げる

    下記のような回答は推奨されていません。

    • 間違っている回答
    • 質問の回答になっていない投稿
    • スパムや攻撃的な表現を用いた投稿

    評価を下げる際はその理由を明確に伝え、適切な回答に修正してもらいましょう。

15分調べてもわからないことは、teratailで質問しよう!

  • ただいまの回答率 88.92%
  • 質問をまとめることで、思考を整理して素早く解決
  • テンプレート機能で、簡単に質問をまとめられる

関連した質問

同じタグがついた質問を見る