急に以下のようなエラーが出ました。
解決策を教えて頂けるとありがたいです。
python
1import chainer 2import numpy as np 3import chainer.functions as F 4from chainer import serializers,iterators,optimizer 5from chainer.dataset import convert 6 7 8#学習に関する基本情報の定義 9EPOCH = 10 #学習回数 10BATCH = 20 #バッチサイズ 11NUM_SHAPE = 48 #画像一辺の長さ 12LEARN_RATE = 0.001 13WEIGHT_DECAY =1e-4 14GPU = 0 15SAVE_MODEL = "/content/drive/My Drive/saved_model/myresnet.npz" 16 17def learn(csvfile): 18 19 train,publictest,_= dataFromCsv(csvfile) 20 21 train_iter = iterators.SerialIterator(train,batch_size=BATCH,shuffle=True) 22 publictest_iter = iterators.SerialIterator(publictest,batch_size=BATCH,repeat=False,shuffle=False) 23 24 25 model = ResNet(class_labels=9) 26 27 chainer.cuda.get_device(GPU).use 28 model.to_gpu() 29 30 optimizer = chainer.optimizers.MomentumSGD(LEARN_RATE) 31 optimizer.setup(model) 32 optimizer.add_hook(chainer.optimizer.WeightDecay(WEIGHT_DECAY)) 33 34 saved_model = model 35 36 while train_iter.epoch < EPOCH: 37 38 batch = train_iter.next() 39 40 trainLossList = [] 41 42 x_array, y_array = convert.concat_examples(batch,GPU) 43 44 x = chainer.Variable(x_array) 45 y = chainer.Variable(y_array) 46 m = model(x) 47 48 loss_train = myCrossEntropyError(m,y) 49 50 model.cleargrads() 51 52 loss_train.backward() 53 54 optimizer.update() 55 56 trainLossList.append(chainer.cuda.to_cpu(loss_train.data)) 57 58 59 if train_iter.is_new_epoch: 60 61 testLossList = [] 62 63 for batch in publictest_iter: 64 x_array, y_array = convert.concat_examples(batch, GPU) 65 x = chainer.Variable(x_array) 66 y = chainer.Variable(y_array) 67 m = model(x) 68 69 loss_test = myCrossEntropyError(m, y) 70 testLossList.append(chainer.cuda.to_cpu(loss_test.data)) 71 72 73 if loss_test.data == np.min(testLossList): 74 saved_model = model 75 76 publictest_iter.reset() 77 78 print("epo:" + str(train_iter.epoch) + " train_loss:" + str(np.mean(trainLossList)) + " test_loss:" + str(np.mean(testLossList))) 79 80 81 82 chainer.serializers.save_npz(SAVE_MODEL, saved_model) 83 84 return 85 86 87def myCrossEntropyError(m,y): 88 DELTA = 1e-7 # マイナス無限大を発生させないように微小な値を追加する 89 return -F.sum(y*F.log(m+DELTA)+(1-y)*F.log(1-m+DELTA)) 90 91if __name__ == "__main__": 92 learn("/content/drive/My Drive/a.csv")
error
1--------------------------------------------------------------------------- 2InvalidType Traceback (most recent call last) 3<ipython-input-4-917b77ca19ac> in <module>() 4 94 5 95 if __name__ == "__main__": 6---> 96 learn("/content/drive/My Drive/a.csv") 7 86 frames 9<ipython-input-4-917b77ca19ac> in learn(csvfile) 10 71 m = model(x) 11 72 12---> 73 loss_test = myCrossEntropyError(m, y) 13 74 testLossList.append(chainer.cuda.to_cpu(loss_test.data)) 14 75 15 16<ipython-input-4-917b77ca19ac> in myCrossEntropyError(m, y) 17 91 def myCrossEntropyError(m,y): 18 92 DELTA = 1e-7 # マイナス無限大を発生させないように微小な値を追加する 19---> 93 return -F.sum(y*F.log(m+DELTA)+(1-y)*F.log(1-m+DELTA)) 20 94 21 95 if __name__ == "__main__": 22 23/usr/local/lib/python3.6/dist-packages/chainer/functions/math/basic_math.py in mul(self, rhs) 24 391 return MulConstant(rhs).apply((self,))[0] 25 392 rhs = _preprocess_rhs(self, rhs) 26--> 393 return Mul().apply((self, rhs))[0] 27 394 28 395 29 30/usr/local/lib/python3.6/dist-packages/chainer/function_node.py in apply(self, inputs) 31 295 32 296 if configuration.config.type_check: 33--> 297 self._check_data_type_forward(in_data) 34 298 35 299 hooks = chainer.get_function_hooks() 36 37/usr/local/lib/python3.6/dist-packages/chainer/function_node.py in _check_data_type_forward(self, in_data) 38 398 in_type = type_check.get_types(in_data, 'in_types', False) 39 399 with type_check.get_function_check_context(self): 40--> 400 self.check_type_forward(in_type) 41 401 42 402 def check_type_forward(self, in_types): 43 44/usr/local/lib/python3.6/dist-packages/chainer/functions/math/basic_math.py in check_type_forward(self, in_types) 45 339 ) 46 340 type_check.expect_broadcast_shapes( 47--> 341 in_types[0].shape, in_types[1].shape) 48 342 49 343 def forward_chainerx(self, x): 50 51/usr/local/lib/python3.6/dist-packages/chainer/utils/type_check.py in expect_broadcast_shapes(*shape_types) 52 643 error = InvalidType('', '', msg='\n'.join(msgs)) 53 644 if error is not None: 54--> 645 raise error 55 56InvalidType: cannot broadcast inputs of the following shapes: 57lhs.shape = (20, 8) 58rhs.shape = (20, 9)
あなたの回答
tips
プレビュー