質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%
Keras

Kerasは、TheanoやTensorFlow/CNTK対応のラッパーライブラリです。DeepLearningの数学的部分を短いコードでネットワークとして表現することが可能。DeepLearningの最新手法を迅速に試すことができます。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

解決済

1回答

6479閲覧

KerasでのDeep Learningでlossがあまり下がらない

IwasakiYuuki

総合スコア9

Keras

Kerasは、TheanoやTensorFlow/CNTK対応のラッパーライブラリです。DeepLearningの数学的部分を短いコードでネットワークとして表現することが可能。DeepLearningの最新手法を迅速に試すことができます。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

0クリップ

投稿2019/05/20 05:09

当方,Pythonも自然言語処理も初心者なため,コードが読みにくかったり,場違いな質問をしてしまっているかもしれませんが,どうかお許しください.

1.タイトルにもある通り,KerasでDeep Learning(名古屋大学会話コーパスをもちいたチャットボットの作成)をしているのですが,出力されるlossの値があまり下がりません.オプティマイザーなどを変えたほうが良いでしょうか?

2.GPUがGTX970のためバッチサイズが頑張っても4くらいしか取れない状態なのですが,それも関係してくるのでしょうか?
また,この特徴ベクトルの次元数と語彙数のままバッチサイズを上げる方法などがあれば是非お教え願いたいです.

Python

