Kerasのfit_generatorに渡すデータを独自開発したいのですが、正常に動作せず、エポックの最初でどうしてもフリーズしてしまいます。
渡す値としては画像のパスが格納された配列です
比較対象のデータについても同様サイズの画像データになります。
メモリ節約のため、ジェネレータ内でパスから画像を読み込み学習を行おうと思い、Generator関数内に画像の読み込み処理を導入しています
import numpy as np import cv2 class CustomGenerator(object): def __init__(self,color_imgpath,gray_imgpath,IMAGE_SIZE=64): self.reset(0) self.color_path=color_imgpath self.gray_path=gray_imgpath self.IMAGE_SIZE=IMAGE_SIZE self.batch_size=36 def reset(self,ind): self.ind = ind self._in = [] self._out = [] def GetImg(self): local_i = 1 while True: gray_img = cv2.imread(self.gray_path[self.ind:self.ind + local_i][0]) gray_img = cv2.resize(gray_img, (self.IMAGE_SIZE, self.IMAGE_SIZE)) self._in.append(np.asarray(gray_img/255.0, dtype=np.float32)) color_img = cv2.imread(self.color_path[self.ind:self.ind + local_i][0]) color_img = cv2.resize(color_img/255.0, (self.IMAGE_SIZE, self.IMAGE_SIZE)) self._out.append(np.asarray(color_img, dtype=np.float32)) local_i += 1 if len(self._in) ==self.batch_size: inputs = np.asarray(self._in, dtype=np.float32) targets = np.asarray(self._out, dtype=np.float32) self.reset(self.ind+self.batch_size) yield inputs,targets
これらデータをモデルに渡しfit_generatorを開始したのですが、1エポックで停止してしまいます。
(内部的にはGetImg関数のWhile文が途中停止しています)
なにかこのコードには間違いがあるはずなのですが原因がわかりません。
model = get_model() model.compile(loss=dice_coef_loss, optimizer=Adam(), metrics=[dice_coef]) gen = CustomGenerator(color_pics,gray_pics,IMAGE_SIZE=IMAGE_SIZE) model.fit_generator( generator=gen.GetImg(), steps_per_epoch=int(np.ceil(len(color_pics) / 32)), epochs=20)
While 文が途中停止の意味がよくわからないのですが、無限ループして if 文で抜けないということでしょうか?
Keras はオープンソースなので、Github のコードを参考にして改造したらどうですか。
https://github.com/keras-team/keras-preprocessing/blob/master/keras_preprocessing/image.py
あなたの回答
tips
プレビュー