Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def group_r2_score(y_true, y_pred, group_membership, *,
multioutput='uniform_average',
sample_weight=None):
"""Wrap the :py:func:`sklearn.metrics.r2_score` routine.
The arguments remain the same, with `group_membership` added.
However, the only positional arguments supported are `y_true`,
`y_pred` and `group_membership`.
All others must be specified by name.
"""
def internal_r2_wrapper(y_true, y_pred, sample_weight=None):
return skm.r2_score(y_true, y_pred,
multioutput=multioutput,
sample_weight=sample_weight)
return metric_by_group(internal_r2_wrapper,
y_true, y_pred, group_membership, sample_weight=sample_weight)
def group_precision_score(y_true, y_pred, group_membership, *,
labels=None, pos_label=1, average='binary',
sample_weight=None):
"""Wrap the :py:func:`sklearn.metrics.precision_score` routine.
The arguments remain the same, with `group_membership` added.
However, the only positional arguments supported are `y_true`,
`y_pred` and `group_membership`.
All others must be specified by name.
"""
def internal_prec_wrapper(y_true, y_pred, sample_weight=None):
return skm.precision_score(y_true, y_pred,
labels=labels, pos_label=pos_label, average=average,
sample_weight=sample_weight)
return metric_by_group(internal_prec_wrapper, y_true, y_pred, group_membership, sample_weight)
def group_zero_one_loss(y_true, y_pred, group_membership, *,
normalize=True,
sample_weight=None):
"""A wrapper around the :any:`sklearn.metrics.zero_one_loss` routine.
The arguments remain the same, with `group_membership` added.
However, the only positional arguments supported are `y_true`,
`y_pred` and `group_membership`.
All others must be specified by name.
"""
def internal_zol_wrapper(y_true, y_pred, sample_weight=None):
return skm.zero_one_loss(y_true, y_pred,
normalize=normalize,
sample_weight=sample_weight)
return metric_by_group(internal_zol_wrapper, y_true, y_pred, group_membership, sample_weight)
def group_mean_squared_error(y_true, y_pred, group_membership, *,
multioutput='uniform_average',
sample_weight=None):
"""A wrapper around the :any:`sklearn.metrics.mean_squared_error` routine.
The arguments remain the same, with `group_membership` added.
However, the only positional arguments supported are `y_true`,
`y_pred` and `group_membership`.
All others must be specified by name.
"""
def internal_mse_wrapper(y_true, y_pred, sample_weight=None):
return skm.mean_squared_error(y_true, y_pred,
multioutput=multioutput,
sample_weight=sample_weight)
return metric_by_group(internal_mse_wrapper,
y_true, y_pred, group_membership, sample_weight=sample_weight)
def group_zero_one_loss(y_true, y_pred, group_membership, *,
normalize=True,
sample_weight=None):
"""Wrap the :py:func:`sklearn.metrics.zero_one_loss` routine.
The arguments remain the same, with `group_membership` added.
However, the only positional arguments supported are `y_true`,
`y_pred` and `group_membership`.
All others must be specified by name.
"""
def internal_zol_wrapper(y_true, y_pred, sample_weight=None):
return skm.zero_one_loss(y_true, y_pred,
normalize=normalize,
sample_weight=sample_weight)
return metric_by_group(internal_zol_wrapper, y_true, y_pred, group_membership, sample_weight)
def group_confusion_matrix(y_true, y_pred, group_membership, *,
labels=None,
sample_weight=None):
"""A wrapper around the :any:`sklearn.metrics.confusion_matrix` routine.
The arguments remain the same, with `group_membership` added.
However, the only positional arguments supported are `y_true`,
`y_pred` and `group_membership`.
All others must be specified by name.
"""
def internal_cm_wrapper(y_true, y_pred, sample_weight=None):
return skm.confusion_matrix(y_true, y_pred,
labels,
sample_weight)
return metric_by_group(internal_cm_wrapper, y_true, y_pred, group_membership, sample_weight)
def group_mean_squared_error(y_true, y_pred, group_membership, *,
multioutput='uniform_average',
sample_weight=None):
"""Wrap the :py:func:`sklearn.metrics.mean_squared_error` routine.
The arguments remain the same, with `group_membership` added.
However, the only positional arguments supported are `y_true`,
`y_pred` and `group_membership`.
All others must be specified by name.
"""
def internal_mse_wrapper(y_true, y_pred, sample_weight=None):
return skm.mean_squared_error(y_true, y_pred,
multioutput=multioutput,
sample_weight=sample_weight)
return metric_by_group(internal_mse_wrapper,
y_true, y_pred, group_membership, sample_weight=sample_weight)
def group_accuracy_score(y_true, y_pred, group_membership, *,
normalize=True,
sample_weight=None):
"""A wrapper around the :any:`sklearn.metrics.accuracy_score` routine.
The arguments remain the same, with `group_membership` added.
However, the only positional arguments supported are `y_true`,
`y_pred` and `group_membership`.
All others must be specified by name.
"""
def internal_acc_wrapper(y_true, y_pred, sample_weight=None):
return skm.accuracy_score(y_true, y_pred,
normalize,
sample_weight)
return metric_by_group(internal_acc_wrapper, y_true, y_pred, group_membership, sample_weight)
def group_recall_score(y_true, y_pred, group_membership, *,
labels=None, pos_label=1, average='binary',
sample_weight=None):
"""Wrap the :py:func:`sklearn.metrics.recall_score` routine.
The arguments remain the same, with `group_membership` added.
However, the only positional arguments supported are `y_true`,
`y_pred` and `group_membership`.
All others must be specified by name.
"""
def internal_recall_wrapper(y_true, y_pred, sample_weight=None):
return skm.recall_score(y_true, y_pred,
labels=labels, pos_label=pos_label, average=average,
sample_weight=sample_weight)
return metric_by_group(internal_recall_wrapper,
y_true, y_pred, group_membership, sample_weight)
def group_selection_rate(y_true, y_pred, group_membership,
*, pos_label=1, sample_weight=None):
"""Wrap :func:`selection_rate` as a group metric.
The arguments are the same, with the addition of the
`group_membership` array.
"""
def internal_sel_wrapper(y_true, y_pred, sample_weight=None):
return selection_rate(y_true, y_pred, pos_label=pos_label, sample_weight=sample_weight)
return metric_by_group(internal_sel_wrapper,
y_true, y_pred, group_membership,
sample_weight=sample_weight)