質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.50%
Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

Q&A

0回答

473閲覧

強化学習 pythonを用いた方策勾配の実装で困っています

Luisu

総合スコア10

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

0グッド

0クリップ

投稿2019/04/05 06:41

編集2022/01/12 10:55

お世話になります。
初めての質問です。
宜しくお願い致します。

やりたいこと: 方策勾配を用いて報酬を最大化したい。
わからないこと: NNを用いた場合の方策勾配の実装が合っているか、そもそも全体の実装コードが悪いのか。

NeuralNetworkを用いて方策勾配を求めてパラメータを更新して学習させたいのですが上手くいきません。
環境のCartPole問題を用いて常にポール直立させたいのですが学習が上手く行っておらず、すぐに倒れてしまします。
以下が実装したコードです。改善すべき箇所があれば御指摘を頂きたいです。
宜しくお願い致します。

python

1import numpy as np 2import gym 3import matplotlib.pyplot as plt 4%matplotlib inline 5 6np.random.seed(0) 7 8def e_greedy(prediction_action, epsilon_percentage, start_epsilon=0.5, end_epsilon=0.001): 9 epsilon = start_epsilon * (1 - epsilon_percentage) + end_epsilon * epsilon_percentage 10 if np.random.rand() < epsilon: 11 return np.random.randint(len(prediction_action[0])) 12 else: 13 return prediction_action[0].argmax() 14 15class Agent: 16 def __init__(self, state_size=4, action_size=2): 17 self.parms = { 18 'W1': np.random.randn(state_size, 16) * np.sqrt(2) * np.sqrt(6 / (state_size + 16)), 19 'B1': np.zeros(16), 20 'W2': np.random.randn(16, 16) * np.sqrt(2) * np.sqrt(6 / (16 + 16)), 21 'B2': np.zeros(16), 22 'W3': np.random.randn(16, action_size) * np.sqrt(2) * np.sqrt(6 / (16 + action_size)), 23 'B3': np.zeros(action_size), 24 } 25 self.grads = {} 26 27 def prediction(self, x): 28 self.x = x 29 self.ih = np.dot(self.x, self.parms['W1']) + self.parms['B1'] 30 # Relu 31 ih_relu = self.ih.copy() 32 ih_relu[ih_relu < 0] = 0 33 34 self.hh = np.dot(ih_relu, self.parms['W2']) + self.parms['B2'] 35 # Relu 36 hh_relu = self.hh.copy() 37 hh_relu[hh_relu < 0] = 0 38 39 self.ho = np.dot(hh_relu, self.parms['W3']) + self.parms['B3'] 40 # softmax 41 exp = np.exp(self.ho - np.max(self.ho)) 42 self.y = exp / np.sum(exp) 43 return self.y 44 45 def gradient(self, dx=1): 46 # Softmax Backward 47 dx = self.y - dx 48 self.grads['B3'] = np.sum(dx, 0) 49 self.grads['W3'] = np.dot(self.hh.T, dx) 50 dx = np.dot(dx, self.parms['W3'].T) 51 52 # Relu Backward 53 dx[dx < 0] = 0 54 self.grads['B2'] = np.sum(dx, 0) 55 self.grads['W2'] = np.dot(self.ih.T, dx) 56 dx = np.dot(dx, self.parms['W2'].T) 57 58 # Relu Backward 59 dx[dx < 0] = 0 60 self.grads['B1'] = np.sum(dx, 0) 61 self.grads['W1'] = np.dot(self.x.T, dx) 62 dx = np.dot(dx, self.parms['W1'].T) 63 64 def fit(self, state, action ,reward): 65 y = self.prediction(state) 66 67 # j = - 1 / M * log(y) * R 68 loss = - np.mean(np.log(y[np.arange(y.shape[0]), action]) * reward) 69 grad_loss = y 70 # grad_j = 1 / y * R 71 grad_loss[np.arange(y.shape[0]), action] = 1 / y[np.arange(y.shape[0]), action] * reward 72 73 # 勾配を求める 74 self.gradient(grad_loss) 75 76 #更新 77 self.Updata() 78 79 return loss 80 81 def Updata(self, lr=0.01): 82 for key in self.parms.keys(): 83 self.parms[key] -= lr * self.grads[key] 84 85def discount_rewards(rewards, gamma=0.98): 86 Rewards = np.zeros_like(rewards) 87 r = 0. 88 for t in reversed(range(len(rewards))): 89 r = rewards[t] + r * gamma 90 Rewards[t] = r 91 return Rewards 92 93class Buffer_hist: 94 def __init__(self): 95 self.states = [] 96 self.actions = [] 97 self.rewards = [] 98 self.next_state = [] 99 self.discounted_returns = [] 100 101 def add_buffer(self, state, action, reward, next_state): 102 self.states.append(state) 103 self.actions.append(action) 104 self.rewards.append(reward) 105 self.next_state.append(next_state) 106 107class Buffer: 108 def __init__(self): 109 self.states = [] 110 self.actions = [] 111 self.discounted_returns = [] 112 113 def add(self, hist): 114 self.states += hist.states 115 self.actions += hist.actions 116 self.discounted_returns += list(hist.discounted_returns) 117 118 def reset_buffer(self): 119 self.states = [] 120 self.actions = [] 121 self.discounted_returns = [] 122 123 124class Train: 125 def __init__(self, episode=5001, epsilon_stop=3000): 126 self.episode = episode 127 self.epsilon_stop = epsilon_stop 128 self.env = gym.make('CartPole-v0') 129 self.agent = Agent() 130 131 def play(self): 132 self.Rewards = [] 133 self.batch_loss = [] 134 global_buffer = Buffer() 135 136 for t in range(self.episode): 137 state = self.env.reset() 138 episode_reward = 0. 139 episode_hist = Buffer_hist() 140 epsilon_percentage = float(min(t / float(self.epsilon_stop), 1.)) 141 done = False 142 143 while not done: 144 action = e_greedy(self.agent.prediction([state]), epsilon_percentage) 145 next_state, reward, done, _ = self.env.step(action) 146 episode_hist.add_buffer(state, action, reward, next_state) 147 state = next_state 148 episode_reward += reward 149 150 if done: 151 episode_hist.discounted_returns = discount_rewards(episode_hist.rewards) 152 global_buffer.add(episode_hist) 153 154 # エピソード8回ごとに学習 155 if t % 8 == 0: 156 loss = self.agent.fit( 157 np.array(global_buffer.states), 158 np.array(global_buffer.actions), 159 np.array(global_buffer.discounted_returns) 160 ) 161 global_buffer.reset_buffer() 162 163 self.Rewards.append(episode_reward) 164 self.batch_loss.append(loss) 165 166 if t % 250 == 0: 167 print('episode: {} total reward: {}'.format(t, episode_reward)) 168 169T = Train() 170T.play() 171plt.plot(T.Rewards) 172plt.plot(T.batch_loss)

追伸

色々と方策勾配の実装について調べても見たのですが、ほとんどがTensorflowやkerasなどのライブラリを用いた実装ばかりで参考になりません。できればライブラリを使わずに実装したいです。
宜しくお願い致します。

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

tiitoi

2019/04/05 07:04 編集

> ほとんどがTensorflowやkerasなどのライブラリを用いた実装ばかりで参考になりません。 ニューラルネットワークをライブラリを使わずに実装するのは、逆伝播法とか全部自前で実装する必要があるので、「深層強化学習」を実装するという目的、本質からそれて、かなり遠回りすることになるのでオススメできません。 数値計算するのに、numpy を使わずに行列積とかを全部自前で実装したいと言っているようなものです。 深層強化学習をしたいのであれば、Tensorflow や PyTorch などのライブラリを使いましょう。
Luisu

2019/04/05 07:23

ご回答有難う御座います。 本質からは遠回りですが、ライブラリに頼って依存してしまうと他の言語で実装したい場合に不都合なので出来れば、依存しない方向で行きたいです。
guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

まだ回答がついていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.50%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問