前提・実現したいこと
CNNで使用したテスト用データをファイルに画像として保存するためのコードを実現したい。
発生している問題・エラーメッセージ
テスト用に使用したデータのうち推定値と正解値の異なる不正解画像のみを取り出し画像を保存したいのですが、上手くいっていない状態です。
該当のソースコード
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D
from tensorflow.keras.layers import Activation, Dropout, Flatten, Dense
import tensorflow.keras
from keras.utils import np_utils
import numpy as np
classes = ["concrete","tile","mortar","brick","wood"]
num_classes = len(classes)
image_size = 50
def main():
,,,,X_train, X_test, y_train, y_test = np.load("./gaiheki.npy",allow_pickle=True)
,,,,X_train = X_train.astype("float") / 256
,,,,X_test = X_test.astype("float") / 256
,,,,y_train = np_utils.to_categorical(y_train, num_classes)
,,,,y_test = np_utils.to_categorical(y_test, num_classes)
,,,,model = model_train(X_train, y_train) ,,,,model_eval(model, X_test, y_test)
def model_train(X, y):
,,,,model = Sequential()
,,,,model.add(Conv2D(2,(3,3), padding='same',input_shape=X.shape[1:]))
,,,,model.add(Activation('relu'))
,,,,model.add(Conv2D(2,(3,3)))
,,,,model.add(Activation('relu'))
,,,,model.add(MaxPooling2D(pool_size=(2,2)))
,,,,model.add(Dropout(0.25))
,,,,model.add(Flatten()) ,,,,model.add(Dense(512)) ,,,,model.add(Activation('relu')) ,,,,model.add(Dropout(0.5)) ,,,,model.add(Dense(5)) ,,,,model.add(Activation('softmax')) ,,,,import tensorflow ,,,,opt = tensorflow.keras.optimizers.RMSprop(lr=0.0001, decay=1e-6) ,,,,model.compile(loss='categorical_crossentropy',optimizer=opt,metrics=['accuracy']) ,,,,model.fit(X, y, batch_size=32, epochs=100) ,,,,model.save('./gaiheki_cnn.h5') ,,,,return model
def model_eval(model, X, y):
,,,,scores = model.evaluate(X, y, verbose=1)
,,,,print('Test Loss: ', scores[0])
,,,,print('Test Accuracy: ', scores[1])
def model_predict(model, X, y):
,,,,result = model.predict(X)
,,,,for i in range(X.shape[0]):
,,,,,,,,print('推定値: ', result[i].argmax())
,,,,,,,,print('正解値: ', y[i].argmax())
if name == "main":
,,,,main()
試したこと
各画像ごとにクラス分けはしており、推定値と正解値はクラス名で表示はできていますが、それをどう画像として新たに保存するのかわからない状態です。
補足情報(FW/ツールのバージョンなど)
python
tensorflow
あなたの回答
tips
プレビュー