以下のpytorchのmodelのクラスがあります。
このクラスに
python
1self.width = int(model.input.width) # 右項は例 2self.height = int(model.input.height)# 右項は例
を追加したいです。そのためにmodelのinputのwidthとheightが知りたいです。
このクラスのパラメータから、inputするshapeのwidthとheightを特定したいのですが、このクラスからinputするwidthとheightの値は何なるんでしょうか?
python
1# model 用Class 2class DeepLabHeadV3Plus(nn.Module): 3 def __init__(self, in_channels, low_level_channels, num_classes, aspp_dilate=[12, 24, 36]): 4 super(DeepLabHeadV3Plus, self).__init__() 5 self.project = nn.Sequential( 6 nn.Conv2d(low_level_channels, 48, 1, bias=False), 7 nn.BatchNorm2d(48), 8 nn.ReLU(inplace=True), 9 ) 10 # self.width = ここにinputのwidth追加したい 11 # self.height = ここにinputのheight追加したい 12 self.aspp = ASPP(in_channels, aspp_dilate) 13 14 self.classifier = nn.Sequential( 15 nn.Conv2d(304, 256, 3, padding=1, bias=False), 16 nn.BatchNorm2d(256), 17 nn.ReLU(inplace=True), 18 nn.Conv2d(256, num_classes, 1) 19 ) 20 self._init_weight() 21 22 def forward(self, feature): 23 low_level_feature = self.project( feature['low_level'] ) 24 output_feature = self.aspp(feature['out']) 25 output_feature = F.interpolate(output_feature, size=low_level_feature.shape[2:], mode='bilinear', align_corners=False) 26 return self.classifier( torch.cat( [ low_level_feature, output_feature ], dim=1 ) ) 27 28 def _init_weight(self): 29 for m in self.modules(): 30 if isinstance(m, nn.Conv2d): 31 nn.init.kaiming_normal_(m.weight) 32 elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 33 nn.init.constant_(m.weight, 1) 34 nn.init.constant_(m.bias, 0)
回答1件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。