chainer,機械学習ともに初心者でお恥ずかしいコードになっていると思いますが、エラーの解決ご協力お願いします。
将来的にはattentionの可視化まで行いたいと思っています。
おかしいのは最後から2番目の行の
y = self.l4(h3)
でh3に問題があると思います。
python
1import chainer 2import chainer.functions as F 3import chainer.links as L 4from chainer import Variable 5import numpy as np 6 7 8class AttentionDRNN(chainer.Chain): 9 10 def __init__(self, in_s, n_units,out_s, train=True): 11 super(AttentionDRNN, self).__init__( 12 l1=L.LSTM(in_s, n_units), 13 l2=L.LSTM(n_units, n_units), 14 l3=L.LSTM(n_units, n_units), 15 l4=L.Linear(n_units, out_s), 16 ) 17 self.train = train 18 self.i = 0 19 self.gl =[] 20 self.wei = [] 21 def reset_state(self): 22 self.l1.reset_state() 23 self.l2.reset_state() 24 self.l3.reset_state() 25 26 def __call__(self, x, train=True): 27 weigh = [] #ウェイトを記録するためのリストの初期化 28 batchsize = 5 # ミニバッチのサイズを記憶 29 sum = 0 30 h1 = self.l1(x) 31 h2 = self.l2(h1) 32 h3 = self.l3(h2) 33 self.gl.append(h3) 34 self.i += 1 35 if self.i==1 : 36 y = self.l4(h3) 37 return y 38 else: 39 for num, ght in enumerate(h3): 40 we = F.matmul(self.gl[self.i-1][0], ght) 41 we = F.exp(we) # softmax関数を使って正規化する 42 weigh.append(we) 43 sum += we 44 if len(ght)==2: 45 batchsize = 2 46 weigh = np.array(weigh) 47 sum = np.array(sum) 48 weigh = weigh / sum 49 weigh = weigh.reshape(5, 1).repeat(60, axis=1) 50 att_f = np.array(h3) * np.array(weigh) 51 att_f = F.reshape(att_f, (5, 60)) 52 print(att_f.shape) 53 h3 = att_f 54 y = self.l4(h3) 55 return y
packages\chainer\utils\type_check.py", line 482, in expect
'{0} {1} {2}'.format(left, self.inv, right))
chainer.utils.type_check.InvalidType:
Invalid operation is performed in: LinearFunction (Forward)
Expect: in_types[0].dtype.kind == f
Actual: O != f
Process finished with exit code 1
このエラーが解決できないのですが調査してみた結果
h2を出力したとき以下のような値が出るのに対して
variable([[ 0.01771316 -0.01998963 0.00072347 -0.00977764 -0.02070782
0.01098927 0.01607024 -0.01494737 -0.04490452 0.04066181
h3を出力したとき以下のようになってるからではないかと考えたのですが直し方がわかりません。
variable([[variable(-0.00482196) variable(-0.00196443)
variable(0.00170961) variable(-0.00334108) variable(0.0049839)
あなたの回答
tips
プレビュー