質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%
Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Q&A

0回答

300閲覧

mnistでの畳み込みネットワークが動かない

KenyaTanaka

総合スコア13

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

0グッド

0クリップ

投稿2019/04/19 18:18

前提・実現したいこと

畳み込みネットワークを導入したいのですが、mnistで行うとエラーが出ます

発生している問題・エラーメッセージ

Invalid operation is performed in: Convolution2DFunction (Forward)

Expect: in_types[0].ndim == 4
Actual: 2 != 4
と出てしまいます

該当のソースコード

python

1from chainer.datasets import mnist 2import matplotlib.pyplot as plt 3from chainer.datasets import split_dataset_random 4from chainer import iterators 5from chainer import optimizers 6from chainer import serializers 7from chainer import training 8import random 9import numpy 10import chainer 11import chainer.links as L 12import chainer.functions as F 13 14 15def reset_seed(seed=0): 16 random.seed(seed) 17 numpy.random.seed(seed) 18 if chainer.cuda.available: 19 chainer.cuda.cupy.random.seed(seed) 20 21reset_seed(0) 22 23train_val, test = mnist.get_mnist() 24train, valid = split_dataset_random(train_val, 50000, seed=0) 25 26batchsize = 128 27 28train_iter = iterators.SerialIterator(train, batchsize) 29valid_iter = iterators.SerialIterator(valid, batchsize, False, False) 30test_iter = iterators.SerialIterator(test, batchsize, False, False) 31 32 33class MyNet(chainer.Chain): 34 35 def __init__(self, n_out): 36 super(MyNet, self).__init__() 37 with self.init_scope(): 38 self.conv1 = L.Convolution2D(None, 28, 3, 3, 1) 39 self.conv2 = L.Convolution2D(28,28, 3, 3, 1) 40 self.conv3 = L.Convolution2D(28,28, 3, 3, 1) 41 self.fc4 = L.Linear(None, 1000) 42 self.fc5 = L.Linear(1000, n_out) 43 44 def __call__(self, x): 45 h = F.relu(self.conv1(x)) 46 h = F.relu(self.conv2(h)) 47 h = F.relu(self.conv3(h)) 48 h = F.relu(self.fc4(h)) 49 h = self.fc5(h) 50 return h 51 52 53gpu_id = 0 # CPUを用いたい場合は、-1を指定してください 54 55lr_decay=None 56 57net = MyNet(10) 58 59if gpu_id >= 0: 60 net.to_gpu(gpu_id) 61 62net = L.Classifier(net) 63 64optimizer = optimizers.SGD(lr=0.01).setup(net) 65 66updater = training.StandardUpdater(train_iter, optimizer, device=gpu_id) 67 68max_epoch = 10 69 70trainer = training.Trainer( 71 updater, (max_epoch, 'epoch'), out='mnist_MyNet_result') 72 73from chainer.training import extensions 74 75trainer.extend(extensions.LogReport()) 76trainer.extend(extensions.snapshot(filename='snapshot_epoch-{.updater.epoch}')) 77trainer.extend(extensions.Evaluator(valid_iter, net, device=gpu_id), name='val') 78trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy', 'val/main/loss', 'val/main/accuracy', 'elapsed_time','lr'])) 79trainer.extend(extensions.PlotReport(['main/loss', 'val/main/loss'], x_key='epoch', file_name='loss.png')) 80trainer.extend(extensions.PlotReport(['main/accuracy', 'val/main/accuracy'], x_key='epoch', file_name='accuracy.png')) 81if lr_decay is not None: 82 trainer.extend(extensions.ExponentialShift('lr', 0.1), trigger=lr_decay) 83trainer.run() 84del trainer 85 86trainer.run() 87 88test_evaluator = extensions.Evaluator(test_iter, net, device=gpu_id) 89results = test_evaluator() 90print('Test accuracy:', results['main/accuracy']) 91 92#以下予測用のコード 93 94reset_seed(0) 95 96infer_net = MyNet(10) 97serializers.load_npz( 98 'mnist_MyNet_result/snapshot_epoch-10', 99 infer_net, path='updater/model:main/predictor/') 100 101gpu_id=0 102 103if gpu_id >= 0: 104 infer_net.to_gpu(gpu_id) 105 106x, t = test[0] 107plt.imshow(x.reshape(28, 28), cmap='gray') 108plt.show() 109 110x = infer_net.xp.asarray(x[None, ...]) 111with chainer.using_config('train', False), chainer.using_config('enable_backprop', False): 112 y = infer_net(x) 113y = y.array 114 115print('予測ラベル:', y.argmax(axis=1)[0])

補足情報(FW/ツールのバージョンなど)

python3.7.1

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

まだ回答がついていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問