Skip to content

Commit

Permalink
[ENH] safer get_fitted_params default functionality to avoid except…
Browse files Browse the repository at this point in the history
…ion on `getattr` (#353)

This PR makes the `get_fitted_params` core functionality and defaults
safer against exceptions on `getattr`.

In rare cases, `getattr` can cause an exception, namely if a property is
being accessed in a way that generates an exception.

Examples are some fitted parameter arguments in `sklearn` that are
decorated as property in newer versions, e.g.,
`RandomForestRegressor.estimator_` in unfitted state.

In most use cases, there will be no change in behaviour; in the
described case where an exception would be raised, this is now caught
and suppressed, and the corresponding parameter is considered as not
present.

Changes occur only in cases that would have previously raised genuine
exceptions, so no working code is affected, and no deprecation is
necessary despite this being a change to a core interface element.
  • Loading branch information
fkiraly authored Aug 23, 2024
1 parent bfb1b8f commit 2d6dcfe
Showing 1 changed file with 41 additions and 4 deletions.
45 changes: 41 additions & 4 deletions skbase/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1292,10 +1292,47 @@ def _get_fitted_params_default(self, obj=None):
fitted_params = [
attr for attr in dir(obj) if attr.endswith("_") and not attr.startswith("_")
]
# remove the "_" at the end
fitted_param_dict = {
p[:-1]: getattr(obj, p) for p in fitted_params if hasattr(obj, p)
}

def getattr_safe(obj, attr):
"""Get attribute of object, safely.
Safe version of getattr, that returns None if attribute does not exist,
or if an exception is raised during getattr.
Also returns a boolean indicating whether the attribute was successfully
retrieved, to distinguish between None value and non-existent attribute,
or exception during getattr.
Parameters
----------
obj : any object
object to get attribute from
attr : str
attribute name to get from obj
Returns
-------
attr : Any
attribute of obj, if it exists and does not raise on getattr;
otherwise None
success : bool
whether the attribute was successfully retrieved
"""
try:
if hasattr(obj, attr):
attr = getattr(obj, attr)
return attr, True
else:
return None, False
except Exception:
return None, False

fitted_param_dict = {}

for p in fitted_params:
attr, success = getattr_safe(obj, p)
if success:
p_name = p[:-1] # remove the "_" at the end to get the parameter name
fitted_param_dict[p_name] = attr

return fitted_param_dict

Expand Down

0 comments on commit 2d6dcfe

Please sign in to comment.