From 9fbe0662abcd91e12a7777460ea13e0ab137b954 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Fri, 11 Aug 2023 13:40:24 +0100 Subject: [PATCH] [ENH] remove `sklearn` dependency in `test_get_params` (#212) This PR removes the `sklearn` dependency in `get_test_params`, by replacing `_check_get_params_invariance` with an `skbase`-native implementation. The call to `clone` is replaced with the object's own `clone`method. --- skbase/testing/test_all_objects.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/skbase/testing/test_all_objects.py b/skbase/testing/test_all_objects.py index 62b31f07..f16a9265 100644 --- a/skbase/testing/test_all_objects.py +++ b/skbase/testing/test_all_objects.py @@ -659,21 +659,17 @@ def test_no_between_test_case_side_effects(self, object_instance, a): assert not hasattr(object_instance, "test__attr") object_instance.test__attr = 42 - @pytest.mark.skipif( - not _check_soft_dependencies("sklearn", severity="none"), - reason="skip test if sklearn is not available", - ) # sklearn is part of the dev dependency set, test should be executed with that def test_get_params(self, object_instance): """Check that get_params works correctly, against sklearn interface.""" - from sklearn.utils.estimator_checks import ( - check_get_params_invariance as _check_get_params_invariance, - ) - params = object_instance.get_params() assert isinstance(params, dict) - _check_get_params_invariance( - object_instance.__class__.__name__, object_instance - ) + + e = object_instance.clone() + + shallow_params = e.get_params(deep=False) + deep_params = e.get_params(deep=True) + + assert all(item in deep_params.items() for item in shallow_params.items()) def test_set_params(self, object_instance): """Check that set_params works correctly."""