Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
@register_metric('ner_token_f1')
def ner_token_f1(y_true, y_pred, print_results=False):
y_true = list(chain(*y_true))
y_pred= list(chain(*y_pred))
# Drop BIO or BIOES markup
assert all(len(tag.split('-')) <= 2 for tag in y_true)
y_true = [tag.split('-')[-1] for tag in y_true]
y_pred = [tag.split('-')[-1] for tag in y_pred]
tags = set(y_true) | set(y_pred)
tags_dict = {tag: n for n, tag in enumerate(tags)}
y_true_inds = np.array([tags_dict[tag] for tag in y_true])
y_pred_inds = np.array([tags_dict[tag] for tag in y_pred])
results = {}
@register_metric('mrr@10')
def r_at_10(labels, predictions):
return mean_reciprocal_rank_at_k(labels, predictions, k=10)
@register_metric('r@5')
def r_at_5(labels, predictions):
return recall_at_k(labels, predictions, k=5)
@register_metric('accuracy')
def accuracy(y_true, y_predicted):
"""
Calculate accuracy in terms of absolute coincidence
Args:
y_true: array of true values
y_predicted: array of predicted values
Returns:
portion of absolutely coincidental samples
"""
examples_len = len(y_true)
correct = sum([y1 == y2 for y1, y2 in zip(y_true, y_predicted)])
return correct / examples_len if examples_len else 0
@register_metric('coref_metrics')
def coref_metrics(predicted_clusters_l: List, mention_to_predicted_l: List, gold_clusters_l: List, ) -> Dict:
evaluator = CorefEvaluator()
for (gold_clusters, predicted_clusters, mention_to_predicted) in zip(gold_clusters_l, predicted_clusters_l,
mention_to_predicted_l):
gold_clusters = [tuple(tuple(m) for m in gc) for gc in gold_clusters]
mention_to_gold = {}
for gc in gold_clusters:
for mention in gc:
mention_to_gold[mention] = gc
evaluator.update(predicted_clusters, gold_clusters, mention_to_predicted, mention_to_gold)
p, r, f = evaluator.get_prf()
# print(dict(F1=(f * 100), precision=(p * 100), recall=(r * 100)))
@register_metric('r@1')
def r_at_1(y_true, y_pred):
return recall_at_k(y_true, y_pred, k=1)
@register_metric('personachat_hits@1_tf')
def personachat_hits1_tf(y_true, metrics):
score = list(map(lambda x: x['hits@1'], metrics))
return float(np.mean(score))
@register_metric('f1_macro')
def round_f1_macro(y_true, y_predicted):
"""
Calculates F1 macro measure.
Args:
y_true: list of true values
y_predicted: list of predicted values
Returns:
F1 score
"""
try:
predictions = [np.round(x) for x in y_predicted]
except TypeError:
predictions = y_predicted
@register_metric('r@2')
def r_at_2(y_true, y_pred):
return recall_at_k(y_true, y_pred, k=2)
@register_metric('gapping_position_f1')
def gapping_position_f1(y_true, y_pred):
tp_words, fp_words, fn_words = 0, 0, 0
has_positive_sents = False
for elem in zip(y_true, y_pred):
print(elem[0], elem[1])
break
for (true_verbs, true_gaps), (pred_verbs, pred_gaps) in zip(y_true, y_pred):
has_positive_sents |= (len(true_gaps) > 0)
for gap in true_gaps:
if gap in pred_gaps:
tp_words += 1
else:
fn_words += 1
for gap in pred_gaps:
if gap not in true_gaps:
fp_words += 1