現在、Deep Learningを使用した学習を行っているのですが、データの水増しについてクラス継承を用いた自作のジェネレータを作成しようと考えました。ジェネレータの作成について参考にした資料は以下に示します。
参考資料
作成したプログラムについては以下に示す通りです。
python
1# agumantation用クラス 2class MyImageDataGenerator(ImageDataGenerator): 3 def __init__(self, featurewise_center = False, samplewise_center = False, 4 featurewise_std_normalization = False, samplewise_std_normalization = False, 5 zca_whitening = False, zca_epsilon = 1e-06, rotation_range = 0.0, width_shift_range = 0.0, 6 height_shift_range = 0.0, brightness_range = None, shear_range = 0.0, zoom_range = 0.0, 7 channel_shift_range = 0.0, fill_mode = 'nearest', cval = 0.0, horizontal_flip = False, 8 vertical_flip = False, rescale = None, preprocessing_function = None, data_format = None, validation_split = 0.0, 9 random_crop = None, mix_up_alpha = 0.0, cutout_mask_size = 0): 10 super().__init__(featurewise_center, samplewise_center, featurewise_std_normalization, samplewise_std_normalization, 11 zca_whitening, zca_epsilon, rotation_range, width_shift_range, height_shift_range, brightness_range, 12 shear_range, zoom_range, channel_shift_range, fill_mode, cval, horizontal_flip, vertical_flip, 13 rescale, preprocessing_function, data_format, validation_split) 14 15 assert random_crop == None or len(random_crop) == 2 16 self.random_crop_size = random_crop 17 18 19 assert cutout_mask_size >= 0 20 self.cutout_mask_size = cutout_mask_size 21 22 23 24 def random_crop(self, original_img): 25 # Note: image_data_format is 'channel_last' 26 assert original_img.shape[2] == 3 27 if original_img.shape[0] < self.random_crop_size[0] or original_img.shape[1] < self.random_crop_size[1]: 28 raise ValueError(f"Invalid random_crop_size : original = {original_img.shape}, crop_size = {self.random_crop_size}") 29 30 height, width = original_img.shape[0], original_img.shape[1] 31 dy, dx = self.random_crop_size 32 x = np.random.randint(0, width - dx + 1) 33 y = np.random.randint(0, height - dy + 1) 34 return original_img[y:(y+dy), x:(x+dx), :] 35 36 def cutout(self, x, y): 37 return np.array(list(map(self._cutout, x))), y 38 39 def _cutout(self, image_origin): 40 # 最後に使うfill()は元の画像を書き換えるので、コピーしておく 41 image = np.copy(image_origin) 42 mask_value = image.mean() 43 44 h, w, _ = image.shape 45 # マスクをかける場所のtop, leftをランダムに決める 46 # はみ出すことを許すので、0以上ではなく負の値もとる(最大mask_size // 2はみ出す) 47 top = np.random.randint(0 - self.cutout_mask_size // 2, h - self.cutout_mask_size) 48 left = np.random.randint(0 - self.cutout_mask_size // 2, w - self.cutout_mask_size) 49 bottom = top + self.cutout_mask_size 50 right = left + self.cutout_mask_size 51 52 # はみ出した場合の処理 53 if top < 0: 54 top = 0 55 if left < 0: 56 left = 0 57 58 # マスク部分の画素値を平均値で埋める 59 image[top:bottom, left:right, :].fill(mask_value) 60 return image 61 62 def flow_from_dataframe(self, dataframe, directory=None, x_col='filename', y_col='class', 63 target_size=(256, 256), color_mode='rgb', classes=None, class_mode='categorical', 64 batch_size=32, shuffle=True, seed=None, save_to_dir=None, save_prefix='', save_format='png', 65 subset=None, interpolation='nearest', drop_duplicates=True): 66 67 # 親クラスのflow_from_dataframe 68 batches = super().flow_from_dataframe(dataframe, directory, x_col, y_col, 69 target_size, color_mode, classes, class_mode, 70 batch_size, shuffle, seed, save_to_dir, save_prefix, save_format, 71 subset, interpolation, drop_duplicates) 72 # 拡張処理 73 while True: 74 # Random crop 75 if self.random_crop_size != None: 76 x = np.zeros((batch_x.shape[0], self.random_crop_size[0], self.random_crop_size[1], 3)) 77 for i in range(batch_x.shape[0]): 78 x[i] = self.random_crop(batch_x[i]) 79 batch_x = x 80 81 if self.cutout_mask_size > 0: 82 batch_x, batch_y = next(batches) 83 result = self.cutout(batch_x, batch_y) 84 batch_x, batch_y = result 85 86 yield (batch_x, batch_y) 87 88 89 # 返り値 90 yield (batch_x, batch_y)
この中の
batches = super().flow_from_dataframe(dataframe, directory, x_col, y_col, target_size, color_mode, classes, class_mode, batch_size, shuffle, seed, save_to_dir, save_prefix, save_format, subset, interpolation, drop_duplicates)
の部分において、以下のようなエラーが出てしまいます。
ValueError: ('Invalid color mode:', None, '; expected "rgb", "rgba", or "grayscale".')
color_modeについてデフォルト引数でrgbに設定しているのですがエラーが消えません。また、調べてみたのですが、
flow_from_dataframeを使用したものが無く原因がわからない状態です。
上記エラーの原因について教えていただけると助かります。よろしくお願いします。
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。