質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%
scikit-learn

scikit-learnは、Pythonで使用できるオープンソースプロジェクトの機械学習用ライブラリです。多くの機械学習アルゴリズムが実装されていますが、どのアルゴリズムも同じような書き方で利用できます。

Matplotlib

MatplotlibはPythonのおよび、NumPy用のグラフ描画ライブラリです。多くの場合、IPythonと連携して使われます。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

Q&A

解決済

1回答

1092閲覧

SVMの識別境界面を3次元で描画したい

mato999

総合スコア3

scikit-learn

scikit-learnは、Pythonで使用できるオープンソースプロジェクトの機械学習用ライブラリです。多くの機械学習アルゴリズムが実装されていますが、どのアルゴリズムも同じような書き方で利用できます。

Matplotlib

MatplotlibはPythonのおよび、NumPy用のグラフ描画ライブラリです。多くの場合、IPythonと連携して使われます。

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

0グッド

0クリップ

投稿2023/03/31 15:04

編集2023/04/02 04:37

実現したいこと

SVMの識別境界面をdecision_functionの出力値を使って3次元空間に描画したい。
まずはlinearカーネルで正しく面ができることを確認した後、RBFカーネルでも曲面を作りたい。

前提

直接、識別関数の式を使って平面自体は描画することができた。
ただ、decision_functionの出力値が0になる位置=境界面となることを利用して
描画を実行すると平面はできあがるが、位置がデータの分布に対してズレてしまう。
(decision_functionの出力値を使う理由はRBFで曲面を作りたいため。)

該当のソースコード

Python

1from sklearn.datasets import load_iris 2import matplotlib.pyplot as plt 3import numpy as np 4from sklearn.preprocessing import StandardScaler 5from sklearn.svm import SVC 6 7# irisデータセットの読み込み 8iris = load_iris() 9 10# 特徴量とクラスラベルの取得 11X = iris.data 12y = iris.target 13 14# 3つの特徴量を使用 15X = X[:, :3] 16# クラス0と1のみにするために2を1に変換 17y[y == 2] = 1 18 19# データのスケーリング 20scaler = StandardScaler() 21X = scaler.fit_transform(X) 22 23# SVMのRBFカーネルによる分類器の学習 24clf = SVC(kernel='linear', gamma=1, C=1) 25clf.fit(X, y) 26 27# 3Dグラフの作成 28fig = plt.figure(figsize=(8, 8)) 29ax = fig.add_subplot(111, projection='3d') 30 31# データのプロット 32ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=y, cmap='bwr') 33 34# 決定境界の作成 35xx, yy = np.meshgrid(np.linspace(X[:, 0].min(), X[:, 0].max(), 50), 36 np.linspace(X[:, 1].min(), X[:, 1].max(), 50)) 37zz = - clf.decision_function(np.c_[xx.ravel(), yy.ravel(), np.zeros(xx.shape).ravel()]) 38zz = zz.reshape(xx.shape) 39 40# 決定境界のプロット 41ax.plot_surface(xx, yy, zz, alpha=0.5) 42 43ax.set_xlabel('X Label') 44ax.set_ylabel('Y Label') 45ax.set_zlabel('Z Label') 46 47plt.show()

試したこと

np.meshgridの領域を変更したが変化なし。
clf.decision_functionの符号を反転したが位置が変わることはなかった。

補足情報(FW/ツールのバージョンなど)

Python3.9
PyCharm 2022.2 (Community Edition)

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

PondVillege

2023/03/31 15:22 編集

それもそのはず,clf.decision_functionの出力値は境界からの距離なのですから,Z座標値の分布とはズレます.この流れで三次元曲面を描画するには distance = clf.desicion_funcition(x, y, z) という現状の関数の逆(?)関数を導出して z = desicion_function_inv(x, y, distance) のように作ることになります.現状,zが0の場合の距離を描画しています.
mato999

2023/03/31 15:35

ありがとうございます。 z = desicion_function_inv(x, y, distance)のdistanceには0が入るということになりますでしょうか?
PondVillege

