質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.46%
PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

scikit-learn

scikit-learnは、Pythonで使用できるオープンソースプロジェクトの機械学習用ライブラリです。多くの機械学習アルゴリズムが実装されていますが、どのアルゴリズムも同じような書き方で利用できます。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

解決済

1回答

2045閲覧

pytorch scikit-learn log_lossが自作計算と合わない

asaliquid1011

総合スコア16

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

scikit-learn

scikit-learnは、Pythonで使用できるオープンソースプロジェクトの機械学習用ライブラリです。多くの機械学習アルゴリズムが実装されていますが、どのアルゴリズムも同じような書き方で利用できます。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

1クリップ

投稿2021/01/27 08:22

pytorchのnn.BCELoss()と、scikit-learnのlog_lossと、自作のlog_loss関数の比較をしているのですが、
数値が合わず、改善方法を教えていただけないでしょうか。

python

1import torch 2import torch.nn as nn 3from sklearn.metrics import log_loss 4 5def my_logloss(target, output_ohe): 6 loss = 0 7 num_data = target.size()[0] 8 for idx in range(num_data): 9 loss += -(target[idx]*torch.log(output_ohe[idx,0])+(1-target[idx])*torch.log(output_ohe[idx,1]))/num_data 10 return loss 11 12#target 13target = torch.tensor([1,0,1]) 14NUM_data = target.size()[0] 15NUM_target = target.max()+1 16target_ohe = torch.zeros([NUM_data,NUM_target]) 17for idx in range(NUM_data): 18 target_ohe[idx,target[idx]] = 1 19 20#output 21output = torch.tensor([0.7,0.4,0.3]) 22output_ohe = torch.zeros([NUM_data,NUM_target]) 23for idx in range(NUM_data): 24 output_ohe[idx,0] = output[idx] 25 output_ohe[idx,1] = 1-output[idx] 26 27#pytorch_logloss 28loss_func = nn.BCELoss() 29loss1 = loss_func(output_ohe, target_ohe) 30print(loss1.numpy())#0.8256462 31 32#sklearn_logloss 33loss2 = log_loss(target,output_ohe) 34print(loss2)#0.825646162033081 35 36#my_logloss 37loss3 = my_logloss(target,output_ohe) 38print(loss3.numpy())#0.69049114 39

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

0

ベストアンサー

ラベルが逆ですね。

python

1 loss += -((1-target[idx])*np.log(output_ohe[idx,0])+(target[idx])*np.log(output_ohe[idx,1]))/num_data

で。

(ワンホット表現がどうなっているのかをよくよく考えるとこの対処でいいことがわかります)

投稿2021/01/28 15:07

hayataka2049

総合スコア30933

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.46%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問