Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def get_method_color(method):
for l in getattr(methods, method).__doc__.split("\n"):
l = l.strip()
if l.startswith("color = "):
v = l.split("=")[1].strip()
if v.startswith("red_blue_circle("):
return colors.red_blue_circle(float(v[16:-1]))
else:
return v
return "#000000"
fig = pl.gcf()
ax = pl.gca()
xticks = ax.get_xticks()
bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
width, height = bbox.width, bbox.height
bbox_to_xscale = xlen/width
hl_scaled = bbox_to_xscale * head_length
renderer = fig.canvas.get_renderer()
# draw the positive arrows
for i in range(len(pos_inds)):
dist = pos_widths[i]
arrow_obj = pl.arrow(
pos_lefts[i], pos_inds[i], max(dist-hl_scaled, 0.000001), 0,
head_length=min(dist, hl_scaled),
color=colors.red_rgb, width=bar_width,
head_width=bar_width
)
txt_obj = pl.text(
pos_lefts[i] + 0.5*dist, pos_inds[i], format_value(pos_widths[i], '%+0.02f'),
horizontalalignment='center', verticalalignment='center', color="white",
fontsize=12
)
text_bbox = txt_obj.get_window_extent(renderer=renderer)
arrow_bbox = arrow_obj.get_window_extent(renderer=renderer)
if text_bbox.width > arrow_bbox.width:
txt_obj.remove()
# draw the negative arrows
for i in range(len(neg_inds)):
dist = neg_widths[i]
if max_display is None:
max_display = 7
else:
max_display = min(len(feature_names), max_display)
feature_order = np.argsort(-np.abs(shap_values))
#
feature_inds = feature_order[:max_display]
y_pos = np.arange(len(feature_inds), 0, -1)
pl.barh(
y_pos, shap_values[feature_inds],
0.7, align='center',
color=[colors.red_rgb if shap_values[feature_inds[i]] > 0 else colors.blue_rgb for i in range(len(y_pos))]
)
pl.yticks(y_pos, fontsize=13)
if features is not None:
features = list(features)
# try and round off any trailing zeros after the decimal point in the feature values
for i in range(len(features)):
try:
if round(features[i]) == features[i]:
features[i] = int(features[i])
except TypeError:
pass # features[i] must not be a number
yticklabels = []
for i in feature_inds:
if features is not None:
yticklabels.append(feature_names[i] + " = " + str(features[i]))
plot_types = ['contour', 'grid'] if plot_types is None else [plot_types]
for plot_type in plot_types:
figs, ax = pdp.pdp_interact_plot(
pdp_interact_out = ft_plot,
feature_names = var_name or feature,
plot_type= plot_type, plot_pdp=True,
which_classes=which_classes, plot_params = plot_params)
plt.show()
def sample(self, sample): return self.df if sample is None else self.df.sample(sample)
# Cell
#harcode to change shap color
green_blue = LinearSegmentedColormap.from_list('custom blue', [(0, '#ffff00'), (1, '#002266')], N=256)
cl.red_blue = green_blue
cl.red_blue_solid = green_blue
# Cell
class Shapley:
"""
SHAP value: https://github.com/slundberg/shap
"""
def __init__(self, explainer, shap_values, df, df_disp, features):
shap.initjs()
self.explainer = explainer
self.shap_values = shap_values
self.df, self.df_disp, self.features = df, df_disp, features
@classmethod
def from_Tree(cls, learner, ds, df_disp = None, sample = 10000, remove_outlier = True):
if remove_outlier:
ys = shap_values[:,ind]
xs = np.arange(len(ys))#np.linspace(0, 12*2, len(ys))
pvals = []
inc = 50
for i in range(inc, len(ys)-inc, inc):
#stat, pval = scipy.stats.mannwhitneyu(v[:i], v[i:], alternative="two-sided")
stat, pval = scipy.stats.ttest_ind(ys[:i], ys[i:])
pvals.append(pval)
min_pval = np.min(pvals)
min_pval_ind = np.argmin(pvals)*inc + inc
if min_pval < 0.05 / shap_values.shape[1]:
pl.axvline(min_pval_ind, linestyle="dashed", color="#666666", alpha=0.2)
pl.scatter(xs, ys, s=10, c=features[:,ind], cmap=colors.red_blue)
pl.xlabel("Sample index")
pl.ylabel(truncate_text(feature_names[ind], 30) + "\nSHAP value", size=13)
pl.gca().xaxis.set_ticks_position('bottom')
pl.gca().yaxis.set_ticks_position('left')
pl.gca().spines['right'].set_visible(False)
pl.gca().spines['top'].set_visible(False)
cb = pl.colorbar()
cb.outline.set_visible(False)
bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted())
cb.ax.set_aspect((bbox.height - 0.7) * 20)
cb.set_label(truncate_text(feature_names[ind], 30), size=13)
if show:
pl.show()
else:
cvals = shap_values[:,ind]
fname = feature_names[ind]
# see if we need to compute the embedding
if type(method) == str and method == "pca":
pca = sklearn.decomposition.PCA(2)
embedding_values = pca.fit_transform(shap_values)
elif hasattr(method, "shape") and method.shape[1] == 2:
embedding_values = method
else:
print("Unsupported embedding method:", method)
pl.scatter(
embedding_values[:,0], embedding_values[:,1], c=cvals,
cmap=colors.red_blue, alpha=alpha, linewidth=0
)
pl.axis("off")
#pl.title(feature_names[ind])
cb = pl.colorbar()
cb.set_label("SHAP value for\n"+fname, size=13)
cb.outline.set_visible(False)
pl.gcf().set_size_inches(7.5, 5)
bbox = cb.ax.get_window_extent().transformed(pl.gcf().dpi_scale_trans.inverted())
cb.ax.set_aspect((bbox.height - 0.7) * 10)
cb.set_alpha(1)
if show:
pl.show()
plot_types = ['contour', 'grid'] if plot_types is None else [plot_types]
for plot_type in plot_types:
figs, ax = pdp.pdp_interact_plot(
pdp_interact_out = ft_plot,
feature_names = var_name or feature,
plot_type= plot_type, plot_pdp=True,
which_classes=which_classes, plot_params = plot_params)
plt.show()
def sample(self, sample): return self.df if sample is None else self.df.sample(sample)
# Cell
#harcode to change shap color
green_blue = LinearSegmentedColormap.from_list('custom blue', [(0, '#ffff00'), (1, '#002266')], N=256)
cl.red_blue = green_blue
cl.red_blue_solid = green_blue
# Cell
class Shapley:
"""
SHAP value: https://github.com/slundberg/shap
"""
def __init__(self, explainer, shap_values, df, df_disp, features):
shap.initjs()
self.explainer = explainer
self.shap_values = shap_values
self.df, self.df_disp, self.features = df, df_disp, features
@classmethod
def from_Tree(cls, learner, ds, df_disp = None, sample = 10000, remove_outlier = True):
# create a symmetric axis around base_value
a, b = (base_value - xmin), (xmax - base_value)
if a > b:
xlim = (base_value - a, base_value + a)
else:
xlim = (base_value - b, base_value + b)
# Adjust xlim to include a little visual margin.
a = (xlim[1] - xlim[0]) * 0.02
xlim = (xlim[0] - a, xlim[1] + a)
# Initialize style arguments
if alpha is None:
alpha = 1.0
if plot_color is None:
plot_color = colors.red_blue
__decision_plot_matplotlib(
base_value,
cumsum,
ascending,
feature_display_count,
features_display,
feature_names_display,
highlight,
plot_color,
axis_color,
y_demarc_color,
xlim,
alpha,
color_bar,
auto_size_plot,
else:
x_curr_gray = x_curr
axes[row,0].imshow(x_curr, cmap=pl.get_cmap('gray'))
axes[row,0].axis('off')
if len(shap_values[0][row].shape) == 2:
abs_vals = np.stack([np.abs(shap_values[i]) for i in range(len(shap_values))], 0).flatten()
else:
abs_vals = np.stack([np.abs(shap_values[i].sum(-1)) for i in range(len(shap_values))], 0).flatten()
max_val = np.nanpercentile(abs_vals, 99.9)
for i in range(len(shap_values)):
if labels is not None:
axes[row,i+1].set_title(labels[row,i], **label_kwargs)
sv = shap_values[i][row] if len(shap_values[i][row].shape) == 2 else shap_values[i][row].sum(-1)
axes[row,i+1].imshow(x_curr_gray, cmap=pl.get_cmap('gray'), alpha=0.15, extent=(-1, sv.shape[0], sv.shape[1], -1))
im = axes[row,i+1].imshow(sv, cmap=colors.red_transparent_blue, vmin=-max_val, vmax=max_val)
axes[row,i+1].axis('off')
if hspace == 'auto':
fig.tight_layout()
else:
fig.subplots_adjust(hspace=hspace)
cb = fig.colorbar(im, ax=np.ravel(axes).tolist(), label="SHAP value", orientation="horizontal", aspect=fig_size[0]/aspect)
cb.outline.set_visible(False)
if show:
pl.show()
features_tmp[:,ind0] = xs0[i]
features_tmp[:,ind1] = xs1[j]
x0[i, j] = xs0[i]
x1[i, j] = xs1[j]
vals[i, j] = model(features_tmp).mean()
fig = pl.figure()
ax = fig.add_subplot(111, projection='3d')
# x = y = np.arange(-3.0, 3.0, 0.05)
# X, Y = np.meshgrid(x, y)
# zs = np.array(fun(np.ravel(X), np.ravel(Y)))
# Z = zs.reshape(X.shape)
ax.plot_surface(x0, x1, vals, cmap=shap.plots.colors.red_blue_transparent)
ax.set_xlabel(feature_names[ind0], fontsize=13)
ax.set_ylabel(feature_names[ind1], fontsize=13)
ax.set_zlabel("E[f(x) | "+ str(feature_names[ind0]) + ", "+ str(feature_names[ind1]) + "]", fontsize=13)
if show:
pl.show()
else:
return fig, ax