現在、chainerを用いる練習として以下のようなモデルを組んでいます。
python
1class MyChain(Chain): 2 def __init__(self): 3 super(MyChain, self).__init__( 4 l1 = L.Linear(361,64), 5 b1 = L.BatchNormalization(64), 6 l2 = L.Linear(64,32), 7 b2 = L.BatchNormalization(32), 8 l3 = L.Linear(32,16), 9 b3 = L.BatchNormalization(16), 10 l4 = L.Linear(16,8), 11 b4 = L.BatchNormalization(8), 12 13 l1_time = L.Linear(4,4), 14 b1_time = L.BatchNormalization(4), 15 l2_time = L.Linear(4,2), 16 b2_time = L.BatchNormalization(2), 17 l3_time = L.Linear(2,1), 18 19 l1_weather = L.Linear(2,2), 20 b1_weather = L.BatchNormalization(2,2), 21 l2_weather = L.Linear(2,2), 22 b2_weather = L.BatchNormalization(2,2), 23 l3_weather = L.Linear(2,1), 24 25 l1_ratio = L.Linear(10,6), 26 ) 27 28 def __call__(self, x, t): 29 y = self.fwd(x) 30 return F.mean_squared_error(y, t) 31 32 def fwd(self,x): 33 x_watt = x[:,0:361] 34 x_time = x[:,361:365] 35 x_weather = x[:,365:367] 36 37 h = self.b1(F.relu(self.l1(x_watt))) 38 h = self.b2(F.relu(self.l2(h))) 39 h = self.b3(F.relu(self.l3(h))) 40 h = F.relu(self.l4(h)) 41 print(h.shape) 42 print(h.dtype) 43 44 h_time = self.b1_time(F.relu(self.l1_time(x_time))) 45 h_time = self.b2_time(F.relu(self.l2_time(h_time))) 46 h_time = F.relu(self.l3_time(h_time)) 47 print(h_time.shape) 48 print(h_time.dtype) 49 50 h_weather = self.b1_weather(F.relu(self.l1_weather(x_weather))) 51 h_weather = self.b2_weather(F.relu(self.l2_weather(h_weather))) 52 h_weather = F.relu(self.l3_weather(h_weather)) 53 print(h_weather.shape) 54 print(h_weather.dtype) 55 56 h_ratio = np.hstack((h,h_time,h_weather)) 57 print(h_ratio.shape) 58 print(h_ratio.dtype) 59 h_ratio = F.softmax(self.l1_ratio(h_ratio)) 60 61 return h_ratio*x[:,180:181]
上記のモデルで訓練を行うと、以下のようなエラーが出てしまいます。
python
1 61 h_ratio = np.hstack((h,h_time,h_weather)) 2 62 print(h_ratio.dtype) 3---> 63 h_ratio = F.softmax(self.l1_ratio(h_ratio)) 4 64 5 65 return h_ratio*x[:,180:181] 6 7 8InvalidType: 9Invalid operation is performed in: LinearFunction (Forward) 10 11Expect: in_types[0].dtype.kind == f 12Actual: O != f
この際、print文による出力で、以下のようになっていることが確認できました。
python
1x_train.shape:(148320, 367) 2y_train.shape:(148320, 6) 3 4h.shape:(128, 8) 5h.dtype:float32 6h_time.shape:(128, 1) 7h_time.dtype:float32 8h_weather.shape:(128, 1) 9h_weather.dtype:float32 10h_ratio.shape:(128, 10) 11h_ratio.dtype:object
どうやら、np.hstackにおいて、配列の型がfloat32からobjectとなってしまっていることが原因のようなのですが、なぜこのようなことが起きてしまうのかがわかりません。
どのようにすれば解決することができるのでしょうか?
回答1件
あなたの回答
tips
プレビュー