Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_sklearn_random_forest_multiclass():
import shap
from sklearn.ensemble import RandomForestClassifier
X, y = shap.datasets.iris()
y[y == 2] = 1
model = RandomForestClassifier(n_estimators=100, max_depth=None, min_samples_split=2, random_state=0)
model.fit(X, y)
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)
assert np.abs(shap_values[0][0,0] - 0.05) < 1e-3
assert np.abs(shap_values[1][0,0] + 0.05) < 1e-3
def test_sklearn_interaction():
import sklearn
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
# train a simple sklean RF model on the iris dataset
X, y = shap.datasets.iris()
X_train,X_test,Y_train,Y_test = train_test_split(*shap.datasets.iris(), test_size=0.2, random_state=0)
rforest = RandomForestClassifier(n_estimators=100, max_depth=None, min_samples_split=2, random_state=0)
model = rforest.fit(X_train, Y_train)
# verify symmetry of the interaction values (this typically breaks if anything is wrong)
interaction_vals = shap.TreeExplainer(model).shap_interaction_values(X)
for i in range(len(interaction_vals)):
for j in range(len(interaction_vals[i])):
for k in range(len(interaction_vals[i][j])):
for l in range(len(interaction_vals[i][j][k])):
assert abs(interaction_vals[i][j][k][l] - interaction_vals[i][j][l][k]) < 1e-6
# ensure the interaction plot works
shap.summary_plot(interaction_vals[0], X, show=False)
def test_lightgbm_multiclass():
try:
import lightgbm
except:
print("Skipping test_lightgbm_multiclass!")
return
import shap
# train lightgbm model
X, Y = shap.datasets.iris()
model = lightgbm.sklearn.LGBMClassifier()
model.fit(X, Y)
# explain the model's predictions using SHAP values
shap_values = shap.TreeExplainer(model).shap_values(X)
# ensure plot works for first class
shap.dependence_plot(0, shap_values[0], X, show=False)
def test_kernel_sparse_vs_dense_multirow_background():
import sklearn
import shap
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
# train a logistic regression classifier
X_train, X_test, Y_train, _ = train_test_split(*shap.datasets.iris(), test_size=0.1, random_state=0)
lr = LogisticRegression(solver='lbfgs')
lr.fit(X_train, Y_train)
# use Kernel SHAP to explain test set predictions with dense data
explainer = shap.KernelExplainer(lr.predict_proba, X_train, nsamples=100, link="logit", l1_reg="rank(3)")
shap_values = explainer.shap_values(X_test)
X_sparse_train = sp.sparse.csr_matrix(X_train)
X_sparse_test = sp.sparse.csr_matrix(X_test)
lr_sparse = LogisticRegression(solver='lbfgs')
lr_sparse.fit(X_sparse_train, Y_train)
# use Kernel SHAP again but with sparse data
sparse_explainer = shap.KernelExplainer(lr.predict_proba, X_sparse_train, nsamples=100, link="logit", l1_reg="rank(3)")
sparse_shap_values = sparse_explainer.shap_values(X_sparse_test)
def test_sklearn_decision_tree_multiclass():
import shap
from sklearn.tree import DecisionTreeClassifier
X, y = shap.datasets.iris()
y[y == 2] = 1
model = DecisionTreeClassifier(max_depth=None, min_samples_split=2, random_state=0)
model.fit(X, y)
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)
assert np.abs(shap_values[0][0,0] - 0.05) < 1e-1
assert np.abs(shap_values[1][0,0] + 0.05) < 1e-1
def test_provided_background_tree_path_dependent():
try:
import xgboost
except:
print("Skipping test_provided_background_tree_path_dependent!")
return
from sklearn.model_selection import train_test_split
import numpy as np
import shap
np.random.seed(10)
X,y = shap.datasets.iris()
X = X[:100]
y = y[:100]
train_x, test_x, train_y, test_y = train_test_split(X, y, random_state=1)
feature_names = ["a", "b", "c", "d"]
dtrain = xgboost.DMatrix(train_x, label=train_y, feature_names=feature_names)
dtest = xgboost.DMatrix(test_x, feature_names=feature_names)
params = {
'booster': 'gbtree',
'objective': 'binary:logistic',
'max_depth': 4,
'eta': 0.1,
'nthread': -1,
'silent': 1
}
def test_xgboost_multiclass():
try:
import xgboost
except Exception as e:
print("Skipping test_xgboost_multiclass!")
return
import shap
# train XGBoost model
X, Y = shap.datasets.iris()
model = xgboost.XGBClassifier(objective="binary:logistic", max_depth=4)
model.fit(X, Y)
# explain the model's predictions using SHAP values (use pred_contrib in LightGBM)
shap_values = shap.TreeExplainer(model).shap_values(X)
# ensure plot works for first class
shap.dependence_plot(0, shap_values[0], X, show=False)
def test_lightgbm_multiclass():
try:
import lightgbm
except:
print("Skipping test_lightgbm_multiclass!")
return
import shap
# train lightgbm model
X, Y = shap.datasets.iris()
model = lightgbm.sklearn.LGBMClassifier()
model.fit(X, Y)
# explain the model's predictions using SHAP values
shap_values = shap.TreeExplainer(model).shap_values(X)
# ensure plot works for first class
shap.dependence_plot(0, shap_values[0], X, show=False)