教材「作りながら学ぶ!深層強化学習Py-torchによる実践プログラミング」
に乗っているコードを打ち込みDQNを実装しました。
*以下その一部
python
1class Environment: 2 def __init__(self): 3 self.env = gym.make(ENV) #実行する課題を設定 4 self.num_states = self.env.observation_space.shape[0] 5 #課題の状態と行動の数を設定 6 self.num_actions = self.env.action_space.n #CartPoleの行動(右に左に押す) 7 #環境内で行動するAgentを生成 8 self.agent = Agent(self.num_states, self.num_actions) 9 10 def run(self): 11 """実行""" 12 episode_10_list = np.zeros(10) #10試行分の対地続けたstep数を格納しm平均ステップ数を出力に利用 13 complete_episodes = 0 #195step以上連続で立ち続けた試行回数 14 episode_final = False #最後の試行フラグ 15 frames = [] #最後の試行を動画にするため画像を格納する変数 16 17 for episode in range(NUM_EPISODES): #試行数分繰り返す 18 observation = self.env.reset() #感興の初期化 19 20 state = observation #観測をそのまま状態sとして使用 21 state = torch.from_numpy(state).type( 22 torch.FloatTensor)#numpy変数をPytorchのテンソルに変換 23 #FloatTensorof size 4をsize 1×4に変換 24 state = torch.unsqueeze(state, 0) 25 26 for step in range(MAX_STEPS): #1エピソードのループ 27 if episode_final is True: #最終試行ではFramesに各時刻の画像わ追加していく 28 frames.append(self.env.render(mode='rgb_array')) 29 action = self.agent.get_action(state, episode) #行動を求める 30 31 #行動a_tの実行により、s_{t+1}とdoneフラグを求める 32 #actionから.item()を指定して、中身を取り出す 33 observation_next, _, done, _ = self.env.step( 34 action.item()) #rewardとinfoは使わないので_にする 35 36 #報酬を与える。さらにepisodeの終了評価とstate_nextを設定する 37 if done: #ステップ数が200経過するか、一定角度以上傾くとdoneはtrueになる 38 state_next = None #次の状態はないので、Noneを格納 39 40 #直近10episodeの縦たstep数リストに追加 41 episode_10_list = np.hstack( 42 (episode_10_list[1:], step + 1)) 43 44 if step < 195: 45 reward = torch.FloatTensor( 46 [-1.0]) #途中でこけたら罰則として報酬-1を与える 47 complete_episodes = 0 #連続成功履歴をリセット 48 else: 49 reward = torch.FloatTensor([1.0]) # 立ったまま終了時は報酬1を与える 50 51 complete_episodes = complete_episodes + 1 52 else: 53 reward = torch.FloatTensor([0.0]) #普段は報酬0 54 state_next = observation_next #観測をそのまま状態とする 55 state_next = torch.from_numpy(state_next).type( 56 torch.FloatTensor) #numpy変数をpytorchのテンソルに変換 57 #FloatTensorof size 4をsize 1×4に変換 58 state_next = torch.unsqueeze(state_next, 0) 59 60 #メモリに経験を追加 61 self.agent.memorize(state, action, state_next, reward) 62 63 #Experience ReplayでQ関数を更新する 64 self.agent.update_q_function() 65 66 #観測の更新 67 state = state_next 68 69 #終了時の処理 70 if done: 71 print('%d Episode: Finished after %d steps: 10試行の平均step数 = %.1lf' %( 72 episode, step+1, episode_10_list.mean())) 73 break 74 if episode_final is True: 75 #動画を保存と描画 76 display_frames_as_gif(frames) 77 break 78 79 #10連続で200step立ち続けたら成功 80 if complete_episodes >=10: 81 print('10回連続成功') 82 episode_final = True #次の試行を描画を行う最終試行とする
これらを実装したところ「**indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead. **」と表示されてしまいました。
要約すると、「unit8は非推奨になったのでtorch.boolを使ってください」と書かれています
恐らく21行目のtype(
が関係しているのではないかと思いコードの一部を持ってきました。ほかにtype関数はなかったので。
どのように打ち換えればいいか知恵をお貸しください。
あなたの回答
tips
プレビュー