質問編集履歴
5
不要部分の取り消し
test
CHANGED
File without changes
|
test
CHANGED
@@ -65,14 +65,7 @@
|
|
65
65
|
[
|
66
66
|
keras.layers.InputLayer(input_shape=(100,6)),
|
67
67
|
keras.layers.LSTM(units=n_hidden_channels,dropout=set_dropout,return_sequences=True),
|
68
|
-
keras.layers.LSTM(units=n_hidden_channels,dropout=set_dropout,return_sequences=True),
|
69
|
-
keras.layers.LSTM(units=n_hidden_channels,dropout=set_dropout,return_sequences=False,kernel_regularizer=regularizers.l2(1e-4)),
|
70
|
-
keras.layers.Dense(units=n_action)
|
71
|
-
|
68
|
+
(略
|
72
|
-
)
|
73
|
-
def call(self, observation, step_type=None, network_state=(), training=True):
|
74
|
-
actions=self.model(observation, training=training)
|
75
|
-
return actions,network_state
|
76
69
|
|
77
70
|
```
|
78
71
|
|
4
初心者マークの追加
test
CHANGED
File without changes
|
test
CHANGED
@@ -195,6 +195,7 @@
|
|
195
195
|
と出てきます.
|
196
196
|
データの内容が違いますが同じデータ数のファイルを使ってもこのエラーが出ないものがあります.
|
197
197
|
エラーがでないときはきちんと学習できていましたが最近作成するExcelファイルにはすべてこのエラーを突き返されます.
|
198
|
+
excelは1000~ 10000行×6列のデータです.csvで保存して読み込んでも同じエラーが出ました.
|
198
199
|
またExcel内の小数点における処理によってエラーが発生しているかと思って,全部整数になるように数値変換してint型で読み込んでも同じエラーがでました.
|
199
200
|
エラーが出るものと出ないファイルのdataflameの型やサイズを調べても同じでした
|
200
201
|
|
3
step関数の戻り値の記載
test
CHANGED
@@ -1 +1 @@
|
|
1
|
-
|
1
|
+
tensorflow,Deep Q Networkに関するエラー
|
test
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
## 読み込むデータを変えただけでInconsistent dtypes or shapes between `inputs` and `input_tensor_spec`というエラーが出る
|
1
|
+
## 読み込むExcelデータを変えただけでInconsistent dtypes or shapes between `inputs` and `input_tensor_spec`というエラーが出る
|
2
2
|
2020年にも別の掲示板で似たような質問をしている人がいましたがtf-agentsAtariのアップデートを待てとのことでした...
|
3
3
|
内部で生成するQネットワークにエラーがある?
|
4
4
|
一部名前を修正しているのでミスがあったらすみません.
|
@@ -41,7 +41,11 @@
|
|
41
41
|
|
42
42
|
reward = 0
|
43
43
|
if action==0:
|
44
|
+
|
44
|
-
(
|
45
|
+
(step処理省略)
|
46
|
+
|
47
|
+
return ts.transition(np.array(self._state,dtype=np.float32), reward=reward)
|
48
|
+
|
45
49
|
```
|
46
50
|
上記のステップ関数内でget_data行×6列のデータ取得しself._stateに更新しています.
|
47
51
|
|
2
maximum_iterations=1000の追記
test
CHANGED
File without changes
|
test
CHANGED
@@ -120,7 +120,7 @@
|
|
120
120
|
observers=[replay_buffer.add_batch],
|
121
121
|
num_steps = 500,
|
122
122
|
)
|
123
|
-
driver.run()
|
123
|
+
driver.run(maximum_iterations=1000)
|
124
124
|
|
125
125
|
num_episodes = 200
|
126
126
|
epsilon = np.linspace(start=1.0, stop=0.0, num=num_episodes+1)
|
1
バージョンとmain関数の追記
test
CHANGED
File without changes
|
test
CHANGED
@@ -2,6 +2,10 @@
|
|
2
2
|
2020年にも別の掲示板で似たような質問をしている人がいましたがtf-agentsAtariのアップデートを待てとのことでした...
|
3
3
|
内部で生成するQネットワークにエラーがある?
|
4
4
|
一部名前を修正しているのでミスがあったらすみません.
|
5
|
+
###### ライブラリのバージョン
|
6
|
+
tensorflow 2.9.1
|
7
|
+
tf-agents 0.13.0
|
8
|
+
###### プログラム(一部抜粋)
|
5
9
|
```python
|
6
10
|
#EnvironmentSimulatorの一部
|
7
11
|
class EnvironmentSimulator(py_environment.PyEnvironment):
|
@@ -40,6 +44,8 @@
|
|
40
44
|
(以下略
|
41
45
|
```
|
42
46
|
上記のステップ関数内でget_data行×6列のデータ取得しself._stateに更新しています.
|
47
|
+
|
48
|
+
#### ネットワーク構成
|
43
49
|
```python
|
44
50
|
#ネットワーク構築
|
45
51
|
class MyQNetwork(network.Network):
|
@@ -65,15 +71,67 @@
|
|
65
71
|
return actions,network_state
|
66
72
|
|
67
73
|
```
|
74
|
+
|
75
|
+
#### メイン関数の一部
|
68
76
|
```python
|
69
77
|
#メイン関数の一部
|
70
|
-
GET_DATA=100
|
78
|
+
GET_DATA=100
|
71
|
-
mix_data=pd.read_excel('mix_data.xlsx',index_col=0)
|
79
|
+
mix_data=pd.read_excel('mix_data.xlsx',index_col=0)
|
72
|
-
|
80
|
+
|
73
|
-
env_py = EnvironmentSimulator(GET_DATA,mix_data)
|
81
|
+
env_py = EnvironmentSimulator(GET_DATA,mix_data)
|
74
|
-
env = tf_py_environment.TFPyEnvironment(env_py)
|
82
|
+
env = tf_py_environment.TFPyEnvironment(env_py)
|
83
|
+
primary_network = MyQNetwork(env.observation_spec(), env.action_spec())
|
84
|
+
#エージェントの設定
|
85
|
+
n_step_update = 1
|
86
|
+
agent = dqn_agent.DqnAgent(
|
87
|
+
env.time_step_spec(),
|
88
|
+
env.action_spec(),
|
89
|
+
q_network=primary_network,
|
90
|
+
optimizer=keras.optimizers.Adam(learning_rate=1e-3, epsilon=1e-5),
|
91
|
+
n_step_update=n_step_update,
|
92
|
+
epsilon_greedy=1.0,
|
93
|
+
target_update_tau=1.0,
|
94
|
+
target_update_period=10,
|
95
|
+
gamma=0.9,
|
96
|
+
td_errors_loss_fn = common.element_wise_squared_loss,
|
97
|
+
train_step_counter = tf.Variable(0)
|
98
|
+
)
|
99
|
+
agent.initialize()
|
100
|
+
agent.train = common.function(agent.train)
|
101
|
+
|
102
|
+
policy = agent.collect_policy
|
103
|
+
|
104
|
+
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
|
105
|
+
data_spec=agent.collect_data_spec,
|
106
|
+
batch_size=env.batch_size,
|
107
|
+
max_length=10**6
|
108
|
+
)
|
109
|
+
dataset = replay_buffer.as_dataset(
|
110
|
+
num_parallel_calls=tf.data.experimental.AUTOTUNE,
|
111
|
+
sample_batch_size=32,
|
112
|
+
num_steps=n_step_update+1
|
113
|
+
).prefetch(tf.data.experimental.AUTOTUNE)
|
114
|
+
iterator = iter(dataset)
|
115
|
+
|
116
|
+
env.reset()
|
117
|
+
driver = dynamic_step_driver.DynamicStepDriver(
|
118
|
+
env,
|
119
|
+
policy,
|
120
|
+
observers=[replay_buffer.add_batch],
|
121
|
+
num_steps = 500,
|
122
|
+
)
|
123
|
+
driver.run()
|
124
|
+
|
125
|
+
num_episodes = 200
|
126
|
+
epsilon = np.linspace(start=1.0, stop=0.0, num=num_episodes+1)
|
127
|
+
tf_policy_saver = policy_saver.PolicySaver(policy=agent.policy)
|
128
|
+
|
129
|
+
|
130
|
+
for episode in range(num_episodes+1):
|
131
|
+
(以下略
|
75
132
|
```
|
76
133
|
これを実行すると以下のエラーがでました
|
134
|
+
##### エラー内容
|
77
135
|
```ここに言語を入力
|
78
136
|
driver.run(maximum_iterations=1000)
|
79
137
|
File "D:\Programs Files\Python\Python39\lib\site-packages\tf_agents\drivers\dynamic_step_driver.py", line 182, in run
|
@@ -124,6 +182,8 @@
|
|
124
182
|
vs.
|
125
183
|
(100, 6).
|
126
184
|
```
|
185
|
+
(100, 6)のデータを取り込んでも勝手に内部で(1,99,6)の形に変換されてる????
|
186
|
+
|
127
187
|
ちなみにネットワークや_observation_specの部分を(1, 99, 6)になるように変えても次は
|
128
188
|
(None,1, 99, 6)
|
129
189
|
vs.
|
@@ -131,7 +191,10 @@
|
|
131
191
|
と出てきます.
|
132
192
|
データの内容が違いますが同じデータ数のファイルを使ってもこのエラーが出ないものがあります.
|
133
193
|
エラーがでないときはきちんと学習できていましたが最近作成するExcelファイルにはすべてこのエラーを突き返されます.
|
134
|
-
た
|
194
|
+
またExcel内の小数点における処理によってエラーが発生しているかと思って,全部整数になるように数値変換してint型で読み込んでも同じエラーがでました.
|
195
|
+
エラーが出るものと出ないファイルのdataflameの型やサイズを調べても同じでした
|
196
|
+
|
197
|
+
|
135
198
|
具体的な解決策がわかりません.どうかよろしくお願いします.
|
136
199
|
|
137
200
|
|