変分オートエンコーダ(VAE)の実装にて、Kerasのサンプルコード(https://keras.io/examples/variational_autoencoder/)を利用したいと思ったのですが、
損失関数の部分の条件分岐が理解できません。
平均二乗誤差と交差エントロピー誤差をどう使い分けているのでしょうか。
if __name__ == '__main__': parser = argparse.ArgumentParser() help_ = "Load h5 model trained weights" parser.add_argument("-w", "--weights", help=help_) help_ = "Use mse loss instead of binary cross entropy (default)" parser.add_argument("-m", "--mse", help=help_, action='store_true') args = parser.parse_args() models = (encoder, decoder) data = (x_test, y_test) # VAE loss = mse_loss or xent_loss + kl_loss if args.mse: reconstruction_loss = mse(inputs, outputs) else: reconstruction_loss = binary_crossentropy(inputs, outputs) reconstruction_loss *= original_dim kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var) kl_loss = K.sum(kl_loss, axis=-1) kl_loss *= -0.5 vae_loss = K.mean(reconstruction_loss + kl_loss) vae.add_loss(vae_loss) vae.compile(optimizer='adam') vae.summary() plot_model(vae, to_file='vae_mlp.png', show_shapes=True) if args.weights: vae.load_weights(args.weights) else: # train the autoencoder vae.fit(x_train, epochs=epochs, batch_size=batch_size, validation_data=(x_test, None)) vae.save_weights('vae_mlp_mnist.h5') plot_results(models, data, batch_size=batch_size, model_name="vae_mlp")

回答1件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。
2020/01/17 09:07
2020/01/17 09:24
2020/01/20 05:34