前提・実現したいこと
OpenCV-python を使って入力した画像のマスクを作成する関数を作っています。具体的には左下のような複数の物体が写った画像を入力した際に特定の物体の領域だけを白、それ以外の部分を黒にした右下のようなマスクを出力するようなものです。領域抽出をしたい物体は基本的に最も大きいため、現在のソースコードでは最大の領域を白、それ以外を黒にするようにしていますが、このコードだと画像によっては別の物体の領域が抽出されたり、背景の領域のほうが大きな領域だと認識され、右下のマスクの白と黒を反転させたようなものが出力されたりしてしまいます。
そこで、領域抽出したい物体は必ず画像の中央の座標を含んでいるということはわかっているので、中央の座標を含んだ領域を白とし、それ以外を黒とするようなコードに変えたいと思っています。それを実現するためにはどのようにコードを変更すればいいでしょうか?
前提として、pytorch のプログラム中での関数で、関数の引数のimageは(batch size, channel, row, col)の4次元のtorch.tensorとなっています。自分のコードで入力する image の shape は (1,2,256,256) です。batch size は1なのでbatch size でのループはありません。また、仮に thresh がすべて0などの配列だった場合などのためにエラー処理としてすべて1で埋めた配列を渡すようにしています。
該当のソースコード
python
1 2def create_mask(image): 3 """Function to create mask 4 Args: 5 image: Input image as torch tensor. Shape must be (batch size, channel, row, col) 6 Returns: 7 Torch tensor of mask image with shape (batch size, channel, row, col) 8 """ 9 10 image = image.cpu().detach().numpy() 11 mask_list = [] 12 13 for c in range(image.shape[1]): 14 slc = image[0][c].astype(np.uint8) 15 thresh = cv2.threshold(slc, 0, 255, cv2.THRESH_OTSU)[1] 16 if not thresh.any(): 17 thresh = np.ones_like(thresh) # black image if thresh is all zero 18 # extract contours 19 contours = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)[0] 20 # pick the contour that has the largest area 21 max_cnt = max(contours, key=lambda x: cv2.contourArea(x)) 22 # create the mask of the largest area filling with white 23 black = np.zeros_like(slc) 24 mask = cv2.drawContours(black, max_cnt, -1, color=255, thickness=-1) 25 cv2.fillPoly(mask, pts =[max_cnt], color=255) 26 mask = torch.from_numpy(mask).float() 27 mask_list.append(mask) 28 mask = torch.stack(mask_list) 29 return mask.unsqueeze(0)
###使用環境・ライブラリのバージョン
Ubuntu 20.04.2 LTS
OpenCV: 4.5.1
pyTorch: 1.7.1
numPy: 1.18.1