scikit-learnのRandomForestclassifierで分類をやっています。
結果を可視化したいのでdtreevizを使いたいのですが、エラーが解決できず悩んでいます。
dtreevizでfeature_namesとclass_namesを与える際に、別にリストを作ったことが問題でしょうか?ご検討いただけますと幸いです。
以下分析の流れです。
データのまとめ方はpandas.dataframeを使っています。
被験者9名分のデータds被験者番号に特徴量と解答ラベルが含まれたものがあり、それらをpd.concatにて分割してtrain_X,train_Yを設定しています。
python
1#trainデータ作り 2train_X = pd.concat([ds10.drop('ans', axis=1),ds9.drop('ans', axis=1),ds8.drop('ans', axis=1),ds7.drop('ans', axis=1),ds6.drop('ans', axis=1),ds5.drop('ans', axis=1),ds4.drop('ans', axis=1),ds3.drop('ans', axis=1),ds2.drop('ans', axis=1)]) 3train_y = pd.concat([ds10.ans,ds9.ans,ds8.ans,ds7.ans,ds6.ans,ds5.ans,ds4.ans,ds3.ans,ds2.ans]) 4#ランダムフォレスト 5from sklearn.ensemble import RandomForestClassifier 6clf = RandomForestClassifier(random_state=0) 7clf = clf.fit(train_X, train_y) 8#dtreeviz用のリスト作り、x,y,z~の特徴量を使って1,2,3~のラベルで分類 9features = [] 10features = ['x','y','z'] 11names = [] 12names =[1,2,3] 13
ここまでは通っているのですが、以下を実行したところエラーが出ました。
python
1from dtreeviz.trees import dtreeviz 2estimators = clf.estimators_ 3viz=dtreeviz( 4 estimators[0], 5 train_X, 6 train_y, 7 target_name='features', 8 feature_names=features, 9 class_names=[str(i) for i in names], 10) 11 12viz
error
1IndexError Traceback (most recent call last) 2<ipython-input-36-4ce17c1ec598> in <module> 3 7 target_name='features', 4 8 feature_names=features, 5----> 9 class_names=[str(i) for i in names], 6 10 ) 7 11 8 9~\Anaconda3\lib\site-packages\dtreeviz\trees.py in dtreeviz(tree_model, X_train, y_train, feature_names, target_name, class_names, precision, orientation, show_root_edge_labels, show_node_labels, fancy, histtype, highlight_path, X, max_X_features_LR, max_X_features_TD, label_fontsize, ticks_fontsize, fontname, colors) 10 697 11 698 n_classes = shadow_tree.nclasses() 12--> 699 color_values = colors['classes'][n_classes] 13 700 14 701 # Fix the mapping from target value to color for entire tree 15 16IndexError: list index out of range
参考にしているのは以下のサイトです。
https://qiita.com/go50/items/38c7757b444db3867b17
追記)
https://github.com/parrt/dtreeviz/issues/26
上記を確認したところ、dtreevizでクラス分けする時は10色までしか使えないようで、今回は14色に分類したかったので不適合という話のようです。
あなたの回答
tips
プレビュー