Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

if y is bool convert it into numerical #431

Merged
merged 3 commits into from
Dec 3, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Add label encoder for supervised transform
Remove converting to float in catboost
  • Loading branch information
mausam.singh committed Nov 30, 2023
commit c6de169f51a2c2f9fbe2f253d9ed4ac675dd8e66
1 change: 0 additions & 1 deletion category_encoders/cat_boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def __init__(self, verbose=0, cols=None, drop_invariant=False, return_df=True,

def _fit(self, X, y, **kwargs):
X = X.copy(deep=True)
y = y.astype(float) #Incase y is bool or categorical.
self._mean = y.mean()
self.mapping = {col: self._fit_column_map(X[col], y) for col in self.cols}

Expand Down
22 changes: 18 additions & 4 deletions category_encoders/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
import pandas as pd
import numpy as np
import sklearn.base
from pandas.api.types import is_object_dtype, is_string_dtype
from pandas.api.types import is_object_dtype, is_string_dtype, is_numeric_dtype
from pandas.core.dtypes.dtypes import CategoricalDtype
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.exceptions import NotFittedError
from typing import Dict, List, Optional, Union
from scipy.sparse import csr_matrix
from sklearn.preprocessing import LabelEncoder

__author__ = 'willmcginnis'

Expand Down Expand Up @@ -294,11 +295,18 @@ def fit(self, X, y=None, **kwargs):
Returns self.

"""
self._check_fit_inputs(X, y)
X, y = convert_inputs(X, y)
self._check_fit_inputs(X, y)
self.feature_names_in_ = X.columns.tolist()
self.n_features_in_ = len(self.feature_names_in_)

if self._get_tags().get('supervised_encoder'):
if not is_numeric_dtype(y):
self.lab_encoder_ = LabelEncoder()
y = self.lab_encoder_.fit_transform(y)
else:
self.lab_encoder_ = None

self._dim = X.shape[1]
self._determine_fit_columns(X)

Expand All @@ -324,8 +332,12 @@ def fit(self, X, y=None, **kwargs):
return self

def _check_fit_inputs(self, X, y):
if self._get_tags().get('supervised_encoder') and y is None:
raise ValueError('Supervised encoders need a target for the fitting. The target cannot be None')
if self._get_tags().get('supervised_encoder'):
if y is None:
raise ValueError('Supervised encoders need a target for the fitting. The target cannot be None')
else:
if y.isna().any(): # Target column should never have missing values
raise ValueError("The target column y must not contain missing values.")

def _check_transform_inputs(self, X):
if self.handle_missing == 'error':
Expand Down Expand Up @@ -435,6 +447,8 @@ def transform(self, X, y=None, override_return_df=False):
# first check the type
X, y = convert_inputs(X, y, deep=True)
self._check_transform_inputs(X)
if y is not None and self.lab_encoder_ is not None:
y = self.lab_encoder_.transform(y)

if not list(self.cols):
return X
Expand Down
2 changes: 2 additions & 0 deletions category_encoders/woe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from category_encoders.ordinal import OrdinalEncoder
import category_encoders.utils as util
from sklearn.utils.random import check_random_state
import pandas as pd

__author__ = 'Jan Motl'

Expand Down Expand Up @@ -87,6 +88,7 @@ def __init__(self, verbose=0, cols=None, drop_invariant=False, return_df=True,

def _fit(self, X, y, **kwargs):
# The label must be binary with values {0,1}
y = pd.Series(y)
unique = y.unique()
if len(unique) != 2:
raise ValueError("The target column y must be binary. But the target contains " + str(len(unique)) + " unique value(s).")
Expand Down
Loading