Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement cross_validate functionality #443

Draft
wants to merge 42 commits into
base: main
Choose a base branch
from

Conversation

augustebaum
Copy link
Contributor

@augustebaum augustebaum commented Oct 4, 2024

Introduces a new item type, CrossValidateItem, as well as a top-level function skore.cross_validate, which detects the type of ML task based on the input estimator and target, adds some scorers, and runs scikit-learn's cross_validate with those extra arguments.
The results are stored in the Project given as input, always under key cross_validation, along with an interactive summary plot.
The output of skore.cross_validate is as close as possible to scikit-learn's, except for some edge cases. Also, if the user passes scorer="my_metric", the output dict will contain key test_my_metric as well as test_score, whereas in the scikit-learn version the dict only has test_score.

TL;DR:

  • Make Altair a mandatory dependency
  • Add CrossValidateItem
  • Add cross_validate function
  • Add plot_cross_validation function
  • Refactor tests (add fixture in_memory_project and use that instead of project for clarity -- it would have been great to do from conftest import in_memory_project as project but pytest doesn't work that way)

How to test:

  • I'm in a notebook/VSCode, I have a sklearn estimator, I run skore.cross_validate on it just like I would with sklearn.model_selection.cross_validate, except I also give the argument project=project to let Skore know where to store things (if the argument is not passed we'll do cross-validation but nothing will be saved)
  • A plot appears in the UI and in my notebook
  • I forgot a metric; I change the cell and re-run
  • A new plot is outputted and the one on the UI updates

Addresses the first part of #383


To-do:

  • Add Marie's tests as automated tests
  • Potentially use pytest-cases for cross-validation tests
  • Deal with multi-class classification (use ["roc_auc_over_weighted", "neg_log_loss", "recall_weighted", "precision_weighted"])
  • Deal with the case where X and y are not numpy arrays
    • At least set up the infrastructure to easily support pandas, numpy and lists; if only numpy works for now, that's okay
    • Look into struct.pack
  • Remove metrics depending on e.g. if estimator has predict_proba
    • Note for example SVC has the method defined, but it raises if the model isn't initialized with probability=True, so a try-except will be more accurate than hasattr
  • Deal with the case where the estimator params are not serializable
    • Use repr for the time being
  • Make it so when the user passes metric="my_metric", the output contains test_my_metric as well as test_score
  • Ensure that the stored cv_results are serializable
    • Serialize the numpy arrays
    • Deal with the case where the user passes return_estimator=True (we can't pass that to CrossValidateItem because it's not serializable)
      • Asked Marie
    • Confirm that the get() output is appropriate
  • Fix plot not updating in the UI
  • Remove default metrics for clustering
  • Stop storing the plot directly in CrossValidateItem, instead store JSON
    • Rename the attribute to plot_serialized, add a property plot
  • Investigate Marie's issue of the empty plot

@augustebaum augustebaum force-pushed the cross-validate-item branch 14 times, most recently from 1a1eeff to 32e26fd Compare October 9, 2024 10:24
@augustebaum augustebaum marked this pull request as ready for review October 9, 2024 10:28
skore/src/skore/item/cross_validate_item.py Outdated Show resolved Hide resolved
examples/basic_usage.ipynb Outdated Show resolved Hide resolved
examples/basic_usage.py Outdated Show resolved Hide resolved
examples/basic_usage.py Outdated Show resolved Hide resolved
skore/src/skore/cross_validate.py Outdated Show resolved Hide resolved
@MarieS-WiMLDS
Copy link
Contributor

MarieS-WiMLDS commented Oct 9, 2024

hi!
When I did make install-skore to test this PR, some files were created in the frontend folder. Git wants to follow these files. If they do not exist already in the folder, I supposed it should be in gitignore?

// EDIT: might have been some old stuff. Fixed.

Copy link
Collaborator

@thomass-dev thomass-dev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solid work, but needs a few adjustments IMHO.
Please read all my comments at once. Thanks!

skore/src/skore/cross_validate.py Show resolved Hide resolved
skore/src/skore/cross_validate.py Outdated Show resolved Hide resolved
skore/src/skore/cross_validate.py Outdated Show resolved Hide resolved
skore/src/skore/item/cross_validate_item.py Outdated Show resolved Hide resolved
skore/src/skore/item/cross_validate_item.py Show resolved Hide resolved
skore/src/skore/item/cross_validate_item.py Outdated Show resolved Hide resolved
skore/src/skore/cross_validate.py Outdated Show resolved Hide resolved
skore/src/skore/ui/project_routes.py Show resolved Hide resolved
skore/src/skore/item/cross_validate_item.py Outdated Show resolved Hide resolved
skore/src/skore/cross_validate.py Outdated Show resolved Hide resolved
@MarieS-WiMLDS
Copy link
Contributor

At first I tried this (cf screenshot). I understood after that I had to do what is commented.
Capture d’écran du 2024-10-09 17-45-05

Is there a technical constraint to prevent us from having the cross_validate function at the root of the lib? would it be bad practice?

