質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.35%
深層学習

深層学習は、多数のレイヤのニューラルネットワークによる機械学習手法。人工知能研究の一つでディープラーニングとも呼ばれています。コンピューター自体がデータの潜在的な特徴を汲み取り、効率的で的確な判断を実現することができます。

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

0回答

951閲覧

PyTorchテキスト分類 学習済みモデルの利用

rukaff

総合スコア5

深層学習

深層学習は、多数のレイヤのニューラルネットワークによる機械学習手法。人工知能研究の一つでディープラーニングとも呼ばれています。コンピューター自体がデータの潜在的な特徴を汲み取り、効率的で的確な判断を実現することができます。

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

0クリップ

投稿2020/05/12 06:06

前提・実現したいこと

PyTorchでテキスト分類モデルを作りたいと考えています。

単語数105個で学習したモデルを
「torch.save」でpthファイルに保存して使っています。

学習時と同じ105単語のデータではエラーは出ずに、予測ができました。

しかし、108単語のデータを予測しようと思い
「model.load_state_dict」でpthファイルを読み込むと、下記のエラーになりました。

学習時には無かった新たな単語を含んでいる場合、どのように書くべきなのでしょうか。

そもそも、新しい単語がある場合は再度学習が必要になるなど
考え方が間違っていたらご指摘いただけると助かります。

何卒よろしくお願いいたします。

参考にしたサイト:https://qiita.com/m__k/items/db1a81bb06607d5b0ec5

発生している問題・エラーメッセージ

エラーメッセージ RuntimeError: Error(s) in loading state_dict for LSTMClassifier: size mismatch for word_embeddings.weight: copying a param with shape torch.Size([105, 100]) from checkpoint, the shape in current model is torch.Size([108, 100]).

該当のソースコード

python

1ソースコード 2class LSTMClassifier(nn.Module): 3 def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size): 4 super(LSTMClassifier, self).__init__() 5 self.hidden_dim = hidden_dim 6 self.word_embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=0) 7 self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True) 8 self.hidden2tag = nn.Linear(hidden_dim, tagset_size) 9 self.softmax = nn.LogSoftmax() 10 11 def forward(self, sentence): 12 embeds = self.word_embeddings(sentence) 13 _, lstm_out = self.lstm(embeds) 14 tag_space = self.hidden2tag(lstm_out[0]) 15 tag_scores = self.softmax(tag_space.squeeze()) 16 return tag_scores 17 18model = LSTMClassifier(EMBEDDING_DIM, HIDDEN_DIM, VOCAB_SIZE, TAG_SIZE).to(device) 19loss_function = nn.NLLLoss() 20 21optimizer = optim.Adam(model.parameters(), lr=0.001) 22 23model.load_state_dict(torch.load("xxxx.pth"))

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

まだ回答がついていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.35%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問