Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
augustebaum committed Oct 9, 2024
1 parent 1cb3d60 commit c49b6e6
Showing 1 changed file with 44 additions and 9 deletions.
53 changes: 44 additions & 9 deletions skore/tests/integration/test_cross_validate.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,60 @@
import pytest
import sklearn.model_selection
from skore.cross_validate import cross_validate, plot_cross_validation
from skore.item.cross_validate_item import CrossValidationItem
from skore.cross_validate import cross_validate
from skore.item.cross_validate_item import CrossValidationItem, plot_cross_validation


def test_cross_validate(in_memory_project):
@pytest.fixture()
def lasso():
from sklearn import datasets, linear_model

diabetes = datasets.load_diabetes()
X = diabetes.data[:150]
y = diabetes.target[:150]
lasso = linear_model.Lasso()
return lasso, X, y

n_splits = 3
cv_results = cross_validate(lasso, X, y, cv=n_splits, project=in_memory_project)
cv_results_sklearn = sklearn.model_selection.cross_validate(
lasso, X, y, cv=n_splits

def test_cross_validate(in_memory_project, lasso):
args = lasso
kwargs = {"cv": 3}

cv_results = cross_validate(*args, **kwargs, project=in_memory_project)
cv_results_sklearn = sklearn.model_selection.cross_validate(*args, **kwargs)

assert isinstance(
in_memory_project.get_item("cross_validation"), CrossValidationItem
)
assert cv_results.keys() == cv_results_sklearn.keys()
assert all(len(v) == kwargs["cv"] for v in cv_results.values())


def test_cross_validate_extra_metrics(in_memory_project, lasso):
args = list(lasso)
kwargs = {"scoring": "r2", "cv": 3}

cv_results = cross_validate(*args, **kwargs, project=in_memory_project)
cv_results_sklearn = sklearn.model_selection.cross_validate(*args, **kwargs)

assert isinstance(
in_memory_project.get_item("cross_validation"), CrossValidationItem
)
assert cv_results.keys() == cv_results_sklearn.keys()
assert all(len(v) == kwargs["cv"] for v in cv_results.values())


assert isinstance(in_memory_project.get_item("cross_validation"), CrossValidationItem)
def test_cross_validate_2_extra_metrics(in_memory_project, lasso):
args = list(lasso)
kwargs = {"scoring": ["r2", "neg_mean_squared_error"], "cv": 3}

cv_results = cross_validate(*args, **kwargs, project=in_memory_project)
cv_results_sklearn = sklearn.model_selection.cross_validate(*args, **kwargs)

assert isinstance(
in_memory_project.get_item("cross_validation"), CrossValidationItem
)
assert cv_results.keys() == cv_results_sklearn.keys()
assert all(len(v) == n_splits for v in cv_results.values())
assert all(len(v) == kwargs["cv"] for v in cv_results.values())


def test_plot_cross_validation():
Expand Down

0 comments on commit c49b6e6

Please sign in to comment.