🎄teratailクリスマスプレゼントキャンペーン2024🎄』開催中!

\teratail特別グッズやAmazonギフトカード最大2,000円分が当たる!/

詳細はこちら
深層学習

深層学習は、多数のレイヤのニューラルネットワークによる機械学習手法。人工知能研究の一つでディープラーニングとも呼ばれています。コンピューター自体がデータの潜在的な特徴を汲み取り、効率的で的確な判断を実現することができます。

強化学習

強化学習とは、ある環境下のエージェントが現状を推測し行動を決定することで報酬を獲得するという見解から、その報酬を最大限に得る方策を学ぶ機械学習のことを指します。問題解決時に得る報酬が選択結果によって変化することで、より良い行動を選択しようと学習する点が特徴です。

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

Q&A

解決済

1回答

2621閲覧

pfrl,pytorchでのR2D2の実装の際の引数のエラー

taiki6004

総合スコア0

深層学習

深層学習は、多数のレイヤのニューラルネットワークによる機械学習手法。人工知能研究の一つでディープラーニングとも呼ばれています。コンピューター自体がデータの潜在的な特徴を汲み取り、効率的で的確な判断を実現することができます。

強化学習

強化学習とは、ある環境下のエージェントが現状を推測し行動を決定することで報酬を獲得するという見解から、その報酬を最大限に得る方策を学ぶ機械学習のことを指します。問題解決時に得る報酬が選択結果によって変化することで、より良い行動を選択しようと学習する点が特徴です。

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

0グッド

0クリップ

投稿2020/12/22 09:51

編集2020/12/23 01:32

前提・実現したいこと

python,機械学習初心者です
pfrlでR2D2の実装を試みていますが下記のようなエラーメッセージが出てしまい実行できません
修正を試みたいのですがサイトパッケージの中での引数がどのように紐づけられているのかわかりません

発生している問題・エラーメッセージ

Traceback (most recent call last): File "R2D3.py", line 497, in <module> trainer.train() File "R2D3.py", line 312, in train action = self.agent.act(obs) File "/home/usr/.pyenv/versions/anaconda3-5.2.0/lib/python3.6/site-packages/pfrl/agent.py", line 161, in act return self.batch_act([obs])[0] File "/home/usr/.pyenv/versions/anaconda3-5.2.0/lib/python3.6/site-packages/pfrl/agents/dqn.py", line 483, in batch_act batch_av = self._evaluate_model_and_update_recurrent_states(batch_obs) File "/home/usr/.pyenv/versions/anaconda3-5.2.0/lib/python3.6/site-packages/pfrl/agents/dqn.py", line 471, in _evaluate_model_and_update_recurrent_states self.model, batch_xs, self.train_recurrent_states File "/home/usr/.pyenv/versions/anaconda3-5.2.0/lib/python3.6/site-packages/pfrl/utils/recurrent.py", line 142, in one_step_forward y, recurrent_state = rnn(pack, recurrent_state) File "/home/usr/.pyenv/versions/anaconda3-5.2.0/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) TypeError: forward() takes 2 positional arguments but 3 were given

ネットワークモデル

python

