質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.46%
PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

CNN (Convolutional Neural Network)

CNN (Convolutional Neural Network)は、全結合層のみではなく畳み込み層とプーリング層で構成されるニューラルネットワークです。画像認識において優れた性能を持ち、畳み込みニューラルネットワークとも呼ばれています。

Model

MVCモデルの一部であるModelはアプリケーションで扱うデータとその動作を管理するために扱います。

コードレビュー

コードレビューは、ソフトウェア開発の一工程で、 ソースコードの検査を行い、開発工程で見過ごされた誤りを検出する事で、 ソフトウェア品質を高めるためのものです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

0回答

2124閲覧

swin transformerをmask r-cnnのbackboneとして使いたい

seiyouakadanuki

総合スコア25

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

CNN (Convolutional Neural Network)

CNN (Convolutional Neural Network)は、全結合層のみではなく畳み込み層とプーリング層で構成されるニューラルネットワークです。画像認識において優れた性能を持ち、畳み込みニューラルネットワークとも呼ばれています。

Model

MVCモデルの一部であるModelはアプリケーションで扱うデータとその動作を管理するために扱います。

コードレビュー

コードレビューは、ソフトウェア開発の一工程で、 ソースコードの検査を行い、開発工程で見過ごされた誤りを検出する事で、 ソフトウェア品質を高めるためのものです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

0クリップ

投稿2021/11/27 10:03

timmの'swin_base_patch4_window7_224_in22k'というモデルをmask r-cnnのbackboneとして使うため、以下のコードで実装しました。

backbone = timm.create_model(model_name="swin_base_patch4_window12_384_in22k", pretrained=True, num_classes=3) backbone.reset_classifier(num_classes=0, global_pool="") backbone.out_channels = backbone.num_features model = mask_rcnn.MaskRCNN(backbone=backbone, num_classes=3, rpn_anchor_generator=AnchorGenerator(sizes=((32, 64, 128, 256, 512),), aspect_ratios=((0.5, 1.0, 2.0),)), box_roi_pool=torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'], output_size=7, sampling_ratio=2), mask_roi_pool=torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'], output_size=14, sampling_ratio=2))

しかし次のコードでbackboneが正しく接続されているかを試したところ、エラーが起きました。

from torchinfo import summary batch_size = 2 summary( model, input_size=(batch_size, 3, 384, 384), col_names=["input_size", "output_size", "num_params"], )

エラーは以下の通りです。

AssertionError Traceback (most recent call last) ~/anaconda3/lib/python3.8/site-packages/torchinfo/torchinfo.py in forward_pass(model, x, batch_dim, cache_forward_pass, device, **kwargs) 260 if isinstance(x, (list, tuple)): --> 261 _ = model.to(device)(*x, **kwargs) 262 elif isinstance(x, dict): ~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 888 else: --> 889 result = self.forward(*input, **kwargs) 890 for hook in itertools.chain( ~/anaconda3/lib/python3.8/site-packages/torchvision/models/detection/generalized_rcnn.py in forward(self, images, targets) 93 ---> 94 features = self.backbone(images.tensors) 95 if isinstance(features, torch.Tensor): ~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 888 else: --> 889 result = self.forward(*input, **kwargs) 890 for hook in itertools.chain( ~/anaconda3/lib/python3.8/site-packages/timm/models/swin_transformer.py in forward(self, x) 535 def forward(self, x): --> 536 x = self.forward_features(x) 537 x = self.head(x) ~/anaconda3/lib/python3.8/site-packages/timm/models/swin_transformer.py in forward_features(self, x) 524 def forward_features(self, x): --> 525 x = self.patch_embed(x) 526 if self.absolute_pos_embed is not None: ~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 888 else: --> 889 result = self.forward(*input, **kwargs) 890 for hook in itertools.chain( ~/anaconda3/lib/python3.8/site-packages/timm/models/layers/patch_embed.py in forward(self, x) 32 B, C, H, W = x.shape ---> 33 assert H == self.img_size[0] and W == self.img_size[1], \ 34 f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." AssertionError: Input image size (800*800) doesn't match model (384*384). The above exception was the direct cause of the following exception: RuntimeError Traceback (most recent call last) <ipython-input-24-99c2e23322ab> in <module> 1 from torchinfo import summary 2 batch_size = 2 ----> 3 summary( 4 model, 5 input_size=(batch_size, 3, 384, 384), ~/anaconda3/lib/python3.8/site-packages/torchinfo/torchinfo.py in summary(model, input_size, input_data, batch_dim, cache_forward_pass, col_names, col_width, depth, device, dtypes, row_settings, verbose, **kwargs) 192 input_data, input_size, batch_dim, device, dtypes 193 ) --> 194 summary_list = forward_pass( 195 model, x, batch_dim, cache_forward_pass, device, **kwargs 196 ) ~/anaconda3/lib/python3.8/site-packages/torchinfo/torchinfo.py in forward_pass(model, x, batch_dim, cache_forward_pass, device, **kwargs) 268 except Exception as e: 269 executed_layers = [layer for layer in summary_list if layer.executed] --> 270 raise RuntimeError( 271 "Failed to run torchinfo. See above stack traces for more details. " 272 f"Executed layers up to: {executed_layers}" RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [GeneralizedRCNNTransform: 1]

なおコードのsummary関数の引数をmodelではなくbackboneとするとエラーは出ませんでした。

backboneの正しい繋げ方を教えてください。

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

まだ回答がついていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.46%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問