質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

ただいまの
回答率

90.35%

  • Python 3.x

    11178questions

    Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

Python3 自作データセットを使ったVAEの実装について

解決済

回答 1

投稿

  • 評価
  • クリップ 0
  • VIEW 243

SuzuAya

score 19

前提・実現したいこと

自作のデータセットでkerasのVAEを実装中です。
ラベルデータは使わず、trainデータのみ使用するコードに少し修正を加えたいのですが、サイズ指定のところがよく分かりません。
画像のデータは幅1536, 高さ496のものを使いたいと思っています。
コードをどう修正したら良いか、コメントを追加した部分を中心に、ご教示頂けますと幸いです。

該当のソースコード

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from keras.layers import Lambda, Input, Dense
from keras.models import Model
from keras.datasets import mnist
from keras.losses import mse, binary_crossentropy
from keras.utils import plot_model, np_utils 
from keras.utils import plot_model
from keras import backend as K
from keras.preprocessing.image import array_to_img, img_to_array, list_pictures, load_img  
from sklearn.model_selection import train_test_split 


import numpy as np
import matplotlib.pyplot as plt
import argparse
import os
import easydict

import warnings
warnings.filterwarnings('ignore')

%matplotlib inline

# reparameterization trick
# instead of sampling from Q(z|X), sample eps = N(0,I)
# z = z_mean + sqrt(var)*eps
def sampling(args):
    """Reparameterization trick by sampling fr an isotropic unit Gaussian.

    # Arguments
        args (tensor): mean and log of variance of Q(z|X)

    # Returns
        z (tensor): sampled latent vector
    """

    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    # by default, random_normal has mean=0 and std=1.0
    epsilon = K.rando
m_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon


def plot_results(models,
                 data,
                 batch_size=128,
                 model_name="vae_mnist"):
    """Plots labels and MNIST digits as function of 2-dim latent vector

    # Arguments
        models (tuple): encoder and decoder models
        data (tuple): test data and label
        batch_size (int): prediction batch size
        model_name (string): which model is using this function
    """

    encoder, decoder = models
    x_test = data #, y_test削除
    os.makedirs(model_name, exist_ok=True)

    filename = os.path.join(model_name, "vae_mean.png")
    # display a 2D plot of the digit classes in the latent space
    z_mean, _, _ = encoder.predict(x_test,
                                   batch_size=batch_size)
    plt.figure(figsize=(12, 10))
    plt.scatter(z_mean[:, 0], z_mean[:, 1])# c=y_testを削除
    plt.colorbar()
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.savefig(filename)
    plt.show()

    filename = os.path.join(model_name, "digits_over_latent.png")
    # display a 30x30 2D manifold of digits
    n = 30
    digit_size_width = 1536 #width追加
    digit_size_height = 496 #追加
    figure = np.zeros((digit_size_width * n, digit_size_height * n))
    # linearly spaced coordinates corresponding to the 2D plot
    # of digit classes in the latent space
    grid_x = np.linspace(-4, 4, n)
    grid_y = np.linspace(-4, 4, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi]])
            x_decoded = decoder.predict(z_sample)
            digit = x_decoded[0].reshape(digit_size, digit_size)#この部分はどう修正すれば良いでしょうか
            figure[i * digit_size: (i + 1) * digit_size,#この部分はどう修正すれば良いでしょうか
                   j * digit_size: (j + 1) * digit_size] = digit#この部分はどう修正すれば良いでしょうか

    plt.figure(figsize=(10, 10))
    start_range = digit_size // 2
    end_range = n * digit_size + start_range + 1
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.imshow(figure, cmap='Greys_r')
    plt.savefig(filename)
    plt.show()


### original dataset
x = []
y = []#ラベル付きデータが無い場合はこの行は不要でしょうか

for picture in list_pictures('./folder_a'):
    img = img_to_array(load_img(picture, target_size=(1536,496))) #サイズ指定
    x.append(img)
    y.append(0)#ラベル付きデータが無い場合はこの行は不要でしょうか

