回答編集履歴
2
説明の追記
test
CHANGED
@@ -10,7 +10,7 @@
|
|
10
10
|
|
11
11
|
[PyTorchを用いてディープラーニングによるワイン分類をしてみた](https://techtech-sorae.com/pytorch%E3%82%92%E7%94%A8%E3%81%84%E3%81%A6%E3%83%87%E3%82%A3%E3%83%BC%E3%83%97%E3%83%A9%E3%83%BC%E3%83%8B%E3%83%B3%E3%82%B0%E3%81%AB%E3%82%88%E3%82%8B%E3%83%AF%E3%82%A4%E3%83%B3%E5%88%86%E9%A1%9E/)
|
12
12
|
|
13
|
-
のコードは、下記の修正をし
|
13
|
+
のコードは、下記の修正をした方がいいです
|
14
14
|
|
15
15
|
```python
|
16
16
|
|
@@ -51,3 +51,15 @@
|
|
51
51
|
```
|
52
52
|
|
53
53
|
の「dim=0」は間違いで、「dim=1」が正しいです
|
54
|
+
|
55
|
+
そこを直して、
|
56
|
+
|
57
|
+
```python
|
58
|
+
|
59
|
+
y_pred_prob = torch.exp(model(test_x))
|
60
|
+
|
61
|
+
```
|
62
|
+
|
63
|
+
を計算しても、二つの合計は1.0になります
|
64
|
+
|
65
|
+
ただし、「log_softmax」を二重に計算するので、効率が悪くなると思います
|
1
説明の追記
test
CHANGED
@@ -1,14 +1,24 @@
|
|
1
|
+
まず、この質問の内容とは直接は関係無いのですが、
|
2
|
+
|
3
|
+
[PytorchのCrossEntropyLossの解説](https://qiita.com/ground0state/items/8933f9ef54d6cd005a69)
|
4
|
+
|
5
|
+
とかに書かれてるように、「torch.nn.CrossEntropyLoss()」には「log_softmax」の計算も含まれてるので、「torch.nn.CrossEntropyLoss()」を使う場合はニューラルネットの定義側には「log_softmax」は不要です
|
6
|
+
|
7
|
+
|
8
|
+
|
9
|
+
なので、
|
10
|
+
|
1
11
|
[PyTorchを用いてディープラーニングによるワイン分類をしてみた](https://techtech-sorae.com/pytorch%E3%82%92%E7%94%A8%E3%81%84%E3%81%A6%E3%83%87%E3%82%A3%E3%83%BC%E3%83%97%E3%83%A9%E3%83%BC%E3%83%8B%E3%83%B3%E3%82%B0%E3%81%AB%E3%82%88%E3%82%8B%E3%83%AF%E3%82%A4%E3%83%B3%E5%88%86%E9%A1%9E/)
|
2
12
|
|
3
|
-
のコードの
|
13
|
+
のコードは、下記の修正をします
|
4
14
|
|
5
15
|
```python
|
6
16
|
|
7
|
-
return F.log_softmax(x, dim=
|
17
|
+
return F.log_softmax(x, dim=0)
|
8
18
|
|
9
19
|
```
|
10
20
|
|
11
|
-
|
21
|
+
↓ 変更
|
12
22
|
|
13
23
|
```python
|
14
24
|
|
@@ -16,7 +26,7 @@
|
|
16
26
|
|
17
27
|
```
|
18
28
|
|
19
|
-
|
29
|
+
その上で、
|
20
30
|
|
21
31
|
```python
|
22
32
|
|
@@ -24,6 +34,20 @@
|
|
24
34
|
|
25
35
|
```
|
26
36
|
|
27
|
-
|
37
|
+
を計算したら、二つの合計が1.0になります
|
28
38
|
|
29
39
|
(それが確率かどうかは別にして)
|
40
|
+
|
41
|
+
|
42
|
+
|
43
|
+
.
|
44
|
+
|
45
|
+
なお、toast-uzさんが回答に書いてるように、修正前の
|
46
|
+
|
47
|
+
```python
|
48
|
+
|
49
|
+
return F.log_softmax(x, dim=0)
|
50
|
+
|
51
|
+
```
|
52
|
+
|
53
|
+
の「dim=0」は間違いで、「dim=1」が正しいです
|