Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use instance over is_foo_dtype #14641

Merged
merged 16 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions python/cudf/cudf/_lib/column.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ from typing import Literal

import cupy as cp
import numpy as np
import pandas as pd

import rmm

import cudf
import cudf._lib as libcudf
from cudf._lib import pylibcudf
from cudf.api.types import is_categorical_dtype, is_datetime64tz_dtype
from cudf.core.buffer import (
Buffer,
ExposureTrackedBuffer,
Expand Down Expand Up @@ -344,10 +344,10 @@ cdef class Column:
)

cdef mutable_column_view mutable_view(self) except *:
if is_categorical_dtype(self.dtype):
if isinstance(self.dtype, cudf.CategoricalDtype):
col = self.base_children[0]
data_dtype = col.dtype
elif is_datetime64tz_dtype(self.dtype):
elif isinstance(self.dtype, pd.DatetimeTZDtype):
col = self
data_dtype = _get_base_dtype(col.dtype)
else:
Expand Down Expand Up @@ -407,10 +407,10 @@ cdef class Column:
return self._view(c_null_count)

cdef column_view _view(self, libcudf_types.size_type null_count) except *:
if is_categorical_dtype(self.dtype):
if isinstance(self.dtype, cudf.CategoricalDtype):
col = self.base_children[0]
data_dtype = col.dtype
elif is_datetime64tz_dtype(self.dtype):
elif isinstance(self.dtype, pd.DatetimeTZDtype):
col = self
data_dtype = _get_base_dtype(col.dtype)
else:
Expand Down Expand Up @@ -482,7 +482,7 @@ cdef class Column:
# categoricals because cudf supports ordered and unordered categoricals
# while libcudf supports only unordered categoricals (see
# https://github.com/rapidsai/cudf/pull/8567).
if is_categorical_dtype(self.dtype):
if isinstance(self.dtype, cudf.CategoricalDtype):
col = self.base_children[0]
else:
col = self
Expand Down Expand Up @@ -648,7 +648,7 @@ cdef class Column:
"""
column_owner = isinstance(owner, Column)
mask_owner = owner
if column_owner and is_categorical_dtype(owner.dtype):
if column_owner and isinstance(owner.dtype, cudf.CategoricalDtype):
owner = owner.base_children[0]

size = cv.size()
Expand Down
74 changes: 48 additions & 26 deletions python/cudf/cudf/_lib/groupby.pyx
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
from functools import singledispatch

from pandas.core.groupby.groupby import DataError

from cudf.api.types import (
is_categorical_dtype,
is_decimal_dtype,
is_interval_dtype,
is_list_dtype,
is_string_dtype,
is_struct_dtype,
)
from cudf.api.types import is_string_dtype
from cudf.core.buffer import acquire_spill_lock
from cudf.core.dtypes import (
CategoricalDtype,
DecimalDtype,
IntervalDtype,
ListDtype,
StructDtype,
)

from libcpp cimport bool
from libcpp.memory cimport unique_ptr
Expand Down Expand Up @@ -73,6 +74,43 @@ _DECIMAL_AGGS = {
ctypedef const scalar constscalar


@singledispatch
def get_valid_aggregation(dtype):
if is_string_dtype(dtype):
return _STRING_AGGS
return "ALL"


@get_valid_aggregation.register
def _(dtype: ListDtype):
return _LIST_AGGS


@get_valid_aggregation.register
def _(dtype: CategoricalDtype):
return _CATEGORICAL_AGGS


@get_valid_aggregation.register
def _(dtype: ListDtype):
return _LIST_AGGS


@get_valid_aggregation.register
def _(dtype: StructDtype):
return _STRUCT_AGGS


@get_valid_aggregation.register
def _(dtype: IntervalDtype):
return _INTERVAL_AGGS


@get_valid_aggregation.register
def _(dtype: DecimalDtype):
return _DECIMAL_AGGS


cdef _agg_result_from_columns(
vector[libcudf_groupby.aggregation_result]& c_result_columns,
set column_included,
Expand Down Expand Up @@ -187,15 +225,7 @@ cdef class GroupBy:
for i, (col, aggs) in enumerate(zip(values, aggregations)):
dtype = col.dtype

valid_aggregations = (
_LIST_AGGS if is_list_dtype(dtype)
else _STRING_AGGS if is_string_dtype(dtype)
else _CATEGORICAL_AGGS if is_categorical_dtype(dtype)
else _STRUCT_AGGS if is_struct_dtype(dtype)
else _INTERVAL_AGGS if is_interval_dtype(dtype)
else _DECIMAL_AGGS if is_decimal_dtype(dtype)
else "ALL"
)
valid_aggregations = get_valid_aggregation(dtype)
included_aggregations_i = []

c_agg_request = move(libcudf_groupby.aggregation_request())
Expand Down Expand Up @@ -258,15 +288,7 @@ cdef class GroupBy:
for i, (col, aggs) in enumerate(zip(values, aggregations)):
dtype = col.dtype

valid_aggregations = (
_LIST_AGGS if is_list_dtype(dtype)
else _STRING_AGGS if is_string_dtype(dtype)
else _CATEGORICAL_AGGS if is_categorical_dtype(dtype)
else _STRUCT_AGGS if is_struct_dtype(dtype)
else _INTERVAL_AGGS if is_interval_dtype(dtype)
else _DECIMAL_AGGS if is_decimal_dtype(dtype)
else "ALL"
)
valid_aggregations = get_valid_aggregation(dtype)
included_aggregations_i = []

c_agg_request = move(libcudf_groupby.scan_request())
Expand Down
10 changes: 5 additions & 5 deletions python/cudf/cudf/_lib/interop.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
# Copyright (c) 2020-2024, NVIDIA CORPORATION.

from cpython cimport pycapsule
from libcpp.memory cimport shared_ptr, unique_ptr
Expand All @@ -18,8 +18,8 @@ from cudf._lib.cpp.table.table cimport table
from cudf._lib.cpp.table.table_view cimport table_view
from cudf._lib.utils cimport columns_from_unique_ptr, table_view_from_columns

from cudf.api.types import is_list_dtype, is_struct_dtype
from cudf.core.buffer import acquire_spill_lock
from cudf.core.dtypes import ListDtype, StructDtype


def from_dlpack(dlpack_capsule):
Expand Down Expand Up @@ -98,7 +98,7 @@ cdef vector[column_metadata] gather_metadata(object cols_dtypes) except *:
if cols_dtypes is not None:
for idx, (col_name, col_dtype) in enumerate(cols_dtypes):
cpp_metadata.push_back(column_metadata(col_name.encode()))
if is_struct_dtype(col_dtype) or is_list_dtype(col_dtype):
if isinstance(col_dtype, (ListDtype, StructDtype)):
_set_col_children_metadata(col_dtype, cpp_metadata[idx])
else:
raise TypeError(
Expand All @@ -113,14 +113,14 @@ cdef _set_col_children_metadata(dtype,

cdef column_metadata element_metadata

if is_struct_dtype(dtype):
if isinstance(dtype, StructDtype):
for name, value in dtype.fields.items():
element_metadata = column_metadata(name.encode())
_set_col_children_metadata(
value, element_metadata
)
col_meta.children_meta.push_back(element_metadata)
elif is_list_dtype(dtype):
elif isinstance(dtype, ListDtype):
col_meta.children_meta.reserve(2)
# Offsets - child 0
col_meta.children_meta.push_back(column_metadata())
Expand Down
6 changes: 3 additions & 3 deletions python/cudf/cudf/_lib/io/utils.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
# Copyright (c) 2020-2024, NVIDIA CORPORATION.

from cpython.buffer cimport PyBUF_READ
from cpython.memoryview cimport PyMemoryView_FromMemory
Expand All @@ -23,7 +23,7 @@ import errno
import io
import os

from cudf.api.types import is_struct_dtype
from cudf.core.dtypes import StructDtype


# Converts the Python source input to libcudf IO source_info
Expand Down Expand Up @@ -172,7 +172,7 @@ cdef Column update_column_struct_field_names(
)
col.set_base_children(tuple(children))

if is_struct_dtype(col):
if isinstance(col.dtype, StructDtype):
field_names.reserve(len(col.base_children))
for i in range(info.children.size()):
field_names.push_back(info.children[i].name)
Expand Down
21 changes: 8 additions & 13 deletions python/cudf/cudf/_lib/json.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2023, NVIDIA CORPORATION.
# Copyright (c) 2019-2024, NVIDIA CORPORATION.

# cython: boundscheck = False

Expand All @@ -17,6 +17,7 @@ from libcpp.utility cimport move
from libcpp.vector cimport vector

cimport cudf._lib.cpp.io.types as cudf_io_types
from cudf._lib.column cimport Column
from cudf._lib.cpp.io.data_sink cimport data_sink
from cudf._lib.cpp.io.json cimport (
json_reader_options,
Expand All @@ -42,10 +43,6 @@ from cudf._lib.io.utils cimport (
from cudf._lib.types cimport dtype_to_data_type
from cudf._lib.utils cimport data_from_unique_ptr, table_view_from_table

from cudf.api.types import is_list_dtype, is_struct_dtype

from cudf._lib.column cimport Column


cpdef read_json(object filepaths_or_buffers,
object dtype,
Expand Down Expand Up @@ -214,13 +211,12 @@ def write_json(
cdef schema_element _get_cudf_schema_element_from_dtype(object dtype) except *:
cdef schema_element s_element
cdef data_type lib_type
if cudf.api.types.is_categorical_dtype(dtype):
dtype = cudf.dtype(dtype)
if isinstance(dtype, cudf.CategoricalDtype):
raise NotImplementedError(
"CategoricalDtype as dtype is not yet "
"supported in JSON reader"
)

dtype = cudf.dtype(dtype)
lib_type = dtype_to_data_type(dtype)
s_element.type = lib_type
if isinstance(dtype, cudf.StructDtype):
Expand All @@ -237,19 +233,18 @@ cdef schema_element _get_cudf_schema_element_from_dtype(object dtype) except *:


cdef data_type _get_cudf_data_type_from_dtype(object dtype) except *:
if cudf.api.types.is_categorical_dtype(dtype):
dtype = cudf.dtype(dtype)
if isinstance(dtype, cudf.CategoricalDtype):
raise NotImplementedError(
"CategoricalDtype as dtype is not yet "
"supported in JSON reader"
)

dtype = cudf.dtype(dtype)
return dtype_to_data_type(dtype)

cdef _set_col_children_metadata(Column col,
column_name_info& col_meta):
cdef column_name_info child_info
if is_struct_dtype(col):
if isinstance(col.dtype, cudf.StructDtype):
for i, (child_col, name) in enumerate(
zip(col.children, list(col.dtype.fields))
):
Expand All @@ -258,7 +253,7 @@ cdef _set_col_children_metadata(Column col,
_set_col_children_metadata(
child_col, col_meta.children[i]
)
elif is_list_dtype(col):
elif isinstance(col.dtype, cudf.ListDtype):
for i, child_col in enumerate(col.children):
col_meta.children.push_back(child_info)
_set_col_children_metadata(
Expand Down
7 changes: 3 additions & 4 deletions python/cudf/cudf/_lib/orc.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
# Copyright (c) 2020-2024, NVIDIA CORPORATION.

import cudf
from cudf.core.buffer import acquire_spill_lock
Expand Down Expand Up @@ -59,7 +59,6 @@ from cudf._lib.utils cimport data_from_unique_ptr, table_view_from_table
from pyarrow.lib import NativeFile

from cudf._lib.utils import _index_level_name, generate_pandas_metadata
from cudf.api.types import is_list_dtype, is_struct_dtype


cpdef read_raw_orc_statistics(filepath_or_buffer):
Expand Down Expand Up @@ -474,15 +473,15 @@ cdef class ORCWriter:
cdef _set_col_children_metadata(Column col,
column_in_metadata& col_meta,
list_column_as_map=False):
if is_struct_dtype(col):
if isinstance(col.dtype, cudf.StructDtype):
for i, (child_col, name) in enumerate(
zip(col.children, list(col.dtype.fields))
):
col_meta.child(i).set_name(name.encode())
_set_col_children_metadata(
child_col, col_meta.child(i), list_column_as_map
)
elif is_list_dtype(col):
elif isinstance(col.dtype, cudf.ListDtype):
if list_column_as_map:
col_meta.set_list_column_as_map()
_set_col_children_metadata(
Expand Down
21 changes: 7 additions & 14 deletions python/cudf/cudf/_lib/parquet.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2023, NVIDIA CORPORATION.
# Copyright (c) 2019-2024, NVIDIA CORPORATION.

# cython: boundscheck = False

Expand All @@ -18,12 +18,7 @@ import numpy as np

from cython.operator cimport dereference

from cudf.api.types import (
is_decimal_dtype,
is_list_dtype,
is_list_like,
is_struct_dtype,
)
from cudf.api.types import is_list_like

from cudf._lib.utils cimport data_from_unique_ptr

Expand Down Expand Up @@ -220,7 +215,7 @@ cpdef read_parquet(filepaths_or_buffers, columns=None, row_groups=None,

# update the decimal precision of each column
for col in names:
if is_decimal_dtype(df._data[col].dtype):
if isinstance(df._data[col].dtype, cudf.core.dtypes.DecimalDtype):
df._data[col].dtype.precision = (
meta_data_per_column[col]["metadata"]["precision"]
)
Expand Down Expand Up @@ -703,7 +698,7 @@ cdef _set_col_metadata(
# is true.
col_meta.set_nullability(True)

if is_struct_dtype(col):
if isinstance(col.dtype, cudf.StructDtype):
for i, (child_col, name) in enumerate(
zip(col.children, list(col.dtype.fields))
):
Expand All @@ -713,13 +708,11 @@ cdef _set_col_metadata(
col_meta.child(i),
force_nullable_schema
)
elif is_list_dtype(col):
elif isinstance(col.dtype, cudf.ListDtype):
_set_col_metadata(
col.children[1],
col_meta.child(1),
force_nullable_schema
)
else:
if is_decimal_dtype(col):
col_meta.set_decimal_precision(col.dtype.precision)
return
elif isinstance(col.dtype, cudf.core.dtypes.DecimalDtype):
col_meta.set_decimal_precision(col.dtype.precision)
Loading
Loading