前提・実現したいこと
pythonで画像判別をしたくネットのサンプルコードを使用してkerasで学習しようとしましたところ、エラーが発生しました。
発生している問題・エラーメッセージ
Using TensorFlow backend. --- 田村 を処理中 ok, 0 WARNING:tensorflow:From C:\Users\tocch\Anaconda3\envs\opencv\lib\site-packages\keras\backend\tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead. learn.py:51: UserWarning: Update your `Conv2D` call to the Keras 2 API: `Conv2D(32, (3, 3), input_shape=(), padding="same")` model.add(Convolution2D(32, 3, 3, border_mode='same', input_shape=in_shape)) WARNING:tensorflow:From C:\Users\tocch\Anaconda3\envs\opencv\lib\site-packages\keras\backend\tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead. Traceback (most recent call last): File "learn.py", line 97, in <module> main() File "learn.py", line 46, in main model = model_train(X_train, y_train) File "learn.py", line 70, in model_train model = build_model(X.shape[1:]) File "learn.py", line 51, in build_model model.add(Convolution2D(32, 3, 3, border_mode='same', input_shape=in_shape)) File "C:\Users\tocch\Anaconda3\envs\opencv\lib\site-packages\keras\engine\sequential.py", line 165, in add layer(x) File "C:\Users\tocch\Anaconda3\envs\opencv\lib\site-packages\keras\engine\base_layer.py", line 414, in __call__ self.assert_input_compatibility(inputs) File "C:\Users\tocch\Anaconda3\envs\opencv\lib\site-packages\keras\engine\base_layer.py", line 311, in assert_input_compatibility str(K.ndim(x))) ValueError: Input 0 is incompatible with layer conv2d_1: expected ndim=4, found ndim=1
該当のソースコード
python
1# -*- coding: utf-8 -*- 2from sklearn.model_selection import train_test_split 3from keras.preprocessing.image import load_img, img_to_array 4from keras.models import Sequential 5from keras.layers import Dense, Dropout, Activation, Flatten 6from keras.layers import Convolution2D, MaxPooling2D 7from keras.utils import np_utils 8import os, glob 9import numpy as np 10import matplotlib.pyplot as plt 11 12root_dir = "C://Users//tocch//Desktop//face_認証//学習用" 13categories = ["田村"] 14nb_classes = len(categories) 15image_size = 64 16 17# CNN設定 18BATCH_SIZE = 32 19EPOCHS = 15 20 21def init(): 22 X = [] 23 Y = [] 24 for idx, cat in enumerate(categories): 25 files = glob.glob(root_dir + "/" + cat + "/_trimming/*") 26 print("---", cat, "を処理中") 27 for i, f in enumerate(files): 28 img = load_img(f, target_size=(image_size,image_size)) 29 data = img_to_array(img) 30 X.append(data) 31 Y.append(idx) 32 X = np.array(X) 33 Y = np.array(Y) 34 35 X_train, X_test, y_train, y_test = train_test_split(X, Y) 36 xy = (X_train, X_test, y_train, y_test) 37 np.save(root_dir + "//npy/face.npy", xy) 38 print("ok,", len(Y)) 39 40def main(): 41 X_train, X_test, y_train, y_test = np.load(root_dir + "/npy/face.npy") 42 X_train = X_train.astype("float") / 256 43 X_test = X_test.astype("float") / 256 44 y_train = np_utils.to_categorical(y_train, nb_classes) 45 y_test = np_utils.to_categorical(y_test, nb_classes) 46 model = model_train(X_train, y_train) 47 model_eval(model, X_test, y_test) 48 49def build_model(in_shape): 50 model = Sequential() 51 model.add(Convolution2D(32, 3, 3, border_mode='same', input_shape=in_shape)) 52 model.add(Activation('relu')) 53 model.add(MaxPooling2D(pool_size=(2, 2))) 54 model.add(Dropout(0.25)) 55 model.add(Convolution2D(64, 3, 3, border_mode='same')) 56 model.add(Activation('relu')) 57 model.add(Convolution2D(64, 3, 3)) 58 model.add(MaxPooling2D(pool_size=(2, 2))) 59 model.add(Dropout(0.25)) 60 model.add(Flatten()) 61 model.add(Dense(512)) 62 model.add(Activation('relu')) 63 model.add(Dropout(0.5)) 64 model.add(Dense(nb_classes)) 65 model.add(Activation('softmax')) 66 model.compile(loss='binary_crossentropy', optimizer='rmsprop', metrics=['accuracy']) 67 return model 68 69def model_train(X, y): 70 model = build_model(X.shape[1:]) 71 #history = model.fit(X, y, batch_size=BATCH_SIZE, nb_epoch=EPOCHS, validation_split=0.1) 72 history = model.fit(X, y, batch_size=BATCH_SIZE, nb_epoch=EPOCHS, validation_split=0.1) 73 74 json_string = model.to_json() 75 open(os.path.join(root_dir + '/h5/cnn_model.json'), 'w').write(json_string) 76 model.save_weights(root_dir + "/h5/face-model.h5") 77 78 # グラフ表示 79 plt.plot(range(1, EPOCHS + 1), history.history['acc'], label = "train-acc") 80 plt.plot(range(1, EPOCHS + 1), history.history['loss'], label = "train-loss") 81 plt.plot(range(1, EPOCHS + 1), history.history['val_acc'], label = "val-acc") 82 plt.plot(range(1, EPOCHS + 1), history.history['val_loss'], label = "val-loss") 83 plt.title('list') 84 plt.xlabel('Epochs') 85 plt.ylabel('Accuracy') 86 plt.legend() 87 plt.show() 88 return model 89 90def model_eval(model, X, y): 91 score = model.evaluate(X, y) 92 print('loss=', score[0]) 93 print('accuracy=', score[1]) 94 95if __name__ == "__main__": 96 init() 97 main()
補足情報(FW/ツールのバージョンなど)
ここにより詳細な情報を記載してください。
リンクは「リンクの挿入」で記入しましょう。
ああああああああああああああああああああああああああああああああああああああ
回答1件
あなたの回答
tips
プレビュー