助けて…
PyTorchを使ったCNNの勉強をしています.
このAttentionに関する記事の,Pytorchを使ったネットワークの定義が理解できません.
助けてほしいです….
(これまで既存のコードのコピペしかしていなかったので,ネットワークの次元の変化が追えていないです????)
理解できないところ
forward関数です.
ResNet34の出力層を取ったものを使っているので,forward関数にてx = self.features(x)
で512×3x3の特徴量xを得ていると思います.
次にself.attn_conv(x)
でnn.Sequential(nn.Conv2d(512,1,1), nn.Sigmoid())
に通用していると思います.
この出力が(コメントより),[B, 1, H, W]次元となるのが理解できていないところです.
512×3×3 => B×1×H×W
となるのが追えません.どなたかご解説願えないでしょうか…!!
該当のネットワーク,ソースコード
python
1class SimpleAttentionNetwork(nn.Module): 2 def __init__(self, num_classes): 3 super().__init__() 4 5 base_model = resnet34(pretrained=True) 6 self.features = nn.Sequential(*[layer for layer in base_model.children()][:-2]) 7 self.attn_conv = nn.Sequential( 8 nn.Conv2d(512, 1, 1), 9 nn.Sigmoid() 10 ) 11 self.fc = nn.Sequential( 12 nn.Dropout(0.5), 13 nn.Linear(512, num_classes) 14 ) 15 self.mask_ = None 16 17 def forward(self, x): 18 x = self.features(x) 19 20 attn = self.attn_conv(x) # [B, 1, H, W] 21 B, _, H, W = attn.shape 22 self.mask_ = attn.detach().cpu() 23 24 x = x * attn 25 x = F.adaptive_avg_pool2d(x, (1, 1)) 26 x = x.reshape(B, -1) 27 28 return self.fc(x)
参考
深層学習入門:画像分類(5)Attention機構|SBテクノロジー(SBT)
ResNet(Residual Network)の実装|AIdrops
回答1件
あなたの回答
tips
プレビュー