前提・実現したいこと
時系列データの生成を行うため、kerasにてGANを構成しようと思っております。Generatorの入力次元、出力次元はそれぞれ、
input_shape = (n_data, max_len, dim_in)
output_shape = (n_data, dim_out)
とし、Discriminatorの入力次元、出力次元をそれぞれ、
input_shape = (1, n_data, dim_out)
output_shape = (1, 1)
にしたいと考えております。GeneratorとDiscriminatorのうまく接続する方法はございますでしょうか?
該当のソースコード
Python
1import numpy as np 2from tensorflow.keras.models import Model 3from tensorflow.keras.layers import Input, LSTM, Dense 4 5# 任意の数 6n_data = 2 7dim_in = 3 8dim_out = 5 9max_length = 7 10dim_hidden = 11 11 12d_in = Input(shape=(n_data, dim_out, )) 13d = LSTM(dim_hidden)(d_in) 14d_out = Dense(1, activation='sigmoid')(d) 15discriminator = Model(d_in, d_out) 16 17g_in = Input(shape=(max_len, dim_in, )) 18g = LSTM(dim_hidden)(g_in) 19g_out = Dense(dim_out)(g) 20generator = Model(g_in, g_out) 21 22#combined modelの構築が分かりません。 23 24x = np.random.normal(0, 1, (n_data, max_len, dim_in)) 25gen_x = generator.predict(x) #gen_x.shape = (n_data, dim_out) 26gen_x = gen_x.reshape(1, n_data, dim_out) 27y = discriminator.predict(gen_x) # y.shape = (1, 1)
試したこと
kerasでは、入力と出力でのバッチサイズを同じにする必要があるというのを見つけました。GeneratorとDiscriminatorのバッチサイズを同じにするために、Generatorの入力を4階テンソルにすればよいのではないかと思いました。つまり、Generatorの入力次元、出力次元をそれぞれ、
input_shape = (batch_size, n_data, max_length, dim_in)
out_shape = (batch_size, n_data, dim_out)
にし、Discriminatorの入力次元、出力次元をそれぞれ、
input_shape = (batch_size, n_data, dim_out)
output_shape = (batch_size, 1)
にしようと思いました。しかし、公式ドキュメントによると、LSTMの入力は3階テンソルのみで、4階テンソルを入力にしようとすると、エラーが発生しました。
Python
1batch_size = 13 2 3d_in = Input(shape=(n_data, dim_out, )) 4d = LSTM(dim_hidden)(d_in) 5d_out = Dense(1, activation='sigmoid')(d) 6discriminator = Model(d_in, d_out) 7 8g_in = Input(shape=(n_data, max_len, dim_in, )) 9g = LSTM(dim_hidden)(g_in) #ここでエラーが発生 10g_out = Dense(dim_out)(g) 11generator = Model(g_in, g_out) 12 13#combined modelの構築が分かりません。 14 15x = np.random.normal(0, 1, (batch_size, n_data, max_len, dim_in)) 16gen_x = generator.predict(x) #理想は、gen_x.shape = (batch_size, n_data, dim_out) 17y = discriminator.predict(gen_x) # y.shape = (batch_size, 1)
補足情報(FW/ツールのバージョンなど)
Python 3.6.5
tensorflow-gpu 2.1.0