回答編集履歴
1
書式の修正
answer
CHANGED
@@ -1,37 +1,37 @@
|
|
1
|
-
> 予測値が [[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]
|
1
|
+
> 予測値が [[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]
|
2
|
-
教師データのインデックスが [1, 2]の場合、誤差は0になる
|
2
|
+
> 教師データのインデックスが [1, 2]の場合、誤差は0になる
|
3
|
-
|
3
|
+
|
4
|
-
[PytorchのCrossEntropyLossの解説](https://qiita.com/ground0state/items/8933f9ef54d6cd005a69)
|
4
|
+
[PytorchのCrossEntropyLossの解説](https://qiita.com/ground0state/items/8933f9ef54d6cd005a69)
|
5
|
-
の「PyTorchのCrossEntropyLoss」に書かれてるように、「torch.nn.CrossEntropyLoss()」は予測値をソフトマックス処理するので、そうはなりません
|
5
|
+
の「PyTorchのCrossEntropyLoss」に書かれてるように、「torch.nn.CrossEntropyLoss()」は予測値をソフトマックス処理するので、そうはなりません
|
6
|
-
|
6
|
+
|
7
|
-
> 出力結果
|
7
|
+
> 出力結果
|
8
|
-
tensor(0.5514)
|
8
|
+
> tensor(0.5514)
|
9
|
-
|
9
|
+
|
10
|
-
```python
|
10
|
+
```python
|
11
|
-
import numpy as np
|
11
|
+
import numpy as np
|
12
|
-
y = [0, 1, 0]
|
12
|
+
y = [0, 1, 0]
|
13
|
-
print(-1 * y[1] + np.log(np.exp(y).sum()))
|
13
|
+
print(-1 * y[1] + np.log(np.exp(y).sum()))
|
14
|
-
```
|
14
|
+
```
|
15
|
-
を計算したら「0.5514」になります
|
15
|
+
を計算したら「0.5514」になります
|
16
|
-
|
16
|
+
|
17
|
-
.
|
17
|
+
.
|
18
|
-
> 確率分布での学習の場合でも。 -t*log(y) でよいのでしょうか?
|
18
|
+
> 確率分布での学習の場合でも。 -t*log(y) でよいのでしょうか?
|
19
|
-
|
19
|
+
|
20
|
-
「torch.nn.CrossEntropyLoss()」での計算のやり方に合わせたら、という意味なら、
|
20
|
+
「torch.nn.CrossEntropyLoss()」での計算のやり方に合わせたら、という意味なら、
|
21
|
-
```python
|
21
|
+
```python
|
22
|
-
y = [[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]
|
22
|
+
y = [[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]
|
23
|
-
t = [[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]
|
23
|
+
t = [[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]
|
24
|
-
```
|
24
|
+
```
|
25
|
-
の場合の計算結果が「0.5514」になるはずですが、質問に掲載のコードではそうなりません
|
25
|
+
の場合の計算結果が「0.5514」になるはずですが、質問に掲載のコードではそうなりません
|
26
|
-
理由は二つあります
|
26
|
+
理由は二つあります
|
27
|
-
・上で指摘したように「torch.nn.CrossEntropyLoss()」は予測値をソフトマックス処理する
|
27
|
+
・上で指摘したように「torch.nn.CrossEntropyLoss()」は予測値をソフトマックス処理する
|
28
|
-
・一つのサンプルでの複数の結果(この例では三つ)は合計する
|
28
|
+
・一つのサンプルでの複数の結果(この例では三つ)は合計する
|
29
|
-
|
29
|
+
|
30
|
-
以上を踏まえると、こんな感じになると思います
|
30
|
+
以上を踏まえると、こんな感じになると思います
|
31
|
-
```python
|
31
|
+
```python
|
32
|
-
(-t * torch.log(y)).mean()
|
32
|
+
(-t * torch.log(y)).mean()
|
33
|
-
```
|
33
|
+
```
|
34
|
-
↓ 変更
|
34
|
+
↓ 変更
|
35
|
-
```python
|
35
|
+
```python
|
36
|
-
torch.mean(torch.sum(-t * torch.t(torch.t(y) - torch.log(torch.sum(torch.exp(y), dim=1))), dim=1))
|
36
|
+
torch.mean(torch.sum(-t * torch.t(torch.t(y) - torch.log(torch.sum(torch.exp(y), dim=1))), dim=1))
|
37
37
|
```
|