前提・実現したいこと
BERTを用いた日本語テキストの分類を実装しようとしているのですがモデル構築のところで以下のようなエラーメッセージが出ます。
発生している問題・エラーメッセージ
TypeError: Inputs to a layer should be tensors. Got: pooler_output
Traceback
TypeError Traceback (most recent call last) <ipython-input-28-a650ae2f3fb8> in <module>() 63 x_train = to_features(train_texts, max_length) 64 y_train = tf.keras.utils.to_categorical(train_labels, num_classes=num_classes) ---> 65 model = build_model(model_name, num_classes=num_classes, max_length=max_length) 66 67 # 訓練 2 frames <ipython-input-28-a650ae2f3fb8> in build_model(model_name, num_classes, max_length) 50 token_type_ids=token_type_ids 51 ) ---> 52 output = tf.keras.layers.Dense(num_classes, activation="softmax")(pooler_output) 53 model = tf.keras.Model(inputs=[input_ids, attention_mask, token_type_ids], outputs=[output]) 54 optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0) /usr/local/lib/python3.7/dist-packages/keras/engine/base_layer.py in __call__(self, *args, **kwargs) 1018 training=training_mode): 1019 -> 1020 input_spec.assert_input_compatibility(self.input_spec, inputs, self.name) 1021 if eager: 1022 call_fn = self.call /usr/local/lib/python3.7/dist-packages/keras/engine/input_spec.py in assert_input_compatibility(input_spec, inputs, layer_name) 194 # have a `shape` attribute. 195 if not hasattr(x, 'shape'): --> 196 raise TypeError('Inputs to a layer should be tensors. Got: %s' % (x,)) 197 198 if len(inputs) != len(input_spec):
該当のソースコード
import numpy as np import tensorflow as tf import transformers from sklearn.metrics import accuracy_score # model_nameはここから取得(cf. https://huggingface.co/transformers/pretrained_models.html) model_name = "cl-tohoku/bert-base-japanese" tokenizer = transformers.BertTokenizer.from_pretrained(model_name) # 訓練データ train_texts = [ "この犬は可愛いです", "その猫は気まぐれです", "あの蛇は苦手です" ] train_labels = [1, 0, 0] # 1: 好き, 0: 嫌い # テストデータ test_texts = [ "その猫はかわいいです", "どの鳥も嫌いです", "あのヤギは怖いです" ] test_labels = [1, 0, 0] # テキストのリストをtransformers用の入力データに変換 def to_features(texts, max_length): shape = (len(texts), max_length) # input_idsやattention_mask, token_type_idsの説明はglossaryに記載(cf. https://huggingface.co/transformers/glossary.html) input_ids = np.zeros(shape, dtype="int32") attention_mask = np.zeros(shape, dtype="int32") token_type_ids = np.zeros(shape, dtype="int32") for i, text in enumerate(texts): encoded_dict = tokenizer.encode_plus(text, max_length=max_length, pad_to_max_length=True) input_ids[i] = encoded_dict["input_ids"] attention_mask[i] = encoded_dict["attention_mask"] token_type_ids[i] = encoded_dict["token_type_ids"] return [input_ids, attention_mask, token_type_ids] # 単一テキストをクラス分類するモデルの構築 def build_model(model_name, num_classes, max_length): input_shape = (max_length, ) input_ids = tf.keras.layers.Input(input_shape, dtype=tf.int32) attention_mask = tf.keras.layers.Input(input_shape, dtype=tf.int32) token_type_ids = tf.keras.layers.Input(input_shape, dtype=tf.int32) bert_model = transformers.TFBertModel.from_pretrained(model_name) last_hidden_state, pooler_output = bert_model( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids ) output = tf.keras.layers.Dense(num_classes, activation="softmax")(pooler_output) model = tf.keras.Model(inputs=[input_ids, attention_mask, token_type_ids], outputs=[output]) optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0) model.compile(optimizer=optimizer, loss="categorical_crossentropy", metrics=["acc"]) return model num_classes = 2 max_length = 15 batch_size = 10 epochs = 3 x_train = to_features(train_texts, max_length) y_train = tf.keras.utils.to_categorical(train_labels, num_classes=num_classes) model = build_model(model_name, num_classes=num_classes, max_length=max_length) # 訓練 model.fit( x_train, y_train, batch_size=batch_size, epochs=epochs ) # 予測 x_test = to_features(test_texts, max_length) y_test = np.asarray(test_labels) y_preda = model.predict(x_test) y_pred = np.argmax(y_preda, axis=1) print("Accuracy: %.5f" % accuracy_score(y_test, y_pred))
補足情報(FW/ツールのバージョンなど)
Google colabで実装しています。
Tnesorflow==2.2.0
transformers==2.11.0
回答1件
あなたの回答
tips
プレビュー