Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] merge sktime BaseEstimator into skbase BaseEstimator #370

Merged
merged 1 commit into from
Oct 4, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 37 additions & 23 deletions skbase/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1445,13 +1445,13 @@ class BaseEstimator(BaseObject):
def __init__(self):
"""Construct BaseEstimator."""
self._is_fitted = False
super(BaseEstimator, self).__init__()
super().__init__()

@property
def is_fitted(self):
"""Whether `fit` has been called.
"""Whether ``fit`` has been called.

Inspects object's `_is_fitted` attribute that should initialize to False
Inspects object's ``_is_fitted` attribute that should initialize to ``False``
during object construction, and be set to True in calls to an object's
`fit` method.

Expand All @@ -1460,25 +1460,43 @@ def is_fitted(self):
bool
Whether the estimator has been `fit`.
"""
return self._is_fitted
if hasattr(self, "_is_fitted"):
return self._is_fitted
else:
return False

def check_is_fitted(self):
def check_is_fitted(self, method_name=None):
"""Check if the estimator has been fitted.

Inspects object's `_is_fitted` attribute that should initialize to False
during object construction, and be set to True in calls to an object's
`fit` method.
Check if ``_is_fitted`` attribute is present and ``True``.
The ``is_fitted``
attribute should be set to ``True`` in calls to an object's ``fit`` method.

If not, raises a ``NotFittedError``.

Parameters
----------
method_name : str, optional
Name of the method that called this function. If provided, the error
message will include this information.

Raises
------
NotFittedError
If the estimator has not been fitted yet.
"""
if not self.is_fitted:
raise NotFittedError(
f"This instance of {self.__class__.__name__} has not been fitted yet. "
f"Please call `fit` first."
)
if method_name is None:
msg = (
f"This instance of {self.__class__.__name__} has not been fitted "
f"yet. Please call `fit` first."
)
else:
msg = (
f"This instance of {self.__class__.__name__} has not been fitted "
f"yet. Please call `fit` before calling `{method_name}`."
)
raise NotFittedError(msg)

def get_fitted_params(self, deep=True):
"""Get fitted parameters.
Expand All @@ -1503,19 +1521,15 @@ def get_fitted_params(self, deep=True):
Dictionary of fitted parameters, paramname : paramvalue
keys-value pairs include:

* always: all fitted parameters of this object, as via `get_param_names`
* always: all fitted parameters of this object, as via ``get_param_names``
values are fitted parameter value for that key, of this object
* if `deep=True`, also contains keys/value pairs of component parameters
parameters of components are indexed as `[componentname]__[paramname]`
all parameters of `componentname` appear as `paramname` with its value
* if `deep=True`, also contains arbitrary levels of component recursion,
e.g., `[componentname]__[componentcomponentname]__[paramname]`, etc
* if ``deep=True``, also contains keys/value pairs of component parameters
parameters of components are indexed as ``[componentname]__[paramname]``
all parameters of ``componentname`` appear as ``paramname`` with its value
* if ``deep=True``, also contains arbitrary levels of component recursion,
e.g., ``[componentname]__[componentcomponentname]__[paramname]``, etc
"""
if not self.is_fitted:
raise NotFittedError(
f"estimator of type {type(self).__name__} has not been "
"fitted yet, please call fit on data before get_fitted_params"
)
self.check_is_fitted(method_name="get_fitted_params")

# collect non-nested fitted params of self
fitted_params = self._get_fitted_params()
Expand Down
Loading