質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

ただいまの
回答率

87.49%

Pytorchでnn.ModuleListで作成したモデルのtorchsummaryによる表示結果に不整合がある

解決済

回答 1

投稿 編集

  • 評価
  • クリップ 0
  • VIEW 1,054

pytorchにて以下のAlexNetのコードを書きました。
torchsummaryにてモデルを確認しようとしたところ、すべてが二回繰り返されて表示されました。

コードは間違っていない気がするのですが、原因を知りたいです。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchsummary


class AlexNet(nn.Module):
    def __init__(self, conv_channels_list=[96, 256, 384, 384, 256]):
        super(AlexNet, self).__init__()

        self.layer1 = nn.Sequential(
            nn.Conv2d(
                in_channels=3,
                out_channels=conv_channels_list[0],
                kernel_size=11,
                stride=4,
                padding=2
            ),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.LocalResponseNorm(size=5, k=2),
            nn.ReLU(inplace=True)
        )

        self.layer2 = nn.Sequential(
            nn.Conv2d(
                in_channels=conv_channels_list[0],
                out_channels=conv_channels_list[1],
                kernel_size=3,
                padding=1,
            ),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.LocalResponseNorm(size=5, k=2),
            nn.ReLU(inplace=True)
        )

        self.layer3 = nn.Sequential(
            nn.Conv2d(
                in_channels=conv_channels_list[1],
                out_channels=conv_channels_list[2],
                kernel_size=3,
                padding=1,
            ),
            nn.ReLU(inplace=True)
        )

        self.layer4 = nn.Sequential(
            nn.Conv2d(
                in_channels=conv_channels_list[2],
                out_channels=conv_channels_list[3],
                kernel_size=3,
                padding=1,
            ),
            nn.ReLU(inplace=True)
        )

        self.layer5 = nn.Sequential(
            nn.Conv2d(
                in_channels=conv_channels_list[3],
                out_channels=conv_channels_list[4],
                kernel_size=3,
                padding=1,
            ),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.ReLU(inplace=True),
            nn.Flatten()
        )

        self.layer6 = nn.Sequential(
            nn.Linear(9216, 4096),
            nn.Dropout(p=0.5),
            nn.ReLU(inplace=True)
        )

        self.layer7 = nn.Sequential(
            nn.Linear(4096, 4096),
            nn.Dropout(p=0.5),
            nn.ReLU(inplace=True)
        )

        self.layer8 = nn.Sequential(
            nn.Linear(4096, 1000),
            nn.Softmax(dim=1)
        )

        self.layers = nn.ModuleList([
            self.layer1,
            self.layer2,
            self.layer3,
            self.layer4,
            self.layer5,
            self.layer6,
            self.layer7,
            self.layer8,
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x



device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = AlexNet()
net = net.to(device)

torchsummary.summary(net, input_size=(3, 224, 224))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 96, 55, 55]          34,944
            Conv2d-2           [-1, 96, 55, 55]          34,944
         MaxPool2d-3           [-1, 96, 27, 27]               0
         MaxPool2d-4           [-1, 96, 27, 27]               0
 LocalResponseNorm-5           [-1, 96, 27, 27]               0
 LocalResponseNorm-6           [-1, 96, 27, 27]               0
              ReLU-7           [-1, 96, 27, 27]               0
              ReLU-8           [-1, 96, 27, 27]               0
            Conv2d-9          [-1, 256, 27, 27]         221,440
           Conv2d-10          [-1, 256, 27, 27]         221,440
        MaxPool2d-11          [-1, 256, 13, 13]               0
        MaxPool2d-12          [-1, 256, 13, 13]               0
LocalResponseNorm-13          [-1, 256, 13, 13]               0
LocalResponseNorm-14          [-1, 256, 13, 13]               0
             ReLU-15          [-1, 256, 13, 13]               0
             ReLU-16          [-1, 256, 13, 13]               0
           Conv2d-17          [-1, 384, 13, 13]         885,120
           Conv2d-18          [-1, 384, 13, 13]         885,120
             ReLU-19          [-1, 384, 13, 13]               0
             ReLU-20          [-1, 384, 13, 13]               0
           Conv2d-21          [-1, 384, 13, 13]       1,327,488
           Conv2d-22          [-1, 384, 13, 13]       1,327,488
             ReLU-23          [-1, 384, 13, 13]               0
             ReLU-24          [-1, 384, 13, 13]               0
           Conv2d-25          [-1, 256, 13, 13]         884,992
           Conv2d-26          [-1, 256, 13, 13]         884,992
        MaxPool2d-27            [-1, 256, 6, 6]               0
        MaxPool2d-28            [-1, 256, 6, 6]               0
             ReLU-29            [-1, 256, 6, 6]               0
             ReLU-30            [-1, 256, 6, 6]               0
          Flatten-31                 [-1, 9216]               0
          Flatten-32                 [-1, 9216]               0
           Linear-33                 [-1, 4096]      37,752,832
           Linear-34                 [-1, 4096]      37,752,832
          Dropout-35                 [-1, 4096]               0
          Dropout-36                 [-1, 4096]               0
             ReLU-37                 [-1, 4096]               0
             ReLU-38                 [-1, 4096]               0
           Linear-39                 [-1, 4096]      16,781,312
           Linear-40                 [-1, 4096]      16,781,312
          Dropout-41                 [-1, 4096]               0
          Dropout-42                 [-1, 4096]               0
             ReLU-43                 [-1, 4096]               0
             ReLU-44                 [-1, 4096]               0
           Linear-45                 [-1, 1000]       4,097,000
           Linear-46                 [-1, 1000]       4,097,000
          Softmax-47                 [-1, 1000]               0
          Softmax-48                 [-1, 1000]               0
================================================================
Total params: 123,970,256
Trainable params: 123,970,256
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 17.91
Params size (MB): 472.91
Estimated Total Size (MB): 491.39
----------------------------------------------------------------

=================補足====================
nn.ModuleListを通常のリストとした場合は正しく動いたのですが、以下のように新しいモデルを継承により作った場合に、エラーが発生しました。

class ZFNet(AlexNet):
    def __init__(self, conv_channels_list=[96, 256, 384, 384, 256]):
        super(ZFNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(
                in_channels=3,
                out_channels=conv_channels_list[0],
                kernel_size=7,
                stride=2,
                padding=2
            ),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.LocalResponseNorm(size=5, k=2),
            nn.ReLU(inplace=True)
        )
Traceback (most recent call last):
  File "/home/〇〇/workspace/machineLearning/popularModel/alexNet.py", line 824, in <module>
    summary(net, input_size=(3, 224, 224))
  File "/home/〇〇/.local/share/virtualenvs/machineLearning-WvLHzLrB/lib/python3.8/site-packages/torchsummary/torchsummary.py", line 72, in summary
    model(*x)
  File "/home/〇〇/.local/share/virtualenvs/machineLearning-WvLHzLrB/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/〇〇/workspace/machineLearning/popularModel/alexNet.py", line 161, in forward
    x = layer(x)
  File "/home/〇〇/.local/share/virtualenvs/machineLearning-WvLHzLrB/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/〇〇/.local/share/virtualenvs/machineLearning-WvLHzLrB/lib/python3.8/site-packages/torch/nn/modules/container.py", line 119, in forward
    input = module(input)
  File "/home/〇〇/.local/share/virtualenvs/machineLearning-WvLHzLrB/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/〇〇/.local/share/virtualenvs/machineLearning-WvLHzLrB/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 399, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/〇〇/.local/share/virtualenvs/machineLearning-WvLHzLrB/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 395, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

modelはGPUに乗せており、printでも確認できました。このエラーの理由も教えていただきたいです。

-----------補足追加----------------

以下が全体のコードです。
ModuleListからリストに変え、ZFnetを使ったものです。

import 同様


class AlexNet(nn.Module):

    //最初のコードと同様。文字数制限にて省略
    //self.layersをModuleListからListに変更


        self.layers = [
            self.layer1,
            self.layer2,
            self.layer3,
            self.layer4,
            self.layer5,
            self.layer6,
            self.layer7,
            self.layer8,
        ]

    def forward(self, x):
        同様


class ZFNet(AlexNet):
    同様


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = ZFNet()
net = net.to(device)

torchsummary.summary(net, input_size=(3, 224, 224))
  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 過去に投稿した質問と同じ内容の質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

質問への追記・修正、ベストアンサー選択の依頼

  • jbpb0

    2021/04/08 09:14

    質問内容とは関係無いですが、torchsummaryは開発が止まっているようなので、torchinfoに乗り換えるといいみたいです
    https://qiita.com/tand826/items/ae3349495944048fd120

    キャンセル

  • mikan_professor

    2021/04/08 10:38

    わかりました!ありがとうございます。

    キャンセル

回答 1

checkベストアンサー

+1

nn.ModuleListの使い方が間違っています。

PyTorchのモジュールは、フォワードメソッドと、各層のパラメータリストを保持最適化する構造になっています。質問者様のコードは、フォワードメソッドは意図通り動作すると思いますが、__init__で各層とModuleListでまとめた層のパラメータリストを二重に登録してしまっています。なお、余計なものが登録されているというだけで、動作は正常かもしれません。

修正する方針は2種類あります。

1) __init__で、self.layersnn.ModuleListは使わずに通常のリストとする。こうすると、self.layersはモジュールとして認識されませんので、パラメータリストは登録されません。フォワードメソッド側だけで使われます。

