マルチラベル分類での重み付けについてお聞きしたいです。
現在tensorflowにて4種類の画像識別モデルを構築しようと思っています。
画像1枚を入力として受け取り、各クラスの物体がが写っている/いないという
判断(つまり4クラスのアウトプット)をさせています。
しかしデータ数に偏りがあり、ざっくりとした比は1:1:14:14という感じになっているため
データに重みをつけて学習をさせたいのですが、model.fit()のclass_weight引数で重みをつけてみても
イマイチ結果が変わっていないいような気がします。
調べている時、loss関数を(binary_crossentropyではなく)自前で用意するという方法も
あったのですが...今回のケースではそちらのほうが適しているのでしょうか?
以下コードです。
(学習の部分だけ載せていますが、他にも必要なものがあればお知らせください)
python
1 2# 今回のMETRICSで採用 3def macro_f1(y, y_hat, thresh=0.5): 4 y_pred = tf.cast(tf.greater(y_hat, thresh), tf.float32) 5 tp = tf.cast(tf.math.count_nonzero(y_pred * y, axis=0), tf.float32) 6 fp = tf.cast(tf.math.count_nonzero(y_pred * (1 - y), axis=0), tf.float32) 7 fn = tf.cast(tf.math.count_nonzero((1 - y_pred) * y, axis=0), tf.float32) 8 f1 = 2*tp / (2*tp + fn + fp + 1e-16) 9 macro_f1 = tf.reduce_mean(f1) 10 return macro_f1 11 12def main(): 13 14 (必要な変数は定義済) 15 16 # 学習 17 IMG_SHAPE = (img_size, img_size, channels) 18 base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE, 19 include_top=False, 20 weights='imagenet') 21 global_average_layer = tf.keras.layers.GlobalAveragePooling2D() 22 prediction_layer = tf.keras.layers.Dense(n_classes, activation='sigmoid') 23 model = tf.keras.Sequential([base_model,global_average_layer,prediction_layer]) 24 model.compile(optimizer=optimizers.SGD(lr=lr, momentum=momentum, nesterov=nesterov),loss='binary_crossentropy', metrics=[macro_f1]) 25 26 ckpt_cb = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, 27 save_weights_only=True, 28 monitor='val_macro_f1', 29 mode='max', 30 verbose=1) 31 32 csv_logger = tf.keras.callbacks.CSVLogger(ckpt_dir+'/training.csv', separator=',') 33 34 #データ数の比は a:b:c:d = 14:1:14:1 35 class_weight = {0:0.07, 1:1, 2:0.0.07, 3:1} 36 37 history = model.fit(train_ds, 38 steps_per_epoch= int(num_train//batch_size), 39 validation_data=val_ds, 40 validation_steps= int(num_val//batch_size), 41 shuffle=True, 42 epochs=epochs, 43 class_weight = class_weight, 44 callbacks=[ckpt_cb, csv_logger],) 45 46 model.save_weights(ckpt_dir + '/my_checkpoint') 47
2021/05/21 00:22