🎄teratailクリスマスプレゼントキャンペーン2024🎄』開催中!

\teratail特別グッズやAmazonギフトカード最大2,000円分が当たる!/

詳細はこちら
PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Q&A

1回答

4679閲覧

TransformerのPytorchでの実装のPositionalEncodingクラスのエラー

ryu1

総合スコア4

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

0グッド

1クリップ

投稿2019/09/27 08:47

編集2019/09/30 05:20

前提・実現したいこと

こちらの解説記事のソースコードを実行しようと試みましたができない

発生している問題・エラーメッセージ

positional encodingの部分でエラーが起きているっぽい

Traceback (most recent call last): File "C:../..", line 245, in <module> tmp_model = make_model(10, 10, 2) File "C:../..", line 228, in make_model position = PositionalEncoding(d_model, dropout) File "C:../..", line 212, in __init__ pe[:, 0::2] = torch.sin(position * div_term) RuntimeError: expected device cpu and dtype Float but got device cpu and dtype Long

該当のソースコード

python

1class PositionalEncoding(nn.Module): 2 3 def __init__(self, d_model, dropout, max_len=5000): 4 super(PositionalEncoding, self).__init__() 5 self.dropout = nn.Dropout(p=dropout) 6 7 8 pe = torch.zeros(max_len, d_model) 9 position = torch.arange(0, max_len).unsqueeze(1) 10 div_term = torch.exp(torch.arange(0., d_model, 2) * 11 -(math.log(10000.0) / d_model)) 12 pe[:, 0::2] = torch.sin(position * div_term) 13 pe[:, 1::2] = torch.cos(position * div_term) 14 pe = pe.unsqueeze(0) 15 self.register_buffer('pe', pe) 16 17 def forward(self, x): 18 x = x + Variable(self.pe[:, :x.size(1)], 19 requires_grad=False) 20 return self.dropout(x) 21 22def make_model(src_vocab, tgt_vocab, N=6, 23 d_model=512, d_ff=2048, h=8, dropout=0.1): 24 "Helper: Construct a model from hyperparameters." 25 c = copy.deepcopy 26 attn = MultiHeadedAttention(h, d_model) 27 ff = PositionwiseFeedForward(d_model, d_ff, dropout) 28 position = PositionalEncoding(d_model, dropout) 29 model = EncoderDecoder( 30 Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N), 31 Decoder(DecoderLayer(d_model, c(attn), c(attn), 32 c(ff), dropout), N), 33 nn.Sequential(Embeddings(d_model, src_vocab), c(position)), 34 nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)), 35 Generator(d_model, tgt_vocab)) 36 37 # This was important from their code. 38 # Initialize parameters with Glorot / fan_avg. 39 for p in model.parameters(): 40 if p.dim() > 1: 41 nn.init.xavier_uniform(p) 42 return model 43 44# Small example model. 45tmp_model = make_model(10, 10, 2)

試したこと

解説記事のプログラムをそのまま実行すると別のエラー(RuntimeError: exp_vml_cpu not implemented for 'Long')が出ていたので

Traceback (most recent call last): File "C:../..", line 245, in <module> tmp_model = make_model(10, 10, 2) File "C:../..", line 228, in make_model position = PositionalEncoding(d_model, dropout) File "C:../..", line 211, in __init__ -(math.log(10000.0) / d_model)) RuntimeError: exp_vml_cpu not implemented for 'Long'

こちらを参考に0に小数点を付けました

補足情報

python 3.5.4
torch 1.1.0

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

meg_

2019/09/28 03:36 編集

・リンクは「リンクの挿入」で記入してください。 ・エラーメッセージ全文を質問に追記してください。
guest

回答1

0

position = torch.arange(0, max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))

position = torch.arange(0., max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0., d_model, 2) * -(math.log(10000.0) / d_model))

に変更するとできるようです.
これは,トーチexpとsinは以前LongTensorをサポートしていましたが、もうサポートしていない可能性があるためらしいです.(それについてはよくわかりません)

詳しくはこちらのサイトに書いてあります.

RuntimeError: “exp” not implemented for 'torch.LongTensor'

投稿2019/10/23 08:36

aqufiz

総合スコア70

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

まだベストアンサーが選ばれていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.36%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問