2) __init__で、各階層self.layerxをモジュールのメンバ変数としてとして登録するのはやめて、layerxというローカル変数にして、nn.ModuleListは各ローカル変数をモジュールリストにしてモジュールのメンバ変数として登録する。この場合、各層の名前は残りませんが、インデックスでアクセス可能になるという利点があります。

投稿

  • 回答の評価を上げる

    以下のような回答は評価を上げましょう

    • 正しい回答
    • わかりやすい回答
    • ためになる回答

    評価が高い回答ほどページの上位に表示されます。

  • 回答の評価を下げる

    下記のような回答は推奨されていません。

    • 間違っている回答
    • 質問の回答になっていない投稿
    • スパムや攻撃的な表現を用いた投稿

    評価を下げる際はその理由を明確に伝え、適切な回答に修正してもらいましょう。

  • 2021/04/09 22:37 編集

    計算をするのが大変ですが、単純に、層を書き換えたのでinput_sizeが変わったにもかかわわらず、同じサイズでtorchsummaryしたから怒られただけだと思います。それを確定するために、質問者様の完全なコードが必要です。実際、サイズ指定が不要なtorchinfoは普通に実行できます。

    キャンセル

  • 2021/04/15 17:27

    追記しました。すみませんが文字数制限で完全なコードはあげられず、以前の部分と変更がないところは同様と記述しております。

    キャンセル

  • 2021/04/15 18:26

    ありがとうございます。「層を書き換えたのでinput_sizeが変わったにもかかわわらず、同じサイズでtorchsummaryしたから怒られた」で確定だと思います。

    キャンセル

15分調べてもわからないことは、teratailで質問しよう!

  • ただいまの回答率 87.49%
  • 質問をまとめることで、思考を整理して素早く解決
  • テンプレート機能で、簡単に質問をまとめられる

同じタグがついた質問を見る