Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_text_explainer_show_methods():
pytest.importorskip('IPython')
from IPython.display import HTML
text = "Hello, world!"
@_apply_to_list
def predict_proba(doc):
return [0.0, 1.0] if 'lo' in doc else [1.0, 0.0]
te = TextExplainer()
te.fit(text, predict_proba)
pred_expl = te.show_prediction()
assert isinstance(pred_expl, HTML)
assert 'lo' in pred_expl.data
weight_expl = te.show_weights()
assert isinstance(weight_expl, HTML)
assert 'lo' in weight_expl.data
assert te.metrics_['score'] < 0.9
assert te.metrics_['mean_KL_divergence'] > 0.3
# position_dependent=True can make it work
te = TextExplainer(position_dependent=True, random_state=42)
te.fit(text, predict_proba)
print(te.metrics_)
assert te.metrics_['score'] > 0.95
assert te.metrics_['mean_KL_divergence'] < 0.3
expl = te.explain_prediction()
format_as_all(expl, te.clf_)
# it is also possible to almost make it work using a custom vectorizer
vec = CountVectorizer(ngram_range=(1, 2))
te = TextExplainer(vec=vec, random_state=42)
te.fit(text, predict_proba)
print(te.metrics_)
assert te.metrics_['score'] > 0.95
assert te.metrics_['mean_KL_divergence'] < 0.3
expl = te.explain_prediction()
format_as_all(expl, te.clf_)
# custom vectorizers are not supported when position_dependent is True
with pytest.raises(ValueError):
te = TextExplainer(position_dependent=True, vec=HashingVectorizer())
def test_text_explainer_char_based(token_pattern):
text = "Hello, world!"
predict_proba = substring_presence_predict_proba('lo')
te = TextExplainer(char_based=True, token_pattern=token_pattern)
te.fit(text, predict_proba)
print(te.metrics_)
assert te.metrics_['score'] > 0.95
assert te.metrics_['mean_KL_divergence'] < 0.1
res = te.explain_prediction()
format_as_all(res, te.clf_)
check_targets_scores(res)
assert res.targets[0].feature_weights.pos[0].feature == 'lo'
# another way to look at results (not that useful for char ngrams)
res = te.explain_weights()
assert res.targets[0].feature_weights.pos[0].feature == 'lo'
def test_text_explainer_rbf_sigma():
text = 'foo bar baz egg spam'
predict_proba = substring_presence_predict_proba('bar')
te1 = TextExplainer().fit(text, predict_proba)
te2 = TextExplainer(rbf_sigma=0.1).fit(text, predict_proba)
te3 = TextExplainer(rbf_sigma=1.0).fit(text, predict_proba)
assert te1.similarity_.sum() < te3.similarity_.sum()
assert te1.similarity_.sum() > te2.similarity_.sum()
def test_text_explainer_token_pattern():
text = "foo-bar baz egg-spam"
predict_proba = substring_presence_predict_proba('bar')
# a different token_pattern
te = TextExplainer(token_pattern=r'(?u)\b[-\w]+\b')
te.fit(text, predict_proba)
print(te.metrics_)
assert te.metrics_['score'] > 0.95
assert te.metrics_['mean_KL_divergence'] < 0.1
expl = te.explain_prediction()
format_as_all(expl, te.clf_)
assert expl.targets[0].feature_weights.pos[0].feature == 'foo-bar'
def test_lime_flat_neighbourhood(newsgroups_train):
docs, y, target_names = newsgroups_train
doc = docs[0]
@_apply_to_list
def predict_proba(doc):
""" This function predicts non-zero probabilities only for 3 labels """
proba_graphics = [0, 1.0, 0, 0]
proba_other = [0.9, 0, 0.1, 0]
return proba_graphics if 'file' in doc else proba_other
te = TextExplainer(expand_factor=None, random_state=42)
te.fit(doc, predict_proba)
print(te.metrics_)
print(te.clf_.classes_, target_names)
res = te.explain_prediction(top=20, target_names=target_names)
for expl in format_as_all(res, te.clf_):
assert 'file' in expl
assert "comp.graphics" in expl
def test_text_explainer_custom_classifier():
text = "foo-bar baz egg-spam"
predict_proba = substring_presence_predict_proba('bar')
# use decision tree to explain the prediction
te = TextExplainer(clf=DecisionTreeClassifier(max_depth=2))
te.fit(text, predict_proba)
print(te.metrics_)
assert te.metrics_['score'] > 0.99
assert te.metrics_['mean_KL_divergence'] < 0.01
expl = te.explain_prediction()
format_as_all(expl, te.clf_)
# with explain_weights we can get a nice tree representation
expl = te.explain_weights()
print(expl.decision_tree.tree)
assert expl.decision_tree.tree.feature_name == "bar"
format_as_all(expl, te.clf_)
docs, y, target_names = newsgroups_train
try:
vec = HashingVectorizer(alternate_sign=False)
except TypeError:
# sklearn < 0.19
vec = HashingVectorizer(non_negative=True)
clf = MultinomialNB()
X = vec.fit_transform(docs)
clf.fit(X, y)
print(clf.score(X, y))
pipe = make_pipeline(vec, clf)
doc = docs[0]
te = TextExplainer(random_state=42)
te.fit(doc, pipe.predict_proba)
print(te.metrics_)
assert te.metrics_['score'] > 0.7
assert te.metrics_['mean_KL_divergence'] < 0.1
res = te.explain_prediction(top=20, target_names=target_names)
expl = format_as_text(res)
print(expl)
assert 'file' in expl
prepro_query = self.preprocess(query)
explainer_generator = ExplainerGenerator(model, vocab, max_len)
sampler = MaskingTextSampler(
replacement=UNK,
max_replace=max_replace,
token_pattern=None,
bow=False
)
explainer_list = list()
for i in indicies:
predict_fn = explainer_generator.get_predict_function(i)
te = TextExplainer(
sampler=sampler,
position_dependent=True,
random_state=RANDOM_SEED,
)
te.fit(' '.join(prepro_query), predict_fn)
pred_explain = te.explain_prediction(target_names=[l for l in label][3:], top_targets=top_targets)
explainer_list.append(pred_explain)
return explainer_list