pytorchのnn.BCELoss()と、scikit-learnのlog_lossと、自作のlog_loss関数の比較をしているのですが、
数値が合わず、改善方法を教えていただけないでしょうか。
python
1import torch 2import torch.nn as nn 3from sklearn.metrics import log_loss 4 5def my_logloss(target, output_ohe): 6 loss = 0 7 num_data = target.size()[0] 8 for idx in range(num_data): 9 loss += -(target[idx]*torch.log(output_ohe[idx,0])+(1-target[idx])*torch.log(output_ohe[idx,1]))/num_data 10 return loss 11 12#target 13target = torch.tensor([1,0,1]) 14NUM_data = target.size()[0] 15NUM_target = target.max()+1 16target_ohe = torch.zeros([NUM_data,NUM_target]) 17for idx in range(NUM_data): 18 target_ohe[idx,target[idx]] = 1 19 20#output 21output = torch.tensor([0.7,0.4,0.3]) 22output_ohe = torch.zeros([NUM_data,NUM_target]) 23for idx in range(NUM_data): 24 output_ohe[idx,0] = output[idx] 25 output_ohe[idx,1] = 1-output[idx] 26 27#pytorch_logloss 28loss_func = nn.BCELoss() 29loss1 = loss_func(output_ohe, target_ohe) 30print(loss1.numpy())#0.8256462 31 32#sklearn_logloss 33loss2 = log_loss(target,output_ohe) 34print(loss2)#0.825646162033081 35 36#my_logloss 37loss3 = my_logloss(target,output_ohe) 38print(loss3.numpy())#0.69049114 39
回答1件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。