From 04085acf1ed43921b638ead432d654695b84d4ea Mon Sep 17 00:00:00 2001 From: GALI PREM SAGAR Date: Wed, 30 Aug 2023 09:55:38 -0500 Subject: [PATCH] Fix `name` selection in `Index.difference` and `Index.intersection` (#13986) closes #13985 This PR fixes issues with `Index.difference` and `Index.intersection` API where the name selection was incorrect and `NA` values handling wasn't happening in these two APIs. Authors: - GALI PREM SAGAR (https://github.com/galipremsagar) Approvers: - Bradley Dice (https://github.com/bdice) - Matthew Roeschke (https://github.com/mroeschke) URL: https://github.com/rapidsai/cudf/pull/13986 --- python/cudf/cudf/core/_base_index.py | 21 +++++++++++---------- python/cudf/cudf/tests/test_index.py | 16 ++++++++++------ 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/python/cudf/cudf/core/_base_index.py b/python/cudf/cudf/core/_base_index.py index d593f0df138..829ca33d8a5 100644 --- a/python/cudf/cudf/core/_base_index.py +++ b/python/cudf/cudf/core/_base_index.py @@ -651,7 +651,7 @@ def _get_reconciled_name_object(self, other): case make a shallow copy of self. """ name = _get_result_name(self.name, other.name) - if self.name != name: + if not _is_same_name(self.name, name): return self.rename(name) return self @@ -943,17 +943,18 @@ def difference(self, other, sort=None): other = cudf.Index(other) + res_name = _get_result_name(self.name, other.name) + if is_mixed_with_object_dtype(self, other): difference = self.copy() else: other = other.copy(deep=False) - other.names = self.names difference = cudf.core.index._index_from_data( - cudf.DataFrame._from_data(self._data) + cudf.DataFrame._from_data({"None": self._column}) .merge( - cudf.DataFrame._from_data(other._data), + cudf.DataFrame._from_data({"None": other._column}), how="leftanti", - on=self.name, + on="None", ) ._data ) @@ -961,6 +962,8 @@ def difference(self, other, sort=None): if self.dtype != other.dtype: difference = difference.astype(self.dtype) + difference.name = res_name + if sort is None and len(other): return difference.sort_values() @@ -1323,14 +1326,12 @@ def _union(self, other, sort=None): return union_result def _intersection(self, other, sort=None): - other_unique = other.unique() - other_unique.names = self.names intersection_result = cudf.core.index._index_from_data( - cudf.DataFrame._from_data(self.unique()._data) + cudf.DataFrame._from_data({"None": self.unique()._column}) .merge( - cudf.DataFrame._from_data(other_unique._data), + cudf.DataFrame._from_data({"None": other.unique()._column}), how="inner", - on=self.name, + on="None", ) ._data ) diff --git a/python/cudf/cudf/tests/test_index.py b/python/cudf/cudf/tests/test_index.py index 2e6b45058ef..359b3c519de 100644 --- a/python/cudf/cudf/tests/test_index.py +++ b/python/cudf/cudf/tests/test_index.py @@ -804,12 +804,16 @@ def test_index_to_series(data): ], ) @pytest.mark.parametrize("sort", [None, False]) -def test_index_difference(data, other, sort): - pd_data = pd.Index(data) - pd_other = pd.Index(other) +@pytest.mark.parametrize( + "name_data,name_other", + [("abc", "c"), (None, "abc"), ("abc", pd.NA), ("abc", "abc")], +) +def test_index_difference(data, other, sort, name_data, name_other): + pd_data = pd.Index(data, name=name_data) + pd_other = pd.Index(other, name=name_other) - gd_data = cudf.core.index.as_index(data) - gd_other = cudf.core.index.as_index(other) + gd_data = cudf.from_pandas(pd_data) + gd_other = cudf.from_pandas(pd_other) expected = pd_data.difference(pd_other, sort=sort) actual = gd_data.difference(gd_other, sort=sort) @@ -2066,7 +2070,7 @@ def test_union_index(idx1, idx2, sort): (pd.RangeIndex(0, 10), pd.RangeIndex(3, 7)), (pd.RangeIndex(0, 10), pd.RangeIndex(-10, 20)), (pd.RangeIndex(0, 10, name="a"), pd.RangeIndex(90, 100, name="b")), - (pd.Index([0, 1, 2, 30], name="a"), pd.Index([30, 0, 90, 100])), + (pd.Index([0, 1, 2, 30], name=pd.NA), pd.Index([30, 0, 90, 100])), (pd.Index([0, 1, 2, 30], name="a"), [90, 100]), (pd.Index([0, 1, 2, 30]), pd.Index([0, 10, 1.0, 11])), (pd.Index(["a", "b", "c", "d", "c"]), pd.Index(["a", "c", "z"])),