#現したいこと
digitsデータでk近傍法 ~マンハッタン距離~を行って使う近傍数ごとの正解率を折れ線グラフを作成したところ以下のようなジグザグな図形になりました。
mnistデータでk近傍法 ~マンハッタン距離~を行って使う近傍数ごとの正解率を折れ線グラフを作成したところ以下のような図形になりました。
似たようなデータなのになぜここまでグラフが変わるのか教えてください。
###Digitsのソースコード
python
1def main(): 2 # データをロード 3 dataset = datasets.load_digits() 4 5 # 特徴データとラベルデータを取り出す 6 features = dataset.data 7 targets = dataset.target 8 9 # 検証する近傍数 10 K = 10 11 ks = range(1, K + 1) 12 13 # 使う近傍数ごとに正解率&各経過時間を計算 14 accuracy_scores = [] 15 start = time.time() 16 for k in ks: 17 predicted_labels = [] 18 loo = LeaveOneOut() 19 for train, test in loo.split(features): 20 train_data = features[train] 21 test_data = targets[train] 22 23 elapsed_time = time.time() - start 24 25 # モデルを学習 26 model = KNeighborsClassifier(n_neighbors=k, metric='manhattan') 27 model.fit(train_data, test_data) 28 29 # 一つだけ取り除いたテストデータを識別 30 predicted_label = model.predict(features[test]) 31 predicted_labels.append(predicted_label) 32 33 # 正解率を計算 34 score = accuracy_score(targets, predicted_labels) 35 print('k={}: {}'.format(k, score)) 36 37 accuracy_scores.append(score) 38 39 # 各経過時間を表示 40 print("経過時間:{:.2f}".format(elapsed_time)) 41 42 # 使う近傍数ごとの正解率を折れ線グラフ 43 X = list(ks) 44 plt.plot(X, accuracy_scores) 45 46 plt.xlabel('k') 47 plt.ylabel('正解率') 48 plt.show() 49 50 51if __name__ == '__main__': 52 main()
###mnistのソースコード
python
1def main(): 2 3 # 特徴データとラベルデータを取り出す 4 features = mnist.data 5 targets = mnist.target 6 7 #データを分割 8 train_dataX, test_dataX, train_dataY, test_dataY = model_selection.train_test_split(features,targets,test_size=0.3) 9 10 11 # 検証する近傍数 12 K = 10 13 ks = range(1, K + 1) 14 15 # 使う近傍数ごとに正解率&各経過時間を計算 16 accuracy_scores = [] 17 start = time.time() 18 for k in ks: 19 predicted_labels = [] 20 elapsed_time = time.time() - start 21 22 # モデルを学習 23 model = KNeighborsClassifier(n_neighbors=k, metric='manhattan') 24 model.fit(train_dataX,train_dataY) 25 26 # 一つだけ取り除いたテストデータを識別 27 predicted_label = model.predict(test_dataX) 28 29 # 正解率を計算 30 score = accuracy_score(test_dataY, predicted_label) 31 print('k={}: {}'.format(k, score)) 32 33 accuracy_scores.append(score) 34 35 # 各経過時間を表示 36 print("経過時間:{:.2f}".format(elapsed_time)) 37 38 # 使う近傍数ごとの正解率を折れ線グラフ 39 X = list(ks) 40 plt.plot(X, accuracy_scores) 41 42 plt.xlabel('k') 43 plt.ylabel('正解率') 44 plt.show() 45 46 47if __name__ == '__main__': 48 main()
回答1件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。
退会済みユーザー
2017/12/11 09:47
2017/12/11 13:54
退会済みユーザー
2017/12/11 14:07