From b2164c2b432f42aa07130fbfc63115f2fb303b02 Mon Sep 17 00:00:00 2001 From: Vyas Ramasubramani Date: Thu, 8 Feb 2024 09:05:39 -0800 Subject: [PATCH] Implement rolling in pylibcudf (#14982) Contributes to #13921 Authors: - Vyas Ramasubramani (https://github.com/vyasr) Approvers: - Matthew Roeschke (https://github.com/mroeschke) URL: https://github.com/rapidsai/cudf/pull/14982 --- .../user_guide/api_docs/pylibcudf/index.rst | 1 + .../user_guide/api_docs/pylibcudf/rolling.rst | 6 + python/cudf/cudf/_lib/aggregation.pxd | 16 - python/cudf/cudf/_lib/aggregation.pyx | 327 +++--------------- python/cudf/cudf/_lib/cpp/aggregation.pxd | 2 - python/cudf/cudf/_lib/cpp/rolling.pxd | 6 +- python/cudf/cudf/_lib/groupby.pyx | 2 +- .../cudf/cudf/_lib/pylibcudf/CMakeLists.txt | 2 +- python/cudf/cudf/_lib/pylibcudf/__init__.pxd | 3 + python/cudf/cudf/_lib/pylibcudf/__init__.py | 3 + .../cudf/cudf/_lib/pylibcudf/aggregation.pxd | 3 + .../cudf/cudf/_lib/pylibcudf/aggregation.pyx | 8 + python/cudf/cudf/_lib/pylibcudf/rolling.pxd | 19 + python/cudf/cudf/_lib/pylibcudf/rolling.pyx | 73 ++++ python/cudf/cudf/_lib/reduce.pyx | 2 +- python/cudf/cudf/_lib/rolling.pyx | 71 ++-- python/cudf/cudf/_lib/sort.pyx | 15 +- python/cudf/cudf/core/indexed_frame.py | 2 +- 18 files changed, 187 insertions(+), 374 deletions(-) create mode 100644 docs/cudf/source/user_guide/api_docs/pylibcudf/rolling.rst delete mode 100644 python/cudf/cudf/_lib/aggregation.pxd create mode 100644 python/cudf/cudf/_lib/pylibcudf/rolling.pxd create mode 100644 python/cudf/cudf/_lib/pylibcudf/rolling.pyx diff --git a/docs/cudf/source/user_guide/api_docs/pylibcudf/index.rst b/docs/cudf/source/user_guide/api_docs/pylibcudf/index.rst index 4772d654a3c..91b84d29ddf 100644 --- a/docs/cudf/source/user_guide/api_docs/pylibcudf/index.rst +++ b/docs/cudf/source/user_guide/api_docs/pylibcudf/index.rst @@ -16,6 +16,7 @@ This page provides API documentation for pylibcudf. groupby join reduce + rolling scalar table types diff --git a/docs/cudf/source/user_guide/api_docs/pylibcudf/rolling.rst b/docs/cudf/source/user_guide/api_docs/pylibcudf/rolling.rst new file mode 100644 index 00000000000..0817d117a94 --- /dev/null +++ b/docs/cudf/source/user_guide/api_docs/pylibcudf/rolling.rst @@ -0,0 +1,6 @@ +======= +rolling +======= + +.. automodule:: cudf._lib.pylibcudf.rolling + :members: diff --git a/python/cudf/cudf/_lib/aggregation.pxd b/python/cudf/cudf/_lib/aggregation.pxd deleted file mode 100644 index 7a2a2b022fb..00000000000 --- a/python/cudf/cudf/_lib/aggregation.pxd +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) 2020-2024, NVIDIA CORPORATION. - -from libcpp.memory cimport unique_ptr - -from cudf._lib cimport pylibcudf -from cudf._lib.cpp.aggregation cimport rolling_aggregation - - -cdef class RollingAggregation: - cdef unique_ptr[rolling_aggregation] c_obj - -cdef class Aggregation: - cdef pylibcudf.aggregation.Aggregation c_obj - -cdef RollingAggregation make_rolling_aggregation(op, kwargs=*) -cdef Aggregation make_aggregation(op, kwargs=*) diff --git a/python/cudf/cudf/_lib/aggregation.pyx b/python/cudf/cudf/_lib/aggregation.pyx index 036c922e128..de3cbb07c37 100644 --- a/python/cudf/cudf/_lib/aggregation.pyx +++ b/python/cudf/cudf/_lib/aggregation.pyx @@ -1,253 +1,31 @@ # Copyright (c) 2020-2024, NVIDIA CORPORATION. -from enum import Enum, IntEnum - import pandas as pd - -from libcpp.string cimport string -from libcpp.utility cimport move - -from cudf._lib.types import SUPPORTED_NUMPY_TO_LIBCUDF_TYPES, NullHandling -from cudf.utils import cudautils - -from cudf._lib.types cimport ( - underlying_type_t_null_policy, - underlying_type_t_type_id, -) - from numba.np import numpy_support -cimport cudf._lib.cpp.aggregation as libcudf_aggregation -cimport cudf._lib.cpp.types as libcudf_types -from cudf._lib.cpp.aggregation cimport underlying_type_t_correlation_type - import cudf - -from cudf._lib cimport pylibcudf - from cudf._lib import pylibcudf +from cudf._lib.types import SUPPORTED_NUMPY_TO_PYLIBCUDF_TYPES +from cudf.utils import cudautils +_agg_name_map = { + "COUNT_VALID": "COUNT", + "COUNT_ALL": "SIZE", + "VARIANCE": "VAR", + "NTH_ELEMENT": "NTH", + "COLLECT_LIST": "COLLECT", + "COLLECT_SET": "UNIQUE", +} -class AggregationKind(Enum): - SUM = libcudf_aggregation.aggregation.Kind.SUM - PRODUCT = libcudf_aggregation.aggregation.Kind.PRODUCT - MIN = libcudf_aggregation.aggregation.Kind.MIN - MAX = libcudf_aggregation.aggregation.Kind.MAX - COUNT = libcudf_aggregation.aggregation.Kind.COUNT_VALID - SIZE = libcudf_aggregation.aggregation.Kind.COUNT_ALL - ANY = libcudf_aggregation.aggregation.Kind.ANY - ALL = libcudf_aggregation.aggregation.Kind.ALL - SUM_OF_SQUARES = libcudf_aggregation.aggregation.Kind.SUM_OF_SQUARES - MEAN = libcudf_aggregation.aggregation.Kind.MEAN - VAR = libcudf_aggregation.aggregation.Kind.VARIANCE - STD = libcudf_aggregation.aggregation.Kind.STD - MEDIAN = libcudf_aggregation.aggregation.Kind.MEDIAN - QUANTILE = libcudf_aggregation.aggregation.Kind.QUANTILE - ARGMAX = libcudf_aggregation.aggregation.Kind.ARGMAX - ARGMIN = libcudf_aggregation.aggregation.Kind.ARGMIN - NUNIQUE = libcudf_aggregation.aggregation.Kind.NUNIQUE - NTH = libcudf_aggregation.aggregation.Kind.NTH_ELEMENT - RANK = libcudf_aggregation.aggregation.Kind.RANK - COLLECT = libcudf_aggregation.aggregation.Kind.COLLECT_LIST - UNIQUE = libcudf_aggregation.aggregation.Kind.COLLECT_SET - PTX = libcudf_aggregation.aggregation.Kind.PTX - CUDA = libcudf_aggregation.aggregation.Kind.CUDA - CORRELATION = libcudf_aggregation.aggregation.Kind.CORRELATION - COVARIANCE = libcudf_aggregation.aggregation.Kind.COVARIANCE - - -class CorrelationType(IntEnum): - PEARSON = ( - - libcudf_aggregation.correlation_type.PEARSON - ) - KENDALL = ( - - libcudf_aggregation.correlation_type.KENDALL - ) - SPEARMAN = ( - - libcudf_aggregation.correlation_type.SPEARMAN - ) - - -class RankMethod(IntEnum): - FIRST = libcudf_aggregation.rank_method.FIRST - AVERAGE = libcudf_aggregation.rank_method.AVERAGE - MIN = libcudf_aggregation.rank_method.MIN - MAX = libcudf_aggregation.rank_method.MAX - DENSE = libcudf_aggregation.rank_method.DENSE - - -cdef class RollingAggregation: - """A Cython wrapper for rolling window aggregations. - - **This class should never be instantiated using a standard constructor, - only using one of its many factories.** These factories handle mapping - different cudf operations to their libcudf analogs, e.g. - `cudf.DataFrame.idxmin` -> `libcudf.argmin`. Additionally, they perform - any additional configuration needed to translate Python arguments into - their corresponding C++ types (for instance, C++ enumerations used for - flag arguments). The factory approach is necessary to support operations - like `df.agg(lambda x: x.sum())`; such functions are called with this - class as an argument to generation the desired aggregation. - """ - @property - def kind(self): - return AggregationKind(self.c_obj.get()[0].kind).name - - @classmethod - def sum(cls): - cdef RollingAggregation agg = cls() - agg.c_obj = move( - libcudf_aggregation.make_sum_aggregation[rolling_aggregation]()) - return agg - - @classmethod - def min(cls): - cdef RollingAggregation agg = cls() - agg.c_obj = move( - libcudf_aggregation.make_min_aggregation[rolling_aggregation]()) - return agg - - @classmethod - def max(cls): - cdef RollingAggregation agg = cls() - agg.c_obj = move( - libcudf_aggregation.make_max_aggregation[rolling_aggregation]()) - return agg - - @classmethod - def idxmin(cls): - cdef RollingAggregation agg = cls() - agg.c_obj = move( - libcudf_aggregation.make_argmin_aggregation[ - rolling_aggregation]()) - return agg - - @classmethod - def idxmax(cls): - cdef RollingAggregation agg = cls() - agg.c_obj = move( - libcudf_aggregation.make_argmax_aggregation[ - rolling_aggregation]()) - return agg - - @classmethod - def mean(cls): - cdef RollingAggregation agg = cls() - agg.c_obj = move( - libcudf_aggregation.make_mean_aggregation[rolling_aggregation]()) - return agg - - @classmethod - def var(cls, ddof=1): - cdef RollingAggregation agg = cls() - agg.c_obj = move( - libcudf_aggregation.make_variance_aggregation[rolling_aggregation]( - ddof - ) - ) - return agg - - @classmethod - def std(cls, ddof=1): - cdef RollingAggregation agg = cls() - agg.c_obj = move( - libcudf_aggregation.make_std_aggregation[rolling_aggregation](ddof) - ) - return agg - - @classmethod - def count(cls, dropna=True): - cdef libcudf_types.null_policy c_null_handling - if dropna: - c_null_handling = libcudf_types.null_policy.EXCLUDE - else: - c_null_handling = libcudf_types.null_policy.INCLUDE - - cdef RollingAggregation agg = cls() - agg.c_obj = move( - libcudf_aggregation.make_count_aggregation[rolling_aggregation]( - c_null_handling - )) - return agg - - @classmethod - def size(cls): - cdef RollingAggregation agg = cls() - agg.c_obj = move( - libcudf_aggregation.make_count_aggregation[rolling_aggregation]( - ( - NullHandling.INCLUDE) - )) - return agg - - @classmethod - def collect(cls): - cdef RollingAggregation agg = cls() - agg.c_obj = move( - libcudf_aggregation.make_collect_list_aggregation[ - rolling_aggregation](libcudf_types.null_policy.INCLUDE)) - return agg - - @classmethod - def from_udf(cls, op, *args, **kwargs): - cdef RollingAggregation agg = cls() - - cdef libcudf_types.type_id tid - cdef libcudf_types.data_type out_dtype - cdef string cpp_str - - # Handling UDF type - nb_type = numpy_support.from_dtype(kwargs['dtype']) - type_signature = (nb_type[:],) - compiled_op = cudautils.compile_udf(op, type_signature) - output_np_dtype = cudf.dtype(compiled_op[1]) - cpp_str = compiled_op[0].encode('UTF-8') - if output_np_dtype not in SUPPORTED_NUMPY_TO_LIBCUDF_TYPES: - raise TypeError( - "Result of window function has unsupported dtype {}" - .format(op[1]) - ) - tid = ( - ( - ( - SUPPORTED_NUMPY_TO_LIBCUDF_TYPES[output_np_dtype] - ) - ) - ) - out_dtype = libcudf_types.data_type(tid) - - agg.c_obj = move( - libcudf_aggregation.make_udf_aggregation[rolling_aggregation]( - libcudf_aggregation.udf_type.PTX, cpp_str, out_dtype - )) - return agg - - # scan aggregations - # TODO: update this after adding per algorithm aggregation derived types - # https://github.com/rapidsai/cudf/issues/7106 - cumsum = sum - cummin = min - cummax = max - @classmethod - def cumcount(cls): - cdef RollingAggregation agg = cls() - agg.c_obj = move( - libcudf_aggregation.make_count_aggregation[rolling_aggregation]( - libcudf_types.null_policy.INCLUDE - )) - return agg - -cdef class Aggregation: - def __init__(self, pylibcudf.aggregation.Aggregation agg): +class Aggregation: + def __init__(self, agg): self.c_obj = agg @property def kind(self): - return AggregationKind(int(self.c_obj.kind())).name + name = self.c_obj.kind().name + return _agg_name_map.get(name, name) @classmethod def sum(cls): @@ -295,7 +73,7 @@ cdef class Aggregation: return cls(pylibcudf.aggregation.nunique(pylibcudf.types.NullPolicy.EXCLUDE)) @classmethod - def nth(cls, libcudf_types.size_type size): + def nth(cls, size): return cls(pylibcudf.aggregation.nth_element(size)) @classmethod @@ -350,7 +128,7 @@ cdef class Aggregation: ) @classmethod - def corr(cls, method, libcudf_types.size_type min_periods): + def corr(cls, method, min_periods): return cls(pylibcudf.aggregation.correlation( pylibcudf.aggregation.CorrelationType[method.upper()], min_periods @@ -358,11 +136,7 @@ cdef class Aggregation: )) @classmethod - def cov( - cls, - libcudf_types.size_type min_periods, - libcudf_types.size_type ddof=1 - ): + def cov(cls, min_periods, ddof=1): return cls(pylibcudf.aggregation.covariance( min_periods, ddof @@ -403,46 +177,26 @@ cdef class Aggregation: def all(cls): return cls(pylibcudf.aggregation.all()) + # Rolling aggregations + @classmethod + def from_udf(cls, op, *args, **kwargs): + # Handling UDF type + nb_type = numpy_support.from_dtype(kwargs['dtype']) + type_signature = (nb_type[:],) + ptx_code, output_dtype = cudautils.compile_udf(op, type_signature) + output_np_dtype = cudf.dtype(output_dtype) + if output_np_dtype not in SUPPORTED_NUMPY_TO_PYLIBCUDF_TYPES: + raise TypeError(f"Result of window function has unsupported dtype {op[1]}") -cdef RollingAggregation make_rolling_aggregation(op, kwargs=None): - r""" - Parameters - ---------- - op : str or callable - If callable, must meet one of the following requirements: - - * Is of the form lambda x: x.agg(*args, **kwargs), where - `agg` is the name of a supported aggregation. Used to - to specify aggregations that take arguments, e.g., - `lambda x: x.quantile(0.5)`. - * Is a user defined aggregation function that operates on - group values. In this case, the output dtype must be - specified in the `kwargs` dictionary. - \*\*kwargs : dict, optional - Any keyword arguments to be passed to the op. - - Returns - ------- - RollingAggregation - """ - if kwargs is None: - kwargs = {} + return cls( + pylibcudf.aggregation.udf( + ptx_code, + pylibcudf.DataType(SUPPORTED_NUMPY_TO_PYLIBCUDF_TYPES[output_np_dtype]), + ) + ) - cdef RollingAggregation agg - if isinstance(op, str): - agg = getattr(RollingAggregation, op)(**kwargs) - elif callable(op): - if op is list: - agg = RollingAggregation.collect() - elif "dtype" in kwargs: - agg = RollingAggregation.from_udf(op, **kwargs) - else: - agg = op(RollingAggregation) - else: - raise TypeError(f"Unknown aggregation {op}") - return agg -cdef Aggregation make_aggregation(op, kwargs=None): +def make_aggregation(op, kwargs=None): r""" Parameters ---------- @@ -466,16 +220,13 @@ cdef Aggregation make_aggregation(op, kwargs=None): if kwargs is None: kwargs = {} - cdef Aggregation agg if isinstance(op, str): - agg = getattr(Aggregation, op)(**kwargs) + return getattr(Aggregation, op)(**kwargs) elif callable(op): if op is list: - agg = Aggregation.collect() + return Aggregation.collect() elif "dtype" in kwargs: - agg = Aggregation.from_udf(op, **kwargs) + return Aggregation.from_udf(op, **kwargs) else: - agg = op(Aggregation) - else: - raise TypeError(f"Unknown aggregation {op}") - return agg + return op(Aggregation) + raise TypeError(f"Unknown aggregation {op}") diff --git a/python/cudf/cudf/_lib/cpp/aggregation.pxd b/python/cudf/cudf/_lib/cpp/aggregation.pxd index 16f48b30a50..91b9d7d024f 100644 --- a/python/cudf/cudf/_lib/cpp/aggregation.pxd +++ b/python/cudf/cudf/_lib/cpp/aggregation.pxd @@ -16,8 +16,6 @@ from cudf._lib.cpp.types cimport ( size_type, ) -ctypedef int32_t underlying_type_t_correlation_type -ctypedef int32_t underlying_type_t_rank_method cdef extern from "cudf/aggregation.hpp" namespace "cudf" nogil: diff --git a/python/cudf/cudf/_lib/cpp/rolling.pxd b/python/cudf/cudf/_lib/cpp/rolling.pxd index df2e833edc2..6b620e3a4c0 100644 --- a/python/cudf/cudf/_lib/cpp/rolling.pxd +++ b/python/cudf/cudf/_lib/cpp/rolling.pxd @@ -1,4 +1,4 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. from libcpp.memory cimport unique_ptr @@ -16,11 +16,11 @@ cdef extern from "cudf/rolling.hpp" namespace "cudf" nogil: column_view preceding_window, column_view following_window, size_type min_periods, - rolling_aggregation agg) except + + rolling_aggregation& agg) except + cdef unique_ptr[column] rolling_window( column_view source, size_type preceding_window, size_type following_window, size_type min_periods, - rolling_aggregation agg) except + + rolling_aggregation& agg) except + diff --git a/python/cudf/cudf/_lib/groupby.pyx b/python/cudf/cudf/_lib/groupby.pyx index 8384d5231b7..05300a41009 100644 --- a/python/cudf/cudf/_lib/groupby.pyx +++ b/python/cudf/cudf/_lib/groupby.pyx @@ -18,11 +18,11 @@ from cudf._lib.utils cimport columns_from_pylibcudf_table from cudf._lib.scalar import as_device_scalar -from cudf._lib.aggregation cimport make_aggregation from cudf._lib.cpp.replace cimport replace_policy from cudf._lib.cpp.scalar.scalar cimport scalar from cudf._lib import pylibcudf +from cudf._lib.aggregation import make_aggregation # The sets below define the possible aggregations that can be performed on # different dtypes. These strings must be elements of the AggregationKind enum. diff --git a/python/cudf/cudf/_lib/pylibcudf/CMakeLists.txt b/python/cudf/cudf/_lib/pylibcudf/CMakeLists.txt index 6144fd07ac0..5eb0e5cdf82 100644 --- a/python/cudf/cudf/_lib/pylibcudf/CMakeLists.txt +++ b/python/cudf/cudf/_lib/pylibcudf/CMakeLists.txt @@ -14,7 +14,7 @@ set(cython_sources aggregation.pyx binaryop.pyx column.pyx copying.pyx gpumemoryview.pyx groupby.pyx interop.pyx - join.pyx reduce.pyx scalar.pyx table.pyx types.pyx unary.pyx utils.pyx + join.pyx reduce.pyx rolling.pyx scalar.pyx table.pyx types.pyx unary.pyx utils.pyx ) set(linked_libraries cudf::cudf) rapids_cython_create_modules( diff --git a/python/cudf/cudf/_lib/pylibcudf/__init__.pxd b/python/cudf/cudf/_lib/pylibcudf/__init__.pxd index 74afa2dbacd..df65e893b68 100644 --- a/python/cudf/cudf/_lib/pylibcudf/__init__.pxd +++ b/python/cudf/cudf/_lib/pylibcudf/__init__.pxd @@ -9,6 +9,8 @@ from . cimport ( interop, join, reduce, + rolling, + types, unary, ) from .column cimport Column @@ -33,5 +35,6 @@ __all__ = [ "join", "unary", "reduce", + "rolling", "types", ] diff --git a/python/cudf/cudf/_lib/pylibcudf/__init__.py b/python/cudf/cudf/_lib/pylibcudf/__init__.py index 96663d365a8..52dded12071 100644 --- a/python/cudf/cudf/_lib/pylibcudf/__init__.py +++ b/python/cudf/cudf/_lib/pylibcudf/__init__.py @@ -8,6 +8,8 @@ interop, join, reduce, + rolling, + types, unary, ) from .column import Column @@ -31,5 +33,6 @@ "join", "unary", "reduce", + "rolling", "types", ] diff --git a/python/cudf/cudf/_lib/pylibcudf/aggregation.pxd b/python/cudf/cudf/_lib/pylibcudf/aggregation.pxd index 1b7da5a5532..a9491793b88 100644 --- a/python/cudf/cudf/_lib/pylibcudf/aggregation.pxd +++ b/python/cudf/cudf/_lib/pylibcudf/aggregation.pxd @@ -11,6 +11,7 @@ from cudf._lib.cpp.aggregation cimport ( rank_method, rank_percentage, reduce_aggregation, + rolling_aggregation, scan_aggregation, ) from cudf._lib.cpp.types cimport ( @@ -30,6 +31,7 @@ ctypedef groupby_aggregation * gba_ptr ctypedef groupby_scan_aggregation * gbsa_ptr ctypedef reduce_aggregation * ra_ptr ctypedef scan_aggregation * sa_ptr +ctypedef rolling_aggregation * roa_ptr cdef class Aggregation: @@ -42,6 +44,7 @@ cdef class Aggregation: ) except * cdef const reduce_aggregation* view_underlying_as_reduce(self) except * cdef const scan_aggregation* view_underlying_as_scan(self) except * + cdef const rolling_aggregation* view_underlying_as_rolling(self) except * @staticmethod cdef Aggregation from_libcudf(unique_ptr[aggregation] agg) diff --git a/python/cudf/cudf/_lib/pylibcudf/aggregation.pyx b/python/cudf/cudf/_lib/pylibcudf/aggregation.pyx index 0020a0c681d..fe7daea38bf 100644 --- a/python/cudf/cudf/_lib/pylibcudf/aggregation.pyx +++ b/python/cudf/cudf/_lib/pylibcudf/aggregation.pyx @@ -36,6 +36,7 @@ from cudf._lib.cpp.aggregation cimport ( rank_method, rank_percentage, reduce_aggregation, + rolling_aggregation, scan_aggregation, ) from cudf._lib.cpp.types cimport ( @@ -124,6 +125,13 @@ cdef class Aggregation: self._unsupported_agg_error("scan") return agg_cast + cdef const rolling_aggregation* view_underlying_as_rolling(self) except *: + """View the underlying aggregation as a rolling_aggregation.""" + cdef rolling_aggregation *agg_cast = dynamic_cast[roa_ptr](self.c_obj.get()) + if agg_cast is NULL: + self._unsupported_agg_error("rolling") + return agg_cast + @staticmethod cdef Aggregation from_libcudf(unique_ptr[aggregation] agg): """Create a Python Aggregation from a libcudf aggregation.""" diff --git a/python/cudf/cudf/_lib/pylibcudf/rolling.pxd b/python/cudf/cudf/_lib/pylibcudf/rolling.pxd new file mode 100644 index 00000000000..88d683c0c35 --- /dev/null +++ b/python/cudf/cudf/_lib/pylibcudf/rolling.pxd @@ -0,0 +1,19 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +from cudf._lib.cpp.types cimport size_type + +from .aggregation cimport Aggregation +from .column cimport Column + +ctypedef fused WindowType: + Column + size_type + + +cpdef Column rolling_window( + Column source, + WindowType preceding_window, + WindowType following_window, + size_type min_periods, + Aggregation agg, +) diff --git a/python/cudf/cudf/_lib/pylibcudf/rolling.pyx b/python/cudf/cudf/_lib/pylibcudf/rolling.pyx new file mode 100644 index 00000000000..8a1d83911ca --- /dev/null +++ b/python/cudf/cudf/_lib/pylibcudf/rolling.pyx @@ -0,0 +1,73 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +from cython.operator cimport dereference +from libcpp.memory cimport unique_ptr +from libcpp.utility cimport move + +from cudf._lib.cpp cimport rolling as cpp_rolling +from cudf._lib.cpp.aggregation cimport rolling_aggregation +from cudf._lib.cpp.column.column cimport column +from cudf._lib.cpp.types cimport size_type + +from .aggregation cimport Aggregation +from .column cimport Column + + +cpdef Column rolling_window( + Column source, + WindowType preceding_window, + WindowType following_window, + size_type min_periods, + Aggregation agg, +): + """Perform a rolling window operation on a column + + For details, see ``cudf::rolling_window`` documentation. + + Parameters + ---------- + source : Column + The column to perform the rolling window operation on. + preceding_window : Union[Column, size_type] + The column containing the preceding window sizes or a scalar value + indicating the sizes of all windows. + following_window : Union[Column, size_type] + The column containing the following window sizes or a scalar value + indicating the sizes of all windows. + min_periods : int + The minimum number of periods to include in the result. + agg : Aggregation + The aggregation to perform. + + Returns + ------- + Column + The result of the rolling window operation. + """ + cdef unique_ptr[column] result + # TODO: Consider making all the conversion functions nogil functions that + # reclaim the GIL internally for just the necessary scope like column.view() + cdef const rolling_aggregation *c_agg = agg.view_underlying_as_rolling() + if WindowType is Column: + with nogil: + result = move( + cpp_rolling.rolling_window( + source.view(), + preceding_window.view(), + following_window.view(), + min_periods, + dereference(c_agg), + ) + ) + else: + with nogil: + result = move( + cpp_rolling.rolling_window( + source.view(), + preceding_window, + following_window, + min_periods, + dereference(c_agg), + ) + ) + return Column.from_libcudf(move(result)) diff --git a/python/cudf/cudf/_lib/reduce.pyx b/python/cudf/cudf/_lib/reduce.pyx index 5767cc8eee1..56bfa0ba332 100644 --- a/python/cudf/cudf/_lib/reduce.pyx +++ b/python/cudf/cudf/_lib/reduce.pyx @@ -3,12 +3,12 @@ import cudf from cudf.core.buffer import acquire_spill_lock -from cudf._lib.aggregation cimport make_aggregation from cudf._lib.column cimport Column from cudf._lib.scalar cimport DeviceScalar from cudf._lib.types cimport dtype_to_pylibcudf_type, is_decimal_type_id from cudf._lib import pylibcudf +from cudf._lib.aggregation import make_aggregation @acquire_spill_lock() diff --git a/python/cudf/cudf/_lib/rolling.pyx b/python/cudf/cudf/_lib/rolling.pyx index 8c4751e3084..5439e70fdce 100644 --- a/python/cudf/cudf/_lib/rolling.pyx +++ b/python/cudf/cudf/_lib/rolling.pyx @@ -1,16 +1,11 @@ -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. from cudf.core.buffer import acquire_spill_lock -from libcpp.memory cimport unique_ptr -from libcpp.utility cimport move - -from cudf._lib.aggregation cimport RollingAggregation, make_rolling_aggregation from cudf._lib.column cimport Column -from cudf._lib.cpp.column.column cimport column -from cudf._lib.cpp.column.column_view cimport column_view -from cudf._lib.cpp.rolling cimport rolling_window as cpp_rolling_window -from cudf._lib.cpp.types cimport size_type + +from cudf._lib import pylibcudf +from cudf._lib.aggregation import make_aggregation @acquire_spill_lock() @@ -41,20 +36,6 @@ def rolling(Column source_column, ------- A Column with rolling calculations """ - cdef size_type c_min_periods = min_periods - cdef size_type c_window = 0 - cdef size_type c_forward_window = 0 - cdef unique_ptr[column] c_result - cdef column_view source_column_view = source_column.view() - cdef column_view pre_column_window_view - cdef column_view fwd_column_window_view - cdef RollingAggregation cython_agg - - if callable(op): - cython_agg = make_rolling_aggregation( - op, {'dtype': source_column.dtype}) - else: - cython_agg = make_rolling_aggregation(op, agg_params) if window is None: if center: @@ -62,34 +43,24 @@ def rolling(Column source_column, raise NotImplementedError( "center is not implemented for offset-based windows" ) - pre_column_window_view = pre_column_window.view() - fwd_column_window_view = fwd_column_window.view() - with nogil: - c_result = move( - cpp_rolling_window( - source_column_view, - pre_column_window_view, - fwd_column_window_view, - c_min_periods, - cython_agg.c_obj.get()[0]) - ) + pre = pre_column_window.to_pylibcudf(mode="read") + fwd = fwd_column_window.to_pylibcudf(mode="read") else: - c_min_periods = min_periods if center: - c_window = (window // 2) + 1 - c_forward_window = window - (c_window) + pre = (window // 2) + 1 + fwd = window - (pre) else: - c_window = window - c_forward_window = 0 - - with nogil: - c_result = move( - cpp_rolling_window( - source_column_view, - c_window, - c_forward_window, - c_min_periods, - cython_agg.c_obj.get()[0]) - ) + pre = window + fwd = 0 - return Column.from_unique_ptr(move(c_result)) + return Column.from_pylibcudf( + pylibcudf.rolling.rolling_window( + source_column.to_pylibcudf(mode="read"), + pre, + fwd, + min_periods, + make_aggregation( + op, {'dtype': source_column.dtype} if callable(op) else agg_params + ).c_obj, + ) + ) diff --git a/python/cudf/cudf/_lib/sort.pyx b/python/cudf/cudf/_lib/sort.pyx index b80ea9c7fdc..e230dffbf3c 100644 --- a/python/cudf/cudf/_lib/sort.pyx +++ b/python/cudf/cudf/_lib/sort.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. from itertools import repeat @@ -10,10 +10,7 @@ from libcpp.utility cimport move, pair from libcpp.vector cimport vector from cudf._lib.column cimport Column -from cudf._lib.cpp.aggregation cimport ( - rank_method, - underlying_type_t_rank_method, -) +from cudf._lib.cpp.aggregation cimport rank_method from cudf._lib.cpp.column.column cimport column from cudf._lib.cpp.column.column_view cimport column_view from cudf._lib.cpp.search cimport lower_bound, upper_bound @@ -414,16 +411,12 @@ def digitize(list source_columns, list bins, bool right=False): @acquire_spill_lock() -def rank_columns(list source_columns, object method, str na_option, +def rank_columns(list source_columns, rank_method method, str na_option, bool ascending, bool pct ): """ Compute numerical data ranks (1 through n) of each column in the dataframe """ - cdef rank_method c_rank_method = < rank_method > ( - < underlying_type_t_rank_method > method - ) - cdef cpp_order column_order = ( cpp_order.ASCENDING if ascending @@ -464,7 +457,7 @@ def rank_columns(list source_columns, object method, str na_option, c_results.push_back(move( rank( c_view, - c_rank_method, + method, column_order, c_null_handling, null_precedence, diff --git a/python/cudf/cudf/core/indexed_frame.py b/python/cudf/cudf/core/indexed_frame.py index bc24216cade..8e43000d0a8 100644 --- a/python/cudf/cudf/core/indexed_frame.py +++ b/python/cudf/cudf/core/indexed_frame.py @@ -6113,7 +6113,7 @@ def rank( if method not in {"average", "min", "max", "first", "dense"}: raise KeyError(method) - method_enum = libcudf.aggregation.RankMethod[method.upper()] + method_enum = libcudf.pylibcudf.aggregation.RankMethod[method.upper()] if na_option not in {"keep", "top", "bottom"}: raise ValueError( "na_option must be one of 'keep', 'top', or 'bottom'"