質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%
Keras

Kerasは、TheanoやTensorFlow/CNTK対応のラッパーライブラリです。DeepLearningの数学的部分を短いコードでネットワークとして表現することが可能。DeepLearningの最新手法を迅速に試すことができます。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

Q&A

0回答

716閲覧

ROC曲線について(python)

_sfgh3k

総合スコア18

Keras

Kerasは、TheanoやTensorFlow/CNTK対応のラッパーライブラリです。DeepLearningの数学的部分を短いコードでネットワークとして表現することが可能。DeepLearningの最新手法を迅速に試すことができます。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

0グッド

0クリップ

投稿2020/08/02 09:34

Pythonのkerasを用いた機械学習で二値分類問題を扱っています。
そこでROC曲線を描きたいのですが検索すると予測における確率を用いていたのですが、どこをいじれば確率が出るのかわかりません。いくつか試してみましたが予測後の結果しか出ずに困っています。ご教授よろしくお願いいたします。
コードは以下の通りです。

python

1#モジュールの読み込み 2from __future__ import print_function 3import csv #csvモジュールの読み込み 4import random 5import pandas as pd 6from sklearn.model_selection import train_test_split 7from sklearn.metrics import accuracy_score 8from sklearn.metrics import confusion_matrix 9import seaborn as sns 10import numpy as np 11import matplotlib.pyplot as plt 12import keras 13from keras.models import Sequential 14from keras.layers import Dense, Dropout 15from keras.optimizers import RMSprop 16from pandas import Series,DataFrame 17from sklearn import metrics 18 19y_true=[] 20y_pred=[] 21 22 23 24csv_file =open("XXXX",'r') 25 26 27get_file = [] 28learn_data = [] 29check_data = [] 30 31 32 33if pattern == 1: 34 #分けて学習 35 for row in csv.reader(csv_file): 36 get_file.append(row) 37 check_data.append(get_file[0]) 38 learn_data.append(get_file[0]) 39 del get_file[0] 40 for i in range (len(get_file)): 41 num = random.randint(1,10) 42 if num == 1 or num == 2 : 43 check_data.append(get_file[i]) 44 else: 45 learn_data.append(get_file[i]) 46 47 #学習用CSV 48 with open("XXXX", "w", encoding="Shift_jis") as f: # 文字コードをShift_JISに指定 49 writer = csv.writer(f, lineterminator="\n") # writerオブジェクトの作成 改行記号で行を区切る 50 writer.writerows(learn_data) # csvファイルに書き込み 51 f.close() 52 #検証用のCSV 53 with open("XXXX", "w", encoding="Shift_jis") as f: # 文字コードをShift_JISに指定 54 writer = csv.writer(f, lineterminator="\n") # writerオブジェクトの作成 改行記号で行を区切る 55 writer.writerows(check_data) # csvファイルに書き込み 56 f.close() 57 58if pattern ==2: 59 #分けずに学習、検証 60 for row in csv.reader(csv_file): 61 get_file.append(row) 62 check_data.append(get_file[0]) 63 learn_data.append(get_file[0]) 64 del get_file[0] 65 for i in range (len(get_file)): 66 num = random.randint(1,10) 67 if num == 1 : 68 check_data.append(get_file[i]) 69 learn_data.append(get_file[i]) 70 else: 71 learn_data.append(get_file[i]) 72 73 #学習用CSV 74 with open("XXXX", "w", encoding="Shift_jis") as f: # 文字コードをShift_JISに指定 75 writer = csv.writer(f, lineterminator="\n") # writerオブジェクトの作成 改行記号で行を区切る 76 writer.writerows(learn_data) # csvファイルに書き込み 77 f.close() 78 #検証用のCSV 79 with open("XXXX", "w", encoding="Shift_jis") as f: # 文字コードをShift_JISに指定 80 writer = csv.writer(f, lineterminator="\n") # writerオブジェクトの作成 改行記号で行を区切る 81 writer.writerows(check_data) # csvファイルに書き込み 82 f.close() 83 84 85 86 87#CSVファイルの読み込み 88#分割データ 89if pattern ==1: 90 sig_data_set = pd.read_csv("XXXX",header=0,encoding='Shift_jis') 91 92#分割データ 93if pattern ==2: 94 sig_data_set = pd.read_csv("XXXX",header=0,encoding='Shift_jis') 95 96#説明変数 97x = DataFrame(sig_data_set.drop("judge",axis=1)) 98 99#目的変数 100y = DataFrame(sig_data_set["judge"]) 101 102#説明変数・目的変数をそれぞれ訓練データ・テストデータに分割 103x_train,x_test,y_train,y_test = train_test_split(x,y,test_size=0.4) 104 105 106#データの整形 107x_train = x_train.astype(np.float) 108x_test = x_test.astype(np.float) 109 110y_train = keras.utils.to_categorical(y_train,2) 111y_test = keras.utils.to_categorical(y_test,2) 112 113 114 115#ニューラルネットワークの実装① 116model = Sequential() 117model.add(Dense(12,activation='relu', input_shape=(8,))) 118model.add(Dropout(0.2)) 119model.add(Dense(12, activation='sigmoid')) 120model.add(Dropout(0.2)) 121model.add(Dense(8, activation='sigmoid')) 122model.add(Dropout(0.2)) 123model.add(Dense(8, activation='sigmoid')) 124model.add(Dropout(0.2)) 125model.add(Dense(2, activation='sigmoid')) 126model.summary() 127 128#ニューラルネットワークの実装② 129#model.compile(loss='mean_squared_error',optimizer=RMSprop(lr=1e-3),metrics=['accuracy']) 130model.compile(loss='binary_crossentropy',optimizer=RMSprop(lr=1e-3),metrics=['accuracy']) 131#ニューラルネットワークの学習 132history = model.fit(x_train, y_train,batch_size=8,epochs=300,verbose=1,validation_data=(x_test, y_test)) 133 134#ニューラルネットワークの推論 135score = model.evaluate(x_test,y_test,verbose=1) 136print("\n") 137print(score) 138 139#評価したいデータを設定 140#分割データ 141if pattern ==1: 142 sig_check = pd.read_csv("XXXX",encoding='Shift_jis') 143 144#非分割データ 145if pattern ==2: 146 sig_check = pd.read_csv("XXXX",encoding='Shift_jis') 147 148sample = np.array(DataFrame(sig_check.drop("judge",axis=1))) 149 150 151#実測値を抽出 152#分割データ 153if pattern ==1: 154 csv_file = open("XXXX",'r') 155 156#非分割データ 157if pattern ==2: 158 csv_file = open("XXXX",'r') 159 160real_judge = [] 161for row in csv.reader(csv_file): 162 real_judge.append(row[-1]) 163del real_judge[0] 164 165predict = model.predict_classes(sample[0].reshape(1,-1),batch_size=1,verbose=0) 166 167#予測を行なう 168print("---予測値,実測値---") 169for i in range(len(sample)): 170 predict = model.predict_classes(sample[i].reshape(1,-1),batch_size=1,verbose=0) 171 y_true.append(int(real_judge[i])) 172 y_pred.append(int(predict[0])) 173 print(predict[0],real_judge[i]) 174 175 176 177 178cm = confusion_matrix(y_true, y_pred) 179tn, fp, fn, tp = cm.flatten() 180print("-----混同行列-----") 181print("TP(真陽性)=%d"%tp) 182print("TN(真陰性)=%d"%tn) 183print("FN(偽陰性)=%d"%fn) 184print("FP(偽陽性)=%d"%fp) 185 186#sns.heatmap(cm, annot=True, cmap='Blues') 187#plt.savefig('confusion_matrix.png') 188 189 190#評価指標 191Recall=tp/(tp+fn) 192Prediction=tp/(tp+fp) 193Acc=(tp+tn)/(tp+tn+fp+fn) 194B_Acc=1/2*((tp/(tp+fn))+(tn/(tn+fp))) 195HSS=(2*(tp*tn-fp*fn))/((tp+fn)*(fn+tn)+(tp+fp)*(fp+tn)) 196TSS=tp/(tp+fn)-fp/(tn+fp) 197 198 199print("-----評価指標-----") 200print("Recall=%f"%Recall) 201print("Prediction=%f"%Prediction) 202print("ACC=%f"%Acc) 203print("BACC=%f"%B_Acc) 204print("HSS=%f"%HSS) 205print("TSS=%f"%TSS) 206 207 208 209 210 211 212def plot_history(history): 213 # print(history.history.keys()) 214 215 # 精度の履歴をプロット 216 plt.plot(history.history['accuracy']) 217 plt.plot(history.history['val_accuracy']) 218 plt.title('model accuracy') 219 plt.xlabel('epoch') 220 plt.ylabel('accuracy') 221 plt.legend(['accuracy', 'val_accuracy'], loc='lower right') 222 plt.savefig('accuracy.png') 223 224 plt.figure() 225 226 # 損失の履歴をプロット 227 plt.plot(history.history['loss']) 228 plt.plot(history.history['val_loss']) 229 plt.title('model loss') 230 plt.xlabel('epoch') 231 plt.ylabel('loss') 232 plt.legend(['loss', 'val_loss'], loc='lower right') 233 plt.savefig('loss.png') 234# 学習履歴をプロット 235plot_history(history) 236 237 238

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

まだ回答がついていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問