回答編集履歴

1

d

2019/02/06 06:44

投稿

tiitoi
tiitoi

スコア21956

test CHANGED
@@ -93,3 +93,91 @@
93
93
 
94
94
 
95
95
  ![イメージ説明](314d1fedb5843e39d39b9ea712911b51.png)
96
+
97
+
98
+
99
+ ## 数でなく、割合で混合行列を表示したい場合
100
+
101
+
102
+
103
+ ```python
104
+
105
+ import matplotlib.pyplot as plt
106
+
107
+ import numpy as np
108
+
109
+ import pandas as pd
110
+
111
+ import seaborn as sns
112
+
113
+ from sklearn.metrics import confusion_matrix
114
+
115
+ from sklearn.neighbors import KNeighborsClassifier
116
+
117
+
118
+
119
+ sns.set()
120
+
121
+
122
+
123
+ df = pd.read_csv('data.csv')
124
+
125
+
126
+
127
+ # 学習データ、テストデータに分割する。
128
+
129
+ train_x, test_x, train_y, test_y = MS.train_test_split(
130
+
131
+ df.drop('result', axis=1), df['result'], test_size=0.2)
132
+
133
+
134
+
135
+ clf = KNeighborsClassifier(n_neighbors=3)
136
+
137
+ clf.fit(train_x, train_y)
138
+
139
+
140
+
141
+ # テストデータを推論する。
142
+
143
+ pred_y = clf.predict(test_x)
144
+
145
+
146
+
147
+ # 混合行列を作成する。
148
+
149
+ cm = confusion_matrix(test_y, pred_y)
150
+
151
+ cm = cm / cm.sum()
152
+
153
+
154
+
155
+ # 混合行列を描画する。
156
+
157
+ def print_confusion_matrix(confusion_matrix, class_names):
158
+
159
+ heatmap = sns.heatmap(
160
+
161
+ confusion_matrix, xticklabels=class_names, yticklabels=class_names,
162
+
163
+ annot=True, fmt='.2%', cbar=False, square=True, cmap='YlGnBu')
164
+
165
+ plt.ylabel('True label')
166
+
167
+ plt.xlabel('Predicted label')
168
+
169
+ plt.show()
170
+
171
+
172
+
173
+
174
+
175
+ labels = ['class {}'.format(i) for i in df['result'].unique()]
176
+
177
+ print_confusion_matrix(cm, labels)
178
+
179
+ ```
180
+
181
+
182
+
183
+ ![イメージ説明](dec1aabb6e73866c6c6c99b680777ee7.png)