diff --git a/python/cudf/cudf/core/dtypes.py b/python/cudf/cudf/core/dtypes.py index 9991bad5a9e..070837c127b 100644 --- a/python/cudf/cudf/core/dtypes.py +++ b/python/cudf/cudf/core/dtypes.py @@ -18,6 +18,7 @@ import cudf from cudf._typing import Dtype +from cudf.core._compat import PANDAS_GE_130 from cudf.core.abc import Serializable from cudf.core.buffer import Buffer @@ -545,7 +546,7 @@ class IntervalDtype(StructDtype): """ subtype: str, np.dtype The dtype of the Interval bounds. - closed: {‘right’, ‘left’, ‘both’, ‘neither’}, default ‘right’ + closed: {'right', 'left', 'both', 'neither'}, default 'right' Whether the interval is closed on the left-side, right-side, both or neither. See the Notes for more detailed explanation. """ @@ -555,6 +556,8 @@ class IntervalDtype(StructDtype): def __init__(self, subtype, closed="right"): super().__init__(fields={"left": subtype, "right": subtype}) + if closed is None: + closed = "right" if closed in ["left", "right", "neither", "both"]: self.closed = closed else: @@ -565,7 +568,7 @@ def subtype(self): return self.fields["left"] def __repr__(self): - return f"interval[{self.fields['left']}]" + return f"interval[{self.subtype}, {self.closed}]" @classmethod def from_arrow(cls, typ): @@ -579,9 +582,23 @@ def to_arrow(self): @classmethod def from_pandas(cls, pd_dtype: pd.IntervalDtype) -> "IntervalDtype": - return cls( - subtype=pd_dtype.subtype - ) # TODO: needs `closed` when we upgrade Pandas + if PANDAS_GE_130: + return cls(subtype=pd_dtype.subtype, closed=pd_dtype.closed) + else: + return cls(subtype=pd_dtype.subtype) + + def __eq__(self, other): + if isinstance(other, str): + # This means equality isn't transitive but mimics pandas + return other == self.name + return ( + type(self) == type(other) + and self.subtype == other.subtype + and self.closed == other.closed + ) + + def __hash__(self): + return hash((self.subtype, self.closed)) def serialize(self) -> Tuple[dict, list]: header = { diff --git a/python/cudf/cudf/tests/test_dtypes.py b/python/cudf/cudf/tests/test_dtypes.py index f6a0e41a0c7..811cae929d8 100644 --- a/python/cudf/cudf/tests/test_dtypes.py +++ b/python/cudf/cudf/tests/test_dtypes.py @@ -6,6 +6,7 @@ import pytest import cudf +from cudf.core._compat import PANDAS_GE_130 from cudf.core.column import ColumnBase from cudf.core.dtypes import ( CategoricalDtype, @@ -164,15 +165,34 @@ def test_max_precision(decimal_type, max_precision): decimal_type(scale=0, precision=max_precision + 1) -@pytest.mark.parametrize("fields", ["int64", "int32"]) -@pytest.mark.parametrize("closed", ["left", "right", "both", "neither"]) -def test_interval_dtype_pyarrow_round_trip(fields, closed): - pa_array = pd.core.arrays._arrow_utils.ArrowIntervalType(fields, closed) +@pytest.fixture(params=["int64", "int32"]) +def subtype(request): + return request.param + + +@pytest.fixture(params=["left", "right", "both", "neither"]) +def closed(request): + return request.param + + +def test_interval_dtype_pyarrow_round_trip(subtype, closed): + pa_array = pd.core.arrays._arrow_utils.ArrowIntervalType(subtype, closed) expect = pa_array got = IntervalDtype.from_arrow(expect).to_arrow() assert expect.equals(got) +@pytest.mark.skipif( + not PANDAS_GE_130, + reason="pandas<1.3.0 doesn't have a closed argument for IntervalDtype", +) +def test_interval_dtype_from_pandas(subtype, closed): + expect = cudf.IntervalDtype(subtype, closed=closed) + pd_type = pd.IntervalDtype(subtype, closed=closed) + got = cudf.IntervalDtype.from_pandas(pd_type) + assert expect == got + + def assert_column_array_dtype_equal(column: ColumnBase, array: pa.array): """ In cudf, each column holds its dtype. And since column may have child