LSTMを使うGANの実装をしています。
その際に使うpredict()が実行できずに困っております。
keras
1def build_generator(seq_length,z_dim): 2 model = Sequential() 3 model.add(Dense(32, input_dim=z_dim)) 4 model.add(Embedding(50,32,input_length=z_dim)) 5 model.add(LSTM(32)) 6 model.add(Dense(seq_length*jigen_length, activation='softmax')) 7 model.add(Reshape((5, 1))) 8 model.summary() 9 return model 10 11z = np.random.normal(0, 1, (batch_size, 100)) 12gen_text = generator.predict(z) #実行できない
error
1InvalidArgumentError Traceback (most recent call last) 2<ipython-input-169-f04cfd1d3585> in <module> 3 6 print(len(x_train)) 4 7 z = np.random.normal(0, 1, (batch_size, 100)) 5----> 8 gen_text = generator.predict(z) 6 9 7 10 print(z) 8 9~\anaconda3\lib\site-packages\tensorflow\python\keras\engine\training.py in _method_wrapper(self, *args, **kwargs) 10 128 raise ValueError('{} is not supported in multi-worker mode.'.format( 11 129 method.__name__)) 12--> 130 return method(self, *args, **kwargs) 13 131 14 132 return tf_decorator.make_decorator( 15 16~\anaconda3\lib\site-packages\tensorflow\python\keras\engine\training.py in predict(self, x, batch_size, verbose, steps, callbacks, max_queue_size, workers, use_multiprocessing) 17 1597 for step in data_handler.steps(): 18 1598 callbacks.on_predict_batch_begin(step) 19-> 1599 tmp_batch_outputs = predict_function(iterator) 20 1600 if data_handler.should_sync: 21 1601 context.async_wait() 22 23~\anaconda3\lib\site-packages\tensorflow\python\eager\def_function.py in __call__(self, *args, **kwds) 24 778 else: 25 779 compiler = "nonXla" 26--> 780 result = self._call(*args, **kwds) 27 781 28 782 new_tracing_count = self._get_tracing_count() 29 30~\anaconda3\lib\site-packages\tensorflow\python\eager\def_function.py in _call(self, *args, **kwds) 31 812 # In this case we have not created variables on the first call. So we can 32 813 # run the first trace but we should fail if variables are created. 33--> 814 results = self._stateful_fn(*args, **kwds) 34 815 if self._created_variables: 35 816 raise ValueError("Creating variables on a non-first call to a function" 36 37~\anaconda3\lib\site-packages\tensorflow\python\eager\function.py in __call__(self, *args, **kwargs) 38 2827 with self._lock: 39 2828 graph_function, args, kwargs = self._maybe_define_function(args, kwargs) 40-> 2829 return graph_function._filtered_call(args, kwargs) # pylint: disable=protected-access 41 2830 42 2831 @property 43 44~\anaconda3\lib\site-packages\tensorflow\python\eager\function.py in _filtered_call(self, args, kwargs, cancellation_manager) 45 1841 `args` and `kwargs`. 46 1842 """ 47-> 1843 return self._call_flat( 48 1844 [t for t in nest.flatten((args, kwargs), expand_composites=True) 49 1845 if isinstance(t, (ops.Tensor, 50 51~\anaconda3\lib\site-packages\tensorflow\python\eager\function.py in _call_flat(self, args, captured_inputs, cancellation_manager) 52 1921 and executing_eagerly): 53 1922 # No tape is watching; skip to running the function. 54-> 1923 return self._build_call_outputs(self._inference_function.call( 55 1924 ctx, args, cancellation_manager=cancellation_manager)) 56 1925 forward_backward = self._select_forward_and_backward_functions( 57 58~\anaconda3\lib\site-packages\tensorflow\python\eager\function.py in call(self, ctx, args, cancellation_manager) 59 543 with _InterpolateFunctionError(self): 60 544 if cancellation_manager is None: 61--> 545 outputs = execute.execute( 62 546 str(self.signature.name), 63 547 num_outputs=self._num_outputs, 64 65~\anaconda3\lib\site-packages\tensorflow\python\eager\execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name) 66 57 try: 67 58 ctx.ensure_initialized() 68---> 59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name, 69 60 inputs, attrs, num_outputs) 70 61 except core._NotOkStatusException as e: 71 72InvalidArgumentError: 2 root error(s) found. 73 (0) Invalid argument: indices[0,3] = -1 is not in [0, 50) 74 [[node sequential_117/embedding_73/embedding_lookup (defined at <ipython-input-168-f04cfd1d3585>:8) ]] 75 (1) Invalid argument: indices[0,3] = -1 is not in [0, 50) 76 [[node sequential_117/embedding_73/embedding_lookup (defined at <ipython-input-168-f04cfd1d3585>:8) ]] 77 [[sequential_117/embedding_73/embedding_lookup/_6]] 780 successful operations. 790 derived errors ignored. [Op:__inference_predict_function_68877] 80 81Errors may have originated from an input operation. 82Input Source operations connected to node sequential_117/embedding_73/embedding_lookup: 83 sequential_117/embedding_73/embedding_lookup/68418 (defined at C:\Users\gest\anaconda3\lib\contextlib.py:113) 84 85Input Source operations connected to node sequential_117/embedding_73/embedding_lookup: 86 sequential_117/embedding_73/embedding_lookup/68418 (defined at C:\Users\gest\anaconda3\lib\contextlib.py:113) 87 88Function call stack: 89predict_function -> predict_function
また、EmbeddingレイヤーとLSTMレイヤーを削除したらpredict()が実行出来ました。
Embeddingレイヤーのみ削除、LSTMレイヤーのみ削除の場合は実行出来ませんでした。
両コードともコンパイルは通ります。
どうか宜しくお願いします。
keras
1def build_generator(seq_length,z_dim): 2 model = Sequential() 3 model.add(Dense(32, input_dim=z_dim)) 4 model.add(Dense(seq_length*jigen_length, activation='softmax')) 5 #model.add(Reshape((5, 1))) 6 model.summary() 7 return model 8 9 10z = np.random.normal(0, 1, (batch_size, 100)) 11gen_text = generator.predict(z) #実行できる 12
keras
1#実際関数を呼び出してGANのモデルをコンパイルしてあげる 2discriminator = build_discriminatior() 3discriminator.compile(loss='binary_crossentropy', optimizer=Adam(), metrics=['accuracy']) 4generator = build_generator(seq_length,z_dim) 5 6#識別器の学習機能をオフにしてあげる。識別器と生成器を別々に学習させてあげられる 7discriminator.trainable = False 8 9gan = build_gan(generator, discriminator) 10gan.compile(loss='binary_crossentropy', optimizer=Adam())
回答1件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。