質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.50%
Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

Q&A

1回答

5306閲覧

ValueErrorについて

退会済みユーザー

退会済みユーザー

総合スコア0

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

0グッド

0クリップ

投稿2019/01/22 11:06

編集2022/01/12 10:55

機械学習で言語を学習をさせたいのですが以下のエラーで困っています。
jupyter notebookで実行しました。

python

1import logging 2import torch 3import torch.nn as nn 4import torch.nn.functional as F 5import torch.optim as optim 6from text_encoder import JapaneseTextEncoder 7 8 9logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO) 10logging.info("Building dictionary and dataset.") 11encoder = JapaneseTextEncoder(corpus, append_eos=True, maxlen=10, padding=True) 12encoder.build() 13logging.info("Done...") 14 15n_vocab = len(encoder.word2id) 16EMBEDDING_DIM = HIDDEN_DIM =128 17batch_size = 10 18logging.info("Vocab has %i elements.", n_vocab) 19logging.info("The longest sentence has %i elements.", len(max(encoder.dataset, key=len))) 20device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 21class RNNLM(nn.Module): 22 def __init__(self, embedding_dim, hidden_dim, vocab_size, batch_size=10, num_layers=1): 23 super(RNNLM, self).__init__() 24 self.batch_size = batch_size 25 self.num_layers = num_layers 26 self.hidden_dim = hidden_dim 27 28 self.word_embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=1) 29 self.dropout = nn.Dropout(p=0.5) 30 31 self.gru = nn.GRU(embedding_dim, hidden_dim, batch_first=True, num_layers=self.num_layers) 32 33 self.output = nn.Linear(hidden_dim, vocab_size) 34 35 def init_hidden(self): 36 self.hidden_state = torch.zeros(self.num_layers, self.batch_size, self.hidden_dim, device=device) 37 38 def forward(self, indices): 39 embed = self.word_embeddings(indices) # batch_len x sequence_length x embedding_dim 40 drop_out = self.dropout(embed) 41 if drop_out.dim() == 1: 42 drop_out = torch.unsqueeze(drop_out, 1) 43 gru_out, self.hidden_state = self.gru(drop_out, self.hidden_state)# batch_len x sequence_length x hidden_dim 44 gru_out = gru_out.contiguous() 45 return self.output(gru_out) 46 47 48def train2batch(dataset): 49 batch_dataset = [] 50 for i in range(0, 1350, batch_size): 51 batch_dataset.append(dataset[i:i+batch_size]) 52 return batch_dataset 53 54n_epoch = 75000 55model = RNNLM(EMBEDDING_DIM, HIDDEN_DIM, n_vocab).to(device) 56optimizer = optim.SGD(model.parameters(), lr=0.01) 57criterion = nn.CrossEntropyLoss(ignore_index=0)# Training 58logging.info("Training mode") 59model.train() 60for epoch in range(1, n_epoch+1): 61 62 if epoch % 100 == 0: 63 logging.info("Epoch %i: %.2f", epoch, loss.item()) 64 65 encoder.shuffle() 66 # len(encoder.dataset) == 1361 67 batch_dataset = train2batch(encoder.dataset) 68 for batch_data in batch_dataset: 69 model.zero_grad() 70 model.init_hidden() 71 72 batch_tensor = torch.tensor(batch_data, device=device) 73 input_tensor = batch_tensor[:, :-1] 74 target_tensor = batch_tensor[:, 1:].contiguous() 75 outputs = model(input_tensor) 76 outputs = outputs.view(-1, n_vocab) 77 targets = target_tensor.view(-1) 78 loss = criterion(outputs, targets) 79 loss.backward() 80 optimizer.step() 81 82torch.save(model.state_dict(), "wikipedia.model") 83 84 85```ValueError Traceback (most recent call last) 86<ipython-input-29-d4ab6d8de40f> in <module> 87 72 model.init_hidden() 88 73 89---> 74 batch_tensor = torch.tensor(batch_data, device=device) 90 75 input_tensor = batch_tensor[:, :-1] 91 76 target_tensor = batch_tensor[:, 1:].contiguous() 92 93ValueError: expected sequence of length 12 at dim 1 (got 10) 94 95ValueErrorについてどのようにすればいいのかを教えてください。よろしくお願いします。

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

y_waiwai

2019/01/22 11:19

このままではコードが見づらいので、質門を編集し、<code>ボタンで、出てくる’’’の枠の中にコードを貼り付けてください
guest

回答1

0

「調べてみた」の詳細を描いてください。
次元1の長さに12を期待しているが10である
なので、データがおかしいと思いますが、それは「調べてみた」に入っているでしょうか?

投稿2019/01/23 11:59

Q71

総合スコア995

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

まだベストアンサーが選ばれていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.50%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問