質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

ただいまの
回答率

88.92%

CNNの2値分類で精度が変化しない

解決済

回答 1

投稿

  • 評価
  • クリップ 0
  • VIEW 367

yukawakota

score 14

前提・実現したいこと

CNNを用いて自前データセットの画像の2値分類を行おうと思っています。
データセットが不均衡なので(クラス0:クラス1=8:1くらい)、クラス1のデータを画像のデータ拡張によってオーバーサンプリングしています。

該当のソースコード

import os
import numpy as np
import keras
from keras import layers
from keras import models
from keras import optimizers
from keras.preprocessing.image import ImageDataGenerator
from keras.utils import plot_model
import tensorflow.keras.backend as K

from sklearn.utils import shuffle

import pandas as pd

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
plt.style.use('ggplot')

from datasets import make_dir, extract_data

def cnn_model(image_size):
    model = models.Sequential()
    model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(150, 150, 1)))
    model.add(layers.normalization.BatchNormalization())
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.Flatten())
    model.add(layers.Dropout(0.5))
    model.add(layers.Dense(30, activation='relu'))
    model.add(layers.normalization.BatchNormalization())
    model.add(layers.Dense(1, activation='sigmoid'))
    return model

def train(x_train, y_train, x_val, y_val, image_size, output_dir, fold_num=5):
    model =cnn_model(image_size)
    model.summary()
    plot_model(model, show_shapes=True, to_file=output_dir + '/model.png')
    history_list = []
    csv_dir = os.path.join(output_dir, 'training_result_csv')
    make_dir(csv_dir)
    model_dir = os.path.join(output_dir, 'saved_model')
    make_dir(model_dir)
    # データの水増し(Data Augmentation)
    datagen = ImageDataGenerator(rescale=1./255,
                                   rotation_range=40,
                                   width_shift_range=0.2,
                                   height_shift_range=0.2,
                                   shear_range=0.2,
                                   zoom_range=0.2,
                                   horizontal_flip=True)
    x_0 = x_train[np.where(y_train==0)[0]]
    x_1 = x_train[np.where(y_train==1)[0]]
    y_0 = y_train[np.where(y_train==0)[0]]
    x_1, y_1 = over_sampling(x_1, datagen, 8, 1)
    x_train = np.concatenate([x_0, x_1])
    y_train = np.concatenate([y_0, y_1])
    x_train, y_train = shuffle(x_train, y_train, random_state=0)
    # 水増し画像を訓練用画像の形式に合わせる
    datagen.fit(x_train, augment=True, rounds=10)
    filename = os.path.join(csv_dir, 'training_result.csv')
    callbacks_list = [keras.callbacks.CSVLogger(filename)]
    model.compile(loss='binary_crossentropy',
          optimizer=optimizers.Adam(lr=1e-4),
          metrics=['acc'])
    history = model.fit_generator(datagen.flow(x_train, y_train, batch_size=32),
                steps_per_epoch=x_train.shape[0] // 32,
                epochs=20,
                verbose=1,
                validation_data=(x_val, y_val),
                callbacks=callbacks_list)
    history_list.append(history)
    model.save(os.path.join(model_dir, 'saved_model.h5'))
    return history_list

def draw_graph(history_list, output_dir):
    for i, history in enumerate(history_list):
        acc = history.history['acc']
        val_acc = history.history['val_acc']
        loss = history.history['loss']
        val_loss = history.history['val_loss']

        epochs = range(1, len(acc) + 1)

        plt.figure()
        plt.plot(epochs, acc, 'bo', label='Training acc')
        plt.plot(epochs, val_acc, 'b', label='Validation acc')
        plt.title('Training and validation accuracy')
        plt.legend()
        plt.savefig(output_dir + '/acc{}.png'.format(i))
        plt.show()

        plt.figure()
        plt.plot(epochs, loss, 'bo', label='Training loss')
        plt.plot(epochs, val_loss, 'b', label='Validation loss')
        plt.title('Training and validation loss')
        plt.legend()
        plt.savefig(output_dir + '/loss{}.png'.format(i))
        plt.show()

