-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Implement cross_validate
functionality
#443
base: main
Are you sure you want to change the base?
Conversation
1a1eeff
to
32e26fd
Compare
hi! // EDIT: might have been some old stuff. Fixed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Solid work, but needs a few adjustments IMHO.
Please read all my comments at once. Thanks!
No constraint, it's just a matter of design. I'll make the shortcut available now |
It's super cool to see this live 🤩 I tested:
Questions and remarks:
EDIT: based on Gaël's and Guillaume's feedback, let's remove all default score for the clustering, because there is no "usual" metric. Cross_validate doesn't have much sense for clustering anyway. |
That's really great, thanks! It would be really useful to add these use-cases in our tests directly, let's sync to do this.
It depends if we want to uphold the requirement that our
If you don't pass any
Will investigate.
This is not normal; let's sync to reproduce this.
Got it! |
It looks like |
@MarieS-WiMLDS The scorer |
@MarieS-WiMLDS What should be the output of |
For multi-class classification, the following metrics should be used to be compliant:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pair reviewed with @rouk1 .
Returns | ||
------- | ||
new_scorers : dict[str, str | None] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove this typing which seems not valid.
|
||
def _add_scorers(estimator, y, scorers): | ||
"""Expand `scorers` with other scorers, based on `estimator` and `y`. | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you had a comment expliciting that the shape of scorers is important.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add unittests to the cross_validate_item
too.
|
||
def test_cross_validate_2_extra_metrics(in_memory_project, lasso): | ||
args = list(lasso) | ||
kwargs = {"scoring": ["r2", "neg_mean_squared_error"], "cv": 3} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a test for each type of score available:
If scoring represents multiple scores, one can use:
a list or tuple of unique strings;
a callable returning a dictionary where the keys are the metric names and the values are the metric scores;
a dictionary with metric names as keys and callables a values.
return "classification" | ||
|
||
|
||
def _add_scorers(estimator, y, scorers): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a dedicated test.
9eafcee
to
f930a10
Compare
Introduces a new item type,
CrossValidateItem
, as well as a top-level functionskore.cross_validate
, which detects the type of ML task based on the input estimator and target, adds some scorers, and runs scikit-learn'scross_validate
with those extra arguments.The results are stored in the
Project
given as input, always under keycross_validation
, along with an interactive summary plot.The output of
skore.cross_validate
is as close as possible to scikit-learn's, except for some edge cases. Also, if the user passesscorer="my_metric"
, the output dict will contain keytest_my_metric
as well astest_score
, whereas in the scikit-learn version the dict only hastest_score
.TL;DR:
CrossValidateItem
cross_validate
functionplot_cross_validation
functionin_memory_project
and use that instead ofproject
for clarity -- it would have been great to dofrom conftest import in_memory_project as project
but pytest doesn't work that way)How to test:
skore.cross_validate
on it just like I would withsklearn.model_selection.cross_validate
, except I also give the argumentproject=project
to let Skore know where to store things (if the argument is not passed we'll do cross-validation but nothing will be saved)Addresses the first part of #383
To-do:
pytest-cases
for cross-validation tests["roc_auc_over_weighted", "neg_log_loss", "recall_weighted", "precision_weighted"]
)X
andy
are not numpy arraysstruct.pack
predict_proba
Note for example SVC has the method defined, but it raises if the model isn't initialized withprobability=True
, so atry-except
will be more accurate thanhasattr
repr
for the time beingmetric="my_metric"
, the output containstest_my_metric
as well astest_score
return_estimator=True
(we can't pass that toCrossValidateItem
because it's not serializable)get()
output is appropriateCrossValidateItem
, instead store JSONplot_serialized
, add a propertyplot
Investigate Marie's issue of the empty plot