質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.50%
PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

コードレビュー

コードレビューは、ソフトウェア開発の一工程で、 ソースコードの検査を行い、開発工程で見過ごされた誤りを検出する事で、 ソフトウェア品質を高めるためのものです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

0回答

1241閲覧

pytorch RuntimeError: Expected hidden[0] size (1, 17938, 50), got [2, 1, 50]が発生

irene555

総合スコア1

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

コードレビュー

コードレビューは、ソフトウェア開発の一工程で、 ソースコードの検査を行い、開発工程で見過ごされた誤りを検出する事で、 ソフトウェア品質を高めるためのものです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

0クリップ

投稿2021/06/24 12:18

前提・実現したいこと

pytorchにてLSTMを使って、時系列データの予測をしたいのですが、
RuntimeErrorが発生しました。

解決方法がわからず困っておりますので、教えていただきたいです。
よろしくお願いいたします。

発生している問題・エラーメッセージ

RuntimeError: Expected hidden[0] size (1, 17938, 50), got [2, 1, 50]

該当のソースコード

python

1class LSTM(nn.Module): 2 3 #クラスのインスタンス化(はじめにLSTMクラスが実行されたときのみ実行される) 4 def __init__(self, input_size =2, hidden_size =50, output_size = 1 ): 5 super().__init__() 6 7 #隠れ層のサイズ 8 self.hidden_size = hidden_size 9 10 #input_size → hidden_sizeに変換 11 self.lstm = nn.LSTM(input_size, hidden_size) 12 13 #hidden_size → output_sizeに変換 lstmの出力に全結合レイヤーを介す 14 self.fc = nn.Linear(hidden_size, output_size) 15 16 #0で隠れ層を初期化する、hidden層の初期化とcell層の初期化 17 self.hidden = (torch.zeros(2,1,hidden_size),torch.zeros(2,1,hidden_size)) 18 19 20 #予測に用いる関数(順伝播) 21 def forward(self, x): 22 #データと隠れ層の状態が返ってくる 23 lstm_out , self.hidden = self.lstm(x, self.hidden) 24 #全結合層、順伝播させるのはデータが格納されているlstm_outのみ 25 pred = self.fc(lstm_out) 26 27 return pred[-1] #-1を指定することで一番最後の値を取り出すことができる 28 29 30print() 31print('***LSTMの定義終了***') 32print() 33 34 35 36 37 38''' ******************関数の定義・インスタンス化****************** ''' 39 40#モデルのインスタンス化と損失関数、最適化関数を定義する 41torch.manual_seed(3) #乱数の値を一定にし、処理結果を同じにする 42model = LSTM() #インスタンス 43 44#param = model.parameters() 45#print(param) 46 47#モデルの中身の確認 48#for k,v in model.state_dict().items(): 49# print(k) 50# print(v) 51 52 53criterion = nn.MSELoss() #損失関数 54optimizer = torch.optim.SGD(model.parameters(),lr=0.01) #最適化関数 55 56 57print() 58print('********************関数の定義・インスタンス化終了********************') 59print() 60 61 62 63 64''' ******************ここから学習****************** ''' 65 66model.train() #学習モード 67 68epochs = 10 #epochを定義 69loss_list_train = [] #損失を格納する配列を用意 70 71 72print(x_train) 73 74 75#epochごとに重みを更新 76for epoch in range(epochs): 77 78 #順伝播 ***LSTMモデルに入力データを引数で渡して、戻り値で順伝播後の予測値が返ってくる 79 rsoc_pred_train = model(x_train) 80     #★ここでエラーが発生 81

データのかたち

x_train.shape = torch.Size([60, 17938, 2])

補足情報(FW/ツールのバージョンなど)

pytorchを使用

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

まだ回答がついていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.50%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問