質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.35%
Keras

Kerasは、TheanoやTensorFlow/CNTK対応のラッパーライブラリです。DeepLearningの数学的部分を短いコードでネットワークとして表現することが可能。DeepLearningの最新手法を迅速に試すことができます。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Q&A

0回答

919閲覧

Fine Tuning実行時のエラー

xeno

総合スコア16

Keras

Kerasは、TheanoやTensorFlow/CNTK対応のラッパーライブラリです。DeepLearningの数学的部分を短いコードでネットワークとして表現することが可能。DeepLearningの最新手法を迅速に試すことができます。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

0グッド

1クリップ

投稿2021/07/27 05:46

DnCNNの学習したモデルをFine Tuningしたいと考え下記のようなコードを作成しました。
しかし、Flattenへの形状?次元?が違うというエラーが発生し解決できません。
現状、model.summary()でもともとのモデルは確認できます。

import os, time, datetime import PIL.Image as Image import numpy as np from keras.optimizers import Adam from keras.models import load_model, model_from_json from keras.layers import Input,Conv2D,BatchNormalization,Activation,Subtract,Flatten from skimage.metrics import peak_signal_noise_ratio, structural_similarity from skimage.io import imread, imsave import keras.backend as K from time import time from keras.models import Sequential, Model t_start = time() # 開始時間 train_data = '/home/script/DnCNN-master/TrainingCodes/dncnn_keras/data/Train400' save_dir = '/home/script/DnCNN-master/TrainingCodes/dncnn_keras/models/000/' model_dir = '/home/script/DnCNN-master/TrainingCodes/dncnn_keras/models/DnCNN_sigma25/model_050.hdf5' epoch = 3 batch_size = 128 initial_epoch = 1000 # define loss def sum_squared_error(y_true, y_pred): #return K.mean(K.square(y_pred - y_true), axis=-1) #return K.sum(K.square(y_pred - y_true), axis=-1)/2 return K.sum(K.square(y_pred - y_true))/2 def train_datagen(epoch_iter=1000,epoch_num=5,batch_size=128,data_dir=train_data): while(True): n_count = 0 if n_count == 0: #print(n_count) xs = dg.datagenerator(data_dir) assert len(xs)%128 ==0, \ log('make sure the last iteration has a full batchsize, this is important if you use batch normalization!') xs = xs.astype('float32')/255.0 indices = list(range(xs.shape[0])) n_count = 1 for _ in range(epoch_num): np.random.shuffle(indices) # shuffle for i in range(0, len(indices), batch_size): batch_x = xs[indices[i:i+batch_size]] noise = np.random.normal(0, 25/255.0, batch_x.shape) # noise #noise = K.random_normal(ge_batch_y.shape, mean=0, stddev=args.sigma/255.0) batch_y = batch_x + noise yield batch_y, batch_x DnCNN_model = load_model(model_dir,custom_objects={'sum_squared_error': sum_squared_error}) print("Original model is ") DnCNN_model.summary() DnCNN_model.trainable = True model = Sequential() model.add(DnCNN_model) model.add(Flatten(input_shape=DnCNN_model.output_shape[1:])) print("Connected model is ") model.summary() model.compile(optimizer=Adam(0.001), loss=sum_squared_error) history = model.fit_generator(train_datagen(batch_size=batch_size), steps_per_epoch=1000, epochs=epoch, verbose=1, initial_epoch=initial_epoch) t_end = time() #終了時間 t_elapsed = t_end - t_start print("処理時間は{0}".format(t_elapsed))

エラー

Traceback (most recent call last): File "tune.py", line 58, in <module> model.add(Flatten(input_shape=DnCNN_model.output_shape[1:])) File "/usr/local/lib/python3.6/dist-packages/keras/engine/sequential.py", line 182, in add output_tensor = layer(self.outputs[0]) File "/usr/local/lib/python3.6/dist-packages/keras/engine/base_layer.py", line 468, in __call__ output_shape = self.compute_output_shape(input_shape) File "/usr/local/lib/python3.6/dist-packages/keras/layers/core.py", line 501, in compute_output_shape '(got ' + str(input_shape[1:]) + '). ' ValueError: The shape of the input to "Flatten" is not fully defined (got (None, None, 1)). Make sure to pass a complete "input_shape" or "batch_input_shape" argument to the first layer in your model.

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

