現在、SSD-kerasについて勉強しています。以下のリンクからダウンロードして使用しています。
git
サンプルコードであるSSD.ipynb
は動きました。そのため、自作のデータセットを使用し学習させできたモデルで自分で撮った写真の物体検出を行おうと思いました。
以下がコードとなります。
python
1import cv2 2import keras 3from keras.applications.imagenet_utils import preprocess_input 4from keras.backend.tensorflow_backend import set_session 5from keras.models import Model 6from keras.preprocessing import image 7import matplotlib.pyplot as plt 8import numpy as np 9from scipy.misc import imread 10import tensorflow as tf 11 12from ssd import SSD300 13from ssd_utils import BBoxUtility 14 15#%matplotlib inline 16plt.rcParams['figure.figsize'] = (8, 8) 17plt.rcParams['image.interpolation'] = 'nearest' 18 19np.set_printoptions(suppress=True) 20 21config = tf.ConfigProto() 22config.gpu_options.per_process_gpu_memory_fraction = 0.45 23set_session(tf.Session(config=config)) 24voc_classes = ['glass', 'Bicycle', 'Bird', 'Boat', 'Bottle', 25 'Bus', 'Car', 'Cat', 'Chair', 'Cow', 'Diningtable', 26 'Dog', 'Horse','Motorbike', 'Person', 'Pottedplant', 27 'Sheep', 'Sofa', 'Train', 'Tvmonitor'] 28NUM_CLASSES = len(voc_classes) + 1 29 30input_shape=(300, 300, 3) 31model = SSD300(input_shape, num_classes=NUM_CLASSES) 32model.load_weights('./checkpoints/weights.00-2.28.hdf5', by_name=True) 33bbox_util = BBoxUtility(NUM_CLASSES) 34 35inputs = [] 36images = [] 37img_path = './pics/car_cat.jpg' 38img = image.load_img(img_path, target_size=(300, 300)) 39img = image.img_to_array(img) 40images.append(imread(img_path)) 41inputs.append(img.copy()) 42 43preds = model.predict(inputs, batch_size=1, verbose=1) 44 45results = bbox_util.detection_out(preds) 46 47a = model.predict(inputs, batch_size=1) 48b = bbox_util.detection_out(preds) 49 50for i, img in enumerate(images): 51 # Parse the outputs. 52 det_label = results[i][:, 0] 53 det_conf = results[i][:, 1] 54 det_xmin = results[i][:, 2] 55 det_ymin = results[i][:, 3] 56 det_xmax = results[i][:, 4] 57 det_ymax = results[i][:, 5] 58 59 # Get detections with confidence higher than 0.6. 60 top_indices = [i for i, conf in enumerate(det_conf) if conf >= 0.6] 61 62 top_conf = det_conf[top_indices] 63 top_label_indices = det_label[top_indices].tolist() 64 top_xmin = det_xmin[top_indices] 65 top_ymin = det_ymin[top_indices] 66 top_xmax = det_xmax[top_indices] 67 top_ymax = det_ymax[top_indices] 68 69 colors = plt.cm.hsv(np.linspace(0, 1, 21)).tolist() 70 71 plt.imshow(img / 255.) 72 currentAxis = plt.gca() 73 74 for i in range(top_conf.shape[0]): 75 xmin = int(round(top_xmin[i] * img.shape[1])) 76 ymin = int(round(top_ymin[i] * img.shape[0])) 77 xmax = int(round(top_xmax[i] * img.shape[1])) 78 ymax = int(round(top_ymax[i] * img.shape[0])) 79 score = top_conf[i] 80 label = int(top_label_indices[i]) 81 label_name = voc_classes[label - 1] 82 display_txt = '{:0.2f}, {}'.format(score, label_name) 83 coords = (xmin, ymin), xmax-xmin+1, ymax-ymin+1 84 color = colors[label] 85 currentAxis.add_patch(plt.Rectangle(*coords, fill=False, edgecolor=color, linewidth=2)) 86 currentAxis.text(xmin, ymin, display_txt, bbox={'facecolor':color, 'alpha':0.5}) 87 88 plt.show() 89コード
動かしてみたおころ、以下のようなエラーが出てしまいました。
expected input_1 to have 4 dimentions, but got array with shape (300,300,3)
撮影した画像を流すことが出来るようにするには何を改善する必要があるか教えていただきたいです。
回答1件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。