前提・実現したいこと
python,機械学習初心者です
pfrlでR2D2の実装を試みていますが下記のようなエラーメッセージが出てしまい実行できません
修正を試みたいのですがサイトパッケージの中での引数がどのように紐づけられているのかわかりません
発生している問題・エラーメッセージ
Traceback (most recent call last): File "R2D3.py", line 497, in <module> trainer.train() File "R2D3.py", line 312, in train action = self.agent.act(obs) File "/home/usr/.pyenv/versions/anaconda3-5.2.0/lib/python3.6/site-packages/pfrl/agent.py", line 161, in act return self.batch_act([obs])[0] File "/home/usr/.pyenv/versions/anaconda3-5.2.0/lib/python3.6/site-packages/pfrl/agents/dqn.py", line 483, in batch_act batch_av = self._evaluate_model_and_update_recurrent_states(batch_obs) File "/home/usr/.pyenv/versions/anaconda3-5.2.0/lib/python3.6/site-packages/pfrl/agents/dqn.py", line 471, in _evaluate_model_and_update_recurrent_states self.model, batch_xs, self.train_recurrent_states File "/home/usr/.pyenv/versions/anaconda3-5.2.0/lib/python3.6/site-packages/pfrl/utils/recurrent.py", line 142, in one_step_forward y, recurrent_state = rnn(pack, recurrent_state) File "/home/usr/.pyenv/versions/anaconda3-5.2.0/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) TypeError: forward() takes 2 positional arguments but 3 were given
ネットワークモデル
python
1class R2D3(nn.Module, StateQFunction): 2 """Distributional dueling fully-connected Q-function with discrete actions.""" 3 4 def __init__( 5 self, 6 n_actions, 7 n_input_channels, 8 n_added_input, 9 n_atoms = 51, 10 v_min = -120, 11 v_max = 1000, 12 activation=torch.relu, 13 bias=0.1, 14 ): 15 assert n_atoms >= 2 16 assert v_min < v_max 17 self.n_actions = n_actions 18 self.n_input_channels = n_input_channels 19 self.activation = activation 20 self.n_atoms = n_atoms 21 self.img_width = 48 22 self.img_height = 27 23 self.n_added_input = n_added_input 24 25 super().__init__() 26 27 self.pool = nn.MaxPool2d(2,2,ceil_mode=True) 28 self.drop1d_B = nn.Dropout(0.5) 29 self.drop1d_S = nn.Dropout(0.5) 30 self.drop2d_B = nn.Dropout2d(0.4) 31 self.drop2d_S = nn.Dropout2d(0.2) 32 self.z_values = torch.linspace(v_min, v_max, n_atoms, dtype=torch.float32) 33 34 # 畳み込み 35 36 self.conv1_1 = nn.Conv2d(n_input_channels, 8, 3) 37 nn.init.kaiming_normal_(self.conv1_1.weight) 38 self.conv1_2 = nn.Conv2d(8, 8, 3) 39 nn.init.kaiming_normal_(self.conv1_2.weight) 40 41 self.conv2_1 = nn.Conv2d(8, 16, 3) 42 nn.init.kaiming_normal_(self.conv2_1.weight) 43 self.conv2_2 = nn.Conv2d(16, 16, 3,1,1) 44 nn.init.kaiming_normal_(self.conv2_2.weight) 45 46 self.conv3_1 = nn.Conv2d(16, 32, 3) 47 nn.init.kaiming_normal_(self.conv1_1.weight) 48 self.conv3_2 = nn.Conv2d(32, 32,3) 49 nn.init.kaiming_normal_(self.conv1_2.weight) 50 self.conv3_3 = nn.Conv2d(32, 64,3) 51 nn.init.kaiming_normal_(self.conv1_2.weight) 52 53 54 # Advantage 55 self.fc1 = nn.LSTM(128*6, 512,batch_first=True) 56 #nn.init.kaiming_normal_(self.al1.weight) 57 58 self.fc2 = nn.Linear(512, 512) 59 nn.init.kaiming_normal_(self.fc2.weight) 60 61 self.fc3 = nn.Linear(512, n_actions*n_atoms) 62 nn.init.kaiming_normal_(self.fc3.weight) 63 64 def forward(self, state): 65 img = state[:,:-self.n_added_input] 66 sen = state[:,-self.n_added_input:] 67 68 img = torch.reshape(img,(-1,self.n_input_channels, self.img_width, self.img_height)) 69 70 h = F.relu(self.conv1_1(img)) 71 h = self.pool(F.relu(self.conv1_2(h))) 72 h = F.relu(self.conv2_1(h)) 73 h = self.pool(F.relu(self.conv2_2(h))) 74 h = F.relu(self.conv3_1(h)) 75 h = F.relu(self.conv3_2(h)) 76 h = self.pool(F.relu(self.conv3_3(h))) 77 78 h = h.view(-1, 2048)#reshape 79 80 # Advantage 81 batch_size = img.shape[0] 82 83 h= torch.cat((torch.reshape(h,(h.shape[0], -1)), sen),dim=1) 84 85 h = F.relu(self.fc1(h)) 86 h = F.relu(self.fc2(h)) 87 h = F.relu(self.fc3(h)) 88 89 q = F.softmax(h, dim=2) 90 91 self.z_values = self.z_values.to(state.device) 92 return pfrl.action_value.DistributionalDiscreteActionValue(q, self.z_values)
DQNの登録
python
1self.agent = pfrl.agents.CategoricalDoubleDQN( 2 self.q_func, optimizer=self.optimizer, replay_buffer=self.rbuf, 3 gamma=0.99, explorer=self.explorer, gpu=0, minibatch_size=64, 4 replay_start_size=self.replay_start_size,target_update_interval=self.target_update_interval, 5 phi=self.phi,update_interval=1,batch_accumulator="mean",recurrent=True 6)
補足情報
Rainbowの実装はできたのでカテゴリカルDQN周りは大丈夫だと思います。
回答1件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。