toast-uz

2021/07/27 23:05

DnCNN_model.summary() DnCNN_model.output_shape の結果を見せていただけますか。 エラーメッセージとしては、DnCNN_model.output_shape = (None, None, None, 1) になってしまっているように思われます。もしそのようになっていたら、DnCNN_modelの定義に問題があるので、その部分のコードを示すようにお願いします。
xeno

2021/07/28 04:51

summaryの結果は下記です。 input0 (InputLayer) (None, None, None, 1 0 __________________________________________________________________________________________________ conv1 (Conv2D) (None, None, None, 6 640 input0[0][0] __________________________________________________________________________________________________ relu2 (Activation) (None, None, None, 6 0 conv1[0][0] __________________________________________________________________________________________________ conv3 (Conv2D) (None, None, None, 6 36864 relu2[0][0] __________________________________________________________________________________________________ bn4 (BatchNormalization) (None, None, None, 6 256 conv3[0][0] __________________________________________________________________________________________________ relu5 (Activation) (None, None, None, 6 0 bn4[0][0] __________________________________________________________________________________________________ conv6 (Conv2D) (None, None, None, 6 36864 relu5[0][0] __________________________________________________________________________________________________ bn7 (BatchNormalization) (None, None, None, 6 256 conv6[0][0] __________________________________________________________________________________________________ relu8 (Activation) (None, None, None, 6 0 bn7[0][0] __________________________________________________________________________________________________ conv9 (Conv2D) (None, None, None, 6 36864 relu8[0][0] __________________________________________________________________________________________________ bn10 (BatchNormalization) (None, None, None, 6 256 conv9[0][0] __________________________________________________________________________________________________ relu11 (Activation) (None, None, None, 6 0 bn10[0][0] __________________________________________________________________________________________________ conv12 (Conv2D) (None, None, None, 6 36864 relu11[0][0] __________________________________________________________________________________________________ bn13 (BatchNormalization) (None, None, None, 6 256 conv12[0][0] __________________________________________________________________________________________________ relu14 (Activation) (None, None, None, 6 0 bn13[0][0] __________________________________________________________________________________________________ conv15 (Conv2D) (None, None, None, 6 36864 relu14[0][0] __________________________________________________________________________________________________ bn16 (BatchNormalization) (None, None, None, 6 256 conv15[0][0] __________________________________________________________________________________________________ relu17 (Activation) (None, None, None, 6 0 bn16[0][0] __________________________________________________________________________________________________ conv18 (Conv2D) (None, None, None, 6 36864 relu17[0][0] __________________________________________________________________________________________________ bn19 (BatchNormalization) (None, None, None, 6 256 conv18[0][0] __________________________________________________________________________________________________ relu20 (Activation) (None, None, None, 6 0 bn19[0][0] __________________________________________________________________________________________________ conv21 (Conv2D) (None, None, None, 6 36864 relu20[0][0] __________________________________________________________________________________________________ bn22 (BatchNormalization) (None, None, None, 6 256 conv21[0][0] __________________________________________________________________________________________________ relu23 (Activation) (None, None, None, 6 0 bn22[0][0] __________________________________________________________________________________________________ conv24 (Conv2D) (None, None, None, 6 36864 relu23[0][0] __________________________________________________________________________________________________ bn25 (BatchNormalization) (None, None, None, 6 256 conv24[0][0] __________________________________________________________________________________________________ relu26 (Activation) (None, None, None, 6 0 bn25[0][0] __________________________________________________________________________________________________ conv27 (Conv2D) (None, None, None, 6 36864 relu26[0][0] __________________________________________________________________________________________________ bn28 (BatchNormalization) (None, None, None, 6 256 conv27[0][0] __________________________________________________________________________________________________ relu29 (Activation) (None, None, None, 6 0 bn28[0][0] __________________________________________________________________________________________________ conv30 (Conv2D) (None, None, None, 6 36864 relu29[0][0] __________________________________________________________________________________________________ bn31 (BatchNormalization) (None, None, None, 6 256 conv30[0][0] __________________________________________________________________________________________________ relu32 (Activation) (None, None, None, 6 0 bn31[0][0] __________________________________________________________________________________________________ conv33 (Conv2D) (None, None, None, 6 36864 relu32[0][0] __________________________________________________________________________________________________ bn34 (BatchNormalization) (None, None, None, 6 256 conv33[0][0] __________________________________________________________________________________________________ relu35 (Activation) (None, None, None, 6 0 bn34[0][0] __________________________________________________________________________________________________ conv36 (Conv2D) (None, None, None, 6 36864 relu35[0][0] __________________________________________________________________________________________________ bn37 (BatchNormalization) (None, None, None, 6 256 conv36[0][0] __________________________________________________________________________________________________ relu38 (Activation) (None, None, None, 6 0 bn37[0][0] __________________________________________________________________________________________________ conv39 (Conv2D) (None, None, None, 6 36864 relu38[0][0] __________________________________________________________________________________________________ bn40 (BatchNormalization) (None, None, None, 6 256 conv39[0][0] __________________________________________________________________________________________________ relu41 (Activation) (None, None, None, 6 0 bn40[0][0] __________________________________________________________________________________________________ conv42 (Conv2D) (None, None, None, 6 36864 relu41[0][0] __________________________________________________________________________________________________ bn43 (BatchNormalization) (None, None, None, 6 256 conv42[0][0] __________________________________________________________________________________________________ relu44 (Activation) (None, None, None, 6 0 bn43[0][0] __________________________________________________________________________________________________ conv45 (Conv2D) (None, None, None, 6 36864 relu44[0][0] __________________________________________________________________________________________________ bn46 (BatchNormalization) (None, None, None, 6 256 conv45[0][0] __________________________________________________________________________________________________ relu47 (Activation) (None, None, None, 6 0 bn46[0][0] __________________________________________________________________________________________________ conv48 (Conv2D) (None, None, None, 1 576 relu47[0][0] __________________________________________________________________________________________________ subtract49 (Subtract) (None, None, None, 1 0 input0[0][0] conv48[0][0]
xeno

2021/07/28 04:52

DnCNN_model.output_shapeはおっしゃるとおり(None, None, None, 1)になっています
toast-uz

2021/07/28 23:11

確かにリンク先のモデルは(None, None, None, 1)になりますね。使い方に工夫が必要なのでは・・・。 https://qiita.com/cvusk/items/9b822860bb2c501a0fe4 の方(3つめのコード)は、(None, 32, 32, 3)なので普通に使えると思います。
toast-uz

2021/07/31 05:43

ちなみにリンク先に対して、変更は一切加えていないとのことですが、Usageは、main_train.py だけですね。そもそもどうやって使っているのでしょうか?どこかにFine Tuningでの使い方が掲載されているのですか?
xeno

2021/08/02 12:57

main_train.pyで学習、main_test.pyで学習したモデルでテストを行っています。 Fine Tuningはmain_train.pyで学習したモデルを再学習したいと思い質問させて頂いています。
jbpb0

2021/08/03 01:08 編集

https://github.com/cszn/DnCNN/tree/master/TrainingCodes/dncnn_keras の「main_train.py」の # load the last model in matconvnet style と書かれてるところを見ると、学習したときに作成される「model_数字.hdf5」というファイルがあれば、それを読み込んでから学習が始まるようなので、2回目以降の「main_train.py」の実行時には、自動的に > main_train.pyで学習したモデルを再学習したい となるのでは? (「model_数字.hdf5」というファイルを削除してなければ)
xeno

2021/08/03 01:13

今回はFlattenだけにしていますが層を追加したいと考えているのでmain_train.pyでは厳しいです。
jbpb0

2021/08/03 03:00

> 層を追加したいと考えているのでmain_train.pyでは厳しい 「main_train.py」の def DnCNN(... の中で追加したらいいだけでは?
guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

まだ回答がついていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.35%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問