質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.35%
深層学習

深層学習は、多数のレイヤのニューラルネットワークによる機械学習手法。人工知能研究の一つでディープラーニングとも呼ばれています。コンピューター自体がデータの潜在的な特徴を汲み取り、効率的で的確な判断を実現することができます。

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

解決済

1回答

838閲覧

PyTorchを使ったネットワーク(20行程度)の解釈を手伝ってほしいです

r.kanke

総合スコア1

深層学習

深層学習は、多数のレイヤのニューラルネットワークによる機械学習手法。人工知能研究の一つでディープラーニングとも呼ばれています。コンピューター自体がデータの潜在的な特徴を汲み取り、効率的で的確な判断を実現することができます。

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

0クリップ

投稿2021/05/16 22:18

編集2021/05/16 22:29

助けて…

PyTorchを使ったCNNの勉強をしています.
このAttentionに関する記事の,Pytorchを使ったネットワークの定義が理解できません.
助けてほしいです….

(これまで既存のコードのコピペしかしていなかったので,ネットワークの次元の変化が追えていないです????)

理解できないところ

forward関数です.

ResNet34の出力層を取ったものを使っているので,forward関数にてx = self.features(x)で512×3x3の特徴量xを得ていると思います.
次にself.attn_conv(x)nn.Sequential(nn.Conv2d(512,1,1), nn.Sigmoid())に通用していると思います.
この出力が(コメントより),[B, 1, H, W]次元となるのが理解できていないところです.

512×3×3 => B×1×H×W
となるのが追えません.どなたかご解説願えないでしょうか…!!

該当のネットワーク,ソースコード

イメージ説明

python

1class SimpleAttentionNetwork(nn.Module): 2 def __init__(self, num_classes): 3 super().__init__() 4 5 base_model = resnet34(pretrained=True) 6 self.features = nn.Sequential(*[layer for layer in base_model.children()][:-2]) 7 self.attn_conv = nn.Sequential( 8 nn.Conv2d(512, 1, 1), 9 nn.Sigmoid() 10 ) 11 self.fc = nn.Sequential( 12 nn.Dropout(0.5), 13 nn.Linear(512, num_classes) 14 ) 15 self.mask_ = None 16 17 def forward(self, x): 18 x = self.features(x) 19 20 attn = self.attn_conv(x) # [B, 1, H, W] 21 B, _, H, W = attn.shape 22 self.mask_ = attn.detach().cpu() 23 24 x = x * attn 25 x = F.adaptive_avg_pool2d(x, (1, 1)) 26 x = x.reshape(B, -1) 27 28 return self.fc(x)

参考

深層学習入門:画像分類(5)Attention機構|SBテクノロジー(SBT)
ResNet(Residual Network)の実装|AIdrops

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

jbpb0

2021/05/16 23:56

> これまで既存のコードのコピペしかしていなかったので,ネットワークの次元の変化が追えていないです その自覚があるのでしたら、まずそこからやりましょうよ こういうツールもあるし https://qiita.com/tand826/items/ae3349495944048fd120
r.kanke

2021/05/17 00:16

やはりそうですか….正直どう手をつけたものかと敬遠していました. 向き合う時が来たんですね…. torchinfo入れてみました.こんなに見やすくなるんですね! ありがとうございます!
quickquip

2021/05/17 01:22 編集

BやHやWの実際の値を確認したけどどうしてその値になるのかわからない、という質問なのか BやHやWが実際にどんな値になるのかわからない、という質問なのか どうしてBやHやWに代入するのかわからない、という質問なのか BやHやWってなんですか、という質問なのか の区別が付かないですね……
r.kanke

2021/05/17 01:52

Bが何を表すのかが知りたかったです. しかし指摘して頂いた内,実際の値は確認していませんでした.やってきます.
quickquip

2021/05/17 02:14 編集

だとすると質問は「Conv2Dの中身を理解したい」と同義ですね。 Conv2Dで変換した結果の、テンソルの第1要素(インデックス0番)の次元数をBとしよう というだけなので。 追記:入れ違いました 追記: 次元数 は誤用でした
r.kanke

2021/05/17 02:11

なるほどそこの理解が足りなかったのですね….なんで詰まったのかが分かったので有り難いです! Conv2dの復習をしようと思います.
guest

回答1

0

自己解決

Attentionマスクのサイズ(次元)を勘違いしていました.

画像では入力画像と同じサイズになっていますが,それは可視化した場合.実際にはバッチサイズ×チャネル×タテ×ヨコ = 32×1×7×7となっていました.

xとattnのshapeを画面出力させて確認しました.  
x.shape: torch.Size([32, 512, 7, 7])
attn.shape: torch.Size([32, 1, 7, 7])

投稿2021/05/17 02:01

編集2021/05/17 02:06
r.kanke

総合スコア1

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

r.kanke

2021/05/17 02:02

@quickquipさんがご指摘くださったように,何が分からないのか,明確にすると対処法が見つかりそうです. @jbpb0さん,@quickquipさん,アドバイスありがとうございました.非常に助かりました????????
guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.35%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問