Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_to_pyplot(self, mock_pyplot):
if sys.version_info[0] >= 3:
import pycm
handler = torchbearer.callbacks.pycm._to_pyplot(True, 'test {epoch}')
y_true = [2, 0, 2, 2, 0, 1, 1, 2, 2, 0, 1, 2]
y_pred = [0, 0, 2, 1, 0, 2, 1, 0, 2, 0, 2, 2]
cm = pycm.ConfusionMatrix(y_true, y_pred)
handler(cm, {torchbearer.EPOCH: 3})
self.assertTrue(mock_pyplot.imshow.call_args[0][0].max() == 1)
mock_pyplot.title.assert_called_once_with('test 3')
handler = torchbearer.callbacks.pycm._to_pyplot(False)
y_true = [2, 0, 2, 2, 0, 1, 1, 2, 2, 0, 1, 2]
y_pred = [0, 0, 2, 1, 0, 2, 1, 0, 2, 0, 2, 2]
cm = pycm.ConfusionMatrix(y_true, y_pred)
handler(cm, {})
self.assertTrue(mock_pyplot.imshow.call_args[0][0].max() > 1)
1 if (np.asarray(target["tag_idxs"]) == np.asarray(pred["tag_idxs"])).all() else 0
)
pred_tags = [
[self._dataset.tag_idx2text[tag_idx] for tag_idx in tag_idxs] for tag_idxs in pred_tag_idxs_list
]
target_tags = [
[self._dataset.tag_idx2text[tag_idx] for tag_idx in tag_idxs] for tag_idxs in target_tag_idxs_list
]
flat_pred_tags = list(common_utils.flatten(pred_tags))
flat_target_tags = list(common_utils.flatten(target_tags))
# confusion matrix
try:
pycm_obj = pycm.ConfusionMatrix(actual_vector=flat_target_tags, predict_vector=flat_pred_tags)
except pycmVectorError as e:
if str(e) == "Number of the classes is lower than 2":
logger.warning("Number of tags in the batch is 1. Sanity check is highly recommended.")
return {
"accuracy": 1.,
"tag_accuracy": 1.,
"macro_f1": 1.,
"macro_precision": 1.,
"macro_recall": 1.,
"conlleval_accuracy": 1.,
"conlleval_f1": 1.,
}
raise
def make_cm(y_pred, y_true):
_, y_pred = torch.max(y_pred, 1)
cm = ConfusionMatrix(y_true.cpu().numpy(), y_pred.cpu().numpy(), **self.kwargs)
for handler in self._handlers:
handler(cm, state)
class_metrics = {
"class_macro_f1": 1.,
"class_macro_precision": 1.,
"class_macro_recall": 1.,
"class_accuracy": 1.,
}
for k in range(self.K):
class_metrics.update({
f"class_accuracy_top{k + 1}": 1.,
})
else:
raise
# tag confusion matrix
try:
tag_pycm_obj = pycm.ConfusionMatrix(actual_vector=flat_target_tags, predict_vector=flat_pred_tags)
self.write_predictions(
{"target": flat_target_tags, "predict": flat_pred_tags}, pycm_obj=tag_pycm_obj, label_type="tag"
)
tag_sequence_accuracy = sum(accurate_sequence) / len(accurate_sequence)
tag_metrics = {
"tag_sequence_accuracy": tag_sequence_accuracy,
"tag_accuracy": tag_pycm_obj.Overall_ACC,
"tag_macro_f1": macro_f1(tag_pycm_obj),
"tag_macro_precision": macro_precision(tag_pycm_obj),
"tag_macro_recall": macro_recall(tag_pycm_obj),
"tag_conlleval_accuracy": conlleval_accuracy(target_tags, pred_tags),
"tag_conlleval_f1": conlleval_f1(target_tags, pred_tags),
}
- 'macro_recall': class prediction macro(unweighted mean) recall
- 'accuracy': class prediction accuracy
"""
pred_classes = []
target_classes = []
for data_id, pred in predictions.items():
target = self._dataset.get_ground_truth(data_id)
pred_classes.append(self._dataset.class_idx2text[pred["class_idx"]])
target_classes.append(target["class_text"])
# confusion matrix
try:
pycm_obj = pycm.ConfusionMatrix(
actual_vector=target_classes, predict_vector=pred_classes
)
except pycmVectorError as e:
if str(e) == "Number of the classes is lower than 2":
logger.warning("Number of classes in the batch is 1. Sanity check is highly recommended.")
return {
"macro_f1": 1.,
"macro_precision": 1.,
"macro_recall": 1.,
"accuracy": 1.,
}
raise
self.write_predictions(
{"target": target_classes, "predict": pred_classes}, pycm_obj=pycm_obj
)
for k in range(self.K):
pred_topk[k + 1].append(self._dataset.class_idx2text[pred[f"top{k + 1}"]])
pred_tags = [
[self._dataset.tag_idx2text[tag_idx] for tag_idx in tag_idxs] for tag_idxs in pred_tag_idxs_list
]
target_tags = [
[self._dataset.tag_idx2text[tag_idx] for tag_idx in tag_idxs] for tag_idxs in target_tag_idxs_list
]
flat_pred_tags = list(common_utils.flatten(pred_tags))
flat_target_tags = list(common_utils.flatten(target_tags))
# class confusion matrix
try:
class_pycm_obj = pycm.ConfusionMatrix(actual_vector=target_classes, predict_vector=pred_classes)
self.write_predictions(
{"target": target_classes, "predict": pred_classes}, pycm_obj=class_pycm_obj, label_type="class"
)
# topk
for k in range(self.K):
self.write_predictions(
{"target": target_classes, "predict": pred_topk[k + 1]}, label_type=f"class_top{k + 1}"
)
class_metrics = {
"class_macro_f1": macro_f1(class_pycm_obj),
"class_macro_precision": macro_precision(class_pycm_obj),
"class_macro_recall": macro_recall(class_pycm_obj),
"class_accuracy": class_pycm_obj.Overall_ACC,
}