def over_sampling(datas, datagen, num, label_num):
    imgs = []
    label_list = []
    for x in datas:
        x = np.expand_dims(x, axis=0)
        for d in datagen.flow(x, batch_size=1):
            d = np.squeeze(d, axis=0)
            imgs.append(d)
            label_list.append(label_num)
            if (len(imgs) % num) == 0:
                break
    return np.array(imgs), np.expand_dims(np.array(label_list), axis=-1)

if __name__ == '__main__':
    train_dir = 'datasets/train_binary/'
    valid_dir = 'datasets/valid_binary/'
    output_dir = 'model_out_shuffle'
    classes = ['0', '1']
    image_size = 150

    make_dir(output_dir)
    x_train, y_train = extract_data(train_dir, classes, image_size)
    x_val, y_val = extract_data(valid_dir, classes, image_size)
    history_list = train(x_train, y_train, x_val, y_val, image_size, output_dir)
    graph_dir = os.path.join(output_dir, 'training_result_graph')
    make_dir(graph_dir)
    draw_graph(history_list, graph_dir)
    df_dir = os.path.join(output_dir, 'df_csv')
    make_dir(df_dir)
    for i, history in enumerate(history_list):
        df = pd.DataFrame(history.history)
        filename = os.path.join(df_dir, 'training{}_history.csv'.format(i+1))
        df.to_csv(filename)

発生している問題・エラーメッセージ

実際に学習させてログを見ると、二回目のエポックから訓練データの精度は1.0、検証データの精度は0.87で停止してしまっています。

Epoch 1/20
53/53 [==============================] - 118s 2s/step - loss: 0.0361 - acc: 0.9870 - val_loss: 2.5961 - val_acc: 0.8700
Epoch 2/20
53/53 [==============================] - 118s 2s/step - loss: 0.0029 - acc: 1.0000 - val_loss: 4.4108 - val_acc: 0.8700
Epoch 3/20
53/53 [==============================] - 117s 2s/step - loss: 0.0044 - acc: 0.9994 - val_loss: 7.1868 - val_acc: 0.8700
Epoch 4/20
53/53 [==============================] - 115s 2s/step - loss: 0.0028 - acc: 1.0000 - val_loss: 12.1159 - val_acc: 0.8700
Epoch 5/20
53/53 [==============================] - 115s 2s/step - loss: 0.0018 - acc: 1.0000 - val_loss: 18.7604 - val_acc: 0.8700
Epoch 6/20
53/53 [==============================] - 116s 2s/step - loss: 0.0022 - acc: 1.0000 - val_loss: 27.9707 - val_acc: 0.8700
Epoch 7/20
53/53 [==============================] - 116s 2s/step - loss: 0.0021 - acc: 1.0000 - val_loss: 40.4534 - val_acc: 0.8700
Epoch 8/20
53/53 [==============================] - 115s 2s/step - loss: 0.0014 - acc: 1.0000 - val_loss: 54.9640 - val_acc: 0.8700
Epoch 9/20
53/53 [==============================] - 115s 2s/step - loss: 0.0015 - acc: 1.0000 - val_loss: 72.0444 - val_acc: 0.8700
Epoch 10/20
53/53 [==============================] - 117s 2s/step - loss: 0.0014 - acc: 1.0000 - val_loss: 90.2816 - val_acc: 0.8700
Epoch 11/20
53/53 [==============================] - 120s 2s/step - loss: 0.0012 - acc: 1.0000 - val_loss: 105.6784 - val_acc: 0.8700
Epoch 12/20
53/53 [==============================] - 117s 2s/step - loss: 0.0014 - acc: 1.0000 - val_loss: 121.3914 - val_acc: 0.8700
Epoch 13/20
53/53 [==============================] - 115s 2s/step - loss: 0.0010 - acc: 1.0000 - val_loss: 140.8129 - val_acc: 0.8700
Epoch 14/20
53/53 [==============================] - 116s 2s/step - loss: 0.0012 - acc: 1.0000 - val_loss: 160.8408 - val_acc: 0.8700
Epoch 15/20
53/53 [==============================] - 118s 2s/step - loss: 9.9551e-04 - acc: 1.0000 - val_loss: 169.9402 - val_acc: 0.8700
Epoch 16/20
53/53 [==============================] - 116s 2s/step - loss: 0.0013 - acc: 1.0000 - val_loss: 181.2626 - val_acc: 0.8700
Epoch 17/20
53/53 [==============================] - 116s 2s/step - loss: 0.0012 - acc: 1.0000 - val_loss: 178.5618 - val_acc: 0.8700
Epoch 18/20
53/53 [==============================] - 116s 2s/step - loss: 8.5294e-04 - acc: 1.0000 - val_loss: 183.2046 - val_acc: 0.8700
Epoch 19/20
53/53 [==============================] - 123s 2s/step - loss: 9.1750e-04 - acc: 1.0000 - val_loss: 187.8404 - val_acc: 0.8700
Epoch 20/20
53/53 [==============================] - 123s 2s/step - loss: 7.0443e-04 - acc: 1.0000 - val_loss: 179.0444 - val_acc: 0.8700

