Skip to content

Commit

Permalink
Merge pull request scikit-learn-contrib#393 from pimlock/fix-basen-in…
Browse files Browse the repository at this point in the history
…verse-transform

Fix basen_to_integer when column name contains regex metachar
  • Loading branch information
PaulWestenthanner authored Jan 13, 2023
2 parents 1def428 + 7a4482e commit a745057
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
2 changes: 1 addition & 1 deletion category_encoders/basen.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def basen_to_integer(self, X, cols, base):
out_cols = X.columns.values.tolist()

for col in cols:
col_list = [col0 for col0 in out_cols if re.match(str(col)+'_\\d+', str(col0))]
col_list = [col0 for col0 in out_cols if re.match(re.escape(str(col))+'_\\d+', str(col0))]
insert_at = out_cols.index(col_list[0])

if base == 1:
Expand Down
10 changes: 10 additions & 0 deletions tests/test_basen.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,16 @@ def test_inverse_transform_HaveHandleMissingValueAndHandleUnknownReturnNan_Expec

pd.testing.assert_frame_equal(expected, original)

def test_inverse_transform_HaveRegexMetacharactersInColumnName_ExpectInversed(self):
train = pd.DataFrame({'state (2-letter code)': ['il', 'ny', 'ca']})

enc = encoders.BaseNEncoder()
enc.fit(train)
result = enc.transform(train)
original = enc.inverse_transform(result)

pd.testing.assert_frame_equal(train, original)

def test_num_cols(self):
"""
Test that BaseNEncoder produces the correct number of output columns.
Expand Down

0 comments on commit a745057

Please sign in to comment.