実現したいこと
ResNet50の最初の畳み込み層からの出力を特徴マップとして可視化したいのですが上手くいきません。チャンネル数が大きいため可視化できないということはわかるのですが具体的にどうすれば解決できるかがわからないです。
環境は変えることができないため、現在の環境でのやり方を教えていただきたいです。
発生している問題・エラーメッセージ
Traceback (most recent call last): File "feat_sample.py", line 88, in <module> main() File "feat_sample.py", line 62, in main plt.imshow(result) File "/home/xxx/.local/lib/python3.6/site-packages/matplotlib/pyplot.py", line 2730, in imshow **kwargs) File "/home/xxx/.local/lib/python3.6/site-packages/matplotlib/__init__.py", line 1447, in inner return func(ax, *map(sanitize_sequence, args), **kwargs) File "/home/xxx/.local/lib/python3.6/site-packages/matplotlib/axes/_axes.py", line 5523, in imshow im.set_data(X) File "/home/xxx/.local/lib/python3.6/site-packages/matplotlib/image.py", line 712, in set_data .format(self._A.shape)) TypeError: Invalid shape (112, 112, 64) for image data
該当のソースコード
import os import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torchvision from torchvision import datasets,transforms,models,utils import numpy as np import matplotlib.pyplot as plt import os import pandas as pd from PIL import Image import sys import random import shutil import argparse from datetime import datetime from tqdm import tqdm from sklearn import metrics import warnings import datetime import matplotlib.pyplot as plt warnings.filterwarnings('ignore') def main(): testdir = '/media/8e2435dd-0595-4145-bfff-6c1a049dfde41/data/dataset/' test_dataset = torchvision.datasets.ImageFolder( testdir, transforms.Compose([ transforms.ToTensor(), ])) testloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=8) model = models.resnet50(pretrained=False) num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, 2) model.to('cuda') model.eval() print(model) feat=model.conv1 print(feat) cnt = 0 for n, (sample, cls) in enumerate(testloader): if n == n: sample = sample.to(device) cls = cls.to(device) output = feat(sample) print(output.shape) feat_map = output.to('cpu').detach().numpy().copy()[0] feat_map = feat_map.transpose(1, 2, 0) result = (feat_map * 255).astype(np.uint8) plt.figure(figsize=(5, 5)) plt.imshow(result) plt.tick_params(bottom=False, left=False, right=False, top=False) plt.tick_params(labelbottom=False, labelleft=False, labelright=False, labeltop=False) plt.savefig(f'img{cnt}.pdf') plt.close() cnt += 1 if __name__=='__main__': seed=1 random.seed(seed) torch.manual_seed(seed) np.random.seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False device = torch.device('cuda') main()
補足情報(FW/ツールのバージョンなど)
torch 1.8.1+cu111
torch-tb-profiler 0.4.1
torchaudio 0.8.1
torchinfo 1.5.4
torchmetrics 0.8.2
torchsummary 1.5.1
torchvision 0.9.1+cu111
回答1件
あなたの回答
tips
プレビュー