今年に入ってから機械学習勉強開始した初心者です。
deeplabの学習データを使ってtflte作成、検出までをやろうとしています。
https://github.com/tensorflow/models/tree/master/research/deeplab
一応エラーが出ない状態で検出を実行できるようになったのですが、どの画像を投げても返ってくるのが0しか入っていない配列になります。
検出方法が間違っているのか?そもそもtfliteの作り方がおかしいのか切り分けしたく、検出のやり方がこれであっているのか、わかる方に教えてほしいです。
PASCAL VOCデータセットを使っています。
やったこと
deeplabのサンプルスクリプトでpbファイル作成(colabで)
tf/modelsのレポジトリをcloneしてきて local_test_mobilenetv2.sh
を使いました。
https://github.com/tensorflow/models/blob/master/research/deeplab/local_test_mobilenetv2.sh
local_test_mobilenetv2.sh
のmodelのエクスポート部分
python
1python "${WORK_DIR}"/export_model.py \ 2 --logtostderr \ 3 --checkpoint_path="${CKPT_PATH}" \ 4 --export_path="${EXPORT_PATH}" \ 5 --model_variant="mobilenet_v2" \ 6 --num_classes=21 \ 7 --crop_size=513 \ 8 --crop_size=513 \ 9 --inference_scales=1.0
tfliteに変換
整数量子化して軽量化するための設定とかあるようですが、まだそこまで理解できていないのでなしで。
shell
1!tflite_convert \ 2 --graph_def_file=/content/models/research/deeplab/datasets/pascal_voc_seg/exp/train_on_trainval_set_mobilenetv2/export/frozen_inference_graph.pb \ 3 --output_file=/content/models/research/deeplab/datasets/pascal_voc_seg/exp/train_on_trainval_set_mobilenetv2/export/frozen_inference_graph.tflite \ 4 --output_format=TFLITE \ 5 --input_shape=1,513,513,3 \ 6 --input_arrays="MobilenetV2/MobilenetV2/input" \ 7 --change_concat_input_ranges=true \ 8 --output_arrays="ArgMax"
8MBぐらいになりました。
検出
python
1from PIL import Image 2import numpy 3import sys 4import cv2 5 6# tflite読み込み(tfliteファイルは移動してます) 7interpreter = tf.lite.Interpreter(model_path="/frozen_inference_graph.tflite") 8interpreter.allocate_tensors() 9 10# input output tensor取得 11input_details = interpreter.get_input_details() 12output_details = interpreter.get_output_details() 13 14# 入出力フォーマットを確認 15print('入出力フォーマットを確認') 16print(input_details) 17print(output_details) 18 19# 入力のshape取得 20input_shape = input_details[0]['shape'] 21print('shape確認') 22print(input_shape) 23 24# テスト画像 25test_img = "/content/deeplab_sample/bicycle513x513.jpg" 26image = Image.open(test_img) 27image = image.convert("RGB") 28image = image.resize((513, 513)) 29img_data = np.asarray(image, dtype=np.uint8) 30 31# 画像shape変換 32reshaped_img = img_data.reshape(input_shape) 33print('入力データ') 34print(reshaped_img) 35interpreter.set_tensor(input_details[0]['index'], reshaped_img) 36 37# 実行 38interpreter.invoke() 39output_data = interpreter.get_tensor(output_details[0]['index']) 40print('出力データ') 41print(output_data) 42print(np.count_nonzero(output_data))
出力
入出力フォーマットを確認 [{'name': 'MobilenetV2/MobilenetV2/input', 'index': 6, 'shape': array([ 1, 513, 513, 3], dtype=int32), 'shape_signature': array([ 1, 513, 513, 3], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}] [{'name': 'ArgMax', 'index': 0, 'shape': array([ 1, 513, 513], dtype=int32), 'shape_signature': array([ 1, 513, 513], dtype=int32), 'dtype': <class 'numpy.int64'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}] shape確認 [ 1 513 513 3] 入力データ [[[[239. 247. 250.] [239. 247. 250.] [238. 246. 249.] ... [244. 247. 252.] [244. 247. 252.] [244. 248. 251.]] [[239. 247. 250.] [238. 246. 249.] [237. 245. 248.] ... [244. 247. 252.] [244. 247. 252.] [244. 248. 251.]] [[238. 245. 251.] [237. 244. 250.] [236. 243. 249.] ... [244. 248. 251.] [244. 248. 251.] [244. 248. 251.]] ... [[112. 120. 45.] [102. 111. 44.] [105. 112. 58.] ... [127. 113. 64.] [115. 101. 52.] [108. 88. 37.]] [[109. 126. 0.] [111. 128. 16.] [105. 118. 26.] ... [135. 121. 72.] [132. 118. 69.] [132. 112. 61.]] [[128. 135. 65.] [104. 111. 43.] [ 94. 100. 36.] ... [145. 124. 71.] [152. 131. 78.] [143. 126. 72.]]]] 出力データ [[[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]]] 0
試したこと
- 違う画像を使う
補足情報
本来は mask r-cnn で学習を行っていて、そっちで作成した学習データを使いたいのですが、h5ファイルからtfliteに変換するための情報がうまく見つけらないのと、変換時にわたすパラメータの理解がまだ悪く一旦あきらめました。でとりあえずdeeplabのサンプルを使ってtflite変換して検出する部分をやってみているところです。
https://github.com/matterport/Mask_RCNN
回答2件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。
2020/04/09 02:16
2020/04/09 03:16