質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%
PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

Q&A

0回答

3890閲覧

PyTorch でloss関数にtorch.sqrtを用いる方法について

fujifuji_

総合スコア4

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

0グッド

0クリップ

投稿2019/10/06 05:18

編集2019/10/14 00:47

前提・実現したいこと

pytorchでdeep learningの画像分類モデルを作成しています。

モデルのloss関数の一部にtorch.sqrt()をしようしたところ、backward時にnanが発生する問題に突き当たりました。
torch.sqrt()に入力されるベクトルの要素の大きさがとても小さいことが原因のようです。

torch.sqrt()のinputが小さいとbackward時に1/(2*torch.sqrt())がinfになるようです...

何か対処法がわかる方がいらっしゃいましたら、お教えいただければ幸いです。

エラーメッセージ Traceback (most recent call last): File "main_label_grad.py", line 504, in <module> model_g = main() File "main_label_grad.py", line 459, in main tr_acc, tr_acc5, tr_los, grad_train, last_v4 = train(train_loader, net_c, net_t, optimizer_c, optimizer_t, epoch, args, log_G, args.noise_dim, grad_train_old=None, v4=None) File "main_label_grad.py", line 320, in train loss_trans.backward() File "C:\Users\GUESTUSER.conda\envs\tf37\lib\site-packages\torch\tensor.py", line 118, in backward torch.autograd.backward(self, gradient, retain_graph, create_graph) File "C:\Users\GUESTUSER.conda\envs\tf37\lib\site-packages\torch\autograd\__init__.py", line 93, in backward allow_unreachable=True) # allow_unreachable flag RuntimeError: Function 'SqrtBackward' returned nan values in its 0th output.

上記のloss_transがモデルの目的関数で、以下の関数の1つ目の返り値に該当します。
以下の関数(new_norm)の return torch.sqrt(v4_ema)のv4_emaが小さくてnanになってしまっております。

def new_norm(v, epoch, iter, last_v4=None): v2 = torch.pow(v,2) v4 = torch.pow(v,4) v4_ema = ema(v4, epoch, iter, last_v4) epsilon = torch.ones(v4_ema.size(0)) * 1e-10 epsilon = epsilon.cuda() return (v2/(torch.sqrt(v4_ema)+epsilon)).sum()/v4_ema.size(0), v4_ema

また上記の関数new_normは以下の式を求めようとして作成したものです。
イメージ説明

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

fujifuji_

2019/10/15 01:11

解決した際や情報の追加を求められた際はは、両方のサイトにお答えさせていただく、所存です。
fujifuji_

2019/10/16 06:11

「現状のプログラムでは、ルート内部に入る値が1e-13とかになっており、1e-5を足してしまうと、プログラムが目的通りに作用しないと考えております。」と回答いたしました
fujifuji_

2019/10/16 06:11

また、1e-5を足すとプログラムは一応動きました。
guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

まだ回答がついていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問