機械学習の勉強中です
https://github.com/deer-dslab/keras-example/tree/master/fine-tuning
のfine-tuning/predict.pyを参考に以下のプログラムを実行しようとしています
python3
1import os 2import sys 3from keras.applications.vgg16 import VGG16 4from keras.models import Sequential, Model 5from keras.layers import Input, Activation, Dropout, Flatten, Dense 6from keras.preprocessing import image 7import numpy as np 8 9f=open('test_labels.txt','r') 10line = f.readline() 11 12while line: 13 print(line) 14 filename = (line.strip()) 15 line = f.readline() 16 17 result_dir = 'dataset/results' 18 19 classes = ['art','nat'] 20 nb_classes = len(classes) 21 22 img_height, img_width = 150, 150 23 channels = 3 24 25 # VGG16 26 input_tensor = Input(shape=(img_height, img_width, channels)) 27 vgg16 = VGG16(include_top=False, weights='imagenet', input_tensor=input_tensor) 28 29 # FC 30 fc = Sequential() 31 fc.add(Flatten(input_shape=vgg16.output_shape[1:])) 32 fc.add(Dense(256, activation='relu')) 33 fc.add(Dropout(0.5)) 34 fc.add(Dense(nb_classes, activation='softmax')) 35 36# VGG16とFCを接続 37 model = Model(input=vgg16.input, output=fc(vgg16.output)) 38 39# 学習済みの重みをロード 40 model.load_weights(os.path.join(result_dir, 'finetuning.h5')) 41 42 model.compile(loss='categorical_crossentropy', 43 optimizer='adam', 44 metrics=['accuracy']) 45 # model.summary() 46 47# 画像を読み込んで4次元テンソルへ変換 48 img = image.load_img(filename, target_size=(img_height, img_width)) 49 x = image.img_to_array(img) 50 x = np.expand_dims(x, axis=0) 51 52# 学習時にImageDataGeneratorのrescaleで正規化したので同じ処理が必要! 53# これを忘れると結果がおかしくなるので注意 54 x = x / 255.0 55 56# print(x) 57# print(x.shape) 58 59# クラスを予測 60# 入力は1枚の画像なので[0]のみ 61 pred = model.predict(x)[0] 62 63#result_fileを読み込む 64#file_name = "./result_data.txt" 65 66# 予測確率が高いトップを出力 67 top = 1 68 top_indices = pred.argsort()[-top:][::-1] 69 result = [(classes[i], pred[i]) for i in top_indices] 70 file = open('result_predict.txt', 'a') 71 file.write(str(result[0]) + str(line)) 72 file.close() 73 74f.close() 75
一行ごとにテスト画像の場所が書いてあるtest_labels.txtから一行づつ読み込み二値分類の結果を./result_data.txtに書き出すのプログラムになっています
実行自体は上手くいき、最初は順調に2~3秒で1行ほど排出してくれるのですがだんだんと動作が遅くなり半分ほど進んだところで強制終了してしまいました・・・
原因・改善方法をご教授ください
回答1件
あなたの回答
tips
プレビュー