質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

ただいまの
回答率

90.50%

  • Python

    11797questions

    Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

SVMを可視化に際して表示される複数の境界線の理由について

解決済

回答 2

投稿 編集

  • 評価
  • クリップ 0
  • VIEW 548

shibattyo

score 2

前提・実現したいこと

SVMを用いて領域分類した結果を可視化したいと思っています。
ここに質問の内容を詳しく書いてください。
その際に、機械学習のコードを参考にコーディングをjupyter notebookで行いました。
すると境界線とは別に線が何本か引かれているグラフが表示されてしまいました。
この線は何を表しているのか、どうすれば消えるのかを解決したいです。

発生している問題・エラーメッセージ

エラーメッセージ

該当のソースコード

ソースコード
AV_df = pd.read_csv("AV1.0~3.0 nomal2.csv")
x = AV_df.iloc[0:75, [0,1]].values
y = AV_df.iloc[0:75, 2].values
model =  SVC(kernel = "rbf", random_state = 0, C = 10, gamma = 100 )
model.fit(x, y)

def plot_decision_regions(X, y, classifier, test_idx=None, resolution=0.02):
    markers = ('s','x','o','^','v')
    colors = ('red','blue','lightgreen','gray','cyan')
    cmap = ListedColormap(colors[:len(np.unique(y))])

    x1_min, x1_max = X[:, 0].min() -0.015, X[:, 0].max() +0.055
    x2_min, x2_max = X[:, 1].min() -0.02, X[:, 1].max() +0.02

    xx1, xx2 = np.meshgrid(np.arange(x1_min, x1_max, resolution), np.arange(x2_min, x2_max, resolution))
    Z = classifier.predict(np.array([xx1.ravel(), xx2.ravel()]).T)
    Z = Z.reshape(xx1.shape)
    plt.contourf(xx1,xx2,Z,alpha = 0.4, cmap = cmap)
    plt.xlim(xx1.min(), xx1.max())
    plt.ylim(xx2.min(), xx2.max())

    for idx, cl in enumerate(np.unique(y)):
        plt.scatter(x=X[y == cl, 0], y=X[y == cl, 1],alpha = 0.8, c= cmap(idx), marker= markers[idx], label = cl)

    if test_idx:
        X_test, y_test = X[test_idx, :], y[test_idx]
        plt.scatter(X_test[:,0], X_test[:,1] ,c='', alpha=1.0, linewidth = 1, marker='o', s =55,label = 'test set')

plot_decision_regions(x, y, classifier = model)

試したこと

ここに問題に対して試したことを記載してください。

補足情報(FW/ツールのバージョンなど)

イメージ説明
ここにより詳細な情報を記載してください。

  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

質問への追記・修正、ベストアンサー選択の依頼

  • tiitoi

    2019/02/01 15:51 編集

    コードの入力は markdown 記法を使ってください。
    https://qiita.com/Qiita/items/c686397e4a0f4f11683d#code---%E3%82%B3%E3%83%BC%E3%83%89%E3%81%AE%E6%8C%BF%E5%85%A5

    あとSVMのコード全体を貼ってください

    キャンセル

  • shibattyo

    2019/02/01 16:03

    申し訳ありませんでした。全文掲載いたしました

    キャンセル

回答 2

+1

すると境界線とは別に線が何本か引かれているグラフが表示されてしまいました。

質問者さんのデータが何クラス分類問題なのか等わからないのですが、SVM で決定境界を描画する場合、以下のようにすればよいと思います。

学習する。

iris データを学習する例

import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC

# データを取得
iris = datasets.load_iris()
data = iris.data[:, [0, 2]]
label = iris.target

# 学習データとテストデータに分割する。
X_train, X_test, Y_train, Y_test = train_test_split(
    data, label, test_size=0.2, stratify=label, random_state=42)

# ロジスティック回帰モデルで学習する。
model = SVC(gamma='auto')
model.fit(X_train, Y_train)

# テストデータを推論し、精度を出力する。
Y_pred = model.score(X_test, Y_test)
print('test accuracy: {:.2%}'.format(Y_pred))

描画する。

fig, ax = plt.subplots(figsize=(8, 6))

