前提・実現したいこと
kerasで入力データに生じた欠損を補間すること目的とした3次元のGANの構築を目的としています.Generatorをオートエンコーダとして実装しGeneratorの入力を欠損の生じたデータとしてGeneratorの出力で欠損が補間されたデータを出力します.
そこでGANのモデルの定義のところでGeneratorとDiscriminatorを接続するcombinedモデルの定義の際にGeneratorの出力をgen_volとしたときにこのGeneratorの出力のgen_volに「np.where(gen_vol < 0.999, 1, 0)」と同様の処理をして,gen_volのデータの値が0.999以下のところには1をそれ以外には0を格納し,この処理を行ったものをcombinedモデルでDiscriminatorへの入力をしたいと思っております.
ここではnp.whereの代わりにtf.whereを用いて実現しようとしていますがうまくいきません.どのようにすれば目的のことを実現できるでしょうか.またそもそもcombinedモデルの定義のところでGeneratorの出力にこのように処理を施してcombinedモデルを定義することは可能なのか,Generatorの出力に「np.where(gen_vol < 0.999, 1, 0)」の処理を加えたものをDiscriminatorに入力できればいいので,何か他にもいい方法があればお願いします
よろしくお願いします.
該当のソースコード
python
1from __future__ import print_function, division 2 3import os 4import sys 5import glob 6import random 7import pickle 8from mpl_toolkits.mplot3d import Axes3D # you should keep the import 9import matplotlib.pyplot as plt 10import numpy as np 11import open3d as o3d 12from keras.layers import BatchNormalization, Activation 13from keras.layers import Input, Dense, Flatten, Dropout, concatenate 14from keras.layers.advanced_activations import LeakyReLU 15from keras.layers.convolutional import UpSampling3D, Conv3D, Deconv3D 16from keras.models import Sequential, Model 17from keras.models import load_model 18from keras.optimizers import Adam 19from sklearn.metrics import hamming_loss 20from utils import mkdirs 21from keras import backend as K 22import tensorflow as tf 23 24# Cudnn の例外を抑制 25if 'tensorflow' == K.backend(): 26 import tensorflow as tf 27from keras.backend.tensorflow_backend import set_session 28config = tf.ConfigProto() 29config.gpu_options.allow_growth = True 30config.gpu_options.visible_device_list = "0" 31set_session(tf.Session(config=config)) 32 33# オプション 34MODEL_DIR = './models/ivy_UNet' 35DATASET_DIR = './datasets/ivy' 36MAX_GRID = 32 # ボクセルグリッドの最大値 37GENERATOR_TYPE = "U-Net" # "U-Net":U-Netを利用 38 39mkdirs(os.path.join(MODEL_DIR, 'images')) 40mkdirs(os.path.join(MODEL_DIR, 'saved_model')) 41 42class EncoderDecoderGAN(): 43 def __init__(self): 44 self.g_type = GENERATOR_TYPE 45 self.channels = 1 46 self.num_classes = 2 47 self.vol_shape = (MAX_GRID, MAX_GRID, MAX_GRID, self.channels) 48 49 optimizer = Adam(0.0002, 0.5) 50 51 try: 52 self.discriminator = load_model(os.path.join(MODEL_DIR, 'saved_model', 'discriminator.h5')) 53 self.generator = load_model(os.path.join(MODEL_DIR, 'saved_model', 'generator.h5')) 54 55 print("Loaded checkpoints") 56 except: 57 58 print("No checkpoints found") 59 if self.g_type == "U-Net": 60 self.generator = self.build_generator_UNet() 61 print("Build generator with U-Net") 62 else: 63 self.generator = self.build_generator() 64 print("Build generator") 65 66 self.discriminator = self.build_discriminator() 67 68 # discriminator 69 self.discriminator.compile(loss='binary_crossentropy', 70 optimizer=optimizer, 71 metrics=['accuracy']) 72 73 # generator 74 75 # The generator takes noise as input and generates the missing part 76 masked_vol = Input(shape=self.vol_shape) 77 gen_vol = self.generator(masked_vol) 78 79 # For the combined model we will only train the generator 80 self.discriminator.trainable = False 81 82 # The discriminator takes generated voxels as input and determines 83 # if it is generated or if it is a real voxels 84 gen_vols = tf.where(gen_vol <= 0.999,tf.ones(tf.shape(gen_vol)),tf.zeros(tf.shape(gen_vol))) 85 valid = self.discriminator(gen_vols) 86 87 # The combined model (stacked generator and discriminator) 88 # Trains generator to fool discriminator 89 self.combined = Model(masked_vol, [gen_vol, valid]) 90 self.combined.compile(loss=['mse', 'binary_crossentropy'], 91 loss_weights=[0.999, 0.001], 92 optimizer=optimizer) 93 94 95 96 def build_generator_UNet(self): 97 98 inputs = Input(self.vol_shape) 99 100 # Encoder 101 conv1 = Conv3D(64, kernel_size=5, strides=2, padding="same")(inputs) 102 conv1 = LeakyReLU(alpha=0.2)(conv1) 103 conv1 = BatchNormalization(momentum=0.8)(conv1) 104 105 conv2 = Conv3D(128, kernel_size=5, strides=2, padding="same")(conv1) 106 conv2 = LeakyReLU(alpha=0.2)(conv2) 107 conv2 = BatchNormalization(momentum=0.8)(conv2) 108 109 conv3 = Conv3D(256, kernel_size=5, strides=2, padding="same")(conv2) 110 conv3 = LeakyReLU(alpha=0.2)(conv3) 111 conv3 = Dropout(0.5)(conv3) 112 113 # Decoder 114 up4 = UpSampling3D()(conv3) 115 up4 = Deconv3D(128, kernel_size=5, padding="same")(up4) 116 up4 = LeakyReLU(alpha=0.2)(up4) 117 up4 = BatchNormalization(momentum=0.8)(up4) 118 merge4 = concatenate([conv2,up4], axis = 4) 119 120 up5 = UpSampling3D()(merge4) 121 up5 = Deconv3D(64, kernel_size=5, padding="same")(up5) 122 up5 = LeakyReLU(alpha=0.2)(up5) 123 up5 = BatchNormalization(momentum=0.8)(up5) 124 merge5 = concatenate([conv1,up5], axis = 4) 125 126 up6 = UpSampling3D()(merge5) 127 128 outputs = Deconv3D(self.channels, kernel_size=5, padding="same")(up6) 129 outputs = Activation('tanh')(outputs) 130 131 masked_vol = inputs 132 gen_missing = outputs 133 134 model = Model(masked_vol, gen_missing) 135 model.summary(line_length=200) 136 137 return Model(masked_vol, gen_missing) 138 139 def build_discriminator(self): 140 141 inputs = Input(self.vol_shape) 142 143 # Encoder 144 conv1 = Conv3D(64, kernel_size=5, strides=2, padding="same")(inputs) 145 conv1 = LeakyReLU(alpha=0.2)(conv1) 146 conv1 = BatchNormalization(momentum=0.8)(conv1) 147 148 conv2 = Conv3D(128, kernel_size=5, strides=2, padding="same")(conv1) 149 conv2 = LeakyReLU(alpha=0.2)(conv2) 150 conv2 = BatchNormalization(momentum=0.8)(conv2) 151 152 conv3 = Conv3D(256, kernel_size=5, strides=2, padding="same")(conv2) 153 conv3 = LeakyReLU(alpha=0.2)(conv3) 154 conv3 = BatchNormalization(momentum=0.8)(conv3) 155 156 outputs = Flatten()(conv3) 157 outputs = Dense(1, activation='sigmoid')(outputs) 158 159 vol = inputs 160 validity = outputs 161 162 model = Model(vol, validity) 163 model.summary(line_length=200) 164 165 return Model(vol, validity)
発生している問題・エラーメッセージ
ValueError: The shape of the input to "Flatten" is not fully defined (got (None, None, None, 256)). Make sure to pass a complete "input_shape" or "batch_input_shape" argument to the first layer in your model.
試したこと
エラーメッセージについて調べてみたところ,すでに存在するモデルをロードした際に出ているエラーというようなものしか出てきませんでした.自分はすでにあるモデルをロードすることはしていないと思うのでこのエラーの解決策や意味が分かりません.
補足情報(FW/ツールのバージョンなど)
python3
cuda10.0
cudnn7.6.3
tensorflow1.15.0
keras2.3.1
googlecolabで実行しています
あなたの回答
tips
プレビュー