下記のGithubからダウンロードしたDnCNNで学習を行い、保存したモデルで実際にノイズを除去しています。その際、白黒画像のノイズ除去はできるのですがカラー画像を読み込む際に止まってしまいます。
学習コードは何も変えていません。
testコードは読み込むディレクトリと保存するディレクトリを変更しました。
また、関数def show内のplt.imshow(x,interpolation='nearest',cmap='gray')をplt.imshow(x,interpolation='nearest')に変更しました。
それぞれのディレクトリ構成は以下です。
読み込むディレクトリ:
data
|ーorigーset(カラー画像と白黒画像が入っている)
|ーTestーset12,set68
|ーTrain400
保存するディレクトリ:
result_imgーSet
カラー画像を使えるようにするにはどうしたらいいでしょうか?
環境はubuntu18.04,tensorflow1.5.0,keras2.2.4です。
LINK
# -*- coding: utf-8 -*- # ============================================================================= # @article{zhang2017beyond, # title={Beyond a {Gaussian} denoiser: Residual learning of deep {CNN} for image denoising}, # author={Zhang, Kai and Zuo, Wangmeng and Chen, Yunjin and Meng, Deyu and Zhang, Lei}, # journal={IEEE Transactions on Image Processing}, # year={2017}, # volume={26}, # number={7}, # pages={3142-3155}, # } # by Kai Zhang (08/2018) # cskaizhang@gmail.com # https://github.com/cszn # modified on the code from https://github.com/husqin/DnCNN-keras # ============================================================================= # run this to test the model import argparse import os, time, datetime #import PIL.Image as Image import numpy as np from keras.models import load_model, model_from_json from skimage.measure import compare_psnr, compare_ssim from skimage.io import imread, imsave def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--set_dir', default='data/orig', type=str, help='directory of test dataset') parser.add_argument('--set_names', default=['set'], type=list, help='name of test dataset') parser.add_argument('--sigma', default=25, type=int, help='noise level') parser.add_argument('--model_dir', default=os.path.join('models','DnCNN_sigma25'), type=str, help='directory of the model') parser.add_argument('--model_name', default='model_010.hdf5', type=str, help='the model name') parser.add_argument('--result_dir', default='result_img', type=str, help='directory of results') #parser.add_argument('--result_names', default=['Set'], type=list, help='name of resulet image') parser.add_argument('--save_result', default=1, type=int, help='save the denoised image, 1 or 0') return parser.parse_args() def to_tensor(img): if img.ndim == 2: return img[np.newaxis,...,np.newaxis] elif img.ndim == 3: return np.moveaxis(img,2,0)[...,np.newaxis] def from_tensor(img): return np.squeeze(np.moveaxis(img[...,0],0,-1)) def log(*args,**kwargs): print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"),*args,**kwargs) def save_result(result,path): path = path if path.find('.') != -1 else path+'.png' ext = os.path.splitext(path)[-1] if ext in ('.txt','.dlm'): np.savetxt(path,result,fmt='%2.4f') else: imsave(path,np.clip(result,0,1)) def show(x,title=None,cbar=False,figsize=None): import matplotlib.pyplot as plt plt.figure(figsize=figsize) #plt.imshow(x,interpolation='nearest',cmap='gray') plt.imshow(x,interpolation='nearest') print("***************************************") if title: plt.title(title) if cbar: plt.colorbar() plt.show() if __name__ == '__main__': args = parse_args() # ============================================================================= # # serialize model to JSON # model_json = model.to_json() # with open("model.json", "w") as json_file: # json_file.write(model_json) # # serialize weights to HDF5 # model.save_weights("model.h5") # print("Saved model") # ============================================================================= if not os.path.exists(os.path.join(args.model_dir, args.model_name)): # load json and create model json_file = open(os.path.join(args.model_dir,'model.json'), 'r') loaded_model_json = json_file.read() json_file.close() model = model_from_json(loaded_model_json) # load weights into new model model.load_weights(os.path.join(args.model_dir,'model_010.hdf5')) log('load trained model on Train400 dataset by kai') else: model = load_model(os.path.join(args.model_dir, args.model_name),compile=False) log('load trained model') if not os.path.exists(args.result_dir): os.mkdir(args.result_dir) for set_cur in args.set_names: if not os.path.exists(os.path.join(args.result_dir,set_cur)): os.mkdir(os.path.join(args.result_dir,set_cur)) psnrs = [] ssims = [] for im in os.listdir(os.path.join(args.set_dir,set_cur)): if im.endswith(".jpg") or im.endswith(".bmp") or im.endswith(".png"): #x = np.array(Image.open(os.path.join(args.set_dir,set_cur,im)), dtype='float32') / 255.0 x = np.array(imread(os.path.join(args.set_dir,set_cur,im)), dtype=np.float32) / 255.0 np.random.seed(seed=0) # for reproducibility y = x + np.random.normal(0, args.sigma/255.0, x.shape) # Add Gaussian noise without clipping y = y.astype(np.float32) y_ = to_tensor(y) start_time = time.time() x_ = model.predict(y_) # inference elapsed_time = time.time() - start_time print('%10s : %10s : %2.4f second'%(set_cur,im,elapsed_time)) x_=from_tensor(x_) psnr_x_ = compare_psnr(x, x_) ssim_x_ = compare_ssim(x, x_) if args.save_result: name, ext = os.path.splitext(im) show(np.hstack((y,x_))) # show the image save_result(x_,path=os.path.join(args.result_dir,set_cur,name+'_dncnn'+ext)) # save the denoised image print('===============================') psnrs.append(psnr_x_) #print("1111111111") ssims.append(ssim_x_) #print("22222222222222" psnr_avg = np.mean(psnrs) ssim_avg = np.mean(ssims) psnrs.append(psnr_avg) ssims.append(ssim_avg) if args.save_result: save_result(np.hstack((psnrs,ssims)),path=os.path.join(args.result_dir,set_cur,'results.txt')) log('Datset: {0:10s} \n PSNR = {1:2.2f}dB, SSIM = {2:1.4f}'.format(set_cur, psnr_avg, ssim_avg))
カラー画像読み込み時に出力される警告文
main_test.py:127: UserWarning: DEPRECATED: skimage.measure.compare_psnr has been moved to skimage.metrics.peak_signal_noise_ratio. It will be removed from skimage.measure in version 0.18. psnr_x_ = compare_psnr(x, x_) main_test.py:128: UserWarning: DEPRECATED: skimage.measure.compare_ssim has been moved to skimage.metrics.structural_similarity. It will be removed from skimage.measure in version 0.18. ssim_x_ = compare_ssim(x, x_) *************************************** WARNING:imageio:Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. =============================== 2021-07-12 11:06:44.320393: W tensorflow/core/framework/allocator.cc:107] Allocation of 3221225472 exceeds 10% of system memory. 2021-07-12 11:06:44.657304: W tensorflow/core/framework/allocator.cc:107] Allocation of 3221225472 exceeds 10% of system memory. terminate called after throwing an instance of 'std::bad_alloc' what(): std::bad_alloc 中止 (コアダンプ)
回答1件
あなたの回答
tips
プレビュー