###前提・実現したいこと
Kerasで深層学習の勉強をしています。InceptionV3のモデルを利用してcifar10の画像分類をしたいです。
###発生している問題・エラーメッセージ
kerasに用意されているApplicationsからモデルをロードしたのですが、cifar10の32*32という小さなサイズには対応しておらず読み込みでエラーが出ます。
/home/***/keras_study/test.py:69: UserWarning: Update your `fit_generator` call to the Keras 2 API: `fit_generator(<keras.pre..., verbose=1, validation_data=(array([[[..., steps_per_epoch=1562, epochs=200, callbacks=[<keras.ca..., max_queue_size=100)` callbacks=[lr_reducer, csv_logger]) Traceback (most recent call last): File "/home/***/keras_study/test.py", line 69, in <module> callbacks=[lr_reducer, csv_logger]) File "/usr/local/lib/python2.7/dist-packages/keras/legacy/interfaces.py", line 87, in wrapper return func(*args, **kwargs) File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 1777, in fit_generator val_x, val_y, val_sample_weight) File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 1234, in _standardize_user_data exception_prefix='input') File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 140, in _standardize_input_data str(array.shape)) ValueError: Error when checking input: expected input_1 to have shape (None, 299, 299, 3) but got array with shape (10000, 32, 32, 3)
###該当のソースコード
Python
1from __future__ import print_function 2from keras.datasets import cifar10 3from keras.preprocessing.image import ImageDataGenerator 4from keras.utils import np_utils 5from keras.callbacks import ReduceLROnPlateau, CSVLogger 6import numpy as np 7 8from keras.applications.inception_v3 import InceptionV3 9 10lr_reducer = ReduceLROnPlateau(factor=np.sqrt(0.1), cooldown=0, patience=5, min_lr=0.5e-6) 11csv_logger = CSVLogger('test_cifar10.csv') 12 13batch_size = 32 14nb_classes = 10 15nb_epoch = 200 16 17# input image dimensions 18img_rows, img_cols = 32, 32 19# The CIFAR10 images are RGB. 20img_channels = 3 21 22# The data, shuffled and split between train and test sets: 23(X_train, y_train), (X_test, y_test) = cifar10.load_data() 24 25# Convert class vectors to binary class matrices. 26Y_train = np_utils.to_categorical(y_train, nb_classes) 27Y_test = np_utils.to_categorical(y_test, nb_classes) 28 29X_train = X_train.astype('float32') 30X_test = X_test.astype('float32') 31 32# subtract mean and normalize 33mean_image = np.mean(X_train, axis=0) 34X_train -= mean_image 35X_test -= mean_image 36X_train /= 128. 37X_test /= 128. 38 39model = InceptionV3(weights='imagenet') 40model.compile(loss='categorical_crossentropy', 41 optimizer='adam', 42 metrics=['accuracy']) 43 44# This will do preprocessing and realtime data augmentation: 45datagen = ImageDataGenerator( 46 featurewise_center=False, # set input mean to 0 over the dataset 47 samplewise_center=False, # set each sample mean to 0 48 featurewise_std_normalization=False, # divide inputs by std of the dataset 49 samplewise_std_normalization=False, # divide each input by its std 50 zca_whitening=False, # apply ZCA whitening 51 rotation_range=10, # randomly rotate images in the range (degrees, 0 to 180) 52 width_shift_range=0.1, # randomly shift images horizontally (fraction of total width) 53 height_shift_range=0.1, # randomly shift images vertically (fraction of total height) 54 shear_range=0.1, 55 zoom_range=0.1, 56 horizontal_flip=True, # randomly flip images 57 vertical_flip=False) # randomly flip images 58 59datagen.fit(X_train) 60 61# Fit the model on the batches generated by datagen.flow(). 62model.fit_generator(datagen.flow(X_train, Y_train, batch_size=batch_size), 63 steps_per_epoch=X_train.shape[0] // batch_size, 64 validation_data=(X_test, Y_test), 65 epochs=nb_epoch, verbose=1, max_q_size=100, 66 callbacks=[lr_reducer, csv_logger]) 67 68model_json_str = model.to_json() 69open('test_cifar10.json', 'w').write(model_json_str) 70model.save_weights('test_cifar10_param.hdf5') 71
###試したこと
cifar10の画像サイズを300300のように大きくリサイズして通そうとは考えたのですが、その場合元のデータに比べて大きくなりすぎて学習で問題が起こりそうな気がしています。
<追記>
3232→300*300にリサイズした画像を実際に表示すると、テレビのノイズ画面のような元の画像とかけ離れた画像に変換されており、これでは教師データとして問題があるんじゃないかと考えています。
###質問したいこと
cifar10のような小さい画像に対してInceptionV3モデルを適用させる方法があれば、Kerasでの実装方法も含めて教えて頂きたいです。
###補足情報(言語/FW/ツール等のバージョンなど)
Keras2.07
Python2.7
回答1件
あなたの回答
tips
プレビュー