前提・実現したいこと
現在SSD_kerasの機械学習のテストとして、
https://github.com/rykov8/ssd_keras
に同梱されているSSD_training.ipynbをGoogleColabで動作させようとしています。
とりあえずepoch=2で動作すればいいかな、と思ったのですが
発生している問題・エラーメッセージ
--------------------------------------------------------------------------- InvalidArgumentError Traceback (most recent call last) <ipython-input-12-201d308d4729> in <module>() 9 validation_data=gen.generate(False), 10 validation_steps=gen.val_batches, ---> 11 workers=1) 6 frames /usr/local/lib/python3.7/dist-packages/keras/engine/training_v1.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch) 1245 use_multiprocessing=use_multiprocessing, 1246 shuffle=shuffle, -> 1247 initial_epoch=initial_epoch) 1248 1249 def evaluate_generator(self, /usr/local/lib/python3.7/dist-packages/keras/engine/training_v1.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs) 794 max_queue_size=max_queue_size, 795 workers=workers, --> 796 use_multiprocessing=use_multiprocessing) 797 798 def evaluate(self, /usr/local/lib/python3.7/dist-packages/keras/engine/training_generator_v1.py in fit(self, model, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing) 584 shuffle=shuffle, 585 initial_epoch=initial_epoch, --> 586 steps_name='steps_per_epoch') 587 588 def evaluate(self, /usr/local/lib/python3.7/dist-packages/keras/engine/training_generator_v1.py in model_iteration(model, data, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch, mode, batch_size, steps_name, **kwargs) 250 251 is_deferred = not model._is_compiled --> 252 batch_outs = batch_function(*batch_data) 253 if not isinstance(batch_outs, list): 254 batch_outs = [batch_outs] /usr/local/lib/python3.7/dist-packages/keras/engine/training_v1.py in train_on_batch(self, x, y, sample_weight, class_weight, reset_metrics) 1074 self._update_sample_weight_modes(sample_weights=sample_weights) 1075 self._make_train_function() -> 1076 outputs = self.train_function(ins) # pylint: disable=not-callable 1077 1078 if reset_metrics: /usr/local/lib/python3.7/dist-packages/keras/backend.py in __call__(self, inputs) 4185 4186 fetched = self._callable_fn(*array_vals, -> 4187 run_metadata=self.run_metadata) 4188 self._call_fetch_callbacks(fetched[-len(self._fetches):]) 4189 output_structure = tf.nest.pack_sequence_as( /usr/local/lib/python3.7/dist-packages/tensorflow/python/client/session.py in __call__(self, *args, **kwargs) 1483 ret = tf_session.TF_SessionRunCallable(self._session._session, 1484 self._handle, args, -> 1485 run_metadata_ptr) 1486 if run_metadata: 1487 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr) InvalidArgumentError: Incompatible shapes: [16,7308,4] vs. [16,933,4] [[{{node training/Adam/gradients/gradients/loss/predictions_loss/sub_1_grad/BroadcastGradientArgs}}]]
該当のソースコード
base_lr = 3e-4 optim = keras.optimizers.Adam(lr=base_lr) model.compile(optimizer=optim, loss=MultiboxLoss(NUM_CLASSES, neg_pos_ratio=2.0).compute_loss) epochs = 2 batch_size=8 history = model.fit_generator(gen.generate(True), steps_per_epoch =gen.train_batches//batch_size, epochs=epochs, verbose=1, callbacks=callbacks, validation_data=gen.generate(False), validation_steps=gen.val_batches, workers=1)
model.fit_generatorは下記のqiitaのサイトを参考に変更した点ですが、元のコードも併記しておきます
history = model.fit_generator(gen.generate(True), gen.train_batches, nb_epoch, verbose=1, callbacks=callbacks, validation_data=gen.generate(False), nb_val_samples=gen.val_batches, nb_worker=1)
試したこと
Tensorflowv2.xへの対応は
https://qiita.com/ttskng/items/4f67f4bbda2568229956
を参考に行いました。
また、学習データとしてVOC2007を利用させていただいています。
自分ではかなり調べたと思うのですが、fit_generatorにて同じようなエラーの発生事例も見つけられず...
回答1件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。