Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] _HeterogenousMetaObject to accept list of tuples of any length #206

Merged
merged 8 commits into from
Aug 14, 2023
11 changes: 7 additions & 4 deletions skbase/base/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ def _set_params(self, attr: str, **params):
# 2. Step replacement
items = getattr(self, attr)
names = []
if items:
names, _ = zip(*items)
if items and isinstance(items, (list, tuple)):
names = list(zip(*items))[0]
for name in list(params.keys()):
if "__" not in name and name in names:
self._replace_object(attr, name, params.pop(name))
Expand All @@ -247,9 +247,12 @@ def _replace_object(self, attr: str, name: str, new_val: Any) -> None:
"""Replace an object in attribute that contains named objects."""
# assumes `name` is a valid object name
new_objects = list(getattr(self, attr))
for i, (object_name, _) in enumerate(new_objects):
for i, obj_tpl in enumerate(new_objects):
object_name = obj_tpl[0]
if object_name == name:
new_objects[i] = (name, new_val)
new_tpl = list(obj_tpl)
new_tpl[1] = new_val
new_objects[i] = tuple(new_tpl)
break
setattr(self, attr, new_objects)

Expand Down
57 changes: 48 additions & 9 deletions skbase/tests/test_meta.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
# -*- coding: utf-8 -*-
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
"""Tests for BaseMetaObject and BaseMetaEstimator mixins.
"""Tests for BaseMetaObject and BaseMetaEstimator mixins."""

tests in this module:


"""

__author__ = ["RNKuhns"]
__author__ = ["RNKuhns", "fkiraly"]
import inspect

import pytest
Expand All @@ -23,37 +18,51 @@


class MetaObjectTester(BaseMetaObject):
"""Class to test meta object functionality."""
"""Class to test meta-object functionality."""

def __init__(self, a=7, b="something", c=None, steps=None):
self.a = a
self.b = b
self.c = c
self.steps = steps
super().__init__()


class MetaEstimatorTester(BaseMetaEstimator):
"""Class to test meta estimator functionality."""
"""Class to test meta-estimator functionality."""

def __init__(self, a=7, b="something", c=None, steps=None):
self.a = a
self.b = b
self.c = c
self.steps = steps
super().__init__()


class ComponentDummy(BaseObject):
Dismissed Show dismissed Hide dismissed
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Show resolved Hide resolved
"""Class to use as components in meta-estimator."""

def __init__(self, a=7, b="something"):
self.a = a
self.b = b
super().__init__()


@pytest.fixture
def fixture_metaestimator_instance():
"""BaseMetaEstimator instance fixture."""
return BaseMetaEstimator()


@pytest.fixture
def fixture_meta_object():
"""MetaObjectTester instance fixture."""
return MetaObjectTester()


@pytest.fixture
def fixture_meta_estimator():
"""MetaEstimatorTester instance fixture."""
return MetaEstimatorTester()


Expand Down Expand Up @@ -129,3 +138,33 @@ def test_basemetaestimator_check_is_fitted_raises_error_when_unfitted(

fixture_metaestimator_instance._is_fitted = True
assert fixture_metaestimator_instance.check_is_fitted() is None


@pytest.mark.parametrize("long_steps", (True, False))
def test_metaestimator_composite(long_steps):
"""Test composite meta-estimator functionality."""
if long_steps:
steps = [("foo", ComponentDummy(42)), ("bar", ComponentDummy(24))]
else:
steps = [("foo", ComponentDummy(42), 123), ("bar", ComponentDummy(24), 321)]

meta_est = MetaEstimatorTester(steps=steps)

meta_est_params = meta_est.get_params()
assert isinstance(meta_est_params, dict)
expected_keys = [
"a",
"b",
"c",
"steps",
"foo",
"bar",
"foo__a",
"foo__b",
"bar__a",
"bar__b",
]
assert set(meta_est_params.keys()) == set(expected_keys)

meta_est.set_params(bar__b="something else")
assert meta_est.get_params()["bar__b"] == "something else"
Loading