質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%
Keras

Kerasは、TheanoやTensorFlow/CNTK対応のラッパーライブラリです。DeepLearningの数学的部分を短いコードでネットワークとして表現することが可能。DeepLearningの最新手法を迅速に試すことができます。

深層学習

深層学習は、多数のレイヤのニューラルネットワークによる機械学習手法。人工知能研究の一つでディープラーニングとも呼ばれています。コンピューター自体がデータの潜在的な特徴を汲み取り、効率的で的確な判断を実現することができます。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

解決済

1回答

1861閲覧

Keras/TensorFlow カスタムレイヤーの call メソッドが呼び出せません

ima_chan1107

総合スコア1

Keras

Kerasは、TheanoやTensorFlow/CNTK対応のラッパーライブラリです。DeepLearningの数学的部分を短いコードでネットワークとして表現することが可能。DeepLearningの最新手法を迅速に試すことができます。

深層学習

深層学習は、多数のレイヤのニューラルネットワークによる機械学習手法。人工知能研究の一つでディープラーニングとも呼ばれています。コンピューター自体がデータの潜在的な特徴を汲み取り、効率的で的確な判断を実現することができます。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

0クリップ

投稿2020/06/22 07:37

編集2020/06/29 05:58

とある論文(有料論文のため本文・図の引用は避けます)で、非時系列データを含む LSTM モデルというものがあり、それを Keras/TensorFlow のカスタムレイヤーで作成しようとしています。
構造としては、時系列データは従来の LSTM と同様に入力ゲート、出力ゲート、忘却ゲート、ブロック入力に入力されますが、非時系列データは4つのうちブロック入力には入力されず、これにより記憶する必要のない非時系列データがメモリセルに記憶されず、非時系列データを含んだデータの学習精度が向上するといったものです。

時系列データと非時系列データを入力とするため、複数の入力としてリストでそれぞれ入力するようにしています。レイヤー内での計算は、下記の Qiita 記事を参考に call メソッドでその時刻での状態 h と、次の時刻に伝える状態 h と 記憶セルの出力 c のリストを返すようにしています。実行すると、call メソッドの呼び出しでエラーが出て行き詰まっています。init メソッド、もしくは build メソッドの中に誤りがあるように感じています。

Keras はよく使用しているのですが、カスタムレイヤーを作成するのは初めてですので、コードの誤りがあればご指摘いただければ幸いです。

参考ページ
[TensorFlow/Keras] 好きな構造のRNNを組み立てるまでの道のり - Qiita
python - Kerasで複数の入力を持つカスタムレイヤーを実装する方法 - ITツールウェブ

発生している問題・エラーメッセージ

下記のエラーが発生します。後に記載するコードの call メソッドへ LSTM の状態 states が入っていないようです。

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, *args, **kwargs) 920 not base_layer_utils.is_in_eager_or_tf_function()): 921 with auto_control_deps.AutomaticControlDependencies() as acd: --> 922 outputs = call_fn(cast_inputs, *args, **kwargs) 923 # Wrap Tensors in `outputs` in `tf.identity` to avoid 924 # circular dependencies. /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs) 263 except Exception as e: # pylint:disable=broad-except 264 if hasattr(e, 'ag_error_metadata'): --> 265 raise e.ag_error_metadata.to_exception(e) 266 else: 267 raise TypeError: in user code: TypeError: tf__call() missing 1 required positional argument: 'states'

該当のソースコード

Python

