diff --git a/skbase/base/_base.py b/skbase/base/_base.py index bf1d89bc..28512554 100644 --- a/skbase/base/_base.py +++ b/skbase/base/_base.py @@ -1294,10 +1294,35 @@ def _get_fitted_params_default(self, obj=None): ] def getattr_safe(obj, attr): + """Get attribute of object, safely. + + Safe version of getattr, that returns None if attribute does not exist, + or if an exception is raised during getattr. + Also returns a boolean indicating whether the attribute was successfully + retrieved, to distinguish between None value and non-existent attribute, + or exception during getattr. + + Parameters + ---------- + obj : any object + object to get attribute from + attr : str + attribute name to get from obj + + Returns + ------- + attr : Any + attribute of obj, if it exists and does not raise on getattr; + otherwise None + success : bool + whether the attribute was successfully retrieved + """ try: if hasattr(obj, attr): attr = getattr(obj, attr) return attr, True + else: + return None, False except Exception: return None, False @@ -1305,10 +1330,9 @@ def getattr_safe(obj, attr): for p in fitted_params: attr, success = getattr_safe(obj, p) - if not success: - continue - p_name = p[:-1] # remove the "_" at the end to get the parameter name - fitted_param_dict[p_name] = attr + if success: + p_name = p[:-1] # remove the "_" at the end to get the parameter name + fitted_param_dict[p_name] = attr return fitted_param_dict