質問編集履歴

5

不要部分の取り消し

2022/06/24 01:55

投稿

yoshiya
yoshiya

スコア20

test CHANGED
File without changes
test CHANGED
@@ -65,14 +65,7 @@
65
65
  [
66
66
  keras.layers.InputLayer(input_shape=(100,6)),
67
67
  keras.layers.LSTM(units=n_hidden_channels,dropout=set_dropout,return_sequences=True),
68
- keras.layers.LSTM(units=n_hidden_channels,dropout=set_dropout,return_sequences=True),
69
- keras.layers.LSTM(units=n_hidden_channels,dropout=set_dropout,return_sequences=False,kernel_regularizer=regularizers.l2(1e-4)),
70
- keras.layers.Dense(units=n_action)
71
- ]
68
+ (略
72
- )
73
- def call(self, observation, step_type=None, network_state=(), training=True):
74
- actions=self.model(observation, training=training)
75
- return actions,network_state
76
69
 
77
70
  ```
78
71
 

4

初心者マークの追加

2022/06/23 06:23

投稿

yoshiya
yoshiya

スコア20

test CHANGED
File without changes
test CHANGED
@@ -195,6 +195,7 @@
195
195
  と出てきます.
196
196
  データの内容が違いますが同じデータ数のファイルを使ってもこのエラーが出ないものがあります.
197
197
  エラーがでないときはきちんと学習できていましたが最近作成するExcelファイルにはすべてこのエラーを突き返されます.
198
+ excelは1000~ 10000行×6列のデータです.csvで保存して読み込んでも同じエラーが出ました.
198
199
  またExcel内の小数点における処理によってエラーが発生しているかと思って,全部整数になるように数値変換してint型で読み込んでも同じエラーがでました.
199
200
  エラーが出るものと出ないファイルのdataflameの型やサイズを調べても同じでした
200
201
 

3

step関数の戻り値の記載

2022/06/22 15:43

投稿

yoshiya
yoshiya

スコア20

test CHANGED
@@ -1 +1 @@
1
- Deep Q Network のInconsistent dtypes or shapes between `inputs` and `input_tensor_spec`エラー
1
+ tensorflow,Deep Q Networkに関するエラー
test CHANGED
@@ -1,4 +1,4 @@
1
- ## 読み込むデータを変えただけでInconsistent dtypes or shapes between `inputs` and `input_tensor_spec`というエラーが出る
1
+ ## 読み込むExcelデータを変えただけでInconsistent dtypes or shapes between `inputs` and `input_tensor_spec`というエラーが出る
2
2
  2020年にも別の掲示板で似たような質問をしている人がいましたがtf-agentsAtariのアップデートを待てとのことでした...
3
3
  内部で生成するQネットワークにエラーがある?
4
4
  一部名前を修正しているのでミスがあったらすみません.
@@ -41,7 +41,11 @@
41
41
 
42
42
  reward = 0
43
43
  if action==0:
44
+
44
- (以下
45
+ (step処理省)
46
+
47
+ return ts.transition(np.array(self._state,dtype=np.float32), reward=reward)
48
+
45
49
  ```
46
50
  上記のステップ関数内でget_data行×6列のデータ取得しself._stateに更新しています.
47
51
 

2

maximum_iterations=1000の追記

2022/06/22 15:25

投稿

yoshiya
yoshiya

スコア20

test CHANGED
File without changes
test CHANGED
@@ -120,7 +120,7 @@
120
120
  observers=[replay_buffer.add_batch],
121
121
  num_steps = 500,
122
122
  )
123
- driver.run()
123
+ driver.run(maximum_iterations=1000)
124
124
 
125
125
  num_episodes = 200
126
126
  epsilon = np.linspace(start=1.0, stop=0.0, num=num_episodes+1)

1

バージョンとmain関数の追記

2022/06/22 12:22

投稿

yoshiya
yoshiya

スコア20

test CHANGED
File without changes
test CHANGED
@@ -2,6 +2,10 @@
2
2
  2020年にも別の掲示板で似たような質問をしている人がいましたがtf-agentsAtariのアップデートを待てとのことでした...
3
3
  内部で生成するQネットワークにエラーがある?
4
4
  一部名前を修正しているのでミスがあったらすみません.
5
+ ###### ライブラリのバージョン
6
+ tensorflow 2.9.1
7
+ tf-agents 0.13.0
8
+ ###### プログラム(一部抜粋)
5
9
  ```python
6
10
  #EnvironmentSimulatorの一部
7
11
  class EnvironmentSimulator(py_environment.PyEnvironment):
@@ -40,6 +44,8 @@
40
44
  (以下略
41
45
  ```
