質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%
Keras

Kerasは、TheanoやTensorFlow/CNTK対応のラッパーライブラリです。DeepLearningの数学的部分を短いコードでネットワークとして表現することが可能。DeepLearningの最新手法を迅速に試すことができます。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

Q&A

解決済

1回答

2200閲覧

fit_generatorを使用してもMemoryErrorになる

tamtam44444

総合スコア13

Keras

Kerasは、TheanoやTensorFlow/CNTK対応のラッパーライブラリです。DeepLearningの数学的部分を短いコードでネットワークとして表現することが可能。DeepLearningの最新手法を迅速に試すことができます。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

0グッド

0クリップ

投稿2019/02/02 09:52

前提・実現したいこと

python初心者です。
SRCNNを用いた超解像を行いたいのですが、fit_generatorを使用してもMemoryErrorになります。
単にPCのメモリの問題なのか、generatorの仕様に問題があるのかわかりません。
generatorはこちらのサイトhttps://www.kumilog.net/entry/keras-generatorを参考にして作りました。

発生している問題・エラーメッセージ

Layer (type) Output Shape Param # ================================================================= conv2d_4 (Conv2D) (None, None, None, 64) 15616 _________________________________________________________________ conv2d_5 (Conv2D) (None, None, None, 32) 2080 _________________________________________________________________ conv2d_6 (Conv2D) (None, None, None, 3) 2403 ================================================================= Total params: 20,099 Trainable params: 20,099 Non-trainable params: 0 _________________________________________________________________ Epoch 1/50 Traceback (most recent call last): File "<ipython-input-2-d62fed411f42>", line 1, in <module> runfile('C:/Users/myname/.spyder-py3/temp.py', wdir='C:/Users/myname/.spyder-py3') File "C:\ProgramData\Anaconda3\lib\site-packages\spyder_kernels\customize\spydercustomize.py", line 704, in runfile execfile(filename, namespace) File "C:\ProgramData\Anaconda3\lib\site-packages\spyder_kernels\customize\spydercustomize.py", line 108, in execfile exec(compile(f.read(), filename, 'exec'), namespace) File "C:/Users/myname/.spyder-py3/temp.py", line 90, in <module> verbose = 1, File "C:\ProgramData\Anaconda3\lib\site-packages\keras\legacy\interfaces.py", line 91, in wrapper return func(*args, **kwargs) File "C:\ProgramData\Anaconda3\lib\site-packages\keras\models.py", line 1256, in fit_generator initial_epoch=initial_epoch) File "C:\ProgramData\Anaconda3\lib\site-packages\keras\legacy\interfaces.py", line 91, in wrapper return func(*args, **kwargs) File "C:\ProgramData\Anaconda3\lib\site-packages\keras\engine\training.py", line 2145, in fit_generator generator_output = next(output_generator) File "C:\ProgramData\Anaconda3\lib\site-packages\keras\utils\data_utils.py", line 755, in get six.reraise(value.__class__, value, value.__traceback__) File "C:\ProgramData\Anaconda3\lib\site-packages\six.py", line 693, in reraise raise value File "C:\ProgramData\Anaconda3\lib\site-packages\keras\utils\data_utils.py", line 635, in _data_generator_task generator_output = next(self._generator) File "C:/Users/myname/.spyder-py3/temp.py", line 58, in flow_from_directory self.x_images.append(np.asarray(f.convert("RGB"), dtype=np.float32)) File "C:\ProgramData\Anaconda3\lib\site-packages\numpy\core\numeric.py", line 501, in asarray return array(a, dtype, copy=False, order=order) MemoryError

該当のソースコード

from keras.models import Sequential from keras.layers import Conv2D from keras import backend as K from PIL import Image import os import numpy as np BATCH_SIZE = 5 N_TRAIN_DATA = 1320 N_TEST_DATA = 100 model = Sequential() model.add(Conv2D(filters=64, kernel_size=9, padding="same", activation="relu", input_shape=(None,None,3) )) model.add(Conv2D(filters=32, kernel_size=1, padding="same", activation="relu", )) model.add(Conv2D(filters=3, kernel_size=5, padding="same", )) model.summary() def psnr(y_true, y_pred): return -10*K.log(K.mean(K.flatten((y_true - y_pred))**2 ))/np.log(10) class DataGenerator(object): def __init__(self): self.reset() def reset(self): self.x_images = [] self.y_images = [] def flow_from_directory(self,directory, batch_size): folderlist = os.listdir(directory) x_path = directory + "\" + folderlist[0] y_path = directory + "\" + folderlist[1] while True: for x_file in os.listdir(x_path): with Image.open(x_path + "\" + x_file) as f: self.x_images.append(np.asarray(f.convert("RGB"), dtype=np.float32)) if len(self.x_images) == batch_size: inputs = np.asarray(self.x_images, dtype=np.float32) for y_file in os.listdir(y_path): with Image.open(y_path + "\" + y_file) as t: self.y_images.append(np.asarray(t.convert("RGB"), dtype=np.float32)) if len(self.y_images) == batch_size: targets = np.asarray(self.y_images, dtype=np.float32) self.reset() yield inputs/255., targets/255. data_gen = DataGenerator() train_datagenerator = data_gen.flow_from_directory("\train", batch_size=BATCH_SIZE) test_datagenerator= data_gen.flow_from_directory("\test", batch_size=BATCH_SIZE) model.compile(loss="mean_squared_error", optimizer="adam", metrics=[psnr]) model.fit_generator(train_datagenerator, validation_data = test_datagenerator, steps_per_epoch=N_TRAIN_DATA//BATCH_SIZE, validation_steps=N_TEST_DATA//BATCH_SIZE, epochs=50, verbose = 1, )

試したこと

試しにバッチサイズを1で行ってみましたがMemoryErrorになりました。

補足情報(FW/ツールのバージョンなど)

動作環境
windows 10 home
CPU core i3-8100
メモリ 8.00GB
Gforce GTX1060 3GB

python 3.6.6
keras 2.1.3

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

0

ベストアンサー

入力層のサイズが input_shape=(None,None,3) となっていますが、
input_shape=(224,224,3) のようなサイズに固定し、入力前にリサイズを行うべきではないでしょうか。

元の画像サイズがわかりませんが、例えば (1000, 1000) ぐらいある画像をそのままネットワークの入力としているのだとしたら、当然 GPU メモリは足りなくなります。

CNN モデルでは画像は入力前に一定サイズにリサイズするのが基本です。

投稿2019/02/02 10:00

編集2019/02/02 10:01
tiitoi

総合スコア21956

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

tamtam44444

2019/02/02 10:08

回答ありがとうございます。 画像サイズは(2560,1920)でこれ以上小さくしたくないのですが、大きな画像を扱うにはメモリを大きくするしかないのでしょうか?
tiitoi

2019/02/02 10:44 編集

メモリを大きくするといっても、GTXシリーズでその上は8G、12Gになりますが、それでもそのサイズをそのまま扱うのは無理です。 論文は読んでないですが、確かSRCNN は元画像をいくつかの小さい画像にグリッド上に分割してモデルに流していたはずですよ。 既存の実装等参考にしてみてはどうでしょうか https://github.com/titu1994/Image-Super-Resolution
tamtam44444

2019/02/02 14:24

回答ありがとうございます。 メモリを大きくしても無理なんですね。 画像を分割してやってみます。
guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問