ChainerでMNISTを使ってNNを実装したんですがaccuracyとlossのPlotの方法がわかりません。できればモデルA,モデルBでの結果を比較できるような形でPlotしたいです。最適化も比較できれば尚いいです。よろしくお願いします。
import matplotlib.pyplot as plt from chainer import cuda from chainer import serializer import chainer from chainer import functions as F from chainer import links as L from chainer import Variable import numpy as np from chainer import optimizers from chainer import training train_full, test_full = chainer.datasets.get_mnist() train = chainer.datasets.SubDataset(train_full, 0, 1000) test = chainer.datasets.SubDataset(test_full, 0, 1000) batchsize = 30 train_iter = chainer.iterators.SerialIterator(train, batchsize) test_iter = chainer.iterators.SerialIterator(test, batchsize, repeat = False, shuffle = False) class MultilayerPerceptron(chainer.Chain): def __init__(self, n_units, n_out): super(MultilayerPerceptron, self).__init__() with self.init_scope(): # full combination # at the same time, create a weight matrix (n_inputs, n_units) self.l1 = L.Linear(None, n_units) # n_in -> n_units self.l2 = L.Linear(None, n_units) # n_units -> n_units self.l3 = L.Linear(None, n_out) # n_units -> n_out self.bn = L.BatchNormalization(n_units) def __call__(self, x): h1 = self.l1(x) hb1 = F.relu(h1) h2 = self.l2(hb1) hb2 = F.relu(self.l2(h2)) y = self.l3(hb2) return y class MultilayerPerceptronV2(MultilayerPerceptron): def __call__(self, x): # most common activation function h1 = self.l1(x) b1 = self.bn(h1) hb1 = F.relu(b1) h2 = self.l2(hb1) b2 = self.bn(h2) hb2 = F.relu(self.l2(b2)) y = self.l3(hb2) return y model = L.Classifier(MultilayerPerceptron(784, 10)) # choose optimizer # AdaDelta, AdaGrad, Adam, MomentumSGD, NesterovAG, RMSprop, RMSpropGraves, SGD, SMORMS3 opt = optimizers.SGD() # self.setup(Link or Chain) opt.setup(model) # device=-1 means Using CPU updater = training.StandardUpdater(train_iter, opt, device=-1) epoch = 10 trainer = training.Trainer(updater, (epoch, 'epoch'), out='/tmp/result') trainer.extend(training.extensions.Evaluator(test_iter, model, device=-1)) trainer.extend(training.extensions.LogReport(trigger=(1, "epoch"))) trainer.extend(training.extensions.PrintReport( ['epoch', 'main/loss', 'main/accuracy', 'validation/main/loss', 'validation/main/accuracy', 'elapsed_time']), trigger=(1, "epoch")) trainer.run()
あなたの回答
tips
プレビュー