前提・実現したいこと
混合ガウスモデルを使ってアヤメのデータに等高線を引きたいです。
アヤメのデータの図に等高線をプロットしたいのですがエラーが出てきてしまい、どのようにして特徴量を変更したらいいのかわかりません。
発生している問題・エラーメッセージ
ValueError Traceback (most recent call last) <ipython-input-13-eb81e5c63a46> in <module> 49 X1, Y = np.meshgrid(x, y) 50 XX = np.array([X1.ravel(), Y.ravel()]).T ---> 51 Z = -gmc.score_samples(XX) 52 Z = Z.reshape(X1.shape) 53 ~\anaconda3\lib\site-packages\sklearn\mixture\_base.py in score_samples(self, X) 334 """ 335 check_is_fitted(self) --> 336 X = _check_X(X, None, self.means_.shape[1]) 337 338 return logsumexp(self._estimate_weighted_log_prob(X), axis=1) ~\anaconda3\lib\site-packages\sklearn\mixture\_base.py in _check_X(X, n_components, n_features, ensure_min_samples) 59 raise ValueError("Expected the input data X have %d features, " 60 "but got %d features" ---> 61 % (n_features, X.shape[1])) 62 return X 63 ValueError: Expected the input data X have 4 features, but got 2 features
該当のソースコード
python
1from sklearn import cluster 2from sklearn.cluster import AgglomerativeClustering 3import numpy as np 4import seaborn as sns 5import pandas as pd 6import matplotlib 7import matplotlib.pyplot as plt 8from matplotlib.colors import LogNorm 9from sklearn import mixture 10 11 12with open("iris.csv", 'r') as file: 13 14 header = file.readline() 15 16 data = np.loadtxt(file, delimiter=',', usecols=(0,1,2,3,4)) 17 18 19 20X = data[:,0:4] 21 22y = data[:,4] 23 24 25 26gmc = mixture.GaussianMixture(n_components=3, covariance_type='full') 27 28gmc.fit(X) 29 30yy=gmc.fit_predict(X) 31 32plt.scatter(X[yy == 0][:,2], 33 X[yy == 0][:,3], 34 c='green', 35 label='versicolor') 36plt.scatter(X[yy == 1][:,2], 37 X[yy == 1][:,3], 38 c='yellow', 39 label='setosa') 40plt.scatter(X[yy == 2][:,2], 41 X[yy == 2][:,3], 42 c='red', 43 label='versinica') 44 45 46 47x = np.linspace(0., 7.) 48y = np.linspace(0.,3.) 49X1, Y = np.meshgrid(x, y) 50XX = np.array([X1.ravel(), Y.ravel()]).T 51Z = -gmc.score_samples(XX) 52Z = Z.reshape(X1.shape) 53 54CS = plt.contour(X1, Y, Z, norm=LogNorm(vmin=1.0, vmax=1000.0), 55 levels=np.logspace(0, 3, 10)) 56CB = plt.colorbar(CS, shrink=0.8, extend='both') 57plt.scatter(X1_train[:, 0], X1_train[:, 1], .8) 58 59plt.title('GMM') 60plt.axis('tight') 61 62 63 64 65 66 67 68 69 70plt.grid() 71 72plt.legend(loc="upper left") 73plt.xlabel("petal length") 74plt.ylabel("petal width") 75plt.show() 76 77 78 79print("Original : {0}".format(y.astype(np.int64))) 80 81print("Clustering: {0}".format(yy))
回答2件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。
退会済みユーザー
2020/06/20 11:40