ResnetにDnCNNをいれてノイズ(落書き)を消そうとしているのですが、エラーを解決できません。
該当箇所に関しては参考にしたDnCNNと同じ書き方です。
modelまでは表示できるのですがその先に進めません。
import keras from tensorflow.keras.models import load_model from tensorflow.keras.datasets import cifar10 from tensorflow.keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, BatchNormalization, Activation from tensorflow.keras.models import Model from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint from tensorflow.keras.layers import Dropout from tensorflow.keras.optimizers import Adam from tensorflow.keras.layers import Concatenate import os import pickle import numpy as np batch_size = 32 #num_classes = 10 epochs = 10 saveDir = "/home/script/savex/" class Data(): def __init__(self): self.input_img = (32, 32, 3) def get_batch(self): (x_train, y_train), (x_test, y_test) = cifar10.load_data() x_train_line = np.copy(x_train) x_test_line = np.copy(x_test) print(x_train_line) x_train_line = self.drawLines(x_train_line) x_test_line = self.drawLines(x_test_line) # normalize data x_train = x_train.astype('float32') x_test = x_test.astype('float32') x_train /= 255 x_test /= 255 # normalize data x_train_line = x_train_line.astype('float32') x_test_line = x_test_line.astype('float32') x_train_line /= 255 x_test_line /= 255 print('x_train shape:', x_train.shape) print(x_train.shape[0], 'train samples') print(x_test.shape[0], 'test samples') def drawUpLeft(x,i,col): vh = np.random.randint(2) start = np.random.randint(10) length = start + np.random.randint(4,21) position = np.random.randint(32) for j in range(3): if vh == 0: x[i,:,:,j][start:length, position] = col[j] else: x[i,:,:,j][position, start:length] = col[j] return x def drawDnRight(x,i,col): vh = np.random.randint(2) start = np.random.randint(10,25) length = start + np.random.randint(5,32-start) position = np.random.randint(32) for j in range(3): if vh == 0: x[i,:,:,j][start:length, position] = col[j] else: x[i,:,:,j][position, start:length] = col[j] return x def drawLines(x): for i in range(len(x)): numLines = np.random.randint(25) + 1 for l in range(numLines): col = [np.random.randint(256),np.random.randint(256),np.random.randint(256)] if l % 2 == 0: drawUpLeft(x,i,col) else: drawDnRight(x,i,col) return x def network(input_shape): inputs = Input(shape=input_shape) x = Conv2D(32, kernel_size=3, padding="same", activation="relu")(inputs) x = BatchNormalization()(x) for i in range(17): shutcut = x #ショートカットコネクション用にモジュールの入力データを取得する x = Conv2D(64, (3, 3), padding='same')(x) x = BatchNormalization()(x) x = Activation('relu')(x) x = Dropout(rate=0.3)(x) x = Concatenate()([x, shutcut]) #ショートカットコネクション x = Conv2D(3, (3, 3), padding='same')(x) output_img = Activation('tanh')(x) model = Model(inputs,output_img) print(model.summary()) return model # モデルを学習させるクラス class Trainer(): # モデルをコンパイルして、学習するための設定をプライベートプロパティに設定する。 def __init__(self, model, loss, optimizer): self._model = model self._model.compile( loss=loss, optimizer=optimizer, metrics=["accuracy"] ) self._verbose = 1 self._batch_size = 128 self._epochs = 100 # 実際の学習 def fit(self, x_train, y_train, x_test, y_test): chkpt = saveDir + 'Cifar10_Resnet_weights.{epoch:02d}-{loss:.2f}-{val_loss:.2f}.hdf5' cp_cb = ModelCheckpoint(filepath = chkpt, monitor='val_loss', verbose=1, save_best_only=True, mode='auto') self._model.fit( x_train, x_train_line, batch_size=self._batch_size, epochs=self._epochs, verbose=self._verbose, validation_data=(x_test_line, x_test), callbacks=[cp_cb], shuffle=True ) score = model.evaluate(x_test_line, x_test, verbose=1) print(score) return self._model dataset = Data() # データを取得するためのCIFAR10Datasetのインスタンス化 model = network(dataset.input_img) #モデルの取得 x_train, y_train, x_test, y_test = dataset.get_batch() # 学習データとテストデータの取得 trainer = Trainer(model, loss="mean_squared_error", optimizer="adam") # モデルとロス関数、最適化アルゴリズムを引数にして、Trainerのインスタンス化 model = trainer.fit(x_train, y_train, x_test, y_test) # モデルの学習 # モデルの評価 score = model.evaluate(x_test, y_test, verbose=0) print('Test loss: ', score[0]) print('Test accuracy: ', score[1]) c10test = model.predict(x_test_line) # definition to show original image and reconstructed image def showOrigDec(orig, noise, denoise, num=10): import matplotlib.pyplot as plt n = num plt.figure(figsize=(20, 6)) for i in range(n): # display original ax = plt.subplot(3, n, i+1) plt.imshow(orig[i].reshape(32, 32, 3)) ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) # display noisy image ax = plt.subplot(3, n, i +1 + n) plt.imshow(noise[i].reshape(32, 32, 3)) ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) # display denoised image ax = plt.subplot(3, n, i +1 + n + n) plt.imshow(denoise[i].reshape(32, 32, 3)) ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) plt.show() showOrigDec(x_test, x_test_line, c10test)
ERROR
1Traceback (most recent call last): 2 File "DnResnet.py", line 139, in <module> 3 x_train, y_train, x_test, y_test = dataset.get_batch() # 学習データとテストデータの取得 4 File "DnResnet.py", line 29, in get_batch 5 x_train_line = self.drawLines(x_train_line) 6TypeError: drawLines() takes 1 positional argument but 2 were given 7
回答1件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。