Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_ClassificationInterpretation(learn):
this_tests(ClassificationInterpretation)
interp = ClassificationInterpretation.from_learner(learn)
assert isinstance(interp.confusion_matrix(), (np.ndarray))
assert interp.confusion_matrix().sum() == len(learn.data.valid_ds)
conf = interp.most_confused()
expect = {'3', '7'}
assert (len(conf) == 0 or
len(conf) == 1 and (set(conf[0][:2]) == expect) or
len(conf) == 2 and (set(conf[0][:2]) == set(conf[1][:2]) == expect)
), f"conf={conf}"
def test_confusion_tabular(learn):
interp = ClassificationInterpretation.from_learner(learn)
assert isinstance(interp.confusion_matrix(), (np.ndarray))
assert interp.confusion_matrix().sum() == len(learn.data.valid_ds)
this_tests(interp.confusion_matrix)
def test_interp(learn):
this_tests(ClassificationInterpretation.from_learner)
interp = ClassificationInterpretation.from_learner(learn)
losses,idxs = interp.top_losses()
assert len(learn.data.valid_ds)==len(losses)==len(idxs)
def _learner_interpret(learn:Learner, ds_type:DatasetType=DatasetType.Valid, tta=False):
"Create a `ClassificationInterpretation` object from `learner` on `ds_type` with `tta`."
return ClassificationInterpretation.from_learner(learn, ds_type=ds_type, tta=tta)
Learner.interpret = _learner_interpret
def _learner_interpret(learn:Learner, ds_type:DatasetType = DatasetType.Valid):
"Create a 'ClassificationInterpretation' object from 'learner' on 'ds_type'."
return ClassificationInterpretation.from_learner(learn, ds_type=ds_type)
print(f'{str(len(mismatches))} misclassified samples over {str(len(self.data.valid_ds))} samples in the validation set.')
samples = min(samples, len(mismatches))
for ima in range(len(mismatches_ordered_byloss)):
mismatchescontainer.append(mismatches_ordered_byloss[ima][0])
for sampleN in range(samples):
actualclasses = ''
for clas in infolist[ordlosses_idxs[sampleN]][2]:
actualclasses = f'{actualclasses} -- {str(classes_ids[clas][1])}'
imag = mismatches_ordered_byloss[sampleN][0]
imag = show_image(imag, figsize=figsize)
imag.set_title(f"""Predicted: {classes_ids[infolist[ordlosses_idxs[sampleN]][1]][1]} \nActual: {actualclasses}\nLoss: {infolist[ordlosses_idxs[sampleN]][3]}\nProbability: {infolist[ordlosses_idxs[sampleN]][4]}""",
loc='left')
plt.show()
if save_misclassified: return mismatchescontainer
ClassificationInterpretation.from_learner = _cl_int_from_learner
ClassificationInterpretation.plot_top_losses = _cl_int_plot_top_losses
ClassificationInterpretation.plot_multi_top_losses = _cl_int_plot_multi_top_losses
def _learner_interpret(learn:Learner, ds_type:DatasetType=DatasetType.Valid, tta=False):
"Create a `ClassificationInterpretation` object from `learner` on `ds_type` with `tta`."
return ClassificationInterpretation.from_learner(learn, ds_type=ds_type, tta=tta)
Learner.interpret = _learner_interpret
def _learner_interpret(learn:Learner, ds_type:DatasetType=DatasetType.Valid):
"Create a `ClassificationInterpretation` object from `learner` on `ds_type` with `tta`."
return ClassificationInterpretation.from_learner(learn, ds_type=ds_type)
Learner.interpret = _learner_interpret