teratail header banner
teratail header banner
質問するログイン新規登録

回答編集履歴

1

edit

2017/12/11 11:25

投稿

mkgrei
mkgrei

スコア8562

answer CHANGED
@@ -22,4 +22,82 @@
22
22
 
23
23
  ---
24
24
 
25
- 後は純粋にMNISTのほうが意地悪なサンプルが割合多く含まれている可能性もありますが、上記の可能性を排除できないにはこのような結論を下すのは時期尚早でしょうね。
25
+ 後は純粋にMNISTのほうが意地悪なサンプルが割合多く含まれている可能性もありますが、上記の可能性を排除できないにはこのような結論を下すのは時期尚早でしょうね。
26
+
27
+ ---
28
+
29
+ 追記:
30
+ 気になったので試してみました。
31
+ MNISTのほうがずっと難しいですね。
32
+ digitsは8x8に対して、MNISTは28x28ですので、自由度がずっと高いですね。
33
+ 例えば、MNISTから2000だけ取り出して8x8にリサイズしてやると、正答率は
34
+ digits:~98%、MNIST:~92%になります。
35
+
36
+ ```python
37
+ from sklearn.model_selection import StratifiedKFold
38
+ from sklearn.neighbors import KNeighborsClassifier
39
+ from sklearn.metrics import accuracy_score
40
+
41
+ from sklearn import datasets
42
+ from keras.datasets import mnist
43
+
44
+ from scipy.misc import imresize
45
+
46
+ import numpy as np
47
+
48
+ try:
49
+ from tqdm import tqdm
50
+ except (ImportError) as e:
51
+ tqdm = lambda x:x
52
+
53
+ def main(key='digits', random_state=2017):
54
+ if key == 'digits':
55
+ dataset = datasets.load_digits()
56
+ X = dataset.data
57
+ Y = dataset.target
58
+ elif key == 'mnist':
59
+ (X_train, y_train), (X_test, y_test) = mnist.load_data()
60
+ kfold = StratifiedKFold(5, shuffle=True, random_state=0)
61
+ tr, ts = next(kfold.split(X_test, y_test))
62
+ X = X_test[ts]
63
+ X = np.array([imresize(x, (8, 8)) for x in X])
64
+ X = X.reshape(-1, np.prod(X.shape[1:]))
65
+ Y = y_test[ts]
66
+ Y = Y.reshape(-1)
67
+ else:
68
+ return [], []
69
+
70
+ ks = np.linspace(1, 10, 5).astype('i')
71
+
72
+ accuracy_scores = []
73
+ for k in tqdm(ks):
74
+ pY = np.zeros(Y.shape)
75
+ kfold = StratifiedKFold(5, shuffle=True, random_state=random_state)
76
+ for tr, ts in kfold.split(X, Y):
77
+ x_tr = X[tr]
78
+ y_tr = Y[tr]
79
+
80
+ model = KNeighborsClassifier(n_neighbors=k, metric='manhattan')
81
+ model.fit(x_tr, y_tr)
82
+
83
+ py = model.predict(X[ts])
84
+ pY[ts] = py
85
+
86
+ score = accuracy_score(Y, pY)
87
+ accuracy_scores.append(score)
88
+ return ks, accuracy_scores
89
+
90
+ if __name__ == '__main__':
91
+ colors = ['red', 'blue']
92
+ for ic, key in enumerate(['digits', 'mnist']):
93
+ for i in np.linspace(1, 1000, 10).astype('i'):
94
+ ks, accuracy_scores = main(key=key, random_state=2017+i)
95
+ plt.plot(ks, accuracy_scores, marker='.', color=colors[ic])
96
+
97
+ plt.xlabel('k')
98
+ plt.ylabel('Accuracy')
99
+ plt.grid()
100
+ plt.xlim((0, np.max(ks)))
101
+ plt.ylim((0.8, 1.0))
102
+ plt.show()
103
+ ```