質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%
Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

1回答

606閲覧

apply() missing 1 required positional argument: 'fn' が出てしまいます

kun_monimoni

総合スコア26

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

0クリップ

投稿2022/06/21 09:39

編集2022/06/21 09:42

Point Netのコードを動かしており、STN3dのモデルの可視化を行いたいのですが、以下のエラーが出て困っています。
ソースコードの最後の行にエラーが発生(summary(STN3d, input_size=(1, 32 * 3 * 2500))の部分)しておりますが、input_sizeの引数に何を入れればいいかわからず困っています。
初歩的な間違いかもしれませんが、わからないため教えていただけると幸いです。

該当するエラーメッセージ

error

1TypeError Traceback (most recent call last) 2<ipython-input-31-b268eb61794c> in <module>() 3 219 4 220 device = 'cuda' if torch.cuda.is_available() else 'cpu' 5--> 221 summary(STN3d, input_size=(1, 32 * 3 * 2500)) 6 7/usr/local/lib/python3.7/dist-packages/torchsummary/torchsummary.py in summary(model, input_size, batch_size, device) 8 66 9 67 # register hook 10---> 68 model.apply(register_hook) 11 69 12 70 # make a forward pass 13 14TypeError: apply() missing 1 required positional argument: 'fn'

ソースコード

python