x = np.asarray(x)
y = np.asarray(y)#ラベル付きデータが無い場合はこの行は不要でしょうか

x = x.astype('float32')
x = x/ 255.0
y = np_utils.to_categorical(y, 1)#ラベル付きデータが無い場合はこの行は不要でしょうか 

x_train, x_test = train_test_split(x, test_size=0.2, random_state=111)#y_train, y_test削除
original_dim = 761856 #1536x496x1  
image_size_width = 1536 #width追加    
image_size_height = 496 #追加

# network parameters
input_shape = (image_size,image_size,1 )
batch_size = 128
latent_dim = 2
epochs = 50

# VAE model = encoder + decoder
# build encoder model
inputs = Input(shape=input_shape, name='encoder_input')
x = Dense(intermediate_dim, activation='relu')(inputs)
z_mean = Dense(latent_dim, name='z_mean')(x)
z_log_var = Dense(latent_dim, name='z_log_var')(x)

# use reparameterization trick to push the sampling out as input
# note that "output_shape" isn't necessary with the TensorFlow backend
z = Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var])

# instantiate encoder model
encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')
encoder.summary()
plot_model(encoder, to_file='vae_mlp_encoder.png', show_shapes=True)

# build decoder model
latent_inputs = Input(shape=(latent_dim,), name='z_sampling')
x = Dense(intermediate_dim, activation='relu')(latent_inputs)
outputs = Dense(original_dim, activation='sigmoid')(x)

# instantiate decoder model
decoder = Model(latent_inputs, outputs, name='decoder')
decoder.summary()
plot_model(decoder, to_file='vae_mlp_decoder.png', show_shapes=True)

# instantiate VAE model
outputs = decoder(encoder(inputs)[2])
vae = Model(inputs, outputs, name='vae_mlp')

if __name__ == '__main__':
    args = easydict.EasyDict({
        "batchsize": 50,
        "epoch": 50,
        "gpu": 0,
        "out": "result",
        "resume": False,
        "unit": 1000
})
    models = (encoder, decoder)
    data = (x_test)#y_test削除

    # VAE loss = mse_loss or xent_loss + kl_loss
    reconstruction_loss = binary_crossentropy(inputs,
                                                  outputs)

    reconstruction_loss *= original_dim
    kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
    kl_loss = K.sum(kl_loss, axis=-1)
    kl_loss *= -0.5
    vae_loss = K.mean(reconstruction_loss + kl_loss)
    vae.add_loss(vae_loss)
    vae.compile(optimizer='adam')
    vae.summary()
    plot_model(vae,
               to_file='vae_mlp.png',
               show_shapes=True)

    history = vae.fit(x_train,
              epochs=epochs,
              batch_size=batch_size,
              validation_data=(x_test, None))
    vae.save_weights('vae_CNN.h5')

    plot_results(models,
                 data,
                 batch_size=batch_size,
                 model_name="vae_CNN")
    plot_history(history)
  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

回答 1

check解決した方法

0

いろいろ試行錯誤していたら、とりあえず動くコードができましたので解決済みとさせていただきます。全コードは別質問(Python3 for文を使った学習回数の設定方法、学習結果の保存名の変更、メモリ解放について)に載せています。

投稿

編集

  • 回答の評価を上げる

    以下のような回答は評価を上げましょう

    • 正しい回答
    • わかりやすい回答
    • ためになる回答

    評価が高い回答ほどページの上位に表示されます。

  • 回答の評価を下げる

    下記のような回答は推奨されていません。

    • 間違っている回答
    • 質問の回答になっていない投稿
    • スパムや攻撃的な表現を用いた投稿

    評価を下げる際はその理由を明確に伝え、適切な回答に修正してもらいましょう。

15分調べてもわからないことは、teratailで質問しよう!

  • ただいまの回答率 90.35%
  • 質問をまとめることで、思考を整理して素早く解決
  • テンプレート機能で、簡単に質問をまとめられる

同じタグがついた質問を見る

  • Python 3.x

    11178questions

    Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。