Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_explain_weights_feature_names_pandas(boston_train):
pd = pytest.importorskip('pandas')
X, y, feature_names = boston_train
df = pd.DataFrame(X, columns=feature_names)
reg = XGBRegressor().fit(df, y)
# it shoud pick up feature names from DataFrame columns
res = explain_weights(reg)
for expl in format_as_all(res, reg):
assert 'PTRATIO' in expl
# it is possible to override DataFrame feature names
numeric_feature_names = ["zz%s" % idx for idx in range(len(feature_names))]
res = explain_weights(reg, feature_names=numeric_feature_names)
for expl in format_as_all(res, reg):
assert 'zz12' in expl
def test_explain_xgboost_regressor(boston_train):
xs, ys, feature_names = boston_train
reg = XGBRegressor()
reg.fit(xs, ys)
res = explain_weights(reg)
for expl in format_as_all(res, reg):
assert 'f12' in expl
res = explain_weights(reg, feature_names=feature_names)
for expl in format_as_all(res, reg):
assert 'LSTAT' in expl
def test_explain_xgboost_regressor(boston_train):
xs, ys, feature_names = boston_train
reg = XGBRegressor()
reg.fit(xs, ys)
res = explain_weights(reg)
for expl in format_as_all(res, reg):
assert 'f12' in expl
res = explain_weights(reg, feature_names=feature_names)
for expl in format_as_all(res, reg):
assert 'LSTAT' in expl
def test_explain_xgboost_booster(boston_train):
xs, ys, feature_names = boston_train
booster = xgboost.train(
params={'objective': 'reg:linear', 'silent': True},
dtrain=xgboost.DMatrix(xs, label=ys),
)
res = explain_weights(booster)
for expl in format_as_all(res, booster):
assert 'f12' in expl
res = explain_weights(booster, feature_names=feature_names)
for expl in format_as_all(res, booster):
assert 'LSTAT' in expl
def test_explain_xgboost_booster(boston_train):
xs, ys, feature_names = boston_train
booster = xgboost.train(
params={'objective': 'reg:linear', 'silent': True},
dtrain=xgboost.DMatrix(xs, label=ys),
)
res = explain_weights(booster)
for expl in format_as_all(res, booster):
assert 'f12' in expl
res = explain_weights(booster, feature_names=feature_names)
for expl in format_as_all(res, booster):
assert 'LSTAT' in expl
@explain_weights.register(XGBRegressor)
@explain_weights.register(Booster)
def explain_weights_xgboost(xgb,
vec=None,
top=20,
target_names=None, # ignored
targets=None, # ignored
feature_names=None,
feature_re=None, # type: Pattern[str]
feature_filter=None,
importance_type='gain',
):
"""
Return an explanation of an XGBoost estimator (via scikit-learn wrapper
XGBClassifier or XGBRegressor, or via xgboost.Booster)
as feature importances.
]
_REGRESSORS = [
regression.AdaGradRegressor,
regression.CDRegressor,
regression.FistaRegressor,
regression.LinearSVR,
regression.SAGARegressor,
regression.SAGRegressor,
regression.SDCARegressor,
regression.SGDRegressor,
# regression.SVRGRegressor
]
for clf in _CLASSIFIERS:
explain_weights.register(clf, explain_linear_classifier_weights)
explain_weights_lightning.register(clf, explain_linear_classifier_weights)
explain_prediction.register(clf, explain_prediction_linear_classifier)
explain_prediction_lightning.register(clf, explain_prediction_linear_classifier)
for reg in _REGRESSORS:
explain_weights.register(reg, explain_linear_regressor_weights)
explain_weights_lightning.register(reg, explain_linear_regressor_weights)
explain_prediction.register(reg, explain_prediction_linear_regressor)
explain_prediction_lightning.register(reg, explain_prediction_linear_regressor)
@explain_weights.register(BaseEstimator)
def explain_weights_sklearn_not_supported(
estimator, vec=None, top=_TOP,
target_names=None,
targets=None,
feature_names=None, coef_scale=None,
feature_re=None, feature_filter=None):
return Explanation(
estimator=repr(estimator),
error="estimator %r is not supported" % estimator,
)
@explain_weights.register(catboost.CatBoostClassifier)
@explain_weights.register(catboost.CatBoostRegressor)
def explain_weights_catboost(catb,
vec=None,
top=20,
importance_type='PredictionValuesChange',
feature_names=None,
pool=None
):
"""
Return an explanation of an CatBoost estimator (CatBoostClassifier,
CatBoost, CatBoostRegressor) as feature importances.
See :func:`eli5.explain_weights` for description of
``top``, ``feature_names``,
``feature_re`` and ``feature_filter`` parameters.
def explain_weights_ovr(ovr, **kwargs):
estimator = ovr.estimator
func = explain_weights.dispatch(estimator.__class__)
return func(ovr, **kwargs)