質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.49%
PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

Q&A

解決済

1回答

2546閲覧

Pytorchのautocastを使用した場合の畳み込み層出力がnanになる問題

pnbmg0044

総合スコア16

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

0グッド

0クリップ

投稿2022/03/28 09:51

Pytorchにて、U^2Netを実装し、(1, 3, 512, 512)構造のデータを入力した場合、一部の畳み込み層での出力がnanになります。
nanが発生する原因が分からず、苦戦しているため、ご教示していただければと思います。

U^2Net(github)

REBNCONV

python

1class REBNCONV(nn.Module): 2 def __init__(self, in_ch=3, out_ch=3, dirate=1): 3 super(REBNCONV, self).__init__() 4 5 self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate) 6 self.bn_s1 = nn.BatchNorm2d(out_ch) 7 self.relu_s1 = nn.ReLU(inplace=True) 8 9 def forward(self, x): 10 hx = x 11 xout = self.relu_s1(self.bn_s1(self.conv_s1(hx))) 12 return xout

RSU-7

python

1class RSU7(nn.Module): # UNet07DRES(nn.Module): 2 def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 3 super(RSU7, self).__init__() 4 self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) 5 以下省略 6 7 def forward(self, x): 8 hx = x 9 hxin = self.rebnconvin(hx) 10 以下省略

U2Net

python

1class U2NET(nn.Module): 2 def __init__(self, in_ch=3, out_ch=1, num_classes=20): 3 super(U2NET, self).__init__() 4 5 self.stage1 = RSU7(in_ch, 32, 64) 6 self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 7 以下省略 8 def forward(self, x): 9 省略 10 # -------------------- decoder -------------------- 11 hx5d = self.stage5d(torch.cat((hx6up, hx5), 1)) 12 hx5dup = _upsample_like(hx5d, hx4) 13 14 hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1)) 15 hx4dup = _upsample_like(hx4d, hx3) 16 17 hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1)) 18 hx3dup = _upsample_like(hx3d, hx2) 19 20 hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1)) 21 hx2dup = _upsample_like(hx2d, hx1) 22 23 hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))

問題の箇所として、U2Netの
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
の箇所でnanが発生&遡及すると、
REBNCONVの
self.conv_s1(hx)
でnanが発生しているようでした。

PCスペック

  • RTX 3080
  • Pytorch 1.11
  • CUDA 11.3

試したこと
①MobilenetV3(torchvision.models)で実行してみたが、nanは発生しなかった
①の0層 (畳み込みをして画像を縮小する層: conv & batchnorm) を分離し、1層目以降をstack hourglass構造にし、(4, 3, 512, 512)を入力したところ、2つ目のデータの順伝播中に、0層目の畳み込み層でnanが発生しました。

備考
インターネット上には、CUDA10.2を使用すると解決するとあり、rtx3080はCUDA11.以降のみ対応のため、先の方法は適応できませんでした。
Half precision Convolution cause NaN in forward pass

宜しくお願い致します。

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

0

自己解決

自己解決しました。
どうやら、各Conv2D層での出力が-infになっているようでしたので、そのクラスのforword関数中の出力に対し
torch.clamp({出力変数}, min=-65504.0, max=65504.0)
を適用したところ、-infが発生せず、結果的にnan発生の抑止に繋がりました。

投稿2022/03/29 07:41

pnbmg0044

総合スコア16

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.49%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問