私はDQNでCartPoleを解きたいと考えています。
しかし、うまく学習しません。
何度も同じ行動をとってしまい(rightかleftに出力が偏る)すぐに失敗してしまいます。
どの辺がおかしいのでしょうか?
参考
CartPoleでDQN(deep Q-learning)、DDQNを実装・解説【Phythonで強化学習:第2回】
python
1import gym 2from collections import deque 3import keras 4from keras import optimizers 5from keras import losses 6from keras.models import Sequential 7from keras.layers import Dense 8from keras.optimizers import Adam 9from keras.utils import plot_model 10from keras import backend as K 11import matplotlib.pyplot as plt 12import numpy as np 13import random 14import sys 15import time 16 17EPOCH = 200 18BATCH_SIZE = 32 19CLEAR_TURN = 100 20GAMMA = 0.99 21 22def getAction(model,observation,episode): 23 y = model.predict(observation.reshape((1,4))) 24 if (0.01 +0.9/(1.0 + episode)) <= np.random.uniform(0,1): 25 return np.argmax(y) 26 else: 27 return np.random.choice([0, 1]) 28 29def learn(model,data): 30 y_pred = [] 31 y_true = [] 32 for d in data: 33 state,action,reward,nextState = d 34 y = model.predict(state) 35 if not (nextState == np.zeros(state.shape)).all(axis=1): 36 y[0][action] = reward + GAMMA * np.max(model.predict(nextState)[0]) 37 else: 38 y[0][action] = reward 39 y_pred.append(state.reshape(4)) 40 y_true.append(y.reshape(2)) 41 model.fit(np.array(y_pred),np.array(y_true),batch_size=BATCH_SIZE,verbose=0,epochs=1) 42 43def huberloss(y_true, y_pred): 44 return K.mean(K.minimum(0.5*K.square(y_pred-y_true), K.abs(y_pred-y_true)-0.5), axis=1) 45 46def main(args): 47 48 env = gym.make('CartPole-v0') 49 model = Sequential() 50 model.add(Dense(16,activation="relu",input_dim=4)) 51 model.add(Dense(16,activation="relu")) 52 model.add(Dense(2,activation="linear")) 53 model.compile(loss=huberloss, optimizer=Adam(lr=0.00001)) 54 model.summary() 55 56 data = deque(maxlen=200) 57 58 plot_x = [] 59 plot_y = [] 60 61 point = 0 62 63 for episode in range(EPOCH): 64 observation = env.reset() 65 nextObservation, reward, done, info = env.step(env.action_space.sample()) 66 67 for t in range(CLEAR_TURN * 2): 68 69 if EPOCH - episode == 2: 70 env.render() 71 time.sleep(0.1) 72 73 action = getAction(model,observation,episode) 74 nextObservation, reward, done, info = env.step(action) 75 if done: 76 if t > CLEAR_TURN: 77 data.append((observation.reshape((1,4)),action,1,np.zeros((1,4)))) 78 else: 79 data.append((observation.reshape((1,4)),action,-1,np.zeros((1,4)))) 80 print("{} times : finished after {} timestamps".format(episode,t+1)) 81 82 plot_x.append(episode) 83 plot_y.append(t + 1) 84 85 break 86 else: 87 data.append((observation.reshape((1,4)),action,0,nextObservation.reshape((1,4)))) 88 observation = nextObservation 89 if BATCH_SIZE < len(data): 90 learn(model,data) 91 92 plt.scatter(plot_x,plot_y,marker="+") 93 plt.show() 94 95main(sys.argv)
あなたの回答
tips
プレビュー