diff --git a/tests/unit/tf/models/test_retrieval.py b/tests/unit/tf/models/test_retrieval.py index 3d25c6f758..5ebade4f7d 100644 --- a/tests/unit/tf/models/test_retrieval.py +++ b/tests/unit/tf/models/test_retrieval.py @@ -261,9 +261,9 @@ def test_two_tower_advanced_options(ecommerce_data): log_to_wandb=False, ) assert metrics["loss-final"] > 0.0 - assert metrics["recall_at_100-final"] > 0.0 assert metrics["runtime_sec-final"] > 0.0 assert metrics["avg_examples_per_sec-final"] > 0.0 + assert metrics["recall_at_10-final"] > 0.0 def test_mf_advanced_options(ecommerce_data): @@ -280,9 +280,9 @@ def test_mf_advanced_options(ecommerce_data): log_to_wandb=False, ) assert metrics["loss-final"] > 0.0 - assert metrics["recall_at_100-final"] > 0.0 assert metrics["runtime_sec-final"] > 0.0 assert metrics["avg_examples_per_sec-final"] > 0.0 + assert metrics["recall_at_10-final"] > 0.0 # def test_retrieval_evaluation_without_negatives(ecommerce_data: Dataset):