🎄teratailクリスマスプレゼントキャンペーン2024🎄』開催中!

\teratail特別グッズやAmazonギフトカード最大2,000円分が当たる!/

詳細はこちら
PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

Q&A

解決済

2回答

5910閲覧

Pythonの__call__メソッドと、Pytorchのクラスについて

fu_3823

総合スコア81

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

0グッド

1クリップ

投稿2021/03/19 07:38

呼び出し可能オブジェクトの振る舞いについて疑問があります。

以下のコードでimport文は省略しました。

Python

1INPUT_FEATURES = 640 * 640 2HIDDEN = 100 3OUTPUT_FEATUERS = 4 4 5class Net(nn.Module): 6 def __init_(self): 7 super().__init__() 8 self.layer1 = nn.Linear(INPUT_FEATURES, HIDDEN) 9 self.layer2 = nn.Linear(HIDDEN, HIDDEN) 10 self.layer3 = nn.Linear(HIDDEN, OUTPUT_FEATUERS) 11 12 def forward(self, x): 13 x = self.layer1(x) 14 x = self.layer2(x) 15 x = self.layer3(x) 16 return x

Pyhton

1net = Net()

このコードで、例えば、self.layer1()はnn.Linearのインスタンスで、nn.Linearクラスは__call__メソッドを
持っている。だから、x = self.layer1(x)でnn.Linear内のメソッドである、forward(self, x)が
実行されることはわかります。
しかし、net(data)で、上記Netで定義したforwardが実行される理由がわかりません。
これは、Netクラスがnn.Moduleを継承しているからでしょうか。
もし、そうであるなら、nn.Moduleを継承したクラスを実装したい場合、順伝播のメソッドはforwardという名前にしないと
いけないのでしょうか。

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答2

0

  • Netクラスがnn.Moduleを継承しているからでしょうか。

はい。

  • nn.Moduleを継承したクラスを実装したい場合、順伝播のメソッドはforwardという名前にしないといけないのでしょうか。

ソースを見る限りでは、self._forward_pre_hooksなどを設定すれば別のものに変更可能なようですが、普通はforwardメソッド使うのでしょう。

ソースを見たければ、以下を行い、その結果のディレクトリの中のmodules/module.pyの中で、__call__を検索してください。(アンダーバーは小文字のアンダーバーにかえてください)

python

1from torch import nn 2print(nn.__path__)

投稿2021/03/19 08:30

ppaul

総合スコア24670

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

0

ベストアンサー

基本的におっしゃる通りです。

nn.Moduleのソースを見れば確認できますが、nn.Moduleクラス内に、forward()メソッドはdefで定義されていませんが、self.forward()は呼び出されています。つまり、継承したクラス内でforwardは定義しなければなりません。

ソース:
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py

投稿2021/03/19 08:22

編集2021/03/19 08:46
nanoseeing

総合スコア133

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

fu_3823

2021/03/26 00:47

ありがとうございました。ソースでself.forward()が呼び出されてるということに思い至りませんでした。
guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.36%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問