2023/03/31 15:43

そうですね. この逆関数を求めるのは現実的じゃないので,zの値を探索的に入力して曲面の点群を求めるのが良さそうです.
mato999

2023/04/01 01:32

上記のコードで考えると、zの値を探索的に入力してclf.decision_functionが0になる点群を求めるということでしょうか?よろしくお願いします。
PondVillege

2023/04/01 03:27

はい,実際は0と完全に等しくなる点を探すのは無理かと思いますので,許容できる誤差まで納まったら,という条件にして探索すると良いでしょう.
mato999

2023/04/01 03:46

ありがとうございます。なんとなくやり方のイメージができました。もう一度トライしてみます。
guest

回答1

0

自己解決

頂いたコメントの内容を実践したところ、期待していた曲面を描画することができました。
ありがとうございました。

全体のコードとしては以下のようにしました。
もし、改善点等がありましたらコメント頂ければと思います。

Python

1from sklearn.datasets import load_iris 2import matplotlib.pyplot as plt 3import numpy as np 4from sklearn.preprocessing import StandardScaler 5from sklearn.svm import SVC 6from scipy import interpolate 7import itertools 8 9# irisデータセットの読み込み 10iris = load_iris() 11 12# 特徴量とクラスラベルの取得 13X = iris.data 14y = iris.target 15 16# 3つの特徴量を使用 17X = X[:, :3] 18 19# クラス0と1のみにするために2を1に変換 20y[y == 2] = 1 21 22# データのスケーリング 23scaler = StandardScaler() 24X = scaler.fit_transform(X) 25 26# SVMのRBFカーネルによる分類器の学習 27clf = SVC(kernel='rbf', gamma=1, C=100) 28clf.fit(X, y) 29 30# 3Dグラフの作成 31fig = plt.figure(figsize=(8, 8)) 32ax = fig.add_subplot(111, projection='3d') 33 34gridsize = 50 35 36# 識別関数の出力値が0になる点群を探索するためのxyzグリッドデータを作成 37xx_s, yy_s, zz_s = np.meshgrid(np.linspace(X[:, 0].min(), X[:, 0].max(), gridsize), 38 np.linspace(X[:, 1].min(), X[:, 1].max(), gridsize), 39 np.linspace(X[:, 2].min(), X[:, 2].max(), gridsize)) 40 41# 識別関数の出力値が0になるときの座標データ格納用 42points_x, points_y, points_z = [], [], [] 43 44# グリッドデータ点毎に探索 45r_iter = itertools.product(range(xx_s.shape[0]), repeat=3) 46for i, j, k in r_iter: 47 dec_val = clf.decision_function(np.array([xx_s[i, j, k], yy_s[i, j, k], zz_s[i, j, k]]).reshape(1, -1)) 48 # 0になる条件を決めてxx_s, yy_s, zz_sの値を格納 49 if -1e-3 <= dec_val <= 1e-3: 50 points_x.append(xx_s[i, j, k]) 51 points_y.append(yy_s[i, j, k]) 52 points_z.append(zz_s[i, j, k]) 53arr = np.array([np.array(points_x), np.array(points_y), np.array(points_z)]).T 54 55# 描画用のxyグリッドデータを作成 56xx, yy = np.meshgrid(np.linspace(X[:, 0].min(), X[:, 0].max(), gridsize), 57 np.linspace(X[:, 0].min(), X[:, 0].max(), gridsize)) 58 59# 識別境界を描画するために座標間のデータを補間 60Z = interpolate.griddata(arr[:, 0:2], arr[:, 2], (xx, yy), method="cubic") 61 62# 学習データと識別境界を描画 63ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=y, cmap='bwr') 64ax.plot_wireframe(xx, yy, Z, alpha=0.5) 65 66ax.set_xlabel('X Label') 67ax.set_ylabel('Y Label') 68ax.set_zlabel('Z Label') 69 70plt.show()

イメージ説明

投稿2023/04/02 04:18

編集2023/04/02 04:24
mato999

総合スコア3

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問