やりたいこと
TensorflowでRankNetを実装しようと考えています。
想定しているモデルは図のような構造です。本来RankNetでは図中の緑枠内と赤枠内は同一のレイヤーを用いて計算を行うため、「やりたかった実装」のように中間モデルmodel_mid1
を複数回モデルに与えたいのですが、下記のようなエラーが返ってきてしまいます。
現在は「現在の実装」のように中間モデルをそれぞれ別のものとして実装をしたのですが、これではmodel_mid1
、model_mid2
で異なる計算をしてしまうため想定している構造ではありません。
最終的な用途として、複数の入力に対してmodel_mid
で推論を行い、その値を元に入力同士の順序関係を推測したいと考えています。そのため学習時はmodel.fit()
で学習し、推論時はmodel_mid.predict()
で推論を行いたいです。
どのようにすればモデル中に同一のレイヤーを複数回使用できるのでしょうか?
(機械学習に関して初学者なので用語の用法等に誤りがありましたらご指摘お願いします。)
実行環境
Python 3.8.3
tensorFlow 2.6.0
やりたかった実装
Python
1#実装したいコード 2dense_input = Input(shape=(LSTM_VAR_NUM,)) 3lstm_input = Input(shape=(None, DENSE_VAR_NUM,)) 4x = LSTM(64)(lstm_input) 5x = concatenate([dense_input,x]) 6x = Dense(64,activation='relu')(x) 7x = Dense(32,activation='relu')(x) 8output_mid1 = Dense(1,activation='linear')(x) 9 10model_mid = Model(inputs=[dense_input1,lstm_input1],outputs=output_mid1) 11 12z = Subtract()([model_mid.output, model_mid.output]) 13output = sigmoid(z) 14 15model = Model(inputs=[model_mid.input,model_mid.input],outputs=output)
エラー
error
1ValueError: The list of inputs passed to the model is redundant. All inputs should only appear once.
現在の実装
Python
1#実際に実装したコード 2dense_input1=Input(shape=(LSTM_VAR_NUM,)) 3lstm_input1 = Input(shape=(None, DENSE_VAR_NUM,)) 4x = LSTM(64)(lstm_input1) 5x = concatenate([dense_input1,x]) 6x = Dense(64,activation='relu')(x) 7x = Dense(32,activation='relu')(x) 8output_mid1 = Dense(1,activation='linear')(x) 9 10model_mid1 = Model(inputs=[dense_input1,lstm_input1],outputs=output_mid1) 11 12dense_input2=Input(shape=(LSTM_VAR_NUM,)) 13lstm_input2 = Input(shape=(None, DENSE_VAR_NUM,)) 14y = LSTM(64)(lstm_input2) 15y = concatenate([dense_input2,y]) 16y = Dense(64,activation='relu')(y) 17y = Dense(32,activation='relu')(y) 18output_mid2 = Dense(1,activation='linear')(y) 19 20model_mid2 = Model(inputs=[dense_input2,lstm_input2],outputs=output_mid1) 21 22z = Subtract()([model_mid1.output, model_mid2.output]) 23output = sigmoid(z) 24 25model = Model(inputs=[model_mid1.input,model_mid1.input],outputs=output)
回答1件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。