前提・実現したいこと
scikit-learnを使って決定木を実装しています。
可視化しようとdtreevizのコードを書いていたらエラーが発生しました。
dtreevizは関係ないかもしれません。
実行結果・エラーメッセージ
0.7368421052631579 0.9642857142857143 変数 重要度 0 食料 0.000000 1 住居 0.000000 2 水道光熱費 0.183482 3 家具家事用品 0.000000 4 衣類 0.095694 5 保険医療 0.000000 6 交通通信 0.000000 7 教育 0.000000 8 教養娯楽 0.720824 9 諸雑費 0.000000 --------------------------------------------------------------------------- ValueError Traceback (most recent call last) <ipython-input-104-2d435d3c4d5a> in <module>() 36 target_name='大都市圏分類', 37 feature_names=X_train.columns, ---> 38 class_names=[str(i) for i in class_names] 39 ) 40 4 frames /usr/local/lib/python3.6/dist-packages/sklearn/tree/tree.py in _validate_X_predict(self, X, check_input) 400 "match the input. Model n_features is %s and " 401 "input n_features is %s " --> 402 % (self.n_features_, n_features)) 403 404 return X ValueError: Number of features of the model must match the input. Model n_features is 4 and input n_features is 10
該当のソースコード
python
1import pandas as pd 2from sklearn.model_selection import train_test_split 3from sklearn import tree 4from dtreeviz.trees import dtreeviz 5data = pd.read_csv('consumerPrices_tree.csv') 6 7X = data.drop(['都道府県', '大都市圏分類'], axis=1) 8Y = data['大都市圏分類'] 9 10from sklearn.tree import DecisionTreeClassifier 11model = DecisionTreeClassifier(max_depth=3, random_state=0) 12model.fit(X, Y) 13model.predict(X) 14 15predicted = pd.DataFrame({'Predicted':model.predict(X)}) 16data_predicted = pd.concat([data, predicted], axis =1) 17sum(model.predict(X)==Y)/len(Y) 18 19X_train,X_test,Y_train,Y_test = train_test_split(X, Y,random_state=0, test_size=0.4) 20 21model = DecisionTreeClassifier(max_depth=3, random_state=0) 22model.fit(X_train, Y_train) 23model.predict(X_train) 24from sklearn import metrics 25print(metrics.accuracy_score(Y_test, model.predict(X_test))) 26print(metrics.accuracy_score(Y_train, model.predict(X_train))) 27 28importance = pd.DataFrame({ '変数':X.columns, '重要度':model.feature_importances_}) 29print(importance) 30 31class_names = Y_train.unique().tolist() 32viz = dtreeviz( 33 classifier, 34 X_train, 35 Y_train, 36 target_name='大都市圏分類', 37 feature_names=X_train.columns, 38 class_names=[str(i) for i in class_names] 39 ) 40 41display(viz) 42 43
csvです。他のサイトが公開しているデータなので一応伏せています。
ここからダウンロードできます。
都道府県,食料,住居,水道光熱費,家具家事用品,衣類,保険医療,交通通信,教育,教養娯楽,諸雑費,大都市圏分類 北 海 道,98.7,82.6,116.3,99.3,103.8,100.2,99.5,93.2,97.1,100.9,1 ...... ......
試したこと
モデルの特徴が4個というのがよく分かりません。
なぜ10個ではないのでしょうか?
補足情報(FW/ツールのバージョンなど)
google colaboratory (python3)
回答1件
あなたの回答
tips
プレビュー