pytorchでclassの中のnn.BatchNorm2d
層を取り除きたいのですが、コメントアウトするとエラーが出ます。
python
1class ConvBNReLU(nn.Sequential): 2 def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, dilation=1, groups=1): 3 #padding = (kernel_size - 1) // 2 4 super(ConvBNReLU, self).__init__( 5 nn.Conv2d(in_planes, out_planes, kernel_size, stride, 0, dilation=dilation, groups=groups, bias=False), 6 nn.BatchNorm2d(out_planes), # コメントアウトするとエラーになる 7 nn.ReLU6(inplace=True) 8 )
以下のようにBatchNorm2dを取り除きたいのですが、どのように書いたら良いでしょうか?
やりたいこと
python
1class ConvBNReLU(nn.Sequential): 2 def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, dilation=1, groups=1): 3 #padding = (kernel_size - 1) // 2 4 super(ConvBNReLU, self).__init__( 5 nn.Conv2d(in_planes, out_planes, kernel_size, stride, 0, dilation=dilation, groups=groups, bias=False), 6 # nn.BatchNorm2d(out_planes), 7 nn.ReLU6(inplace=True) 8 )
どのようなエラーがでるのでしょうか?
forward() 時にエラーが出るのか、重み読み込みでエラーが出るのか etc...
回答1件
あなたの回答
tips
プレビュー