# タイトル、x 軸、y 軸のラベルを設定する。
ax.set_title('classification data using SVM')
ax.set_xlabel('Sepal length')
ax.set_ylabel('Petal length')

# サンプルを描画する。
ax.scatter(data[:, 0], data[:, 1], c=label, s=7, cmap='tab10')

X, Y = np.meshgrid(np.linspace(*ax.get_xlim(), 1000),
                   np.linspace(*ax.get_ylim(), 1000))
XY = np.column_stack([X.ravel(), Y.ravel()])
Z = model.predict(XY).reshape(X.shape)

# 等高線を描画する。
ax.contourf(X, Y, Z, alpha=0.4, cmap='Paired')
plt.show()

イメージ説明

追記

質問者さんのコードでデータセットを iris に差し替えて実行した結果

イメージ説明

決定境界以外に直接が引かれるという現象は確認できません。

投稿

編集

  • 回答の評価を上げる

    以下のような回答は評価を上げましょう

    • 正しい回答
    • わかりやすい回答
    • ためになる回答

    評価が高い回答ほどページの上位に表示されます。

  • 回答の評価を下げる

    下記のような回答は推奨されていません。

    • 間違っている回答
    • 質問の回答になっていない投稿
    • スパムや攻撃的な表現を用いた投稿

    評価を下げる際はその理由を明確に伝え、適切な回答に修正してもらいましょう。

  • 2019/02/01 17:01

    これは3クラス分類です。
    境界線の引き方ではなく上図に引かれている、境界線とは別の線が何を意味しているのかを知りたいです。
    これもcontourf()メソッドで引かれた等高線という解釈でいいんでしょうか??

    キャンセル

  • 2019/02/01 17:10

    等高線であることは確かなのですが、predict() の返り値は推定クラスなので、3クラスなら 0, 1, 2 のどれかのはずです。
    等高線が引かれたということは、その部分に段差があるということになりますが、分類境界以外の部分に段差があるということになり変な気がします。
    原因を見つけるには、Z がどのような値なのか確認する必要があります。
    当方は AV1.0~3.0 nomal2.csv がないので、質問者さんのコードを動かすことができないので、原因まではわかりません。

    キャンセル

  • 2019/02/01 17:26 編集

    matplotlib で描画する際、(numpy.meshgrid()で作成した)いくつかの点の関数値からそれ以外の点の関数値を補完するので、ごく一部だけラベルが違ったりして、そのような結果になっている可能性がありますね。
    つまり、境界線じゃないように見える線も実は拡大するとその部分に別のラベルに推定された点が存在したりして決定境界になっているのではないでしょうか

    キャンセル

  • 2019/02/02 14:35

    解答して頂きありがとうございました。
    アドバイス頂いた通りにzの値を確認する過程でplot_decision_regions関数の引数であるrelosutionの値を小さくしたところ境界線のみになりました。
    おそらく入力としていたデータの値が少数第6位まであり小さかったため、広い間隔で分割して予測を行なった結果、等高線が境界線以外に引かれてしまったのではないかと考えられます。
    問題も一応解決いたしました。本当にありがとうございました。

    キャンセル

check解決した方法

0

plot_decision_regions関数の引数であるrelosutionの値を小さくしたところ境界線のみになりました。
おそらく入力としていたデータの値が少数第6位まであり小さかったため、広い間隔で分割して予測を行なった結果、等高線が境界線以外に引かれてしまったのではないかと考えられます。
そのため入力するいわゆる学習データの値に応じてresolutionの値は変えないといけないと思います。下の画像はresolutionを0.02から0.001に変更した時の様子です。
イメージ説明

投稿

  • 回答の評価を上げる

    以下のような回答は評価を上げましょう

    • 正しい回答
    • わかりやすい回答
    • ためになる回答

    評価が高い回答ほどページの上位に表示されます。

  • 回答の評価を下げる

    下記のような回答は推奨されていません。

    • 間違っている回答
    • 質問の回答になっていない投稿
    • スパムや攻撃的な表現を用いた投稿

    評価を下げる際はその理由を明確に伝え、適切な回答に修正してもらいましょう。

同じタグがついた質問を見る

  • Python

    11797questions

    Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。