From 4b869319cc8f681f504648ad2fd5860c277e09ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Mon, 30 Sep 2024 12:18:46 +0100 Subject: [PATCH] Update _base.py --- skbase/base/_base.py | 264 +++++++++++++++++++++++++++++++------------ 1 file changed, 192 insertions(+), 72 deletions(-) diff --git a/skbase/base/_base.py b/skbase/base/_base.py index ee377414..261e4890 100644 --- a/skbase/base/_base.py +++ b/skbase/base/_base.py @@ -89,10 +89,11 @@ def __init__(self): def __eq__(self, other): """Equality dunder. Checks equal class and parameters. - Returns True iff result of get_params(deep=False) + Returns True iff result of ``get_params(deep=False)`` results in equal parameter sets. - Nested BaseObject descendants from get_params are compared via __eq__ as well. + Nested BaseObject descendants from ``get_params`` are compared via + ``__eq__`` as well. """ from skbase.utils.deep_equals import deep_equals @@ -107,24 +108,33 @@ def __eq__(self, other): def reset(self): """Reset the object to a clean post-init state. - Using reset, runs __init__ with current values of hyper-parameters - (result of get_params). This Removes any object attributes, except: + Results in setting ``self`` to the state it had directly + after the constructor call, with the same hyper-parameters. + Config values set by ``set_config`` are also retained. - - hyper-parameters = arguments of __init__ - - object attributes containing double-underscores, i.e., the string "__" + A ``reset`` call deletes any object attributes, except: + + - hyper-parameters = arguments of ``__init__`` written to ``self``, + e.g., ``self.paramname`` where ``paramname`` is an argument of ``__init__`` + - object attributes containing double-underscores, i.e., the string "__". + For instance, an attribute named "__myattr" is retained. + - config attributes, configs are retained without change. + That is, results of ``get_config`` before and after ``reset`` are equal. Class and object methods, and class attributes are also unaffected. + Equivalent to ``clone``, with the exception that ``reset`` + mutates ``self`` instead of returning a new object. + + After a ``self.reset()`` call, + ``self`` is equal in value and state, to the object obtained after + a constructor call``type(self)(**self.get_params(deep=False))``. + Returns ------- self Instance of class reset to a clean post-init state but retaining the current hyper-parameter values. - - Notes - ----- - Equivalent to sklearn.clone but overwrites self. After self.reset() - call, self is equal in value to `type(self)(**self.get_params(deep=False))` """ # retrieve parameters to copy them later params = self.get_params(deep=False) @@ -149,13 +159,21 @@ def clone(self): A clone is a different object without shared references, in post-init state. This function is equivalent to returning ``sklearn.clone`` of ``self``. + Equivalent to constructing a new instance of ``type(self)``, with + parameters of ``self``, that is, + ``type(self)(**self.get_params(deep=False))``. + + If configs were set on ``self``, the clone will also have the same configs + as the original, + equivalent to calling ``cloned_self.set_config(**self.get_config())``. + + Also equivalent in value to a call of ``self.reset``, + with the exception that ``clone`` returns a new object, + instead of mutating ``self`` like ``reset``. + Raises ------ RuntimeError if the clone is non-conforming, due to faulty ``__init__``. - - Notes - ----- - If successful, equal in value to ``type(self)(**self.get_params(deep=False))``. """ self_clone = _clone(self) if self.get_config()["check_clone"]: @@ -175,7 +193,7 @@ def _get_init_signature(cls): Raises ------ - RuntimeError if cls has varargs in __init__. + RuntimeError if ``cls`` has varargs in ``__init__``. """ # fetch the constructor or the original constructor before # deprecation wrapping if any @@ -218,7 +236,7 @@ def get_param_names(cls, sort=True): Returns ------- param_names: list[str] - List of parameter names of cls. + List of parameter names of ``cls``. If ``sort=False``, in same order as they appear in the class ``__init__``. If ``sort=True``, alphabetically ordered. """ @@ -238,8 +256,9 @@ def get_param_defaults(cls): Returns ------- default_dict: dict[str, Any] - Keys are all parameters of cls that have a default defined in __init__ - values are the defaults, as defined in __init__. + Keys are all parameters of ``cls`` that have + a default defined in ``__init__``. + Values are the defaults, as defined in ``__init__``. """ parameters = cls._get_init_signature() default_dict = { @@ -255,9 +274,11 @@ def get_params(self, deep=True): deep : bool, default=True Whether to return parameters of components. - * If True, will return a dict of parameter name : value for this object, - including parameters of components (= BaseObject-valued parameters). - * If False, will return a dict of parameter name : value for this object, + * If ``True``, will return a ``dict`` of + parameter name : value for this object, + including parameters of components (= ``BaseObject``-valued parameters). + * If ``False``, will return a ``dict`` + of parameter name : value for this object, but not include parameters of components. Returns @@ -266,14 +287,14 @@ def get_params(self, deep=True): Dictionary of parameters, paramname : paramvalue keys-value pairs include: - * always: all parameters of this object, as via `get_param_names` + * always: all parameters of this object, as via ``get_param_names`` values are parameter value for that key, of this object values are always identical to values passed at construction - * if `deep=True`, also contains keys/value pairs of component parameters - parameters of components are indexed as `[componentname]__[paramname]` - all parameters of `componentname` appear as `paramname` with its value - * if `deep=True`, also contains arbitrary levels of component recursion, - e.g., `[componentname]__[componentcomponentname]__[paramname]`, etc + * if ``deep=True``, also contains keys/value pairs of component parameters + parameters of components are indexed as ``[componentname]__[paramname]`` + all parameters of ``componentname`` appear as ``paramname`` with its value + * if ``deep=True``, also contains arbitrary levels of component recursion, + e.g., ``[componentname]__[componentcomponentname]__[paramname]``, etc """ params = {key: getattr(self, key) for key in self.get_param_names()} @@ -302,7 +323,7 @@ def set_params(self, **params): ---------- **params : dict BaseObject parameters, keys must be ``__`` strings. - __ suffixes can alias full strings, if unique among get_params keys. + ``__`` suffixes can alias full strings, if unique among get_params keys. Returns ------- @@ -375,16 +396,18 @@ def _alias_params(self, d, valid_params): ------ alias_dict: dict with str keys, all keys in valid_params values are as in d, with keys replaced by following rule: - If key is a __ suffix of exactly one key in valid_params, - it is replaced by that key. Otherwise an exception is raised. - A __ suffix of a str is any str obtained as suffix from partition by __. - Else, i.e., if key is in valid_params or not a __ suffix, - the key is replaced by itself, i.e., left unchanged. + + * If key is a ``__`` suffix of exactly one key in ``valid_params``, + it is replaced by that key. Otherwise an exception is raised. + * A ``__``-suffix of a ``str`` is any ``str`` obtained as suffix + from partition by the string ``"__"``. + Else, i.e., if key is in valid_params or not a ``__``-suffix, + the key is replaced by itself, i.e., left unchanged. Raises ------ - ValueError if at least one key of d is neither contained in valid_params, - nor is it a __ suffix of exactly one key in valid_params + ValueError if at least one key of d is neither contained in ``valid_params``, + nor is it a ``__``-suffix of exactly one key in ``valid_params`` """ def _is_suffix(x, y): @@ -421,10 +444,23 @@ def _get_alias(x, d): def get_class_tags(cls): """Get class tags from the class and all its parent classes. - Retrieves tag: value pairs from _tags class attribute. Does not return - information from dynamic tags (set via set_tags or clone_tags) + Returns a dictionary with keys being keys of any attribute of ``_tags`` + set in the class or any of its parent classes, or tags set via ``set_tags`` + or ``clone_tags``. + + Values are the corresponding tag values, with overrides in the following + order of descending priority: + + 1. Tags set via ``set_tags`` or ``clone_tags`` on the instance. + 2. Tags set via ``_tags`` in the class. + 3. Tags set via ``_tags`` in parent classes, in order of inheritance. + + Does not take into account dynamic tag overrides on instances, + set via ``set_tags`` or ``clone_tags``, that are defined on instances. + For including overrides from dynamic tags, use ``get_tags``. + Returns ------- collected_tags : dict @@ -437,9 +473,19 @@ class attribute via nested inheritance. def get_class_tag(cls, tag_name, tag_value_default=None): """Get a class tag's value. - Does not return information from dynamic tags (set via set_tags or clone_tags) + Returns the value of the tag with name ``tag_name`` from the object, + taking into account tag overrides, in the following + order of descending priority: + + 1. Tags set via ``_tags`` in the class. + 2. Tags set via ``_tags`` in parent classes, in order of inheritance. + + Does not take into account dynamic tag overrides on instances, + set via ``set_tags`` or ``clone_tags``, that are defined on instances. + For including overrides from dynamic tags, use ``get_tag``. + Parameters ---------- tag_name : str @@ -450,8 +496,8 @@ def get_class_tag(cls, tag_name, tag_value_default=None): Returns ------- tag_value : - Value of the `tag_name` tag in self. If not found, returns - `tag_value_default`. + Value of the ``tag_name`` tag in self. + If not found, returns ``tag_value_default``. """ return cls._get_class_flag( flag_name=tag_name, @@ -462,18 +508,37 @@ def get_class_tag(cls, tag_name, tag_value_default=None): def get_tags(self): """Get tags from skbase class and dynamic tag overrides. + Returns a dictionary with keys being keys of any attribute of ``_tags`` + set in the class or any of its parent classes, or tags set via ``set_tags`` + or ``clone_tags``. + + Values are the corresponding tag values, with overrides in the following + order of descending priority: + + 1. Tags set via ``set_tags`` or ``clone_tags`` on the instance. + 2. Tags set via ``_tags`` in the class. + 3. Tags set via ``_tags`` in parent classes, in order of inheritance. + Returns ------- collected_tags : dict - Dictionary of tag name : tag value pairs. Collected from _tags + Dictionary of tag name : tag value pairs. Collected from ``_tags`` class attribute via nested inheritance and then any overrides - and new tags from _tags_dynamic object attribute. + and new tags from ``_tags_dynamic`` object attribute. """ return self._get_flags(flag_attr_name="_tags") def get_tag(self, tag_name, tag_value_default=None, raise_error=True): """Get tag value from object class and dynamic tag overrides. + Returns the value of the tag with name ``tag_name`` from the object, + taking into account tag overrides, in the following + order of descending priority: + + 1. Tags set via ``set_tags`` or ``clone_tags`` on the instance. + 2. Tags set via ``_tags`` in the class. + 3. Tags set via ``_tags`` in parent classes, in order of inheritance. + Parameters ---------- tag_name : str @@ -486,13 +551,14 @@ def get_tag(self, tag_name, tag_value_default=None, raise_error=True): Returns ------- tag_value : Any - Value of the `tag_name` tag in self. If not found, returns an error if - `raise_error` is True, otherwise it returns `tag_value_default`. + Value of the ``tag_name`` tag in self. If not found, returns an error if + ``raise_error`` is True, otherwise it returns ``tag_value_default``. Raises ------ - ValueError if raise_error is True i.e. if `tag_name` is not in - self.get_tags().keys() + ValueError, if ``raise_error`` is ``True``. + The ``ValueError`` is then raised if ``tag_name`` is + not in ``self.get_tags().keys()``. """ return self._get_flag( flag_name=tag_name, @@ -504,6 +570,20 @@ def get_tag(self, tag_name, tag_value_default=None, raise_error=True): def set_tags(self, **tag_dict): """Set dynamic tags to given values. + Tags are key-value pairs specific to an instance ``self``, + they are static flags that are not changed after construction + of the object. They may be used for metadata inspection, + or for controlling behaviour of the object. + + ``set_tags`` sets dynamic tag overrides + to the values + as specified in ``tag_dict``, with keys being the tag name, + and dict values being the value to set the tag to. + It should be called only in the ``__init__`` method of an object, + during construction, or directly after construction via ``__init__``. + + Current tag values can be inspected by ``get_tags`` or ``get_tag``. + Parameters ---------- **tag_dict : dict @@ -513,10 +593,6 @@ def set_tags(self, **tag_dict): ------- Self Reference to self. - - Notes - ----- - Changes object state by setting tag values in tag_dict as dynamic tags in self. """ self._set_flags(flag_attr_name="_tags", **tag_dict) @@ -525,22 +601,34 @@ def set_tags(self, **tag_dict): def clone_tags(self, estimator, tag_names=None): """Clone tags from another object as dynamic override. + Tags are key-value pairs specific to an instance ``self``, + they are static flags that are not changed after construction + of the object. They may be used for metadata inspection, + or for controlling behaviour of the object. + + ``clone_tags`` sets dynamic tag overrides + from another object, ``estimator``. + It should be called only in the ``__init__`` method of an object, + during construction, or directly after construction via ``__init__``. + + The dynamic tags are set to the values of the tags in ``estimator``, + with the names specified in ``tag_names``. + + The default of ``tag_names`` writes all tags from ``estimator`` to ``self``. + + Current tag values can be inspected by ``get_tags`` or ``get_tag``. + Parameters ---------- estimator : An instance of :class:BaseObject or derived class tag_names : str or list of str, default = None - Names of tags to clone. If None then all tags in estimator are used - as `tag_names`. + Names of tags to clone. + The default (``None``) clones all tags from ``estimator``. Returns ------- - Self : - Reference to self. - - Notes - ----- - Changes object state by setting tag values in tag_set from estimator as - dynamic tags in self. + self : + Reference to ``self``. """ self._clone_flags( estimator=estimator, flag_names=tag_names, flag_attr_name="_tags" @@ -551,6 +639,17 @@ def clone_tags(self, estimator, tag_names=None): def get_config(self): """Get config flags for self. + Configs are key-value pairs of ``self``, + typically used as transient flags for controlling behaviour. + + ``get_config`` returns dynamic configs, which override the default configs. + + Default configs are set in the class attribute ``_config`` of + the class or its parent classes, + and are overridden by dynamic configs set via ``set_config``. + + Configs are retained under ``clone`` or ``reset`` calls. + Returns ------- config_dict : dict @@ -563,6 +662,17 @@ class attribute via nested inheritance and then any overrides def set_config(self, **config_dict): """Set config flags to given values. + Configs are key-value pairs of ``self``, + typically used as transient flags for controlling behaviour. + + ``set_config`` sets dynamic configs, which override the default configs. + + Default configs are set in the class attribute ``_config`` of + the class or its parent classes, + and are overridden by dynamic configs set via ``set_config``. + + Configs are retained under ``clone`` or ``reset`` calls. + Parameters ---------- config_dict : dict @@ -584,6 +694,21 @@ def set_config(self, **config_dict): def get_test_params(cls, parameter_set="default"): """Return testing parameter settings for the skbase object. + ``get_test_params`` is a unified interface point to store + parameter settings for testing purposes. This function is also + used in ``create_test_instance`` and ``create_test_instances_and_names`` + to construct test instances. + + ``get_test_params`` should return a single ``dict``, or a ``list`` of ``dict``. + + Each ``dict`` is a parameter configuration for testing, + and can be used to construct an "interesting" test instance. + A call to ``cls(**params)`` should + be valid for all dictionaries ``params`` in the return of ``get_test_params``. + + The ``get_test_params`` need not return fixed lists of dictionaries, + it can also return dynamic or stochastic parameter settings. + Parameters ---------- parameter_set : str, default="default" @@ -630,11 +755,6 @@ def create_test_instance(cls, parameter_set="default"): ------- instance : instance of the class with default parameters - Notes - ----- - `get_test_params` can return dict or list of dict. - This function takes first or single dict that get_test_params returns, and - constructs the object with that. """ if "parameter_set" in inspect.getfullargspec(cls.get_test_params).args: params = cls.get_test_params(parameter_set=parameter_set) @@ -665,11 +785,11 @@ def create_test_instances_and_names(cls, parameter_set="default"): Returns ------- objs : list of instances of cls - i-th instance is cls(**cls.get_test_params()[i]) + i-th instance is ``cls(**cls.get_test_params()[i])`` names : list of str, same length as objs - i-th element is name of i-th instance of obj in tests - convention is {cls.__name__}-{i} if more than one instance - otherwise {cls.__name__} + i-th element is name of i-th instance of obj in tests. + The naming convention is ``{cls.__name__}-{i}`` if more than one instance, + otherwise ``{cls.__name__}`` """ if "parameter_set" in inspect.getfullargspec(cls.get_test_params).args: param_list = cls.get_test_params(parameter_set=parameter_set) @@ -760,7 +880,7 @@ def is_composite(self): ------- composite: bool Whether an object has any parameters whose values - are BaseObjects. + are ``BaseObject`` descendant instances. """ # walk through method resolution order and inspect methods # of classes and direct parents, "adjacent" classes in mro @@ -772,7 +892,7 @@ def is_composite(self): def _components(self, base_class=None): """Return references to all state changing BaseObject type attributes. - This *excludes* the blue-print-like components passed in the __init__. + This *excludes* the blue-print-like components passed in the ``__init__``. Caution: this method returns *references* and not *copies*. Writing to the reference will change the respective attribute of self. @@ -780,7 +900,7 @@ def _components(self, base_class=None): Parameters ---------- base_class : class, optional, default=None, must be subclass of BaseObject - if not None, sub-sets return dict to only descendants of base_class + if not ``None``, sub-sets return dict to only descendants of ``base_class`` Returns -------