質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

ただいまの
回答率

89.98%

PyTorch でloss関数にtorch.sqrtを用いる方法について

受付中

回答 0

投稿 編集

  • 評価
  • クリップ 0
  • VIEW 348

fujifuji_

score 0

前提・実現したいこと

pytorchでdeep learningの画像分類モデルを作成しています。

モデルのloss関数の一部にtorch.sqrt()をしようしたところ、backward時にnanが発生する問題に突き当たりました。
torch.sqrt()に入力されるベクトルの要素の大きさがとても小さいことが原因のようです。

torch.sqrt()のinputが小さいとbackward時に1/(2*torch.sqrt())がinfになるようです...

何か対処法がわかる方がいらっしゃいましたら、お教えいただければ幸いです。

エラーメッセージ
Traceback (most recent call last):
  File "main_label_grad.py", line 504, in <module>
    model_g = main()
  File "main_label_grad.py", line 459, in main
    tr_acc, tr_acc5, tr_los, grad_train, last_v4 = train(train_loader, net_c, net_t, optimizer_c, optimizer_t, epoch, args, log_G, args.noise_dim, grad_train_old=None, v4=None)
  File "main_label_grad.py", line 320, in train
    loss_trans.backward()
  File "C:\Users\GUESTUSER\.conda\envs\tf37\lib\site-packages\torch\tensor.py", line 118, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "C:\Users\GUESTUSER\.conda\envs\tf37\lib\site-packages\torch\autograd\__init__.py", line 93, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Function 'SqrtBackward' returned nan values in its 0th output.


上記のloss_transがモデルの目的関数で、以下の関数の1つ目の返り値に該当します。
以下の関数(new_norm)の return torch.sqrt(v4_ema)のv4_emaが小さくてnanになってしまっております。

def new_norm(v, epoch, iter, last_v4=None):
    v2 = torch.pow(v,2)
    v4 = torch.pow(v,4)
    v4_ema = ema(v4, epoch, iter, last_v4)
    epsilon = torch.ones(v4_ema.size(0)) * 1e-10
    epsilon = epsilon.cuda()
    return (v2/(torch.sqrt(v4_ema)+epsilon)).sum()/v4_ema.size(0), v4_ema


また上記の関数new_normは以下の式を求めようとして作成したものです。
イメージ説明

  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

質問への追記・修正の依頼

  • fujifuji_

    2019/10/16 15:09

    stackoverflowにおいて、https://ja.stackoverflow.com/questions/59692/pytorchでdeep-learningの画像分類モデルを作成しています

    「infになってしまう時のinputの値は完全に0でしょうか?分数の時にinfとならないように分母に十分小さな値(例えば1e-5)を足しておいても同じエラーになりますか?」このような質問をいただきました

    キャンセル

  • fujifuji_

    2019/10/16 15:11

    「現状のプログラムでは、ルート内部に入る値が1e-13とかになっており、1e-5を足してしまうと、プログラムが目的通りに作用しないと考えております。」と回答いたしました

    キャンセル

  • fujifuji_

    2019/10/16 15:11

    また、1e-5を足すとプログラムは一応動きました。

    キャンセル

まだ回答がついていません

15分調べてもわからないことは、teratailで質問しよう!

  • ただいまの回答率 89.98%
  • 質問をまとめることで、思考を整理して素早く解決
  • テンプレート機能で、簡単に質問をまとめられる