tensorflowのkerasを使って勉強しています。
レイヤー層の重要性
レイヤー層の数の重要性に関して調べていて、多ければ多いほど学習はするが、勾配が消失したりする恐れがあるとのことでした。なのでBatchNormalizationを入れれば解決するのではないか、と思い実験してみることにしました。
モデルのイメージ画像は以下です。(あえて勉学のために極端な設計です。)
質問したいことは以下の4点です。
質問
1.2つのインプットをConcatenateする前と後でそれぞれの層数をそろえたのですが、多すぎますでしょうか。一般的にはつなげる前は5層くらいで、つなげた後に7~8層用意する位でしょうか?(それぞれ3層だとPearson scoreは正常なのですが、これだとnanになってしまいます。)
2.BatchNormalizationは、どこに入れるべきなのでしょうか。Concatenateする前と後両方に入れるべきでしょうか。
3. 自分なりに、勾配消失に対処するため、ところどころにregularizerを指定するようにしてみたのですが、これでもPearson_scoreはnanになってしまいます。もっと減らすべきなのでしょか。
4.DropoutとBatchNormalizationを併用する際、このページにとても勉強になるまとめがあるのですが、FC層の前にだけDropoutを入れるとありました。コードは以下のようになっているのですが、この場合どうするべきなのでしょうか。
python
1def get_model(): 2 investment_id_inputs = tf.keras.Input((1, ), dtype=tf.uint16) 3 features_inputs = tf.keras.Input((300, ), dtype=tf.float16) 4 5 investment_id_x = investment_id_lookup_layer(investment_id_inputs) 6 investment_id_x = keras.layers.Embedding(investment_id_size, 32, input_length=1)(investment_id_x) 7 investment_id_x = keras.layers.Reshape((-1, ))(investment_id_x) 8 investment_id_x = keras.layers.Dense(64, activation='swish')(investment_id_x) 9 investment_id_x = keras.layers.Dense(64, activation='swish')(investment_id_x) 10 investment_id_x = keras.layers.Dense(64, activation='swish')(investment_id_x) 11 investment_id_x = keras.layers.Dense(64, activation='swish')(investment_id_x) 12 investment_id_x = keras.layers.Dense(64, activation='swish')(investment_id_x) 13 investment_id_x = keras.layers.Dense(64, activation='swish')(investment_id_x) 14 investment_id_x = keras.layers.Dense(64, activation='swish')(investment_id_x) 15 investment_id_x = keras.layers.Dense(64, activation='swish')(investment_id_x) 16 investment_id_x = keras.layers.Dense(64, activation='swish')(investment_id_x) 17 investment_id_x = keras.layers.Dense(64, activation='swish')(investment_id_x) 18 19 feature_x = keras.layers.Dense(256, activation='swish')(features_inputs) 20 feature_x = keras.layers.Dense(256, activation='swish')(feature_x) 21 feature_x = keras.layers.Dense(256, activation='swish')(feature_x) 22 feature_x = keras.layers.Dense(256, activation='swish')(feature_x) 23 feature_x = keras.layers.Dense(256, activation='swish')(feature_x) 24 feature_x = keras.layers.Dense(256, activation='swish')(feature_x) 25 feature_x = keras.layers.Dense(256, activation='swish')(feature_x) 26 feature_x = keras.layers.Dense(256, activation='swish')(feature_x) 27 feature_x = keras.layers.Dense(256, activation='swish')(feature_x) 28 feature_x = keras.layers.Dense(256, activation='swish')(feature_x) 29 30 x = keras.layers.Dropout(0.45)(feature_x) 31 32 33 x = keras.layers.Concatenate(axis=1)([investment_id_x, feature_x]) 34 x = keras.layers.Dense(512, activation='swish', kernel_regularizer=regularizers.l1_l2())(x) 35 x = keras.layers.Dense(256, activation='swish')(x) 36 x = keras.layers.Dense(128, activation='swish', kernel_regularizer=regularizers.l1_l2())(x) 37 x = keras.layers.Dense(128, activation='swish')(x) 38 x = keras.layers.Dense(64, activation='swish', kernel_regularizer=regularizers.l1_l2())(x) 39 x = keras.layers.Dense(64, activation='swish')(x) 40 x = keras.layers.Dense(32, activation='swish', kernel_regularizer=regularizers.l1_l2())(x) 41 x = keras.layers.Dense(32, activation='swish')(x) 42 # x = keras.layers.Dense(32)(x) 43 # x = keras.layers.BatchNormalization()(x) 44 # x = keras.layers.Activation('swish')(x) 45 x = keras.layers.Dense(32, activation='swish', kernel_regularizer=regularizers.l1_l2())(x) 46 x = keras.layers.Dense(32, activation='swish')(x) 47 # x = keras.layers.Dense(32)(x) 48 # x = keras.layers.BatchNormalization()(x) 49 # x = keras.layers.Activation('swish')(x) 50 51 x = keras.layers.Dropout(0.45)(x) 52 output = keras.layers.Dense(1)(x) 53 ###layer num are increased, we have to care about gradient vanishing 54 #this one is out... Pearson score is nan... 55 56 rmse = keras.metrics.RootMeanSquaredError(name="rmse") 57 model = tf.keras.Model(inputs=[investment_id_inputs, features_inputs], outputs=[output]) 58 model.compile(optimizer=tf.optimizers.Adam(0.001), loss='mse', metrics=['mse', "mae", "mape", rmse]) 59 return model 60 61model = get_model() 62print(model.summary()) 63keras.utils.plot_model(model, show_shapes = True)
質問させていただいた点以外に、何かご指摘・アドバイス等ありましたら、コメントいただけると有難いです。どうぞよろしくお願い致します。
勉強に使用させていただいた文献
過学習を防ぐ「正則化」とは?
バッチ正規化:ディープラーニングの最大のブレークスルー
深層学習(勾配消失問題~CNN)
機械学習で「分からん!」となりがちな正則化の図を分かりやすく解説
深層学習 Day 2 - Section 1 勾配消失問題 のまとめ
L2 Regularization and Batch Norm
Dropoutによる過学習の抑制
あなたの回答
tips
プレビュー