以下のGANのDiscriminatorについて,全体の出力と各中間層のリストを出力にしたいです.
PyTorch
1class Discriminator(nn.Module): 2 def __init__(self,img_channel): 3 super(Discriminator, self).__init__() 4 self.block = nn.Sequential( 5 nn.Conv2d(img_channel,64,kernel_size=4,stride=2,padding=1), 6 nn.LeakyReLU(0.2,inplace=True), 7 8 nn.Conv2d(64,128,kernel_size=4,stride=2,padding=1,bias=True), 9 nn.InstanceNorm2d(128), 10 nn.LeakyReLU(0.2,inplace=True), 11 12 nn.Conv2d(128,256,kernel_size=4,stride=2,padding=1,bias=True), 13 nn.InstanceNorm2d(256), 14 nn.LeakyReLU(0.2,inplace=True), 15 16 nn.Conv2d(256,512,kernel_size=4,stride=1,padding=1,bias=True), 17 nn.InstanceNorm2d(512), 18 nn.LeakyReLU(0.2,inplace=True), 19 20 nn.Conv2d(512,1,kernel_size=4,stride=1,padding=1) 21 ) 22 23 def forward(self, x): 24 x = self.block(x) 25 return x 26
現在は出力が各層を通った全体の出力のみが返り値になっていますが,これに加えて各LeakyReLU層を通った段階での中間層をリストにして
return x, [中間層のリスト]
というような返り値にしたいです.どのように変更を加えればそのようにできるでしょうか.
PyTorch はじめて 30 分くらいしか経ってないのでコメントにとどめておきますが、register_forward_hook というのが使えそうです。Pytorch - 中間層の出力を取得する方法 - pystyle https://pystyle.info/pytorch-extract-intermediate-layer-output/
あなたの回答
tips
プレビュー