chainerを用いたNNモデルのfwd内におけるエラーが解決できません。

解決済

回答 1

投稿 編集

  • 評価
  • クリップ 0
  • VIEW 757

futashige

score 26

現在、chainerを用いる練習として以下のようなモデルを組んでいます。

class MyChain(Chain):
    def __init__(self):
        super(MyChain, self).__init__(
            l1 = L.Linear(361,64),
            b1 = L.BatchNormalization(64),
            l2 = L.Linear(64,32),
            b2 = L.BatchNormalization(32),
            l3 = L.Linear(32,16),
            b3 = L.BatchNormalization(16),
            l4 = L.Linear(16,8),
            b4 = L.BatchNormalization(8),

            l1_time = L.Linear(4,4),
            b1_time = L.BatchNormalization(4),
            l2_time = L.Linear(4,2),
            b2_time = L.BatchNormalization(2),
            l3_time = L.Linear(2,1),

            l1_weather = L.Linear(2,2),
            b1_weather = L.BatchNormalization(2,2),
            l2_weather = L.Linear(2,2),
            b2_weather = L.BatchNormalization(2,2),
            l3_weather = L.Linear(2,1),

            l1_ratio = L.Linear(10,6),
        )

    def __call__(self, x, t):
        y = self.fwd(x)
        return F.mean_squared_error(y, t)

    def fwd(self,x):
        x_watt = x[:,0:361]
        x_time = x[:,361:365]
        x_weather = x[:,365:367]

        h = self.b1(F.relu(self.l1(x_watt)))
        h = self.b2(F.relu(self.l2(h)))
        h = self.b3(F.relu(self.l3(h)))
        h = F.relu(self.l4(h))
        print(h.shape)
        print(h.dtype)

        h_time = self.b1_time(F.relu(self.l1_time(x_time)))
        h_time = self.b2_time(F.relu(self.l2_time(h_time)))
        h_time = F.relu(self.l3_time(h_time))
        print(h_time.shape)
        print(h_time.dtype)

        h_weather = self.b1_weather(F.relu(self.l1_weather(x_weather)))
        h_weather = self.b2_weather(F.relu(self.l2_weather(h_weather)))
        h_weather = F.relu(self.l3_weather(h_weather))
        print(h_weather.shape)
        print(h_weather.dtype)

        h_ratio = np.hstack((h,h_time,h_weather))
        print(h_ratio.shape)
        print(h_ratio.dtype)
        h_ratio = F.softmax(self.l1_ratio(h_ratio))

        return h_ratio*x[:,180:181]

上記のモデルで訓練を行うと、以下のようなエラーが出てしまいます。

     61         h_ratio = np.hstack((h,h_time,h_weather))
     62         print(h_ratio.dtype)
---> 63         h_ratio = F.softmax(self.l1_ratio(h_ratio))
     64 
     65         return h_ratio*x[:,180:181]


InvalidType: 
Invalid operation is performed in: LinearFunction (Forward)

Expect: in_types[0].dtype.kind == f
Actual: O != f

この際、print文による出力で、以下のようになっていることが確認できました。

x_train.shape:(148320, 367)
y_train.shape:(148320, 6)

h.shape:(128, 8)
h.dtype:float32
h_time.shape:(128, 1)
h_time.dtype:float32
h_weather.shape:(128, 1)
h_weather.dtype:float32
h_ratio.shape:(128, 10)
h_ratio.dtype:object


どうやら、np.hstackにおいて、配列の型がfloat32からobjectとなってしまっていることが原因のようなのですが、なぜこのようなことが起きてしまうのかがわかりません。
どのようにすれば解決することができるのでしょうか?

  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

質問への追記・修正、ベストアンサー選択の依頼

  • 退会済みユーザー

    退会済みユーザー

    2018/03/05 13:20

    h_ratio = np.hstack((h,h_time,h_weather))のh, h_time, h_weatherとその結果に関する情報を追記願えませんでしょうか。

    キャンセル

  • futashige

    2018/03/05 15:20

    情報が足りず申し訳ございませんでした。追記させていただきました。

    キャンセル

回答 1

check解決した方法

0

chainer内部の関数であるF.hstack([x1,x2])を用いることで解決しました。

投稿

  • 回答の評価を上げる

    以下のような回答は評価を上げましょう

    • 正しい回答
    • わかりやすい回答
    • ためになる回答

    評価が高い回答ほどページの上位に表示されます。

  • 回答の評価を下げる

    下記のような回答は推奨されていません。

    • 間違っている回答
    • 質問の回答になっていない投稿
    • スパムや攻撃的な表現を用いた投稿

    評価を下げる際はその理由を明確に伝え、適切な回答に修正してもらいましょう。

15分調べてもわからないことは、teratailで質問しよう!

  • ただいまの回答率 90.21%
  • 質問をまとめることで、思考を整理して素早く解決
  • テンプレート機能で、簡単に質問をまとめられる