前提
学習済みBERTから2つの単語ベクトルを抽出し、それらのcos類似度を求めたいです。
東北大版BERTのモデルを読み込みました。
単語ベクトルは11層のものを使用します。
実現したいこと
モデルを読み込み、単語ベクトルを画面出力するところまではできました。
その単語ベクトルからcos類似度を求めたいです。
エラーメッセージのように、"shapes (1,7,768) and (1,6,768) not aligned"となっているため、行又は列が一致していないために起こるエラーだと思っています。
この2つの語句でcos類似度を求める方法はありませんでしょうか?
発生している問題・エラーメッセージ
Traceback (most recent call last): File "/home/acd13859jl/grad_work/bert_prog.py", line 32, in <module> cos = cos_similarity(A, B) File "/home/acd13859jl/grad_work/bert_prog.py", line 28, in cos_similarity return np.dot(na, nb) File "<__array_function__ internals>", line 180, in dot ValueError: shapes (1,7,768) and (1,6,768) not aligned: 768 (dim 2) != 6 (dim 1)
該当のソースコード
python
1# tohoku-BERT 2from transformers import BertConfig, BertModel 3 4config = BertConfig.from_json_file('config.json') 5config.output_hidden_states = True # 各層の情報の取り出し 6model = BertModel.from_pretrained('pytorch_model.bin', config=config) 7 8from transformers import BertJapaneseTokenizer 9tknz = BertJapaneseTokenizer(vocab_file='vocab.txt', do_lower_case=False, do_basic_tokenize=False) 10 11from transformers.models.bert_japanese import tokenization_bert_japanese 12tknz.word_tokenizer = tokenization_bert_japanese.MecabTokenizer() 13 14 15import torch 16import numpy as np 17 18x = tknz.encode("欠けていたピース") # [2, 15201, 16, 21, 10, 14802, 3] 19y = tknz.encode("ぴったりかみ合う") # [2, 10411, 21087, 11620, 7393, 3] 20 21x = torch.LongTensor(x).unsqueeze(0) 22y = torch.LongTensor(y).unsqueeze(0) 23a = model(x) 24b = model(y) 25 26# cos類似度を求める(epsは母数がゼロにならないようにするため) 27# cos_similarityの中の処理を一部変更しました。結果自体は変わらないと思います。 28def cos_similarity(a, b, eps=1e-8): 29 cos = np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) + eps 30 return cos 31 32A = (a[2][11]).to('cpu').detach().numpy() 33B = (b[2][11]).to('cpu').detach().numpy() 34# https://tzmi.hatenablog.com/entry/2020/02/16/170928#pytorch-tensor%E3%81%8B%E3%82%89-numpy-ndarray%E3%81%B8%E5%A4%89%E6%8F%9B 35# このサイトをもとに一部A,Bを修正しました 36 37cos = cos_similarity(A, B) 38print(cos) 39
補足情報(FW/ツールのバージョンなど)
python 3.10
torch 1.12.1+cu116
numpy 1.23.4
回答3件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。