質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

ただいまの
回答率

87.37%

Keras/TensorFlow カスタムレイヤーの call メソッドが呼び出せません

解決済

回答 1

投稿 編集

  • 評価
  • クリップ 0
  • VIEW 898

score 1

とある論文(有料論文のため本文・図の引用は避けます)で、非時系列データを含む LSTM モデルというものがあり、それを Keras/TensorFlow のカスタムレイヤーで作成しようとしています。
構造としては、時系列データは従来の LSTM と同様に入力ゲート、出力ゲート、忘却ゲート、ブロック入力に入力されますが、非時系列データは4つのうちブロック入力には入力されず、これにより記憶する必要のない非時系列データがメモリセルに記憶されず、非時系列データを含んだデータの学習精度が向上するといったものです。

時系列データと非時系列データを入力とするため、複数の入力としてリストでそれぞれ入力するようにしています。レイヤー内での計算は、下記の Qiita 記事を参考に call メソッドでその時刻での状態 h と、次の時刻に伝える状態 h と 記憶セルの出力 c のリストを返すようにしています。実行すると、call メソッドの呼び出しでエラーが出て行き詰まっています。init メソッド、もしくは build メソッドの中に誤りがあるように感じています。

Keras はよく使用しているのですが、カスタムレイヤーを作成するのは初めてですので、コードの誤りがあればご指摘いただければ幸いです。

参考ページ
[TensorFlow/Keras] 好きな構造のRNNを組み立てるまでの道のり - Qiita
python - Kerasで複数の入力を持つカスタムレイヤーを実装する方法 - ITツールウェブ

発生している問題・エラーメッセージ

下記のエラーが発生します。後に記載するコードの call メソッドへ LSTM の状態 states が入っていないようです。

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, *args, **kwargs)
    920                     not base_layer_utils.is_in_eager_or_tf_function()):
    921                   with auto_control_deps.AutomaticControlDependencies() as acd:
--> 922                     outputs = call_fn(cast_inputs, *args, **kwargs)
    923                     # Wrap Tensors in `outputs` in `tf.identity` to avoid
    924                     # circular dependencies.

/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
    263       except Exception as e:  # pylint:disable=broad-except
    264         if hasattr(e, 'ag_error_metadata'):
--> 265           raise e.ag_error_metadata.to_exception(e)
    266         else:
    267           raise

TypeError: in user code:


    TypeError: tf__call() missing 1 required positional argument: 'states'

該当のソースコード

import tensorflow as tf 
import numpy as np
from tensorflow.keras.layers import Input, Dense, RNN, AbstractRNNCell
from tensorflow.python.keras import activations, constraints, initializers, regularizers
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops

