前提・実現したいこと
YOLOv4でのDeepSortをしているのですが出力結果のところに座標、FPS、検出されている物体名、IDは出力できています。これらにプラスで切り取られているフレームでの動画の再生時間??を求めたいのですがどのように書いたらいいのかわかりません
現在出力されているtxt
Frame #: 1
FPS: 0.57
Frame #: 2
FPS: 41.59
Frame #: 3
Tracker ID: 1, Class: KPHN, BBox Coords (xmin, ymin, xmax, ymax): (866, 800, 965, 850)
Tracker ID: 2, Class: KPHN, BBox Coords (xmin, ymin, xmax, ymax): (796, 850, 851, 904)
FPS: 42.47
上記はFrame4までですが全部で2588Frameあります 自分が表記したい感じは Frame #: 1 FPS: 0.57 Time:0.0001 Frame #: 2 FPS: 41.59 Time:0.0002 Frame #: 3 Tracker ID: 1, Class: KPHN, BBox Coords (xmin, ymin, xmax, ymax): (866, 800, 965, 850) Tracker ID: 2, Class: KPHN, BBox Coords (xmin, ymin, xmax, ymax): (796, 850, 851, 904) FPS: 42.47 Time:0.0003 このような感じで出したいと思ってます(時間は適当に表記しました)
該当のソースコード
python
1def main(_argv): 2 # Definition of the parameters 3 max_cosine_distance = 0.4 4 nn_budget = None 5 nms_max_overlap = 1.0 6 7 # initialize deep sort 8 model_filename = 'model_data/mars-small128.pb' 9 encoder = gdet.create_box_encoder(model_filename, batch_size=1) 10 # calculate cosine distance metric 11 metric = nn_matching.NearestNeighborDistanceMetric("cosine", max_cosine_distance, nn_budget) 12 # initialize tracker 13 tracker = Tracker(metric) 14 15 # load configuration for object detector 16 config = ConfigProto() 17 config.gpu_options.allow_growth = True 18 session = InteractiveSession(config=config) 19 STRIDES, ANCHORS, NUM_CLASS, XYSCALE = utils.load_config(FLAGS) 20 input_size = FLAGS.size 21 video_path = FLAGS.video 22 23 # load tflite model if flag is set 24 if FLAGS.framework == 'tflite': 25 interpreter = tf.lite.Interpreter(model_path=FLAGS.weights) 26 interpreter.allocate_tensors() 27 input_details = interpreter.get_input_details() 28 output_details = interpreter.get_output_details() 29 print(input_details) 30 print(output_details) 31 # otherwise load standard tensorflow saved model 32 else: 33 saved_model_loaded = tf.saved_model.load(FLAGS.weights, tags=[tag_constants.SERVING]) 34 infer = saved_model_loaded.signatures['serving_default'] 35 36 # begin video capture 37 try: 38 vid = cv2.VideoCapture(int(video_path)) 39 except: 40 vid = cv2.VideoCapture(video_path) 41 42 out = None 43 44 # get video ready to save locally if flag is set 45 if FLAGS.output: 46 # by default VideoCapture returns float instead of int 47 width = int(vid.get(cv2.CAP_PROP_FRAME_WIDTH)) 48 height = int(vid.get(cv2.CAP_PROP_FRAME_HEIGHT)) 49 fps = int(vid.get(cv2.CAP_PROP_FPS)) 50 codec = cv2.VideoWriter_fourcc(*FLAGS.output_format) 51 out = cv2.VideoWriter(FLAGS.output, codec, fps, (width, height)) 52 53 frame_num = 0 54 # while video is running 55 while True: 56 return_value, frame = vid.read() 57 if return_value: 58 frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 59 image = Image.fromarray(frame) 60 else: 61 print('Video has ended or failed, try a different video format!') 62 break 63 frame_num +=1 64 print('Frame #: ', frame_num) 65 frame_size = frame.shape[:2] 66 image_data = cv2.resize(frame, (input_size, input_size)) 67 image_data = image_data / 255. 68 image_data = image_data[np.newaxis, ...].astype(np.float32) 69 start_time = time.time() 70 71 # run detections on tflite if flag is set 72 if FLAGS.framework == 'tflite': 73 interpreter.set_tensor(input_details[0]['index'], image_data) 74 interpreter.invoke() 75 pred = [interpreter.get_tensor(output_details[i]['index']) for i in range(len(output_details))] 76 # run detections using yolov3 if flag is set 77 if FLAGS.model == 'yolov3' and FLAGS.tiny == True: 78 boxes, pred_conf = filter_boxes(pred[1], pred[0], score_threshold=0.25, 79 input_shape=tf.constant([input_size, input_size])) 80 else: 81 boxes, pred_conf = filter_boxes(pred[0], pred[1], score_threshold=0.25, 82 input_shape=tf.constant([input_size, input_size])) 83 else: 84 batch_data = tf.constant(image_data) 85 pred_bbox = infer(batch_data) 86 for key, value in pred_bbox.items(): 87 boxes = value[:, :, 0:4] 88 pred_conf = value[:, :, 4:] 89 90 boxes, scores, classes, valid_detections = tf.image.combined_non_max_suppression( 91 boxes=tf.reshape(boxes, (tf.shape(boxes)[0], -1, 1, 4)), 92 scores=tf.reshape( 93 pred_conf, (tf.shape(pred_conf)[0], -1, tf.shape(pred_conf)[-1])), 94 max_output_size_per_class=50, 95 max_total_size=50, 96 iou_threshold=FLAGS.iou, 97 score_threshold=FLAGS.score 98 ) 99 100 # convert data to numpy arrays and slice out unused elements 101 num_objects = valid_detections.numpy()[0] 102 bboxes = boxes.numpy()[0] 103 bboxes = bboxes[0:int(num_objects)] 104 scores = scores.numpy()[0] 105 scores = scores[0:int(num_objects)] 106 classes = classes.numpy()[0] 107 classes = classes[0:int(num_objects)] 108 109 # format bounding boxes from normalized ymin, xmin, ymax, xmax ---> xmin, ymin, width, height 110 original_h, original_w, _ = frame.shape 111 bboxes = utils.format_boxes(bboxes, original_h, original_w) 112 113 # store all predictions in one parameter for simplicity when calling functions 114 pred_bbox = [bboxes, scores, classes, num_objects] 115 116 # read in all class names from config 117 class_names = utils.read_class_names(cfg.YOLO.CLASSES) 118 119 # by default allow all classes in .names file 120 allowed_classes = list(class_names.values()) 121 122 # custom allowed classes (uncomment line below to customize tracker for only people) 123 #allowed_classes = ['person'] 124 125 # loop through objects and use class index to get class name, allow only classes in allowed_classes list 126 names = [] 127 deleted_indx = [] 128 for i in range(num_objects): 129 class_indx = int(classes[i]) 130 class_name = class_names[class_indx] 131 if class_name not in allowed_classes: 132 deleted_indx.append(i) 133 else: 134 names.append(class_name) 135 names = np.array(names) 136 count = len(names) 137 if FLAGS.count: 138 cv2.putText(frame, "Objects being tracked: {}".format(count), (5, 35), cv2.FONT_HERSHEY_COMPLEX_SMALL, 2, (0, 255, 0), 2) 139 print("Objects being tracked: {}".format(count)) 140 # delete detections that are not in allowed_classes 141 bboxes = np.delete(bboxes, deleted_indx, axis=0) 142 scores = np.delete(scores, deleted_indx, axis=0) 143 144 # encode yolo detections and feed to tracker 145 features = encoder(frame, bboxes) 146 detections = [Detection(bbox, score, class_name, feature) for bbox, score, class_name, feature in zip(bboxes, scores, names, features)] 147 148 #initialize color map 149 cmap = plt.get_cmap('tab20b') 150 colors = [cmap(i)[:3] for i in np.linspace(0, 1, 20)] 151 152 # run non-maxima supression 153 boxs = np.array([d.tlwh for d in detections]) 154 scores = np.array([d.confidence for d in detections]) 155 classes = np.array([d.class_name for d in detections]) 156 indices = preprocessing.non_max_suppression(boxs, classes, nms_max_overlap, scores) 157 detections = [detections[i] for i in indices] 158 159 # Call the tracker 160 tracker.predict() 161 tracker.update(detections) 162 163 # update tracks 164 for track in tracker.tracks: 165 if not track.is_confirmed() or track.time_since_update > 1: 166 continue 167 bbox = track.to_tlbr() 168 class_name = track.get_class() 169 170 # draw bbox on screen 171 color = colors[int(track.track_id) % len(colors)] 172 color = [i * 255 for i in color] 173 cv2.rectangle(frame, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), color, 2) 174 cv2.rectangle(frame, (int(bbox[0]), int(bbox[1]-30)), (int(bbox[0])+(len(class_name)+len(str(track.track_id)))*17, int(bbox[1])), color, -1) 175 cv2.putText(frame, class_name + "-" + str(track.track_id),(int(bbox[0]), int(bbox[1]-10)),0, 0.75, (255,255,255),2) 176 177 # if enable info flag then print details about each track 178 if FLAGS.info: 179 print("Tracker ID: {}, Class: {}, BBox Coords (xmin, ymin, xmax, ymax): {}".format(str(track.track_id), class_name, (int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])))) 180 181 # calculate frames per second of running detections 182 fps = 1.0 / (time.time() - start_time) 183 print("FPS: %.2f" % fps) 184 result = np.asarray(frame) 185 result = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) 186 187 if not FLAGS.dont_show: 188 cv2.imshow("Output Video", result) 189 190 # if output flag is set, save video file 191 if FLAGS.output: 192 out.write(result) 193 if cv2.waitKey(1) & 0xFF == ord('q'): break 194 cv2.destroyAllWindows() 195 196if __name__ == '__main__': 197 try: 198 app.run(main) 199 except SystemExit: 200 pass
補足情報(FW/ツールのバージョンなど)
文字数制限でソースが書ききれませんでした。
のでimportなどの部分は抜いて表記しましたが全部見たい方がいましたらお手数ですが
下のGIT内のobject_tracker.pyを見てご回答いただけると助かります
こちらの方のGitを参考に使わせていただいています
https://github.com/theAIGuysCode/yolov4-deepsort
あなたの回答
tips
プレビュー