DnCNNの学習したモデルをFine Tuningしたいと考え下記のようなコードを作成しました。
しかし、Flattenへの形状?次元?が違うというエラーが発生し解決できません。
現状、model.summary()でもともとのモデルは確認できます。
import os, time, datetime import PIL.Image as Image import numpy as np from keras.optimizers import Adam from keras.models import load_model, model_from_json from keras.layers import Input,Conv2D,BatchNormalization,Activation,Subtract,Flatten from skimage.metrics import peak_signal_noise_ratio, structural_similarity from skimage.io import imread, imsave import keras.backend as K from time import time from keras.models import Sequential, Model t_start = time() # 開始時間 train_data = '/home/script/DnCNN-master/TrainingCodes/dncnn_keras/data/Train400' save_dir = '/home/script/DnCNN-master/TrainingCodes/dncnn_keras/models/000/' model_dir = '/home/script/DnCNN-master/TrainingCodes/dncnn_keras/models/DnCNN_sigma25/model_050.hdf5' epoch = 3 batch_size = 128 initial_epoch = 1000 # define loss def sum_squared_error(y_true, y_pred): #return K.mean(K.square(y_pred - y_true), axis=-1) #return K.sum(K.square(y_pred - y_true), axis=-1)/2 return K.sum(K.square(y_pred - y_true))/2 def train_datagen(epoch_iter=1000,epoch_num=5,batch_size=128,data_dir=train_data): while(True): n_count = 0 if n_count == 0: #print(n_count) xs = dg.datagenerator(data_dir) assert len(xs)%128 ==0, \ log('make sure the last iteration has a full batchsize, this is important if you use batch normalization!') xs = xs.astype('float32')/255.0 indices = list(range(xs.shape[0])) n_count = 1 for _ in range(epoch_num): np.random.shuffle(indices) # shuffle for i in range(0, len(indices), batch_size): batch_x = xs[indices[i:i+batch_size]] noise = np.random.normal(0, 25/255.0, batch_x.shape) # noise #noise = K.random_normal(ge_batch_y.shape, mean=0, stddev=args.sigma/255.0) batch_y = batch_x + noise yield batch_y, batch_x DnCNN_model = load_model(model_dir,custom_objects={'sum_squared_error': sum_squared_error}) print("Original model is ") DnCNN_model.summary() DnCNN_model.trainable = True model = Sequential() model.add(DnCNN_model) model.add(Flatten(input_shape=DnCNN_model.output_shape[1:])) print("Connected model is ") model.summary() model.compile(optimizer=Adam(0.001), loss=sum_squared_error) history = model.fit_generator(train_datagen(batch_size=batch_size), steps_per_epoch=1000, epochs=epoch, verbose=1, initial_epoch=initial_epoch) t_end = time() #終了時間 t_elapsed = t_end - t_start print("処理時間は{0}".format(t_elapsed))
エラー
Traceback (most recent call last): File "tune.py", line 58, in <module> model.add(Flatten(input_shape=DnCNN_model.output_shape[1:])) File "/usr/local/lib/python3.6/dist-packages/keras/engine/sequential.py", line 182, in add output_tensor = layer(self.outputs[0]) File "/usr/local/lib/python3.6/dist-packages/keras/engine/base_layer.py", line 468, in __call__ output_shape = self.compute_output_shape(input_shape) File "/usr/local/lib/python3.6/dist-packages/keras/layers/core.py", line 501, in compute_output_shape '(got ' + str(input_shape[1:]) + '). ' ValueError: The shape of the input to "Flatten" is not fully defined (got (None, None, 1)). Make sure to pass a complete "input_shape" or "batch_input_shape" argument to the first layer in your model.
DnCNN_model.summary()
DnCNN_model.output_shape
の結果を見せていただけますか。
エラーメッセージとしては、DnCNN_model.output_shape = (None, None, None, 1) になってしまっているように思われます。もしそのようになっていたら、DnCNN_modelの定義に問題があるので、その部分のコードを示すようにお願いします。
summaryの結果は下記です。
input0 (InputLayer) (None, None, None, 1 0
__________________________________________________________________________________________________
conv1 (Conv2D) (None, None, None, 6 640 input0[0][0]
__________________________________________________________________________________________________
relu2 (Activation) (None, None, None, 6 0 conv1[0][0]
__________________________________________________________________________________________________
conv3 (Conv2D) (None, None, None, 6 36864 relu2[0][0]
__________________________________________________________________________________________________
bn4 (BatchNormalization) (None, None, None, 6 256 conv3[0][0]
__________________________________________________________________________________________________
relu5 (Activation) (None, None, None, 6 0 bn4[0][0]
__________________________________________________________________________________________________
conv6 (Conv2D) (None, None, None, 6 36864 relu5[0][0]
__________________________________________________________________________________________________
bn7 (BatchNormalization) (None, None, None, 6 256 conv6[0][0]
__________________________________________________________________________________________________
relu8 (Activation) (None, None, None, 6 0 bn7[0][0]
__________________________________________________________________________________________________
conv9 (Conv2D) (None, None, None, 6 36864 relu8[0][0]
__________________________________________________________________________________________________
bn10 (BatchNormalization) (None, None, None, 6 256 conv9[0][0]
__________________________________________________________________________________________________
relu11 (Activation) (None, None, None, 6 0 bn10[0][0]
__________________________________________________________________________________________________
conv12 (Conv2D) (None, None, None, 6 36864 relu11[0][0]
__________________________________________________________________________________________________
bn13 (BatchNormalization) (None, None, None, 6 256 conv12[0][0]
__________________________________________________________________________________________________
relu14 (Activation) (None, None, None, 6 0 bn13[0][0]
__________________________________________________________________________________________________
conv15 (Conv2D) (None, None, None, 6 36864 relu14[0][0]
__________________________________________________________________________________________________
bn16 (BatchNormalization) (None, None, None, 6 256 conv15[0][0]
__________________________________________________________________________________________________
relu17 (Activation) (None, None, None, 6 0 bn16[0][0]
__________________________________________________________________________________________________
conv18 (Conv2D) (None, None, None, 6 36864 relu17[0][0]
__________________________________________________________________________________________________
bn19 (BatchNormalization) (None, None, None, 6 256 conv18[0][0]
__________________________________________________________________________________________________
relu20 (Activation) (None, None, None, 6 0 bn19[0][0]
__________________________________________________________________________________________________
conv21 (Conv2D) (None, None, None, 6 36864 relu20[0][0]
__________________________________________________________________________________________________
bn22 (BatchNormalization) (None, None, None, 6 256 conv21[0][0]
__________________________________________________________________________________________________
relu23 (Activation) (None, None, None, 6 0 bn22[0][0]
__________________________________________________________________________________________________
conv24 (Conv2D) (None, None, None, 6 36864 relu23[0][0]
__________________________________________________________________________________________________
bn25 (BatchNormalization) (None, None, None, 6 256 conv24[0][0]
__________________________________________________________________________________________________
relu26 (Activation) (None, None, None, 6 0 bn25[0][0]
__________________________________________________________________________________________________
conv27 (Conv2D) (None, None, None, 6 36864 relu26[0][0]
__________________________________________________________________________________________________
bn28 (BatchNormalization) (None, None, None, 6 256 conv27[0][0]
__________________________________________________________________________________________________
relu29 (Activation) (None, None, None, 6 0 bn28[0][0]
__________________________________________________________________________________________________
conv30 (Conv2D) (None, None, None, 6 36864 relu29[0][0]
__________________________________________________________________________________________________
bn31 (BatchNormalization) (None, None, None, 6 256 conv30[0][0]
__________________________________________________________________________________________________
relu32 (Activation) (None, None, None, 6 0 bn31[0][0]
__________________________________________________________________________________________________
conv33 (Conv2D) (None, None, None, 6 36864 relu32[0][0]
__________________________________________________________________________________________________
bn34 (BatchNormalization) (None, None, None, 6 256 conv33[0][0]
__________________________________________________________________________________________________
relu35 (Activation) (None, None, None, 6 0 bn34[0][0]
__________________________________________________________________________________________________
conv36 (Conv2D) (None, None, None, 6 36864 relu35[0][0]
__________________________________________________________________________________________________
bn37 (BatchNormalization) (None, None, None, 6 256 conv36[0][0]
__________________________________________________________________________________________________
relu38 (Activation) (None, None, None, 6 0 bn37[0][0]
__________________________________________________________________________________________________
conv39 (Conv2D) (None, None, None, 6 36864 relu38[0][0]
__________________________________________________________________________________________________
bn40 (BatchNormalization) (None, None, None, 6 256 conv39[0][0]
__________________________________________________________________________________________________
relu41 (Activation) (None, None, None, 6 0 bn40[0][0]
__________________________________________________________________________________________________
conv42 (Conv2D) (None, None, None, 6 36864 relu41[0][0]
__________________________________________________________________________________________________
bn43 (BatchNormalization) (None, None, None, 6 256 conv42[0][0]
__________________________________________________________________________________________________
relu44 (Activation) (None, None, None, 6 0 bn43[0][0]
__________________________________________________________________________________________________
conv45 (Conv2D) (None, None, None, 6 36864 relu44[0][0]
__________________________________________________________________________________________________
bn46 (BatchNormalization) (None, None, None, 6 256 conv45[0][0]
__________________________________________________________________________________________________
relu47 (Activation) (None, None, None, 6 0 bn46[0][0]
__________________________________________________________________________________________________
conv48 (Conv2D) (None, None, None, 1 576 relu47[0][0]
__________________________________________________________________________________________________
subtract49 (Subtract) (None, None, None, 1 0 input0[0][0]
conv48[0][0]
DnCNN_model.output_shapeはおっしゃるとおり(None, None, None, 1)になっています
使用しているDnCNNのモデルは下記のものを使用しています。
変更は一切加えていません。
https://github.com/cszn/DnCNN/tree/master/TrainingCodes/dncnn_keras
確かにリンク先のモデルは(None, None, None, 1)になりますね。使い方に工夫が必要なのでは・・・。
https://qiita.com/cvusk/items/9b822860bb2c501a0fe4
の方(3つめのコード)は、(None, 32, 32, 3)なので普通に使えると思います。
ちなみにリンク先に対して、変更は一切加えていないとのことですが、Usageは、main_train.py だけですね。そもそもどうやって使っているのでしょうか?どこかにFine Tuningでの使い方が掲載されているのですか?
main_train.pyで学習、main_test.pyで学習したモデルでテストを行っています。
Fine Tuningはmain_train.pyで学習したモデルを再学習したいと思い質問させて頂いています。
https://github.com/cszn/DnCNN/tree/master/TrainingCodes/dncnn_keras
の「main_train.py」の
# load the last model in matconvnet style
と書かれてるところを見ると、学習したときに作成される「model_数字.hdf5」というファイルがあれば、それを読み込んでから学習が始まるようなので、2回目以降の「main_train.py」の実行時には、自動的に
> main_train.pyで学習したモデルを再学習したい
となるのでは?
(「model_数字.hdf5」というファイルを削除してなければ)
今回はFlattenだけにしていますが層を追加したいと考えているのでmain_train.pyでは厳しいです。
> 層を追加したいと考えているのでmain_train.pyでは厳しい
「main_train.py」の
def DnCNN(...
の中で追加したらいいだけでは?
あなたの回答
tips
プレビュー