#分からないこと
下記のページを参考に(というかそのままコピペで)BERT分類モデルを作ろうとしましたが、エラーが出てしまいます。
どのようにすれば解消できますでしょうか。
https://kaeru-nantoka.hatenablog.com/entry/2020/05/29/144745
#試したこと
ソースコードは下記のとおりです。
Python
1import pandas as pd 2import numpy as np 3import torch 4import transformers 5 6from transformers import BertJapaneseTokenizer 7from tqdm import tqdm 8tqdm.pandas() 9 10class BertSequenceVectorizer: 11 def __init__(self): 12 self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 13 self.model_name = 'cl-tohoku/bert-base-japanese-whole-word-masking' 14 self.tokenizer = BertJapaneseTokenizer.from_pretrained(self.model_name) 15 self.bert_model = transformers.BertModel.from_pretrained(self.model_name) 16 self.bert_model = self.bert_model.to(self.device) 17 self.max_len = 128 18 19 20 def vectorize(self, sentence : str) -> np.array: 21 inp = self.tokenizer.encode(sentence) 22 len_inp = len(inp) 23 24 if len_inp >= self.max_len: 25 inputs = inp[:self.max_len] 26 masks = [1] * self.max_len 27 else: 28 inputs = inp + [0] * (self.max_len - len_inp) 29 masks = [1] * len_inp + [0] * (self.max_len - len_inp) 30 31 inputs_tensor = torch.tensor([inputs], dtype=torch.long).to(self.device) 32 masks_tensor = torch.tensor([masks], dtype=torch.long).to(self.device) 33 34 seq_out, pooled_out = self.bert_model(inputs_tensor, masks_tensor) 35 36 if torch.cuda.is_available(): 37 return seq_out[0][0].cpu().detach().numpy() # 0番目は [CLS] token, 768 dim の文章特徴量 38 else: 39 return seq_out[0][0].detach().numpy() 40 41 42def cos_sim_matrix(matrix): 43 """ 44 item-feature 行列が与えられた際に 45 item 間コサイン類似度行列を求める関数 46 """ 47 d = matrix @ matrix.T # item-vector 同士の内積を要素とする行列 48 49 # コサイン類似度の分母に入れるための、各 item-vector の大きさの平方根 50 norm = (matrix * matrix).sum(axis=1, keepdims=True) ** .5 51 52 # それぞれの item の大きさの平方根で割っている(なんだかスマート!) 53 return d / norm / norm.T 54 55if __name__ == '__main__': 56 57 sample_df = pd.DataFrame(['お腹が痛いので遅れます。', 58 '頭が痛いので遅れます。', 59 'おはようございます!', 60 'kaggle が好きなかえる', 61 '味噌汁が好きなワニ' 62 ], columns=['text']) 63 64 65 BSV = BertSequenceVectorizer() 66 sample_df['text_feature'] = sample_df['text'].progress_apply(lambda x: BSV.vectorize(x)) 67 print(sample_df.head()) 68 69 print(cos_sim_matrix(np.stack(sample_df.text_feature))) 70
出力結果: --------------------------------------------------------------------------- AttributeError Traceback (most recent call last) <ipython-input-8-2e8b8691372c> in <module> 10 11 BSV = BertSequenceVectorizer() ---> 12 sample_df['text_feature'] = sample_df['text'].progress_apply(lambda x: BSV.vectorize(x)) 13 print(sample_df.head()) 14 ~\anaconda3\lib\site-packages\tqdm\std.py in inner(df, func, *args, **kwargs) 795 # on the df using our wrapper (which provides bar updating) 796 try: --> 797 return getattr(df, df_function)(wrapper, **kwargs) 798 finally: 799 t.close() ~\anaconda3\lib\site-packages\pandas\core\series.py in apply(self, func, convert_dtype, args, **kwds) 4198 else: 4199 values = self.astype(object)._values -> 4200 mapped = lib.map_infer(values, f, convert=convert_dtype) 4201 4202 if len(mapped) and isinstance(mapped[0], Series): pandas\_libs\lib.pyx in pandas._libs.lib.map_infer() ~\anaconda3\lib\site-packages\tqdm\std.py in wrapper(*args, **kwargs) 790 # take a fast or slow code path; so stop when t.total==t.n 791 t.update(n=1 if not t.total or t.n < t.total else 0) --> 792 return func(*args, **kwargs) 793 794 # Apply the provided function (in **kwargs) <ipython-input-8-2e8b8691372c> in <lambda>(x) 10 11 BSV = BertSequenceVectorizer() ---> 12 sample_df['text_feature'] = sample_df['text'].progress_apply(lambda x: BSV.vectorize(x)) 13 print(sample_df.head()) 14 <ipython-input-7-341fefaa86f9> in vectorize(self, sentence) 28 return seq_out[0][0].cpu().detach().numpy() # 0番目は [CLS] token, 768 dim の文章特徴量 29 else: ---> 30 return seq_out[0][0].detach().numpy() 31 32 AttributeError: 'str' object has no attribute 'detach'
「AttributeError: 'str' object has no attribute 'detach'」について、検索したところ
pip install transformers==3.0.0 をするとよいと書かれていましたが、それも試しました。
関係ありそうなライブラリのバージョンは下記のとおりです。
pandas 1.1.5
numpy 1.19.5
torch 1.8.0
tokenizers 0.8.0rc4
transformers 3.0.0
tqdm 4.59.0
#動作環境
Python 3.6.12 :: Anaconda, Inc.
Windows10
回答1件
あなたの回答
tips
プレビュー