#前提・実現したいこと
[Keras]MobileNetV2+ArcFaceを使ってペットボトルを分類してみた!
上記URLのサイト様のコードを参考に、自前の画像で分類を行いたいと考えております。
#発生している問題・エラーメッセージ
途中まではサイト様のコード通りに動いたのですが、途中のコードでエラーが発生してしまいました。コードのどの点を変更すればいいのかわからない状況となっております。下記に発生したエラーメッセージを添付いたします。
Python
1touketu157 2Epoch 1/30 3--------------------------------------------------------------------------- 4ValueError Traceback (most recent call last) 5<ipython-input-11-e2c75c3c8750> in <module> 6 51 verbose=1)] 7 52 8---> 53 model.fit_generator(train_gene, steps_per_epoch=80, epochs=30, validation_steps=20, validation_data=val_gane, callbacks=callbacks_list) 9 10~/.pyenv/versions/3.6.5/lib/python3.6/site-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs) 11 89 warnings.warn('Update your `' + object_name + '` call to the ' + 12 90 'Keras 2 API: ' + signature, stacklevel=2) 13---> 91 return func(*args, **kwargs) 14 92 wrapper._original_function = func 15 93 return wrapper 16 17~/.pyenv/versions/3.6.5/lib/python3.6/site-packages/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch) 18 1730 use_multiprocessing=use_multiprocessing, 19 1731 shuffle=shuffle, 20-> 1732 initial_epoch=initial_epoch) 21 1733 22 1734 @interfaces.legacy_generator_methods_support 23 24~/.pyenv/versions/3.6.5/lib/python3.6/site-packages/keras/engine/training_generator.py in fit_generator(model, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch) 25 218 sample_weight=sample_weight, 26 219 class_weight=class_weight, 27--> 220 reset_metrics=False) 28 221 29 222 outs = to_list(outs) 30 31~/.pyenv/versions/3.6.5/lib/python3.6/site-packages/keras/engine/training.py in train_on_batch(self, x, y, sample_weight, class_weight, reset_metrics) 32 1506 x, y, 33 1507 sample_weight=sample_weight, 34-> 1508 class_weight=class_weight) 35 1509 if self._uses_dynamic_learning_phase(): 36 1510 ins = x + y + sample_weights + [1] 37 38~/.pyenv/versions/3.6.5/lib/python3.6/site-packages/keras/engine/training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, check_array_lengths, batch_size) 39 577 feed_input_shapes, 40 578 check_batch_axis=False, # Don't enforce the batch size. 41--> 579 exception_prefix='input') 42 580 43 581 if y is not None: 44 45~/.pyenv/versions/3.6.5/lib/python3.6/site-packages/keras/engine/training_utils.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix) 46 143 ': expected ' + names[i] + ' to have shape ' + 47 144 str(shape) + ' but got array with shape ' + 48--> 145 str(data_shape)) 49 146 return data 50 147 51 52ValueError: Error when checking input: expected input_2 to have shape (5,) but got array with shape (2,)
#コード
Yohei-Kawakami
/
201901_self_checkout
元のコードの全文は上記URLのサイト様にございます。
Python
1# model構築の準備 2from keras.models import Model 3from keras.layers import Dense, GlobalAveragePooling2D,Input,Dropout,Activation 4from keras.applications import MobileNetV2 5from keras.preprocessing.image import ImageDataGenerator 6from keras.optimizers import Adam 7from keras.callbacks import CSVLogger,EarlyStopping 8import numpy as np 9from keras import backend as K 10from keras.engine.topology import Layer 11import numpy as np 12import tensorflow as tf 13from keras.preprocessing.image import load_img, img_to_array, array_to_img 14import os 15import matplotlib.pyplot as plt 16%matplotlib inline 17 18# Arcfacelayerの実行 19%run arcface.py 20 21def create_mobilenet_with_arcface(n_categories, file_path=None): 22 base_model=MobileNetV2(input_shape=(583,438,3), 23 weights='imagenet', 24 include_top=False) 25 26 27 #add new layers instead of FC networks 28 x = base_model.output 29 yinput = Input(shape=(n_categories,)) 30 # stock hidden model 31 hidden = GlobalAveragePooling2D()(x) 32 # stock Feature extraction 33 #x = Dropout(0.5)(hidden) 34 x = Arcfacelayer(5, 30, 0.05)([hidden,yinput]) 35 #x = Dense(1024,activation='relu')(x) 36 prediction = Activation('softmax')(x) 37 model = Model(inputs=[base_model.input,yinput],outputs=prediction) 38 39 if file_path: 40 model.load_weights(file_path) 41 print('weightは{}'.format(file_path)) 42 43 return model 44 45def create_predict_model(n_categories, file_path): 46 arcface_model = create_mobilenet_with_arcface(n_categories, file_path) 47 predict_model = Model(arcface_model.get_layer(index=0).input, arcface_model.get_layer(index=-4).output) 48 predict_model.summary() 49 return predict_model 50 51# learn model 52model = create_mobilenet_with_arcface(5)# 重みloadしない 53model.summary() 54 55# folderをtrainとtestに分ける 56%run gazo_sprit_many_class.py 57 58class train_Generator_xandy(object): # rule1 59 def __init__(self): 60 datagen = ImageDataGenerator( 61 vertical_flip = False, 62 width_shift_range = 0.1, 63 height_shift_range = 0.1, 64 rescale=1.0/255., 65 zoom_range=0.2, 66 fill_mode = "constant", 67 cval=0) 68 train_generator=datagen.flow_from_directory( 69 train_dir, 70 target_size=(583,438), 71 batch_size=25, 72 class_mode='categorical', 73 shuffle=True) 74 75 self.gene = train_generator 76 77 def __iter__(self): 78 # __next__()はselfが実装してるのでそのままselfを返す 79 return self 80 81 def __next__(self): 82 X, Y = self.gene.next() 83 return [X,Y], Y 84 85 86class val_Generator_xandy(object): 87 def __init__(self): 88 validation_datagen=ImageDataGenerator(rescale=1.0/255.) 89 90 validation_generator=validation_datagen.flow_from_directory( 91 validation_dir, 92 target_size=(583,438), 93 batch_size=25, 94 class_mode='categorical', 95 shuffle=True) 96 97 self.gene = validation_generator 98 99 def __iter__(self): 100 # __next__()はselfが実装してるのでそのままselfを返す 101 return self 102 103 def __next__(self): 104 X, Y = self.gene.next() 105 return [X,Y], Y 106 107train_dir = './zidolegi_data/train' 108validation_dir = './zidolegi_data/validation' 109train_gene = train_Generator_xandy() 110val_gane = val_Generator_xandy() 111 112**#以下のコードからエラーが発生いたします** 113# layerを徐々に解凍しながら学習する 114from keras import callbacks 115 116touketulayerlists = [ 117 model.layers.index(model.get_layer("arcfacelayer_1")), 118 model.layers.index(model.get_layer("block_16_expand")), 119 model.layers.index(model.get_layer("block_15_expand")), 120 model.layers.index(model.get_layer("block_14_expand")), 121 model.layers.index(model.get_layer("block_13_expand")), 122 model.layers.index(model.get_layer("block_12_expand")), 123 model.layers.index(model.get_layer("block_11_expand")), 124 model.layers.index(model.get_layer("block_10_expand")), 125 model.layers.index(model.get_layer("block_9_expand")), 126 model.layers.index(model.get_layer("block_8_expand")), 127 model.layers.index(model.get_layer("block_7_expand")), 128 model.layers.index(model.get_layer("block_6_expand")) 129] 130 131maenosavepath = None 132for touketu in touketulayerlists: 133 print('touketu{}'.format(touketu)) 134 135 modelsavepath = "zidolege_model/m02_fine{}kara_weights".format(touketu) 136 if maenosavepath: 137 model.load_weights(maenosavepath) 138 139 maenosavepath = modelsavepath 140 #凍結 141 for layer in model.layers[:touketu]: 142 layer.trainable=False 143 for layer in model.layers[touketu:]: 144 layer.trainable=True 145 146 model.compile(optimizer=Adam(lr=0.001), 147 loss='categorical_crossentropy', 148 metrics=['accuracy']) 149 150 callbacks_list = [ 151 #バリデーションlossが改善したらモデルをsave 152 callbacks.ModelCheckpoint( 153 filepath=modelsavepath, 154 monitor="val_loss", 155 save_weights_only=True, 156 save_best_only=True), 157 158 #バリデーションlossが改善しなくなったら学習率を変更する 159 callbacks.ReduceLROnPlateau( 160 monitor="val_loss", 161 factor=0.8, 162 patience=5, 163 verbose=1)] 164 165 model.fit_generator(train_gene, steps_per_epoch=80, epochs=30, validation_steps=20, validation_data=val_gane, callbacks=callbacks_list) 166
元のコードを見ていただけるとわかると思いますが、上記のコードはエラーが発生したところまでの表示にしてあります。
#試していること
英語力に乏しい自分でも、エラーメッセージが「本当はshape(5,)が必要であるのにshape(2,)が入力されているよ」という意味であることがわかりました。確かに元のサイト様は5つのカテゴリーに分けるため、shape(5,)としていることまでは自分でも理解できました。しかし、自分の分類したいカテゴリーは2つですので、どのようにコードを変更させればいいのかわからない状況となっております。
#補足
使っているPCはmacOS Catalina バージョン10.15.5
Pythonのバージョンは3.6.5です
jupyter notebookを使用しております。
回答1件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。
2020/07/13 12:42
2020/07/13 12:45
2020/07/13 13:01 編集
2020/07/13 13:09
2020/07/13 13:36