回答編集履歴

1

追記

2020/12/03 14:02

投稿

meg_
meg_

スコア10948

answer CHANGED
@@ -33,4 +33,70 @@
33
33
  ^
34
34
  SyntaxError: invalid syntax
35
35
 
36
- 上記については``run train()``ではなく``run_train()``です。
36
+ 上記については``run train()``ではなく``run_train()``です。
37
+
38
+ ---
39
+ 【追記】
40
+ > なぜ、参考動画と同じなのに、エラーが出てしまっているのか。
41
+
42
+ 同じではないからです。forwardのインデントが間違っています。
43
+
44
+ <誤>
45
+ ```Python
46
+ class Model(nn.Module):
47
+
48
+ def __init__(self, input=1, h=50, output=1):
49
+ super().__init__()
50
+ self.hidden_size = h
51
+
52
+ self.lstm = nn.LSTM(input, h)
53
+ self.fc = nn.Linear(h, output)
54
+
55
+ self.hidden = (
56
+ torch.zeros(1, 1, h),
57
+ torch.zeros(1, 1, h)
58
+ )
59
+
60
+ def forward(self, seq):
61
+
62
+ out,_=self.lstm(
63
+ seq.view(len(seq), 1, -1),
64
+ self.hidden
65
+ )
66
+
67
+ out = self.fc(
68
+ out.view(len(seq), -1)
69
+ )
70
+
71
+ return out[-1]
72
+ ```
73
+
74
+ <正>
75
+ ```Python
76
+ class Model(nn.Module):
77
+
78
+ def __init__(self, input=1, h=50, output=1):
79
+ super().__init__()
80
+ self.hidden_size = h
81
+
82
+ self.lstm = nn.LSTM(input, h)
83
+ self.fc = nn.Linear(h, output)
84
+
85
+ self.hidden = (
86
+ torch.zeros(1, 1, h),
87
+ torch.zeros(1, 1, h)
88
+ )
89
+
90
+ def forward(self, seq):
91
+
92
+ out, _ = self.lstm(
93
+ seq.view(len(seq), 1, -1),
94
+ self.hidden
95
+ )
96
+
97
+ out = self.fc(
98
+ out.view(len(seq), -1)
99
+ )
100
+
101
+ return out[-1]
102
+ ```