質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%
Chainer

Chainerは、国産の深層学習フレームワークです。あらゆるニューラルネットワークをPythonで柔軟に書くことができ、学習させることが可能。GPUをサポートしており、複数のGPUを用いた学習も直感的に記述できます。

強化学習

強化学習とは、ある環境下のエージェントが現状を推測し行動を決定することで報酬を獲得するという見解から、その報酬を最大限に得る方策を学ぶ機械学習のことを指します。問題解決時に得る報酬が選択結果によって変化することで、より良い行動を選択しようと学習する点が特徴です。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Q&A

0回答

381閲覧

【LSTM】ChainerRLでのLSTM

junko_kobayashi

総合スコア11

Chainer

Chainerは、国産の深層学習フレームワークです。あらゆるニューラルネットワークをPythonで柔軟に書くことができ、学習させることが可能。GPUをサポートしており、複数のGPUを用いた学習も直感的に記述できます。

強化学習

強化学習とは、ある環境下のエージェントが現状を推測し行動を決定することで報酬を獲得するという見解から、その報酬を最大限に得る方策を学ぶ機械学習のことを指します。問題解決時に得る報酬が選択結果によって変化することで、より良い行動を選択しようと学習する点が特徴です。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

0グッド

0クリップ

投稿2022/03/10 22:31

編集2022/03/10 22:34

ChainerRLで深層強化学習をしています。

全結合ではできたので、LSTMでの構築を試みたのですが、次のようなエラーが出ます。

Python

1 import chainer 2 import chainer.functions as F 3 import chainer.links as L 4 import chainerrl 5 import gym 6 import numpy as np 7 import cupy as cp 8 import time 9 10 c = cp.array([0]*1024) 11 print(c) 12 13 env = MyEnv() 14 print(env.step(2)) 15 16 obs = env.reset() #初期化 17 action = env.action_space.sample() 18 obs, r, done, info = env.step(action) 19 20 ### どんな値が入っているのか確認 21 print('next observation : {}'.format(obs)) 22 print('reward : {}'.format(r)) 23 print('done : {}'.format(done)) 24 print('info : {}'.format(info)) 25 26 obs_size = env.observation_space.shape[0] 27 28 n_actions = env.action_space.n 29 30 class QFunction(chainer.Chain): 31 def __init__(self, n_actions): 32 super().__init__( 33 L0=L.Linear(obs_size , 1024), 34 L1=L.Linear(1024 , 512), 35 L2=L.Linear(512 , 256), 36 L3=L.Linear(256 , 128), 37 L4=L.LSTM(128, 64), 38 L5=L.Linear(64, 8), 39 L6=L.Linear(8, n_actions)) 40 41 42 def __call__(self, x, test=False): 43 global c 44 h = F.relu(self.L0(x)) 45 h = F.relu(self.L1(h)) 46 h = F.relu(self.L2(h)) 47 h = F.relu(self.L3(h)) 48 c, h = F.lstm(c, self.L4(h)) 49 50 h = F.relu(self.L5(h)) 51 52 return chainerrl.action_value.DiscreteActionValue(self.L6(h)) 53 54 q_func = QFunction(n_actions) 55 56

c, h = F.lstm(c, self.L4(h))>>> tuple index out of range

初回だけ使うcの値もいろいろ試したのですが、自力では解決できませんでした。考えられる原因を教えてください。

追記

Python

1 q_func.to_gpu(0) ## GPUを使いたい人はこのコメントを外す 2 3 optimizer = chainer.optimizers.Adam(eps=1e-2) 4 optimizer.setup(q_func) #設計したq関数の最適化にAdamを使う 5 gamma = 0.95 6 explorer = chainerrl.explorers.ConstantEpsilonGreedy( 7 epsilon=0.3, random_action_func=env.action_space.sample) 8 replay_buffer = chainerrl.replay_buffer.ReplayBuffer(capacity = 10**6) 9 phi = lambda x:x.astype(np.float32, copy=False)##型の変換(chainerはfloat32型。float64は駄目) 10 11 agent = chainerrl.agents.DoubleDQN( 12 q_func, optimizer, replay_buffer, gamma, explorer, 13 replay_start_size=500, update_interval=1, 14 target_update_interval=100, phi=phi)

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

まだ回答がついていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問