前提・実現したいこと
以下のサイトを参考に、自作のgeneratorを作成しました。
https://qiita.com/koshian2/items/909360f50e3dd5922f32
しかし、コードを実行してみると、以下のようなエラーが発生してしまいますが、どう修正したらいいかわからず困っております。
お手数をお掛けするのですが、修正方法についてお分かりの方がいらっしゃいましたら教えていただけないでしょうか。
どうぞよろしくお願いいたします。
while True:
40 batch_x_2, batch_y_2 = next(batches)
---> 41 m1, m2 = batch_x.shape[0], batch_x_2.shape[0]
42 if m1 < m2:
43 batch_x_2 = batch_x_2[:m1]
UnboundLocalError: local variable 'batch_x' referenced before assignment
該当のソースコード
Python
1# model構築の準備 2import keras 3from keras.models import Model 4from keras.layers import Dense, GlobalAveragePooling2D,Input,Dropout,Activation 5from keras.preprocessing.image import ImageDataGenerator 6from keras import applications 7from keras.preprocessing.image import ImageDataGenerator 8from keras.optimizers import Adam 9from keras.callbacks import CSVLogger,EarlyStopping 10import numpy as np 11from keras import backend as K 12from keras.engine.topology import Layer 13import numpy as np 14import tensorflow as tf 15from keras.preprocessing.image import load_img, img_to_array, array_to_img 16import os 17import matplotlib.pyplot as plt 18%matplotlib inline 19 20 21class MyImageDataGenerator(ImageDataGenerator): 22 def __init__(self, featurewise_center = False, samplewise_center = False, 23 featurewise_std_normalization = False, samplewise_std_normalization = False, 24 zca_whitening = False, zca_epsilon = 1e-06, rotation_range = 0.0, width_shift_range = 0.0, 25 height_shift_range = 0.0, brightness_range = None, shear_range = 0.0, zoom_range = 0.0, 26 channel_shift_range = 0.0, fill_mode = 'nearest', cval = 0.0, horizontal_flip = False, 27 vertical_flip = False, rescale = None, preprocessing_function = None, data_format = None, validation_split = 0.0, 28 random_crop = None, mix_up_alpha = 0.0): 29 # 親クラスのコンストラクタ 30 super().__init__(featurewise_center, samplewise_center, featurewise_std_normalization, samplewise_std_normalization, zca_whitening, zca_epsilon, rotation_range, width_shift_range, height_shift_range, brightness_range, shear_range, zoom_range, channel_shift_range, fill_mode, cval, horizontal_flip, vertical_flip, rescale, preprocessing_function, data_format, validation_split) 31 # 拡張処理のパラメーター 32 # Mix-up 33 assert mix_up_alpha >= 0.0 34 self.mix_up_alpha = mix_up_alpha 35 # Mix-up 36 # 参考 https://qiita.com/yu4u/items/70aa007346ec73b7ff05 37 def mix_up(self, X1, y1, X2, y2): 38 assert X1.shape[0] == y1.shape[0] == X2.shape[0] == y2.shape[0] 39 batch_size = X1.shape[0] 40 l = np.random.beta(self.mix_up_alpha, self.mix_up_alpha, batch_size) 41 X_l = l.reshape(batch_size, 1, 1, 1) 42 y_l = l.reshape(batch_size, 1) 43 X = X1 * X_l + X2 * (1-X_l) 44 y = y1 * y_l + y2 * (1-y_l) 45 return X, y 46 47 def flow_from_directory(self, directory, target_size = (299,299), color_mode = 'rgb', 48 classes = None, class_mode = 'categorical', batch_size = 12, shuffle = True, 49 seed = None, save_to_dir = None, save_prefix = '', save_format = 'png', 50 follow_links = False, subset = None, interpolation = 'nearest'): 51 # 親クラスのflow_from_directory 52 batches = super().flow_from_directory(directory, target_size, color_mode, classes, class_mode, batch_size, shuffle, seed, save_to_dir, save_prefix, save_format, follow_links, subset, interpolation) 53 # 拡張処理 54 while True: 55 if self.mix_up_alpha > 0: 56 while True: 57 batch_x_2, batch_y_2 = next(batches) 58 m1, m2 = batch_x.shape[0], batch_x_2.shape[0] 59 if m1 < m2: 60 batch_x_2 = batch_x_2[:m1] 61 batch_y_2 = batch_y_2[:m1] 62 break 63 elif m1 == m2: 64 break 65 batch_x, batch_y = self.mix_up(batch_x, batch_y, batch_x_2, batch_y_2) 66 # 返り値 67 yield (batch_x, batch_y) 68 69train_dir = './train' 70validation_dir = './validation' 71 72train_datagen = MyImageDataGenerator( 73 rescale=1/255.0, 74 mix_up_alpha=0.2, 75 #rotation_range=8, 76 horizontal_flip=True) 77 78train_generator=train_datagen.flow_from_directory( 79 train_dir, 80 target_size=(299,299), 81 batch_size=12,#25, 82 class_mode='categorical', 83 shuffle=True) 84 85 86validation_datagen=ImageDataGenerator(rescale=1.0/255.) 87 88validation_generator=validation_datagen.flow_from_directory( 89 validation_dir, 90 target_size=(299,299), 91 batch_size=12,#25, 92 class_mode='categorical', 93 shuffle=True) 94 95 96batch_size = 12 97 98base_model=keras.applications.inception_resnet_v2.InceptionResNetV2(input_shape=(299,299,3), 99 weights='imagenet', 100 include_top=False) 101 102x = base_model.output 103x = GlobalAveragePooling2D()(x) 104x = Dense(1024, activation='relu')(x) 105predictions = Dense(4, activation='softmax')(x) 106model = Model(inputs=base_model.input,outputs=predictions) 107 108model.summary() 109 110 111model.compile(optimizer=Adam(lr=0.001), 112 loss='categorical_crossentropy', 113 metrics=['accuracy']) 114 115callbacks_list = [ 116 callbacks.ModelCheckpoint( 117 filepath="model.ep{epoch:02d}.h5",#delsavepath, 118 save_best_only=True), 119 120 #バリデーションlossが改善しなくなったら学習率を変更する 121callbacks.ReduceLROnPlateau( 122 monitor="val_loss", 123 factor=0.8, 124 patience=5, 125 verbose=1)]#, 126 127#callbacks.EarlyStopping(monitor='val_loss', patience=7, verbose=1)] 128 129model.fit_generator(train_generator, steps_per_epoch=7208, epochs=10, validation_steps=1805, validation_data=validation_generator, callbacks=callbacks_list) 130
回答2件
あなたの回答
tips
プレビュー