Skip to content

Commit

Permalink
Clean up _copy_type_metadata (NVIDIA#11156)
Browse files Browse the repository at this point in the history
This PR moves index-related logic from `Frame._copy_type_metadata` into the `IndexedFrame` override. It also removes all reference to the index parameter from the corresponding methods of various index types.

Authors:
  - Vyas Ramasubramani (https://github.com/vyasr)

Approvers:
  - Ashwin Srinath (https://github.com/shwina)
  - Charles Blackmon-Luca (https://github.com/charlesbluca)

URL: rapidsai/cudf#11156
  • Loading branch information
vyasr authored Jul 4, 2022
1 parent 9e08c73 commit 41ce35f
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 50 deletions.
7 changes: 6 additions & 1 deletion python/cudf/cudf/core/_base_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pickle
import warnings
from functools import cached_property
from typing import Any, Set
from typing import Any, Set, TypeVar

import pandas as pd

Expand Down Expand Up @@ -64,6 +64,8 @@
Float64Index([1.0, 2.0, 3.0], dtype='float64')
"""

BaseIndexT = TypeVar("BaseIndexT", bound="BaseIndex")


class BaseIndex(Serializable):
"""Base class for all cudf Index types."""
Expand Down Expand Up @@ -100,6 +102,9 @@ def __getitem__(self, key):
def __contains__(self, item):
return item in self._values

def _copy_type_metadata(self: BaseIndexT, other: BaseIndexT) -> BaseIndexT:
raise NotImplementedError

def get_level_values(self, level):
"""
Return an Index of values for requested level.
Expand Down
30 changes: 1 addition & 29 deletions python/cudf/cudf/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
Tuple,
TypeVar,
Union,
cast,
)

import cupy
Expand Down Expand Up @@ -1156,9 +1155,7 @@ def _positions_from_column_names(self, column_names):
if name in set(column_names)
]

def _copy_type_metadata(
self, other: Frame, include_index: bool = True
) -> Frame:
def _copy_type_metadata(self: T, other: T) -> T:
"""
Copy type metadata from each column of `other` to the corresponding
column of `self`.
Expand All @@ -1171,31 +1168,6 @@ def _copy_type_metadata(
name, col._with_type_metadata(other_col.dtype), validate=False
)

if include_index:
if self._index is not None and other._index is not None:
self._index._copy_type_metadata(other._index) # type: ignore
# When other._index is a CategoricalIndex, the current index
# will be a NumericalIndex with an underlying CategoricalColumn
# (the above _copy_type_metadata call will have converted the
# column). Calling cudf.Index on that column generates the
# appropriate index.
if isinstance(
other._index, cudf.core.index.CategoricalIndex
) and not isinstance(
self._index, cudf.core.index.CategoricalIndex
):
self._index = cudf.Index(
cast(
cudf.core.index.NumericIndex, self._index
)._column,
name=self._index.name,
)
elif isinstance(
other._index, cudf.MultiIndex
) and not isinstance(self._index, cudf.MultiIndex):
self._index = cudf.MultiIndex._from_data(
self._index._data, name=self._index.name
)
return self

@_cudf_nvtx_annotate
Expand Down
20 changes: 4 additions & 16 deletions python/cudf/cudf/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,7 @@ def __init__(
# whereas _stop is an upper bound.
self._end = self._start + self._step * (len(self._range) - 1)

def _copy_type_metadata(
self, other: Frame, include_index: bool = True
) -> RangeIndex:
def _copy_type_metadata(self: RangeIndex, other: RangeIndex) -> RangeIndex:
# There is no metadata to be copied for RangeIndex since it does not
# have an underlying column.
return self
Expand Down Expand Up @@ -894,22 +892,12 @@ def _binaryop(
return ret.values
return ret

# Override just to make mypy happy.
@_cudf_nvtx_annotate
def _copy_type_metadata(
self, other: Frame, include_index: bool = True
self: GenericIndex, other: GenericIndex
) -> GenericIndex:
"""
Copy type metadata from each column of `other` to the corresponding
column of `self`.
See `ColumnBase._with_type_metadata` for more information.
"""
for name, col, other_col in zip(
self._data.keys(), self._data.values(), other._data.values()
):
self._data.set_by_label(
name, col._with_type_metadata(other_col.dtype), validate=False
)
return self
return super()._copy_type_metadata(other)

@property # type: ignore
@_cudf_nvtx_annotate
Expand Down
38 changes: 38 additions & 0 deletions python/cudf/cudf/core/indexed_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Type,
TypeVar,
Union,
cast,
)
from uuid import uuid4

Expand Down Expand Up @@ -886,6 +887,43 @@ def clip(self, lower=None, upper=None, inplace=False, axis=1):
output._copy_type_metadata(self, include_index=False)
return self._mimic_inplace(output, inplace=inplace)

def _copy_type_metadata(
self: T, other: T, include_index: bool = True
) -> T:
"""
Copy type metadata from each column of `other` to the corresponding
column of `self`.
See `ColumnBase._with_type_metadata` for more information.
"""
super()._copy_type_metadata(other)

if include_index:
if self._index is not None and other._index is not None:
self._index._copy_type_metadata(other._index)
# When other._index is a CategoricalIndex, the current index
# will be a NumericalIndex with an underlying CategoricalColumn
# (the above _copy_type_metadata call will have converted the
# column). Calling cudf.Index on that column generates the
# appropriate index.
if isinstance(
other._index, cudf.core.index.CategoricalIndex
) and not isinstance(
self._index, cudf.core.index.CategoricalIndex
):
self._index = cudf.Index(
cast(
cudf.core.index.NumericIndex, self._index
)._column,
name=self._index.name,
)
elif isinstance(
other._index, cudf.MultiIndex
) and not isinstance(self._index, cudf.MultiIndex):
self._index = cudf.MultiIndex._from_data(
self._index._data, name=self._index.name
)
return self

@_cudf_nvtx_annotate
def interpolate(
self,
Expand Down
6 changes: 2 additions & 4 deletions python/cudf/cudf/core/multiindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -1803,10 +1803,8 @@ def _intersection(self, other, sort=None):
return midx

@_cudf_nvtx_annotate
def _copy_type_metadata(
self, other: Frame, include_index: bool = True
) -> Frame:
res = super()._copy_type_metadata(other, include_index=include_index)
def _copy_type_metadata(self: MultiIndex, other: MultiIndex) -> MultiIndex:
res = super()._copy_type_metadata(other)
res._names = other._names
return res

Expand Down

0 comments on commit 41ce35f

Please sign in to comment.