構文が無効と言っているのですが、参考動画と同じなのですが、どこが間違っているのか、教えてくださると助かります。
また、print("The Number of Training Data: ", len(train_data)) の実行結果が1しかないのですが、本当であれば、564と結果がないといけないのですが、なぜ、データが1つしないのかも教えてくださると助かります。
ご教授お願いいたします。
参考動画リンク内容 13:41
import torch import torch.nn as nn import numpy as np import pandas as pd import matplotlib.pyplot as plt
stock_data = pd.read_csv( "Downloads/PP.csv", index_col = 0, parse_dates=True ) stock_data
stock_data.drop( ["Open", "High", "Low", "Close", "Volume"], axis="columns", inplace=True ) stock_data
stock_data.plot(figsize=(12, 4))
y = stock_data["Adj Close"].values y
from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler(feature_range=(-1, 1)) scaler.fit(y.reshape(-1, 1)) y = scaler.transform(y.reshape(-1, 1)) y
y = torch.FloatTensor(y).view(-1) y
test_size = 24 train_seq = y[:-test_size:] test_seq = y[-test_size:]
plt.figure(figsize=(12, 4)) plt.xlim(-20, len(test_seq)+20) plt.grid(True) plt.plot(test_seq)
train_window_size = 12
def input_data(seq, ws): out = [] L = len(seq) for i in range(L-ws): window = seq[i:i+ws] label = seq[i+ws:i+ws+1] out.append((window, label)) return out
train_data = input_data(train_seq, train_window_size)
print("The Number of Training Data: ", len(train_data)) ↓ The Number of Training Data: 1
class Model(nn.Module): def __init__(self, input=1, h=50, output=1): super().__init__() self.hidden_size = h self.lstm = nn.LSTM(input, h) self.fc = nn.Linear(h, output) self.hidden = ( torch.zeros(1, 1, h), torch.zeros(1, 1, h) ) def forward(self, seq): out,_=self.lstm( seq.view(len(seq), 1, -1), self.hidden ) out = self.fc( out.view(len(seq), -1) ) return out[-1]
torch.manual_seed(123) model = Model() criterion = nn.MSELoss() optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
epochs = 10 train_losses = [] test_losses = []
def run_train(): model.train() for train_window, correct_label in train_data: optimizer.zero_grad() model.hidden = ( torch.zeros(1, 1, model.hidden_size), torch.zeros(1, 1, model.hidden_size) ) train_predicted_label = model.forward(train_window) train_loss = criterion(train_precicted_label, correct_label) train_loss.backward() optimizer.step() train_losses.append(train_loss)
a = torch.tensor([3]) a.item()
def run_test(): model.eval() for i in range(test_size): test_window = torch.Float(extending_seq[-test_size:]) with torch.no_grad(): model.hidden = ( torch.zeros(1, 1, model.hidden_size), torch.zeros(1, 1, model.hidden_size) ) test_predicred_label = model.forward(test_window) extending_seq.append(test_predicred_label.item()) test_loss = criterion( torch.FloatTensor(extending_seq[-test_size:]), y[len(y)-test_size:] ) test_losses.append(test_loss)
train_seq[-test_size:]
train_seq[-test_size:].tolist()
for epoch in range(epochs): print() print(f'Epoch: {epoch+1}') run train() extending_seq = train_seq[-test_size:].tolist() run test() plt.figure(figsize=(12, 4)) plt.xlim(-20, len(y)+20) plt.grid(True) plt.plot(y.numpy()) plt.plot( range(len(y)-test_size, len(y)), extending_seq[-test_size:] ) plt.show()
↓ エラー
File "<ipython-input-67-5ed17c2958f4>", line 6 run train() ^ SyntaxError: invalid syntax
追記。。
for epoch in range(epochs): print() print(f'Epoch: {epoch+1}') run_train() extending_seq = train_seq[-test_size:].tolist() run_test() plt.figure(figsize=(12, 4)) plt.xlim(-20, len(y)+20) plt.grid(True) plt.plot(y.numpy()) plt.plot( range(len(y)-test_size, len(y)), extending_seq[-test_size:] ) plt.show()
エラー内容
Epoch:
1----------------------------------------------------------------------- 2NotImplementedError Traceback (most recent call last) 3<ipython-input-212-863045c37b1e> in <module> 4 4 print(f'Epoch: {epoch+1}') 5 5 6----> 6 run_train() 7 7 8 8 extending_seq = train_seq[-test_size:].tolist() 9 10<ipython-input-206-e770668ef1d7> in run_train() 11 10 ) 12 11 13---> 12 train_predicted_label = model.forward(train_window) 14 13 train_loss = criterion(train_precicted_label, correct_label) 15 14 16 17~/opt/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _forward_unimplemented(self, *input) 18 173 registered hooks while the latter silently ignores them. 19 174 """ 20--> 175 raise NotImplementedError 21 176 22 177 23 24NotImplementedError:

回答1件
あなたの回答
tips
プレビュー