質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

ただいまの
回答率

88.77%

pytorch Variableレイヤーをgpuに指定できません

解決済

回答 1

投稿 編集

  • 評価
  • クリップ 0
  • VIEW 342

Flan.

score 32

.to(cuda)で指定しているのですが うまくいきません
cpuにあるとエラーがでます

試したこと
試しにもう一度def gpu(self): で二重に指定したのですがうまくいきません
乱数をに入れてみたり 分けて実行もしても同じエラーが出たので
self.tril_maskレイヤーが原因だと思います

なぜこうなるのかまったくわからない わかる人教えてください

     モデル定義
        self.L = nn.Linear(self.hidden_size, num_outputs ** 2)
        nn.init.kaiming_normal_(self.L.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
        self.tril_mask = Variable(torch.tril(torch.ones(num_outputs, num_outputs), diagonal=-1).unsqueeze(0))
        self.diag_mask = Variable(torch.diag(torch.diag(torch.ones(num_outputs, num_outputs))).unsqueeze(0))

       略
            L1=self.diag_mask.expand_as(L)
            L2=torch.exp(L)
            L3 = L * self.tril_mask.expand_as(L)
            L=L3 + L2 * L1
            P = torch.bmm(L, L.transpose(2, 1))
            print(self.tril_mask.is_cuda)
       略
       return output
            def gpu(self):
                self.tril_mask.to('cuda:0')


criterion = nn.MSELoss()
targetQN = mainQN
mainQN.gpu()#!!!!!!
targetQN.eval()
mainQN.eval()
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-1-87610d0531cc> in <module>
    401             total_reward_vec = np.hstack((total_reward_vec[1:], episode_reward))  # 報酬を記録
    402             if batch_size<len(memory.buffer*4):
--> 403                 memory_TDerror.update_TDerror(gamma,multireward_steps)
    404                 for _ in range(t):
    405                     trin.pioritized_experience_replay(batch_size, gamma,step=episode,state_size=state_,action_size=acthon,multireward_steps=multireward_steps)

<ipython-input-1-87610d0531cc> in update_TDerror(self, gamma, multireward_steps)
    304             next_state=memory.buffer[i][0]
    305             target = memory.buffer[i][2] + (gamma**multireward_steps) * targetQN.forward(next_state,"net_v")[0]
--> 306             self.buffer[i] =target - mainQN.forward(inpp,"net_q")[0]
    307 
    308 

<ipython-input-1-87610d0531cc> in forward(self, inputs, net)
    194             L1=self.diag_mask.expand_as(L)
    195             L2=torch.exp(L)
--> 196             L3 = L * self.tril_mask.expand_as(L)
    197             L=L3 + L2 * L1
    198             P = torch.bmm(L, L.transpose(2, 1))

~\Anaconda3\envs\pyflan\lib\site-packages\apex\amp\wrap.py in wrapper(*args, **kwargs)
     56                                          args,
     57                                          kwargs)
---> 58             return orig_fn(*new_args, **kwargs)
     59         else:
     60             raise NotImplementedError('Do not know how to handle ' +

RuntimeError: expected device cuda:0 but got device cpu

追記 is_cuda という物を使って調べた結果 False と出ました
.to(cuda:0)をforでやってみましたが(やけくそ)結果同じ

  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 過去に投稿した質問と同じ内容の質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

回答 1

check解決した方法

0

Variableを消したらうまくいきました 

投稿

  • 回答の評価を上げる

    以下のような回答は評価を上げましょう

    • 正しい回答
    • わかりやすい回答
    • ためになる回答

    評価が高い回答ほどページの上位に表示されます。

  • 回答の評価を下げる

    下記のような回答は推奨されていません。

    • 間違っている回答
    • 質問の回答になっていない投稿
    • スパムや攻撃的な表現を用いた投稿

    評価を下げる際はその理由を明確に伝え、適切な回答に修正してもらいましょう。

15分調べてもわからないことは、teratailで質問しよう!

  • ただいまの回答率 88.77%
  • 質問をまとめることで、思考を整理して素早く解決
  • テンプレート機能で、簡単に質問をまとめられる

関連した質問

同じタグがついた質問を見る