1from __future__ import print_function 2import torch 3import torch.nn as nn 4import torch.nn.parallel 5import torch.utils.data 6from torch.autograd import Variable 7import numpy as np 8import torch.nn.functional as F 9 10class STN3d(nn.Module): 11 def __init__(self): 12 super(STN3d, self).__init__() 13 self.conv1 = torch.nn.Conv1d(3, 64, 1) 14 self.conv2 = torch.nn.Conv1d(64, 128, 1) 15 self.conv3 = torch.nn.Conv1d(128, 1024, 1) 16 self.fc1 = nn.Linear(1024, 512) 17 self.fc2 = nn.Linear(512, 256) 18 self.fc3 = nn.Linear(256, 9) 19 self.relu = nn.ReLU() 20 21 self.bn1 = nn.BatchNorm1d(64) 22 self.bn2 = nn.BatchNorm1d(128) 23 self.bn3 = nn.BatchNorm1d(1024) 24 self.bn4 = nn.BatchNorm1d(512) 25 self.bn5 = nn.BatchNorm1d(256) 26 27 28 def forward(self, x): 29 batchsize = x.size()[0] 30 x = F.relu(self.bn1(self.conv1(x))) 31 x = F.relu(self.bn2(self.conv2(x))) 32 x = F.relu(self.bn3(self.conv3(x))) 33 x = torch.max(x, 2, keepdim=True)[0] 34 x = x.view(-1, 1024) 35 36 x = F.relu(self.bn4(self.fc1(x))) 37 x = F.relu(self.bn5(self.fc2(x))) 38 x = self.fc3(x) 39 40 iden = Variable(torch.from_numpy(np.array([1,0,0,0,1,0,0,0,1]).astype(np.float32))).view(1,9).repeat(batchsize,1) 41 if x.is_cuda: 42 iden = iden.cuda() 43 x = x + iden 44 x = x.view(-1, 3, 3) 45 return x 46 47 48class STNkd(nn.Module): 49 def __init__(self, k=64): 50 super(STNkd, self).__init__() 51 self.conv1 = torch.nn.Conv1d(k, 64, 1) 52 self.conv2 = torch.nn.Conv1d(64, 128, 1) 53 self.conv3 = torch.nn.Conv1d(128, 1024, 1) 54 self.fc1 = nn.Linear(1024, 512) 55 self.fc2 = nn.Linear(512, 256) 56 self.fc3 = nn.Linear(256, k*k) 57 self.relu = nn.ReLU() 58 59 self.bn1 = nn.BatchNorm1d(64) 60 self.bn2 = nn.BatchNorm1d(128) 61 self.bn3 = nn.BatchNorm1d(1024) 62 self.bn4 = nn.BatchNorm1d(512) 63 self.bn5 = nn.BatchNorm1d(256) 64 65 self.k = k 66 67 def forward(self, x): 68 batchsize = x.size()[0] 69 x = F.relu(self.bn1(self.conv1(x))) 70 x = F.relu(self.bn2(self.conv2(x))) 71 x = F.relu(self.bn3(self.conv3(x))) 72 x = torch.max(x, 2, keepdim=True)[0] 73 x = x.view(-1, 1024) 74 75 x = F.relu(self.bn4(self.fc1(x))) 76 x = F.relu(self.bn5(self.fc2(x))) 77 x = self.fc3(x) 78 79 iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1,self.k*self.k).repeat(batchsize,1) 80 if x.is_cuda: 81 iden = iden.cuda() 82 x = x + iden 83 x = x.view(-1, self.k, self.k) 84 return x 85 86class PointNetfeat(nn.Module): 87 def __init__(self, global_feat = True, feature_transform = False): 88 super(PointNetfeat, self).__init__() 89 self.stn = STN3d() 90 self.conv1 = torch.nn.Conv1d(3, 64, 1) 91 self.conv2 = torch.nn.Conv1d(64, 128, 1) 92 self.conv3 = torch.nn.Conv1d(128, 1024, 1) 93 self.bn1 = nn.BatchNorm1d(64) 94 self.bn2 = nn.BatchNorm1d(128) 95 self.bn3 = nn.BatchNorm1d(1024) 96 self.global_feat = global_feat 97 self.feature_transform = feature_transform 98 if self.feature_transform: 99 self.fstn = STNkd(k=64) 100 101 def forward(self, x): 102 n_pts = x.size()[2] 103 trans = self.stn(x) 104 x = x.transpose(2, 1) 105 x = torch.bmm(x, trans) 106 x = x.transpose(2, 1) 107 x = F.relu(self.bn1(self.conv1(x))) 108 109 if self.feature_transform: 110 trans_feat = self.fstn(x) 111 x = x.transpose(2,1) 112 x = torch.bmm(x, trans_feat) 113 x = x.transpose(2,1) 114 else: 115 trans_feat = None 116 117 pointfeat = x 118 x = F.relu(self.bn2(self.conv2(x))) 119 x = self.bn3(self.conv3(x)) 120 x = torch.max(x, 2, keepdim=True)[0] 121 x = x.view(-1, 1024) 122 if self.global_feat: 123 return x, trans, trans_feat 124 else: 125 x = x.view(-1, 1024, 1).repeat(1, 1, n_pts) 126 return torch.cat([x, pointfeat], 1), trans, trans_feat 127 128class PointNetCls(nn.Module): 129 def __init__(self, k=2, feature_transform=False): 130 super(PointNetCls, self).__init__() 131 self.feature_transform = feature_transform 132 self.feat = PointNetfeat(global_feat=True, feature_transform=feature_transform) 133 self.fc1 = nn.Linear(1024, 512) 134 self.fc2 = nn.Linear(512, 256) 135 self.fc3 = nn.Linear(256, k) 136 self.dropout = nn.Dropout(p=0.3) 137 self.bn1 = nn.BatchNorm1d(512) 138 self.bn2 = nn.BatchNorm1d(256) 139 self.relu = nn.ReLU() 140 141 def forward(self, x): 142 x, trans, trans_feat = self.feat(x) 143 x = F.relu(self.bn1(self.fc1(x))) 144 x = F.relu(self.bn2(self.dropout(self.fc2(x)))) 145 x = self.fc3(x) 146 return F.log_softmax(x, dim=1), trans, trans_feat 147 148 149class PointNetDenseCls(nn.Module): 150 def __init__(self, k = 2, feature_transform=False): 151 super(PointNetDenseCls, self).__init__() 152 self.k = k 153 self.feature_transform=feature_transform 154 self.feat = PointNetfeat(global_feat=False, feature_transform=feature_transform) 155 self.conv1 = torch.nn.Conv1d(1088, 512, 1) 156 self.conv2 = torch.nn.Conv1d(512, 256, 1) 157 self.conv3 = torch.nn.Conv1d(256, 128, 1) 158 self.conv4 = torch.nn.Conv1d(128, self.k, 1) 159 self.bn1 = nn.BatchNorm1d(512) 160 self.bn2 = nn.BatchNorm1d(256) 161 self.bn3 = nn.BatchNorm1d(128) 162 163 def forward(self, x): 164 batchsize = x.size()[0] 165 n_pts = x.size()[2] 166 x, trans, trans_feat = self.feat(x) 167 x = F.relu(self.bn1(self.conv1(x))) 168 x = F.relu(self.bn2(self.conv2(x))) 169 x = F.relu(self.bn3(self.conv3(x))) 170 x = self.conv4(x) 171 x = x.transpose(2,1).contiguous() 172 x = F.log_softmax(x.view(-1,self.k), dim=-1) 173 x = x.view(batchsize, n_pts, self.k) 174 return x, trans, trans_feat 175 176def feature_transform_regularizer(trans): 177 d = trans.size()[1] 178 batchsize = trans.size()[0] 179 I = torch.eye(d)[None, :, :] 180 if trans.is_cuda: 181 I = I.cuda() 182 loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2,1)) - I, dim=(1,2))) 183 return loss 184 185if __name__ == '__main__': 186 sim_data = Variable(torch.rand(32,3,2500)) 187 trans = STN3d() 188 out = trans(sim_data) 189 print('stn', out.size()) 190 print('loss', feature_transform_regularizer(out)) 191 192 sim_data_64d = Variable(torch.rand(32, 64, 2500)) 193 trans = STNkd(k=64) 194 out = trans(sim_data_64d) 195 print('stn64d', out.size()) 196 print('loss', feature_transform_regularizer(out)) 197 198 pointfeat = PointNetfeat(global_feat=True) 199 out, _, _ = pointfeat(sim_data) 200 print('global feat', out.size()) 201 202 pointfeat = PointNetfeat(global_feat=False) 203 out, _, _ = pointfeat(sim_data) 204 print('point feat', out.size()) 205 206 cls = PointNetCls(k = 5) 207 out, _, _ = cls(sim_data) 208 print('class', out.size()) 209 210 seg = PointNetDenseCls(k = 3) 211 out, _, _ = seg(sim_data) 212 print('seg', out.size()) 213 214 device = 'cuda' if torch.cuda.is_available() else 'cpu' 215 summary(STN3d, input_size=(1, 32 * 3 * 2500))

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

0

python

1 summary(STN3d, input_size=(1, 32 * 3 * 2500))

↓ 修正

python

1 summary(STN3d(), input_size=(3, 2500))

投稿2022/08/15 08:03

jbpb0

総合スコア7651

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

まだベストアンサーが選ばれていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問