質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

ただいまの
回答率

87.61%

pfrl,pytorchでのR2D2の実装の際の引数のエラー

解決済

回答 1

投稿 編集

  • 評価
  • クリップ 0
  • VIEW 658

score 0

前提・実現したいこと

python,機械学習初心者です
pfrlでR2D2の実装を試みていますが下記のようなエラーメッセージが出てしまい実行できません
修正を試みたいのですがサイトパッケージの中での引数がどのように紐づけられているのかわかりません

発生している問題・エラーメッセージ

Traceback (most recent call last):
  File "R2D3.py", line 497, in <module>
    trainer.train()
  File "R2D3.py", line 312, in train
    action = self.agent.act(obs)
  File "/home/usr/.pyenv/versions/anaconda3-5.2.0/lib/python3.6/site-packages/pfrl/agent.py", line 161, in act
    return self.batch_act([obs])[0]
  File "/home/usr/.pyenv/versions/anaconda3-5.2.0/lib/python3.6/site-packages/pfrl/agents/dqn.py", line 483, in batch_act
    batch_av = self._evaluate_model_and_update_recurrent_states(batch_obs)
  File "/home/usr/.pyenv/versions/anaconda3-5.2.0/lib/python3.6/site-packages/pfrl/agents/dqn.py", line 471, in _evaluate_model_and_update_recurrent_states
    self.model, batch_xs, self.train_recurrent_states
  File "/home/usr/.pyenv/versions/anaconda3-5.2.0/lib/python3.6/site-packages/pfrl/utils/recurrent.py", line 142, in one_step_forward
    y, recurrent_state = rnn(pack, recurrent_state)
  File "/home/usr/.pyenv/versions/anaconda3-5.2.0/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
TypeError: forward() takes 2 positional arguments but 3 were given

ネットワークモデル

class R2D3(nn.Module, StateQFunction):
    """Distributional dueling fully-connected Q-function with discrete actions."""

    def __init__(
        self,
        n_actions,
        n_input_channels,
        n_added_input,
        n_atoms = 51,
        v_min = -120,
        v_max = 1000,
        activation=torch.relu,
        bias=0.1,
    ):
        assert n_atoms >= 2
        assert v_min < v_max
        self.n_actions = n_actions
        self.n_input_channels = n_input_channels
        self.activation = activation
        self.n_atoms = n_atoms
        self.img_width = 48
        self.img_height = 27
        self.n_added_input = n_added_input

        super().__init__()

        self.pool = nn.MaxPool2d(2,2,ceil_mode=True)
        self.drop1d_B = nn.Dropout(0.5)
        self.drop1d_S = nn.Dropout(0.5)
        self.drop2d_B = nn.Dropout2d(0.4)
        self.drop2d_S = nn.Dropout2d(0.2)
        self.z_values = torch.linspace(v_min, v_max, n_atoms, dtype=torch.float32)

        # 畳み込み

        self.conv1_1 = nn.Conv2d(n_input_channels, 8, 3)
        nn.init.kaiming_normal_(self.conv1_1.weight)
        self.conv1_2 = nn.Conv2d(8, 8, 3)
        nn.init.kaiming_normal_(self.conv1_2.weight)

        self.conv2_1 = nn.Conv2d(8, 16, 3)
        nn.init.kaiming_normal_(self.conv2_1.weight)
        self.conv2_2 = nn.Conv2d(16, 16, 3,1,1)
        nn.init.kaiming_normal_(self.conv2_2.weight)

        self.conv3_1 = nn.Conv2d(16, 32, 3)
        nn.init.kaiming_normal_(self.conv1_1.weight)
        self.conv3_2 = nn.Conv2d(32, 32,3)
        nn.init.kaiming_normal_(self.conv1_2.weight)
        self.conv3_3 = nn.Conv2d(32, 64,3)
        nn.init.kaiming_normal_(self.conv1_2.weight)


        # Advantage
        self.fc1 = nn.LSTM(128*6, 512,batch_first=True)
        #nn.init.kaiming_normal_(self.al1.weight)

        self.fc2 = nn.Linear(512, 512)
        nn.init.kaiming_normal_(self.fc2.weight)

        self.fc3 = nn.Linear(512, n_actions*n_atoms)
        nn.init.kaiming_normal_(self.fc3.weight)

    def forward(self, state):
        img = state[:,:-self.n_added_input]
        sen = state[:,-self.n_added_input:]

        img = torch.reshape(img,(-1,self.n_input_channels, self.img_width, self.img_height))

        h = F.relu(self.conv1_1(img))
        h = self.pool(F.relu(self.conv1_2(h)))
        h = F.relu(self.conv2_1(h))
        h = self.pool(F.relu(self.conv2_2(h)))
        h = F.relu(self.conv3_1(h))
        h = F.relu(self.conv3_2(h))
        h = self.pool(F.relu(self.conv3_3(h)))

        h = h.view(-1,  2048)#reshape 

        # Advantage
        batch_size = img.shape[0]

        h= torch.cat((torch.reshape(h,(h.shape[0], -1)), sen),dim=1)

        h = F.relu(self.fc1(h))
        h = F.relu(self.fc2(h))
        h = F.relu(self.fc3(h))

        q = F.softmax(h, dim=2)

        self.z_values = self.z_values.to(state.device)
        return pfrl.action_value.DistributionalDiscreteActionValue(q, self.z_values)

DQNの登録

self.agent = pfrl.agents.CategoricalDoubleDQN(
  self.q_func, optimizer=self.optimizer, replay_buffer=self.rbuf,
  gamma=0.99, explorer=self.explorer, gpu=0, minibatch_size=64,
  replay_start_size=self.replay_start_size,target_update_interval=self.target_update_interval,
  phi=self.phi,update_interval=1,batch_accumulator="mean",recurrent=True
)

補足情報

Rainbowの実装はできたのでカテゴリカルDQN周りは大丈夫だと思います。

  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 過去に投稿した質問と同じ内容の質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

回答 1

check解決した方法

0

pfrlの関数を使用したらエラーメッセージはでなくなりました

以下コードになります

if q_func == 'R2D3':
            self.q_func =pfrl.nn.RecurrentSequential(network.R2D3(n_actions,n_input_channels,n_added_input=8))


一応動いてはいるのですがまだ動作確認はできていません
正しいやり方を知っておられる方がおられましたら教えていただけると幸いです

投稿

編集

  • 回答の評価を上げる

    以下のような回答は評価を上げましょう

    • 正しい回答
    • わかりやすい回答
    • ためになる回答

    評価が高い回答ほどページの上位に表示されます。

  • 回答の評価を下げる

    下記のような回答は推奨されていません。

    • 間違っている回答
    • 質問の回答になっていない投稿
    • スパムや攻撃的な表現を用いた投稿

    評価を下げる際はその理由を明確に伝え、適切な回答に修正してもらいましょう。

15分調べてもわからないことは、teratailで質問しよう!

  • ただいまの回答率 87.61%
  • 質問をまとめることで、思考を整理して素早く解決
  • テンプレート機能で、簡単に質問をまとめられる

関連した質問

同じタグがついた質問を見る