chainer+MNISTで手書き数字の分類をしようとしていますがエラーが出てしまい。最後まで学習できません。
機械学習初心者なのでわかりやすく解説していただけると嬉しいです。
下準備
python
1import numpy as np 2import chainer 3import chainer 4import chainer.links as L 5import chainer.functions as F 6from chainer import optimizers 7 8 9 10train_data, test_data = chainer.datasets.get_mnist(withlabel=True, ndim=1) 11 12#ニューラルネットワークのクラスを作成 13class MLP(chainer.Chain): 14 #ニューラルネットワークの構成を定義 15 #hiddenは中間層のユニット数outは出力層のユニット数 16 def __init__(self, number_hidden_units=1000, number_out_units=10): 17 super(MLP,self).__init__() 18 19 #初期化時に実行(model = MLP()の時) 20 with self.init_scope(): 21 self.layer1=L.Linear(None, number_hidden_units) 22 self.layer2=L.Linear(number_hidden_units, number_hidden_units) 23 self.layer3=L.Linear(number_hidden_units, number_out_units) 24 25 #クラスのインスタンス作成時に実行(model(inputdate)の時) 26 def __call__(self, input_data): 27 result1 = F.relu(self.layer1(input_data)) 28 result2 = F.relu(self.layer2(result1)) 29 return self.layer3(result2) 30 31model = MLP() 32 33#反復作業を行うためのイテレータを用意 34 35from chainer import iterators 36 37BATCH_SIZE = 100 38 39train_iterator = iterators.SerialIterator(train_data, BATCH_SIZE) 40test_iterator = iterators.SerialIterator(test_data, BATCH_SIZE, repeat=False, shuffle=False) 41 42#learn rate(学習率) 43optimizer = optimizers.SGD(lr=0.01) 44optimizer.setup(model) 45 46import numpy as np 47from chainer.dataset import concat_examples 48import matplotlib.pyplot as plt 49 50#学習回数 51MAX_EPOCH = 20 52 53def testEpoch(train_iterator, loss): 54 #学習誤差の表示 55 print("学習回数:{:02d} --> 学習誤差:{:.02f}".format(train_iterator.epoch, float(loss.data)), end="") 56 57 #検証用誤差と精度 58 test_losses = [] 59 test_accuracies = [] 60 61 while True: 62 test_dataset = test_iterator.next() 63 test_data, test_teacher_labels = concat_examples(test_dataset) 64 65 #検証データをモデルに渡す 66 prediction_test = model(test_data) 67 68 #検証データに対して得られた予測値と予測値と教師ラベルと比較して。ロスの計算をする 69 loss_test = F.softmax_cross_entropy(prediction_test, test_teacher_labels) 70 test_losses.append(loss_test.data) 71 72 #精度を計算する 73 accuracy = F.accuracy(prediction_test, test_teacher_labels) 74 test_accuracies.append(accuracy.data) 75 76 if test_iterator.is_new_epoch: 77 test_iterator.epoch = 0 78 test_iterator.current_position = 0 79 test_iterator.is_new_epoch = False 80 test_iterator._pushed_position = None 81 break 82 83 print ("検証誤差:{: .04f} 検証精度:{: .02f}".format(np.mean(test_losses), np.mean(test_accuracies)))
エラーが出た場所
python
1#学習開始 2while train_iterator.epoch < MAX_EPOCH: 3 train_dataset = train_iterator.next() 4 5 train_data, teacher_labels = concat_examples(train_dataset) 6 7 prediction_train = model(train_data) 8 9 loss = F.softmax_cross_entropy(prediction_train, teacher_labels) 10 11 model.cleargrads() 12 13 loss.backward() 14 15 optimizer.update() 16 17 if train_iterator.is_new_epoch: 18 testEpoch(train_iterator, loss)
エラー内容
python
1学習回数:01 --> 学習誤差:0.47検証誤差: 0.5014 検証精度: 0.87 2検証誤差: 0.4853 検証精度: 0.89 3検証誤差: 0.4924 検証精度: 0.89 4検証誤差: 0.5116 検証精度: 0.87 5検証誤差: 0.5253 検証精度: 0.87 6検証誤差: 0.5436 検証精度: 0.86 7検証誤差: 0.5495 検証精度: 0.85 8検証誤差: 0.5381 検証精度: 0.85 9検証誤差: 0.5256 検証精度: 0.86 10検証誤差: 0.5323 検証精度: 0.86 11検証誤差: 0.5339 検証精度: 0.86 12検証誤差: 0.5438 検証精度: 0.86 13検証誤差: 0.5702 検証精度: 0.85 14検証誤差: 0.5700 検証精度: 0.85 15検証誤差: 0.5707 検証精度: 0.85 16検証誤差: 0.5747 検証精度: 0.85 17検証誤差: 0.5773 検証精度: 0.85 18検証誤差: 0.5856 検証精度: 0.85 19検証誤差: 0.5825 検証精度: 0.85 20検証誤差: 0.5807 検証精度: 0.85 21検証誤差: 0.5865 検証精度: 0.85 22検証誤差: 0.5944 検証精度: 0.85 23検証誤差: 0.5955 検証精度: 0.85 24検証誤差: 0.5941 検証精度: 0.85 25検証誤差: 0.5911 検証精度: 0.85 26検証誤差: 0.5913 検証精度: 0.85 27検証誤差: 0.5917 検証精度: 0.85 28検証誤差: 0.5888 検証精度: 0.85 29検証誤差: 0.5828 検証精度: 0.85 30検証誤差: 0.5832 検証精度: 0.85 31検証誤差: 0.5791 検証精度: 0.86 32検証誤差: 0.5794 検証精度: 0.85 33検証誤差: 0.5773 検証精度: 0.85 34検証誤差: 0.5761 検証精度: 0.85 35検証誤差: 0.5737 検証精度: 0.85 36検証誤差: 0.5765 検証精度: 0.85 37検証誤差: 0.5732 検証精度: 0.85 38検証誤差: 0.5782 検証精度: 0.85 39検証誤差: 0.5831 検証精度: 0.85 40検証誤差: 0.5835 検証精度: 0.85 41検証誤差: 0.5821 検証精度: 0.85 42検証誤差: 0.5817 検証精度: 0.85 43検証誤差: 0.5839 検証精度: 0.85 44検証誤差: 0.5841 検証精度: 0.85 45検証誤差: 0.5852 検証精度: 0.85 46検証誤差: 0.5842 検証精度: 0.85 47検証誤差: 0.5824 検証精度: 0.85 48検証誤差: 0.5812 検証精度: 0.85 49検証誤差: 0.5845 検証精度: 0.85 50検証誤差: 0.5847 検証精度: 0.85 51検証誤差: 0.5809 検証精度: 0.85 52検証誤差: 0.5766 検証精度: 0.85 53検証誤差: 0.5720 検証精度: 0.85 54検証誤差: 0.5653 検証精度: 0.86 55検証誤差: 0.5584 検証精度: 0.86 56検証誤差: 0.5541 検証精度: 0.86 57検証誤差: 0.5522 検証精度: 0.86 58検証誤差: 0.5485 検証精度: 0.86 59検証誤差: 0.5470 検証精度: 0.86 60検証誤差: 0.5461 検証精度: 0.86 61検証誤差: 0.5460 検証精度: 0.86 62検証誤差: 0.5432 検証精度: 0.86 63検証誤差: 0.5367 検証精度: 0.86 64検証誤差: 0.5324 検証精度: 0.87 65検証誤差: 0.5289 検証精度: 0.87 66検証誤差: 0.5294 検証精度: 0.87 67検証誤差: 0.5301 検証精度: 0.87 68検証誤差: 0.5306 検証精度: 0.87 69検証誤差: 0.5280 検証精度: 0.87 70検証誤差: 0.5248 検証精度: 0.87 71検証誤差: 0.5214 検証精度: 0.87 72検証誤差: 0.5192 検証精度: 0.87 73検証誤差: 0.5176 検証精度: 0.87 74検証誤差: 0.5134 検証精度: 0.87 75検証誤差: 0.5146 検証精度: 0.87 76検証誤差: 0.5131 検証精度: 0.87 77検証誤差: 0.5112 検証精度: 0.87 78検証誤差: 0.5079 検証精度: 0.88 79検証誤差: 0.5090 検証精度: 0.87 80検証誤差: 0.5064 検証精度: 0.88 81検証誤差: 0.5039 検証精度: 0.88 82検証誤差: 0.5015 検証精度: 0.88 83検証誤差: 0.5008 検証精度: 0.88 84検証誤差: 0.4986 検証精度: 0.88 85検証誤差: 0.4962 検証精度: 0.88 86検証誤差: 0.4939 検証精度: 0.88 87検証誤差: 0.4906 検証精度: 0.88 88検証誤差: 0.4862 検証精度: 0.88 89検証誤差: 0.4824 検証精度: 0.88 90検証誤差: 0.4793 検証精度: 0.88 91検証誤差: 0.4799 検証精度: 0.88 92検証誤差: 0.4776 検証精度: 0.88 93検証誤差: 0.4762 検証精度: 0.89 94検証誤差: 0.4737 検証精度: 0.89 95検証誤差: 0.4717 検証精度: 0.89 96検証誤差: 0.4702 検証精度: 0.89 97検証誤差: 0.4708 検証精度: 0.89 98検証誤差: 0.4740 検証精度: 0.89 99検証誤差: 0.4759 検証精度: 0.89 100--------------------------------------------------------------------------- 101AttributeError Traceback (most recent call last) 102/var/folders/tv/vd8qs8x14xg282z04cjgj2dm0000gn/T/ipykernel_1514/2234887032.py in <module> 103 15 104 16 if train_iterator.is_new_epoch: 105---> 17 testEpoch(train_iterator, loss) 106 107/var/folders/tv/vd8qs8x14xg282z04cjgj2dm0000gn/T/ipykernel_1514/2749362989.py in testEpoch(train_iterator, loss) 108 30 109 31 if test_iterator.is_new_epoch: 110---> 32 test_iterator.epoch = 0 111 33 test_iterator.current_position = 0 112 34 test_iterator.is_new_epoch = False 113 114AttributeError: can't set attribute
いろいろ調べてみましたが、このエラーについてはあまり記事がなくわかりませんでした。
わかる方いましたら、回答よろしくお願いします。
回答2件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。