Pythonで自分の作成した学習データで物体検出を行おうと思い、コードが公開されていたssd_kerasを用いました。
keras,tensorflowのバージョンを合わせてコードを実行してみた所、モデルが学習している最中にエラーメッセージが発生しました。
python
1 2# some constants 3#NUM_CLASSES = 4 4NUM_CLASSES = 21 5input_shape = (300, 300, 3) 6 7priors = pickle.load(open('prior_boxes_ssd300.pkl', 'rb')) 8bbox_util = BBoxUtility(NUM_CLASSES, priors) 9 10#gt = pickle.load(open('gt_pascal.pkl', 'rb')) 11gt = pickle.load(open('Mytest/test.pkl','rb')) 12keys = sorted(gt.keys()) 13print(keys) 14num_train = int(round(0.8 * len(keys))) 15train_keys = keys[:num_train] 16val_keys = keys[num_train:] 17num_val = len(val_keys) 18 19class Generator(object): 20 def __init__(self, gt, bbox_util, 21 batch_size, path_prefix, 22 train_keys, val_keys, image_size, 23 saturation_var=0.5, 24 brightness_var=0.5, 25 contrast_var=0.5, 26 lighting_std=0.5, 27 hflip_prob=0.5, 28 vflip_prob=0.5, 29 do_crop=True, 30 crop_area_range=[0.75, 1.0], 31 aspect_ratio_range=[3./4., 4./3.]): 32 self.gt = gt 33 self.bbox_util = bbox_util 34 self.batch_size = batch_size 35 self.path_prefix = path_prefix 36 self.train_keys = train_keys 37 self.val_keys = val_keys 38 self.train_batches = len(train_keys) 39 self.val_batches = len(val_keys) 40 self.image_size = image_size 41 self.color_jitter = [] 42 if saturation_var: 43 self.saturation_var = saturation_var 44 self.color_jitter.append(self.saturation) 45 if brightness_var: 46 self.brightness_var = brightness_var 47 self.color_jitter.append(self.brightness) 48 if contrast_var: 49 self.contrast_var = contrast_var 50 self.color_jitter.append(self.contrast) 51 self.lighting_std = lighting_std 52 self.hflip_prob = hflip_prob 53 self.vflip_prob = vflip_prob 54 self.do_crop = do_crop 55 self.crop_area_range = crop_area_range 56 self.aspect_ratio_range = aspect_ratio_range 57 58 def grayscale(self, rgb): 59 return rgb.dot([0.299, 0.587, 0.114]) 60 61 def saturation(self, rgb): 62 gs = self.grayscale(rgb) 63 alpha = 2 * np.random.random() * self.saturation_var 64 alpha += 1 - self.saturation_var 65 rgb = rgb * alpha + (1 - alpha) * gs[:, :, None] 66 return np.clip(rgb, 0, 255) 67 68 def brightness(self, rgb): 69 alpha = 2 * np.random.random() * self.brightness_var 70 alpha += 1 - self.saturation_var 71 rgb = rgb * alpha 72 return np.clip(rgb, 0, 255) 73 74 def contrast(self, rgb): 75 gs = self.grayscale(rgb).mean() * np.ones_like(rgb) 76 alpha = 2 * np.random.random() * self.contrast_var 77 alpha += 1 - self.contrast_var 78 rgb = rgb * alpha + (1 - alpha) * gs 79 return np.clip(rgb, 0, 255) 80 81 def lighting(self, img): 82 cov = np.cov(img.reshape(-1, 3) / 255.0, rowvar=False) 83 eigval, eigvec = np.linalg.eigh(cov) 84 noise = np.random.randn(3) * self.lighting_std 85 noise = eigvec.dot(eigval * noise) * 255 86 img += noise 87 return np.clip(img, 0, 255) 88 89 def horizontal_flip(self, img, y): 90 if np.random.random() < self.hflip_prob: 91 img = img[:, ::-1] 92 y[:, [0, 2]] = 1 - y[:, [2, 0]] 93 return img, y 94 95 def vertical_flip(self, img, y): 96 if np.random.random() < self.vflip_prob: 97 img = img[::-1] 98 y[:, [1, 3]] = 1 - y[:, [3, 1]] 99 return img, y 100 101 def random_sized_crop(self, img, targets): 102 img_w = img.shape[1] 103 img_h = img.shape[0] 104 img_area = img_w * img_h 105 random_scale = np.random.random() 106 random_scale *= (self.crop_area_range[1] - 107 self.crop_area_range[0]) 108 random_scale += self.crop_area_range[0] 109 target_area = random_scale * img_area 110 random_ratio = np.random.random() 111 random_ratio *= (self.aspect_ratio_range[1] - 112 self.aspect_ratio_range[0]) 113 random_ratio += self.aspect_ratio_range[0] 114 w = np.round(np.sqrt(target_area * random_ratio)) 115 h = np.round(np.sqrt(target_area / random_ratio)) 116 if np.random.random() < 0.5: 117 w, h = h, w 118 w = min(w, img_w) 119 w_rel = w / img_w 120 w = int(w) 121 h = min(h, img_h) 122 h_rel = h / img_h 123 h = int(h) 124 x = np.random.random() * (img_w - w) 125 x_rel = x / img_w 126 x = int(x) 127 y = np.random.random() * (img_h - h) 128 y_rel = y / img_h 129 y = int(y) 130 img = img[y:y+h, x:x+w] 131 new_targets = [] 132 for box in targets: 133 cx = 0.5 * (box[0] + box[2]) 134 cy = 0.5 * (box[1] + box[3]) 135 if (x_rel < cx < x_rel + w_rel and 136 y_rel < cy < y_rel + h_rel): 137 xmin = (box[0] - x_rel) / w_rel 138 ymin = (box[1] - y_rel) / h_rel 139 xmax = (box[2] - x_rel) / w_rel 140 ymax = (box[3] - y_rel) / h_rel 141 xmin = max(0, xmin) 142 ymin = max(0, ymin) 143 xmax = min(1, xmax) 144 ymax = min(1, ymax) 145 box[:4] = [xmin, ymin, xmax, ymax] 146 new_targets.append(box) 147 new_targets = np.asarray(new_targets).reshape(-1, targets.shape[1]) 148 return img, new_targets 149 150 def generate(self, train=True): 151 while True: 152 if train: 153 shuffle(self.train_keys) 154 keys = self.train_keys 155 else: 156 shuffle(self.val_keys) 157 keys = self.val_keys 158 inputs = [] 159 targets = [] 160 for key in keys: 161 img_path = self.path_prefix + key 162 img = imread(img_path).astype('float32') 163 y = self.gt[key].copy() 164 if train and self.do_crop: 165 img, y = self.random_sized_crop(img, y) 166 #img = imresize(img, self.image_size).astype('float32') → scipyのversionアップで使えない 167 img = np.array(Image.fromarray(img).self.image_size, resample=0) 168 if train: 169 shuffle(self.color_jitter) 170 for jitter in self.color_jitter: 171 img = jitter(img) 172 if self.lighting_std: 173 img = self.lighting(img) 174 if self.hflip_prob > 0: 175 img, y = self.horizontal_flip(img, y) 176 if self.vflip_prob > 0: 177 img, y = self.vertical_flip(img, y) 178 y = self.bbox_util.assign_boxes(y) 179 inputs.append(img) 180 targets.append(y) 181 if len(targets) == self.batch_size: 182 tmp_inp = np.array(inputs) 183 tmp_targets = np.array(targets) 184 inputs = [] 185 targets = [] 186 yield preprocess_input(tmp_inp), tmp_targets 187 188#path_prefix = '../../frames/' 189path_prefix = 'Mytest/fig/' 190"""gen = Generator(gt, bbox_util, 16, '../../frames/', 191 train_keys, val_keys, 192 (input_shape[0], input_shape[1]), do_crop=False) 193""" 194gen = Generator(gt,bbox_util,1,'Mytest/fig/',train_keys,val_keys,(input_shape[0], input_shape[1]), do_crop=False) 195 196model = SSD300(input_shape, num_classes=NUM_CLASSES) 197model.load_weights('weights_SSD300.hdf5', by_name=True) 198 199freeze = ['input_1', 'conv1_1', 'conv1_2', 'pool1', 200 'conv2_1', 'conv2_2', 'pool2', 201 'conv3_1', 'conv3_2', 'conv3_3', 'pool3']#, 202# 'conv4_1', 'conv4_2', 'conv4_3', 'pool4'] 203 204for L in model.layers: 205 if L.name in freeze: 206 L.trainable = False 207 208def schedule(epoch, decay=0.9): 209 return base_lr * decay**(epoch) 210 211callbacks = [keras.callbacks.ModelCheckpoint('./checkpoints/weights.{epoch:02d}-{val_loss:.2f}.hdf5', 212 verbose=1, 213 save_weights_only=True), 214 keras.callbacks.LearningRateScheduler(schedule)] 215 216base_lr = 3e-4 217optim = keras.optimizers.Adam(lr=base_lr) 218# optim = keras.optimizers.RMSprop(lr=base_lr) 219# optim = keras.optimizers.SGD(lr=base_lr, momentum=0.9, decay=decay, nesterov=True) 220model.compile(optimizer=optim, 221 loss=MultiboxLoss(NUM_CLASSES, neg_pos_ratio=2.0).compute_loss) 222 223nb_epoch = 1 224history = model.fit_generator(gen.generate(True), gen.train_batches, 225 nb_epoch, verbose=1, 226 callbacks=callbacks, 227 validation_data=gen.generate(False), 228 nb_val_samples=gen.val_batches, 229 nb_worker=1)
~.conda\envs\keras\lib\site-packages\keras\engine\training.py in fit_generator(self, generator, samples_per_epoch, nb_epoch, verbose, callbacks, validation_data, nb_val_samples, class_weight, max_q_size, nb_worker, pickle_safe, initial_epoch)
1529 '(x, y, sample_weight) '
1530 'or (x, y). Found: ' +
-> 1531 str(generator_output))
1532 if len(generator_output) == 2:
1533 x, y = generator_output
ValueError: output of generator should be a tuple (x, y, sample_weight) or (x, y). Found: None
長くなりましたが、上記エラーが発生します。アノテーションツールはLabelimg-master をつかってます
以上、ご対応のご教授を宜しくお願い致します。
あなたの回答
tips
プレビュー