1class R2D3(nn.Module, StateQFunction): 2 """Distributional dueling fully-connected Q-function with discrete actions.""" 3 4 def __init__( 5 self, 6 n_actions, 7 n_input_channels, 8 n_added_input, 9 n_atoms = 51, 10 v_min = -120, 11 v_max = 1000, 12 activation=torch.relu, 13 bias=0.1, 14 ): 15 assert n_atoms >= 2 16 assert v_min < v_max 17 self.n_actions = n_actions 18 self.n_input_channels = n_input_channels 19 self.activation = activation 20 self.n_atoms = n_atoms 21 self.img_width = 48 22 self.img_height = 27 23 self.n_added_input = n_added_input 24 25 super().__init__() 26 27 self.pool = nn.MaxPool2d(2,2,ceil_mode=True) 28 self.drop1d_B = nn.Dropout(0.5) 29 self.drop1d_S = nn.Dropout(0.5) 30 self.drop2d_B = nn.Dropout2d(0.4) 31 self.drop2d_S = nn.Dropout2d(0.2) 32 self.z_values = torch.linspace(v_min, v_max, n_atoms, dtype=torch.float32) 33 34 # 畳み込み 35 36 self.conv1_1 = nn.Conv2d(n_input_channels, 8, 3) 37 nn.init.kaiming_normal_(self.conv1_1.weight) 38 self.conv1_2 = nn.Conv2d(8, 8, 3) 39 nn.init.kaiming_normal_(self.conv1_2.weight) 40 41 self.conv2_1 = nn.Conv2d(8, 16, 3) 42 nn.init.kaiming_normal_(self.conv2_1.weight) 43 self.conv2_2 = nn.Conv2d(16, 16, 3,1,1) 44 nn.init.kaiming_normal_(self.conv2_2.weight) 45 46 self.conv3_1 = nn.Conv2d(16, 32, 3) 47 nn.init.kaiming_normal_(self.conv1_1.weight) 48 self.conv3_2 = nn.Conv2d(32, 32,3) 49 nn.init.kaiming_normal_(self.conv1_2.weight) 50 self.conv3_3 = nn.Conv2d(32, 64,3) 51 nn.init.kaiming_normal_(self.conv1_2.weight) 52 53 54 # Advantage 55 self.fc1 = nn.LSTM(128*6, 512,batch_first=True) 56 #nn.init.kaiming_normal_(self.al1.weight) 57 58 self.fc2 = nn.Linear(512, 512) 59 nn.init.kaiming_normal_(self.fc2.weight) 60 61 self.fc3 = nn.Linear(512, n_actions*n_atoms) 62 nn.init.kaiming_normal_(self.fc3.weight) 63 64 def forward(self, state): 65 img = state[:,:-self.n_added_input] 66 sen = state[:,-self.n_added_input:] 67 68 img = torch.reshape(img,(-1,self.n_input_channels, self.img_width, self.img_height)) 69 70 h = F.relu(self.conv1_1(img)) 71 h = self.pool(F.relu(self.conv1_2(h))) 72 h = F.relu(self.conv2_1(h)) 73 h = self.pool(F.relu(self.conv2_2(h))) 74 h = F.relu(self.conv3_1(h)) 75 h = F.relu(self.conv3_2(h)) 76 h = self.pool(F.relu(self.conv3_3(h))) 77 78 h = h.view(-1, 2048)#reshape 79 80 # Advantage 81 batch_size = img.shape[0] 82 83 h= torch.cat((torch.reshape(h,(h.shape[0], -1)), sen),dim=1) 84 85 h = F.relu(self.fc1(h)) 86 h = F.relu(self.fc2(h)) 87 h = F.relu(self.fc3(h)) 88 89 q = F.softmax(h, dim=2) 90 91 self.z_values = self.z_values.to(state.device) 92 return pfrl.action_value.DistributionalDiscreteActionValue(q, self.z_values)

DQNの登録

python

1self.agent = pfrl.agents.CategoricalDoubleDQN( 2 self.q_func, optimizer=self.optimizer, replay_buffer=self.rbuf, 3 gamma=0.99, explorer=self.explorer, gpu=0, minibatch_size=64, 4 replay_start_size=self.replay_start_size,target_update_interval=self.target_update_interval, 5 phi=self.phi,update_interval=1,batch_accumulator="mean",recurrent=True 6)

補足情報

Rainbowの実装はできたのでカテゴリカルDQN周りは大丈夫だと思います。

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

0

自己解決

####pfrlの関数を使用したらエラーメッセージはでなくなりました
以下コードになります

python

1if q_func == 'R2D3': 2 self.q_func =pfrl.nn.RecurrentSequential(network.R2D3(n_actions,n_input_channels,n_added_input=8)) 3

一応動いてはいるのですがまだ動作確認はできていません
正しいやり方を知っておられる方がおられましたら教えていただけると幸いです

投稿2020/12/24 08:22

編集2020/12/24 08:25
taiki6004

総合スコア0

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.36%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問