質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.46%
Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

Q&A

0回答

946閲覧

添付のコードを、Edge TPU PythonAPI 仕様から、PyCoralAPI(python)仕様に作り直したい。

taboopython

総合スコア40

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

0グッド

0クリップ

投稿2021/08/16 07:01

編集2021/08/16 07:10

#■やりたいこと
Coral devboard mendel(Linux)環境でSmart Bird Feeder project.を試したいと思っています。(https://coral.ai/projects/bird-feeder/

添付コード bird_classify.pyを実行しようとしていますが実行できず、次のようなエラーになります。

#■エラー

error内容

1engine = ClassificationEngine(args.model) 2NameError: name 'ClassificationEngine' is not defined

pycoral apiを使って pycoralを使うように書き換え、コードを作り直す必要があることまではわかるのですが、具体的にどのように直すと良いでしょうか。良い方法があれば教えていただければ幸いです。

#■やってみたこと
ClassificationEngineの箇所をpycoral仕様にするため、edgetpu.classification.engineを削除し、pycoralで必要になりそうなコードを追加しようと思い、次のように記述しました。

bird_classify.py もともとあった記述(削除済み)

1from edgetpu.classification.engine import ClassificationEngine

https://coral.ai/docs/reference/py/pycoral.utils/#module-pycoral.utils.dataset を参照し、ClassificationEngineの代わりになりそうなもの?と思い次のようにしてみました。

bird_classify.py 追加

1from pycoral.adapters import classify 2from pycoral.adapters import common 3from pycoral.utils.dataset import read_label_file 4from pycoral.utils.edgetpu import make_interpreter

bird_classify.py 全文 python3

1import argparse 2import time 3import re 4import imp 5import logging 6import gstreamer 7 8from PIL import Image 9from playsound import playsound 10 11from pycoral.adapters import classify 12from pycoral.adapters import common 13from pycoral.utils.dataset import read_label_file 14from pycoral.utils.edgetpu import make_interpreter 15 16def save_data(image,results,path,ext='png'): 17 """Saves camera frame and model inference results 18 to user-defined storage directory.""" 19 tag = '%010d' % int(time.monotonic()*1000) 20 name = '%s/img-%s.%s' %(path,tag,ext) 21 image.save(name) 22 print('Frame saved as: %s' %name) 23 logging.info('Image: %s Results: %s', tag,results) 24 25def load_labels(path): 26 """Parses provided label file for use in model inference.""" 27 p = re.compile(r'\s*(\d+)(.+)') 28 with open(path, 'r', encoding='utf-8') as f: 29 lines = (p.match(line).groups() for line in f.readlines()) 30 return {int(num): text.strip() for num, text in lines} 31 32def print_results(start_time, last_time, end_time, results): 33 """Print results to terminal for debugging.""" 34 inference_rate = ((end_time - start_time) * 1000) 35 fps = (1.0/(end_time - last_time)) 36 print('\nInference: %.2f ms, FPS: %.2f fps' % (inference_rate, fps)) 37 for label, score in results: 38 print(' %s, score=%.2f' %(label, score)) 39 40def do_training(results,last_results,top_k): 41 """Compares current model results to previous results and returns 42 true if at least one label difference is detected. Used to collect 43 images for training a custom model.""" 44 new_labels = [label[0] for label in results] 45 old_labels = [label[0] for label in last_results] 46 shared_labels = set(new_labels).intersection(old_labels) 47 if len(shared_labels) < top_k: 48 print('Difference detected') 49 return True 50 51def user_selections(): 52 parser = argparse.ArgumentParser() 53 parser.add_argument('--model', required=True, 54 help='.tflite model path') 55 parser.add_argument('--labels', required=True, 56 help='label file path') 57 parser.add_argument('--top_k', type=int, default=3, 58 help='number of classes with highest score to display') 59 parser.add_argument('--threshold', type=float, default=0.1, 60 help='class score threshold') 61 parser.add_argument('--storage', required=True, 62 help='File path to store images and results') 63 parser.add_argument('--sound', required=True, 64 help='File path to deterrent sound') 65 parser.add_argument('--print', default=False, required=False, 66 help='Print inference results to terminal') 67 parser.add_argument('--training', default=False, required=False, 68 help='Training mode for image collection') 69 args = parser.parse_args() 70 return args 71 72 73def main(): 74 75 args = user_selections() 76 print("Loading %s with %s labels."%(args.model, args.labels)) 77 engine = ClassificationEngine(args.model) 78 labels = load_labels(args.labels) 79 storage_dir = args.storage 80 81 82 logging.basicConfig(filename='%s/results.log'%storage_dir, 83 format='%(asctime)s-%(message)s', 84 level=logging.DEBUG) 85 86 last_time = time.monotonic() 87 last_results = [('label', 0)] 88 def user_callback(image,svg_canvas): 89 nonlocal last_time 90 nonlocal last_results 91 start_time = time.monotonic() 92 results = engine.classify_with_image(image, threshold=args.threshold, top_k=args.top_k) 93 end_time = time.monotonic() 94 results = [(labels[i], score) for i, score in results] 95 96 if args.print: 97 print_results(start_time,last_time, end_time, results) 98 99 if args.training: 100 if do_training(results,last_results,args.top_k): 101 save_data(image,results, storage_dir) 102 else: 103 104 if results[0][0] !='background': 105 save_data(image, storage_dir,results) 106 if 'fox squirrel, eastern fox squirrel, Sciurus niger' in results: 107 playsound(args.sound) 108 logging.info('Deterrent sounded') 109 110 last_results=results 111 last_time = end_time 112 result = gstreamer.run_pipeline(user_callback) 113 114if __name__ == '__main__': 115 main() 116

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

まだ回答がついていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.46%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問