42
46
  上記のステップ関数内でget_data行×6列のデータ取得しself._stateに更新しています.
47
+
48
+ #### ネットワーク構成
43
49
  ```python
44
50
  #ネットワーク構築
45
51
  class MyQNetwork(network.Network):
@@ -65,15 +71,67 @@
65
71
  return actions,network_state
66
72
 
67
73
  ```
74
+
75
+ #### メイン関数の一部
68
76
  ```python
69
77
  #メイン関数の一部
70
- GET_DATA=100
78
+ GET_DATA=100
71
- mix_data=pd.read_excel('mix_data.xlsx',index_col=0)
79
+ mix_data=pd.read_excel('mix_data.xlsx',index_col=0)
72
-
80
+
73
- env_py = EnvironmentSimulator(GET_DATA,mix_data)
81
+ env_py = EnvironmentSimulator(GET_DATA,mix_data)
74
- env = tf_py_environment.TFPyEnvironment(env_py)
82
+ env = tf_py_environment.TFPyEnvironment(env_py)
83
+ primary_network = MyQNetwork(env.observation_spec(), env.action_spec())
84
+ #エージェントの設定
85
+ n_step_update = 1
86
+ agent = dqn_agent.DqnAgent(
87
+ env.time_step_spec(),
88
+ env.action_spec(),
89
+ q_network=primary_network,
90
+ optimizer=keras.optimizers.Adam(learning_rate=1e-3, epsilon=1e-5),
91
+ n_step_update=n_step_update,
92
+ epsilon_greedy=1.0,
93
+ target_update_tau=1.0,
94
+ target_update_period=10,
95
+ gamma=0.9,
96
+ td_errors_loss_fn = common.element_wise_squared_loss,
97
+ train_step_counter = tf.Variable(0)
98
+ )
99
+ agent.initialize()
100
+ agent.train = common.function(agent.train)
101
+
102
+ policy = agent.collect_policy
103
+
104
+ replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
105
+ data_spec=agent.collect_data_spec,
106
+ batch_size=env.batch_size,
107
+ max_length=10**6
108
+ )
109
+ dataset = replay_buffer.as_dataset(
110
+ num_parallel_calls=tf.data.experimental.AUTOTUNE,
111
+ sample_batch_size=32,
112
+ num_steps=n_step_update+1
113
+ ).prefetch(tf.data.experimental.AUTOTUNE)
114
+ iterator = iter(dataset)
115
+
116
+ env.reset()
117
+ driver = dynamic_step_driver.DynamicStepDriver(
118
+ env,
119
+ policy,
120
+ observers=[replay_buffer.add_batch],
121
+ num_steps = 500,
122
+ )
123
+ driver.run()
124
+
125
+ num_episodes = 200
126
+ epsilon = np.linspace(start=1.0, stop=0.0, num=num_episodes+1)
127
+ tf_policy_saver = policy_saver.PolicySaver(policy=agent.policy)
128
+
129
+
130
+ for episode in range(num_episodes+1):
131
+ (以下略
75
132
  ```
76
133
  これを実行すると以下のエラーがでました
134
+ ##### エラー内容
77
135
  ```ここに言語を入力
78
136
  driver.run(maximum_iterations=1000)
79
137
  File "D:\Programs Files\Python\Python39\lib\site-packages\tf_agents\drivers\dynamic_step_driver.py", line 182, in run
@@ -124,6 +182,8 @@
124
182
  vs.
125
183
  (100, 6).
126
184
  ```
185
+ (100, 6)のデータを取り込んでも勝手に内部で(1,99,6)の形に変換されてる????
186
+
127
187
  ちなみにネットワークや_observation_specの部分を(1, 99, 6)になるように変えても次は
128
188
  (None,1, 99, 6)
129
189
  vs.
@@ -131,7 +191,10 @@
131
191
  と出てきます.
132
192
  データの内容が違いますが同じデータ数のファイルを使ってもこのエラーが出ないものがあります.
133
193
  エラーがでないときはきちんと学習できていましたが最近作成するExcelファイルにはすべてこのエラーを突き返されます.
134
- ぶんですがExcel内の小数点における処理によってエラーが発生してるものかと思われます(type関数それぞれのdataflameを調べても同じでした)
194
+ たExcel内の小数点における処理によってエラーが発生してるかと思って,全部整数になるように数値変換してint読み込んでも同じエラーがした
195
+ エラーが出るものと出ないファイルのdataflameの型やサイズを調べても同じでした
196
+
197
+
135
198
  具体的な解決策がわかりません.どうかよろしくお願いします.
136
199
 
137
200