Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

12 allow passing kfold iterator directly in the crossvalidation class #14

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions flexcv/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import warnings
from typing import Dict
from typing import Dict, Iterator

import numpy as np
import optuna
Expand All @@ -11,6 +11,7 @@
from sklearn.preprocessing import StandardScaler
from sklearn.utils.validation import check_random_state
from statsmodels.tools.sm_exceptions import ConvergenceWarning
from sklearn.model_selection._split import BaseCrossValidator
from tqdm import tqdm

from .cv_logging import (
Expand Down Expand Up @@ -111,8 +112,8 @@ def cross_validate(
run: NeptuneRun,
groups: pd.Series,
slopes: pd.DataFrame | pd.Series,
split_out: CrossValMethod,
split_in: CrossValMethod,
split_out: CrossValMethod | BaseCrossValidator | Iterator,
split_in: CrossValMethod | BaseCrossValidator | Iterator,
break_cross_val: bool,
scale_in: bool,
scale_out: bool,
Expand Down Expand Up @@ -140,7 +141,7 @@ def cross_validate(
run (NeptuneRun): A Run object to log to.
groups (pd.Series): The grouping or clustering variable.
slopes (pd.DataFrame | pd.Series): Random slopes variable(s)
split_out (CrossValMethod): Outer split strategy.
split_out (CrossValMethod | BaseCross): Outer split strategy.
split_in (CrossValMethod): Inner split strategy.
break_cross_val (bool): If True, only the first outer fold is evaluated.
scale_in (bool): If True, the features are scaled in the inner cross-validation to zero mean and unit variance. This works independently of the outer scaling.
Expand Down Expand Up @@ -303,7 +304,7 @@ def cross_validate(
else:
# this block performs the inner cross-validation with Optuna

n_trials
n_trials = mapping[model_name]["n_trials"]
n_jobs_cv_int = mapping[model_name]["n_jobs_cv"]

pipe_in = Pipeline(
Expand All @@ -328,11 +329,6 @@ def cross_validate(
study_update_freq=10, # log every 10th trial,
)

if n_trials == "mapped": # TODO can this be automatically detected by the CrossValidation Class?
n_trials = mapping[model_name]["n_trials"]
if not isinstance(n_trials, int):
raise ValueError("Invalid value for n_trials.")

# generate numpy random_state object for seeding the sampler
random_state = check_random_state(random_seed)
sampler_seed = random_state.randint(0, np.iinfo("int32").max)
Expand Down
27 changes: 22 additions & 5 deletions flexcv/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
import logging
from dataclasses import dataclass
from pprint import pformat
from typing import Iterator

import pandas as pd
import numpy as np
from neptune.metadata_containers.run import Run as NeptuneRun
from neptune.types import File
from sklearn.model_selection import BaseCrossValidator

from .core import cross_validate
from .metrics import MetricsDict
Expand Down Expand Up @@ -199,8 +202,8 @@ def set_data(

def set_splits(
self,
split_out: str | CrossValMethod = CrossValMethod.KFOLD,
split_in: str | CrossValMethod = CrossValMethod.KFOLD,
split_out: str | CrossValMethod | BaseCrossValidator | Iterator = CrossValMethod.KFOLD,
split_in: str | CrossValMethod | BaseCrossValidator | Iterator = CrossValMethod.KFOLD,
n_splits_out: int = 5,
n_splits_in: int = 5,
scale_out: bool = True,
Expand Down Expand Up @@ -441,6 +444,15 @@ def _prepare_before_perform(self):

if not hasattr(self.config, "run"):
self.config["run"] = DummyRun()

# check for every key in config, if "n_trials" is set
# if not, set to the value of self.config["n_trials"]
for model_key, inner_dict in self.config["mapping"].items():
if "n_trials" not in inner_dict:
self.config["mapping"][model_key]["n_trials"] = self.config["n_trials"]

elif not isinstance(inner_dict["n_trials"], int):
raise TypeError("n_trials must be an integer")

def _log(self):
"""Logs the config to Neptune. If None, a Dummy is instantiated.
Expand Down Expand Up @@ -468,9 +480,14 @@ def _log(self):
run["data/slopes_name"].log(
pd.DataFrame(self.config["slopes"]).columns.tolist()
)

run["cross_val/cross_val_method_out"].log(self.config["split_out"].value)
run["cross_val/cross_val_method_in"].log(self.config["split_in"].value)
try:
run["cross_val/cross_val_method_out"].log(self.config["split_out"].value)
except AttributeError:
run["cross_val/cross_val_method_out"].log(self.config["split_out"])
try:
run["cross_val/cross_val_method_in"].log(self.config["split_in"].value)
except AttributeError:
run["cross_val/cross_val_method_in"].log(self.config["split_in"])
run["cross_val/n_splits_out"].log(self.config["n_splits_out"])
run["cross_val/n_splits_in"].log(self.config["n_splits_in"])
run["cross_val/scale_in"].log(self.config["scale_in"])
Expand Down
207 changes: 63 additions & 144 deletions flexcv/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,23 @@
from typing import Callable, Iterator

import pandas as pd
import numpy as np
from numpy import ndarray
from numpy.core._exceptions import UFuncTypeError
from sklearn.model_selection import (
from sklearn.model_selection._split import (
BaseCrossValidator,
GroupsConsumerMixin,
)
from sklearn.model_selection import (
GroupKFold,
KFold,
StratifiedGroupKFold,
StratifiedKFold,
)
from sklearn.preprocessing import KBinsDiscretizer

from .stratification import (
ContinuousStratifiedKFold,
ContinuousStratifiedGroupKFold,
ConcatenatedStratifiedKFold
)


class CrossValMethod(Enum):
Expand All @@ -28,20 +34,21 @@ class CrossValMethod(Enum):

Members:
- `KFOLD`: Regular sklearn `KFold` cross validation. No grouping information is used.
- `CUSTOMSTRAT`: Applies stratification on the target variable using a custom discretization of the target variable.
I.e. uses the sklearn `StratifiedKFold` cross validation but for a continuous target variable instead of a multi-class target variable.
- `GROUP`: Applies grouping information on the samples. I.e. uses the sklearn `GroupKFold` cross validation.
- `STRATGROUP`: Uses the sklearn `StratifiedGroupKFold` cross validation.
- `CUSTOMSTRATGROUP`: Applies stratification to both the target variable and the grouping information.
I.e. uses the sklearn `StratifiedGroupKFold` cross validation but for a continuous target variable instead of a multi-class target variable.

- `GROUP`: Regular sklearn `GroupKFold` cross validation. Grouping information is used.
- `STRAT`: Regular sklearn `StratifiedKFold` cross validation. No grouping information is used.
- `STRATGROUP`: Regular sklearn `StratifiedGroupKFold` cross validation. Grouping information is used.
- `CONTISTRAT`: Stratified cross validation for continuous targets. No grouping information is used.
- `CONTISTRATGROUP`: Stratified cross validation for continuous targets. Grouping information is used.
- `CONCATSTRATKFOLD`: Stratified cross validation. Leaky stratification on element-wise-concatenated target and group labels.
"""

KFOLD = "KFold"
GROUP = "GroupKFold"
CUSTOMSTRAT = "CustomStratifiedKFold"
STRAT = "StratifiedKFold"
STRATGROUP = "StratifiedGroupKFold"
CUSTOMSTRATGROUP = "CustomStratifiedGroupKFold"
CONTISTRAT = "ContinuousStratifiedKFold"
CONTISTRATGROUP = "ContinuousStratifiedGroupKFold"
CONCATSTRATKFOLD = "ConcatenatedStratifiedKFold"


def string_to_crossvalmethod(method: str) -> CrossValMethod:
Expand All @@ -67,113 +74,6 @@ def string_to_crossvalmethod(method: str) -> CrossValMethod:
raise ValueError("Invalid Cross Validation method given.")


class CustomStratifiedGroupKFold(BaseCrossValidator):
"""sklearn's StratifiedGroupKFold adapted for continuous target variables."""

def __init__(self, n_splits, shuffle=True, random_state=42, groups=None):
self.n_splits = n_splits
self.shuffle = shuffle
self.random_state = random_state
self.groups = groups

def split(self, X, y, groups=None):
"""Generate indices to split data into training and test set.
The data is first grouped by groups and then split into n_splits folds. The folds are made by preserving the percentage of samples for each class.
This is a variation of StratifiedGroupKFold that uses a custom discretization of the target variable.

Args:
X (array-like): Features
y (array-like): target
groups (array-like): Grouping/clustering variable (Default value = None)

Returns:
(Iterator[tuple[ndarray, ndarray]]): Iterator over the indices of the training and test set.
"""
self.sgkf = StratifiedGroupKFold(
n_splits=self.n_splits, shuffle=self.shuffle, random_state=self.random_state
)
assert y is not None, "y cannot be None"
kbins = KBinsDiscretizer(n_bins=10, encode="ordinal", strategy="quantile")
if isinstance(y, pd.Series):
y_cat = (
kbins.fit_transform(y.to_numpy().reshape(-1, 1)).flatten().astype(int)
)
y_cat = pd.Series(y_cat, index=y.index)
else:
y_cat = kbins.fit_transform(y.reshape(-1, 1)).flatten().astype(int) # type: ignore
return self.sgkf.split(X, y_cat, groups)

def get_n_splits(self, X, y=None, groups=None):
"""
Returns the number of splitting iterations in the cross-validator.

Returns:
(int): The number of splitting iterations in the cross-validator.
"""
return self.n_splits


class CustomStratifiedKFold(BaseCrossValidator):
"""Cross Validation Method.
This is a variation of StratifiedKFold that uses a custom discretization of the target variable.
Stratification is done on the concatination of discretized target variable and group instead of the original target variable.
This ensures, that distributions of the target variable per group are similar in each fold.
"""
def __init__(self, n_splits, shuffle=True, random_state=42, groups=None):
self.n_splits = n_splits
self.shuffle = shuffle
self.random_state = random_state
self.groups = groups

def split(self, X, y, groups=None):
"""Generate indices to split data into training and test set.
The data is first grouped by groups and then split into n_splits folds. The folds are made by preserving the percentage of samples for each class.
This is a variation of StratifiedGroupKFold that uses a custom discretization of the target variable.

Args:
X (array-like): Features
y (array-like): target
groups (array-like): Grouping variable (Default value = None)

Returns:
(Iterator[tuple[ndarray, ndarray]]): Iterator over the indices of the training and test set.
"""
self.skf = StratifiedKFold(
n_splits=self.n_splits, shuffle=self.shuffle, random_state=self.random_state
)
assert y is not None, "y cannot be None"
kbins = KBinsDiscretizer(n_bins=10, encode="ordinal", strategy="quantile")
if isinstance(y, pd.Series):
y_cat = (
kbins.fit_transform(y.to_numpy().reshape(-1, 1)).flatten().astype(int)
)
y_cat = pd.Series(y_cat, index=y.index)
else:
y_cat = kbins.fit_transform(y.reshape(-1, 1)).flatten().astype(int) # type: ignore
# concatenate y_cat and groups such that the stratification is done on both
# elementwise concatenation of three arrays
try:
y_cat = y_cat.astype(str) + "_" + groups.astype(str)
except UFuncTypeError:
# Why easy when you can do it the hard way?
y_concat = np.core.defchararray.add(np.core.defchararray.add(y_cat.astype(str), "_"), groups.astype(str))

return self.skf.split(X, y_concat)

def get_n_splits(self, X, y=None, groups=None):
"""

Args:
X (array-like): Features
y (array-like): target values. (Default value = None)
groups (array-like): grouping values. (Default value = None)

Returns:
(int) : The number of splitting iterations in the cross-validator.
"""
return self.n_splits


def make_cross_val_split(
*,
groups: pd.Series | None,
Expand All @@ -196,45 +96,64 @@ def make_cross_val_split(
(TypeError): If the given method is not one of KFOLD

"""

match method:
case CrossValMethod.KFOLD:
cross_val_obj = KFold(
kf = KFold(
n_splits=n_splits, random_state=random_state, shuffle=True
)
return cross_val_obj.split

return kf.split

case CrossValMethod.STRAT:
strat_skf = StratifiedKFold(
n_splits=n_splits, random_state=random_state, shuffle=True
)
return strat_skf.split

case CrossValMethod.CONTISTRAT:
conti_skf = ContinuousStratifiedKFold(
n_splits=n_splits, random_state=random_state, shuffle=True
)
return conti_skf.split

case CrossValMethod.GROUP:
if groups is None:
raise ValueError("Groups must be specified for GroupKFold.")
cross_val_obj = GroupKFold(n_splits=n_splits)
return partial(cross_val_obj.split, groups=groups)

gkf = GroupKFold(n_splits=n_splits)
return partial(gkf.split, groups=groups)

case CrossValMethod.STRATGROUP:
if groups is None:
raise ValueError("Groups must be specified for StratGroupKFold.")
cross_val_obj = StratifiedGroupKFold(
strat_gkf = StratifiedGroupKFold(
n_splits=n_splits, random_state=random_state, shuffle=True
)
return partial(cross_val_obj.split, groups=groups)
return partial(strat_gkf.split, groups=groups)

case CrossValMethod.CUSTOMSTRATGROUP:
if groups is None:
raise ValueError("Groups must be specified for CustomStratGroupKFold.")
cross_val_obj = CustomStratifiedGroupKFold(
case CrossValMethod.CONTISTRATGROUP:
conti_sgkf = ContinuousStratifiedGroupKFold(
n_splits=n_splits, random_state=random_state, shuffle=True
)
return partial(cross_val_obj.split, groups=groups)
return partial(conti_sgkf.split, groups=groups)

case CrossValMethod.CUSTOMSTRAT:
if groups is None:
raise ValueError("Groups must be specified for our StratifiedKFold.")
cross_val_obj = CustomStratifiedKFold(
case CrossValMethod.CONCATSTRATKFOLD:
concat_skf = ConcatenatedStratifiedKFold(
n_splits=n_splits, random_state=random_state, shuffle=True
)
return partial(cross_val_obj.split, groups=groups)
return partial(concat_skf.split, groups=groups)

case _:
raise TypeError("Invalid Cross Validation method given.")

is_cross_validator = isinstance(method, BaseCrossValidator)
is_groups_consumer = isinstance(method, GroupsConsumerMixin)

if is_cross_validator and is_groups_consumer:
return partial(method.split, groups=groups)

if is_cross_validator:
return method.split

if isinstance(method, Iterator):
return method

else:
raise TypeError("Invalid Cross Validation method given.")


if __name__ == "__main__":
Expand Down
Loading