前提・実現したいこと
python3.8
tensorflow2.4
を使用しています
https://qiita.com/yutaoba/items/6eb0e12ba0d169a480df
上記のサイトを参考に画像キャプション生成を行っています。
しかし、エラーが発生しわからない状態です。
申し訳ございませんがご教示いただけないでしょうか。
発生している問題・エラーメッセージ
Traceback (most recent call last): File "/home/limlab/deep_learning/keras_imagecaption/japa_p3.py", line 419, in <module> tr.TrainModel() File "/home/limlab/deep_learning/keras_imagecaption/japa_p3.py", line 339, in TrainModel self.model.fit_generator(generator, epochs=self.epochs, steps_per_epoch=steps, verbose=1) File "/home/limlab/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 1847, in fit_generator return self.fit( File "/home/limlab/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 1100, in fit tmp_logs = self.train_function(iterator) File "/home/limlab/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 828, in __call__ result = self._call(*args, **kwds) File "/home/limlab/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 871, in _call self._initialize(args, kwds, add_initializers_to=initializers) File "/home/limlab/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 725, in _initialize self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access File "/home/limlab/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2969, in _get_concrete_function_internal_garbage_collected graph_function, _ = self._maybe_define_function(args, kwargs) File "/home/limlab/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3361, in _maybe_define_function graph_function = self._create_graph_function(args, kwargs) File "/home/limlab/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3196, in _create_graph_function func_graph_module.func_graph_from_py_func( File "/home/limlab/.local/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 990, in func_graph_from_py_func func_outputs = python_func(*func_args, **func_kwargs) File "/home/limlab/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 634, in wrapped_fn out = weak_wrapped_fn().__wrapped__(*args, **kwds) File "/home/limlab/.local/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 977, in wrapper raise e.ag_error_metadata.to_exception(e) ValueError: in user code: /home/limlab/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:805 train_function * return step_function(self, iterator) /home/limlab/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:795 step_function ** outputs = model.distribute_strategy.run(run_step, args=(data,)) /home/limlab/.local/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py:1259 run return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs) /home/limlab/.local/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py:2730 call_for_each_replica return self._call_for_each_replica(fn, args, kwargs) /home/limlab/.local/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py:3417 _call_for_each_replica return fn(*args, **kwargs) /home/limlab/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:788 run_step ** outputs = model.train_step(data) /home/limlab/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:754 train_step y_pred = self(x, training=True) /home/limlab/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py:998 __call__ input_spec.assert_input_compatibility(self.input_spec, inputs, self.name) /home/limlab/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/input_spec.py:204 assert_input_compatibility raise ValueError('Layer ' + layer_name + ' expects ' + ValueError: Layer model expects 2 input(s), but it received 3 input tensors. Inputs received: [<tf.Tensor 'IteratorGetNext:0' shape=(None, None) dtype=float32>, <tf.Tensor 'IteratorGetNext:1' shape=(None, None) dtype=int32>, <tf.Tensor 'IteratorGetNext:2' shape=(None, None) dtype=float32>]
該当のソースコード
def DataGenerator(self): """ ・model.fit_generatorに読み込ませるデータジェネレータを作成する。 """ while 1: for key, image_texts in self.train_texts_dict.items(): image_feature = self.train_features_dict[key][0] in_img, in_seq, out_word = self.MakeInputOutput(image_texts, image_feature) yield [[in_img, in_seq], out_word] def TrainModel(self): """ ・モデルを訓練する。 """ self.MakeTokenizer() self.GetVocabSize() self.GetMaxLength() self.MakeCaptioningModel() steps=len(self.train_texts_dict) for i in range(self.epochs): generator = self.DataGenerator() self.model.fit_generator(generator, epochs=self.epochs, steps_per_epoch=steps, verbose=1) self.model.save('model_' + str(i) + '.h5') return None
あなたの回答
tips
プレビュー