Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_group_mean_squared_error_multioutput_list_ndarray():
y_t = [np.random.rand(2) for x in groups]
y_p = [np.random.rand(2) for x in groups]
result = metrics.group_mean_squared_error(y_t, y_p, groups, multioutput='raw_values')
expected_overall = skm.mean_squared_error(y_t, y_p, multioutput='raw_values')
assert np.array_equal(result.overall, expected_overall)
for target_group in np.unique(groups):
y_true = []
y_pred = []
for i in range(len(groups)):
if groups[i] == target_group:
y_true.append(y_t[i])
y_pred.append(y_p[i])
expected = skm.mean_squared_error(y_true, y_pred, multioutput='raw_values')
assert np.array_equal(result.by_group[target_group], expected)
def test_group_mean_squared_error_multioutput_single_ndarray():
y_t = np.random.rand(len(groups), 2)
y_p = np.random.rand(len(groups), 2)
result = metrics.group_mean_squared_error(y_t, y_p, groups, multioutput='raw_values')
expected_overall = skm.mean_squared_error(y_t, y_p, multioutput='raw_values')
assert np.array_equal(result.overall, expected_overall)
for target_group in np.unique(groups):
mask = np.asarray(groups) == target_group
expected = skm.mean_squared_error(y_t[mask], y_p[mask], multioutput='raw_values')
assert np.array_equal(result.by_group[target_group], expected)
},
"selection_rate": {
"model_type": [],
"function": group_selection_rate
},
"max_error": {
"model_type": ["regression"],
"function": group_max_error
},
"mean_absolute_error": {
"model_type": ["regression"],
"function": group_mean_absolute_error
},
"mean_squared_error": {
"model_type": ["regression"],
"function": group_mean_squared_error
},
"mean_squared_log_error": {
"model_type": ["regression"],
"function": group_mean_squared_log_error
},
"median_absolute_error": {
"model_type": ["regression"],
"function": group_median_absolute_error
},
"balanced_root_mean_squared_error": {
"model_type": ["regression"],
"function": group_balanced_root_mean_squared_error
},
"overprediction": {
"model_type": [],
"function": group_mean_overprediction
},
"auc": {
"model_type": ["probability"],
"function": group_roc_auc_score
},
"root_mean_squared_error": {
"model_type": ["regression", "probability"],
"function": group_root_mean_squared_error
},
"balanced_root_mean_squared_error": {
"model_type": ["probability"],
"function": group_balanced_root_mean_squared_error
},
"mean_squared_error": {
"model_type": ["regression", "probability"],
"function": group_mean_squared_error
},
"mean_absolute_error": {
"model_type": ["regression", "probability"],
"function": group_mean_absolute_error
},
"r2_score": {
"model_type": ["regression"],
"function": group_r2_score
},
"max_error": {
"model_type": [],
"function": group_max_error
},
"median_absolute_error": {
"model_type": [],
"function": group_median_absolute_error