@augustebaum
Copy link
Contributor Author

augustebaum commented Oct 9, 2024

No constraint, it's just a matter of design. I'll make the shortcut available now

@MarieS-WiMLDS
Copy link
Contributor

MarieS-WiMLDS commented Oct 9, 2024

It's super cool to see this live 🤩

I tested:

  • creating an error in scikit-learn --> the error message is passed correctly!
  • a regression problem (with diabetes dataset & lasso)
  • a multi-class classification problem (with iris & random forest)
  • a binary classification problem
  • a classification problem

Questions and remarks:

  1. Can we drop score_time? I don't find this metric very interesting, and I have the feeling that it overloads the options
  2. What is the test_score?
  3. The plot cross_validate in skore ui doesn't update
  4. in the classif multiclass. I hadn't think of this usecase before actually. The default values for binary and multi-class shouldn't be the same.
    a. the default I chose aren't correct, it creates an error. let's create an additional elif, and instead of recall and precision, use recall_weighted and precision_weighted.
    b. Because the default aren't correct, there was an error in the process. Yet, some part of cross_validate in scikit learn side ran to the end, and we have some results. Is it normal not to have at least these results displayed? (cf screen shot below).
  5. Again, for classification, the default are not correct on my side. The silhouette score exists in scikit-learn, but you have to pass it to cross_validate through a callable, and not a string. I'm surprised, so I would tend to say that we shouldn't use it as default. @sylvaincom what do you think about this? (I asked also in OS chan in slack).

Capture d’écran du 2024-10-09 19-36-18

EDIT: based on Gaël's and Guillaume's feedback, let's remove all default score for the clustering, because there is no "usual" metric. Cross_validate doesn't have much sense for clustering anyway.

@augustebaum
Copy link
Contributor Author

I tested:

  • creating an error in scikit-learn --> the error message is passed correctly!
  • a regression problem (with diabetes dataset & lasso)
  • a multi-class classification problem (with iris & random forest)
  • a binary classification problem
  • a classification problem

That's really great, thanks! It would be really useful to add these use-cases in our tests directly, let's sync to do this.

  • Can we drop score_time?

It depends if we want to uphold the requirement that our cross_validate returns the same as scikit-learn's. IMO it makes sense to follow the principle of least surprise, so keep it.

  • What is the test_score?

If you don't pass any scoring parameter (or if you pass a single metric, like scoring="r2"), this metric is the one shown in test_score. If you don't pass scoring, in many cases scikit-learn can make a guess at what metric to use.

  • The plot cross_validate in skore ui doesn't update

Will investigate.

  • Because the default aren't correct, there was an error in the process (cf screen shot below).

This is not normal; let's sync to reproduce this.

Let's remove all default scores for clustering

Got it!

@augustebaum
Copy link
Contributor Author

It looks like neg_brier_score requires the estimator to have a predict_proba method. I guess we can check the estimator beforehand and remove this scorer if don't have that method.

@augustebaum
Copy link
Contributor Author

@MarieS-WiMLDS The scorer recall_weighted does not exist; did you mean recall_score(average="weighted")?

@augustebaum
Copy link
Contributor Author

@MarieS-WiMLDS What should be the output of project.get("cross_validate")? For now, it outputs something like what scikit-learn outputs ({"test_score": ..., ...}), but this doesn't work when the user e.g. sets return_estimator=True, because skore is not meant to store anything and everything. An option could be to return the altair plot, another could be something like the scikit-learn outputs but without the non-serializable things...

@MarieS-WiMLDS
Copy link
Contributor

For multi-class classification, the following metrics should be used to be compliant:

  • roc_auc_ovr_weighted
  • recall_weighted
  • precision_weighted
  • neg_log_loss

Copy link
Collaborator

@thomass-dev thomass-dev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pair reviewed with @rouk1 .

skore/src/skore/item/cross_validate_item.py Outdated Show resolved Hide resolved
skore/src/skore/item/cross_validate_item.py Outdated Show resolved Hide resolved
skore/src/skore/item/cross_validate_item.py Show resolved Hide resolved
skore/src/skore/item/cross_validate_item.py Outdated Show resolved Hide resolved
skore/src/skore/cross_validate.py Outdated Show resolved Hide resolved
Returns
-------
new_scorers : dict[str, str | None]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this typing which seems not valid.


def _add_scorers(estimator, y, scorers):
"""Expand `scorers` with other scorers, based on `estimator` and `y`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you had a comment expliciting that the shape of scorers is important.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add unittests to the cross_validate_item too.


def test_cross_validate_2_extra_metrics(in_memory_project, lasso):
args = list(lasso)
kwargs = {"scoring": ["r2", "neg_mean_squared_error"], "cv": 3}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a test for each type of score available:

If scoring represents multiple scores, one can use:
a list or tuple of unique strings;
a callable returning a dictionary where the keys are the metric names and the values are the metric scores;
a dictionary with metric names as keys and callables a values.

return "classification"


def _add_scorers(estimator, y, scorers):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a dedicated test.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

skore.cross_validate
5 participants