前提・実現したいこと
現在9人の画像を集めて学習させ、別の写真を見せた時にどの人か特定させるというようなことをしたいと考えています。
乃木坂メンバーの顔をCNNで分類
こちらの記事を参考にKerasのCNNを利用して実装しているのですが、できたモデルで9人のいろんな写真を判定させるとある1人として判定されてしまうことが多いです。
9人をA、B、C...だとして、Aの写真、Bの写真、Cの写真を判定させても90%超えでAさんと判定されてしまうといった感じです。
画像もそこまで量を集められているわけではないので精度が多少悪くなるのは仕方ないとは思うのですが、なぜここまで偏ってしまうのかがわかりません。せめてこの偏った判定は解消したいと考えています。
単純に画像が足りないせいのか、コードの書き方に誤りがあるのか、層が足りないのか、調整が足りないのか...今後どこを改善していけばいいのか悩んでいるところです。
解決方法をご教示ください。
そもそも、CNNではなく最近の画像分類ではこの手法のほうがおすすめなどのアドバイスもあれば教えていただきたいです。
###用意したデータ
集めた画像は、メンバーごとに116枚~124枚
上記記事を参考に、9倍の量に水増し
上記記事を参考に、ランダムに選んだ8割がtrainフォルダに2割がtestフォルダに入っている状態です
該当のソースコード
python
1members = ['A','B','C','D','E','F','G','H','I'] 2 3TRAIN_FOLDER_PATH = 'D:\train' 4TEST_FOLDER_PATH = 'D:\test' 5 6# 教師データのラベル付け 7X_train = [] 8Y_train = [] 9for i in range(len(members)): 10 images = os.listdir(os.path.join(TRAIN_FOLDER_PATH, members[i])) 11 for image in images: 12 img = cv2.imread(os.path.join(TRAIN_FOLDER_PATH, members[i], image)) 13 b,g,r = cv2.split(img) 14 img = cv2.merge([r,g,b]) 15 X_train.append(img) 16 Y_train.append(i) 17 18# テストデータのラベル付け 19X_test = [] # 画像データ読み込み 20Y_test = [] # ラベル(名前) 21for i in range(len(members)): 22 images = os.listdir(os.path.join(TEST_FOLDER_PATH, members[i])) 23 for image in images: 24 img = cv2.imread(os.path.join(TEST_FOLDER_PATH, members[i], image)) 25 b,g,r = cv2.split(img) 26 img = cv2.merge([r,g,b]) 27 X_test.append(img) 28 Y_test.append(i) 29X_train=np.array(X_train) 30X_test=np.array(X_test) 31 32 33y_train = to_categorical(Y_train) 34y_test = to_categorical(Y_test) 35 36 37model = Sequential() 38model.add(Conv2D(32, (3, 3), activation='relu', 39 input_shape=(64, 64, 3), padding="same")) 40model.add(MaxPooling2D((2, 2))) 41model.add(Dropout(0.2)) 42model.add(Conv2D(64, (3, 3), activation='relu')) 43model.add(MaxPooling2D((2, 2))) 44model.add(Dropout(0.2)) 45model.add(Conv2D(128, (3, 3), activation='relu')) 46model.add(MaxPooling2D((2, 2))) 47model.add(Dropout(0.2)) 48model.add(Conv2D(128, (3, 3), activation='relu')) 49model.add(MaxPooling2D((2, 2))) 50model.add(Dropout(0.2)) 51model.add(Flatten()) 52model.add(Dense(512, activation='relu')) 53model.add(Dropout(0.2)) 54model.add(Dense(9, activation='softmax')) 55 56 57model.compile(loss='categorical_crossentropy', 58 optimizer=optimizers.RMSprop(lr=1e-4), 59 metrics=['accuracy']) 60 61 62# 学習 63history = model.fit(X_train, y_train, batch_size=32, 64 epochs=50, verbose=1, validation_data=(X_test, y_test)) 65 66# 汎化制度の評価・表示 67score = model.evaluate(X_test, y_test, batch_size=32, verbose=0) 68print('validation loss:{0[0]}\nvalidation accuracy:{0[1]}'.format(score)) 69
###model.summary()
Model:
1_________________________________________________________________ 2 Layer (type) Output Shape Param # 3================================================================= 4 conv2d (Conv2D) (None, 64, 64, 32) 896 5 6 max_pooling2d (MaxPooling2D (None, 32, 32, 32) 0 7 ) 8 9 dropout (Dropout) (None, 32, 32, 32) 0 10 11 conv2d_1 (Conv2D) (None, 30, 30, 64) 18496 12 13 max_pooling2d_1 (MaxPooling (None, 15, 15, 64) 0 14 2D) 15 16 dropout_1 (Dropout) (None, 15, 15, 64) 0 17 18 conv2d_2 (Conv2D) (None, 13, 13, 128) 73856 19 20 max_pooling2d_2 (MaxPooling (None, 6, 6, 128) 0 21 2D) 22 23 dropout_2 (Dropout) (None, 6, 6, 128) 0 24 25 conv2d_3 (Conv2D) (None, 4, 4, 128) 147584 26 27 max_pooling2d_3 (MaxPooling (None, 2, 2, 128) 0 28 2D) 29 30 dropout_3 (Dropout) (None, 2, 2, 128) 0 31 32 flatten (Flatten) (None, 512) 0 33 34 dense (Dense) (None, 512) 262656 35 36 dropout_4 (Dropout) (None, 512) 0 37 38 dense_1 (Dense) (None, 9) 4617 39 40================================================================= 41Total params: 508,105 42Trainable params: 508,105 43Non-trainable params: 0
###結果
validation loss:0.15804602205753326 validation accuracy:0.9585736989974976
試したこと
他のサイトも参考にして、上記参考記事から以下のような変更を加えてみてます。
活性化関数をsigmoid
-> relu
に変更
オプティマイザをsgd
-> RMSprop
に変更
その他、batchサイズやepoch数を大きくしたり小さくしたりしてみたりしたのですが、偏りはあまり解消されませんでした。
###追記
テストデータに分けてから、訓練データのみ水増し処理するように修正しました。
結果は以下のようになりました。
ただ、特定のメンバーが確率高く出る偏りは解消されないようでした。
validation loss:0.9875603318214417 validation accuracy:0.7211538553237915
回答1件
あなたの回答
tips
プレビュー