質問編集履歴
2
タイトルの修正
title
CHANGED
@@ -1,1 +1,1 @@
|
|
1
|
-
PyTorchを使ったネットワークの解釈を手伝ってほしいです
|
1
|
+
PyTorchを使ったネットワーク(20行程度)の解釈を手伝ってほしいです
|
body
CHANGED
File without changes
|
1
コードの無関係な部分を削除しました.
title
CHANGED
File without changes
|
body
CHANGED
@@ -50,19 +50,6 @@
|
|
50
50
|
x = x.reshape(B, -1)
|
51
51
|
|
52
52
|
return self.fc(x)
|
53
|
-
|
54
|
-
def save_attention_mask(self, x, path):
|
55
|
-
B = x.shape[0]
|
56
|
-
self.forward(x)
|
57
|
-
x = x.cpu() * torch.Tensor([0.229, 0.224, 0.225]).reshape(-1, 1, 1)
|
58
|
-
x = x + torch.Tensor([0.485, 0.456, 0.406]).reshape(-1, 1, 1)
|
59
|
-
fig, axs = plt.subplots(4, 2, figsize=(6, 8))
|
60
|
-
plt.axis('off')
|
61
|
-
for i in range(4):
|
62
|
-
axs[i, 0].imshow(x[i].permute(1, 2, 0))
|
63
|
-
axs[i, 1].imshow(self.mask_[i][0])
|
64
|
-
plt.savefig(path)
|
65
|
-
plt.close()
|
66
53
|
```
|
67
54
|
|
68
55
|
### 参考
|