ニューラルネットワークの損失関数において,(binary) cross entropy lossと(binary) cross entropy logitsの違いがわかりません.
入力0.1,0.9に対してそれぞれ試してみた結果次のようになりました.
Python
1x = torch.tensor([0.1]*10 ,dtype=torch.float) 2b = torch.zeros(10,dtype=torch.float) 3b1 = torch.ones(10, dtype=torch.float) 4 5m = F.binary_cross_entropy(x, b) 6m1 = F.binary_cross_entropy(x, b1) 7n1 = F.binary_cross_entropy_with_logits(x,b) 8n2 = F.binary_cross_entropy_with_logits(x,b1) 9print(m,m1,n1,n2) 10 11#tensor(0.1054) tensor(2.3026) tensor(0.7444) tensor(0.6444)
python
1x = torch.tensor([0.9]*10 ,dtype=torch.float) 2b = torch.zeros(10,dtype=torch.float) 3b1 = torch.ones(10, dtype=torch.float) 4 5m = F.binary_cross_entropy(x, b) 6m1 = F.binary_cross_entropy(x, b1) 7n1 = F.binary_cross_entropy_with_logits(x,b) 8n2 = F.binary_cross_entropy_with_logits(x,b1) 9 10print(m,m1,n1,n2) 11 12#tensor(2.3026) tensor(0.1054) tensor(1.2412) tensor(0.3412)
0.9に対しては,正解ラベル1の方がlossの値は小さくなっているので理解できます.
0.1に対して,with logits では正解ラベル0の方がlossの値が小さくなっているのは変じゃないですか?これでは正解ラベルに近づくようにパラメータを更新できるようには思えません….
Pytorch公式の解説(ページ)を読みましたがピンときませんでした.
なぜこのようなことが起こってしまうのか,そもそもwith logitsはなんなのか,ご教授頂けますと幸いです.
あなたの回答
tips
プレビュー