diff --git a/python/cudf/cudf/core/_base_index.py b/python/cudf/cudf/core/_base_index.py index 031dd58478b..24a810e35a8 100644 --- a/python/cudf/cudf/core/_base_index.py +++ b/python/cudf/cudf/core/_base_index.py @@ -5,7 +5,7 @@ import pickle import warnings from functools import cached_property -from typing import Any, Set +from typing import Any, Set, TypeVar import pandas as pd @@ -64,6 +64,8 @@ Float64Index([1.0, 2.0, 3.0], dtype='float64') """ +BaseIndexT = TypeVar("BaseIndexT", bound="BaseIndex") + class BaseIndex(Serializable): """Base class for all cudf Index types.""" @@ -100,6 +102,9 @@ def __getitem__(self, key): def __contains__(self, item): return item in self._values + def _copy_type_metadata(self: BaseIndexT, other: BaseIndexT) -> BaseIndexT: + raise NotImplementedError + def get_level_values(self, level): """ Return an Index of values for requested level. diff --git a/python/cudf/cudf/core/frame.py b/python/cudf/cudf/core/frame.py index b4a2b1bd675..d7a3858d0e9 100644 --- a/python/cudf/cudf/core/frame.py +++ b/python/cudf/cudf/core/frame.py @@ -17,7 +17,6 @@ Tuple, TypeVar, Union, - cast, ) import cupy @@ -1156,9 +1155,7 @@ def _positions_from_column_names(self, column_names): if name in set(column_names) ] - def _copy_type_metadata( - self, other: Frame, include_index: bool = True - ) -> Frame: + def _copy_type_metadata(self: T, other: T) -> T: """ Copy type metadata from each column of `other` to the corresponding column of `self`. @@ -1171,31 +1168,6 @@ def _copy_type_metadata( name, col._with_type_metadata(other_col.dtype), validate=False ) - if include_index: - if self._index is not None and other._index is not None: - self._index._copy_type_metadata(other._index) # type: ignore - # When other._index is a CategoricalIndex, the current index - # will be a NumericalIndex with an underlying CategoricalColumn - # (the above _copy_type_metadata call will have converted the - # column). Calling cudf.Index on that column generates the - # appropriate index. - if isinstance( - other._index, cudf.core.index.CategoricalIndex - ) and not isinstance( - self._index, cudf.core.index.CategoricalIndex - ): - self._index = cudf.Index( - cast( - cudf.core.index.NumericIndex, self._index - )._column, - name=self._index.name, - ) - elif isinstance( - other._index, cudf.MultiIndex - ) and not isinstance(self._index, cudf.MultiIndex): - self._index = cudf.MultiIndex._from_data( - self._index._data, name=self._index.name - ) return self @_cudf_nvtx_annotate diff --git a/python/cudf/cudf/core/index.py b/python/cudf/cudf/core/index.py index 4df324fb12e..c09eef8a312 100644 --- a/python/cudf/cudf/core/index.py +++ b/python/cudf/cudf/core/index.py @@ -183,9 +183,7 @@ def __init__( # whereas _stop is an upper bound. self._end = self._start + self._step * (len(self._range) - 1) - def _copy_type_metadata( - self, other: Frame, include_index: bool = True - ) -> RangeIndex: + def _copy_type_metadata(self: RangeIndex, other: RangeIndex) -> RangeIndex: # There is no metadata to be copied for RangeIndex since it does not # have an underlying column. return self @@ -894,22 +892,12 @@ def _binaryop( return ret.values return ret + # Override just to make mypy happy. @_cudf_nvtx_annotate def _copy_type_metadata( - self, other: Frame, include_index: bool = True + self: GenericIndex, other: GenericIndex ) -> GenericIndex: - """ - Copy type metadata from each column of `other` to the corresponding - column of `self`. - See `ColumnBase._with_type_metadata` for more information. - """ - for name, col, other_col in zip( - self._data.keys(), self._data.values(), other._data.values() - ): - self._data.set_by_label( - name, col._with_type_metadata(other_col.dtype), validate=False - ) - return self + return super()._copy_type_metadata(other) @property # type: ignore @_cudf_nvtx_annotate diff --git a/python/cudf/cudf/core/indexed_frame.py b/python/cudf/cudf/core/indexed_frame.py index df7809c687b..84a5f3f1caf 100644 --- a/python/cudf/cudf/core/indexed_frame.py +++ b/python/cudf/cudf/core/indexed_frame.py @@ -20,6 +20,7 @@ Type, TypeVar, Union, + cast, ) from uuid import uuid4 @@ -886,6 +887,43 @@ def clip(self, lower=None, upper=None, inplace=False, axis=1): output._copy_type_metadata(self, include_index=False) return self._mimic_inplace(output, inplace=inplace) + def _copy_type_metadata( + self: T, other: T, include_index: bool = True + ) -> T: + """ + Copy type metadata from each column of `other` to the corresponding + column of `self`. + See `ColumnBase._with_type_metadata` for more information. + """ + super()._copy_type_metadata(other) + + if include_index: + if self._index is not None and other._index is not None: + self._index._copy_type_metadata(other._index) + # When other._index is a CategoricalIndex, the current index + # will be a NumericalIndex with an underlying CategoricalColumn + # (the above _copy_type_metadata call will have converted the + # column). Calling cudf.Index on that column generates the + # appropriate index. + if isinstance( + other._index, cudf.core.index.CategoricalIndex + ) and not isinstance( + self._index, cudf.core.index.CategoricalIndex + ): + self._index = cudf.Index( + cast( + cudf.core.index.NumericIndex, self._index + )._column, + name=self._index.name, + ) + elif isinstance( + other._index, cudf.MultiIndex + ) and not isinstance(self._index, cudf.MultiIndex): + self._index = cudf.MultiIndex._from_data( + self._index._data, name=self._index.name + ) + return self + @_cudf_nvtx_annotate def interpolate( self, diff --git a/python/cudf/cudf/core/multiindex.py b/python/cudf/cudf/core/multiindex.py index be9ac822653..dbd6f9739ea 100644 --- a/python/cudf/cudf/core/multiindex.py +++ b/python/cudf/cudf/core/multiindex.py @@ -1803,10 +1803,8 @@ def _intersection(self, other, sort=None): return midx @_cudf_nvtx_annotate - def _copy_type_metadata( - self, other: Frame, include_index: bool = True - ) -> Frame: - res = super()._copy_type_metadata(other, include_index=include_index) + def _copy_type_metadata(self: MultiIndex, other: MultiIndex) -> MultiIndex: + res = super()._copy_type_metadata(other) res._names = other._names return res