現在,Deep Learningで背景削除をしてみる
https://tech.fusic.co.jp/posts/2020-01-20-remove-background/
というページのコードを実際に動かそうとしていますが、うまくいきません。
具体的には、下記のコードの" back = back.astype(float)/255"という部分が
AttributeError: 'NoneType' object has no attribute 'astype'というエラーを返します。
それと、途中までのコードで画像をtrimapという前景が白、後景が黒、どちらかわからない部分が灰色という風に変換する部分があるのですが、そのtrimapの画像が黒と灰色しか出ません。その結果として最終的に作られた画像が、背景が消えてはいるのですが、残したい部分も薄くなります。←上記のattributeErrorがでるまではこのような形で出力されていました。
#ライブラリインポート import numpy as np import cv2 import matplotlib.pyplot as plt import torch import torchvision from torchvision import transforms #画像読み込み img = cv2.imread("upperbody1.jpg") img = img[...,::-1] h,w,_ = img.shape img = cv2.resize(img,(320,320)) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True) model = model.to(device) model.eval() preprocess = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) input_tensor = preprocess(img) input_batch = input_tensor.unsqueeze(0).to(device) with torch.no_grad(): output = model(input_batch)['out'][0] output = output.argmax(0) mask = output.byte().cpu().numpy() mask = cv2.resize(mask,(w,h)) img = cv2.resize(img,(w,h)) cv2.imwrite('./examples/trimaps/picture.jpg',mask) henkan = cv2.imread("./examples/trimaps/picture.jpg") black = [0, 0, 0] white = [255, 255, 255] henkan[np.where((henkan != black).all(axis=2))] = white cv2.imwrite("./examples/trimaps/henkan.jpg",henkan) plt.figure(figsize=(20,20)) plt.subplot(1,2,1) plt.imshow(img) plt.subplot(1,2,2) plt.imshow(henkan) plt.show() img = cv2.imread("upperbody1.jpg") img = img[...,::-1] back = cv2.imread("./example/trimaps/picture.jpg") h,w,_ = img.shape bg = np.full_like(img,255) img = img.astype(float) bg = bg.astype(float) back = back.astype(float)/255 img = cv2.multiply(img, back) bg = cv2.multiply(bg, 1.0 - back) output = cv2.add(img, bg) cv2.imwrite("./example/extraction.jpg",output) h = cv2.imread("./example/extraction.jpg") cv2.imshow("output",h) cv2.waitKey(0) cv2.destroyAllWindows()
書き忘れていましたが環境は
windows10
python3.9.6 64bit版
opencv
です。
あなたの回答
tips
プレビュー