質問編集履歴
1
sigmoidを内蔵している誤差関数に変更し、層を一つ増やしました。
test
CHANGED
File without changes
|
test
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
チュートリアルの例文を少し書き換える形で初めてpytorchを使ってみたのですが、思ったような結果が得られないので質問させていただくことにしました。
|
2
2
|
|
3
|
-
|
3
|
+
|
4
4
|
|
5
5
|
tfidfVectorizerを使って文章をベクトル化したものを使って2クラス分類の問題を解こうとしていて、train_vectorsはcsr_matrix形式のベクトルで、それをCSR_to_Tensorという関数でTensor型に直しています。大きさは(7613, 15269)です。
|
6
6
|
|
@@ -20,7 +20,7 @@
|
|
20
20
|
|
21
21
|
# H is hidden dimension; D_out is output dimension.
|
22
22
|
|
23
|
-
N, D_in, H, D_out = 7613, 15269, 100, 1
|
23
|
+
N, D_in, H, I, D_out = 7613, 15269, 1000, 100, 1
|
24
24
|
|
25
25
|
|
26
26
|
|
@@ -46,7 +46,11 @@
|
|
46
46
|
|
47
47
|
torch.nn.ReLU(),
|
48
48
|
|
49
|
+
torch.nn.Linear(H, I),
|
50
|
+
|
51
|
+
torch.nn.ReLU(),
|
52
|
+
|
49
|
-
torch.nn.Linear(
|
53
|
+
torch.nn.Linear(I, D_out),
|
50
54
|
|
51
55
|
)
|
52
56
|
|
@@ -56,11 +60,11 @@
|
|
56
60
|
|
57
61
|
# case we will use Mean Squared Error (MSE) as our loss function.
|
58
62
|
|
59
|
-
loss_fn = torch.nn.
|
63
|
+
loss_fn = torch.nn.BCEWithLogitsLoss()
|
60
64
|
|
61
65
|
|
62
66
|
|
63
|
-
learning_rate = 1e-
|
67
|
+
learning_rate = 1e-3
|
64
68
|
|
65
69
|
for t in range(100):
|
66
70
|
|
@@ -74,7 +78,9 @@
|
|
74
78
|
|
75
79
|
y_pred = model(x.float())
|
76
80
|
|
81
|
+
y_pred = y_pred.reshape(len(y_pred))
|
77
82
|
|
83
|
+
|
78
84
|
|
79
85
|
# Compute and print loss. We pass Tensors containing the predicted and true
|
80
86
|
|
@@ -120,4 +126,28 @@
|
|
120
126
|
|
121
127
|
```
|
122
128
|
|
129
|
+
一応実行はできるのですが誤差関数の値がほとんど小さくなっていないことから学習ができていないと思われます。
|
130
|
+
|
123
|
-
|
131
|
+
その原因と解決方法を教えていただければ幸いです。
|
132
|
+
|
133
|
+
出力は以下のようになっていました。
|
134
|
+
|
135
|
+
> 9 0.7005422115325928
|
136
|
+
|
137
|
+
19 0.7004371285438538
|
138
|
+
|
139
|
+
29 0.7003325819969177
|
140
|
+
|
141
|
+
39 0.7002285718917847
|
142
|
+
|
143
|
+
49 0.7001250386238098
|
144
|
+
|
145
|
+
59 0.7000225186347961
|
146
|
+
|
147
|
+
69 0.6999202966690063
|
148
|
+
|
149
|
+
79 0.6998189687728882
|
150
|
+
|
151
|
+
89 0.6997179388999939
|
152
|
+
|
153
|
+
99 0.6996178030967712
|