###実現したいこと
RNNを使用して数字列の順番を予測させたいと考えています
入力する数字列は[3,4,1,4,2,0,3,2]のような形式で、長さは10〜14程度、全100セットあります
###該当のソースコード
ソースコードはこちらです
python
1class LSTM(Chain): 2 # in_size = 1 3 # hidden_size = 5 4 # out_size = 1 5 def __init__(self, in_size, hidden_size, out_size): 6 super(LSTM, self).__init__( 7 xh = L.Linear(in_size, hidden_size), 8 hh_x = L.Linear(hidden_size, 4 * hidden_size), 9 hh_h = L.Linear(hidden_size, 4 * hidden_size), 10 hy = L.Linear(hidden_size, out_size) 11 ) 12 self.hidden_size = hidden_size 13 14 def __call__(self, x, t=None): 15 if self.h is None: 16 self.h = Variable(np.zeros((x.shape[0], self.hidden_size))) 17 self.c = Variable(np.zeros((x.shape[0], self.hidden_size))) 18 19 x = Variable(x) 20 # print x >> variable([0 4 0 1 2 4 3 1 2 4 0 3]) 21 t = Variable(t) 22 # print t >> variable([4 0 1 2 4 3 1 2 4 0 3 4]) 23 24 h = self.xh(x) 25 h = self.hh_x(h) + self.hh_h(self.h) 26 self.c, self.h = F.lstm(self.c, h) 27 y = self.hy(self.h) 28 29 return F.mean_squared_error(y, t) 30 31 def reset(self): 32 self.zerograds() 33 self.h = None 34 self.c = None
###発生している問題・エラーメッセージ
上記のコードにおいて ____call____の中程
h = self.xh(x)
の部分で止まっています
インデックスエラーのため、入力層のユニット数やリストの大きさを変えてみましたが違うようです
エラーメッセージはこちらです
Traceback (most recent call last): File "/program/rnn/rnn.py", line 112, in <module> loss += model(x, t, train=True) File "/program/rnn/rnn.py", line 35, in __call__ h = self.xh(x) File "/usr/local/lib/python2.7/dist-packages/chainer/links/connection/linear.py", line 127, in __call__ return linear.linear(x, self.W, self.b) File "/usr/local/lib/python2.7/dist-packages/chainer/functions/connection/linear.py", line 118, in linear y, = LinearFunction().apply(args) File "/usr/local/lib/python2.7/dist-packages/chainer/function_node.py", line 207, in apply self._check_data_type_forward(in_data) File "/usr/local/lib/python2.7/dist-packages/chainer/function_node.py", line 267, in _check_data_type_forward self.check_type_forward(in_type) File "/usr/local/lib/python2.7/dist-packages/chainer/functions/connection/linear.py", line 20, in check_type_forward x_type.shape[1] == w_type.shape[1], IndexError: tuple index out of range
なぜこのようなエラーが出るのか、わかる方教えていただけませんか?

回答1件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。
2017/11/09 00:47