Skip to content

Commit

Permalink
Fix contains check in string column (#8834)
Browse files Browse the repository at this point in the history
Fixes: #8832 

This PR fixes `contains` check in the `StringColumn`.  We were using `f"^{item}$"` to generate a regex and do a `contains_re` to check for an exact match for `item` in the `StringColumn`, but this approach would break if `item` by itself has some regex special characters, so replaced these checks with `libcudf.search.contains` which does the exact check for `item` in the `StringColumn`.

Authors:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

Approvers:
  - Ram (Ramakrishna Prabhu) (https://github.com/rgsl888prabhu)
  - Charles Blackmon-Luca (https://github.com/charlesbluca)

URL: #8834
  • Loading branch information
galipremsagar authored Jul 23, 2021
1 parent b803c4e commit fc95992
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
13 changes: 11 additions & 2 deletions python/cudf/cudf/core/column/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -5042,7 +5042,14 @@ def set_base_children(self, value: Tuple["column.ColumnBase", ...]):
super().set_base_children(value)

def __contains__(self, item: ScalarLike) -> bool:
return True in libstrings.contains_re(self, f"^{item}$")
if is_scalar(item):
return True in libcudf.search.contains(
self, column.as_column([item], dtype=self.dtype)
)
else:
return True in libcudf.search.contains(
self, column.as_column(item, dtype=self.dtype)
)

def as_numerical_column(
self, dtype: Dtype, **kwargs
Expand Down Expand Up @@ -5303,7 +5310,9 @@ def fillna(
return super().fillna(method=method)

def _find_first_and_last(self, value: ScalarLike) -> Tuple[int, int]:
found_indices = libstrings.contains_re(self, f"^{value}$")
found_indices = libcudf.search.contains(
self, column.as_column([value], dtype=self.dtype)
)
found_indices = libcudf.unary.cast(found_indices, dtype=np.int32)
first = column.as_column(found_indices).find_first_value(np.int32(1))
last = column.as_column(found_indices).find_last_value(np.int32(1))
Expand Down
11 changes: 11 additions & 0 deletions python/cudf/cudf/tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,3 +830,14 @@ def test_serialize_categorical_columns(data):
df = cudf.DataFrame(data)
recreated = df.__class__.deserialize(*df.serialize())
assert_eq(recreated, df)


@pytest.mark.parametrize(
"data", [["$ 1", "$ 2", "hello"], ["($) 1", "( 2", "hello", "^1$"]]
)
@pytest.mark.parametrize("value", ["$ 1", "hello", "$", "^1$"])
def test_categorical_string_index_contains(data, value):
idx = cudf.CategoricalIndex(data)
pidx = idx.to_pandas()

assert_eq(value in idx, value in pidx)

0 comments on commit fc95992

Please sign in to comment.