試したこと

検証の精度0.87はクラス0とクラス1の比なので、おそらくすべてクラス0と予測して出力しているだろうことはわかりました。しかし、訓練の精度が1.0で停止している理由がわかりません。また、その解決策があればよろしくおねがいします。

  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 過去に投稿した質問と同じ内容の質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

質問への追記・修正、ベストアンサー選択の依頼

  • Q71

    2019/12/16 22:58

    conv
    max pool
    conv
    max pool
    conv
    flat * 2
    で間違いないでしょうか。
    畳み込み3層は薄いです。6層は重ねてみましょう。
    ちゃんとどんな特徴を取ったか確認しましょう。

    キャンセル

  • yukawakota

    2019/12/16 23:52

    層が薄いと過学習の原因になるのですね。。。
    層を増やしてハイパーパラメータを試行錯誤してみます。

    キャンセル

  • Q71

    2019/12/17 08:27

    過学習というか、学習できていません。Faster R-CNNやGrad-CAMの論文になにを学習したかを判断するヒントが書かれています。読んでみましょう。

    キャンセル

回答 1

check解決した方法

0

自分でデータをファイルから読み込んだ部分で画像を1./255で正規化したあとに、kerasのDatageneratorの方でも同じ処理を二重で行ってしまっており、trainのデータとvalidationのデータで正規化のスケーリングが一致していないのがtrainでの精度が1.0で停滞してしまっていた原因でした。最終的にはImageDataGenerator.flow_from_directoryで画像の前処理を一本化することで解決しました。不均衡データに対する不具合はunder samplingや、ハイパーパラメータの調整でより良い精度を目指します。

投稿

  • 回答の評価を上げる

    以下のような回答は評価を上げましょう

    • 正しい回答
    • わかりやすい回答
    • ためになる回答

    評価が高い回答ほどページの上位に表示されます。

  • 回答の評価を下げる

    下記のような回答は推奨されていません。

    • 間違っている回答
    • 質問の回答になっていない投稿
    • スパムや攻撃的な表現を用いた投稿

    評価を下げる際はその理由を明確に伝え、適切な回答に修正してもらいましょう。

15分調べてもわからないことは、teratailで質問しよう!

  • ただいまの回答率 88.92%
  • 質問をまとめることで、思考を整理して素早く解決
  • テンプレート機能で、簡単に質問をまとめられる

関連した質問

同じタグがついた質問を見る