前提・実現したいこと
PyTorchのTransformer
モジュールを使ってTransformerのモデルを作っています。
ですが、forward
の引数の***_mask
で何を隠せばよいかわからず悩んでいます。
Transformer — PyTorch master documentation
上のリファレンスによれば、src_mask
のサイズは(S,S)
、tgt_mask
のサイズは(T,T)
とあり、隠したいポジションにTrue
または1
を指定せよ、と書いてあると思います。
考えたこと
たとえばsrc_mask
のサイズは(S,S)
つまり(シーケンス長✕シーケンス長)ですが、バッチサイズの指定がありません。つまりこれはバッチの各文の<pad>
の部分を明示するものではないと考えました。
なので、バッチで取り込まれる各文に適用されるmaskとなる...と考えるとtgtなら未来を隠すmaskだろうと想像つくのですが、srcのほうには必要だっけ...?となり、わからなくなったという状況です。
さらに、この引数が実際に使われているところを見るとMultiHeadAttention
モジュールだとわかったのですが、こちらはmaskのサイズが(N,S)
で(バッチサイズ✕シーケンス長)だったので、もっとわからなくなってしまいました。
Transformerの仕組みはとりあえず理解できた程度で勘違いをしているかもしれませんが、どうかご回答よろしくおねがいします。
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。