Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement cross_validate functionality #443

Draft
wants to merge 42 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
bb726cd
Move altair to required dependencies
augustebaum Oct 8, 2024
849c5c4
refactor: Factorize `project` fixture to `conftest.py`
augustebaum Oct 4, 2024
ab0479e
Add `CrossValidateItem`
augustebaum Oct 4, 2024
e579b71
add `cross_validate` function
augustebaum Oct 4, 2024
c090117
[squashme] wip
augustebaum Oct 4, 2024
d38414c
add _find_ml_task
augustebaum Oct 4, 2024
5b275f8
use new scorers to run cross-validate
augustebaum Oct 4, 2024
164e2d6
remove unit test
augustebaum Oct 9, 2024
39ea8c8
make project an argument of cross_validate
augustebaum Oct 9, 2024
5832c39
squashme wip
augustebaum Oct 9, 2024
7219e97
rename CrossValidateItem to CrossValidationItem
augustebaum Oct 8, 2024
4e7bccf
Make plot an attribute of the CVItem
augustebaum Oct 8, 2024
1ef04a4
Make plot_cross_validation a component of CrossValidationItem
augustebaum Oct 8, 2024
7a23511
Define how to output CrossValidationItem for skore-ui
augustebaum Oct 8, 2024
d868cb1
refine cross_validate docstring example
augustebaum Oct 8, 2024
37c017c
adjust cv plot dimensions to fit y axis labels with more certainty
augustebaum Oct 8, 2024
9376fc9
remove logging in case `display` is not found
augustebaum Oct 8, 2024
9ff7bf1
update tests
augustebaum Oct 8, 2024
ee207eb
update basic_usage example
augustebaum Oct 14, 2024
47f2b44
address cam's feedback
augustebaum Oct 9, 2024
5f9e55a
Update docstrings and type-hints
augustebaum Oct 9, 2024
93eb731
Make `from skore import cross_validate` possible
augustebaum Oct 9, 2024
6e9060e
replace `try-except` with `contextlib.suppress`
augustebaum Oct 10, 2024
f807be6
remove added scorers for clustering
augustebaum Oct 10, 2024
4a4c80c
use `else` when adding scorers
augustebaum Oct 10, 2024
cd4e1ed
Replace CrossValidationItem.plot with a bytes attribute and a property
augustebaum Oct 10, 2024
3adcada
properly refresh VegaWidget when underlying chart changed
augustebaum Oct 10, 2024
e487a8b
refactor _expand_scorers to _add_scorers
augustebaum Oct 10, 2024
60d4de3
Make it so metric name is always included in cv_results, alongside
augustebaum Oct 10, 2024
31c6929
Store estimator params as a string in CrossValidationItem
augustebaum Oct 10, 2024
56d9140
fix bug with plotting when cv_results have keys `indices` or `estimator`
augustebaum Oct 10, 2024
969468e
make cv_results attribute of CrossValidationItem primitive
augustebaum Oct 10, 2024
4875745
[squashme] wip
augustebaum Oct 11, 2024
9ac246a
Add Marie's tests
augustebaum Oct 14, 2024
3bbe03e
add multi-class classification case
augustebaum Oct 14, 2024
67c69ac
Deal with non-numpy cases when instantiating CrossValidationItem
augustebaum Oct 14, 2024
f930a10
Make PR compatible with Python>=3.9
augustebaum Oct 14, 2024
5efd10c
make plot a cached property
augustebaum Oct 15, 2024
75d6764
don't use MediaItem to convert plot to bytes
augustebaum Oct 15, 2024
d3f9729
cache plot at initialization
augustebaum Oct 15, 2024
a199749
remove property `cv_results`, use `cv_results_serialized` directly
augustebaum Oct 15, 2024
76dce8d
clean docstring
augustebaum Oct 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 34 additions & 8 deletions examples/basic_usage.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -631,13 +631,39 @@
"cell_type": "markdown",
"id": "57",
"metadata": {},
"source": [
"---\n",
"## Cross-validation with skore"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "58",
"metadata": {},
"outputs": [],
"source": [
"from sklearn import datasets, linear_model\n",
"from skore.cross_validate import cross_validate\n",
"diabetes = datasets.load_diabetes()\n",
"X = diabetes.data[:150]\n",
"y = diabetes.target[:150]\n",
"lasso = linear_model.Lasso()\n",
"\n",
"cv_results = cross_validate(lasso, X, y, cv=3, project=project)"
]
},
{
"cell_type": "markdown",
"id": "59",
"metadata": {},
"source": [
"_Stay tuned for some new features!_"
]
},
{
"cell_type": "markdown",
"id": "58",
"id": "60",
"metadata": {},
"source": [
"---\n",
Expand All @@ -649,7 +675,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "59",
"id": "61",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -662,7 +688,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "60",
"id": "62",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -675,7 +701,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "61",
"id": "63",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -687,7 +713,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "62",
"id": "64",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -700,7 +726,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "63",
"id": "65",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -713,7 +739,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "64",
"id": "66",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -726,7 +752,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "65",
"id": "67",
"metadata": {},
"outputs": [],
"source": []
Expand Down
14 changes: 14 additions & 0 deletions examples/basic_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,20 @@ def my_func(x):

project.put("my_fitted_pipeline", my_pipeline)

# %% [markdown]
# ---
# ## Cross-validation with skore

# %%
from sklearn import datasets, linear_model
from skore.cross_validate import cross_validate
diabetes = datasets.load_diabetes()
X = diabetes.data[:150]
y = diabetes.target[:150]
lasso = linear_model.Lasso()

cv_results = cross_validate(lasso, X, y, cv=3, project=project)

# %% [markdown]
# _Stay tuned for some new features!_

Expand Down
28 changes: 26 additions & 2 deletions skore-ui/src/components/VegaWidget.vue
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
<script setup lang="ts">
import { isUserInDarkMode } from "@/services/utils";
import { isDeepEqual, isUserInDarkMode } from "@/services/utils";
import { View as VegaView } from "vega";
import embed, { type Config, type VisualizationSpec } from "vega-embed";
import { onBeforeUnmount, onMounted, ref } from "vue";
import { onBeforeUnmount, onMounted, ref, watch } from "vue";

const props = defineProps<{ spec: VisualizationSpec }>();

Expand Down Expand Up @@ -49,6 +49,30 @@ onBeforeUnmount(() => {
vegaView.finalize();
}
});

watch(
() => props.spec,
async (newSpec, oldSpec) => {
if (isDeepEqual(newSpec, oldSpec)) {
return;
}
// Refresh view
// TODO: This perhaps could be done in a more fine-grained way
const r = await embed(
container.value!,
{
width: container.value?.clientWidth || 0,
...newSpec,
},
{
theme: isUserInDarkMode() ? "dark" : undefined,
config: vegaConfig,
actions: false,
}
);
vegaView = r.view;
}
);
</script>

<template>
Expand Down
2 changes: 1 addition & 1 deletion skore/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ dynamic = [
requires-python = ">=3.9, <3.13"
maintainers = [{name = "skore developers", email="skore@signal.probabl.ai"}]
dependencies = [
"altair>=5,<6",
"diskcache",
"fastapi",
"rich",
Expand Down Expand Up @@ -61,7 +62,6 @@ artifacts = ["src/skore/ui/static/"]

[project.optional-dependencies]
test = [
"altair",
"httpx",
"matplotlib",
"pandas",
Expand Down
2 changes: 2 additions & 0 deletions skore/src/skore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@

import rich.logging

from skore.cross_validate import cross_validate
from skore.project import Project, load

from .utils._show_versions import show_versions

__all__ = [
"cross_validate",
"load",
"show_versions",
"Project",
Expand Down
Loading