前提
python(jupyter notebook)でSHAPを実装中に、Additivity check failed in TreeExplainer!というエラーがずっと出てしまうため、解消方法を知りたいです。
『機械学習を解釈する技術』の6章のコードを参考にしてSHAPの実装を図っています。
私が使うデータセットにて、ロジスティック回帰を用いたSHAPは実行可能なものの、ランダムフォレストやXgboostは実装することができません。
一方で、教科書のコードをコピペしたものとその中で使われるboston_housingデータセットでランダムフォレストとそのshapを実装することはできているため、SHAPのバージョンの問題ではないと考えています(最新の0.41を使っています)。
また、データセットはkaggleのhttps://www.kaggle.com/datasets/fedesoriano/company-bankruptcy-prediction より、不均衡データを用いているため、20%のアンダーサンプリングを行いました。
実現したいこと
- インスタンスごとのSHAP値を出す
- SHAPを用いてfeature importanceを出す
shap.plots.beeswarm(shap_values)
の実装
発生している問題・エラーメッセージ
- インスタンスごとのSHAP値、
shap_values[0]
などを出力する際に、additivity check failed in TreeExplainer!というエラーが発生する - SHAP値で計算するデータの数を減らすhttps://github.com/slundberg/shap/issues/urlと、
shap_values[0]
は出力できるものの、インスタンスごとのSHAP値がベクトルで出力されるため、waterallやマクロなfeature importance、shap.plots.beeswarm(shap_values)
が出力されない
#1 96%|=================== | 3624/3756 [00:21<00:00] --------------------------------------------------------------------------- ExplainerError Traceback (most recent call last) <ipython-input-147-12023fbc3284> in <module> ----> 1 shap_values = explainer(X_test) #X_resampledじゃないよね? check_additivity=False [0:50] 2 shap_values[0] ~\anaconda3\lib\site-packages\shap\explainers\_tree.py in __call__(self, X, y, interactions, check_additivity) 215 v = np.stack(v, axis=-1) # put outputs at the end 216 --> 217 # the explanation object expects an expected value for each row 218 if hasattr(self.expected_value, "__len__"): 219 ev_tiled = np.tile(self.expected_value, (v.shape[0],1)) ~\anaconda3\lib\site-packages\shap\explainers\_tree.py in shap_values(self, X, y, tree_limit, approximate, check_additivity, from_call) 408 return out 409 --> 410 # we pull off the last column and keep it as our expected_value 411 def _get_shap_output(self, phi, flat_output): 412 if self.model.num_outputs == 1: ~\anaconda3\lib\site-packages\shap\explainers\_tree.py in assert_additivity(self, phi, model_output) 540 541 @staticmethod --> 542 def supports_model_with_masker(model, masker): 543 """ Determines if this explainer can handle the given model. 544 ~\anaconda3\lib\site-packages\shap\explainers\_tree.py in check_sum(sum_val, model_output) 536 for i in range(len(phi)): 537 check_sum(self.expected_value[i] + phi[i].sum(-1), model_output[:,i]) --> 538 else: 539 check_sum(self.expected_value + phi.sum(-1), model_output) 540 ExplainerError: Additivity check failed in TreeExplainer! Please ensure the data matrix you passed to the explainer is the same shape that the model was trained on. If your data shape is correct then please report this on GitHub. This check failed because for one of the samples the sum of the SHAP values was 0.929200, while the model output was 0.950000. If this difference is acceptable you can set check_additivity=False to disable this check. #2 Exception Traceback (most recent call last) <ipython-input-168-09f62f536042> in <module> ----> 1 shap.plots.waterfall(shap_values[0]) ~\anaconda3\lib\site-packages\shap\plots\_waterfall.py in waterfall(shap_values, max_display, show) 52 raise Exception("waterfall_plot requires a scalar base_values of the model output as the first " \ 53 "parameter, but you have passed an array as the first parameter! " \ ---> 54 "Try shap.waterfall_plot(explainer.base_values[0], values[0], X[0]) or " \ 55 "for multi-output models try " \ 56 "shap.waterfall_plot(explainer.base_values[0], values[0][0], X[0]).") Exception: waterfall_plot requires a scalar base_values of the model output as the first parameter, but you have passed an array as the first parameter! Try shap.waterfall_plot(explainer.base_values[0], values[0], X[0]) or for multi-output models try shap.waterfall_plot(explainer.base_values[0], values[0][0], X[0]).
該当のソースコード
python
1#1 2shap_values = explainer(X_test) #X_test[0:50]などとデータを減らして渡せば動くものの、 3 shap_values[0] #.values =array([[ 2.27666661e-03, -2.27666668e-03], 4 #[ 6.60753955e-04, -6.60753973e-04],…とベクトルになる 5 6#2 7shap.plots.waterfall(shap_values[0])
試したこと
0.SHAPのバージョンを指定する
0. 丸め誤差を解消するため、データフレームをround()を用いて四捨五入した後にshapを計算する
0. データを減らしてshapを計算する
補足情報(FW/ツールのバージョンなど)
前提に記載
windowsのバージョンは最新です
あなたの回答
tips
プレビュー