diff --git a/examples/basic_usage.ipynb b/examples/basic_usage.ipynb
index 5e0c9153..db3199b1 100644
--- a/examples/basic_usage.ipynb
+++ b/examples/basic_usage.ipynb
@@ -631,13 +631,39 @@
"cell_type": "markdown",
"id": "57",
"metadata": {},
+ "source": [
+ "---\n",
+ "## Cross-validation with skore"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "58",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from sklearn import datasets, linear_model\n",
+ "from skore.cross_validate import cross_validate\n",
+ "diabetes = datasets.load_diabetes()\n",
+ "X = diabetes.data[:150]\n",
+ "y = diabetes.target[:150]\n",
+ "lasso = linear_model.Lasso()\n",
+ "\n",
+ "cv_results = cross_validate(lasso, X, y, cv=3, project=project)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "59",
+ "metadata": {},
"source": [
"_Stay tuned for some new features!_"
]
},
{
"cell_type": "markdown",
- "id": "58",
+ "id": "60",
"metadata": {},
"source": [
"---\n",
@@ -649,7 +675,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "59",
+ "id": "61",
"metadata": {},
"outputs": [],
"source": [
@@ -662,7 +688,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "60",
+ "id": "62",
"metadata": {},
"outputs": [],
"source": [
@@ -675,7 +701,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "61",
+ "id": "63",
"metadata": {},
"outputs": [],
"source": [
@@ -687,7 +713,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "62",
+ "id": "64",
"metadata": {},
"outputs": [],
"source": [
@@ -700,7 +726,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "63",
+ "id": "65",
"metadata": {},
"outputs": [],
"source": [
@@ -713,7 +739,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "64",
+ "id": "66",
"metadata": {},
"outputs": [],
"source": [
@@ -726,7 +752,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "65",
+ "id": "67",
"metadata": {},
"outputs": [],
"source": []
diff --git a/examples/basic_usage.py b/examples/basic_usage.py
index 964997fc..c117fe72 100644
--- a/examples/basic_usage.py
+++ b/examples/basic_usage.py
@@ -303,6 +303,20 @@ def my_func(x):
project.put("my_fitted_pipeline", my_pipeline)
+# %% [markdown]
+# ---
+# ## Cross-validation with skore
+
+# %%
+from sklearn import datasets, linear_model
+from skore.cross_validate import cross_validate
+diabetes = datasets.load_diabetes()
+X = diabetes.data[:150]
+y = diabetes.target[:150]
+lasso = linear_model.Lasso()
+
+cv_results = cross_validate(lasso, X, y, cv=3, project=project)
+
# %% [markdown]
# _Stay tuned for some new features!_
diff --git a/skore-ui/src/components/VegaWidget.vue b/skore-ui/src/components/VegaWidget.vue
index 1b3f0744..191bbd44 100644
--- a/skore-ui/src/components/VegaWidget.vue
+++ b/skore-ui/src/components/VegaWidget.vue
@@ -1,8 +1,8 @@
diff --git a/skore/pyproject.toml b/skore/pyproject.toml
index cbe6f60a..621d5803 100644
--- a/skore/pyproject.toml
+++ b/skore/pyproject.toml
@@ -9,6 +9,7 @@ dynamic = [
requires-python = ">=3.9, <3.13"
maintainers = [{name = "skore developers", email="skore@signal.probabl.ai"}]
dependencies = [
+ "altair>=5,<6",
"diskcache",
"fastapi",
"rich",
@@ -61,7 +62,6 @@ artifacts = ["src/skore/ui/static/"]
[project.optional-dependencies]
test = [
- "altair",
"httpx",
"matplotlib",
"pandas",
diff --git a/skore/src/skore/__init__.py b/skore/src/skore/__init__.py
index 3315055f..9ecccba7 100644
--- a/skore/src/skore/__init__.py
+++ b/skore/src/skore/__init__.py
@@ -4,11 +4,13 @@
import rich.logging
+from skore.cross_validate import cross_validate
from skore.project import Project, load
from .utils._show_versions import show_versions
__all__ = [
+ "cross_validate",
"load",
"show_versions",
"Project",
diff --git a/skore/src/skore/cross_validate.py b/skore/src/skore/cross_validate.py
new file mode 100644
index 00000000..d06d25c4
--- /dev/null
+++ b/skore/src/skore/cross_validate.py
@@ -0,0 +1,234 @@
+"""cross_validate function.
+
+This function implements a wrapper over scikit-learn's [cross_validate](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_validate.html#sklearn.model_selection.cross_validate)
+function in order to enrich it with more information and enable more analysis.
+"""
+
+import contextlib
+from typing import Literal, Optional
+
+from skore.item.cross_validate_item import CrossValidationItem
+from skore.project import Project
+
+
+def _find_ml_task(
+ estimator, y
+) -> Literal[
+ "binary-classification",
+ "multiclass-classification",
+ "regression",
+ "clustering",
+ "unknown",
+]:
+ """Guess the ML task being addressed based on an estimator and a target array.
+
+ Parameters
+ ----------
+ estimator : sklearn.base.BaseEstimator
+ An estimator.
+ y : numpy.ndarray
+ A target vector.
+
+ Returns
+ -------
+ Literal["classification", "regression", "clustering", "unknown"]
+ The guess of the kind of ML task being performed.
+ """
+ import sklearn.utils.multiclass
+ from sklearn.base import is_classifier, is_regressor
+
+ if y is None:
+ # NOTE: The task might not be clustering
+ return "clustering"
+
+ if is_regressor(estimator):
+ return "regression"
+
+ type_of_target = sklearn.utils.multiclass.type_of_target(y)
+
+ if is_classifier(estimator):
+ if type_of_target == "binary":
+ return "binary-classification"
+
+ if type_of_target == "multiclass":
+ return "multiclass-classification"
+
+ if type_of_target == "unknown":
+ return "unknown"
+
+ if "continuous" in type_of_target:
+ return "regression"
+
+ return "classification"
+
+
+def _add_scorers(estimator, y, scorers):
+ """Expand `scorers` with other scorers, based on `estimator` and `y`.
+
+ Parameters
+ ----------
+ estimator : sklearn.base.BaseEstimator
+ An estimator.
+ scorers : any type that is accepted by scikit-learn's cross_validate
+ The scorer(s) to expand.
+ y : numpy.ndarray
+ A target vector.
+
+ Returns
+ -------
+ new_scorers : dict[str, str | None]
+ The scorers after adding `scorers_to_add`.
+ added_scorers : list[str]
+ The scorers that were actually added (i.e. the ones that were not already
+ in `scorers`).
+ """
+ ml_task = _find_ml_task(estimator, y)
+
+ # Add scorers based on the ML task
+ if ml_task == "regression":
+ scorers_to_add = ["r2", "neg_mean_squared_error"]
+ elif ml_task == "binary-classification":
+ scorers_to_add = ["roc_auc", "neg_brier_score", "recall", "precision"]
+ elif ml_task == "multiclass-classification":
+ scorers_to_add = ["recall_weighted", "precision_weighted"]
+
+ if hasattr(estimator, "predict_proba"):
+ scorers_to_add += ["roc_auc_ovr_weighted", "neg_log_loss"]
+ else:
+ scorers_to_add = []
+
+ added_scorers = []
+
+ if scorers is None:
+ new_scorers = {"score": None}
+ for s in scorers_to_add:
+ new_scorers[s] = s
+ added_scorers.append(s)
+ elif isinstance(scorers, str):
+ new_scorers = {"score": scorers}
+ for s in scorers_to_add:
+ if s == scorers:
+ continue
+ new_scorers[s] = s
+ added_scorers.append(s)
+ elif isinstance(scorers, dict):
+ new_scorers = scorers.copy()
+ for s in scorers_to_add:
+ if s in scorers:
+ continue
+ new_scorers[s] = s
+ added_scorers.append(s)
+ elif isinstance(scorers, list):
+ new_scorers = scorers.copy()
+ for s in scorers_to_add:
+ if s in scorers:
+ continue
+ new_scorers.append(s)
+ added_scorers.append(s)
+ elif isinstance(scorers, tuple):
+ scorers = list(scorers)
+ new_scorers, added_scorers = _add_scorers(scorers, scorers_to_add)
+
+ return new_scorers, added_scorers
+
+
+def _strip_cv_results_scores(cv_results: dict, added_scorers: list[str]) -> dict:
+ """Remove information about `added_scorers` in `cv_results`.
+
+ Parameters
+ ----------
+ cv_results : dict
+ A dict of the form returned by scikit-learn's cross_validate function.
+ added_scorers : list[str]
+ A list of scorers in `cv_results` which should be removed.
+
+ Returns
+ -------
+ dict
+ A new cv_results dict, with the specified scorers information removed.
+ """
+ # Takes care both of train and test scores
+ return {
+ k: v
+ for k, v in cv_results.items()
+ if not any(added_scorer in k for added_scorer in added_scorers)
+ }
+
+
+def cross_validate(
+ *args, project: Optional[Project] = None, **kwargs
+) -> CrossValidationItem:
+ """Evaluate estimator by cross-validation and output UI-friendly object.
+
+ This function wraps scikit-learn's [cross_validate](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_validate.html#sklearn.model_selection.cross_validate)
+ function, to provide more context and facilitate the analysis.
+ As such, the arguments are the same as scikit-learn's cross_validate function.
+
+ Parameters
+ ----------
+ The same parameters as scikit-learn's cross_validate function, except for
+
+ project : Project, optional
+ A project to save cross-validation data into. If None, no save is performed.
+
+ Returns
+ -------
+ cv_results : dict
+ A dict of the form returned by scikit-learn's cross_validate function.
+
+ Examples
+ --------
+ >>> def prepare_cv():
+ ... from sklearn import datasets, linear_model
+ ... diabetes = datasets.load_diabetes()
+ ... X = diabetes.data[:150]
+ ... y = diabetes.target[:150]
+ ... lasso = linear_model.Lasso()
+ ... return lasso, X, y
+
+ >>> project = skore.load("project.skore") # doctest: +SKIP
+ >>> lasso, X, y = prepare_cv() # doctest: +SKIP
+ >>> cross_validate(lasso, X, y, cv=3, project=project) # doctest: +SKIP
+ alt.Chart(...)
+ {'fit_time': array(...), 'score_time': array(...), 'test_score': array(...)}
+ """
+ import sklearn.model_selection
+
+ # Recover specific arguments
+ estimator = args[0] if len(args) >= 1 else kwargs.get("estimator")
+ X = args[1] if len(args) >= 2 else kwargs.get("X")
+ y = args[2] if len(args) == 3 else kwargs.get("y")
+
+ try:
+ scorers = kwargs.pop("scoring")
+ except KeyError:
+ scorers = None
+
+ # Extend scorers with other relevant scorers
+ new_scorers, added_scorers = _add_scorers(estimator, y, scorers)
+
+ cv_results = sklearn.model_selection.cross_validate(
+ *args, **kwargs, scoring=new_scorers
+ )
+ # Add explicit metric to result (rather than just "test_score")
+ if isinstance(scorers, str):
+ if kwargs.get("return_train_score") is not None:
+ cv_results[f"train_{scorers}"] = cv_results["train_score"]
+ cv_results[f"test_{scorers}"] = cv_results["test_score"]
+
+ cross_validation_item = CrossValidationItem.factory(cv_results, estimator, X, y)
+
+ if project is not None:
+ project.put_item("cross_validation", cross_validation_item)
+
+ # If in a IPython context (e.g. Jupyter notebook), display the plot
+ with contextlib.suppress(ImportError):
+ from IPython.display import display
+
+ display(cross_validation_item.plot)
+
+ # Remove information related to our scorers, so that our return value is
+ # the same as sklearn's
+ stripped_cv_results = _strip_cv_results_scores(cv_results, added_scorers)
+
+ return stripped_cv_results
diff --git a/skore/src/skore/item/__init__.py b/skore/src/skore/item/__init__.py
index a2fd83c8..2a3e5480 100644
--- a/skore/src/skore/item/__init__.py
+++ b/skore/src/skore/item/__init__.py
@@ -5,6 +5,7 @@
from contextlib import suppress
from typing import Any
+from skore.item.cross_validate_item import CrossValidationItem
from skore.item.item import Item
from skore.item.item_repository import ItemRepository
from skore.item.media_item import MediaItem
@@ -23,6 +24,7 @@ def object_to_item(object: Any) -> Item:
PandasSeriesItem,
NumpyArrayItem,
SklearnBaseEstimatorItem,
+ CrossValidationItem,
MediaItem,
):
with suppress(ImportError, TypeError):
diff --git a/skore/src/skore/item/cross_validate_item.py b/skore/src/skore/item/cross_validate_item.py
new file mode 100644
index 00000000..76d47007
--- /dev/null
+++ b/skore/src/skore/item/cross_validate_item.py
@@ -0,0 +1,233 @@
+"""CrossValidationItem class.
+
+This class represents the output of a cross-validation workflow.
+"""
+
+from __future__ import annotations
+
+import contextlib
+import hashlib
+from functools import cached_property
+from typing import TYPE_CHECKING, Any
+
+import altair
+import numpy
+
+from skore.item.item import Item
+
+if TYPE_CHECKING:
+ import sklearn.base
+
+
+def plot_cross_validation(cv_results: dict) -> altair.Chart:
+ """Plot the result of a cross-validation run.
+
+ Parameters
+ ----------
+ cv_results : dict
+ The output of scikit-learn's cross_validate function.
+
+ Returns
+ -------
+ altair.Chart
+ A plot of the cross-validation results
+ """
+ import altair
+ import pandas
+
+ _cv_results = cv_results.copy()
+
+ with contextlib.suppress(KeyError):
+ del _cv_results["indices"]
+ del _cv_results["estimator"]
+
+ df = (
+ pandas.DataFrame(_cv_results)
+ .reset_index(names="split")
+ .melt(id_vars="split", var_name="metric", value_name="score")
+ )
+
+ input_dropdown = altair.binding_select(
+ options=df["metric"].unique().tolist(), name="Metric: "
+ )
+ selection = altair.selection_point(
+ fields=["metric"], bind=input_dropdown, value="test_score"
+ )
+
+ return (
+ altair.Chart(df, title="Cross-validation scores per split")
+ .mark_bar()
+ .encode(
+ altair.X("split:N").axis(
+ title="Split number",
+ labelAngle=0,
+ ),
+ altair.Y("score:Q").axis(
+ title="Score",
+ titleAngle=0,
+ titleAlign="left",
+ titleX=0,
+ titleY=-5,
+ labelLimit=300,
+ ),
+ tooltip=["metric:N", "split:N", "score:Q"],
+ )
+ .interactive()
+ .add_params(selection)
+ .transform_filter(selection)
+ .properties(
+ width=500,
+ height=200,
+ padding=15,
+ autosize=altair.AutoSizeParams(type="pad", contains="padding"),
+ )
+ )
+
+
+def _hash_numpy(arr: numpy.ndarray) -> str:
+ """Compute a hash string from a numpy array.
+
+ Parameters
+ ----------
+ arr : numpy array
+ The numpy array whose hash will be computed.
+
+ Returns
+ -------
+ hash : str
+ A hash corresponding to the input array.
+ """
+ return hashlib.sha256(bytes(memoryview(arr))).hexdigest()
+
+
+# Data used for training, passed as input to scikit-learn
+Data = Any
+# Target used for training, passed as input to scikit-learn
+Target = Any
+
+
+class CrossValidationItem(Item):
+ """
+ A class to represent the output of a cross-validation workflow.
+
+ This class encapsulates the output of scikit-learn's cross-validate function along
+ with its creation and update timestamps.
+ """
+
+ def __init__(
+ self,
+ cv_results_serialized: dict,
+ estimator_info: dict,
+ X_info: dict,
+ y_info: dict,
+ plot_bytes: bytes,
+ created_at: str | None = None,
+ updated_at: str | None = None,
+ ):
+ """
+ Initialize a CrossValidationItem.
+
+ Parameters
+ ----------
+ cv_results_serialized : dict
+ The dict output of scikit-learn's cross_validate function,
+ in a form suitable for serialization.
+ estimator_info : dict
+ The estimator that was cross-validated.
+ X_info : dict
+ A summary of the data, input of scikit-learn's cross_validation function.
+ y_info : dict
+ A summary of the target, input of scikit-learn's cross_validation function.
+ plot_bytes : bytes
+ A plot of the cross-validation results, in the form of bytes.
+ created_at : str
+ The creation timestamp in ISO format.
+ updated_at : str
+ The last update timestamp in ISO format.
+ """
+ super().__init__(created_at, updated_at)
+
+ self.cv_results_serialized = cv_results_serialized
+ self.estimator_info = estimator_info
+ self.X_info = X_info
+ self.y_info = y_info
+ self.plot_bytes = plot_bytes
+
+ @classmethod
+ def factory(
+ cls,
+ cv_results: dict,
+ estimator: sklearn.base.BaseEstimator,
+ X: Data,
+ y: Target | None,
+ ) -> CrossValidationItem:
+ """
+ Create a new CrossValidationItem instance.
+
+ Parameters
+ ----------
+ cv_results : dict
+ The dict output of scikit-learn's cross_validate function.
+ estimator : sklearn.base.BaseEstimator,
+ The estimator that was cross-validated.
+ X
+ The data, input of scikit-learn's cross_validation function.
+ y
+ The target, input of scikit-learn's cross_validation function.
+
+ Returns
+ -------
+ CrossValidationItem
+ A new CrossValidationItem instance.
+ """
+ if not isinstance(cv_results, dict):
+ raise TypeError(f"Type '{cv_results.__class__}' is not supported.")
+
+ cv_results_serialized = {}
+ for k, v in cv_results.items():
+ if k == "estimator":
+ continue
+ if k == "indices":
+ cv_results_serialized["indices"] = {
+ "train": tuple(arr.tolist() for arr in v["train"]),
+ "test": tuple(arr.tolist() for arr in v["test"]),
+ }
+ if isinstance(v, numpy.ndarray):
+ cv_results_serialized[k] = v.tolist()
+
+ estimator_info = {
+ "name": estimator.__class__.__name__,
+ "params": repr(estimator.get_params()),
+ }
+
+ y_array = y if isinstance(y, numpy.ndarray) else numpy.array(y)
+ y_info = None if y is None else {"hash": _hash_numpy(y_array)}
+
+ X_array = X if isinstance(X, numpy.ndarray) else numpy.array(X)
+ X_info = {
+ "nb_rows": X_array.shape[0],
+ "nb_cols": X_array.shape[1],
+ "hash": _hash_numpy(X_array),
+ }
+
+ # Keep plot itself as well as bytes so we can cache it
+ plot = plot_cross_validation(cv_results)
+ plot_bytes = plot.to_json().encode("utf-8")
+
+ instance = cls(
+ cv_results_serialized=cv_results_serialized,
+ estimator_info=estimator_info,
+ X_info=X_info,
+ y_info=y_info,
+ plot_bytes=plot_bytes,
+ )
+
+ # Cache plot
+ instance.plot = plot
+
+ return instance
+
+ @cached_property
+ def plot(self):
+ """A plot of the cross-validation results."""
+ return altair.Chart.from_json(self.plot_bytes.decode("utf-8"))
diff --git a/skore/src/skore/item/item_repository.py b/skore/src/skore/item/item_repository.py
index 80554dd9..ce73a011 100644
--- a/skore/src/skore/item/item_repository.py
+++ b/skore/src/skore/item/item_repository.py
@@ -13,6 +13,7 @@
from skore.persistence.abstract_storage import AbstractStorage
+from skore.item.cross_validate_item import CrossValidationItem
from skore.item.media_item import MediaItem
from skore.item.numpy_array_item import NumpyArrayItem
from skore.item.pandas_dataframe_item import PandasDataFrameItem
@@ -34,6 +35,7 @@ class ItemRepository:
"PandasDataFrameItem": PandasDataFrameItem,
"PandasSeriesItem": PandasSeriesItem,
"PrimitiveItem": PrimitiveItem,
+ "CrossValidationItem": CrossValidationItem,
"SklearnBaseEstimatorItem": SklearnBaseEstimatorItem,
}
diff --git a/skore/src/skore/project.py b/skore/src/skore/project.py
index e0701deb..011a7885 100644
--- a/skore/src/skore/project.py
+++ b/skore/src/skore/project.py
@@ -6,6 +6,7 @@
from typing import Any, Literal, Union
from skore.item import (
+ CrossValidationItem,
Item,
ItemRepository,
MediaItem,
@@ -125,6 +126,8 @@ def get(self, key: str) -> Any:
return item.series
elif isinstance(item, SklearnBaseEstimatorItem):
return item.estimator
+ elif isinstance(item, CrossValidationItem):
+ return item.cv_results_serialized
elif isinstance(item, MediaItem):
return item.media_bytes
else:
diff --git a/skore/src/skore/ui/project_routes.py b/skore/src/skore/ui/project_routes.py
index aa04aaa7..3921500c 100644
--- a/skore/src/skore/ui/project_routes.py
+++ b/skore/src/skore/ui/project_routes.py
@@ -9,6 +9,7 @@
from fastapi.params import Depends
from fastapi.templating import Jinja2Templates
+from skore.item.cross_validate_item import CrossValidationItem
from skore.item.media_item import MediaItem
from skore.item.numpy_array_item import NumpyArrayItem
from skore.item.pandas_dataframe_item import PandasDataFrameItem
@@ -69,6 +70,11 @@ def __serialize_project(project: Project) -> SerializedProject:
elif isinstance(item, MediaItem):
value = base64.b64encode(item.media_bytes).decode()
media_type = item.media_type
+ elif isinstance(item, CrossValidationItem):
+ # Convert plot to MediaItem
+ item = MediaItem.factory(item.plot)
+ value = base64.b64encode(item.media_bytes).decode()
+ media_type = item.media_type
else:
raise ValueError(f"Item {item} is not a known item type.")
diff --git a/skore/tests/conftest.py b/skore/tests/conftest.py
index 79a645f4..6b2e3e0a 100644
--- a/skore/tests/conftest.py
+++ b/skore/tests/conftest.py
@@ -1,6 +1,10 @@
from datetime import datetime, timezone
import pytest
+from skore.item.item_repository import ItemRepository
+from skore.persistence.in_memory_storage import InMemoryStorage
+from skore.project import Project
+from skore.view.view_repository import ViewRepository
@pytest.fixture
@@ -23,3 +27,13 @@ def now(*args, **kwargs):
return mock_now
return MockDatetime
+
+
+@pytest.fixture
+def in_memory_project():
+ item_repository = ItemRepository(storage=InMemoryStorage())
+ view_repository = ViewRepository(storage=InMemoryStorage())
+ return Project(
+ item_repository=item_repository,
+ view_repository=view_repository,
+ )
diff --git a/skore/tests/integration/test_cross_validate.py b/skore/tests/integration/test_cross_validate.py
new file mode 100644
index 00000000..da644707
--- /dev/null
+++ b/skore/tests/integration/test_cross_validate.py
@@ -0,0 +1,196 @@
+import numpy
+import pandas
+import pytest
+import sklearn.model_selection
+from numpy import array
+from sklearn import datasets, linear_model
+from sklearn.cluster import KMeans
+from sklearn.ensemble import RandomForestClassifier
+from sklearn.multiclass import OneVsOneClassifier
+from sklearn.svm import SVC
+from skore.cross_validate import cross_validate
+from skore.item.cross_validate_item import CrossValidationItem, plot_cross_validation
+
+
+@pytest.fixture()
+def lasso():
+ diabetes = datasets.load_diabetes()
+ X = diabetes.data[:150]
+ y = diabetes.target[:150]
+ lasso = linear_model.Lasso()
+ return lasso, X, y
+
+
+def test_cross_validate_regression(in_memory_project, lasso):
+ args = lasso
+ kwargs = {"cv": 3}
+
+ cv_results = cross_validate(*args, **kwargs, project=in_memory_project)
+ cv_results_sklearn = sklearn.model_selection.cross_validate(*args, **kwargs)
+
+ assert isinstance(
+ in_memory_project.get_item("cross_validation"), CrossValidationItem
+ )
+ assert cv_results.keys() == cv_results_sklearn.keys()
+ assert all(len(v) == kwargs["cv"] for v in cv_results.values())
+
+
+def test_cross_validate_regression_data_is_list(in_memory_project, lasso):
+ model, X, y = lasso
+ args = [model, X.tolist(), y.tolist()]
+ kwargs = {"cv": 3}
+
+ cv_results = cross_validate(*args, **kwargs, project=in_memory_project)
+ cv_results_sklearn = sklearn.model_selection.cross_validate(*args, **kwargs)
+
+ assert isinstance(
+ in_memory_project.get_item("cross_validation"), CrossValidationItem
+ )
+ assert cv_results.keys() == cv_results_sklearn.keys()
+ assert all(len(v) == kwargs["cv"] for v in cv_results.values())
+
+
+def test_cross_validate_regression_data_is_pandas(in_memory_project, lasso):
+ model, X, y = lasso
+ args = [model, pandas.DataFrame(X), pandas.Series(y)]
+ kwargs = {"cv": 3}
+
+ cv_results = cross_validate(*args, **kwargs, project=in_memory_project)
+ cv_results_sklearn = sklearn.model_selection.cross_validate(*args, **kwargs)
+
+ assert isinstance(
+ in_memory_project.get_item("cross_validation"), CrossValidationItem
+ )
+ assert cv_results.keys() == cv_results_sklearn.keys()
+ assert all(len(v) == kwargs["cv"] for v in cv_results.values())
+
+
+def test_cross_validate_regression_extra_metrics(in_memory_project, lasso):
+ args = list(lasso)
+ kwargs = {"scoring": "r2", "cv": 3}
+
+ cv_results = cross_validate(*args, **kwargs, project=in_memory_project)
+ cv_results_sklearn = sklearn.model_selection.cross_validate(*args, **kwargs)
+
+ assert isinstance(
+ in_memory_project.get_item("cross_validation"), CrossValidationItem
+ )
+ assert set(cv_results.keys()).issuperset(cv_results_sklearn.keys())
+ assert all(len(v) == kwargs["cv"] for v in cv_results.values())
+
+
+def test_cross_validate_2_extra_metrics(in_memory_project, lasso):
+ args = list(lasso)
+ kwargs = {"scoring": ["r2", "neg_mean_squared_error"], "cv": 3}
+
+ cv_results = cross_validate(*args, **kwargs, project=in_memory_project)
+ cv_results_sklearn = sklearn.model_selection.cross_validate(*args, **kwargs)
+
+ assert isinstance(
+ in_memory_project.get_item("cross_validation"), CrossValidationItem
+ )
+ assert cv_results.keys() == cv_results_sklearn.keys()
+ assert all(len(v) == kwargs["cv"] for v in cv_results.values())
+
+
+def test_cross_validation_binary_classification_no_predict_proba(in_memory_project):
+ X, y = datasets.load_iris(return_X_y=True)
+ model = SVC()
+
+ args = [model, X, y]
+ kwargs = {"cv": 3}
+
+ cv_results = cross_validate(*args, **kwargs, project=in_memory_project)
+ cv_results_sklearn = sklearn.model_selection.cross_validate(*args, **kwargs)
+
+ assert isinstance(
+ in_memory_project.get_item("cross_validation"), CrossValidationItem
+ )
+ assert cv_results.keys() == cv_results_sklearn.keys()
+ assert all(len(v) == kwargs["cv"] for v in cv_results.values())
+
+
+def test_cross_validation_multi_class_classification_sub_estimator(in_memory_project):
+ X, y = datasets.load_iris(return_X_y=True)
+ model = OneVsOneClassifier(SVC())
+
+ args = [model, X, y]
+ kwargs = {"cv": 3}
+
+ cv_results = cross_validate(*args, **kwargs, project=in_memory_project)
+ cv_results_sklearn = sklearn.model_selection.cross_validate(*args, **kwargs)
+
+ assert isinstance(
+ in_memory_project.get_item("cross_validation"), CrossValidationItem
+ )
+ assert cv_results.keys() == cv_results_sklearn.keys()
+ assert all(len(v) == kwargs["cv"] for v in cv_results.values())
+
+
+def test_cross_validation_multi_class_classification(in_memory_project):
+ iris = datasets.load_iris()
+ X = iris.data[:250]
+ y = iris.target[:250]
+ rf = RandomForestClassifier()
+
+ args = [rf, X, y]
+ kwargs = {"cv": 3}
+
+ cv_results = cross_validate(*args, **kwargs, project=in_memory_project)
+ cv_results_sklearn = sklearn.model_selection.cross_validate(*args, **kwargs)
+
+ assert isinstance(
+ in_memory_project.get_item("cross_validation"), CrossValidationItem
+ )
+ assert cv_results.keys() == cv_results_sklearn.keys()
+ assert all(len(v) == kwargs["cv"] for v in cv_results.values())
+
+
+def test_cross_validation_binary_classification(in_memory_project):
+ iris = datasets.load_iris()
+ X = iris.data[:150]
+ y = numpy.random.randint(2, size=150)
+ rf = RandomForestClassifier()
+
+ args = [rf, X, y]
+ kwargs = {"cv": 3}
+
+ cv_results = cross_validate(*args, **kwargs, project=in_memory_project)
+ cv_results_sklearn = sklearn.model_selection.cross_validate(*args, **kwargs)
+
+ assert isinstance(
+ in_memory_project.get_item("cross_validation"), CrossValidationItem
+ )
+ assert cv_results.keys() == cv_results_sklearn.keys()
+ assert all(len(v) == kwargs["cv"] for v in cv_results.values())
+
+
+def test_cross_validation_clustering(in_memory_project):
+ iris = datasets.load_iris()
+ X = iris.data[:150]
+ kmeans = KMeans()
+
+ args = [kmeans, X]
+ kwargs = {"cv": 3}
+
+ cv_results = cross_validate(*args, **kwargs, project=in_memory_project)
+ cv_results_sklearn = sklearn.model_selection.cross_validate(*args, **kwargs)
+
+ assert isinstance(
+ in_memory_project.get_item("cross_validation"), CrossValidationItem
+ )
+ assert cv_results.keys() == cv_results_sklearn.keys()
+ assert all(len(v) == kwargs["cv"] for v in cv_results.values())
+
+
+def test_plot_cross_validation():
+ cv_results = {
+ "fit_time": array([0.00058246, 0.00041819, 0.00039363]),
+ "score_time": array([0.00101399, 0.00072646, 0.00072432]),
+ "test_score": array([0.3315057, 0.08022103, 0.03531816]),
+ "test_r2": array([0.3315057, 0.08022103, 0.03531816]),
+ "test_neg_mean_squared_error": array(
+ [-3635.52042005, -3573.35050281, -6114.77901585]
+ ),
+ }
+ plot_cross_validation(cv_results)
diff --git a/skore/tests/integration/ui/test_ui.py b/skore/tests/integration/ui/test_ui.py
index a1dc11e1..0212ce21 100644
--- a/skore/tests/integration/ui/test_ui.py
+++ b/skore/tests/integration/ui/test_ui.py
@@ -1,26 +1,12 @@
import pytest
from fastapi.testclient import TestClient
-from skore.item.item_repository import ItemRepository
-from skore.persistence.in_memory_storage import InMemoryStorage
-from skore.project import Project
from skore.ui.app import create_app
from skore.view.view import View
-from skore.view.view_repository import ViewRepository
@pytest.fixture
-def project():
- item_repository = ItemRepository(storage=InMemoryStorage())
- view_repository = ViewRepository(storage=InMemoryStorage())
- return Project(
- item_repository=item_repository,
- view_repository=view_repository,
- )
-
-
-@pytest.fixture
-def client(project):
- return TestClient(app=create_app(project=project))
+def client(in_memory_project):
+ return TestClient(app=create_app(project=in_memory_project))
def test_app_state(client):
@@ -34,14 +20,14 @@ def test_skore_ui_index(client):
assert b"" in response.content
-def test_get_items(client, project):
+def test_get_items(client, in_memory_project):
response = client.get("/api/project/items")
assert response.status_code == 200
assert response.json() == {"views": {}, "items": {}}
- project.put("test", "test")
- item = project.get_item("test")
+ in_memory_project.put("test", "test")
+ item = in_memory_project.get_item("test")
response = client.get("/api/project/items")
assert response.status_code == 200
@@ -58,8 +44,8 @@ def test_get_items(client, project):
}
-def test_share_view(client, project):
- project.put_view("hello", View(layout=[]))
+def test_share_view(client, in_memory_project):
+ in_memory_project.put_view("hello", View(layout=[]))
response = client.post("/api/project/views/share?key=hello")
assert response.status_code == 200
@@ -76,8 +62,8 @@ def test_put_view_layout(client):
assert response.status_code == 201
-def test_delete_view(client, project):
- project.put_view("hello", View(layout=[]))
+def test_delete_view(client, in_memory_project):
+ in_memory_project.put_view("hello", View(layout=[]))
response = client.delete("/api/project/views?key=hello")
assert response.status_code == 202
diff --git a/skore/tests/unit/test_project.py b/skore/tests/unit/test_project.py
index b65ff17b..a60ae202 100644
--- a/skore/tests/unit/test_project.py
+++ b/skore/tests/unit/test_project.py
@@ -10,71 +10,62 @@
from matplotlib import pyplot as plt
from PIL import Image
from sklearn.ensemble import RandomForestClassifier
-from skore.item import ItemRepository
-from skore.persistence.in_memory_storage import InMemoryStorage
from skore.project import Project, ProjectLoadError, ProjectPutError, load
from skore.view.view import View
-from skore.view.view_repository import ViewRepository
-@pytest.fixture
-def project():
- return Project(
- item_repository=ItemRepository(InMemoryStorage()),
- view_repository=ViewRepository(InMemoryStorage()),
- )
-
+def test_put_string_item(in_memory_project):
+ in_memory_project.put("string_item", "Hello, World!")
+ assert in_memory_project.get("string_item") == "Hello, World!"
-def test_put_string_item(project):
- project.put("string_item", "Hello, World!")
- assert project.get("string_item") == "Hello, World!"
+def test_put_int_item(in_memory_project):
+ in_memory_project.put("int_item", 42)
+ assert in_memory_project.get("int_item") == 42
-def test_put_int_item(project):
- project.put("int_item", 42)
- assert project.get("int_item") == 42
+def test_put_float_item(in_memory_project):
+ in_memory_project.put("float_item", 3.14)
+ assert in_memory_project.get("float_item") == 3.14
-def test_put_float_item(project):
- project.put("float_item", 3.14)
- assert project.get("float_item") == 3.14
+def test_put_bool_item(in_memory_project):
+ in_memory_project.put("bool_item", True)
+ assert in_memory_project.get("bool_item") is True
-def test_put_bool_item(project):
- project.put("bool_item", True)
- assert project.get("bool_item") is True
+def test_put_list_item(in_memory_project):
+ in_memory_project.put("list_item", [1, 2, 3])
+ assert in_memory_project.get("list_item") == [1, 2, 3]
-def test_put_list_item(project):
- project.put("list_item", [1, 2, 3])
- assert project.get("list_item") == [1, 2, 3]
+def test_put_dict_item(in_memory_project):
+ in_memory_project.put("dict_item", {"key": "value"})
+ assert in_memory_project.get("dict_item") == {"key": "value"}
-def test_put_dict_item(project):
- project.put("dict_item", {"key": "value"})
- assert project.get("dict_item") == {"key": "value"}
-
-def test_put_pandas_dataframe(project):
+def test_put_pandas_dataframe(in_memory_project):
dataframe = pandas.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
- project.put("pandas_dataframe", dataframe)
- pandas.testing.assert_frame_equal(project.get("pandas_dataframe"), dataframe)
+ in_memory_project.put("pandas_dataframe", dataframe)
+ pandas.testing.assert_frame_equal(
+ in_memory_project.get("pandas_dataframe"), dataframe
+ )
-def test_put_pandas_series(project):
+def test_put_pandas_series(in_memory_project):
series = pandas.Series([0, 1, 2])
- project.put("pandas_series", series)
- pandas.testing.assert_series_equal(project.get("pandas_series"), series)
+ in_memory_project.put("pandas_series", series)
+ pandas.testing.assert_series_equal(in_memory_project.get("pandas_series"), series)
-def test_put_numpy_array(project):
+def test_put_numpy_array(in_memory_project):
# Add a Numpy array
arr = numpy.array([1, 2, 3, 4, 5])
- project.put("numpy_array", arr) # NumpyArrayItem
- numpy.testing.assert_array_equal(project.get("numpy_array"), arr)
+ in_memory_project.put("numpy_array", arr) # NumpyArrayItem
+ numpy.testing.assert_array_equal(in_memory_project.get("numpy_array"), arr)
-def test_put_mpl_figure(project, monkeypatch):
+def test_put_mpl_figure(in_memory_project, monkeypatch):
# Add a Matplotlib figure
def savefig(*args, **kwargs):
return ""
@@ -83,35 +74,35 @@ def savefig(*args, **kwargs):
fig, ax = plt.subplots()
ax.plot([1, 2, 3, 4])
- project.put("mpl_figure", fig) # MediaItem (SVG)
- assert isinstance(project.get("mpl_figure"), bytes)
+ in_memory_project.put("mpl_figure", fig) # MediaItem (SVG)
+ assert isinstance(in_memory_project.get("mpl_figure"), bytes)
-def test_put_vega_chart(project):
+def test_put_vega_chart(in_memory_project):
# Add an Altair chart
altair_chart = altair.Chart().mark_point()
- project.put("vega_chart", altair_chart)
- assert isinstance(project.get("vega_chart"), bytes)
+ in_memory_project.put("vega_chart", altair_chart)
+ assert isinstance(in_memory_project.get("vega_chart"), bytes)
-def test_put_pil_image(project):
+def test_put_pil_image(in_memory_project):
# Add a PIL Image
pil_image = Image.new("RGB", (100, 100), color="red")
with BytesIO() as output:
# FIXME: Not JPEG!
pil_image.save(output, format="jpeg")
- project.put("pil_image", pil_image) # MediaItem (PNG)
- assert isinstance(project.get("pil_image"), bytes)
+ in_memory_project.put("pil_image", pil_image) # MediaItem (PNG)
+ assert isinstance(in_memory_project.get("pil_image"), bytes)
-def test_put_rf_model(project, monkeypatch):
+def test_put_rf_model(in_memory_project, monkeypatch):
# Add a scikit-learn model
monkeypatch.setattr("sklearn.utils.estimator_html_repr", lambda _: "")
model = RandomForestClassifier()
model.fit(numpy.array([[1, 2], [3, 4]]), [0, 1])
- project.put("rf_model", model) # ScikitLearnModelItem
- assert isinstance(project.get("rf_model"), RandomForestClassifier)
+ in_memory_project.put("rf_model", model) # ScikitLearnModelItem
+ assert isinstance(in_memory_project.get("rf_model"), RandomForestClassifier)
def test_load(tmp_path):
@@ -127,82 +118,82 @@ def test_load(tmp_path):
assert isinstance(p, Project)
-def test_put(project):
- project.put("key1", 1)
- project.put("key2", 2)
- project.put("key3", 3)
- project.put("key4", 4)
+def test_put(in_memory_project):
+ in_memory_project.put("key1", 1)
+ in_memory_project.put("key2", 2)
+ in_memory_project.put("key3", 3)
+ in_memory_project.put("key4", 4)
- assert project.list_item_keys() == ["key1", "key2", "key3", "key4"]
+ assert in_memory_project.list_item_keys() == ["key1", "key2", "key3", "key4"]
-def test_put_twice(project):
- project.put("key2", 2)
- project.put("key2", 5)
+def test_put_twice(in_memory_project):
+ in_memory_project.put("key2", 2)
+ in_memory_project.put("key2", 5)
- assert project.get("key2") == 5
+ assert in_memory_project.get("key2") == 5
-def test_put_int_key(project, caplog):
+def test_put_int_key(in_memory_project, caplog):
# Warns that 0 is not a string, but doesn't raise
- project.put(0, "hello")
+ in_memory_project.put(0, "hello")
assert len(caplog.record_tuples) == 1
- assert project.list_item_keys() == []
+ assert in_memory_project.list_item_keys() == []
-def test_get(project):
- project.put("key1", 1)
- assert project.get("key1") == 1
+def test_get(in_memory_project):
+ in_memory_project.put("key1", 1)
+ assert in_memory_project.get("key1") == 1
with pytest.raises(KeyError):
- project.get("key2")
+ in_memory_project.get("key2")
-def test_delete(project):
- project.put("key1", 1)
- project.delete_item("key1")
+def test_delete(in_memory_project):
+ in_memory_project.put("key1", 1)
+ in_memory_project.delete_item("key1")
- assert project.list_item_keys() == []
+ assert in_memory_project.list_item_keys() == []
with pytest.raises(KeyError):
- project.delete_item("key2")
+ in_memory_project.delete_item("key2")
-def test_keys(project):
- project.put("key1", 1)
- project.put("key2", 2)
- assert project.list_item_keys() == ["key1", "key2"]
+def test_keys(in_memory_project):
+ in_memory_project.put("key1", 1)
+ in_memory_project.put("key2", 2)
+ assert in_memory_project.list_item_keys() == ["key1", "key2"]
-def test_view(project):
+def test_view(in_memory_project):
layout = ["key1", "key2"]
view = View(layout=layout)
- project.put_view("view", view)
- assert project.get_view("view") == view
+ in_memory_project.put_view("view", view)
+ assert in_memory_project.get_view("view") == view
-def test_list_view_keys(project):
+def test_list_view_keys(in_memory_project):
view = View(layout=[])
- project.put_view("view", view)
- assert project.list_view_keys() == ["view"]
+ in_memory_project.put_view("view", view)
+ assert in_memory_project.list_view_keys() == ["view"]
-def test_put_several_happy_path(project):
- project.put({"a": "foo", "b": "bar"})
- assert project.list_item_keys() == ["a", "b"]
+def test_put_several_happy_path(in_memory_project):
+ in_memory_project.put({"a": "foo", "b": "bar"})
+ assert in_memory_project.list_item_keys() == ["a", "b"]
-def test_put_several_canonical(project):
+def test_put_several_canonical(in_memory_project):
"""Use `put_several` instead of the `put` alias."""
- project.put_several({"a": "foo", "b": "bar"})
- assert project.list_item_keys() == ["a", "b"]
+ in_memory_project.put_several({"a": "foo", "b": "bar"})
+ assert in_memory_project.list_item_keys() == ["a", "b"]
-def test_put_several_some_errors(project, caplog):
- project.put(
+def test_put_several_some_errors(in_memory_project, caplog):
+ in_memory_project.put(
{
0: "hello",
1: "hello",
@@ -210,34 +201,34 @@ def test_put_several_some_errors(project, caplog):
}
)
assert len(caplog.record_tuples) == 3
- assert project.list_item_keys() == []
+ assert in_memory_project.list_item_keys() == []
-def test_put_several_nested(project):
- project.put({"a": {"b": "baz"}})
- assert project.list_item_keys() == ["a"]
- assert project.get("a") == {"b": "baz"}
+def test_put_several_nested(in_memory_project):
+ in_memory_project.put({"a": {"b": "baz"}})
+ assert in_memory_project.list_item_keys() == ["a"]
+ assert in_memory_project.get("a") == {"b": "baz"}
-def test_put_several_error(project):
+def test_put_several_error(in_memory_project):
"""If some key-value pairs are wrong, add all that are valid and print a warning."""
- project.put({"a": "foo", "b": (lambda: "unsupported object")})
- assert project.list_item_keys() == ["a"]
+ in_memory_project.put({"a": "foo", "b": (lambda: "unsupported object")})
+ assert in_memory_project.list_item_keys() == ["a"]
-def test_put_key_is_a_tuple(project):
+def test_put_key_is_a_tuple(in_memory_project):
"""If key is not a string, warn."""
- project.put(("a", "foo"), ("b", "bar"))
- assert project.list_item_keys() == []
+ in_memory_project.put(("a", "foo"), ("b", "bar"))
+ assert in_memory_project.list_item_keys() == []
-def test_put_key_is_a_set(project):
+def test_put_key_is_a_set(in_memory_project):
"""Cannot use an unhashable type as a key."""
with pytest.raises(ProjectPutError):
- project.put(set(), "hello", on_error="raise")
+ in_memory_project.put(set(), "hello", on_error="raise")
-def test_put_wrong_key_and_value_raise(project):
+def test_put_wrong_key_and_value_raise(in_memory_project):
"""When `on_error` is "raise", raise the first error that occurs."""
with pytest.raises(ProjectPutError):
- project.put(0, (lambda: "unsupported object"), on_error="raise")
+ in_memory_project.put(0, (lambda: "unsupported object"), on_error="raise")