1import tensorflow as tf 2import numpy as np 3from tensorflow.keras.layers import Input, Dense, RNN, AbstractRNNCell 4from tensorflow.python.keras import activations, constraints, initializers, regularizers 5from tensorflow.python.keras import backend as K 6from tensorflow.python.keras.utils import tf_utils 7from tensorflow.python.ops import array_ops 8 9class NontsLSTM(AbstractRNNCell): 10 def __init__(self, 11 units, 12 activation='tanh', 13 recurrent_activation='hard_sigmoid', 14 use_bias=True, 15 kernel_initializer='glorot_uniform', 16 recurrent_initializer='orthogonal', 17 bias_initializer='zeros', 18 unit_forget_bias=True, 19 kernel_regularizer=None, 20 recurrent_regularizer=None, 21 bias_regularizer=None, 22 kernel_constraint=None, 23 recurrent_constraint=None, 24 bias_constraint=None, 25 **kwargs): 26 super(NontsLSTM, self).__init__(**kwargs) 27 self.units = units 28 self.activation = activations.get(activation) 29 self.recurrent_activation = activations.get(recurrent_activation) 30 self.use_bias = use_bias 31 self.unit_forget_bias = unit_forget_bias 32 33 self.kernel_initializer = initializers.get(kernel_initializer) 34 self.recurrent_initializer = initializers.get(recurrent_initializer) 35 self.bias_initializer = initializers.get(bias_initializer) 36 37 self.kernel_regularizer = regularizers.get(kernel_regularizer) 38 self.recurrent_regularizer = regularizers.get(recurrent_regularizer) 39 self.bias_regularizer = regularizers.get(bias_regularizer) 40 41 self.kernel_constraint = constraints.get(kernel_constraint) 42 self.recurrent_constraint = constraints.get(recurrent_constraint) 43 self.bias_constraint = constraints.get(bias_constraint) 44 45 @property 46 def state_size(self): 47 return [self.units, self.units] 48 49 def build(self, input_shape): 50 input_dim1 = input_shape[0][-1] # 時系列データの入力 51 input_dim2 = input_shape[1][-1] # 非時系列データの入力 52 self.kernel1 = self.add_weight(shape=(input_dim1, self.units * 4), 53 name='kernel1', 54 initializer=self.kernel_initializer, 55 regularizer=self.kernel_regularizer, 56 constraint=self.kernel_constraint) 57 self.kernel2 = self.add_weight(shape=(input_dim2, self.units * 3), 58 name='kernel2', 59 initializer=self.kernel_initializer, 60 regularizer=self.kernel_regularizer, 61 constraint=self.kernel_constraint) 62 self.recurrent_kernel = self.add_weight(shape=(self.units, self.units * 4), 63 name='recurrent_kernel', 64 initializer=self.recurrent_initializer, 65 regularizer=self.recurrent_regularizer, 66 constraint=self.recurrent_constraint) 67 if self.use_bias: 68 if self.unit_forget_bias: 69 def bias_initializer(_, *args, **kwargs): 70 return K.cocatenate([ 71 self.bias_initializer((self.units,), *args, **kwargs), 72 initializers.Ones()((self.units,), *args, **kwargs), 73 self.bias_initializer((self.units * 4,), *args, **kwargs), 74 ]) 75 else: 76 bias_initializer = self.bias_initializer 77 78 self.bias = self.add_weight(shape=(self.units * 4,), 79 name='bias', 80 initializer=self.bias_initializer, 81 regularizer=self.bias_regularizer, 82 constraint=self.bias_constraint) 83 else: 84 self.bias = None 85 self.built = True 86 87 def call(self, inputs, states, training=None): 88 h_tm1 = states[0] # 前時刻の隠れ状態 89 c_tm1 = states[1] # 前時刻のメモリセル状態 90 91 # 時系列データの入力 92 kt_f, kt_u, kt_i, kt_o = array_ops.split(self.kernel1, num_or_size_splits=4, axis=1) 93 xt_f = K.dot(inputs[0], kt_f) 94 xt_u = K.dot(inputs[0], kt_u) 95 xt_i = K.dot(inputs[0], kt_i) 96 xt_o = K.dot(inputs[0], kt_o) 97 98 # 非時系列データの入力 99 kn_f, kn_i, kn_o = array_ops.aplit(self.kernel2, num_or_size_splits=3, axis=1) 100 xn_f = K.dot(inputs[1], kn_f) 101 xn_i = K.dot(inputs[1], kn_i) 102 xn_o = K.dot(inputs[1], kn_o) 103 104 x_f = xt_f + xn_f 105 x_u = xt_u 106 x_i = xt_i + xn_i 107 x_o = xt_o + xn_o 108 109 if self.use_bias: 110 b_f, b_u, b_i, b_o = array_ops.split(self.bias, num_or_size_splits=4, axis=0) 111 x_f = K.bias_add(x_f, b_f) 112 x_u = K.bias_add(x_u, b_u) 113 x_i = K.bias_add(x_i, b_i) 114 x_o = K.bias_add(x_o, b_o) 115 116 f = self.recurrent_activation(x_f + K.dot(h_tm1, self.recurrent_kernel[:, :self.units])) # 忘却ゲート 117 u = self.activation(x_u + K.dot(h_tm1, self.recurrene_kernel[:, self.units:self.units * 2])) # ブロック入力 118 i = self.recurrent_activation(x_i + K.dot(h_tm1, self.recurrent_kernel[:, self.units * 2:self.units * 3])) # 入力ゲート 119 o = self.recurrent_activation(x_o + K.dot(h_tm1, self.recurrent_kernel[:, self.units * 3:self.units * 4])) # 出力ゲート 120 121 c = f * c_tm1 + u * i 122 h = self.activation(c) * o 123 124 return h, [h, c] 125 126 def compute_output_shape(self, input_shape): 127 return input_shape[0][-1], self.units 128

Python

1t_input = Input(shape=(30, 1)) 2n_input = Input(shape=(30, 1)) 3h = NontsLSTM(128)([t_input, n_input]) 4output = Dense(1, activation='linear')(h) 5 6model = Model([t_input, n_input], output) 7model.compile(optimizer='adam', loss='mse', metrics='loss')

試したこと

TensorFlow の github も参照しましたが、先の Qiita 記事によると AbstractRNNCell を継承するのが良いということで記事内のコードを参考にしました。
デバッグプリントをしたところ、init メソッド、build メソッドは最後まで処理されていました。

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

0

ベストアンサー

Cellレイヤを継承してそれを直接呼び出しているのが原因です。RNNでラップしてから呼び出して下さい。
例えば呼び出すソースコードのNontsLSTMの箇所をLSTMCellにすれば同じエラーが見られると思います。

参考URL

投稿2020/06/22 15:57

yymmt

総合スコア1615

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

ima_chan1107

2020/06/23 01:52

ご返信ありがとうございます。 ご指摘の箇所を修正した上で手直しを加えるとエラーはなくなりました。 複数のレイヤーを入力する場合はリストではなくタプルでないといけないようですね。 ありがとうございました。
guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問