質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.35%
機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

解決済

1回答

1724閲覧

SVMで決定境界線を引きたい(MNIST数字分類)

Poyoyo

総合スコア6

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

0クリップ

投稿2021/09/22 15:46

編集2021/09/23 05:04

#実現したいこと
MNISTの0から9までの数字を分類したのち、各数字がどのように線形分離されているか決定境界線を引いて見える化したい。
具体的には、2と8の二つの数字のみに対し、3次元で境界面、もしくは2次元で境界線を引きたい。

#背景
2と8の分類結果として誤認分類が多かったため、決定境界を可視化することで、2が8の領域、8が2の領域に存在していることを見える化する。

#試したこと
下記コードにて、3次元空間で0から9まで数値を可視化した。
この後、mixtendなどでいじってみたが、境界線を作るところが解決できなかった。

#環境
python3.8.5

python

1import tensorflow as tf 2import tensorflow.keras 3import matplotlib.pyplot as plt 4%matplotlib inline 5from tensorflow.keras.datasets import mnist 6import numpy as np 7 8 9(x_train, y_train), (x_test, y_test) = mnist.load_data() 10x_train = x_train.reshape(60000, 784) 11x_test = x_test.reshape(10000, 784) 12x_train = (x_train / 255.0 * 0.99) + 0.01 13x_test = (x_test / 255.0 * 0.99) + 0.01 14 15from sklearn.model_selection import train_test_split 16from sklearn import datasets, svm, metrics 17from sklearn.metrics import accuracy_score 18 19clf = svm.LinearSVC() 20clf.fit(x_train, y_train) 21 22y_pred = clf.predict(x_test) 23print(accuracy_score(y_test, y_pred)) 24 25#ここから3次元空間でプロットへ 26#下記4行追加修正しました 27all_features = x_test 28teacher_labels = y_test 29from sklearn import decomposition 30from mpl_toolkits.mplot3d import Axes3D 31 32def getcolor(color): 33 if color == 0: 34 return "red" 35 elif color == 1: 36 return "blue" 37 elif color == 2: 38 return "yellow" 39 elif color == 3: 40 return "greenyellow" 41 elif color == 4: 42 return "green" 43 elif color == 5: 44 return "cyan" 45 elif color == 6: 46 return "blue" 47 elif color == 7: 48 return "navy" 49 elif color == 8: 50 return "purple" 51 else: 52 return "black" 53 54#次元削減 55pca = decomposition.PCA(n_components=3) 56three_features = pca.fit_transform(all_features) 57 58#描画 59fig = plt.figure(figsize=(12,9)) 60subfig = fig.add_subplot(111, projection = "3d") 61colors = list(map(getcolor, teacher_labels)) 62subfig.scatter(three_features[:, 0], three_features[:, 1], three_features[:,2], s=50, c=colors, alpha=0.3) 63plt.show() 64

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

Poyoyo

2021/09/23 05:05 編集

大変失礼いたしました。 下記4行をコピペミスしていました。下記4行を、「#ここから3次元空間でプロットへ」の直後に入れて修正しました。 all_features = x_test teacher_labels = y_test from sklearn import decomposition from mpl_toolkits.mplot3d import Axes3D -------------------- また、ご指摘いただいたリンク先は参照したことがなく、私は画像認識プログラミングレシピという書籍を参照して作りました。 とはいうものの、リンク先の三次元プロットとは類似の結果になるところまでは作れたのですが、決定境界が作れない状況です。
guest

回答1

0

ベストアンサー

2と8に限定して、2次元で決定境界を描画してみました。いかがでしょうか?

イメージ説明

参考: 【python】高次元の分離境界をなんとか2次元で見る

all_features以降のコードを示します。

Python

1all_features = x_test[(y_test == 2) | (y_test == 8)] 2teacher_labels = y_test[(y_test == 2) | (y_test == 8)] 3color = {2: 'blue', 8: 'red'} 4 5from sklearn import decomposition 6from mlxtend.plotting import plot_decision_regions 7 8#次元削減 9pca = decomposition.PCA(n_components=2) 10features = pca.fit_transform(all_features) 11 12#描画 13colors = list(map(lambda x: color[x], teacher_labels)) 14X, Y = features[:, 0], features[:, 1] 15plt.scatter(X, Y, c=colors, alpha=0.3) 16#境界描画 17grid = np.meshgrid(np.linspace(X.min(), X.max(), 100), np.linspace(Y.min(), Y.max(), 100)) 18Z = clf.predict(pca.inverse_transform(np.c_[grid[0].flatten(), grid[1].flatten()])) 19plt.pcolormesh(grid[0], grid[1], Z.reshape(grid[0].shape), alpha=0.1, shading='gouraud') 20plt.show()

投稿2021/09/23 12:17

toast-uz

総合スコア3266

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

Poyoyo

2021/09/24 07:04

ありがとうございました。 私も環境でもmlxtendを入れて、上記内容で作ることができました。
guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.35%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問