Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_recall_at_k():
no_users, no_items = (10, 100)
train, test = _generate_data(no_users, no_items)
model = LightFM(loss="bpr")
model.fit_partial(test)
for k in (10, 5, 1):
# Without omitting train interactions
recall = evaluation.recall_at_k(model, test, k=k)
expected_mean_recall = _recall_at_k(model, test, k)
assert np.allclose(recall.mean(), expected_mean_recall)
assert len(recall) == (test.getnnz(axis=1) > 0).sum()
assert (
len(evaluation.recall_at_k(model, train, preserve_rows=True))
== test.shape[0]
)
# With omitting train interactions
recall = evaluation.recall_at_k(model, test, k=k, train_interactions=train)
expected_mean_recall = _recall_at_k(model, test, k, train=train)
assert np.allclose(recall.mean(), expected_mean_recall)
)
evaluation.recall_at_k(
model, test, train_interactions=train, check_intersections=True
)
evaluation.precision_at_k(
model, test, train_interactions=train, check_intersections=True
)
evaluation.reciprocal_rank(
model, test, train_interactions=train, check_intersections=True
)
# check no error is raised when there are intersections but flag is False
evaluation.auc_score(
model, train, train_interactions=train, check_intersections=False
)
evaluation.recall_at_k(
model, train, train_interactions=train, check_intersections=False
)
evaluation.precision_at_k(
model, train, train_interactions=train, check_intersections=False
)
evaluation.reciprocal_rank(
model, train, train_interactions=train, check_intersections=False
)
with pytest.raises(ValueError):
evaluation.precision_at_k(
model, train, train_interactions=train, check_intersections=True
)
with pytest.raises(ValueError):
evaluation.reciprocal_rank(
model, train, train_interactions=train, check_intersections=True
)
# check no errors raised when train and test have no interactions in common
evaluation.auc_score(
model, test, train_interactions=train, check_intersections=True
)
evaluation.recall_at_k(
model, test, train_interactions=train, check_intersections=True
)
evaluation.precision_at_k(
model, test, train_interactions=train, check_intersections=True
)
evaluation.reciprocal_rank(
model, test, train_interactions=train, check_intersections=True
)
# check no error is raised when there are intersections but flag is False
evaluation.auc_score(
model, train, train_interactions=train, check_intersections=False
)
evaluation.recall_at_k(
model, train, train_interactions=train, check_intersections=False
)
for k in (10, 5, 1):
# Without omitting train interactions
recall = evaluation.recall_at_k(model, test, k=k)
expected_mean_recall = _recall_at_k(model, test, k)
assert np.allclose(recall.mean(), expected_mean_recall)
assert len(recall) == (test.getnnz(axis=1) > 0).sum()
assert (
len(evaluation.recall_at_k(model, train, preserve_rows=True))
== test.shape[0]
)
# With omitting train interactions
recall = evaluation.recall_at_k(model, test, k=k, train_interactions=train)
expected_mean_recall = _recall_at_k(model, test, k, train=train)
assert np.allclose(recall.mean(), expected_mean_recall)
train, test = _generate_data(no_users, no_items)
model = LightFM(loss="bpr")
model.fit_partial(test)
for k in (10, 5, 1):
# Without omitting train interactions
recall = evaluation.recall_at_k(model, test, k=k)
expected_mean_recall = _recall_at_k(model, test, k)
assert np.allclose(recall.mean(), expected_mean_recall)
assert len(recall) == (test.getnnz(axis=1) > 0).sum()
assert (
len(evaluation.recall_at_k(model, train, preserve_rows=True))
== test.shape[0]
)
# With omitting train interactions
recall = evaluation.recall_at_k(model, test, k=k, train_interactions=train)
expected_mean_recall = _recall_at_k(model, test, k, train=train)
assert np.allclose(recall.mean(), expected_mean_recall)
def recall_at_k_on_ranks(
ranks, test_interactions, train_interactions=None, k=10, preserve_rows=False):
return recall_at_k(
model=ModelMockRanksCacher(ranks.copy()),
test_interactions=test_interactions,
train_interactions=train_interactions,
k=k,
preserve_rows=preserve_rows,
)