質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.35%
深層学習

深層学習は、多数のレイヤのニューラルネットワークによる機械学習手法。人工知能研究の一つでディープラーニングとも呼ばれています。コンピューター自体がデータの潜在的な特徴を汲み取り、効率的で的確な判断を実現することができます。

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

自然言語処理

自然言語処理は、日常的に使用される自然言語をコンピューターに処理させる技術やソフトウェアの総称です。

Q&A

0回答

1329閲覧

pytorchでseq2seqモデルにBERTを組み込んで翻訳がしたい

shell33

総合スコア0

深層学習

深層学習は、多数のレイヤのニューラルネットワークによる機械学習手法。人工知能研究の一つでディープラーニングとも呼ばれています。コンピューター自体がデータの潜在的な特徴を汲み取り、効率的で的確な判断を実現することができます。

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

自然言語処理

自然言語処理は、日常的に使用される自然言語をコンピューターに処理させる技術やソフトウェアの総称です。

0グッド

0クリップ

投稿2021/12/16 15:43

前提・実現したいこと

閲覧ありがとうございます。初めての質問、深層学習の初学者です。
参考書を元にpytorchでseq2seqモデルを構築しました。
このモデルの性能を上げるために、学習済みのBERTモデルを組み込んで学習を行いました。
BERTを組み込んだ学習済みモデルで日英翻訳を行ったがうまく学習できていません

発生している問題・エラーメッセージ

モデルの学習を行うプログラムは正常に動いているようで、損失が順調に減っていきますが、
このモデルで翻訳を行うと意味不明な文字列が出力されてしまいます

該当のソースコード

python

1class MyAttNMT(nn.Module): 2 def __init__(self, jv, ev, k, jmodel, emodel): 3 super(MyAttNMT, self).__init__() 4 self.jmodel = jmodel 5 self.emodel = emodel 6 self.lstm1 = nn.LSTM(k, k, num_layers=2, batch_first=True) 7 self.lstm2 = nn.LSTM(k, k, num_layers=2, batch_first=True) 8 self.Wc = nn.Linear(2*k, k) 9 self.W = nn.Linear(k, ev) 10 def forward(self, jline, eline): 11 x = self.jmodel(jline) 12 ox, (hnx, cnx) = self.lstm1(x[0]) 13 y = self.emodel(eline) 14 oy, (hny, cny) = self.lstm2(y[0],(hnx, cnx)) 15 ox1 = ox.permute(0,2,1) 16 sim = torch.bmm(oy,ox1) 17 bs, yws, xws = sim.shape 18 sim2 = sim.reshape(bs*yws,xws) 19 alpha = F.softmax(sim2,dim=1).reshape(bs, yws, xws) 20 ct = torch.bmm(alpha,ox) 21 oy1 = torch.cat([ct,oy],dim=2) 22 oy2 = self.Wc(oy1) 23 return self.W(oy2)

引数はjv=日本語BERTモデルの語彙数、ev=英語BERTモデルの語彙数、k=設定次元数、
jmodelは日本語BERTモデル、emodelは英語BERTモデル、
jlineはテンソル型の和文の単語ID配列、elineはテンソル型の英文の単語ID配列

jv,evはどちらも30000語ほどで設定次元数はBERTモデルの単語ベクトルの次元数に合わせて768にしてあります
学習用の対訳ペアは50000文で100エポックほど学習を行いました

試したこと

最初、BERTモデルのところはpytorchのnn.Embedding()を使用しており、埋め込み語彙数は日本語9000語、英語14000語、設定次元数は200で50000文を100エポック分の学習を行いました。そのときはそれなりに翻訳できていて、比較的短め文はちゃんと翻訳できていました。
自分の中で考えでは最後のnn.Linearが入力次元数が768次元で出力次元数が30000次元とかなり差が大きくなっていてうまく学習できていないのかなと考えています。この考えは正しいのでしょうか?このような場合は根気強くうまくいく次元数を探していかなければならないのでしょうか?
一応、ネットで同じような構成を探してみたり、次元数の設定のコツなどについて調べてみたりしたのですが、見つかりませんでした
ご助言をいただけるとありがたいです

補足情報(FW/ツールのバージョンなど)

現在、実行環境、ソースコードが手元になく、上記ソースコードのBERTモデルの定義部分に間違いがあるかもしれません。
プログラムの動作は確認できているので、ソースコードはネットワークの構成の参考にしてください

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

まだ回答がついていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.35%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問