ボールドテキスト
前提・実現したいこと
Yolov4を用いたDeepSortでのTrackingでの1度検出された物体にずっと同じIDを振り続けたい
発生している問題・エラーメッセージ
動画が少し動いたりするとIDが更新され番号が変わったりする。
該当のソースコード
Python
1flags.DEFINE_float('iou', 0.45, 'iou threshold') 2flags.DEFINE_float('score', 0.50, 'score threshold') 3flags.DEFINE_boolean('dont_show', False, 'dont show video output') 4flags.DEFINE_boolean('info', False, 'show detailed info of tracked objects') 5flags.DEFINE_boolean('count', False, 'count objects being tracked on screen') 6 7def main(_argv): 8 # Definition of the parameters 9 max_cosine_distance = 0.4 10 nn_budget = None 11 nms_max_overlap = 1.0 12 13 # initialize deep sort 14 model_filename = 'model_data/mars-small128.pb' 15 encoder = gdet.create_box_encoder(model_filename, batch_size=1) 16 # calculate cosine distance metric 17 metric = nn_matching.NearestNeighborDistanceMetric("cosine", max_cosine_distance, nn_budget) 18 # initialize tracker 19 tracker = Tracker(metric) 20 21 # load configuration for object detector 22 config = ConfigProto() 23 config.gpu_options.allow_growth = True 24 session = InteractiveSession(config=config) 25 STRIDES, ANCHORS, NUM_CLASS, XYSCALE = utils.load_config(FLAGS) 26 input_size = FLAGS.size 27 video_path = FLAGS.video 28 29 # load tflite model if flag is set 30 if FLAGS.framework == 'tflite': 31 interpreter = tf.lite.Interpreter(model_path=FLAGS.weights) 32 interpreter.allocate_tensors() 33 input_details = interpreter.get_input_details() 34 output_details = interpreter.get_output_details() 35 print(input_details) 36 print(output_details) 37 # otherwise load standard tensorflow saved model 38 else: 39 saved_model_loaded = tf.saved_model.load(FLAGS.weights, tags=[tag_constants.SERVING]) 40 infer = saved_model_loaded.signatures['serving_default'] 41 42 # begin video capture 43 try: 44 vid = cv2.VideoCapture(int(video_path)) 45 except: 46 vid = cv2.VideoCapture(video_path) 47 48 out = None 49 50 # get video ready to save locally if flag is set 51 if FLAGS.output: 52 # by default VideoCapture returns float instead of int 53 width = int(vid.get(cv2.CAP_PROP_FRAME_WIDTH)) 54 height = int(vid.get(cv2.CAP_PROP_FRAME_HEIGHT)) 55 fps = int(vid.get(cv2.CAP_PROP_FPS)) 56 codec = cv2.VideoWriter_fourcc(*FLAGS.output_format) 57 out = cv2.VideoWriter(FLAGS.output, codec, fps, (width, height)) 58 59 frame_num = 0 60 # while video is running 61 while True: 62 return_value, frame = vid.read() 63 if return_value: 64 frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 65 image = Image.fromarray(frame) 66 else: 67 print('Video has ended or failed, try a different video format!') 68 break 69 frame_num +=1 70 print('Frame #: ', frame_num) 71 frame_size = frame.shape[:2] 72 image_data = cv2.resize(frame, (input_size, input_size)) 73 image_data = image_data / 255. 74 image_data = image_data[np.newaxis, ...].astype(np.float32) 75 start_time = time.time() 76 77 # run detections on tflite if flag is set 78 if FLAGS.framework == 'tflite': 79 interpreter.set_tensor(input_details[0]['index'], image_data) 80 interpreter.invoke() 81 pred = [interpreter.get_tensor(output_details[i]['index']) for i in range(len(output_details))] 82 # run detections using yolov3 if flag is set 83 if FLAGS.model == 'yolov3' and FLAGS.tiny == True: 84 boxes, pred_conf = filter_boxes(pred[1], pred[0], score_threshold=0.25, 85 input_shape=tf.constant([input_size, input_size])) 86 else: 87 boxes, pred_conf = filter_boxes(pred[0], pred[1], score_threshold=0.25, 88 input_shape=tf.constant([input_size, input_size])) 89 else: 90 batch_data = tf.constant(image_data) 91 pred_bbox = infer(batch_data) 92 for key, value in pred_bbox.items(): 93 boxes = value[:, :, 0:4] 94 pred_conf = value[:, :, 4:] 95 96 boxes, scores, classes, valid_detections = tf.image.combined_non_max_suppression( 97 boxes=tf.reshape(boxes, (tf.shape(boxes)[0], -1, 1, 4)), 98 scores=tf.reshape( 99 pred_conf, (tf.shape(pred_conf)[0], -1, tf.shape(pred_conf)[-1])), 100 max_output_size_per_class=50, 101 max_total_size=50, 102 iou_threshold=FLAGS.iou, 103 score_threshold=FLAGS.score 104 ) 105 106 # convert data to numpy arrays and slice out unused elements 107 num_objects = valid_detections.numpy()[0] 108 bboxes = boxes.numpy()[0] 109 bboxes = bboxes[0:int(num_objects)] 110 scores = scores.numpy()[0] 111 scores = scores[0:int(num_objects)] 112 classes = classes.numpy()[0] 113 classes = classes[0:int(num_objects)] 114 115 # format bounding boxes from normalized ymin, xmin, ymax, xmax ---> xmin, ymin, width, height 116 original_h, original_w, _ = frame.shape 117 bboxes = utils.format_boxes(bboxes, original_h, original_w) 118 119 # store all predictions in one parameter for simplicity when calling functions 120 pred_bbox = [bboxes, scores, classes, num_objects] 121 122 # read in all class names from config 123 class_names = utils.read_class_names(cfg.YOLO.CLASSES) 124 125 # by default allow all classes in .names file 126 allowed_classes = list(class_names.values()) 127 128 # custom allowed classes (uncomment line below to customize tracker for only people) 129 #allowed_classes = ['person'] 130 131 # loop through objects and use class index to get class name, allow only classes in allowed_classes list 132 names = [] 133 deleted_indx = [] 134 for i in range(num_objects): 135 class_indx = int(classes[i]) 136 class_name = class_names[class_indx] 137 if class_name not in allowed_classes: 138 deleted_indx.append(i) 139 else: 140 names.append(class_name) 141 names = np.array(names) 142 count = len(names) 143 if FLAGS.count: 144 cv2.putText(frame, "Objects being tracked: {}".format(count), (5, 35), cv2.FONT_HERSHEY_COMPLEX_SMALL, 2, (0, 255, 0), 2) 145 print("Objects being tracked: {}".format(count)) 146 # delete detections that are not in allowed_classes 147 bboxes = np.delete(bboxes, deleted_indx, axis=0) 148 scores = np.delete(scores, deleted_indx, axis=0) 149 150 # encode yolo detections and feed to tracker 151 features = encoder(frame, bboxes) 152 detections = [Detection(bbox, score, class_name, feature) for bbox, score, class_name, feature in zip(bboxes, scores, names, features)] 153 154 #initialize color map 155 cmap = plt.get_cmap('tab20b') 156 colors = [cmap(i)[:3] for i in np.linspace(0, 1, 20)] 157 158 # run non-maxima supression 159 boxs = np.array([d.tlwh for d in detections]) 160 scores = np.array([d.confidence for d in detections]) 161 classes = np.array([d.class_name for d in detections]) 162 indices = preprocessing.non_max_suppression(boxs, classes, nms_max_overlap, scores) 163 detections = [detections[i] for i in indices] 164 165 # Call the tracker 166 tracker.predict() 167 tracker.update(detections) 168 169 # update tracks 170 for track in tracker.tracks: 171 if not track.is_confirmed() or track.time_since_update > 1: 172 continue 173 bbox = track.to_tlbr() 174 class_name = track.get_class() 175 176 # draw bbox on screen 177 color = colors[int(track.track_id) % len(colors)] 178 color = [i * 255 for i in color] 179 cv2.rectangle(frame, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), color, 2) 180 cv2.rectangle(frame, (int(bbox[0]), int(bbox[1]-30)), (int(bbox[0])+(len(class_name)+len(str(track.track_id)))*17, int(bbox[1])), color, -1) 181 cv2.putText(frame, class_name + "-" + str(track.track_id),(int(bbox[0]), int(bbox[1]-10)),0, 0.75, (255,255,255),2) 182 183 # if enable info flag then print details about each track 184 if FLAGS.info: 185 print("Tracker ID: {}, Class: {}, BBox Coords (xmin, ymin, xmax, ymax): {}".format(str(track.track_id), class_name, (int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])))) 186 187 # calculate frames per second of running detections 188 fps = 1.0 / (time.time() - start_time) 189 print("FPS: %.2f" % fps) 190 result = np.asarray(frame) 191 result = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) 192 193 if not FLAGS.dont_show: 194 cv2.imshow("Output Video", result) 195 196 # if output flag is set, save video file 197 if FLAGS.output: 198 out.write(result) 199 if cv2.waitKey(1) & 0xFF == ord('q'): break 200 cv2.destroyAllWindows() 201 202if __name__ == '__main__': 203 try: 204 app.run(main) 205 except SystemExit: 206 pass 207
試したこと
閾値の値などを変えたが流石にできなかったので直近で消えたIDをその後すぐ出てきた新しい物体に同じIDを振りたい
補足情報(FW/ツールのバージョンなど)
ソースなどはこの方を使用させていただいています
https://github.com/theAIGuysCode/yolov4-deepsort/blob/master/object_tracker.py
この方のqiitaの途中で出てくるEnd2Newというような機能を追加してあげたい
https://qiita.com/ComputerVision/items/fc64317006c25f37cc3e
あなたの回答
tips
プレビュー