sklearnのfit関数でのエラーについて質問させていただきます。
現在アンサンブル学習をやりたいと思い、以下のサイトのコードを試しています。
参考HP
エラーが出るコードの箇所と内容は以下になります。
python
1# feature importance using random forest 2from sklearn.ensemble import RandomForestRegressor 3rf = RandomForestRegressor(n_estimators=80, max_features='auto') 4rf.fit(X_train, y_train) 5print('Training done using Random Forest') 6 7ranking = np.argsort(-rf.feature_importances_) 8f, ax = plt.subplots(figsize=(11, 9)) 9sns.barplot(x=rf.feature_importances_[ranking], y=X_train.columns.values[ranking], orient='h') 10ax.set_xlabel("feature importance") 11plt.tight_layout() 12plt.show() 13 ...
error
1--------------------------------------------------------------------------- 2ValueError Traceback (most recent call last) 3<ipython-input-17-f9c4d1e292fc> in <module>() 4 2 from sklearn.ensemble import RandomForestRegressor 5 3 rf = RandomForestRegressor(n_estimators=80, max_features='auto') 6----> 4 rf.fit(X_train, y_train) 7 5 print('Training done using Random Forest') 8 6 9 10/usr/local/lib/python2.7/dist-packages/sklearn/ensemble/forest.pyc in fit(self, X, y, sample_weight) 11 245 """ 12 246 # Validate or convert input data 13--> 247 X = check_array(X, accept_sparse="csc", dtype=DTYPE) 14 248 y = check_array(y, accept_sparse='csc', ensure_2d=False, dtype=None) 15 249 if sample_weight is not None: 16 17/usr/local/lib/python2.7/dist-packages/sklearn/utils/validation.pyc in check_array(array, accept_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, warn_on_dtype, estimator) 18 451 % (array.ndim, estimator_name)) 19 452 if force_all_finite: 20--> 453 _assert_all_finite(array) 21 454 22 455 shape_repr = _shape_repr(array.shape) 23 24/usr/local/lib/python2.7/dist-packages/sklearn/utils/validation.pyc in _assert_all_finite(X) 25 42 and not np.isfinite(X).all()): 26 43 raise ValueError("Input contains NaN, infinity" 27---> 44 " or a value too large for %r." % X.dtype) 28 45 29 46 30 31ValueError: Input contains NaN, infinity or a value too large for dtype('float32').
試したことはX_train.drop(X_train.columns[np.isnan(X_train).any()], axis=1)
を入れて`NaN`を削除しようとしましたが、変化なしでした。
分かる方がいましたら、回答いただけると助かります。
※ご回答いただいた内容に質問させていただくこともあるかと思いますので、
※よろしければご返信いただければと思います。
回答1件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。
2019/11/02 06:18