質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

ただいまの
回答率

87.60%

DnCNNの学習時のエラー

受付中

回答 0

投稿 編集

  • 評価
  • クリップ 0
  • VIEW 131

score 10

下記のリンクからコピーしたDnCNNの学習をfitで行えるようにしたいと考えています。
LINK
しかし、学習を行おうとするとエラーが発生し解決できません。
教えていただきたいです。
data_generator.pyは変更していません。

流れとしては
1,Class Dataで保存している学習用画像を加工
2,加工したデータと学習用画像で学習
3,学習したモデルを保存

# 必要なライブラリーのインストール
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import Subtract
import data_generator as dg
import numpy as np
import matplotlib.pyplot as plt
from keras.callbacks import EarlyStopping, ModelCheckpoint
from time import time
import matplotlib.pyplot as plt

t_start = time() # 開始時間

print('tf       :', tf.__version__)
print('tf.keras :', tf.keras.__version__)
print('keras    :', keras.__version__)

epochs = 1
batch = 128
save_dir = "/home/script/DnCNN-master/TrainingCodes/dncnn_keras/models/dndnn_model/"
test_dir = "/home/script/DnCNN-master/TrainingCodes/dncnn_keras/data/orig/set/"

# データを加工するクラス
class Data():
    def __init__(self):
        self.data = "/home/script/DnCNN-master/TrainingCodes/dncnn_keras/data/Train400"

    def train_datagen(self,epoch_iter=1000,epoch_num=5,batch_size=128):
        data_dir = self.data
        while(True):
            n_count = 0
            if n_count == 0:
                #print(n_count)
                xs = dg.datagenerator(data_dir)
                assert len(xs)%128 ==0, \
                log('make sure the last iteration has a full batchsize, this is important if you use batch normalization!')
                xs = xs.astype('float32')/255.0
                indices = list(range(xs.shape[0]))
                n_count = 1
                for _ in range(epoch_num):
                    np.random.shuffle(indices)    # shuffle
                    for i in range(0, len(indices), batch_size):
                        batch_x = xs[indices[i:i+batch_size]]
                        noise =  np.random.normal(0, 25/255.0, batch_x.shape)    # noise
                        #noise =  K.random_normal(ge_batch_y.shape, mean=0, stddev=args.sigma/255.0)
                        batch_y = batch_x + noise
                        yield batch_y, batch_x


# Network
def DnCNN(depth,filters=64,image_channels=1, use_bnorm=True):
    layer_count = 0
    inpt = Input(shape=(None,None,image_channels),name = 'input'+str(layer_count))
    # 1st layer, Conv+relu
    layer_count += 1
    x = Conv2D(filters=filters, kernel_size=(3,3), strides=(1,1),kernel_initializer='Orthogonal', padding='same',name = 'conv'+str(layer_count))(inpt)
    layer_count += 1
    x = Activation('relu',name = 'relu'+str(layer_count))(x)
    # depth-2 layers, Conv+BN+relu
    for i in range(depth-2):
        layer_count += 1
        x = Conv2D(filters=filters, kernel_size=(3,3), strides=(1,1),kernel_initializer='Orthogonal', padding='same',use_bias = False,name = 'conv'+str(layer_count))(x)
        if use_bnorm:
            layer_count += 1
            #x = BatchNormalization(axis=3, momentum=0.1,epsilon=0.0001, name = 'bn'+str(layer_count))(x)
            x = BatchNormalization(axis=3, momentum=0.0,epsilon=0.0001, name = 'bn'+str(layer_count))(x)
        layer_count += 1
        x = Activation('relu',name = 'relu'+str(layer_count))(x)
    # last layer, Conv
    layer_count += 1
    x = Conv2D(filters=image_channels, kernel_size=(3,3), strides=(1,1), kernel_initializer='Orthogonal',padding='same',use_bias = False,name = 'conv'+str(layer_count))(x)
    layer_count += 1
    x = Subtract(name = 'subtract' + str(layer_count))([inpt, x])   # input - noise
    model = Model(inputs=inpt, outputs=x)

    return model

# 学習用のクラス
class Trainer():
    def __init__(self, model, loss, optimizer):
        self._model = model
        self._model.compile(
        loss=loss,
        optimizer=optimizer,
        metrics=["accuracy"]
        )

      # 実際の学習
    def fit(self, train_noise, train_org):
        chkpt = save_dir + 'DnCNN.{epoch:02d}-{loss:.2f}-{val_loss:.2f}.hdf5'
        cp_cb = ModelCheckpoint(filepath = chkpt, monitor='val_loss', verbose=1, save_best_only=True, mode='auto')

        self._model.fit(
            train_noise,
            train_org,
            epochs = epochs,
            verbose = 1,
            callbacks=[cp_cb],
            shuffle=False
        )
        return self._model

dataset = Data() # データを取得するためのDatasetのインスタンス化
model = DnCNN(depth=17,filters=64,image_channels=1,use_bnorm=True) #モデルの取得
trainer = Trainer(model, loss="mse", optimizer="adam") # モデルとロス関数、最適化アルゴリズムを引数にして、Trainerのインスタンス化

model = trainer.fit(Data.train_datagen(dataset, batch_size = batch), dataset) # モデルの学習


t_end = time() #終了時間
t_elapsed = t_end - t_start

print("処理時間は{0}".format(t_elapsed))


error

Traceback (most recent call last):
  File "DnCNN.py", line 117, in <module>
    model = trainer.fit(Data.train_datagen(dataset, batch_size = batch), dataset) # モデルの学習
  File "DnCNN.py", line 109, in fit
    shuffle=False
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py", line 658, in fit
    y, sample_weight, validation_split=validation_split)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training_utils.py", line 946, in check_generator_arguments
    raise ValueError('`y` argument is not supported when data is'
ValueError: `y` argument is not supported when data isa generator or Sequence instance. Instead pass targets as the second element of the generator.
  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 過去に投稿した質問と同じ内容の質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

質問への追記・修正の依頼

  • jbpb0

    2021/08/05 12:06

    そうなんですか
    同じ論文を参照してるので、ニューラルネットは同じものだと思ったのですが、違うのですか

    キャンセル

  • xeno

    2021/08/05 16:07

    論文ベースでは残差が組み込まれていますが、紹介して頂いたものには残差が組み込まれていないようです。

    キャンセル

  • jbpb0

    2021/08/05 17:55

    ニューラルネットの定義のところを差し替えたら、いけたりしませんか?

    キャンセル

まだ回答がついていません

15分調べてもわからないことは、teratailで質問しよう!

  • ただいまの回答率 87.60%
  • 質問をまとめることで、思考を整理して素早く解決
  • テンプレート機能で、簡単に質問をまとめられる

関連した質問

同じタグがついた質問を見る