前提・実現したいこと
chainerを用いて艦これのキャラクターを区別できるアプリを作りたいです
また、現在はデータセットは4人分だけでテストしている段階です
データセットは4人合わせて280個ほどです
該当のソースコード
python
1from chainer.datasets import LabeledImageDataset 2from chainercv.transforms import resize 3from chainer.datasets import TransformDataset 4import numpy as np 5import chainer 6import chainer.links as L 7import chainer.functions as F 8from chainer.datasets import split_dataset_random 9from chainer import iterators 10from chainer import optimizers 11from chainer import training 12from chainer.training import extensions 13 14train = LabeledImageDataset('output/hibiki.txt', 'output/Images') 15np.random.seed(0) 16 17 18def transform(in_data): 19 img, label = in_data 20 img = resize(img, (100, 100)) 21 if np.random.rand() > 0.5: 22 img = img[..., ::-1] 23 return img, label 24 25 26dataset = TransformDataset(train, transform) 27 28train, valid = split_dataset_random(dataset, int(len(dataset) * 0.7), seed=1) 29 30print('Training dataset size:', len(train)) 31print('Validation dataset size:', len(valid)) 32 33batchsize = 32 34train_iter = iterators.SerialIterator(train, batchsize, shuffle=True, repeat=True) 35valid_iter = iterators.SerialIterator(valid, batchsize, shuffle=False, repeat=False) 36 37 38class MLP(chainer.Chain): 39 def __init__(self, n_hidden_1=20, n_out=4): 40 super().__init__() 41 with self.init_scope(): 42 self.fc1 = L.Linear(None, n_hidden_1) 43 self.fc2 = L.Linear(None, n_out) 44 45 def forward(self, x): 46 h = F.sigmoid(self.fc1(x)) 47 h = self.fc2(h) 48 return h 49 50 51predictor = MLP() 52net = L.Classifier(predictor) 53optimizer = optimizers.Adam() 54optimizer.setup(net) 55updater = training.StandardUpdater(train_iter, optimizer) 56trainer = training.Trainer(updater, (100, 'epoch'), out='result/hibiki') 57trainer.extend(extensions.LogReport(trigger=(1, 'epoch'), log_name='hibiki_log')) 58trainer.extend(extensions.snapshot(filename='snapshot_epoch-{.updater.epoch}')) 59trainer.extend(extensions.dump_graph('main/loss')) 60trainer.extend(extensions.Evaluator(valid_iter, net), name='val') 61trainer.extend(extensions.PrintReport(['epoch', 'iteration', 'main/loss', 'main/accuracy', 'val/main/loss', 62 'val/main/accuracy', 'fc1/W/data/mean', 'elapsed_time'])) 63trainer.extend(extensions.PlotReport(['fc1/W/grad/mean'], x_key='epoch', filename='mean.png')) 64trainer.extend(extensions.PlotReport(['main/loss', 'val/main/loss'], x_key='epoch', filename='loss.png')) 65trainer.extend(extensions.PlotReport(['main/accuracy', 'val/main/accuracy'], x_key='epoch', filename='accuracy.png')) 66trainer.run() 67
###返ってくる結果
文字数制限に引っかかったので50エポック分だけ
Training dataset size: 201 Validation dataset size: 87 epoch iteration main/loss main/accuracy val/main/loss val/main/accuracy fc1/W/data/mean elapsed_time 1 7 1.41295 0.294643 1.40185 0.324275 1.61004 2 13 1.39771 0.286458 1.3931 0.324275 3.05204 3 19 1.39028 0.286458 1.39186 0.324275 4.51765 4 26 1.37493 0.299107 1.389 0.324275 6.08734 5 32 1.37575 0.307292 1.38776 0.324275 7.61809 6 38 1.37768 0.291667 1.38939 0.324275 9.08393 7 44 1.37627 0.296875 1.38646 0.324275 10.506 8 51 1.37081 0.303571 1.38681 0.324275 12.0721 9 57 1.37856 0.291667 1.38776 0.324275 13.5506 10 63 1.37301 0.291667 1.38581 0.324275 15.0259 11 70 1.37237 0.308036 1.38371 0.324275 16.5973 12 76 1.36652 0.307292 1.38623 0.324275 18.0419 13 82 1.38441 0.28125 1.38305 0.324275 19.5243 14 88 1.36964 0.307292 1.38388 0.324275 21.0289 15 95 1.37387 0.28125 1.3847 0.324275 22.5887 16 101 1.38087 0.296875 1.38363 0.324275 24.0153 17 107 1.36879 0.317708 1.38594 0.324275 25.4613 18 114 1.3666 0.325893 1.38427 0.324275 27.0275 19 120 1.38243 0.25 1.38568 0.324275 28.4708 20 126 1.3741 0.317708 1.38467 0.324275 29.9948 21 132 1.37115 0.302083 1.383 0.324275 31.4303 22 139 1.37763 0.294643 1.38392 0.324275 33.1488 23 145 1.36742 0.296875 1.38577 0.324275 34.6832 24 151 1.37904 0.291667 1.38499 0.324275 36.1234 25 158 1.36891 0.3125 1.38481 0.324275 37.6757 26 164 1.37316 0.291667 1.38329 0.324275 39.1111 27 170 1.3731 0.28125 1.38597 0.324275 40.7936 28 176 1.37593 0.3125 1.38378 0.324275 42.4755 29 183 1.37217 0.294643 1.38361 0.324275 44.3277 30 189 1.36924 0.3125 1.38463 0.324275 46.1707 31 195 1.37829 0.291667 1.38456 0.324275 47.837 32 201 1.37436 0.296875 1.38286 0.324275 49.3686 33 208 1.37437 0.299107 1.38355 0.324275 51.0571 34 214 1.37272 0.291667 1.38397 0.324275 52.6414 35 220 1.37244 0.307292 1.38313 0.324275 54.3373 36 227 1.37533 0.294643 1.38375 0.324275 56.2733 37 233 1.36825 0.3125 1.38442 0.324275 58.303 38 239 1.37652 0.286458 1.38259 0.324275 60.6793 39 245 1.37562 0.296875 1.38438 0.324275 62.615 40 252 1.37595 0.285714 1.38363 0.324275 64.3411 41 258 1.37335 0.317708 1.38652 0.324275 65.8537 42 264 1.37306 0.286458 1.38621 0.324275 67.431 43 271 1.37141 0.308036 1.38587 0.324275 69.253 44 277 1.37753 0.286458 1.38313 0.324275 71.109 45 283 1.36983 0.3125 1.38398 0.324275 72.8181 46 289 1.3716 0.296875 1.38408 0.324275 74.4545 47 296 1.37039 0.299107 1.3852 0.324275 76.1762 48 302 1.37535 0.291667 1.38599 0.324275 77.7544 49 308 1.37393 0.3125 1.3845 0.324275 79.5401 50 315 1.37401 0.294643 1.38474 0.324275 81.2977
###発生している問題
上の結果の通り、損失が減らず、正答率も増えなくなってしまいます
試したこと
・中間層のノード数を増やしてみる
・最適化手法を変えてみる
補足情報(FW/ツールのバージョンなど)
python3.7.3
chainer6.1.0
chainercv0.13.1
numpy1.16.2
あなたの回答
tips
プレビュー