timmの'swin_base_patch4_window7_224_in22k'というモデルをmask r-cnnのbackboneとして使うため、以下のコードで実装しました。
backbone = timm.create_model(model_name="swin_base_patch4_window12_384_in22k", pretrained=True, num_classes=3) backbone.reset_classifier(num_classes=0, global_pool="") backbone.out_channels = backbone.num_features model = mask_rcnn.MaskRCNN(backbone=backbone, num_classes=3, rpn_anchor_generator=AnchorGenerator(sizes=((32, 64, 128, 256, 512),), aspect_ratios=((0.5, 1.0, 2.0),)), box_roi_pool=torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'], output_size=7, sampling_ratio=2), mask_roi_pool=torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'], output_size=14, sampling_ratio=2))
しかし次のコードでbackboneが正しく接続されているかを試したところ、エラーが起きました。
from torchinfo import summary batch_size = 2 summary( model, input_size=(batch_size, 3, 384, 384), col_names=["input_size", "output_size", "num_params"], )
エラーは以下の通りです。
AssertionError Traceback (most recent call last) ~/anaconda3/lib/python3.8/site-packages/torchinfo/torchinfo.py in forward_pass(model, x, batch_dim, cache_forward_pass, device, **kwargs) 260 if isinstance(x, (list, tuple)): --> 261 _ = model.to(device)(*x, **kwargs) 262 elif isinstance(x, dict): ~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 888 else: --> 889 result = self.forward(*input, **kwargs) 890 for hook in itertools.chain( ~/anaconda3/lib/python3.8/site-packages/torchvision/models/detection/generalized_rcnn.py in forward(self, images, targets) 93 ---> 94 features = self.backbone(images.tensors) 95 if isinstance(features, torch.Tensor): ~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 888 else: --> 889 result = self.forward(*input, **kwargs) 890 for hook in itertools.chain( ~/anaconda3/lib/python3.8/site-packages/timm/models/swin_transformer.py in forward(self, x) 535 def forward(self, x): --> 536 x = self.forward_features(x) 537 x = self.head(x) ~/anaconda3/lib/python3.8/site-packages/timm/models/swin_transformer.py in forward_features(self, x) 524 def forward_features(self, x): --> 525 x = self.patch_embed(x) 526 if self.absolute_pos_embed is not None: ~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 888 else: --> 889 result = self.forward(*input, **kwargs) 890 for hook in itertools.chain( ~/anaconda3/lib/python3.8/site-packages/timm/models/layers/patch_embed.py in forward(self, x) 32 B, C, H, W = x.shape ---> 33 assert H == self.img_size[0] and W == self.img_size[1], \ 34 f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." AssertionError: Input image size (800*800) doesn't match model (384*384). The above exception was the direct cause of the following exception: RuntimeError Traceback (most recent call last) <ipython-input-24-99c2e23322ab> in <module> 1 from torchinfo import summary 2 batch_size = 2 ----> 3 summary( 4 model, 5 input_size=(batch_size, 3, 384, 384), ~/anaconda3/lib/python3.8/site-packages/torchinfo/torchinfo.py in summary(model, input_size, input_data, batch_dim, cache_forward_pass, col_names, col_width, depth, device, dtypes, row_settings, verbose, **kwargs) 192 input_data, input_size, batch_dim, device, dtypes 193 ) --> 194 summary_list = forward_pass( 195 model, x, batch_dim, cache_forward_pass, device, **kwargs 196 ) ~/anaconda3/lib/python3.8/site-packages/torchinfo/torchinfo.py in forward_pass(model, x, batch_dim, cache_forward_pass, device, **kwargs) 268 except Exception as e: 269 executed_layers = [layer for layer in summary_list if layer.executed] --> 270 raise RuntimeError( 271 "Failed to run torchinfo. See above stack traces for more details. " 272 f"Executed layers up to: {executed_layers}" RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [GeneralizedRCNNTransform: 1]
なおコードのsummary関数の引数をmodelではなくbackboneとするとエラーは出ませんでした。
backboneの正しい繋げ方を教えてください。
あなたの回答
tips
プレビュー