読み込むExcelデータを変えただけでInconsistent dtypes or shapes between inputs
and input_tensor_spec
というエラーが出る
2020年にも別の掲示板で似たような質問をしている人がいましたがtf-agentsAtariのアップデートを待てとのことでした...
内部で生成するQネットワークにエラーがある?
一部名前を修正しているのでミスがあったらすみません.
ライブラリのバージョン
tensorflow 2.9.1 tf-agents 0.13.0
プログラム(一部抜粋)
python
#EnvironmentSimulatorの一部 class EnvironmentSimulator(py_environment.PyEnvironment): def __init__(self,GET_DATA,mix_alpha): self.GET_DATA=GET_DATA self.mix_alpha=mix_alpha self._observation_spec = array_spec.BoundedArraySpec( shape=(self.GET_DATA,6), dtype=np.float32 ) self._action_spec = array_spec.BoundedArraySpec( shape=(), dtype=np.int32, minimum=0, maximum=2 ) self._reset() def observation_spec(self): return self._observation_spec def action_spec(self): return self._action_spec def _reset(self): self._state= np.zeros((self.GET_DATA,6), dtype=np.float32) self._TABLE_INX=500 return ts.restart(np.array(self._state, dtype=np.float32)) def _step(self, action): df_mix=self.df_mix[self._TABLE_INX:self._TABLE_INX+self.GET_DATA] self._state=tf.constant(df_mix) self._TABLE_INX=self._TABLE_INX+1 reward = 0 if action==0: (step処理省略) return ts.transition(np.array(self._state,dtype=np.float32), reward=reward)
上記のステップ関数内でget_data行×6列のデータ取得しself._stateに更新しています.
ネットワーク構成
python
#ネットワーク構築 class MyQNetwork(network.Network): def __init__(self, observation_spec, action_spec, n_hidden_channels=100,name='QNetwork'): super(MyQNetwork, self).__init__( input_tensor_spec=observation_spec, state_spec=(), name=name ) n_action = action_spec.maximum - action_spec.minimum + 1 set_dropout=0.2 self.model = keras.Sequential( [ keras.layers.InputLayer(input_shape=(100,6)), keras.layers.LSTM(units=n_hidden_channels,dropout=set_dropout,return_sequences=True), (略
メイン関数の一部
python
#メイン関数の一部 GET_DATA=100 mix_data=pd.read_excel('mix_data.xlsx',index_col=0) env_py = EnvironmentSimulator(GET_DATA,mix_data) env = tf_py_environment.TFPyEnvironment(env_py) primary_network = MyQNetwork(env.observation_spec(), env.action_spec()) #エージェントの設定 n_step_update = 1 agent = dqn_agent.DqnAgent( env.time_step_spec(), env.action_spec(), q_network=primary_network, optimizer=keras.optimizers.Adam(learning_rate=1e-3, epsilon=1e-5), n_step_update=n_step_update, epsilon_greedy=1.0, target_update_tau=1.0, target_update_period=10, gamma=0.9, td_errors_loss_fn = common.element_wise_squared_loss, train_step_counter = tf.Variable(0) ) agent.initialize() agent.train = common.function(agent.train) policy = agent.collect_policy replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec=agent.collect_data_spec, batch_size=env.batch_size, max_length=10**6 ) dataset = replay_buffer.as_dataset( num_parallel_calls=tf.data.experimental.AUTOTUNE, sample_batch_size=32, num_steps=n_step_update+1 ).prefetch(tf.data.experimental.AUTOTUNE) iterator = iter(dataset) env.reset() driver = dynamic_step_driver.DynamicStepDriver( env, policy, observers=[replay_buffer.add_batch], num_steps = 500, ) driver.run(maximum_iterations=1000) num_episodes = 200 epsilon = np.linspace(start=1.0, stop=0.0, num=num_episodes+1) tf_policy_saver = policy_saver.PolicySaver(policy=agent.policy) for episode in range(num_episodes+1): (以下略
これを実行すると以下のエラーがでました
エラー内容
driver.run(maximum_iterations=1000) File "D:\Programs Files\Python\Python39\lib\site-packages\tf_agents\drivers\dynamic_step_driver.py", line 182, in run return self._run_fn( File "D:\Programs Files\Python\Python39\lib\site-packages\tf_agents\utils\common.py", line 188, in with_check_resource_vars return fn(*fn_args, **fn_kwargs) File "D:\Programs Files\Python\Python39\lib\site-packages\tf_agents\drivers\dynamic_step_driver.py", line 202, in _run tf.while_loop( File "D:\Programs Files\Python\Python39\lib\site-packages\tensorflow\python\util\deprecation.py", line 629, in new_func return func(*args, **kwargs) File "D:\Programs Files\Python\Python39\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2507, in while_loop_v2 return while_loop( File "D:\Programs Files\Python\Python39\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2754, in while_loop loop_vars = body(*loop_vars) File "D:\Programs Files\Python\Python39\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2745, in <lambda> body = lambda i, lv: (i + 1, orig_body(*lv)) File "D:\Programs Files\Python\Python39\lib\site-packages\tf_agents\drivers\dynamic_step_driver.py", line 135, in loop_body action_step = self.policy.action(time_step, policy_state) File "D:\Programs Files\Python\Python39\lib\site-packages\tf_agents\policies\tf_policy.py", line 324, in action step = action_fn(time_step=time_step, policy_state=policy_state, seed=seed) File "D:\Programs Files\Python\Python39\lib\site-packages\tf_agents\utils\common.py", line 188, in with_check_resource_vars return fn(*fn_args, **fn_kwargs) File "D:\Programs Files\Python\Python39\lib\site-packages\tf_agents\policies\epsilon_greedy_policy.py", line 115, in _action greedy_action = self._greedy_policy.action(time_step, policy_state) File "D:\Programs Files\Python\Python39\lib\site-packages\tf_agents\policies\tf_policy.py", line 324, in action step = action_fn(time_step=time_step, policy_state=policy_state, seed=seed) File "D:\Programs Files\Python\Python39\lib\site-packages\tf_agents\utils\common.py", line 188, in with_check_resource_vars return fn(*fn_args, **fn_kwargs) File "D:\Programs Files\Python\Python39\lib\site-packages\tf_agents\policies\tf_policy.py", line 560, in _action distribution_step = self._distribution(time_step, policy_state) # pytype: disable=wrong-arg-types File "D:\Programs Files\Python\Python39\lib\site-packages\tf_agents\policies\greedy_policy.py", line 80, in _distribution distribution_step = self._wrapped_policy.distribution( File "D:\Programs Files\Python\Python39\lib\site-packages\tf_agents\policies\tf_policy.py", line 403, in distribution step = self._distribution(time_step=time_step, policy_state=policy_state) File "D:\Programs Files\Python\Python39\lib\site-packages\tf_agents\policies\q_policy.py", line 156, in _distribution q_values, policy_state = self._q_network( File "D:\Programs Files\Python\Python39\lib\site-packages\tf_agents\networks\network.py", line 391, in __call__ nest_utils.assert_matching_dtypes_and_inner_shapes( File "D:\Programs Files\Python\Python39\lib\site-packages\tf_agents\utils\nest_utils.py", line 407, in assert_matching_dtypes_and_inner_shapes raise ValueError('{}: Inconsistent dtypes or shapes between {} and {}.\n' ValueError: <__main__.MyQNetwork object at 0x000001833D352550>: Inconsistent dtypes or shapes between `inputs` and `input_tensor_spec`. dtypes: <dtype: 'float32'> vs. <dtype: 'float32'>. shapes: (1, 99, 6) vs. (100, 6).
(100, 6)のデータを取り込んでも勝手に内部で(1,99,6)の形に変換されてる????
ちなみにネットワークや_observation_specの部分を(1, 99, 6)になるように変えても次は
(None,1, 99, 6)
vs.
(1, 99, 6)
と出てきます.
データの内容が違いますが同じデータ数のファイルを使ってもこのエラーが出ないものがあります.
エラーがでないときはきちんと学習できていましたが最近作成するExcelファイルにはすべてこのエラーを突き返されます.
excelは1000~ 10000行×6列のデータです.csvで保存して読み込んでも同じエラーが出ました.
またExcel内の小数点における処理によってエラーが発生しているかと思って,全部整数になるように数値変換してint型で読み込んでも同じエラーがでました.
エラーが出るものと出ないファイルのdataflameの型やサイズを調べても同じでした
具体的な解決策がわかりません.どうかよろしくお願いします.
まだ回答がついていません
会員登録して回答してみよう