前提・実現したいこと
Google colaboratory上で自前の重みモデルkanji.pthをロードし学習した結果をバウンティングボックス付きの画像として出力したいです。
くずし字の漢字を検出・認識するプログラムを作っています。
現在 Create Pytorch Dataset for Classifying Characters
内にあるDemoModel↓
class DemoModel(torch.nn.Module): def __init__(self): super(DemoModel, self).__init__() self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=16, kernel_size=7) self.relu1 = torch.nn.ReLU(inplace=True) self.maxpool1 = torch.nn.MaxPool2d(kernel_size=2) self.conv2 = torch.nn.Conv2d(in_channels=16, out_channels=128, kernel_size=6) self.relu2 = torch.nn.ReLU(inplace=True) self.maxpool2 = torch.nn.MaxPool2d(kernel_size=2) self.fc = torch.nn.Linear(in_features=128*8*8, out_features=4212, bias=True) self.log_softmax = torch.nn.LogSoftmax(dim=-1) def forward(self, x): out = self.conv1(x) # (batch, 1, 48, 48) -> (batch, 16, 42, 42) out = self.relu1(out) out = self.maxpool1(out) # (batch, 16, 42, 42) -> (batch, 16, 21, 21) out = self.conv2(out) # (batch, 16, 21, 21) -> (batch, 128, 16, 16) out = self.relu2(out) out = self.maxpool2(out) # (batch, 128, 16, 16) -> (batch, 128, 8, 8) out = out.view(out.size(0), -1) # (batch, 128, 8, 8) -> (batch, 8192) out = self.fc(out) # (batch, 8192) -> (batch, 4212) out = self.log_softmax(out) return out
をCreate Pytorch Dataset for Classifying Charactersに書かれている方法で学習した後、
torch.save(DemoModel().state_dict(),'/content/drive/MyDrive/Colab Notebooks/Kuzushiji_Visualisation/' + 'kanji' + '.pth')
で保存したモデルで物体検出を行おうとしたところ以下のエラーメッセージが発生しました。
発生している問題・エラーメッセージ
[[209. 223. 219. ... 209. 197. 140.] [180. 220. 211. ... 221. 219. 197.] [185. 182. 218. ... 212. 216. 180.] ... [ 84. 212. 208. ... 217. 220. 169.] [ 95. 217. 217. ... 230. 226. 198.] [ 24. 175. 192. ... 204. 186. 141.]] [[209. 223. 219. ... 209. 197. 140.] [180. 220. 211. ... 221. 219. 197.] [185. 182. 218. ... 212. 216. 180.] ... [ 84. 212. 208. ... 217. 220. 169.] [ 95. 217. 217. ... 230. 226. 198.] [ 24. 175. 192. ... 204. 186. 141.]] --------------------------------------------------------------------------- IndexError Traceback (most recent call last) <ipython-input-8-a04dab2fd286> in <module>() 116 117 if __name__ == '__main__': --> 118 main() 1 frames <ipython-input-8-a04dab2fd286> in detect(image, count) 91 # jは確信度上位200件のボックスのインデックス 92 # detections[0,i,j]は[conf,xmin,ymin,xmax,ymax]の形状 ---> 93 while detections[0,i,j,0] >= 0.6: 94 score = detections[0,i,j,0] 95 label_name = labels[i-1] IndexError: too many indices for tensor of dimension 2
該当のソースコード
PyTorch 学習済みモデルでサクッと物体検出をしてみる
上記のサイトを参考にさせていただきました。上記サイトではRGB画像(三層の画像)を物体検出にかけていますが、今回グレースケールの漢字の画像(48px×48px)を訓練用データとして使用したのに合わせて
##x = cv2.resize(gray_image, (300, 300)).astype(np.float32) # 300*300にリサイズ #print(x) #x -= (104.0, 117.0, 123.0) #x = x.astype(np.float32) #x = x[:, :, ::-1].copy() #x = torch.from_numpy(x).permute(2, 0, 1) # [300,300,3]→[3,300,300] #xx = Variable(x.unsqueeze(0)) # [3,300,300]→[1,3,300,300] #print(x)
を削除し、以下のdef detect(image, count):
から9文変更しました。
ipynb
1import os 2import sys 3import torch 4import torch.nn as nn 5from torch.autograd import Variable 6import numpy as np 7import cv2 8import glob 9from ssd import build_ssd 10from matplotlib import pyplot as plt 11import tensorflow as tf 12 13# SSDモデルを読み込み 14net = DemoModel() 15net.load_state_dict(torch.load('/content/drive/MyDrive/Colab Notebooks/Kuzushiji_Visualisation/kanji.pth')) 16 17# 関数 detect 18def detect(image, count): 19 rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 20 gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 21 x = cv2.resize(gray_image, (48,48)).astype(np.float32) # 元元のコードは300*300にリサイズ.訓練データに合わせ48*48 22 print(x) 23 x = x.astype(np.float32) 24 xx = torch.from_numpy(x).clone() 25 xx = Variable(xx.unsqueeze(0)) # [48,48]→[1,48,48] 26 xx = Variable(xx.unsqueeze(0)) # [1,48,48]→[1,1,48,48] 27 print(x) 28 29 # 順伝播を実行し、推論結果を出力 30 y = net(xx) 31 from data import VOC_CLASSES as labels 32 plt.figure(figsize=(10,6)) 33 colors = plt.cm.hsv(np.linspace(0, 1, 21)).tolist() 34 #plt.imshow(rgb_image) 35 currentAxis = plt.gca() 36 # 推論結果をdetectionsに格納 37 detections = y.data 38 # scale each detection back up to the image 39 scale = torch.Tensor(rgb_image.shape[1::-1]).repeat(2) 40 41 # バウンディングボックスとクラス名を表示 42 for i in range(detections.size(1)): 43 j = 0 44 # 確信度confが0.6以上のボックスを表示 45 # jは確信度上位200件のボックスのインデックス 46 # detections[0,i,j]は[conf,xmin,ymin,xmax,ymax]の形状 47 while detections[0,i,j,0] >= 0.6: 48 score = detections[0,i,j,0] 49 label_name = labels[i-1] 50 display_txt = '%s: %.2f'%(label_name, score) 51 pt = (detections[0,i,j,1:]*scale).cpu().numpy() 52 coords = (pt[0], pt[1]), pt[2]-pt[0]+1, pt[3]-pt[1]+1 53 color = colors[i] 54 currentAxis.add_patch(plt.Rectangle(*coords, fill=False, edgecolor=color, linewidth=2)) 55 currentAxis.text(pt[0], pt[1], display_txt, bbox={'facecolor':color, 'alpha':0.5}) 56 j+=1 57 plt.savefig('/content/drive/MyDrive/Colab Notebooks/Kuzushiji_Visualisation/detect_img/'+'{0:04d}'.format(count)+'.jpg') 58 plt.close() 59 60def main(): 61 files = sorted(glob.glob('/content/drive/MyDrive/Colab Notebooks/Kuzushiji_Visualisation/image_dir/*.jpg')) 62 count = 1 63 for i, file in enumerate (files): 64 image = cv2.imread(file, cv2.IMREAD_COLOR) 65 detect(image, count) 66 print(count) 67 count +=1 68 69if __name__ == '__main__': 70 main()
動作環境
Google Colaboratory
Python 3.7.11

回答1件
あなたの回答
tips
プレビュー