回答編集履歴
2
誤字修正
test
CHANGED
@@ -1,97 +1,49 @@
|
|
1
|
-
そ
|
1
|
+
そもそもなのですが、散布図を描くためには4つの特徴量のうち2つを選ぶ必要があるのではないしょうか?
|
2
|
-
|
3
2
|
同様にガウス混合分類器も4つではなく2つの特徴量で訓練する必要があるかと思います。
|
4
3
|
|
5
|
-
|
6
|
-
|
7
4
|
以下`sepal length (cm)`と`sepal width (cm)`で散布図と等高線図を描く例です。
|
8
|
-
|
9
5
|
```Python
|
10
6
|
|
11
|
-
|
12
|
-
|
13
7
|
from sklearn.mixture import GaussianMixture
|
14
|
-
|
15
8
|
from sklearn.datasets import load_iris
|
16
|
-
|
17
9
|
import numpy as np
|
18
|
-
|
19
10
|
import matplotlib.pyplot as plt
|
20
11
|
|
21
|
-
|
22
|
-
|
23
12
|
data = load_iris()
|
24
|
-
|
25
13
|
N_CLASS = len(data.target_names) # 3
|
26
|
-
|
27
14
|
LEN_CLASS = len(data.target) // N_CLASS # 50
|
28
15
|
|
29
|
-
|
30
|
-
|
31
16
|
vx,vy = 0,1 # 対象の2変数
|
32
|
-
|
33
17
|
x_min, x_max = min(data.data[:,vx]), max(data.data[:,vx])
|
34
|
-
|
35
18
|
y_min, y_max = min(data.data[:,vy]), max(data.data[:,vy])
|
36
19
|
|
37
|
-
|
38
|
-
|
39
20
|
# 2変数で散布図を描画
|
40
|
-
|
41
21
|
cs = []
|
42
|
-
|
43
22
|
for c in 'rgb':
|
44
|
-
|
45
23
|
cs += list(c*LEN_CLASS)
|
46
|
-
|
47
24
|
plt.scatter(data.data[:,vx], data.data[:,vy], c=cs)
|
48
|
-
|
49
25
|
plt.xlabel(data.feature_names[vx])
|
50
|
-
|
51
26
|
plt.ylabel(data.feature_names[vy])
|
52
|
-
|
53
27
|
plt.xlim(x_min, x_max)
|
54
|
-
|
55
28
|
plt.ylim(y_min, y_max)
|
56
29
|
|
57
|
-
|
58
|
-
|
59
30
|
# 散布図に合わせた2変数で訓練
|
60
|
-
|
61
31
|
model = GaussianMixture(n_components=N_CLASS, random_state=3)
|
62
|
-
|
63
32
|
model.fit(data.data[:,[vx,vy]])
|
64
33
|
|
65
|
-
|
66
|
-
|
67
34
|
# 存在する範囲内で点群を生成
|
68
|
-
|
69
35
|
N = 100
|
70
|
-
|
71
36
|
x, y = np.meshgrid(np.linspace(x_min,x_max,N), np.linspace(y_min,y_max,N))
|
72
|
-
|
73
37
|
X = np.array([x, y]).reshape(2, -1).T
|
74
|
-
|
75
38
|
probs = model.predict_proba(X)
|
76
39
|
|
77
|
-
|
78
|
-
|
79
40
|
# 等高線図を各クラス毎に描画
|
80
|
-
|
81
41
|
cs = ['Reds', 'Greens', 'Blues'] # 散布図の色とは必ずしも対応しないことに注意
|
82
|
-
|
83
42
|
assert len(cs) == N_CLASS
|
84
|
-
|
85
43
|
for i in range(N_CLASS):
|
86
|
-
|
87
44
|
plt.contourf(x, y, probs[:,i].reshape(N,N), cmap=cs[i], alpha=0.2)
|
88
|
-
|
89
45
|
plt.xlim(x_min, x_max)
|
90
|
-
|
91
46
|
plt.ylim(y_min, y_max)
|
92
|
-
|
93
47
|
plt.show()
|
94
|
-
|
95
48
|
```
|
96
|
-
|
97
49
|
![イメージ説明](da564fce2c578bef066bf158e5cd66ac.png)
|
1
コード修正
test
CHANGED
@@ -7,6 +7,8 @@
|
|
7
7
|
以下`sepal length (cm)`と`sepal width (cm)`で散布図と等高線図を描く例です。
|
8
8
|
|
9
9
|
```Python
|
10
|
+
|
11
|
+
|
10
12
|
|
11
13
|
from sklearn.mixture import GaussianMixture
|
12
14
|
|
@@ -26,6 +28,14 @@
|
|
26
28
|
|
27
29
|
|
28
30
|
|
31
|
+
vx,vy = 0,1 # 対象の2変数
|
32
|
+
|
33
|
+
x_min, x_max = min(data.data[:,vx]), max(data.data[:,vx])
|
34
|
+
|
35
|
+
y_min, y_max = min(data.data[:,vy]), max(data.data[:,vy])
|
36
|
+
|
37
|
+
|
38
|
+
|
29
39
|
# 2変数で散布図を描画
|
30
40
|
|
31
41
|
cs = []
|
@@ -34,15 +44,15 @@
|
|
34
44
|
|
35
45
|
cs += list(c*LEN_CLASS)
|
36
46
|
|
37
|
-
plt.scatter(data.data[:,
|
47
|
+
plt.scatter(data.data[:,vx], data.data[:,vy], c=cs)
|
38
48
|
|
39
|
-
plt.xlabel(data.feature_names[
|
49
|
+
plt.xlabel(data.feature_names[vx])
|
40
50
|
|
41
|
-
plt.ylabel(data.feature_names[
|
51
|
+
plt.ylabel(data.feature_names[vy])
|
42
52
|
|
43
|
-
plt.xlim(
|
53
|
+
plt.xlim(x_min, x_max)
|
44
54
|
|
45
|
-
plt.ylim(
|
55
|
+
plt.ylim(y_min, y_max)
|
46
56
|
|
47
57
|
|
48
58
|
|
@@ -50,7 +60,7 @@
|
|
50
60
|
|
51
61
|
model = GaussianMixture(n_components=N_CLASS, random_state=3)
|
52
62
|
|
53
|
-
model.fit(data.data[:,
|
63
|
+
model.fit(data.data[:,[vx,vy]])
|
54
64
|
|
55
65
|
|
56
66
|
|
@@ -58,7 +68,7 @@
|
|
58
68
|
|
59
69
|
N = 100
|
60
70
|
|
61
|
-
x, y = np.meshgrid(np.linspace(
|
71
|
+
x, y = np.meshgrid(np.linspace(x_min,x_max,N), np.linspace(y_min,y_max,N))
|
62
72
|
|
63
73
|
X = np.array([x, y]).reshape(2, -1).T
|
64
74
|
|
@@ -76,12 +86,12 @@
|
|
76
86
|
|
77
87
|
plt.contourf(x, y, probs[:,i].reshape(N,N), cmap=cs[i], alpha=0.2)
|
78
88
|
|
79
|
-
plt.xlim(
|
89
|
+
plt.xlim(x_min, x_max)
|
80
90
|
|
81
|
-
plt.ylim(
|
91
|
+
plt.ylim(y_min, y_max)
|
82
92
|
|
83
93
|
plt.show()
|
84
94
|
|
85
95
|
```
|
86
96
|
|
87
|
-
![イメージ説明](
|
97
|
+
![イメージ説明](da564fce2c578bef066bf158e5cd66ac.png)
|