前提・実現したいこと
サンプルコードを使って、手持ちの文書の5クラス分類をしたいです。
訓練の開始までは進みますが、以下のエラーが出て止まります。
発生している問題・エラーメッセージ
ValueError Traceback (most recent call last) <ipython-input-15-ced522624eac> in <module>() 49 50 for epoch in range(max_epoch): ---> 51 train_ = train(model) 52 test_ = validation(model) 53 print(f'epoch {epoch} loss : {test_}') 5 frames /usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in binary_cross_entropy_with_logits(input, target, weight, size_average, reduce, reduction, pos_weight) 2978 2979 if not (target.size() == input.size()): -> 2980 raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size())) 2981 2982 return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum) ValueError: Target size (torch.Size([32])) must be the same as input size (torch.Size([32, 5]))
該当のソースコード
Google colabでBERTを使ってライブドアニュースコーパスを多クラス分類をする
↑このコードから以下の部分を**「削って」**、手持ちの文書データを入力データ(df)とするように変更しています。
def remove_brackets(inp): # 記号とかを除く brackets_tail = re.compile('【[^】]*】$') brackets_head = re.compile('^【[^】]*】') output = re.sub(brackets_head, '', re.sub(brackets_tail, '', inp)) return output def read_title(f): # 2行スキップ next(f) # URL next(f) # タイムスタンプ title = next(f) # 3行目を返す:タイトル title = remove_brackets(title.decode('utf-8')) return title[:-1] # all_text.tsvを作る with tarfile.open(tgz_fname) as tf: # 対象ファイルの選定 for ti in tf: """ ・ライセンスファイルはスキップ ・genre内のtxt意外ならスキップ ・txtファイル意外ならスキップ ・用意したgenre意外ならスキップ """ if "LICENSE.txt" in ti.name: continue if len(ti.name.split('/')) < 3: continue if not ti.name.endswith(".txt"): continue genre = ti.name.split('/')[1] if not genre in target_genres: continue genre_index = target_genres.index(genre) fname_class_list[target_genres[genre_index]].append(ti.name) with open(tsv_fname, "w") as wf: writer = csv.writer(wf, delimiter='\t') for i, genre in enumerate(target_genres): for fname in fname_class_list[genre]: f = tf.extractfile(fname) title = read_title(f) row = [genre, i, title] writer.writerow(row)
# 作成したデータの読み込み df = pd.read_csv("all_text.tsv", delimiter='\t', header=None, names=['media_name', 'label', 'sentence']) df = df.dropna(how='any') # nanのところは落とす # データの確認 print(f'データサイズ: {df.shape}') display(df.sample(10))
また
# 分類したい種類の対象や数はここで調整する fname_class_list = { "dokujo-tsushin": [], "it-life-hack": [], "kaden-channel": [], "livedoor-homme": [], "movie-enter": [], "peachy": [], "smax": [], "sports-watch": [], "topic-news": [] } target_genres = list(fname_class_list.keys())
を
fname_class_list = { "class_1": [], "class_2": [], "class_3": [], "class_4": [], "class_5": [], } target_genres = list(fname_class_list.keys())
に修正しています。
試したこと
・ labelsを5列のワンホット・データに変更 → エラー
・ def train(model): と def validation(model): の
loss = F.cross_entropy(outputs.logits, b_labels)
を
loss = F.cross_entropy(outputs.logits, b_labels.unsqueeze(5))
に変更 → 同じエラーのまま
クラスのタイプをfloatからintに変更したところ、訓練が進むようになりました。
訓練が終わり次第、質問を閉じる予定です。
お手間を取らせて、申し訳ありませんでした。
回答1件
あなたの回答
tips
プレビュー