class NontsLSTM(AbstractRNNCell):
    def __init__(self,
                 units,
                 activation='tanh',
                 recurrent_activation='hard_sigmoid',
                 use_bias=True,
                 kernel_initializer='glorot_uniform',
                 recurrent_initializer='orthogonal',
                 bias_initializer='zeros',
                 unit_forget_bias=True,
                 kernel_regularizer=None,
                 recurrent_regularizer=None,
                 bias_regularizer=None,
                 kernel_constraint=None,
                 recurrent_constraint=None,
                 bias_constraint=None,
                 **kwargs):
        super(NontsLSTM, self).__init__(**kwargs)
        self.units = units
        self.activation = activations.get(activation)
        self.recurrent_activation = activations.get(recurrent_activation)
        self.use_bias = use_bias
        self.unit_forget_bias = unit_forget_bias

        self.kernel_initializer = initializers.get(kernel_initializer)
        self.recurrent_initializer = initializers.get(recurrent_initializer)
        self.bias_initializer = initializers.get(bias_initializer)

        self.kernel_regularizer = regularizers.get(kernel_regularizer)
        self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
        self.bias_regularizer = regularizers.get(bias_regularizer)

        self.kernel_constraint = constraints.get(kernel_constraint)
        self.recurrent_constraint = constraints.get(recurrent_constraint)
        self.bias_constraint = constraints.get(bias_constraint)

    @property
    def state_size(self):
        return [self.units, self.units]

    def build(self, input_shape):
        input_dim1 = input_shape[0][-1]     # 時系列データの入力
        input_dim2 = input_shape[1][-1]     # 非時系列データの入力
        self.kernel1          = self.add_weight(shape=(input_dim1, self.units * 4),
                                                name='kernel1',
                                                initializer=self.kernel_initializer,
                                                regularizer=self.kernel_regularizer,
                                                constraint=self.kernel_constraint)
        self.kernel2          = self.add_weight(shape=(input_dim2, self.units * 3),
                                                name='kernel2',
                                                initializer=self.kernel_initializer,
                                                regularizer=self.kernel_regularizer,
                                                constraint=self.kernel_constraint)
        self.recurrent_kernel = self.add_weight(shape=(self.units, self.units * 4),
                                                name='recurrent_kernel',
                                                initializer=self.recurrent_initializer,
                                                regularizer=self.recurrent_regularizer,
                                                constraint=self.recurrent_constraint)
        if self.use_bias:
            if self.unit_forget_bias:
                def bias_initializer(_, *args, **kwargs):
                    return K.cocatenate([
                        self.bias_initializer((self.units,), *args, **kwargs),
                        initializers.Ones()((self.units,), *args, **kwargs),
                        self.bias_initializer((self.units * 4,), *args, **kwargs),
                    ])
            else:
                bias_initializer = self.bias_initializer

            self.bias = self.add_weight(shape=(self.units * 4,),
                                        name='bias',
                                        initializer=self.bias_initializer,
                                        regularizer=self.bias_regularizer,
                                        constraint=self.bias_constraint)
        else:
            self.bias = None
        self.built = True

    def call(self, inputs, states, training=None):
        h_tm1 = states[0]   # 前時刻の隠れ状態
        c_tm1 = states[1]   # 前時刻のメモリセル状態

        # 時系列データの入力
        kt_f, kt_u, kt_i, kt_o = array_ops.split(self.kernel1, num_or_size_splits=4, axis=1)
        xt_f = K.dot(inputs[0], kt_f)
        xt_u = K.dot(inputs[0], kt_u)
        xt_i = K.dot(inputs[0], kt_i)
        xt_o = K.dot(inputs[0], kt_o)

        # 非時系列データの入力
        kn_f, kn_i, kn_o = array_ops.aplit(self.kernel2, num_or_size_splits=3, axis=1)
        xn_f = K.dot(inputs[1], kn_f)
        xn_i = K.dot(inputs[1], kn_i)
        xn_o = K.dot(inputs[1], kn_o)

        x_f = xt_f + xn_f
        x_u = xt_u
        x_i = xt_i + xn_i
        x_o = xt_o + xn_o

        if self.use_bias:
            b_f, b_u, b_i, b_o = array_ops.split(self.bias, num_or_size_splits=4, axis=0)
            x_f = K.bias_add(x_f, b_f)
            x_u = K.bias_add(x_u, b_u)
            x_i = K.bias_add(x_i, b_i)
            x_o = K.bias_add(x_o, b_o)

        f = self.recurrent_activation(x_f + K.dot(h_tm1, self.recurrent_kernel[:, :self.units]))                    # 忘却ゲート
        u = self.activation(x_u + K.dot(h_tm1, self.recurrene_kernel[:, self.units:self.units * 2]))                # ブロック入力
        i = self.recurrent_activation(x_i + K.dot(h_tm1, self.recurrent_kernel[:, self.units * 2:self.units * 3]))  # 入力ゲート
        o = self.recurrent_activation(x_o + K.dot(h_tm1, self.recurrent_kernel[:, self.units * 3:self.units * 4]))  # 出力ゲート

        c = f * c_tm1 + u * i
        h = self.activation(c) * o

        return h, [h, c]

    def compute_output_shape(self, input_shape):
        return input_shape[0][-1], self.units
t_input = Input(shape=(30, 1))
n_input = Input(shape=(30, 1))
h = NontsLSTM(128)([t_input, n_input])
output = Dense(1, activation='linear')(h)

model = Model([t_input, n_input], output)
model.compile(optimizer='adam', loss='mse', metrics='loss')

試したこと

TensorFlow の github も参照しましたが、先の Qiita 記事によると AbstractRNNCell を継承するのが良いということで記事内のコードを参考にしました。
デバッグプリントをしたところ、init メソッド、build メソッドは最後まで処理されていました。

  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 過去に投稿した質問と同じ内容の質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

回答 1

checkベストアンサー

+1

Cellレイヤを継承してそれを直接呼び出しているのが原因です。RNNでラップしてから呼び出して下さい。
例えば呼び出すソースコードのNontsLSTMの箇所をLSTMCellにすれば同じエラーが見られると思います。

参考URL

投稿

  • 回答の評価を上げる

    以下のような回答は評価を上げましょう

    • 正しい回答
    • わかりやすい回答
    • ためになる回答

    評価が高い回答ほどページの上位に表示されます。

  • 回答の評価を下げる

    下記のような回答は推奨されていません。

    • 間違っている回答
    • 質問の回答になっていない投稿
    • スパムや攻撃的な表現を用いた投稿

    評価を下げる際はその理由を明確に伝え、適切な回答に修正してもらいましょう。

  • 2020/06/23 10:52

    ご返信ありがとうございます。
    ご指摘の箇所を修正した上で手直しを加えるとエラーはなくなりました。
    複数のレイヤーを入力する場合はリストではなくタプルでないといけないようですね。
    ありがとうございました。

    キャンセル

15分調べてもわからないことは、teratailで質問しよう!

  • ただいまの回答率 87.37%
  • 質問をまとめることで、思考を整理して素早く解決
  • テンプレート機能で、簡単に質問をまとめられる

関連した質問

同じタグがついた質問を見る