質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%

Q&A

0回答

827閲覧

独自データで学習させたモデルを使用したい

trafield

総合スコア0

0グッド

0クリップ

投稿2021/09/09 11:36

前提・実現したいこと

https://github.com/aqeelanwar/SocialDistancingAI
現在上記を、自分で学習させたモデルを使って実行しようと考えています。

tensorflow_objectdetection_api にて転移学習をおこない、frozen_inferene_graphを出力しました。

SocialDistancingAIのmain.pyを、このモデルを使用して実行したいのですが、どこを書き換えればよいのかが分かりません。

発生している問題・エラーメッセージ

エラーメッセージ

main.py

該当のソースコード

python

1import cv2 2import os 3import argparse 4from network_model import model 5from aux_functions import * 6 7# Suppress TF warnings 8os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 9 10mouse_pts = [] 11 12 13def get_mouse_points(event, x, y, flags, param): 14 # Used to mark 4 points on the frame zero of the video that will be warped 15 # Used to mark 2 points on the frame zero of the video that are 6 feet away 16 global mouseX, mouseY, mouse_pts 17 if event == cv2.EVENT_LBUTTONDOWN: 18 mouseX, mouseY = x, y 19 cv2.circle(image, (x, y), 10, (0, 255, 255), 10) 20 if "mouse_pts" not in globals(): 21 mouse_pts = [] 22 mouse_pts.append((x, y)) 23 print("Point detected") 24 print(mouse_pts) 25 26 27# Command-line input setup 28parser = argparse.ArgumentParser(description="SocialDistancing") 29parser.add_argument( 30 "--videopath", type=str, default="123.avi", help="Path to the video file" 31) 32args = parser.parse_args() 33 34input_video = args.videopath 35 36# Define a DNN model 37DNN = model() 38# Get video handle 39cap = cv2.VideoCapture(input_video) 40height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 41width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 42fps = int(cap.get(cv2.CAP_PROP_FPS)) 43 44scale_w = 1.2 / 2 45scale_h = 4 / 2 46 47SOLID_BACK_COLOR = (41, 41, 41) 48# Setuo video writer 49fourcc = cv2.VideoWriter_fourcc(*"XVID") 50output_movie = cv2.VideoWriter("Pedestrian_detect.avi", fourcc, fps, (width, height)) 51bird_movie = cv2.VideoWriter( 52 "Pedestrian_bird.avi", fourcc, fps, (int(width * scale_w), int(height * scale_h)) 53) 54# Initialize necessary variables 55frame_num = 0 56total_pedestrians_detected = 0 57total_six_feet_violations = 0 58total_pairs = 0 59abs_six_feet_violations = 0 60pedestrian_per_sec = 0 61sh_index = 1 62sc_index = 1 63 64cv2.namedWindow("image") 65cv2.setMouseCallback("image", get_mouse_points) 66num_mouse_points = 0 67first_frame_display = True 68 69# Process each frame, until end of video 70while cap.isOpened(): 71 frame_num += 1 72 ret, frame = cap.read() 73 74 if not ret: 75 print("end of the video file...") 76 break 77 78 frame_h = frame.shape[0] 79 frame_w = frame.shape[1] 80 81 if frame_num == 1: 82 # Ask user to mark parallel points and two points 6 feet apart. Order bl, br, tr, tl, p1, p2 83 while True: 84 image = frame 85 cv2.imshow("image", image) 86 cv2.waitKey(1) 87 if len(mouse_pts) == 7: 88 cv2.destroyWindow("image") 89 break 90 first_frame_display = False 91 four_points = mouse_pts 92 93 # Get perspective 94 M, Minv = get_camera_perspective(frame, four_points[0:4]) 95 pts = src = np.float32(np.array([four_points[4:]])) 96 warped_pt = cv2.perspectiveTransform(pts, M)[0] 97 d_thresh = np.sqrt( 98 (warped_pt[0][0] - warped_pt[1][0]) ** 2 99 + (warped_pt[0][1] - warped_pt[1][1]) ** 2 100 ) 101 bird_image = np.zeros( 102 (int(frame_h * scale_h), int(frame_w * scale_w), 3), np.uint8 103 ) 104 105 bird_image[:] = SOLID_BACK_COLOR 106 pedestrian_detect = frame 107 108 print("Processing frame: ", frame_num) 109 110 # draw polygon of ROI 111 pts = np.array( 112 [four_points[0], four_points[1], four_points[3], four_points[2]], np.int32 113 ) 114 cv2.polylines(frame, [pts], True, (0, 255, 255), thickness=4) 115 116 # Detect person and bounding boxes using DNN 117 pedestrian_boxes, num_pedestrians = DNN.detect_pedestrians(frame) 118 119 if len(pedestrian_boxes) > 0: 120 pedestrian_detect = plot_pedestrian_boxes_on_image(frame, pedestrian_boxes) 121 warped_pts, bird_image = plot_points_on_bird_eye_view( 122 frame, pedestrian_boxes, M, scale_w, scale_h 123 ) 124 six_feet_violations, ten_feet_violations, pairs = plot_lines_between_nodes( 125 warped_pts, bird_image, d_thresh 126 ) 127 # plot_violation_rectangles(pedestrian_boxes, ) 128 total_pedestrians_detected += num_pedestrians 129 total_pairs += pairs 130 131 total_six_feet_violations += six_feet_violations / fps 132 abs_six_feet_violations += six_feet_violations 133 pedestrian_per_sec, sh_index = calculate_stay_at_home_index( 134 total_pedestrians_detected, frame_num, fps 135 ) 136 137 last_h = 75 138 text = "# 6ft violations: " + str(int(total_six_feet_violations)) 139 pedestrian_detect, last_h = put_text(pedestrian_detect, text, text_offset_y=last_h) 140 141 text = "Stay-at-home Index: " + str(np.round(100 * sh_index, 1)) + "%" 142 pedestrian_detect, last_h = put_text(pedestrian_detect, text, text_offset_y=last_h) 143 144 if total_pairs != 0: 145 sc_index = 1 - abs_six_feet_violations / total_pairs 146 147 text = "Social-distancing Index: " + str(np.round(100 * sc_index, 1)) + "%" 148 pedestrian_detect, last_h = put_text(pedestrian_detect, text, text_offset_y=last_h) 149 150 cv2.imshow("Street Cam", pedestrian_detect) 151 cv2.waitKey(1) 152 output_movie.write(pedestrian_detect) 153 bird_movie.write(bird_image) 154

試したこと

ここに問題に対して試したことを記載してください。

補足情報(FW/ツールのバージョンなど)

python、AIどちらもほとんど扱った経験がなく、初心者ですのでお手柔らかにお願い致します。

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

まだ回答がついていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問