1import pickle 2import numpy as np 3 4import tensorflow as tf 5import keras 6import keras_transformer 7import keras_bert 8import keras.backend as K 9 10 11config_file_path = 'data/config/config.json' 12checkpoint_file_path = 'data/checkpoint/bert_model.ckpt' 13 14with open('data/nagoya_corpus/nagoya_encoder_inputs.pickle', 'rb') as f: 15 encoder_inputs = pickle.load(f) 16with open('data/nagoya_corpus/nagoya_decoder_inputs.pickle', 'rb') as f: 17 decoder_inputs = pickle.load(f) 18with open('data/nagoya_corpus/nagoya_decoder_outputs.pickle', 'rb') as f: 19 decoder_outputs = pickle.load(f) 20 21 22def get_transformer_on_bert_model( 23 token_num: int, 24 embed_dim: int, 25 encoder_num: int, 26 decoder_num: int, 27 head_num: int, 28 hidden_dim: int, 29 embed_weights, 30 attention_activation=None, 31 feed_forward_activation: str = 'relu', 32 dropout_rate: float = 0.0, 33 use_same_embed: bool = True, 34 embed_trainable=True, 35 trainable: bool = True 36) -> keras.engine.training.Model: 37 """ 38 Transformerのモデルのinputsを特徴ベクトルにしたモデル.それ以外は特に変わらない. 39 inputsのshapeは (None, seq_len, embed_dim) となっている, 40 41 Parameters 42 ---------- 43 token_num 44 トークンのサイズ.(vocab_sizeと同じ) 45 embed_dim 46 特徴ベクトルの次元.inputsの次元数と同じにする. 47 encoder_num 48 エンコーダの層の数. 49 decoder_num 50 デコーダの層の数. 51 head_num 52 Multi-Head Attentionレイヤの分割ヘッド数. 53 hidden_dim 54 隠し層の次元数. 55 embed_weights 56 特徴ベクトルの初期化. 57 attention_activation 58 Attentionレイヤの活性化関数. 59 feed_forward_activation 60 FFNレイヤの活性化関数. 61 dropout_rate 62 Dropoutのレート. 63 use_same_embed 64 エンコーダとデコーダで同じweightsを使用するか. 65 embed_trainable 66 特徴ベクトルがトレーニング可能かどうか. 67 trainable 68 モデルがトレーニング可能かどうか. 69 70 Returns 71 ------- 72 model 73 日本語学習済みのBERTの特徴ベクトルを用いたTransformerモデル 74 """ 75 return keras_transformer.get_model( 76 token_num=token_num, 77 embed_dim=embed_dim, 78 encoder_num=encoder_num, 79 decoder_num=decoder_num, 80 head_num=head_num, 81 hidden_dim=hidden_dim, 82 embed_weights=embed_weights, 83 attention_activation=attention_activation, 84 feed_forward_activation=feed_forward_activation, 85 dropout_rate=dropout_rate, 86 use_same_embed=use_same_embed, 87 embed_trainable=embed_trainable, 88 trainable=trainable 89 ) 90 91 92def train(use_checkpoint=True): 93 if use_checkpoint: 94 transformer_model = keras_transformer.get_model( 95 token_num=32006, 96 embed_dim=768, 97 encoder_num=3, 98 decoder_num=3, 99 head_num=8, 100 hidden_dim=128, 101 dropout_rate=0.1, 102 ) 103 transformer_model.load_weights('data/checkpoint/transformer_model.ckpt') 104 else: 105 bert_model = keras_bert.load_trained_model_from_checkpoint( 106 checkpoint_file=checkpoint_file_path, 107 config_file=config_file_path 108 ) 109 bert_weights = bert_model.get_layer(name='Embedding-Token').get_weights()[0] 110 transformer_model = get_transformer_on_bert_model( 111 token_num=32006, 112 embed_dim=768, 113 encoder_num=3, 114 decoder_num=3, 115 head_num=8, 116 hidden_dim=128, 117 dropout_rate=0.1, 118 embed_weights=bert_weights, 119 ) 120 transformer_model.compile( 121 optimizer=keras.optimizers.Adam(beta_1=0.9, beta_2=0.98), 122# optimizer=keras.optimizers.SGD(), 123 loss=keras.losses.sparse_categorical_crossentropy, 124 metrics=[keras.metrics.mae, keras.metrics.categorical_accuracy], 125 ) 126 transformer_model.summary() 127 transformer_model.fit_generator( 128 generator=_generator(), 129 steps_per_epoch=100, 130 epochs=100, 131 validation_data=_generator(), 132 validation_steps=20, 133 callbacks=[ 134 keras.callbacks.ModelCheckpoint('./data/checkpoint/transformer_model.ckpt', monitor='val_loss'), 135 keras.callbacks.TensorBoard(log_dir='./data/tflog/'), 136 keras.callbacks.LearningRateScheduler(_decay), 137 ] 138 ) 139 140 141def main(): 142 train(use_checkpoint=False) 143 144 145def _decay(epochs): 146 if epochs == 0: 147 step_num = 1 148 else: 149 step_num = epochs*100 150 warmup_steps = 4000 151 d = 768 152 return (768**-0.5) * min(step_num**-0.5, step_num*(warmup_steps**-1.5)) 153 154 155def _generator(): 156 i = 0 157 data_len = len(encoder_inputs) 158 batch_size = 4 159 while True: 160 if (i + batch_size) >= data_len: 161 i = 0 162 else: 163 i += 1 164 yield [encoder_inputs[i:i+batch_size], decoder_inputs[i:i+batch_size]], decoder_outputs[i:i+batch_size] 165 166 167if __name__ == '__main__': 168 main()

ちなみに,コーパスのサイズは約170000文,モデルはTransformerのEmbeddingレイヤにBERTの特徴量を差し込んだものとなっています.

この状態で試しに100ステップ*100エポックほど回してみると,はじめはlossが8,7,6のようになさがっていったのですが,4.1あたりになると,ほぼ下がらなくなってしまいました.これくらいのステップ数とエポック数であれば,そんなものなのでしょうか?また,ある程度長時間回すと最終的にlossはどれくらいまで下がるのが普通なのでしょうか?

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

0

ベストアンサー

コードを読んでいないので、一般論で回答します。

4.1あたりになると,ほぼ下がらなくなってしまいました

収束しているような下がり方なら、step/epoch増やしても無駄です。

精度を上げるには。。。
・ハイパーパラメータ変えましょう。
・データの前処理しましょう。

投稿2019/05/20 05:15

編集2019/05/20 05:15
yamato_user

総合スコア2321

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

IwasakiYuuki

2019/05/20 08:19

ご回答有り難うございます. お教えいただいた通り,ハイパーパラメータの調整をしてみようと思います.
guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問