diff --git a/cpp/include/cudf/aggregation.hpp b/cpp/include/cudf/aggregation.hpp index 74ce6e42d7e..2600926d363 100644 --- a/cpp/include/cudf/aggregation.hpp +++ b/cpp/include/cudf/aggregation.hpp @@ -40,14 +40,16 @@ namespace cudf { // forward declaration namespace detail { +class simple_aggregations_collector; class aggregation_finalizer; } // namespace detail /** - * @brief Base class for specifying the desired aggregation in an + * @brief Abstract base class for specifying the desired aggregation in an * `aggregation_request`. * - * Other kinds of aggregations may derive from this class to encapsulate - * additional information needed to compute the aggregation. + * All aggregations must derive from this class to implement the pure virtual + * functions and potentially encapsulate additional information needed to + * compute the aggregation. */ class aggregation { public: @@ -82,58 +84,78 @@ class aggregation { CUDA ///< CUDA UDF based reduction }; + aggregation() = delete; aggregation(aggregation::Kind a) : kind{a} {} Kind kind; ///< The aggregation to perform + virtual ~aggregation() = default; virtual bool is_equal(aggregation const& other) const { return kind == other.kind; } - virtual size_t do_hash() const { return std::hash{}(kind); } + virtual std::unique_ptr clone() const = 0; - virtual std::unique_ptr clone() const - { - return std::make_unique(*this); - } + // override functions for compound aggregations + virtual std::vector> get_simple_aggregations( + data_type col_type, cudf::detail::simple_aggregations_collector& collector) const = 0; + virtual void finalize(cudf::detail::aggregation_finalizer& finalizer) const = 0; +}; - virtual ~aggregation() = default; +/** + * @brief Derived class intended for enforcing operation-specific restrictions + * when calling various cudf functions. + * + * As an example, rolling_window will only accept rolling_aggregation inputs, + * and the appropriate derived classes (sum_aggregation, mean_aggregation, etc) + * derive from this interface to represent these valid options. + */ +class rolling_aggregation : public virtual aggregation { + public: + ~rolling_aggregation() = default; - // override functions for compound aggregations - virtual std::vector get_simple_aggregations(data_type col_type) const; - virtual void finalize(cudf::detail::aggregation_finalizer& finalizer); + protected: + rolling_aggregation() {} }; enum class udf_type : bool { CUDA, PTX }; /// Factory to create a SUM aggregation -std::unique_ptr make_sum_aggregation(); +template +std::unique_ptr make_sum_aggregation(); /// Factory to create a PRODUCT aggregation -std::unique_ptr make_product_aggregation(); +template +std::unique_ptr make_product_aggregation(); /// Factory to create a MIN aggregation -std::unique_ptr make_min_aggregation(); +template +std::unique_ptr make_min_aggregation(); /// Factory to create a MAX aggregation -std::unique_ptr make_max_aggregation(); +template +std::unique_ptr make_max_aggregation(); /** * @brief Factory to create a COUNT aggregation * * @param null_handling Indicates if null values will be counted. */ -std::unique_ptr make_count_aggregation( - null_policy null_handling = null_policy::EXCLUDE); +template +std::unique_ptr make_count_aggregation(null_policy null_handling = null_policy::EXCLUDE); -/// Factory to create a ANY aggregation -std::unique_ptr make_any_aggregation(); +/// Factory to create an ANY aggregation +template +std::unique_ptr make_any_aggregation(); /// Factory to create a ALL aggregation -std::unique_ptr make_all_aggregation(); +template +std::unique_ptr make_all_aggregation(); /// Factory to create a SUM_OF_SQUARES aggregation -std::unique_ptr make_sum_of_squares_aggregation(); +template +std::unique_ptr make_sum_of_squares_aggregation(); /// Factory to create a MEAN aggregation -std::unique_ptr make_mean_aggregation(); +template +std::unique_ptr make_mean_aggregation(); /** * @brief Factory to create a VARIANCE aggregation @@ -141,7 +163,8 @@ std::unique_ptr make_mean_aggregation(); * @param ddof Delta degrees of freedom. The divisor used in calculation of * `variance` is `N - ddof`, where `N` is the population size. */ -std::unique_ptr make_variance_aggregation(size_type ddof = 1); +template +std::unique_ptr make_variance_aggregation(size_type ddof = 1); /** * @brief Factory to create a STD aggregation @@ -149,10 +172,12 @@ std::unique_ptr make_variance_aggregation(size_type ddof = 1); * @param ddof Delta degrees of freedom. The divisor used in calculation of * `std` is `N - ddof`, where `N` is the population size. */ -std::unique_ptr make_std_aggregation(size_type ddof = 1); +template +std::unique_ptr make_std_aggregation(size_type ddof = 1); /// Factory to create a MEDIAN aggregation -std::unique_ptr make_median_aggregation(); +template +std::unique_ptr make_median_aggregation(); /** * @brief Factory to create a QUANTILE aggregation @@ -160,22 +185,25 @@ std::unique_ptr make_median_aggregation(); * @param quantiles The desired quantiles * @param interpolation The desired interpolation */ -std::unique_ptr make_quantile_aggregation(std::vector const& q, - interpolation i = interpolation::LINEAR); +template +std::unique_ptr make_quantile_aggregation(std::vector const& q, + interpolation i = interpolation::LINEAR); /** * @brief Factory to create an `argmax` aggregation * * `argmax` returns the index of the maximum element. */ -std::unique_ptr make_argmax_aggregation(); +template +std::unique_ptr make_argmax_aggregation(); /** * @brief Factory to create an `argmin` aggregation * * `argmin` returns the index of the minimum element. */ -std::unique_ptr make_argmin_aggregation(); +template +std::unique_ptr make_argmin_aggregation(); /** * @brief Factory to create a `nunique` aggregation @@ -183,8 +211,8 @@ std::unique_ptr make_argmin_aggregation(); * `nunique` returns the number of unique elements. * @param null_handling Indicates if null values will be counted. */ -std::unique_ptr make_nunique_aggregation( - null_policy null_handling = null_policy::EXCLUDE); +template +std::unique_ptr make_nunique_aggregation(null_policy null_handling = null_policy::EXCLUDE); /** * @brief Factory to create a `nth_element` aggregation @@ -199,11 +227,13 @@ std::unique_ptr make_nunique_aggregation( * @param n index of nth element in each group. * @param null_handling Indicates to include/exclude nulls during indexing. */ -std::unique_ptr make_nth_element_aggregation( +template +std::unique_ptr make_nth_element_aggregation( size_type n, null_policy null_handling = null_policy::INCLUDE); /// Factory to create a ROW_NUMBER aggregation -std::unique_ptr make_row_number_aggregation(); +template +std::unique_ptr make_row_number_aggregation(); /** * @brief Factory to create a COLLECT_LIST aggregation @@ -215,7 +245,8 @@ std::unique_ptr make_row_number_aggregation(); * * @param null_handling Indicates whether to include/exclude nulls in list elements. */ -std::unique_ptr make_collect_list_aggregation( +template +std::unique_ptr make_collect_list_aggregation( null_policy null_handling = null_policy::INCLUDE); /** @@ -233,16 +264,18 @@ std::unique_ptr make_collect_list_aggregation( * @param nans_equal Flag to specify whether NaN values in floating point column should be * considered equal */ -std::unique_ptr make_collect_set_aggregation( - null_policy null_handling = null_policy::INCLUDE, - null_equality nulls_equal = null_equality::EQUAL, - nan_equality nans_equal = nan_equality::UNEQUAL); +template +std::unique_ptr make_collect_set_aggregation(null_policy null_handling = null_policy::INCLUDE, + null_equality nulls_equal = null_equality::EQUAL, + nan_equality nans_equal = nan_equality::UNEQUAL); /// Factory to create a LAG aggregation -std::unique_ptr make_lag_aggregation(size_type offset); +template +std::unique_ptr make_lag_aggregation(size_type offset); /// Factory to create a LEAD aggregation -std::unique_ptr make_lead_aggregation(size_type offset); +template +std::unique_ptr make_lead_aggregation(size_type offset); /** * @brief Factory to create an aggregation base on UDF for PTX or CUDA @@ -253,9 +286,10 @@ std::unique_ptr make_lead_aggregation(size_type offset); * * @return aggregation unique pointer housing user_defined_aggregator string. */ -std::unique_ptr make_udf_aggregation(udf_type type, - std::string const& user_defined_aggregator, - data_type output_type); +template +std::unique_ptr make_udf_aggregation(udf_type type, + std::string const& user_defined_aggregator, + data_type output_type); /** @} */ // end of group } // namespace cudf diff --git a/cpp/include/cudf/detail/aggregation/aggregation.hpp b/cpp/include/cudf/detail/aggregation/aggregation.hpp index 0bfe6b84ae2..3941d776f75 100644 --- a/cpp/include/cudf/detail/aggregation/aggregation.hpp +++ b/cpp/include/cudf/detail/aggregation/aggregation.hpp @@ -28,252 +28,495 @@ namespace cudf { namespace detail { -// Forward declare compound aggregations. -class mean_aggregation; -class var_aggregation; -class std_aggregation; -class min_aggregation; -class max_aggregation; - // Visitor pattern +class simple_aggregations_collector { // Declares the interface for the simple aggregations + // collector + public: + // Declare overloads for each kind of a agg to dispatch + virtual std::vector> visit(data_type col_type, + aggregation const& agg); + virtual std::vector> visit(data_type col_type, + class sum_aggregation const& agg); + virtual std::vector> visit(data_type col_type, + class product_aggregation const& agg); + virtual std::vector> visit(data_type col_type, + class min_aggregation const& agg); + virtual std::vector> visit(data_type col_type, + class max_aggregation const& agg); + virtual std::vector> visit(data_type col_type, + class count_aggregation const& agg); + virtual std::vector> visit(data_type col_type, + class any_aggregation const& agg); + virtual std::vector> visit(data_type col_type, + class all_aggregation const& agg); + virtual std::vector> visit( + data_type col_type, class sum_of_squares_aggregation const& agg); + virtual std::vector> visit(data_type col_type, + class mean_aggregation const& agg); + virtual std::vector> visit(data_type col_type, + class var_aggregation const& agg); + virtual std::vector> visit(data_type col_type, + class std_aggregation const& agg); + virtual std::vector> visit(data_type col_type, + class median_aggregation const& agg); + virtual std::vector> visit(data_type col_type, + class quantile_aggregation const& agg); + virtual std::vector> visit(data_type col_type, + class argmax_aggregation const& agg); + virtual std::vector> visit(data_type col_type, + class argmin_aggregation const& agg); + virtual std::vector> visit(data_type col_type, + class nunique_aggregation const& agg); + virtual std::vector> visit(data_type col_type, + class nth_element_aggregation const& agg); + virtual std::vector> visit(data_type col_type, + class row_number_aggregation const& agg); + virtual std::vector> visit( + data_type col_type, class collect_list_aggregation const& agg); + virtual std::vector> visit(data_type col_type, + class collect_set_aggregation const& agg); + virtual std::vector> visit(data_type col_type, + class lead_lag_aggregation const& agg); + virtual std::vector> visit(data_type col_type, + class udf_aggregation const& agg); +}; + class aggregation_finalizer { // Declares the interface for the finalizer public: // Declare overloads for each kind of a agg to dispatch - virtual void visit(aggregation const& agg) = 0; - virtual void visit(min_aggregation const& agg) = 0; - virtual void visit(max_aggregation const& agg) = 0; - virtual void visit(mean_aggregation const& agg) = 0; - virtual void visit(var_aggregation const& agg) = 0; - virtual void visit(std_aggregation const& agg) = 0; + virtual void visit(aggregation const& agg); + virtual void visit(class sum_aggregation const& agg); + virtual void visit(class product_aggregation const& agg); + virtual void visit(class min_aggregation const& agg); + virtual void visit(class max_aggregation const& agg); + virtual void visit(class count_aggregation const& agg); + virtual void visit(class any_aggregation const& agg); + virtual void visit(class all_aggregation const& agg); + virtual void visit(class sum_of_squares_aggregation const& agg); + virtual void visit(class mean_aggregation const& agg); + virtual void visit(class var_aggregation const& agg); + virtual void visit(class std_aggregation const& agg); + virtual void visit(class median_aggregation const& agg); + virtual void visit(class quantile_aggregation const& agg); + virtual void visit(class argmax_aggregation const& agg); + virtual void visit(class argmin_aggregation const& agg); + virtual void visit(class nunique_aggregation const& agg); + virtual void visit(class nth_element_aggregation const& agg); + virtual void visit(class row_number_aggregation const& agg); + virtual void visit(class collect_list_aggregation const& agg); + virtual void visit(class collect_set_aggregation const& agg); + virtual void visit(class lead_lag_aggregation const& agg); + virtual void visit(class udf_aggregation const& agg); }; /** - * @brief Derived class for specifying a min aggregation + * @brief Derived class for specifying a sum aggregation */ -struct min_aggregation final : aggregation { - min_aggregation() : aggregation{MIN} {} +class sum_aggregation final : public rolling_aggregation { + public: + sum_aggregation() : aggregation(SUM) {} - std::vector get_simple_aggregations(data_type col_type) const override + std::unique_ptr clone() const override + { + return std::make_unique(*this); + } + std::vector> get_simple_aggregations( + data_type col_type, cudf::detail::simple_aggregations_collector& collector) const override { - if (col_type.id() == type_id::STRING) - return {aggregation::ARGMIN}; - else - return {this->kind}; + return collector.visit(col_type, *this); } - void finalize(aggregation_finalizer& finalizer) override { finalizer.visit(*this); } + void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); } +}; + +/** + * @brief Derived class for specifying a product aggregation + */ +class product_aggregation final : public aggregation { + public: + product_aggregation() : aggregation(PRODUCT) {} std::unique_ptr clone() const override { - return std::make_unique(*this); + return std::make_unique(*this); + } + std::vector> get_simple_aggregations( + data_type col_type, cudf::detail::simple_aggregations_collector& collector) const override + { + return collector.visit(col_type, *this); } + void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); } }; /** - * @brief Derived class for specifying a max aggregation + * @brief Derived class for specifying a min aggregation */ -struct max_aggregation final : aggregation { - max_aggregation() : aggregation{MAX} {} +class min_aggregation final : public rolling_aggregation { + public: + min_aggregation() : aggregation(MIN) {} - std::vector get_simple_aggregations(data_type col_type) const override + std::unique_ptr clone() const override { - if (col_type.id() == type_id::STRING) - return {aggregation::ARGMAX}; - else - return {this->kind}; + return std::make_unique(*this); } - void finalize(aggregation_finalizer& finalizer) override { finalizer.visit(*this); } + std::vector> get_simple_aggregations( + data_type col_type, cudf::detail::simple_aggregations_collector& collector) const override + { + return collector.visit(col_type, *this); + } + void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); } +}; + +/** + * @brief Derived class for specifying a max aggregation + */ +class max_aggregation final : public rolling_aggregation { + public: + max_aggregation() : aggregation(MAX) {} std::unique_ptr clone() const override { return std::make_unique(*this); } + std::vector> get_simple_aggregations( + data_type col_type, cudf::detail::simple_aggregations_collector& collector) const override + { + return collector.visit(col_type, *this); + } + void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); } }; /** - * @brief A wrapper to simplify inheritance of virtual methods from aggregation - * - * Derived aggregations are required to implement operator==() and hash_impl(). - * - * https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern + * @brief Derived class for specifying a count aggregation */ -template -class derived_aggregation : public aggregation { +class count_aggregation final : public rolling_aggregation { public: - derived_aggregation(aggregation::Kind a) : aggregation(a) {} + count_aggregation(aggregation::Kind kind) : aggregation(kind) {} - bool is_equal(aggregation const& other) const override + std::unique_ptr clone() const override { - if (this->aggregation::is_equal(other)) { - // Dispatch to operator== using static polymorphism - return static_cast(*this) == static_cast(other); - } else { - return false; - } + return std::make_unique(*this); } - - size_t do_hash() const override + std::vector> get_simple_aggregations( + data_type col_type, cudf::detail::simple_aggregations_collector& collector) const override { - // Dispatch to hash_impl() using static polymorphism - return this->aggregation::do_hash() ^ static_cast(*this).hash_impl(); + return collector.visit(col_type, *this); } + void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); } +}; + +/** + * @brief Derived class for specifying an any aggregation + */ +class any_aggregation final : public aggregation { + public: + any_aggregation() : aggregation(ANY) {} std::unique_ptr clone() const override { - // Dispatch to copy constructor using static polymorphism - return std::make_unique(static_cast(*this)); + return std::make_unique(*this); + } + std::vector> get_simple_aggregations( + data_type col_type, cudf::detail::simple_aggregations_collector& collector) const override + { + return collector.visit(col_type, *this); } + void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); } }; /** - * @brief Derived class for specifying a quantile aggregation + * @brief Derived class for specifying an all aggregation */ -struct quantile_aggregation final : derived_aggregation { - quantile_aggregation(std::vector const& q, interpolation i) - : derived_aggregation{QUANTILE}, _quantiles{q}, _interpolation{i} +class all_aggregation final : public aggregation { + public: + all_aggregation() : aggregation(ALL) {} + + std::unique_ptr clone() const override { + return std::make_unique(*this); } - std::vector _quantiles; ///< Desired quantile(s) - interpolation _interpolation; ///< Desired interpolation + std::vector> get_simple_aggregations( + data_type col_type, cudf::detail::simple_aggregations_collector& collector) const override + { + return collector.visit(col_type, *this); + } + void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); } +}; - protected: - friend class derived_aggregation; +/** + * @brief Derived class for specifying a sum_of_squares aggregation + */ +class sum_of_squares_aggregation final : public aggregation { + public: + sum_of_squares_aggregation() : aggregation(SUM_OF_SQUARES) {} - bool operator==(quantile_aggregation const& other) const + std::unique_ptr clone() const override { - return _interpolation == other._interpolation and - std::equal(_quantiles.begin(), _quantiles.end(), other._quantiles.begin()); + return std::make_unique(*this); } + std::vector> get_simple_aggregations( + data_type col_type, cudf::detail::simple_aggregations_collector& collector) const override + { + return collector.visit(col_type, *this); + } + void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); } +}; - size_t hash_impl() const +/** + * @brief Derived class for specifying a mean aggregation + */ +class mean_aggregation final : public rolling_aggregation { + public: + mean_aggregation() : aggregation(MEAN) {} + + std::unique_ptr clone() const override { - return std::hash{}(static_cast(_interpolation)) ^ - std::accumulate( - _quantiles.cbegin(), _quantiles.cend(), size_t{0}, [](size_t a, double b) { - return a ^ std::hash{}(b); - }); + return std::make_unique(*this); + } + std::vector> get_simple_aggregations( + data_type col_type, cudf::detail::simple_aggregations_collector& collector) const override + { + return collector.visit(col_type, *this); } + void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); } }; /** - * @brief Derived aggregation class for specifying LEAD/LAG window aggregations + * @brief Derived class for specifying a standard deviation/variance aggregation */ -struct lead_lag_aggregation final : derived_aggregation { - lead_lag_aggregation(Kind kind, size_type offset) - : derived_aggregation{offset < 0 ? (kind == LAG ? LEAD : LAG) : kind}, - row_offset{std::abs(offset)} +class std_var_aggregation : public aggregation { + public: + size_type _ddof; ///< Delta degrees of freedom + + bool is_equal(aggregation const& _other) const override { + if (!this->aggregation::is_equal(_other)) { return false; } + auto const& other = dynamic_cast(_other); + return _ddof == other._ddof; } - size_type row_offset; + size_t do_hash() const override { return this->aggregation::do_hash() ^ hash_impl(); } protected: - friend class derived_aggregation; - - bool operator==(lead_lag_aggregation const& rhs) const { return row_offset == rhs.row_offset; } + std_var_aggregation(aggregation::Kind k, size_type ddof) : aggregation(k), _ddof{ddof} + { + CUDF_EXPECTS(k == aggregation::STD or k == aggregation::VARIANCE, + "std_var_aggregation can accept only STD, VARIANCE"); + } - size_t hash_impl() const { return std::hash()(row_offset); } + size_type hash_impl() const { return std::hash{}(_ddof); } }; /** - * @brief Derived class for specifying a mean aggregation + * @brief Derived class for specifying a variance aggregation */ -struct mean_aggregation final : aggregation { - mean_aggregation() : aggregation{MEAN} {} +class var_aggregation final : public std_var_aggregation { + public: + var_aggregation(size_type ddof) : std_var_aggregation{aggregation::VARIANCE, ddof} {} - std::vector get_simple_aggregations(data_type col_type) const override + std::unique_ptr clone() const override { - CUDF_EXPECTS(is_fixed_width(col_type), "MEAN aggregation expects fixed width type"); - return {aggregation::SUM, aggregation::COUNT_VALID}; + return std::make_unique(*this); } - void finalize(aggregation_finalizer& finalizer) override { finalizer.visit(*this); } + std::vector> get_simple_aggregations( + data_type col_type, cudf::detail::simple_aggregations_collector& collector) const override + { + return collector.visit(col_type, *this); + } + void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); } +}; + +/** + * @brief Derived class for specifying a standard deviation aggregation + */ +class std_aggregation final : public std_var_aggregation { + public: + std_aggregation(size_type ddof) : std_var_aggregation{aggregation::STD, ddof} {} std::unique_ptr clone() const override { - return std::make_unique(*this); + return std::make_unique(*this); } + std::vector> get_simple_aggregations( + data_type col_type, cudf::detail::simple_aggregations_collector& collector) const override + { + return collector.visit(col_type, *this); + } + void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); } }; /** - * @brief Derived class for specifying a standard deviation/variance aggregation + * @brief Derived class for specifying a median aggregation */ -struct std_var_aggregation : derived_aggregation { - size_type _ddof; ///< Delta degrees of freedom +class median_aggregation final : public aggregation { + public: + median_aggregation() : aggregation(MEDIAN) {} - virtual std::vector get_simple_aggregations(data_type col_type) const override + std::unique_ptr clone() const override { - return {aggregation::SUM, aggregation::COUNT_VALID}; + return std::make_unique(*this); } + std::vector> get_simple_aggregations( + data_type col_type, cudf::detail::simple_aggregations_collector& collector) const override + { + return collector.visit(col_type, *this); + } + void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); } +}; - protected: - friend class derived_aggregation; +/** + * @brief Derived class for specifying a quantile aggregation + */ +class quantile_aggregation final : public aggregation { + public: + quantile_aggregation(std::vector const& q, interpolation i) + : aggregation{QUANTILE}, _quantiles{q}, _interpolation{i} + { + } + std::vector _quantiles; ///< Desired quantile(s) + interpolation _interpolation; ///< Desired interpolation - bool operator==(std_var_aggregation const& other) const { return _ddof == other._ddof; } + bool is_equal(aggregation const& _other) const override + { + if (!this->aggregation::is_equal(_other)) { return false; } + + auto const& other = dynamic_cast(_other); - size_t hash_impl() const { return std::hash{}(_ddof); } + return _interpolation == other._interpolation && + std::equal(_quantiles.begin(), _quantiles.end(), other._quantiles.begin()); + } - std_var_aggregation(aggregation::Kind k, size_type ddof) : derived_aggregation{k}, _ddof{ddof} + size_t do_hash() const override { return this->aggregation::do_hash() ^ hash_impl(); } + + std::unique_ptr clone() const override { - CUDF_EXPECTS(k == aggregation::STD or k == aggregation::VARIANCE, - "std_var_aggregation can accept only STD, VARIANCE"); + return std::make_unique(*this); + } + std::vector> get_simple_aggregations( + data_type col_type, cudf::detail::simple_aggregations_collector& collector) const override + { + return collector.visit(col_type, *this); + } + void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); } + + private: + size_t hash_impl() const + { + return std::hash{}(static_cast(_interpolation)) ^ + std::accumulate( + _quantiles.cbegin(), _quantiles.cend(), size_t{0}, [](size_t a, double b) { + return a ^ std::hash{}(b); + }); } }; /** - * @brief Derived class for specifying a standard deviation aggregation + * @brief Derived class for specifying an argmax aggregation */ -struct std_aggregation final : std_var_aggregation { - std_aggregation(size_type ddof) : std_var_aggregation{aggregation::STD, ddof} {} - void finalize(aggregation_finalizer& finalizer) override { finalizer.visit(*this); } +class argmax_aggregation final : public rolling_aggregation { + public: + argmax_aggregation() : aggregation(ARGMAX) {} + + std::unique_ptr clone() const override + { + return std::make_unique(*this); + } + std::vector> get_simple_aggregations( + data_type col_type, cudf::detail::simple_aggregations_collector& collector) const override + { + return collector.visit(col_type, *this); + } + void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); } }; /** - * @brief Derived class for specifying a variance aggregation + * @brief Derived class for specifying an argmin aggregation */ -struct var_aggregation final : std_var_aggregation { - var_aggregation(size_type ddof) : std_var_aggregation{aggregation::VARIANCE, ddof} {} - void finalize(aggregation_finalizer& finalizer) override { finalizer.visit(*this); } +class argmin_aggregation final : public rolling_aggregation { + public: + argmin_aggregation() : aggregation(ARGMIN) {} + + std::unique_ptr clone() const override + { + return std::make_unique(*this); + } + std::vector> get_simple_aggregations( + data_type col_type, cudf::detail::simple_aggregations_collector& collector) const override + { + return collector.visit(col_type, *this); + } + void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); } }; /** * @brief Derived class for specifying a nunique aggregation */ -struct nunique_aggregation final : derived_aggregation { +class nunique_aggregation final : public aggregation { + public: nunique_aggregation(null_policy null_handling) - : derived_aggregation{NUNIQUE}, _null_handling{null_handling} + : aggregation{NUNIQUE}, _null_handling{null_handling} { } - null_policy _null_handling; ///< include or exclude nulls - protected: - friend class derived_aggregation; + null_policy _null_handling; ///< include or exclude nulls - bool operator==(nunique_aggregation const& other) const + bool is_equal(aggregation const& _other) const override { + if (!this->aggregation::is_equal(_other)) { return false; } + auto const& other = dynamic_cast(_other); return _null_handling == other._null_handling; } + size_t do_hash() const override { return this->aggregation::do_hash() ^ hash_impl(); } + + std::unique_ptr clone() const override + { + return std::make_unique(*this); + } + std::vector> get_simple_aggregations( + data_type col_type, cudf::detail::simple_aggregations_collector& collector) const override + { + return collector.visit(col_type, *this); + } + void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); } + + private: size_t hash_impl() const { return std::hash{}(static_cast(_null_handling)); } }; /** * @brief Derived class for specifying a nth element aggregation */ -struct nth_element_aggregation final : derived_aggregation { +class nth_element_aggregation final : public aggregation { + public: nth_element_aggregation(size_type n, null_policy null_handling) - : derived_aggregation{NTH_ELEMENT}, _n{n}, _null_handling{null_handling} + : aggregation{NTH_ELEMENT}, _n{n}, _null_handling{null_handling} { } + size_type _n; ///< nth index to return null_policy _null_handling; ///< include or exclude nulls - protected: - friend class derived_aggregation; - - bool operator==(nth_element_aggregation const& other) const + bool is_equal(aggregation const& _other) const override { + if (!this->aggregation::is_equal(_other)) { return false; } + auto const& other = dynamic_cast(_other); return _n == other._n and _null_handling == other._null_handling; } + size_t do_hash() const override { return this->aggregation::do_hash() ^ hash_impl(); } + + std::unique_ptr clone() const override + { + return std::make_unique(*this); + } + std::vector> get_simple_aggregations( + data_type col_type, cudf::detail::simple_aggregations_collector& collector) const override + { + return collector.visit(col_type, *this); + } + void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); } + + private: size_t hash_impl() const { return std::hash{}(_n) ^ std::hash{}(static_cast(_null_handling)); @@ -281,92 +524,102 @@ struct nth_element_aggregation final : derived_aggregation { - udf_aggregation(aggregation::Kind type, - std::string const& user_defined_aggregator, - data_type output_type) - : derived_aggregation{type}, - _source{user_defined_aggregator}, - _operator_name{(type == aggregation::PTX) ? "rolling_udf_ptx" : "rolling_udf_cuda"}, - _function_name{"rolling_udf"}, - _output_type{output_type} - { - CUDF_EXPECTS(type == aggregation::PTX or type == aggregation::CUDA, - "udf_aggregation can accept only PTX, CUDA"); - } - std::string const _source; - std::string const _operator_name; - std::string const _function_name; - data_type _output_type; - - protected: - friend class derived_aggregation; +class row_number_aggregation final : public rolling_aggregation { + public: + row_number_aggregation() : aggregation(ROW_NUMBER) {} - bool operator==(udf_aggregation const& other) const + std::unique_ptr clone() const override { - return _source == other._source and _operator_name == other._operator_name and - _function_name == other._function_name and _output_type == other._output_type; + return std::make_unique(*this); } - - size_t hash_impl() const + std::vector> get_simple_aggregations( + data_type col_type, cudf::detail::simple_aggregations_collector& collector) const override { - return std::hash{}(_source) ^ std::hash{}(_operator_name) ^ - std::hash{}(_function_name) ^ - std::hash{}(static_cast(_output_type.id())); + return collector.visit(col_type, *this); } + void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); } }; /** * @brief Derived aggregation class for specifying COLLECT_LIST aggregation */ -struct collect_list_aggregation final : derived_aggregation { +class collect_list_aggregation final : public rolling_aggregation { + public: explicit collect_list_aggregation(null_policy null_handling = null_policy::INCLUDE) - : derived_aggregation{COLLECT_LIST}, _null_handling{null_handling} + : aggregation{COLLECT_LIST}, _null_handling{null_handling} { } + null_policy _null_handling; ///< include or exclude nulls - protected: - friend class derived_aggregation; + bool is_equal(aggregation const& _other) const override + { + if (!this->aggregation::is_equal(_other)) { return false; } + auto const& other = dynamic_cast(_other); + return (_null_handling == other._null_handling); + } - bool operator==(nunique_aggregation const& other) const + size_t do_hash() const override { return this->aggregation::do_hash() ^ hash_impl(); } + + std::unique_ptr clone() const override { - return _null_handling == other._null_handling; + return std::make_unique(*this); + } + std::vector> get_simple_aggregations( + data_type col_type, cudf::detail::simple_aggregations_collector& collector) const override + { + return collector.visit(col_type, *this); } + void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); } + private: size_t hash_impl() const { return std::hash{}(static_cast(_null_handling)); } }; /** * @brief Derived aggregation class for specifying COLLECT_SET aggregation */ -struct collect_set_aggregation final : derived_aggregation { +class collect_set_aggregation final : public aggregation { + public: explicit collect_set_aggregation(null_policy null_handling = null_policy::INCLUDE, null_equality nulls_equal = null_equality::EQUAL, nan_equality nans_equal = nan_equality::UNEQUAL) - : derived_aggregation{COLLECT_SET}, + : aggregation{COLLECT_SET}, _null_handling{null_handling}, _nulls_equal(nulls_equal), _nans_equal(nans_equal) { } + null_policy _null_handling; ///< include or exclude nulls null_equality _nulls_equal; ///< whether to consider nulls as equal values nan_equality _nans_equal; ///< whether to consider NaNs as equal value (applicable only to ///< floating point types) - protected: - friend class derived_aggregation; + bool is_equal(aggregation const& _other) const override + { + if (!this->aggregation::is_equal(_other)) { return false; } + auto const& other = dynamic_cast(_other); + return (_null_handling == other._null_handling && _nulls_equal == other._nulls_equal && + _nans_equal == other._nans_equal); + } + + size_t do_hash() const override { return this->aggregation::do_hash() ^ hash_impl(); } - bool operator==(collect_set_aggregation const& other) const + std::unique_ptr clone() const override + { + return std::make_unique(*this); + } + std::vector> get_simple_aggregations( + data_type col_type, cudf::detail::simple_aggregations_collector& collector) const override { - return _null_handling == other._null_handling && _nulls_equal == other._nulls_equal && - _nans_equal == other._nans_equal; + return collector.visit(col_type, *this); } + void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); } + protected: size_t hash_impl() const { return std::hash{}(static_cast(_null_handling) ^ static_cast(_nulls_equal) ^ @@ -374,6 +627,96 @@ struct collect_set_aggregation final : derived_aggregationaggregation::is_equal(_other)) { return false; } + auto const& other = dynamic_cast(_other); + return (row_offset == other.row_offset); + } + + size_t do_hash() const override { return this->aggregation::do_hash() ^ hash_impl(); } + + std::unique_ptr clone() const override + { + return std::make_unique(*this); + } + std::vector> get_simple_aggregations( + data_type col_type, cudf::detail::simple_aggregations_collector& collector) const override + { + return collector.visit(col_type, *this); + } + void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); } + + size_type row_offset; + + private: + size_t hash_impl() const { return std::hash()(row_offset); } +}; + +/** + * @brief Derived class for specifying a custom aggregation + * specified in udf + */ +class udf_aggregation final : public rolling_aggregation { + public: + udf_aggregation(aggregation::Kind type, + std::string const& user_defined_aggregator, + data_type output_type) + : aggregation{type}, + _source{user_defined_aggregator}, + _operator_name{(type == aggregation::PTX) ? "rolling_udf_ptx" : "rolling_udf_cuda"}, + _function_name{"rolling_udf"}, + _output_type{output_type} + { + CUDF_EXPECTS(type == aggregation::PTX or type == aggregation::CUDA, + "udf_aggregation can accept only PTX, CUDA"); + } + + bool is_equal(aggregation const& _other) const override + { + if (!this->aggregation::is_equal(_other)) { return false; } + auto const& other = dynamic_cast(_other); + return (_source == other._source and _operator_name == other._operator_name and + _function_name == other._function_name and _output_type == other._output_type); + } + + size_t do_hash() const override { return this->aggregation::do_hash() ^ hash_impl(); } + + std::unique_ptr clone() const override + { + return std::make_unique(*this); + } + std::vector> get_simple_aggregations( + data_type col_type, cudf::detail::simple_aggregations_collector& collector) const override + { + return collector.visit(col_type, *this); + } + void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); } + + std::string const _source; + std::string const _operator_name; + std::string const _function_name; + data_type _output_type; + + protected: + size_t hash_impl() const + { + return std::hash{}(_source) ^ std::hash{}(_operator_name) ^ + std::hash{}(_function_name) ^ + std::hash{}(static_cast(_output_type.id())); + } +}; + /** * @brief Sentinel value used for `ARGMAX` aggregation. * @@ -763,4 +1106,4 @@ constexpr inline bool is_valid_aggregation() bool is_valid_aggregation(data_type source, aggregation::Kind k); } // namespace detail -} // namespace cudf +} // namespace cudf \ No newline at end of file diff --git a/cpp/include/cudf/detail/rolling.hpp b/cpp/include/cudf/detail/rolling.hpp index ec2af220440..2b06d11c5a9 100644 --- a/cpp/include/cudf/detail/rolling.hpp +++ b/cpp/include/cudf/detail/rolling.hpp @@ -33,7 +33,7 @@ namespace detail { * column_view const& preceding_window, * column_view const& following_window, * size_type min_periods, - * std::unique_ptr const& agg, + * rolling_aggregation const& agg, * rmm::mr::device_memory_resource* mr) * * @param stream CUDA stream used for device memory operations and kernel launches. @@ -43,7 +43,7 @@ std::unique_ptr rolling_window( column_view const& preceding_window, column_view const& following_window, size_type min_periods, - std::unique_ptr const& agg, + rolling_aggregation const& agg, rmm::cuda_stream_view stream = rmm::cuda_stream_default, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); diff --git a/cpp/include/cudf/rolling.hpp b/cpp/include/cudf/rolling.hpp index 8a2498d0163..4fb1b4a7319 100644 --- a/cpp/include/cudf/rolling.hpp +++ b/cpp/include/cudf/rolling.hpp @@ -59,7 +59,7 @@ std::unique_ptr rolling_window( size_type preceding_window, size_type following_window, size_type min_periods, - std::unique_ptr const& agg, + rolling_aggregation const& agg, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); /** @@ -68,7 +68,7 @@ std::unique_ptr rolling_window( * size_type preceding_window, * size_type following_window, * size_type min_periods, - * std::unique_ptr const& agg, + * rolling_aggregation const& agg, * rmm::mr::device_memory_resource* mr) * * @param default_outputs A column of per-row default values to be returned instead @@ -81,7 +81,7 @@ std::unique_ptr rolling_window( size_type preceding_window, size_type following_window, size_type min_periods, - std::unique_ptr const& agg, + rolling_aggregation const& agg, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); /** @@ -197,7 +197,7 @@ std::unique_ptr grouped_rolling_window( size_type preceding_window, size_type following_window, size_type min_periods, - std::unique_ptr const& aggr, + rolling_aggregation const& aggr, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); /** @@ -207,7 +207,7 @@ std::unique_ptr grouped_rolling_window( * size_type preceding_window, * size_type following_window, * size_type min_periods, - * std::unique_ptr const& aggr, + * rolling_aggregation const& aggr, * rmm::mr::device_memory_resource* mr) */ std::unique_ptr grouped_rolling_window( @@ -216,7 +216,7 @@ std::unique_ptr grouped_rolling_window( window_bounds preceding_window, window_bounds following_window, size_type min_periods, - std::unique_ptr const& aggr, + rolling_aggregation const& aggr, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); /** @@ -226,7 +226,7 @@ std::unique_ptr grouped_rolling_window( * size_type preceding_window, * size_type following_window, * size_type min_periods, - * std::unique_ptr const& aggr, + * rolling_aggregation const& aggr, * rmm::mr::device_memory_resource* mr) * * @param default_outputs A column of per-row default values to be returned instead @@ -240,7 +240,7 @@ std::unique_ptr grouped_rolling_window( size_type preceding_window, size_type following_window, size_type min_periods, - std::unique_ptr const& aggr, + rolling_aggregation const& aggr, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); /** @@ -251,7 +251,7 @@ std::unique_ptr grouped_rolling_window( * size_type preceding_window, * size_type following_window, * size_type min_periods, - * std::unique_ptr const& aggr, + * rolling_aggregation const& aggr, * rmm::mr::device_memory_resource* mr) */ std::unique_ptr grouped_rolling_window( @@ -261,7 +261,7 @@ std::unique_ptr grouped_rolling_window( window_bounds preceding_window, window_bounds following_window, size_type min_periods, - std::unique_ptr const& aggr, + rolling_aggregation const& aggr, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); /** @@ -355,7 +355,7 @@ std::unique_ptr grouped_time_range_rolling_window( size_type preceding_window_in_days, size_type following_window_in_days, size_type min_periods, - std::unique_ptr const& aggr, + rolling_aggregation const& aggr, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); /** @@ -370,7 +370,7 @@ std::unique_ptr grouped_time_range_rolling_window( * size_type preceding_window_in_days, * size_type following_window_in_days, * size_type min_periods, - * std::unique_ptr const& aggr, + * rolling_aggregation const& aggr, * rmm::mr::device_memory_resource* mr) * * The `preceding_window_in_days` and `following_window_in_days` supports "unbounded" windows, @@ -398,7 +398,7 @@ std::unique_ptr grouped_time_range_rolling_window( window_bounds preceding_window_in_days, window_bounds following_window_in_days, size_type min_periods, - std::unique_ptr const& aggr, + rolling_aggregation const& aggr, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); /** @@ -517,7 +517,7 @@ std::unique_ptr grouped_range_rolling_window( range_window_bounds const& preceding, range_window_bounds const& following, size_type min_periods, - std::unique_ptr const& aggr, + rolling_aggregation const& aggr, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); /** @@ -559,7 +559,7 @@ std::unique_ptr rolling_window( column_view const& preceding_window, column_view const& following_window, size_type min_periods, - std::unique_ptr const& agg, + rolling_aggregation const& agg, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); /** @} */ // end of group diff --git a/cpp/src/aggregation/aggregation.cpp b/cpp/src/aggregation/aggregation.cpp index 3a044a42101..3a2215eaa53 100644 --- a/cpp/src/aggregation/aggregation.cpp +++ b/cpp/src/aggregation/aggregation.cpp @@ -22,142 +22,489 @@ namespace cudf { -std::vector aggregation::get_simple_aggregations(data_type col_type) const +namespace detail { + +// simple_aggregations_collector ---------------------------------------- + +std::vector> simple_aggregations_collector::visit( + data_type col_type, aggregation const& agg) +{ + std::vector> aggs; + aggs.push_back(agg.clone()); + return aggs; +} + +std::vector> simple_aggregations_collector::visit( + data_type col_type, sum_aggregation const& agg) +{ + return visit(col_type, static_cast(agg)); +} + +std::vector> simple_aggregations_collector::visit( + data_type col_type, product_aggregation const& agg) +{ + return visit(col_type, static_cast(agg)); +} + +std::vector> simple_aggregations_collector::visit( + data_type col_type, min_aggregation const& agg) +{ + return visit(col_type, static_cast(agg)); +} + +std::vector> simple_aggregations_collector::visit( + data_type col_type, max_aggregation const& agg) +{ + return visit(col_type, static_cast(agg)); +} + +std::vector> simple_aggregations_collector::visit( + data_type col_type, count_aggregation const& agg) +{ + return visit(col_type, static_cast(agg)); +} + +std::vector> simple_aggregations_collector::visit( + data_type col_type, any_aggregation const& agg) +{ + return visit(col_type, static_cast(agg)); +} + +std::vector> simple_aggregations_collector::visit( + data_type col_type, all_aggregation const& agg) +{ + return visit(col_type, static_cast(agg)); +} + +std::vector> simple_aggregations_collector::visit( + data_type col_type, sum_of_squares_aggregation const& agg) +{ + return visit(col_type, static_cast(agg)); +} + +std::vector> simple_aggregations_collector::visit( + data_type col_type, mean_aggregation const& agg) +{ + return visit(col_type, static_cast(agg)); +} + +std::vector> simple_aggregations_collector::visit( + data_type col_type, var_aggregation const& agg) +{ + return visit(col_type, static_cast(agg)); +} + +std::vector> simple_aggregations_collector::visit( + data_type col_type, std_aggregation const& agg) +{ + return visit(col_type, static_cast(agg)); +} + +std::vector> simple_aggregations_collector::visit( + data_type col_type, median_aggregation const& agg) +{ + return visit(col_type, static_cast(agg)); +} + +std::vector> simple_aggregations_collector::visit( + data_type col_type, quantile_aggregation const& agg) +{ + return visit(col_type, static_cast(agg)); +} + +std::vector> simple_aggregations_collector::visit( + data_type col_type, argmax_aggregation const& agg) +{ + return visit(col_type, static_cast(agg)); +} + +std::vector> simple_aggregations_collector::visit( + data_type col_type, argmin_aggregation const& agg) +{ + return visit(col_type, static_cast(agg)); +} + +std::vector> simple_aggregations_collector::visit( + data_type col_type, nunique_aggregation const& agg) +{ + return visit(col_type, static_cast(agg)); +} + +std::vector> simple_aggregations_collector::visit( + data_type col_type, nth_element_aggregation const& agg) +{ + return visit(col_type, static_cast(agg)); +} + +std::vector> simple_aggregations_collector::visit( + data_type col_type, row_number_aggregation const& agg) +{ + return visit(col_type, static_cast(agg)); +} + +std::vector> simple_aggregations_collector::visit( + data_type col_type, collect_list_aggregation const& agg) +{ + return visit(col_type, static_cast(agg)); +} + +std::vector> simple_aggregations_collector::visit( + data_type col_type, collect_set_aggregation const& agg) +{ + return visit(col_type, static_cast(agg)); +} + +std::vector> simple_aggregations_collector::visit( + data_type col_type, lead_lag_aggregation const& agg) +{ + return visit(col_type, static_cast(agg)); +} + +std::vector> simple_aggregations_collector::visit( + data_type col_type, udf_aggregation const& agg) +{ + return visit(col_type, static_cast(agg)); +} + +// aggregation_finalizer ---------------------------------------- + +void aggregation_finalizer::visit(aggregation const& agg) {} + +void aggregation_finalizer::visit(sum_aggregation const& agg) +{ + visit(static_cast(agg)); +} + +void aggregation_finalizer::visit(product_aggregation const& agg) +{ + visit(static_cast(agg)); +} + +void aggregation_finalizer::visit(min_aggregation const& agg) +{ + visit(static_cast(agg)); +} + +void aggregation_finalizer::visit(max_aggregation const& agg) { - return {this->kind}; + visit(static_cast(agg)); } -void aggregation::finalize(cudf::detail::aggregation_finalizer& finalizer) + +void aggregation_finalizer::visit(count_aggregation const& agg) +{ + visit(static_cast(agg)); +} + +void aggregation_finalizer::visit(any_aggregation const& agg) +{ + visit(static_cast(agg)); +} + +void aggregation_finalizer::visit(all_aggregation const& agg) +{ + visit(static_cast(agg)); +} + +void aggregation_finalizer::visit(sum_of_squares_aggregation const& agg) +{ + visit(static_cast(agg)); +} + +void aggregation_finalizer::visit(mean_aggregation const& agg) +{ + visit(static_cast(agg)); +} + +void aggregation_finalizer::visit(var_aggregation const& agg) +{ + visit(static_cast(agg)); +} + +void aggregation_finalizer::visit(std_aggregation const& agg) +{ + visit(static_cast(agg)); +} + +void aggregation_finalizer::visit(median_aggregation const& agg) +{ + visit(static_cast(agg)); +} + +void aggregation_finalizer::visit(quantile_aggregation const& agg) { - finalizer.visit(*this); + visit(static_cast(agg)); +} + +void aggregation_finalizer::visit(argmax_aggregation const& agg) +{ + visit(static_cast(agg)); +} + +void aggregation_finalizer::visit(argmin_aggregation const& agg) +{ + visit(static_cast(agg)); +} + +void aggregation_finalizer::visit(nunique_aggregation const& agg) +{ + visit(static_cast(agg)); +} + +void aggregation_finalizer::visit(nth_element_aggregation const& agg) +{ + visit(static_cast(agg)); +} + +void aggregation_finalizer::visit(row_number_aggregation const& agg) +{ + visit(static_cast(agg)); +} + +void aggregation_finalizer::visit(collect_list_aggregation const& agg) +{ + visit(static_cast(agg)); +} + +void aggregation_finalizer::visit(collect_set_aggregation const& agg) +{ + visit(static_cast(agg)); +} + +void aggregation_finalizer::visit(lead_lag_aggregation const& agg) +{ + visit(static_cast(agg)); +} + +void aggregation_finalizer::visit(udf_aggregation const& agg) +{ + visit(static_cast(agg)); +} + +} // namespace detail + +std::vector> aggregation::get_simple_aggregations( + data_type col_type, cudf::detail::simple_aggregations_collector& collector) const +{ + return collector.visit(col_type, *this); } /// Factory to create a SUM aggregation -std::unique_ptr make_sum_aggregation() +template +std::unique_ptr make_sum_aggregation() { - return std::make_unique(aggregation::SUM); + return std::make_unique(); } +template std::unique_ptr make_sum_aggregation(); +template std::unique_ptr make_sum_aggregation(); + /// Factory to create a PRODUCT aggregation -std::unique_ptr make_product_aggregation() +template +std::unique_ptr make_product_aggregation() { - return std::make_unique(aggregation::PRODUCT); + return std::make_unique(); } +template std::unique_ptr make_product_aggregation(); + /// Factory to create a MIN aggregation -std::unique_ptr make_min_aggregation() +template +std::unique_ptr make_min_aggregation() { return std::make_unique(); } +template std::unique_ptr make_min_aggregation(); +template std::unique_ptr make_min_aggregation(); + /// Factory to create a MAX aggregation -std::unique_ptr make_max_aggregation() +template +std::unique_ptr make_max_aggregation() { return std::make_unique(); } +template std::unique_ptr make_max_aggregation(); +template std::unique_ptr make_max_aggregation(); + /// Factory to create a COUNT aggregation -std::unique_ptr make_count_aggregation(null_policy null_handling) +template +std::unique_ptr make_count_aggregation(null_policy null_handling) { auto kind = (null_handling == null_policy::INCLUDE) ? aggregation::COUNT_ALL : aggregation::COUNT_VALID; - return std::make_unique(kind); + return std::make_unique(kind); } +template std::unique_ptr make_count_aggregation( + null_policy null_handling); +template std::unique_ptr make_count_aggregation( + null_policy null_handling); + /// Factory to create a ANY aggregation -std::unique_ptr make_any_aggregation() +template +std::unique_ptr make_any_aggregation() { - return std::make_unique(aggregation::ANY); + return std::make_unique(); } +template std::unique_ptr make_any_aggregation(); + /// Factory to create a ALL aggregation -std::unique_ptr make_all_aggregation() +template +std::unique_ptr make_all_aggregation() { - return std::make_unique(aggregation::ALL); + return std::make_unique(); } +template std::unique_ptr make_all_aggregation(); + /// Factory to create a SUM_OF_SQUARES aggregation -std::unique_ptr make_sum_of_squares_aggregation() +template +std::unique_ptr make_sum_of_squares_aggregation() { - return std::make_unique(aggregation::SUM_OF_SQUARES); + return std::make_unique(); } +template std::unique_ptr make_sum_of_squares_aggregation(); + /// Factory to create a MEAN aggregation -std::unique_ptr make_mean_aggregation() +template +std::unique_ptr make_mean_aggregation() { return std::make_unique(); } +template std::unique_ptr make_mean_aggregation(); +template std::unique_ptr make_mean_aggregation(); + /// Factory to create a VARIANCE aggregation -std::unique_ptr make_variance_aggregation(size_type ddof) +template +std::unique_ptr make_variance_aggregation(size_type ddof) { return std::make_unique(ddof); -}; +} +template std::unique_ptr make_variance_aggregation(size_type ddof); + /// Factory to create a STD aggregation -std::unique_ptr make_std_aggregation(size_type ddof) +template +std::unique_ptr make_std_aggregation(size_type ddof) { return std::make_unique(ddof); -}; +} +template std::unique_ptr make_std_aggregation(size_type ddof); + /// Factory to create a MEDIAN aggregation -std::unique_ptr make_median_aggregation() +template +std::unique_ptr make_median_aggregation() { - // TODO I think this should just return a quantile_aggregation? - return std::make_unique(aggregation::MEDIAN); + return std::make_unique(); } +template std::unique_ptr make_median_aggregation(); + /// Factory to create a QUANTILE aggregation -std::unique_ptr make_quantile_aggregation(std::vector const& q, - interpolation i) +template +std::unique_ptr make_quantile_aggregation(std::vector const& q, interpolation i) { return std::make_unique(q, i); } -/// Factory to create a ARGMAX aggregation -std::unique_ptr make_argmax_aggregation() +template std::unique_ptr make_quantile_aggregation( + std::vector const& q, interpolation i); + +/// Factory to create an ARGMAX aggregation +template +std::unique_ptr make_argmax_aggregation() { - return std::make_unique(aggregation::ARGMAX); + return std::make_unique(); } -/// Factory to create a ARGMIN aggregation -std::unique_ptr make_argmin_aggregation() +template std::unique_ptr make_argmax_aggregation(); +template std::unique_ptr make_argmax_aggregation(); + +/// Factory to create an ARGMIN aggregation +template +std::unique_ptr make_argmin_aggregation() { - return std::make_unique(aggregation::ARGMIN); + return std::make_unique(); } -/// Factory to create a NUNIQUE aggregation -std::unique_ptr make_nunique_aggregation(null_policy null_handling) +template std::unique_ptr make_argmin_aggregation(); +template std::unique_ptr make_argmin_aggregation(); + +/// Factory to create an NUNIQUE aggregation +template +std::unique_ptr make_nunique_aggregation(null_policy null_handling) { return std::make_unique(null_handling); } -/// Factory to create a NTH_ELEMENT aggregation -std::unique_ptr make_nth_element_aggregation(size_type n, null_policy null_handling) +template std::unique_ptr make_nunique_aggregation( + null_policy null_handling); + +/// Factory to create an NTH_ELEMENT aggregation +template +std::unique_ptr make_nth_element_aggregation(size_type n, null_policy null_handling) { return std::make_unique(n, null_handling); } +template std::unique_ptr make_nth_element_aggregation( + size_type n, null_policy null_handling); + /// Factory to create a ROW_NUMBER aggregation -std::unique_ptr make_row_number_aggregation() +template +std::unique_ptr make_row_number_aggregation() { - return std::make_unique(aggregation::ROW_NUMBER); + return std::make_unique(); } +template std::unique_ptr make_row_number_aggregation(); +template std::unique_ptr make_row_number_aggregation(); + /// Factory to create a COLLECT_LIST aggregation -std::unique_ptr make_collect_list_aggregation(null_policy null_handling) +template +std::unique_ptr make_collect_list_aggregation(null_policy null_handling) { return std::make_unique(null_handling); } +template std::unique_ptr make_collect_list_aggregation( + null_policy null_handling); +template std::unique_ptr make_collect_list_aggregation( + null_policy null_handling); + /// Factory to create a COLLECT_SET aggregation -std::unique_ptr make_collect_set_aggregation(null_policy null_handling, - null_equality nulls_equal, - nan_equality nans_equal) +template +std::unique_ptr make_collect_set_aggregation(null_policy null_handling, + null_equality nulls_equal, + nan_equality nans_equal) { return std::make_unique(null_handling, nulls_equal, nans_equal); } +template std::unique_ptr make_collect_set_aggregation( + null_policy null_handling, null_equality nulls_equal, nan_equality nans_equal); + /// Factory to create a LAG aggregation -std::unique_ptr make_lag_aggregation(size_type offset) +template +std::unique_ptr make_lag_aggregation(size_type offset) { - return std::make_unique(aggregation::LAG, offset); + return std::make_unique(aggregation::LAG, offset); } +template std::unique_ptr make_lag_aggregation(size_type offset); +template std::unique_ptr make_lag_aggregation( + size_type offset); + /// Factory to create a LEAD aggregation -std::unique_ptr make_lead_aggregation(size_type offset) +template +std::unique_ptr make_lead_aggregation(size_type offset) { - return std::make_unique(aggregation::LEAD, offset); + return std::make_unique(aggregation::LEAD, offset); } +template std::unique_ptr make_lead_aggregation(size_type offset); +template std::unique_ptr make_lead_aggregation( + size_type offset); + /// Factory to create a UDF aggregation -std::unique_ptr make_udf_aggregation(udf_type type, - std::string const& user_defined_aggregator, - data_type output_type) +template +std::unique_ptr make_udf_aggregation(udf_type type, + std::string const& user_defined_aggregator, + data_type output_type) { - aggregation* a = + auto* a = new detail::udf_aggregation{type == udf_type::PTX ? aggregation::PTX : aggregation::CUDA, user_defined_aggregator, output_type}; - return std::unique_ptr(a); + return std::unique_ptr(a); } +template std::unique_ptr make_udf_aggregation( + udf_type type, std::string const& user_defined_aggregator, data_type output_type); +template std::unique_ptr make_udf_aggregation( + udf_type type, std::string const& user_defined_aggregator, data_type output_type); namespace detail { namespace { diff --git a/cpp/src/groupby/hash/groupby.cu b/cpp/src/groupby/hash/groupby.cu index 022fefb6428..31b48790861 100644 --- a/cpp/src/groupby/hash/groupby.cu +++ b/cpp/src/groupby/hash/groupby.cu @@ -100,6 +100,64 @@ bool constexpr is_hash_aggregation(aggregation::Kind t) return array_contains(hash_aggregations, t); } +class groupby_simple_aggregations_collector final + : public cudf::detail::simple_aggregations_collector { + public: + using cudf::detail::simple_aggregations_collector::visit; + + std::vector> visit(data_type col_type, + cudf::detail::min_aggregation const& agg) override + { + std::vector> aggs; + aggs.push_back(col_type.id() == type_id::STRING ? make_argmin_aggregation() + : make_min_aggregation()); + return aggs; + } + + std::vector> visit(data_type col_type, + cudf::detail::max_aggregation const& agg) override + { + std::vector> aggs; + aggs.push_back(col_type.id() == type_id::STRING ? make_argmax_aggregation() + : make_max_aggregation()); + return aggs; + } + + std::vector> visit( + data_type col_type, cudf::detail::mean_aggregation const& agg) override + { + CUDF_EXPECTS(is_fixed_width(col_type), "MEAN aggregation expects fixed width type"); + std::vector> aggs; + aggs.push_back(make_sum_aggregation()); + // COUNT_VALID + aggs.push_back(make_count_aggregation()); + + return aggs; + } + + std::vector> visit(data_type col_type, + cudf::detail::var_aggregation const& agg) override + { + std::vector> aggs; + aggs.push_back(make_sum_aggregation()); + // COUNT_VALID + aggs.push_back(make_count_aggregation()); + + return aggs; + } + + std::vector> visit(data_type col_type, + cudf::detail::std_aggregation const& agg) override + { + std::vector> aggs; + aggs.push_back(make_sum_aggregation()); + // COUNT_VALID + aggs.push_back(make_count_aggregation()); + + return aggs; + } +}; + template class hash_compound_agg_finalizer final : public cudf::detail::aggregation_finalizer { size_t col_idx; @@ -115,6 +173,8 @@ class hash_compound_agg_finalizer final : public cudf::detail::aggregation_final rmm::cuda_stream_view stream; public: + using cudf::detail::aggregation_finalizer::visit; + hash_compound_agg_finalizer(size_t col_idx, column_view col, cudf::detail::result_cache* sparse_results, @@ -153,10 +213,9 @@ class hash_compound_agg_finalizer final : public cudf::detail::aggregation_final } // Enables conversion of ARGMIN/ARGMAX into MIN/MAX - auto gather_argminmax(aggregation::Kind const& agg_kind) + auto gather_argminmax(aggregation const& agg) { - auto transformed_agg = std::make_unique(agg_kind); - auto arg_result = to_dense_agg_result(*transformed_agg); + auto arg_result = to_dense_agg_result(agg); // We make a view of ARG(MIN/MAX) result without a null mask and gather // using this map. The values in data buffer of ARG(MIN/MAX) result // corresponding to null values was initialized to ARG(MIN/MAX)_SENTINEL @@ -175,7 +234,7 @@ class hash_compound_agg_finalizer final : public cudf::detail::aggregation_final stream, mr); return std::move(gather_argminmax->release()[0]); - }; + } // Declare overloads for each kind of aggregation to dispatch void visit(cudf::aggregation const& agg) override @@ -187,20 +246,24 @@ class hash_compound_agg_finalizer final : public cudf::detail::aggregation_final void visit(cudf::detail::min_aggregation const& agg) override { if (dense_results->has_result(col_idx, agg)) return; - if (result_type.id() == type_id::STRING) - dense_results->add_result(col_idx, agg, gather_argminmax(aggregation::ARGMIN)); - else + if (result_type.id() == type_id::STRING) { + auto transformed_agg = make_argmin_aggregation(); + dense_results->add_result(col_idx, agg, gather_argminmax(*transformed_agg)); + } else { dense_results->add_result(col_idx, agg, to_dense_agg_result(agg)); + } } void visit(cudf::detail::max_aggregation const& agg) override { if (dense_results->has_result(col_idx, agg)) return; - if (result_type.id() == type_id::STRING) - dense_results->add_result(col_idx, agg, gather_argminmax(aggregation::ARGMAX)); - else + if (result_type.id() == type_id::STRING) { + auto transformed_agg = make_argmax_aggregation(); + dense_results->add_result(col_idx, agg, gather_argminmax(*transformed_agg)); + } else { dense_results->add_result(col_idx, agg, to_dense_agg_result(agg)); + } } void visit(cudf::detail::mean_aggregation const& agg) override @@ -259,19 +322,22 @@ class hash_compound_agg_finalizer final : public cudf::detail::aggregation_final { if (dense_results->has_result(col_idx, agg)) return; auto var_agg = make_variance_aggregation(agg._ddof); - this->visit(*static_cast(var_agg.get())); + this->visit(*dynamic_cast(var_agg.get())); column_view variance = dense_results->get_result(col_idx, *var_agg); auto result = cudf::detail::unary_operation(variance, unary_operator::SQRT, stream, mr); dense_results->add_result(col_idx, agg, std::move(result)); } }; - // flatten aggs to filter in single pass aggs -std::tuple, std::vector> +std::tuple, + std::vector>, + std::vector> flatten_single_pass_aggs(host_span requests) { std::vector columns; + std::vector> aggs; std::vector agg_kinds; std::vector col_ids; @@ -280,24 +346,30 @@ flatten_single_pass_aggs(host_span requests) auto const& agg_v = request.aggregations; std::unordered_set agg_kinds_set; - auto insert_agg = [&](size_t i, column_view const& request_values, aggregation::Kind k) { - if (agg_kinds_set.insert(k).second) { - agg_kinds.push_back(k); - columns.push_back(request_values); - col_ids.push_back(i); - } - }; + auto insert_agg = + [&](size_t i, column_view const& request_values, std::unique_ptr&& agg) { + if (agg_kinds_set.insert(agg->kind).second) { + agg_kinds.push_back(agg->kind); + aggs.push_back(std::move(agg)); + columns.push_back(request_values); + col_ids.push_back(i); + } + }; auto values_type = cudf::is_dictionary(request.values.type()) ? cudf::dictionary_column_view(request.values).keys().type() : request.values.type(); for (auto&& agg : agg_v) { - for (auto const& agg_s : agg->get_simple_aggregations(values_type)) - insert_agg(i, request.values, agg_s); + groupby_simple_aggregations_collector collector; + + for (auto& agg_s : agg->get_simple_aggregations(values_type, collector)) { + insert_agg(i, request.values, std::move(agg_s)); + } } } - return std::make_tuple(table_view(columns), std::move(agg_kinds), std::move(col_ids)); + return std::make_tuple( + table_view(columns), std::move(agg_kinds), std::move(aggs), std::move(col_ids)); } /** @@ -425,14 +497,14 @@ void compute_single_pass_aggs(table_view const& keys, rmm::cuda_stream_view stream) { // flatten the aggs to a table that can be operated on by aggregate_row - auto const [flattened_values, aggs, col_ids] = flatten_single_pass_aggs(requests); + auto const [flattened_values, agg_kinds, aggs, col_ids] = flatten_single_pass_aggs(requests); // make table that will hold sparse results - table sparse_table = create_sparse_results_table(flattened_values, aggs, stream); + table sparse_table = create_sparse_results_table(flattened_values, agg_kinds, stream); // prepare to launch kernel to do the actual aggregation auto d_sparse_table = mutable_table_device_view::create(sparse_table, stream); auto d_values = table_device_view::create(flattened_values, stream); - auto const d_aggs = cudf::detail::make_device_uvector_async(aggs, stream); + auto const d_aggs = cudf::detail::make_device_uvector_async(agg_kinds, stream); bool skip_key_rows_with_nulls = keys_have_nulls and include_null_keys == null_policy::EXCLUDE; @@ -453,8 +525,7 @@ void compute_single_pass_aggs(table_view const& keys, auto sparse_result_cols = sparse_table.release(); for (size_t i = 0; i < aggs.size(); i++) { // Note that the cache will make a copy of this temporary aggregation - auto agg = std::make_unique(aggs[i]); - sparse_results->add_result(col_ids[i], *agg, std::move(sparse_result_cols[i])); + sparse_results->add_result(col_ids[i], *aggs[i], std::move(sparse_result_cols[i])); } } diff --git a/cpp/src/groupby/sort/aggregate.cpp b/cpp/src/groupby/sort/aggregate.cpp index 12f157cd3d9..bf091565b22 100644 --- a/cpp/src/groupby/sort/aggregate.cpp +++ b/cpp/src/groupby/sort/aggregate.cpp @@ -51,7 +51,7 @@ namespace detail { * memoised sorted and/or grouped values and re-using will save on computation * of these values. */ -struct aggregrate_result_functor final : store_result_functor { +struct aggregate_result_functor final : store_result_functor { using store_result_functor::store_result_functor; template void operator()(aggregation const& agg) @@ -61,7 +61,7 @@ struct aggregrate_result_functor final : store_result_functor { }; template <> -void aggregrate_result_functor::operator()(aggregation const& agg) +void aggregate_result_functor::operator()(aggregation const& agg) { if (cache.has_result(col_idx, agg)) return; @@ -76,7 +76,7 @@ void aggregrate_result_functor::operator()(aggregation } template <> -void aggregrate_result_functor::operator()(aggregation const& agg) +void aggregate_result_functor::operator()(aggregation const& agg) { if (cache.has_result(col_idx, agg)) return; @@ -87,7 +87,7 @@ void aggregrate_result_functor::operator()(aggregation c } template <> -void aggregrate_result_functor::operator()(aggregation const& agg) +void aggregate_result_functor::operator()(aggregation const& agg) { if (cache.has_result(col_idx, agg)) return; @@ -99,7 +99,7 @@ void aggregrate_result_functor::operator()(aggregation const& }; template <> -void aggregrate_result_functor::operator()(aggregation const& agg) +void aggregate_result_functor::operator()(aggregation const& agg) { if (cache.has_result(col_idx, agg)) return; @@ -111,7 +111,7 @@ void aggregrate_result_functor::operator()(aggregation con }; template <> -void aggregrate_result_functor::operator()(aggregation const& agg) +void aggregate_result_functor::operator()(aggregation const& agg) { if (cache.has_result(col_idx, agg)) return; @@ -126,7 +126,7 @@ void aggregrate_result_functor::operator()(aggregation cons }; template <> -void aggregrate_result_functor::operator()(aggregation const& agg) +void aggregate_result_functor::operator()(aggregation const& agg) { if (cache.has_result(col_idx, agg)) return; @@ -141,7 +141,7 @@ void aggregrate_result_functor::operator()(aggregation cons }; template <> -void aggregrate_result_functor::operator()(aggregation const& agg) +void aggregate_result_functor::operator()(aggregation const& agg) { if (cache.has_result(col_idx, agg)) return; @@ -178,7 +178,7 @@ void aggregrate_result_functor::operator()(aggregation const& }; template <> -void aggregrate_result_functor::operator()(aggregation const& agg) +void aggregate_result_functor::operator()(aggregation const& agg) { if (cache.has_result(col_idx, agg)) return; @@ -215,7 +215,7 @@ void aggregrate_result_functor::operator()(aggregation const& }; template <> -void aggregrate_result_functor::operator()(aggregation const& agg) +void aggregate_result_functor::operator()(aggregation const& agg) { if (cache.has_result(col_idx, agg)) return; @@ -239,11 +239,11 @@ void aggregrate_result_functor::operator()(aggregation const& }; template <> -void aggregrate_result_functor::operator()(aggregation const& agg) +void aggregate_result_functor::operator()(aggregation const& agg) { if (cache.has_result(col_idx, agg)) return; - auto var_agg = static_cast(agg); + auto var_agg = dynamic_cast(agg); auto mean_agg = make_mean_aggregation(); auto count_agg = make_count_aggregation(); operator()(*mean_agg); @@ -262,11 +262,11 @@ void aggregrate_result_functor::operator()(aggregation co }; template <> -void aggregrate_result_functor::operator()(aggregation const& agg) +void aggregate_result_functor::operator()(aggregation const& agg) { if (cache.has_result(col_idx, agg)) return; - auto std_agg = static_cast(agg); + auto std_agg = dynamic_cast(agg); auto var_agg = make_variance_aggregation(std_agg._ddof); operator()(*var_agg); column_view var_result = cache.get_result(col_idx, *var_agg); @@ -276,14 +276,14 @@ void aggregrate_result_functor::operator()(aggregation const& }; template <> -void aggregrate_result_functor::operator()(aggregation const& agg) +void aggregate_result_functor::operator()(aggregation const& agg) { if (cache.has_result(col_idx, agg)) return; auto count_agg = make_count_aggregation(); operator()(*count_agg); column_view group_sizes = cache.get_result(col_idx, *count_agg); - auto quantile_agg = static_cast(agg); + auto quantile_agg = dynamic_cast(agg); auto result = detail::group_quantiles(get_sorted_values(), group_sizes, @@ -297,7 +297,7 @@ void aggregrate_result_functor::operator()(aggregation co }; template <> -void aggregrate_result_functor::operator()(aggregation const& agg) +void aggregate_result_functor::operator()(aggregation const& agg) { if (cache.has_result(col_idx, agg)) return; @@ -317,11 +317,11 @@ void aggregrate_result_functor::operator()(aggregation cons }; template <> -void aggregrate_result_functor::operator()(aggregation const& agg) +void aggregate_result_functor::operator()(aggregation const& agg) { if (cache.has_result(col_idx, agg)) return; - auto nunique_agg = static_cast(agg); + auto nunique_agg = dynamic_cast(agg); auto result = detail::group_nunique(get_sorted_values(), helper.group_labels(stream), @@ -334,19 +334,20 @@ void aggregrate_result_functor::operator()(aggregation con }; template <> -void aggregrate_result_functor::operator()(aggregation const& agg) +void aggregate_result_functor::operator()(aggregation const& agg) { if (cache.has_result(col_idx, agg)) return; - auto nth_element_agg = static_cast(agg); + auto nth_element_agg = dynamic_cast(agg); auto count_agg = make_count_aggregation(nth_element_agg._null_handling); - if (count_agg->kind == aggregation::COUNT_VALID) + if (count_agg->kind == aggregation::COUNT_VALID) { operator()(*count_agg); - else if (count_agg->kind == aggregation::COUNT_ALL) + } else if (count_agg->kind == aggregation::COUNT_ALL) { operator()(*count_agg); - else + } else { CUDF_FAIL("Wrong count aggregation kind"); + } column_view group_sizes = cache.get_result(col_idx, *count_agg); cache.add_result(col_idx, @@ -363,10 +364,11 @@ void aggregrate_result_functor::operator()(aggregation } template <> -void aggregrate_result_functor::operator()(aggregation const& agg) +void aggregate_result_functor::operator()(aggregation const& agg) { auto null_handling = - static_cast(agg)._null_handling; + dynamic_cast(agg)._null_handling; + agg.do_hash(); CUDF_EXPECTS(null_handling == null_policy::INCLUDE, "null exclusion is not supported on groupby COLLECT_LIST aggregation."); @@ -379,10 +381,10 @@ void aggregrate_result_functor::operator()(aggregatio }; template <> -void aggregrate_result_functor::operator()(aggregation const& agg) +void aggregate_result_functor::operator()(aggregation const& agg) { auto const null_handling = - static_cast(agg)._null_handling; + dynamic_cast(agg)._null_handling; CUDF_EXPECTS(null_handling == null_policy::INCLUDE, "null exclusion is not supported on groupby COLLECT_SET aggregation."); @@ -391,9 +393,9 @@ void aggregrate_result_functor::operator()(aggregation auto const collect_result = detail::group_collect( get_grouped_values(), helper.group_offsets(stream), helper.num_groups(stream), stream, mr); auto const nulls_equal = - static_cast(agg)._nulls_equal; + dynamic_cast(agg)._nulls_equal; auto const nans_equal = - static_cast(agg)._nans_equal; + dynamic_cast(agg)._nans_equal; cache.add_result( col_idx, agg, @@ -415,7 +417,7 @@ std::pair, std::vector> groupby::sort for (size_t i = 0; i < requests.size(); i++) { auto store_functor = - detail::aggregrate_result_functor(i, requests[i].values, helper(), cache, stream, mr); + detail::aggregate_result_functor(i, requests[i].values, helper(), cache, stream, mr); for (size_t j = 0; j < requests[i].aggregations.size(); j++) { // TODO (dm): single pass compute all supported reductions cudf::detail::aggregation_dispatcher( diff --git a/cpp/src/reductions/reductions.cpp b/cpp/src/reductions/reductions.cpp index 43dd86b307f..083b0da8cf3 100644 --- a/cpp/src/reductions/reductions.cpp +++ b/cpp/src/reductions/reductions.cpp @@ -58,11 +58,11 @@ struct reduce_dispatch_functor { break; case aggregation::MEAN: return reduction::mean(col, output_dtype, stream, mr); break; case aggregation::VARIANCE: { - auto var_agg = static_cast(agg.get()); + auto var_agg = dynamic_cast(agg.get()); return reduction::variance(col, output_dtype, var_agg->_ddof, stream, mr); } break; case aggregation::STD: { - auto var_agg = static_cast(agg.get()); + auto var_agg = dynamic_cast(agg.get()); return reduction::standard_deviation(col, output_dtype, var_agg->_ddof, stream, mr); } break; case aggregation::MEDIAN: { @@ -73,7 +73,7 @@ struct reduce_dispatch_functor { return get_element(*col_ptr, 0, stream, mr); } break; case aggregation::QUANTILE: { - auto quantile_agg = static_cast(agg.get()); + auto quantile_agg = dynamic_cast(agg.get()); CUDF_EXPECTS(quantile_agg->_quantiles.size() == 1, "Reduction quantile accepts only one quantile value"); auto sorted_indices = sorted_order(table_view{{col}}, {}, {null_order::AFTER}, stream, mr); @@ -89,7 +89,7 @@ struct reduce_dispatch_functor { return get_element(*col_ptr, 0, stream, mr); } break; case aggregation::NUNIQUE: { - auto nunique_agg = static_cast(agg.get()); + auto nunique_agg = dynamic_cast(agg.get()); return make_fixed_width_scalar( detail::distinct_count( col, nunique_agg->_null_handling, nan_policy::NAN_IS_VALID, stream), @@ -97,7 +97,7 @@ struct reduce_dispatch_functor { mr); } break; case aggregation::NTH_ELEMENT: { - auto nth_agg = static_cast(agg.get()); + auto nth_agg = dynamic_cast(agg.get()); return reduction::nth_element(col, nth_agg->_n, nth_agg->_null_handling, stream, mr); } break; default: CUDF_FAIL("Unsupported reduction operator"); diff --git a/cpp/src/rolling/grouped_rolling.cu b/cpp/src/rolling/grouped_rolling.cu index 888d28fd1a5..5702a32536c 100644 --- a/cpp/src/rolling/grouped_rolling.cu +++ b/cpp/src/rolling/grouped_rolling.cu @@ -31,7 +31,7 @@ std::unique_ptr grouped_rolling_window(table_view const& group_keys, size_type preceding_window, size_type following_window, size_type min_periods, - std::unique_ptr const& aggr, + rolling_aggregation const& aggr, rmm::mr::device_memory_resource* mr) { return grouped_rolling_window(group_keys, @@ -48,7 +48,7 @@ std::unique_ptr grouped_rolling_window(table_view const& group_keys, window_bounds preceding_window, window_bounds following_window, size_type min_periods, - std::unique_ptr const& aggr, + rolling_aggregation const& aggr, rmm::mr::device_memory_resource* mr) { return grouped_rolling_window(group_keys, @@ -67,7 +67,7 @@ std::unique_ptr grouped_rolling_window(table_view const& group_keys, size_type preceding_window, size_type following_window, size_type min_periods, - std::unique_ptr const& aggr, + rolling_aggregation const& aggr, rmm::mr::device_memory_resource* mr) { return grouped_rolling_window(group_keys, @@ -88,7 +88,7 @@ std::unique_ptr grouped_rolling_window(table_view const& group_keys, window_bounds preceding_window_bounds, window_bounds following_window_bounds, size_type min_periods, - std::unique_ptr const& aggr, + rolling_aggregation const& aggr, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { @@ -155,7 +155,7 @@ std::unique_ptr grouped_rolling_window(table_view const& group_keys, return thrust::minimum{}(following_window, (group_end - 1) - idx); }; - if (aggr->kind == aggregation::CUDA || aggr->kind == aggregation::PTX) { + if (aggr.kind == aggregation::CUDA || aggr.kind == aggregation::PTX) { cudf::detail::preceding_window_wrapper grouped_preceding_window{ group_offsets.data(), group_labels.data(), preceding_window}; @@ -192,7 +192,7 @@ std::unique_ptr grouped_rolling_window(table_view const& group_keys, window_bounds preceding_window_bounds, window_bounds following_window_bounds, size_type min_periods, - std::unique_ptr const& aggr, + rolling_aggregation const& aggr, rmm::mr::device_memory_resource* mr) { return detail::grouped_rolling_window(group_keys, @@ -309,7 +309,7 @@ std::unique_ptr range_window_ASC(column_view const& input, T following_window, bool following_window_is_unbounded, size_type min_periods, - std::unique_ptr const& aggr, + rolling_aggregation const& aggr, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { @@ -473,7 +473,7 @@ std::unique_ptr range_window_ASC(column_view const& input, T following_window, bool following_window_is_unbounded, size_type min_periods, - std::unique_ptr const& aggr, + rolling_aggregation const& aggr, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { @@ -577,7 +577,7 @@ std::unique_ptr range_window_DESC(column_view const& input, T following_window, bool following_window_is_unbounded, size_type min_periods, - std::unique_ptr const& aggr, + rolling_aggregation const& aggr, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { @@ -668,7 +668,7 @@ std::unique_ptr range_window_DESC(column_view const& input, T following_window, bool following_window_is_unbounded, size_type min_periods, - std::unique_ptr const& aggr, + rolling_aggregation const& aggr, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { @@ -760,7 +760,7 @@ std::unique_ptr range_window_DESC(column_view const& input, auto following_column = expand_to_column(following_calculator, input.size(), stream, mr); - if (aggr->kind == aggregation::CUDA || aggr->kind == aggregation::PTX) { + if (aggr.kind == aggregation::CUDA || aggr.kind == aggregation::PTX) { CUDF_FAIL("Ranged rolling window does NOT (yet) support UDF."); } else { return cudf::detail::rolling_window( @@ -778,7 +778,7 @@ std::unique_ptr grouped_range_rolling_window_impl( range_window_bounds preceding_window, range_window_bounds following_window, size_type min_periods, - std::unique_ptr const& aggr, + rolling_aggregation const& aggr, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { @@ -931,7 +931,7 @@ namespace detail { * range_window_bounds const& preceding, * range_window_bounds const& following, * size_type min_periods, - * std::unique_ptr const& aggr, + * rolling_aggregation const& aggr, * rmm::mr::device_memory_resource* mr ); * * @param stream CUDA stream used for device memory operations and kernel launches. @@ -943,7 +943,7 @@ std::unique_ptr grouped_range_rolling_window(table_view const& group_key range_window_bounds const& preceding, range_window_bounds const& following, size_type min_periods, - std::unique_ptr const& aggr, + rolling_aggregation const& aggr, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { @@ -992,7 +992,7 @@ std::unique_ptr grouped_range_rolling_window(table_view const& group_key * size_type preceding_window_in_days, * size_type following_window_in_days, * size_type min_periods, - * std::unique_ptr const& aggr, + * rolling_aggregation const& aggr, * rmm::mr::device_memory_resource* mr); */ std::unique_ptr grouped_time_range_rolling_window(table_view const& group_keys, @@ -1002,7 +1002,7 @@ std::unique_ptr grouped_time_range_rolling_window(table_view const& grou size_type preceding_window_in_days, size_type following_window_in_days, size_type min_periods, - std::unique_ptr const& aggr, + rolling_aggregation const& aggr, rmm::mr::device_memory_resource* mr) { auto preceding = to_range_bounds(preceding_window_in_days, timestamp_column.type()); @@ -1028,7 +1028,7 @@ std::unique_ptr grouped_time_range_rolling_window(table_view const& grou * window_bounds preceding_window_in_days, * window_bounds following_window_in_days, * size_type min_periods, - * std::unique_ptr const& aggr, + * rolling_aggregation const& aggr, * rmm::mr::device_memory_resource* mr); */ std::unique_ptr grouped_time_range_rolling_window(table_view const& group_keys, @@ -1038,7 +1038,7 @@ std::unique_ptr grouped_time_range_rolling_window(table_view const& grou window_bounds preceding_window_in_days, window_bounds following_window_in_days, size_type min_periods, - std::unique_ptr const& aggr, + rolling_aggregation const& aggr, rmm::mr::device_memory_resource* mr) { range_window_bounds preceding = @@ -1067,7 +1067,7 @@ std::unique_ptr grouped_time_range_rolling_window(table_view const& grou * range_window_bounds const& preceding, * range_window_bounds const& following, * size_type min_periods, - * std::unique_ptr const& aggr, + * rolling_aggregation const& aggr, * rmm::mr::device_memory_resource* mr ); */ std::unique_ptr grouped_range_rolling_window(table_view const& group_keys, @@ -1077,7 +1077,7 @@ std::unique_ptr grouped_range_rolling_window(table_view const& group_key range_window_bounds const& preceding, range_window_bounds const& following, size_type min_periods, - std::unique_ptr const& aggr, + rolling_aggregation const& aggr, rmm::mr::device_memory_resource* mr) { return detail::grouped_range_rolling_window(group_keys, diff --git a/cpp/src/rolling/rolling.cu b/cpp/src/rolling/rolling.cu index c187b8720b1..63032128c4d 100644 --- a/cpp/src/rolling/rolling.cu +++ b/cpp/src/rolling/rolling.cu @@ -23,7 +23,7 @@ std::unique_ptr rolling_window(column_view const& input, size_type preceding_window, size_type following_window, size_type min_periods, - std::unique_ptr const& agg, + rolling_aggregation const& agg, rmm::mr::device_memory_resource* mr) { auto defaults = @@ -40,7 +40,7 @@ std::unique_ptr rolling_window(column_view const& input, size_type preceding_window, size_type following_window, size_type min_periods, - std::unique_ptr const& agg, + rolling_aggregation const& agg, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { @@ -52,7 +52,7 @@ std::unique_ptr rolling_window(column_view const& input, CUDF_EXPECTS((default_outputs.is_empty() || default_outputs.size() == input.size()), "Defaults column must be either empty or have as many rows as the input column."); - if (agg->kind == aggregation::CUDA || agg->kind == aggregation::PTX) { + if (agg.kind == aggregation::CUDA || agg.kind == aggregation::PTX) { return cudf::detail::rolling_window_udf(input, preceding_window, "cudf::size_type", @@ -82,7 +82,7 @@ std::unique_ptr rolling_window(column_view const& input, column_view const& preceding_window, column_view const& following_window, size_type min_periods, - std::unique_ptr const& agg, + rolling_aggregation const& agg, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { @@ -98,7 +98,7 @@ std::unique_ptr rolling_window(column_view const& input, CUDF_EXPECTS(preceding_window.size() == input.size() && following_window.size() == input.size(), "preceding_window/following_window size must match input size"); - if (agg->kind == aggregation::CUDA || agg->kind == aggregation::PTX) { + if (agg.kind == aggregation::CUDA || agg.kind == aggregation::PTX) { return cudf::detail::rolling_window_udf(input, preceding_window.begin(), "cudf::size_type*", @@ -130,7 +130,7 @@ std::unique_ptr rolling_window(column_view const& input, size_type preceding_window, size_type following_window, size_type min_periods, - std::unique_ptr const& agg, + rolling_aggregation const& agg, rmm::mr::device_memory_resource* mr) { return detail::rolling_window(input, @@ -148,7 +148,7 @@ std::unique_ptr rolling_window(column_view const& input, column_view const& preceding_window, column_view const& following_window, size_type min_periods, - std::unique_ptr const& agg, + rolling_aggregation const& agg, rmm::mr::device_memory_resource* mr) { return detail::rolling_window( diff --git a/cpp/src/rolling/rolling_detail.cuh b/cpp/src/rolling/rolling_detail.cuh index a26dc4c7120..1192b9cad87 100644 --- a/cpp/src/rolling/rolling_detail.cuh +++ b/cpp/src/rolling/rolling_detail.cuh @@ -504,7 +504,7 @@ struct rolling_window_launcher { PrecedingWindowIterator preceding_window_begin, FollowingWindowIterator following_window_begin, size_type min_periods, - std::unique_ptr const& agg, + rolling_aggregation const& agg, rmm::cuda_stream_view stream) { using Type = device_storage_type_t; @@ -558,7 +558,7 @@ struct rolling_window_launcher { PrecedingWindowIterator preceding_window_begin, FollowingWindowIterator following_window_begin, size_type min_periods, - std::unique_ptr const& agg, + rolling_aggregation const& agg, agg_op const& device_agg_op, rmm::cuda_stream_view stream) { @@ -621,7 +621,7 @@ struct rolling_window_launcher { PrecedingWindowIterator preceding_window_begin, FollowingWindowIterator following_window_begin, size_type min_periods, - std::unique_ptr const& agg, + rolling_aggregation const& agg, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { @@ -659,7 +659,7 @@ struct rolling_window_launcher { PrecedingWindowIterator preceding_window_begin, FollowingWindowIterator following_window_begin, size_type min_periods, - std::unique_ptr const& agg, + rolling_aggregation const& agg, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { @@ -728,7 +728,7 @@ struct rolling_window_launcher { PrecedingWindowIterator preceding_window_begin, FollowingWindowIterator following_window_begin, size_type min_periods, - std::unique_ptr const& agg, + rolling_aggregation const& agg, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { @@ -748,7 +748,7 @@ struct rolling_window_launcher { PrecedingWindowIterator preceding, FollowingWindowIterator following, size_type min_periods, - std::unique_ptr const& agg, + rolling_aggregation const& agg, agg_op const& device_agg_op, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) @@ -785,7 +785,7 @@ struct rolling_window_launcher { PrecedingWindowIterator preceding_window_begin, FollowingWindowIterator following_window_begin, size_type min_periods, - std::unique_ptr const& agg, + rolling_aggregation const& agg, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { @@ -816,7 +816,7 @@ struct rolling_window_launcher { PrecedingWindowIterator preceding_window_begin, FollowingWindowIterator following_window_begin, size_type min_periods, - std::unique_ptr const& agg, + rolling_aggregation const& agg, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { @@ -842,7 +842,7 @@ struct rolling_window_launcher { PrecedingWindowIterator preceding_window_begin, FollowingWindowIterator following_window_begin, size_type min_periods, - std::unique_ptr const& agg, + rolling_aggregation const& agg, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { @@ -857,7 +857,7 @@ struct rolling_window_launcher { following_window_begin, min_periods, agg, - cudf::DeviceLeadLag{static_cast(agg.get())->row_offset}, + cudf::DeviceLeadLag{dynamic_cast(agg).row_offset}, stream, mr); } @@ -873,7 +873,7 @@ struct rolling_window_launcher { PrecedingWindowIterator preceding_window_begin, FollowingWindowIterator following_window_begin, size_type min_periods, - std::unique_ptr const& agg, + rolling_aggregation const& agg, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { @@ -883,7 +883,7 @@ struct rolling_window_launcher { default_outputs, preceding_window_begin, following_window_begin, - static_cast(agg.get())->row_offset, + dynamic_cast(agg).row_offset, stream, mr); } @@ -1137,7 +1137,7 @@ struct rolling_window_launcher { PrecedingIter preceding_begin_raw, FollowingIter following_begin_raw, size_type min_periods, - std::unique_ptr const& agg, + rolling_aggregation const& agg, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { @@ -1175,7 +1175,7 @@ struct rolling_window_launcher { // If gather_map collects null elements, and null_policy == EXCLUDE, // those elements must be filtered out, and offsets recomputed. - auto null_handling = static_cast(agg.get())->_null_handling; + auto null_handling = dynamic_cast(agg)._null_handling; if (null_handling == null_policy::EXCLUDE && input.has_nulls()) { auto num_child_nulls = count_child_nulls(input, gather_map, stream); if (num_child_nulls != 0) { @@ -1216,11 +1216,11 @@ struct dispatch_rolling { PrecedingWindowIterator preceding_window_begin, FollowingWindowIterator following_window_begin, size_type min_periods, - std::unique_ptr const& agg, + rolling_aggregation const& agg, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { - return aggregation_dispatcher(agg->kind, + return aggregation_dispatcher(agg.kind, rolling_window_launcher{}, input, default_outputs, @@ -1243,7 +1243,7 @@ std::unique_ptr rolling_window_udf(column_view const& input, FollowingWindowIterator following_window, std::string const& following_window_str, size_type min_periods, - std::unique_ptr const& agg, + rolling_aggregation const& agg, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { @@ -1255,28 +1255,27 @@ std::unique_ptr rolling_window_udf(column_view const& input, min_periods = std::max(min_periods, 0); - auto udf_agg = static_cast(agg.get()); + auto udf_agg = dynamic_cast(agg); - std::string hash = "prog_rolling." + std::to_string(std::hash{}(udf_agg->_source)); + std::string hash = "prog_rolling." + std::to_string(std::hash{}(udf_agg._source)); std::string cuda_source; - switch (udf_agg->kind) { + switch (udf_agg.kind) { case aggregation::Kind::PTX: cuda_source += - cudf::jit::parse_single_function_ptx(udf_agg->_source, - udf_agg->_function_name, - cudf::jit::get_type_name(udf_agg->_output_type), + cudf::jit::parse_single_function_ptx(udf_agg._source, + udf_agg._function_name, + cudf::jit::get_type_name(udf_agg._output_type), {0, 5}); // args 0 and 5 are pointers. break; case aggregation::Kind::CUDA: - cuda_source += - cudf::jit::parse_single_function_cuda(udf_agg->_source, udf_agg->_function_name); + cuda_source += cudf::jit::parse_single_function_cuda(udf_agg._source, udf_agg._function_name); break; default: CUDF_FAIL("Unsupported UDF type."); } std::unique_ptr output = make_numeric_column( - udf_agg->_output_type, input.size(), cudf::mask_state::UNINITIALIZED, stream, mr); + udf_agg._output_type, input.size(), cudf::mask_state::UNINITIALIZED, stream, mr); auto output_view = output->mutable_view(); rmm::device_scalar device_valid_count{0, stream}; @@ -1285,7 +1284,7 @@ std::unique_ptr rolling_window_udf(column_view const& input, jitify2::reflection::Template("cudf::rolling::jit::gpu_rolling_new") // .instantiate(cudf::jit::get_type_name(input.type()), // list of template arguments cudf::jit::get_type_name(output->type()), - udf_agg->_operator_name, + udf_agg._operator_name, preceding_window_str.c_str(), following_window_str.c_str()); @@ -1316,7 +1315,7 @@ std::unique_ptr rolling_window_udf(column_view const& input, * PrecedingWindowIterator preceding_window_begin, * FollowingWindowIterator following_window_begin, * size_type min_periods, - * std::unique_ptr const& agg, + * rolling_aggregation const& agg, * rmm::mr::device_memory_resource* mr) * * @param stream CUDA stream used for device memory operations and kernel launches. @@ -1327,7 +1326,7 @@ std::unique_ptr rolling_window(column_view const& input, PrecedingWindowIterator preceding_window_begin, FollowingWindowIterator following_window_begin, size_type min_periods, - std::unique_ptr const& agg, + rolling_aggregation const& agg, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { @@ -1337,10 +1336,10 @@ std::unique_ptr rolling_window(column_view const& input, if (input.is_empty()) return empty_like(input); if (cudf::is_dictionary(input.type())) - CUDF_EXPECTS(agg->kind == aggregation::COUNT_ALL || agg->kind == aggregation::COUNT_VALID || - agg->kind == aggregation::ROW_NUMBER || agg->kind == aggregation::MIN || - agg->kind == aggregation::MAX || agg->kind == aggregation::LEAD || - agg->kind == aggregation::LAG, + CUDF_EXPECTS(agg.kind == aggregation::COUNT_ALL || agg.kind == aggregation::COUNT_VALID || + agg.kind == aggregation::ROW_NUMBER || agg.kind == aggregation::MIN || + agg.kind == aggregation::MAX || agg.kind == aggregation::LEAD || + agg.kind == aggregation::LAG, "Invalid aggregation for dictionary column"); min_periods = std::max(min_periods, 0); @@ -1362,8 +1361,8 @@ std::unique_ptr rolling_window(column_view const& input, if (!cudf::is_dictionary(input.type())) return output; // dictionary column post processing - if (agg->kind == aggregation::COUNT_ALL || agg->kind == aggregation::COUNT_VALID || - agg->kind == aggregation::ROW_NUMBER) + if (agg.kind == aggregation::COUNT_ALL || agg.kind == aggregation::COUNT_VALID || + agg.kind == aggregation::ROW_NUMBER) return output; // output is new dictionary indices (including nulls) diff --git a/cpp/tests/rolling/collect_list_test.cpp b/cpp/tests/rolling/collect_list_test.cpp index de179223d68..8322dd0eee9 100644 --- a/cpp/tests/rolling/collect_list_test.cpp +++ b/cpp/tests/rolling/collect_list_test.cpp @@ -64,7 +64,11 @@ TYPED_TEST(TypedCollectListTest, BasicRollingWindow) static_cast(foll_column).size()); auto const result_column_based_window = - rolling_window(input_column, prev_column, foll_column, 1, make_collect_list_aggregation()); + rolling_window(input_column, + prev_column, + foll_column, + 1, + *make_collect_list_aggregation()); auto const expected_result = lists_column_wrapper{ @@ -79,11 +83,15 @@ TYPED_TEST(TypedCollectListTest, BasicRollingWindow) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_column_based_window->view()); auto const result_fixed_window = - rolling_window(input_column, 2, 1, 1, make_collect_list_aggregation()); + rolling_window(input_column, 2, 1, 1, *make_collect_list_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_fixed_window->view()); auto const result_with_nulls_excluded = - rolling_window(input_column, 2, 1, 1, make_collect_list_aggregation(null_policy::EXCLUDE)); + rolling_window(input_column, + 2, + 1, + 1, + *make_collect_list_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } @@ -104,7 +112,11 @@ TYPED_TEST(TypedCollectListTest, RollingWindowWithEmptyOutputLists) static_cast(foll_column).size()); auto const result_column_based_window = - rolling_window(input_column, prev_column, foll_column, 0, make_collect_list_aggregation()); + rolling_window(input_column, + prev_column, + foll_column, + 0, + *make_collect_list_aggregation()); auto const expected_result = lists_column_wrapper{ @@ -119,8 +131,12 @@ TYPED_TEST(TypedCollectListTest, RollingWindowWithEmptyOutputLists) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_column_based_window->view()); - auto const result_with_nulls_excluded = rolling_window( - input_column, prev_column, foll_column, 0, make_collect_list_aggregation(null_policy::EXCLUDE)); + auto const result_with_nulls_excluded = + rolling_window(input_column, + prev_column, + foll_column, + 0, + *make_collect_list_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } @@ -137,16 +153,23 @@ TYPED_TEST(TypedCollectListTest, RollingWindowWithEmptyOutputListsAtEnds) auto const prev_column = fixed_width_column_wrapper{0, 2, 2, 2, 2, 0}; auto foll_column = fixed_width_column_wrapper{0, 1, 1, 1, 1, 0}; - auto const result = - rolling_window(input_column, prev_column, foll_column, 0, make_collect_list_aggregation()); + auto const result = rolling_window(input_column, + prev_column, + foll_column, + 0, + *make_collect_list_aggregation()); auto const expected_result = lists_column_wrapper{{}, {0, 1, 2}, {1, 2, 3}, {2, 3, 4}, {3, 4, 5}, {}}.release(); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); - auto const result_with_nulls_excluded = rolling_window( - input_column, prev_column, foll_column, 0, make_collect_list_aggregation(null_policy::EXCLUDE)); + auto const result_with_nulls_excluded = + rolling_window(input_column, + prev_column, + foll_column, + 0, + *make_collect_list_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } @@ -167,8 +190,11 @@ TYPED_TEST(TypedCollectListTest, RollingWindowHonoursMinPeriods) auto preceding = 2; auto following = 1; auto min_periods = 3; - auto const result = rolling_window( - input_column, preceding, following, min_periods, make_collect_list_aggregation()); + auto const result = rolling_window(input_column, + preceding, + following, + min_periods, + *make_collect_list_aggregation()); auto const expected_result = lists_column_wrapper{ {{}, {0, 1, 2}, {1, 2, 3}, {2, 3, 4}, {3, 4, 5}, {}}, @@ -183,7 +209,7 @@ TYPED_TEST(TypedCollectListTest, RollingWindowHonoursMinPeriods) preceding, following, min_periods, - make_collect_list_aggregation(null_policy::EXCLUDE)); + *make_collect_list_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); @@ -191,8 +217,11 @@ TYPED_TEST(TypedCollectListTest, RollingWindowHonoursMinPeriods) following = 2; min_periods = 4; - auto result_2 = rolling_window( - input_column, preceding, following, min_periods, make_collect_list_aggregation()); + auto result_2 = rolling_window(input_column, + preceding, + following, + min_periods, + *make_collect_list_aggregation()); auto expected_result_2 = lists_column_wrapper{ {{}, {0, 1, 2, 3}, {1, 2, 3, 4}, {2, 3, 4, 5}, {}, {}}, cudf::detail::make_counting_transform_iterator(0, [num_elements](auto i) { @@ -206,7 +235,7 @@ TYPED_TEST(TypedCollectListTest, RollingWindowHonoursMinPeriods) preceding, following, min_periods, - make_collect_list_aggregation(null_policy::EXCLUDE)); + *make_collect_list_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_2->view(), result_2_with_nulls_excluded->view()); @@ -231,8 +260,11 @@ TYPED_TEST(TypedCollectListTest, RollingWindowWithNullInputsHonoursMinPeriods) auto preceding = 2; auto following = 1; auto min_periods = 3; - auto const result = rolling_window( - input_column, preceding, following, min_periods, make_collect_list_aggregation()); + auto const result = rolling_window(input_column, + preceding, + following, + min_periods, + *make_collect_list_aggregation()); auto expected_result_child_values = std::vector{0, 1, 2, 1, 2, 3, 2, 3, 4, 3, 4, 5}; auto expected_result_child_validity = std::vector{1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1}; @@ -258,14 +290,15 @@ TYPED_TEST(TypedCollectListTest, RollingWindowWithNullInputsHonoursMinPeriods) { // One result row at each end should be null. // Exclude nulls: No nulls elements for any output list rows. - auto preceding = 2; - auto following = 1; - auto min_periods = 3; - auto const result = rolling_window(input_column, - preceding, - following, - min_periods, - make_collect_list_aggregation(null_policy::EXCLUDE)); + auto preceding = 2; + auto following = 1; + auto min_periods = 3; + auto const result = + rolling_window(input_column, + preceding, + following, + min_periods, + *make_collect_list_aggregation(null_policy::EXCLUDE)); auto expected_result_child_values = std::vector{0, 2, 2, 3, 2, 3, 3, 5}; auto expected_result_child = fixed_width_column_wrapper( @@ -290,8 +323,11 @@ TYPED_TEST(TypedCollectListTest, RollingWindowWithNullInputsHonoursMinPeriods) auto preceding = 2; auto following = 2; auto min_periods = 4; - auto const result = rolling_window( - input_column, preceding, following, min_periods, make_collect_list_aggregation()); + auto const result = rolling_window(input_column, + preceding, + following, + min_periods, + *make_collect_list_aggregation()); auto expected_result_child_values = std::vector{0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5}; auto expected_result_child_validity = std::vector{1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1}; @@ -318,14 +354,15 @@ TYPED_TEST(TypedCollectListTest, RollingWindowWithNullInputsHonoursMinPeriods) { // First result row, and the last two result rows should be null. // Exclude nulls: No nulls elements for any output list rows. - auto preceding = 2; - auto following = 2; - auto min_periods = 4; - auto const result = rolling_window(input_column, - preceding, - following, - min_periods, - make_collect_list_aggregation(null_policy::EXCLUDE)); + auto preceding = 2; + auto following = 2; + auto min_periods = 4; + auto const result = + rolling_window(input_column, + preceding, + following, + min_periods, + *make_collect_list_aggregation(null_policy::EXCLUDE)); auto expected_result_child_values = std::vector{0, 2, 3, 2, 3, 2, 3, 5}; auto expected_result_child = fixed_width_column_wrapper( @@ -361,8 +398,11 @@ TEST_F(CollectListTest, RollingWindowHonoursMinPeriodsOnStrings) auto preceding = 2; auto following = 1; auto min_periods = 3; - auto const result = rolling_window( - input_column, preceding, following, min_periods, make_collect_list_aggregation()); + auto const result = rolling_window(input_column, + preceding, + following, + min_periods, + *make_collect_list_aggregation()); auto const expected_result = lists_column_wrapper{ {{}, {"0", "1", "2"}, {"1", "2", "3"}, {"2", "3", "4"}, {"3", "4", "5"}, {}}, @@ -377,7 +417,7 @@ TEST_F(CollectListTest, RollingWindowHonoursMinPeriodsOnStrings) preceding, following, min_periods, - make_collect_list_aggregation(null_policy::EXCLUDE)); + *make_collect_list_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); @@ -385,8 +425,11 @@ TEST_F(CollectListTest, RollingWindowHonoursMinPeriodsOnStrings) following = 2; min_periods = 4; - auto result_2 = rolling_window( - input_column, preceding, following, min_periods, make_collect_list_aggregation()); + auto result_2 = rolling_window(input_column, + preceding, + following, + min_periods, + *make_collect_list_aggregation()); auto expected_result_2 = lists_column_wrapper{ {{}, {"0", "1", "2", "3"}, {"1", "2", "3", "4"}, {"2", "3", "4", "5"}, {}, {}}, cudf::detail::make_counting_transform_iterator(0, [num_elements](auto i) { @@ -400,7 +443,7 @@ TEST_F(CollectListTest, RollingWindowHonoursMinPeriodsOnStrings) preceding, following, min_periods, - make_collect_list_aggregation(null_policy::EXCLUDE)); + *make_collect_list_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_2->view(), result_2_with_nulls_excluded->view()); @@ -424,8 +467,11 @@ TEST_F(CollectListTest, RollingWindowHonoursMinPeriodsWithDecimal) auto preceding = 2; auto following = 1; auto min_periods = 3; - auto const result = rolling_window( - input_column, preceding, following, min_periods, make_collect_list_aggregation()); + auto const result = rolling_window(input_column, + preceding, + following, + min_periods, + *make_collect_list_aggregation()); auto expected_result_child_values = std::vector{0, 1, 2, 1, 2, 3, 2, 3, 4, 3, 4, 5}; auto expected_result_child = @@ -451,7 +497,7 @@ TEST_F(CollectListTest, RollingWindowHonoursMinPeriodsWithDecimal) preceding, following, min_periods, - make_collect_list_aggregation(null_policy::EXCLUDE)); + *make_collect_list_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); @@ -462,8 +508,11 @@ TEST_F(CollectListTest, RollingWindowHonoursMinPeriodsWithDecimal) auto preceding = 2; auto following = 2; auto min_periods = 4; - auto const result = rolling_window( - input_column, preceding, following, min_periods, make_collect_list_aggregation()); + auto const result = rolling_window(input_column, + preceding, + following, + min_periods, + *make_collect_list_aggregation()); auto expected_result_child_values = std::vector{0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5}; auto expected_result_child = @@ -489,7 +538,7 @@ TEST_F(CollectListTest, RollingWindowHonoursMinPeriodsWithDecimal) preceding, following, min_periods, - make_collect_list_aggregation(null_policy::EXCLUDE)); + *make_collect_list_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); @@ -515,7 +564,7 @@ TYPED_TEST(TypedCollectListTest, BasicGroupedRollingWindow) preceding, following, min_periods, - make_collect_list_aggregation()); + *make_collect_list_aggregation()); auto const expected_result = lists_column_wrapper{ {10, 11}, @@ -530,13 +579,13 @@ TYPED_TEST(TypedCollectListTest, BasicGroupedRollingWindow) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); - auto const result_with_nulls_excluded = - grouped_rolling_window(table_view{std::vector{group_column}}, - input_column, - preceding, - following, - min_periods, - make_collect_list_aggregation(null_policy::EXCLUDE)); + auto const result_with_nulls_excluded = grouped_rolling_window( + table_view{std::vector{group_column}}, + input_column, + preceding, + following, + min_periods, + *make_collect_list_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } @@ -558,12 +607,13 @@ TYPED_TEST(TypedCollectListTest, BasicGroupedRollingWindowWithNulls) { // Nulls included. - auto const result = grouped_rolling_window(table_view{std::vector{group_column}}, - input_column, - preceding, - following, - min_periods, - make_collect_list_aggregation()); + auto const result = + grouped_rolling_window(table_view{std::vector{group_column}}, + input_column, + preceding, + following, + min_periods, + *make_collect_list_aggregation()); auto expected_child = fixed_width_column_wrapper{ {10, 11, 10, 11, 12, 11, 12, 13, 12, 13, 14, 13, 14, 20, 21, 20, 21, 22, 21, 22, 23, 22, 23}, @@ -582,12 +632,13 @@ TYPED_TEST(TypedCollectListTest, BasicGroupedRollingWindowWithNulls) { // Nulls excluded. - auto const result = grouped_rolling_window(table_view{std::vector{group_column}}, - input_column, - preceding, - following, - min_periods, - make_collect_list_aggregation(null_policy::EXCLUDE)); + auto const result = grouped_rolling_window( + table_view{std::vector{group_column}}, + input_column, + preceding, + following, + min_periods, + *make_collect_list_aggregation(null_policy::EXCLUDE)); auto expected_child = fixed_width_column_wrapper{ 10, 10, 12, 12, 13, 12, 13, 14, 13, 14, 20, 20, 22, 22, 23, 22, 23}; @@ -627,7 +678,7 @@ TYPED_TEST(TypedCollectListTest, BasicGroupedTimeRangeRollingWindow) preceding, following, min_periods, - make_collect_list_aggregation()); + *make_collect_list_aggregation()); auto const expected_result = lists_column_wrapper{ {10, 11, 12, 13}, @@ -642,15 +693,15 @@ TYPED_TEST(TypedCollectListTest, BasicGroupedTimeRangeRollingWindow) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); - auto const result_with_nulls_excluded = - grouped_time_range_rolling_window(table_view{std::vector{group_column}}, - time_column, - cudf::order::ASCENDING, - input_column, - preceding, - following, - min_periods, - make_collect_list_aggregation(null_policy::EXCLUDE)); + auto const result_with_nulls_excluded = grouped_time_range_rolling_window( + table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + *make_collect_list_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } @@ -678,7 +729,7 @@ TYPED_TEST(TypedCollectListTest, GroupedTimeRangeRollingWindowWithNulls) preceding, following, min_periods, - make_collect_list_aggregation()); + *make_collect_list_aggregation()); auto null_at_0 = iterator_with_null_at(0); auto null_at_1 = iterator_with_null_at(1); @@ -697,15 +748,15 @@ TYPED_TEST(TypedCollectListTest, GroupedTimeRangeRollingWindowWithNulls) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); - auto const result_with_nulls_excluded = - grouped_time_range_rolling_window(table_view{std::vector{group_column}}, - time_column, - cudf::order::ASCENDING, - input_column, - preceding, - following, - min_periods, - make_collect_list_aggregation(null_policy::EXCLUDE)); + auto const result_with_nulls_excluded = grouped_time_range_rolling_window( + table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + *make_collect_list_aggregation(null_policy::EXCLUDE)); // After null exclusion, `11`, `21`, and `null` should not appear. auto const expected_result_with_nulls_excluded = lists_column_wrapper{ @@ -744,7 +795,7 @@ TEST_F(CollectListTest, BasicGroupedTimeRangeRollingWindowOnStrings) preceding, following, min_periods, - make_collect_list_aggregation()); + *make_collect_list_aggregation()); auto const expected_result = lists_column_wrapper{ {"10", "11", "12", "13"}, @@ -759,15 +810,15 @@ TEST_F(CollectListTest, BasicGroupedTimeRangeRollingWindowOnStrings) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); - auto const result_with_nulls_excluded = - grouped_time_range_rolling_window(table_view{std::vector{group_column}}, - time_column, - cudf::order::ASCENDING, - input_column, - preceding, - following, - min_periods, - make_collect_list_aggregation(null_policy::EXCLUDE)); + auto const result_with_nulls_excluded = grouped_time_range_rolling_window( + table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + *make_collect_list_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } @@ -793,7 +844,7 @@ TEST_F(CollectListTest, GroupedTimeRangeRollingWindowOnStringsWithNulls) preceding, following, min_periods, - make_collect_list_aggregation()); + *make_collect_list_aggregation()); auto null_at_0 = iterator_with_null_at(0); auto null_at_1 = iterator_with_null_at(1); @@ -813,15 +864,15 @@ TEST_F(CollectListTest, GroupedTimeRangeRollingWindowOnStringsWithNulls) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); - auto const result_with_nulls_excluded = - grouped_time_range_rolling_window(table_view{std::vector{group_column}}, - time_column, - cudf::order::ASCENDING, - input_column, - preceding, - following, - min_periods, - make_collect_list_aggregation(null_policy::EXCLUDE)); + auto const result_with_nulls_excluded = grouped_time_range_rolling_window( + table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + *make_collect_list_aggregation(null_policy::EXCLUDE)); // After null exclusion, `11`, `21`, and `null` should not appear. auto const expected_result_with_nulls_excluded = lists_column_wrapper{ @@ -868,7 +919,7 @@ TYPED_TEST(TypedCollectListTest, BasicGroupedTimeRangeRollingWindowOnStructs) preceding, following, min_periods, - make_collect_list_aggregation()); + *make_collect_list_aggregation()); auto expected_numeric_column = fixed_width_column_wrapper{ 10, 11, 12, 13, 10, 11, 12, 13, 10, 11, 12, 13, 14, 10, 11, 12, @@ -890,15 +941,15 @@ TYPED_TEST(TypedCollectListTest, BasicGroupedTimeRangeRollingWindowOnStructs) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); - auto const result_with_nulls_excluded = - grouped_time_range_rolling_window(table_view{std::vector{group_column}}, - time_column, - cudf::order::ASCENDING, - struct_column->view(), - preceding, - following, - min_periods, - make_collect_list_aggregation(null_policy::EXCLUDE)); + auto const result_with_nulls_excluded = grouped_time_range_rolling_window( + table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + struct_column->view(), + preceding, + following, + min_periods, + *make_collect_list_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } @@ -928,7 +979,7 @@ TYPED_TEST(TypedCollectListTest, GroupedTimeRangeRollingWindowWithMinPeriods) preceding, following, min_periods, - make_collect_list_aggregation()); + *make_collect_list_aggregation()); auto const expected_result = lists_column_wrapper{ {{10, 11, 12, 13}, @@ -946,15 +997,15 @@ TYPED_TEST(TypedCollectListTest, GroupedTimeRangeRollingWindowWithMinPeriods) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); - auto const result_with_nulls_excluded = - grouped_time_range_rolling_window(table_view{std::vector{group_column}}, - time_column, - cudf::order::ASCENDING, - input_column, - preceding, - following, - min_periods, - make_collect_list_aggregation(null_policy::EXCLUDE)); + auto const result_with_nulls_excluded = grouped_time_range_rolling_window( + table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + *make_collect_list_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } @@ -984,7 +1035,7 @@ TYPED_TEST(TypedCollectListTest, GroupedTimeRangeRollingWindowWithNullsAndMinPer preceding, following, min_periods, - make_collect_list_aggregation()); + *make_collect_list_aggregation()); auto null_at_1 = iterator_with_null_at(1); @@ -1005,15 +1056,15 @@ TYPED_TEST(TypedCollectListTest, GroupedTimeRangeRollingWindowWithNullsAndMinPer CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); - auto const result_with_nulls_excluded = - grouped_time_range_rolling_window(table_view{std::vector{group_column}}, - time_column, - cudf::order::ASCENDING, - input_column, - preceding, - following, - min_periods, - make_collect_list_aggregation(null_policy::EXCLUDE)); + auto const result_with_nulls_excluded = grouped_time_range_rolling_window( + table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + *make_collect_list_aggregation(null_policy::EXCLUDE)); // After null exclusion, `11`, `21`, and `null` should not appear. auto const expected_result_with_nulls_excluded = lists_column_wrapper{ @@ -1056,7 +1107,7 @@ TEST_F(CollectListTest, GroupedTimeRangeRollingWindowOnStringsWithMinPeriods) preceding, following, min_periods, - make_collect_list_aggregation()); + *make_collect_list_aggregation()); auto const expected_result = lists_column_wrapper{ {{"10", "11", "12", "13"}, @@ -1074,15 +1125,15 @@ TEST_F(CollectListTest, GroupedTimeRangeRollingWindowOnStringsWithMinPeriods) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); - auto const result_with_nulls_excluded = - grouped_time_range_rolling_window(table_view{std::vector{group_column}}, - time_column, - cudf::order::ASCENDING, - input_column, - preceding, - following, - min_periods, - make_collect_list_aggregation(null_policy::EXCLUDE)); + auto const result_with_nulls_excluded = grouped_time_range_rolling_window( + table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + *make_collect_list_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } @@ -1110,7 +1161,7 @@ TEST_F(CollectListTest, GroupedTimeRangeRollingWindowOnStringsWithNullsAndMinPer preceding, following, min_periods, - make_collect_list_aggregation()); + *make_collect_list_aggregation()); auto null_at_1 = iterator_with_null_at(1); @@ -1131,15 +1182,15 @@ TEST_F(CollectListTest, GroupedTimeRangeRollingWindowOnStringsWithNullsAndMinPer CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); - auto const result_with_nulls_excluded = - grouped_time_range_rolling_window(table_view{std::vector{group_column}}, - time_column, - cudf::order::ASCENDING, - input_column, - preceding, - following, - min_periods, - make_collect_list_aggregation(null_policy::EXCLUDE)); + auto const result_with_nulls_excluded = grouped_time_range_rolling_window( + table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + *make_collect_list_aggregation(null_policy::EXCLUDE)); // After null exclusion, `11`, `21`, and `null` should not appear. auto const expected_result_with_nulls_excluded = lists_column_wrapper{ @@ -1190,7 +1241,7 @@ TYPED_TEST(TypedCollectListTest, GroupedTimeRangeRollingWindowOnStructsWithMinPe preceding, following, min_periods, - make_collect_list_aggregation()); + *make_collect_list_aggregation()); auto expected_numeric_column = fixed_width_column_wrapper{ 10, 11, 12, 13, 10, 11, 12, 13, 10, 11, 12, 13, 14, 10, 11, 12, 13, 14, 10, 11, 12, 13, 14}; @@ -1218,15 +1269,15 @@ TYPED_TEST(TypedCollectListTest, GroupedTimeRangeRollingWindowOnStructsWithMinPe CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); - auto const result_with_nulls_excluded = - grouped_time_range_rolling_window(table_view{std::vector{group_column}}, - time_column, - cudf::order::ASCENDING, - struct_column->view(), - preceding, - following, - min_periods, - make_collect_list_aggregation(null_policy::EXCLUDE)); + auto const result_with_nulls_excluded = grouped_time_range_rolling_window( + table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + struct_column->view(), + preceding, + following, + min_periods, + *make_collect_list_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } diff --git a/cpp/tests/rolling/grouped_rolling_test.cpp b/cpp/tests/rolling/grouped_rolling_test.cpp index 6f930f99b50..804fd715951 100644 --- a/cpp/tests/rolling/grouped_rolling_test.cpp +++ b/cpp/tests/rolling/grouped_rolling_test.cpp @@ -126,7 +126,7 @@ class GroupedRollingTest : public cudf::test::BaseFixture { size_type const& preceding_window, size_type const& following_window, size_type min_periods, - std::unique_ptr const& op) + cudf::rolling_aggregation const& op) { std::unique_ptr output; @@ -170,28 +170,29 @@ class GroupedRollingTest : public cudf::test::BaseFixture { preceding_window, following_window, min_periods, - cudf::make_min_aggregation()); + *cudf::make_min_aggregation()); run_test_col(keys, input, expected_grouping, preceding_window, following_window, min_periods, - cudf::make_count_aggregation()); + *cudf::make_count_aggregation()); + run_test_col( + keys, + input, + expected_grouping, + preceding_window, + following_window, + min_periods, + *cudf::make_count_aggregation(cudf::null_policy::INCLUDE)); run_test_col(keys, input, expected_grouping, preceding_window, following_window, min_periods, - cudf::make_count_aggregation(cudf::null_policy::INCLUDE)); - run_test_col(keys, - input, - expected_grouping, - preceding_window, - following_window, - min_periods, - cudf::make_max_aggregation()); + *cudf::make_max_aggregation()); if (!cudf::is_timestamp(input.type())) { run_test_col(keys, @@ -200,14 +201,14 @@ class GroupedRollingTest : public cudf::test::BaseFixture { preceding_window, following_window, min_periods, - cudf::make_sum_aggregation()); + *cudf::make_sum_aggregation()); run_test_col(keys, input, expected_grouping, preceding_window, following_window, min_periods, - cudf::make_mean_aggregation()); + *cudf::make_mean_aggregation()); } run_test_col(keys, input, @@ -215,11 +216,11 @@ class GroupedRollingTest : public cudf::test::BaseFixture { preceding_window, following_window, min_periods, - cudf::make_row_number_aggregation()); + *cudf::make_row_number_aggregation()); // >>> test UDFs <<< if (input.type() == cudf::data_type{cudf::type_id::INT32} && !input.has_nulls()) { - auto cuda_udf_agg = cudf::make_udf_aggregation( + auto cuda_udf_agg = cudf::make_udf_aggregation( cudf::udf_type::CUDA, cuda_func, cudf::data_type{cudf::type_id::INT64}); run_test_col(keys, input, @@ -227,9 +228,9 @@ class GroupedRollingTest : public cudf::test::BaseFixture { preceding_window, following_window, min_periods, - cuda_udf_agg); + *cuda_udf_agg); - auto ptx_udf_agg = cudf::make_udf_aggregation( + auto ptx_udf_agg = cudf::make_udf_aggregation( cudf::udf_type::PTX, ptx_func, cudf::data_type{cudf::type_id::INT64}); run_test_col(keys, input, @@ -237,7 +238,7 @@ class GroupedRollingTest : public cudf::test::BaseFixture { preceding_window, following_window, min_periods, - ptx_udf_agg); + *ptx_udf_agg); } } @@ -402,16 +403,15 @@ class GroupedRollingTest : public cudf::test::BaseFixture { CUDF_FAIL("Unsupported combination of type and aggregation"); } - std::unique_ptr create_reference_output( - std::unique_ptr const& op, - cudf::column_view const& input, - std::vector const& group_offsets, - size_type const& preceding_window, - size_type const& following_window, - size_type min_periods) + std::unique_ptr create_reference_output(cudf::rolling_aggregation const& op, + cudf::column_view const& input, + std::vector const& group_offsets, + size_type const& preceding_window, + size_type const& following_window, + size_type min_periods) { // unroll aggregation types - switch (op->kind) { + switch (op.kind) { case cudf::aggregation::SUM: return create_reference_output{grouping_keys_col}}; EXPECT_THROW( - cudf::grouped_rolling_window(grouping_keys, input, 2, 2, -2, cudf::make_sum_aggregation()), + cudf::grouped_rolling_window( + grouping_keys, input, 2, 2, -2, *cudf::make_sum_aggregation()), cudf::logic_error); } @@ -493,8 +494,9 @@ TEST_F(GroupedRollingErrorTest, EmptyInput) cudf::test::fixed_width_column_wrapper empty_col{}; std::unique_ptr output; const cudf::table_view grouping_keys{std::vector{}}; - EXPECT_NO_THROW(output = cudf::grouped_rolling_window( - grouping_keys, empty_col, 2, 0, 2, cudf::make_sum_aggregation())); + EXPECT_NO_THROW( + output = cudf::grouped_rolling_window( + grouping_keys, empty_col, 2, 0, 2, *cudf::make_sum_aggregation())); EXPECT_EQ(output->size(), 0); } @@ -519,19 +521,24 @@ TEST_F(GroupedRollingErrorTest, SumTimestampNotSupported) fixed_width_column_wrapper(grouping_keys_vec.begin(), grouping_keys_vec.end())}}; EXPECT_THROW( - cudf::grouped_rolling_window(grouping_keys, input_D, 2, 2, 0, cudf::make_sum_aggregation()), + cudf::grouped_rolling_window( + grouping_keys, input_D, 2, 2, 0, *cudf::make_sum_aggregation()), cudf::logic_error); EXPECT_THROW( - cudf::grouped_rolling_window(grouping_keys, input_s, 2, 2, 0, cudf::make_sum_aggregation()), + cudf::grouped_rolling_window( + grouping_keys, input_s, 2, 2, 0, *cudf::make_sum_aggregation()), cudf::logic_error); EXPECT_THROW( - cudf::grouped_rolling_window(grouping_keys, input_ms, 2, 2, 0, cudf::make_sum_aggregation()), + cudf::grouped_rolling_window( + grouping_keys, input_ms, 2, 2, 0, *cudf::make_sum_aggregation()), cudf::logic_error); EXPECT_THROW( - cudf::grouped_rolling_window(grouping_keys, input_us, 2, 2, 0, cudf::make_sum_aggregation()), + cudf::grouped_rolling_window( + grouping_keys, input_us, 2, 2, 0, *cudf::make_sum_aggregation()), cudf::logic_error); EXPECT_THROW( - cudf::grouped_rolling_window(grouping_keys, input_ns, 2, 2, 0, cudf::make_sum_aggregation()), + cudf::grouped_rolling_window( + grouping_keys, input_ns, 2, 2, 0, *cudf::make_sum_aggregation()), cudf::logic_error); } @@ -655,10 +662,13 @@ TEST_F(GroupedRollingTestStrings, StringsUnsupportedOperators) const cudf::table_view key_cols{std::vector{ fixed_width_column_wrapper(key_col_vec.begin(), key_col_vec.end())}}; - EXPECT_THROW(cudf::grouped_rolling_window(key_cols, input, 2, 2, 0, cudf::make_sum_aggregation()), - cudf::logic_error); EXPECT_THROW( - cudf::grouped_rolling_window(key_cols, input, 2, 2, 0, cudf::make_mean_aggregation()), + cudf::grouped_rolling_window( + key_cols, input, 2, 2, 0, *cudf::make_sum_aggregation()), + cudf::logic_error); + EXPECT_THROW( + cudf::grouped_rolling_window( + key_cols, input, 2, 2, 0, *cudf::make_mean_aggregation()), cudf::logic_error); } @@ -674,7 +684,7 @@ class GroupedTimeRangeRollingTest : public cudf::test::BaseFixture { size_type const& preceding_window_in_days, size_type const& following_window_in_days, size_type min_periods, - std::unique_ptr const& op) + cudf::rolling_aggregation const& op) { std::unique_ptr output; @@ -734,7 +744,7 @@ class GroupedTimeRangeRollingTest : public cudf::test::BaseFixture { preceding_window_in_days, following_window_in_days, min_periods, - cudf::make_min_aggregation()); + *cudf::make_min_aggregation()); run_test_col(keys, timestamp_column, timestamp_order, @@ -743,7 +753,17 @@ class GroupedTimeRangeRollingTest : public cudf::test::BaseFixture { preceding_window_in_days, following_window_in_days, min_periods, - cudf::make_count_aggregation()); + *cudf::make_count_aggregation()); + run_test_col( + keys, + timestamp_column, + timestamp_order, + input, + expected_grouping, + preceding_window_in_days, + following_window_in_days, + min_periods, + *cudf::make_count_aggregation(cudf::null_policy::INCLUDE)); run_test_col(keys, timestamp_column, timestamp_order, @@ -752,16 +772,7 @@ class GroupedTimeRangeRollingTest : public cudf::test::BaseFixture { preceding_window_in_days, following_window_in_days, min_periods, - cudf::make_count_aggregation(cudf::null_policy::INCLUDE)); - run_test_col(keys, - timestamp_column, - timestamp_order, - input, - expected_grouping, - preceding_window_in_days, - following_window_in_days, - min_periods, - cudf::make_max_aggregation()); + *cudf::make_max_aggregation()); if (!cudf::is_timestamp(input.type())) { run_test_col(keys, timestamp_column, @@ -771,7 +782,7 @@ class GroupedTimeRangeRollingTest : public cudf::test::BaseFixture { preceding_window_in_days, following_window_in_days, min_periods, - cudf::make_sum_aggregation()); + *cudf::make_sum_aggregation()); run_test_col(keys, timestamp_column, timestamp_order, @@ -780,7 +791,7 @@ class GroupedTimeRangeRollingTest : public cudf::test::BaseFixture { preceding_window_in_days, following_window_in_days, min_periods, - cudf::make_mean_aggregation()); + *cudf::make_mean_aggregation()); } run_test_col(keys, timestamp_column, @@ -790,7 +801,7 @@ class GroupedTimeRangeRollingTest : public cudf::test::BaseFixture { preceding_window_in_days, following_window_in_days, min_periods, - cudf::make_row_number_aggregation()); + *cudf::make_row_number_aggregation()); } private: @@ -1038,18 +1049,17 @@ class GroupedTimeRangeRollingTest : public cudf::test::BaseFixture { CUDF_FAIL("Unsupported combination of type and aggregation"); } - std::unique_ptr create_reference_output( - std::unique_ptr const& op, - cudf::column_view const& timestamp_column, - cudf::order const& timestamp_order, - cudf::column_view const& input, - std::vector const& group_offsets, - size_type const& preceding_window, - size_type const& following_window, - size_type min_periods) + std::unique_ptr create_reference_output(cudf::rolling_aggregation const& op, + cudf::column_view const& timestamp_column, + cudf::order const& timestamp_order, + cudf::column_view const& input, + std::vector const& group_offsets, + size_type const& preceding_window, + size_type const& following_window, + size_type min_periods) { // unroll aggregation types - switch (op->kind) { + switch (op.kind) { case cudf::aggregation::SUM: return create_reference_output()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -1287,14 +1298,15 @@ TYPED_TEST(TypedNullTimestampTestForRangeQueries, CountSingleGroupTimestampASCNu auto const preceding = 1L; auto const following = 1L; auto const min_periods = 1L; - auto const output = cudf::grouped_time_range_rolling_window(grouping_keys, - time_col, - cudf::order::ASCENDING, - agg_col, - preceding, - following, - min_periods, - cudf::make_count_aggregation()); + auto const output = cudf::grouped_time_range_rolling_window( + grouping_keys, + time_col, + cudf::order::ASCENDING, + agg_col, + preceding, + following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -1315,14 +1327,15 @@ TYPED_TEST(TypedNullTimestampTestForRangeQueries, CountMultiGroupTimestampASCNul auto const preceding = 1L; auto const following = 1L; auto const min_periods = 1L; - auto const output = cudf::grouped_time_range_rolling_window(grouping_keys, - time_col, - cudf::order::ASCENDING, - agg_col, - preceding, - following, - min_periods, - cudf::make_count_aggregation()); + auto const output = cudf::grouped_time_range_rolling_window( + grouping_keys, + time_col, + cudf::order::ASCENDING, + agg_col, + preceding, + following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -1343,14 +1356,15 @@ TYPED_TEST(TypedNullTimestampTestForRangeQueries, CountMultiGroupTimestampASCNul auto const preceding = 1L; auto const following = 1L; auto const min_periods = 1L; - auto const output = cudf::grouped_time_range_rolling_window(grouping_keys, - time_col, - cudf::order::ASCENDING, - agg_col, - preceding, - following, - min_periods, - cudf::make_count_aggregation()); + auto const output = cudf::grouped_time_range_rolling_window( + grouping_keys, + time_col, + cudf::order::ASCENDING, + agg_col, + preceding, + following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -1372,14 +1386,15 @@ TYPED_TEST(TypedNullTimestampTestForRangeQueries, CountSingleGroupTimestampDESCN auto const preceding = 1L; auto const following = 1L; auto const min_periods = 1L; - auto const output = cudf::grouped_time_range_rolling_window(grouping_keys, - time_col, - cudf::order::DESCENDING, - agg_col, - preceding, - following, - min_periods, - cudf::make_count_aggregation()); + auto const output = cudf::grouped_time_range_rolling_window( + grouping_keys, + time_col, + cudf::order::DESCENDING, + agg_col, + preceding, + following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -1402,14 +1417,15 @@ TYPED_TEST(TypedNullTimestampTestForRangeQueries, CountSingleGroupTimestampDESCN auto const preceding = 1L; auto const following = 1L; auto const min_periods = 1L; - auto const output = cudf::grouped_time_range_rolling_window(grouping_keys, - time_col, - cudf::order::DESCENDING, - agg_col, - preceding, - following, - min_periods, - cudf::make_count_aggregation()); + auto const output = cudf::grouped_time_range_rolling_window( + grouping_keys, + time_col, + cudf::order::DESCENDING, + agg_col, + preceding, + following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -1430,14 +1446,15 @@ TYPED_TEST(TypedNullTimestampTestForRangeQueries, CountMultiGroupTimestampDESCNu auto const preceding = 1L; auto const following = 1L; auto const min_periods = 1L; - auto const output = cudf::grouped_time_range_rolling_window(grouping_keys, - time_col, - cudf::order::DESCENDING, - agg_col, - preceding, - following, - min_periods, - cudf::make_count_aggregation()); + auto const output = cudf::grouped_time_range_rolling_window( + grouping_keys, + time_col, + cudf::order::DESCENDING, + agg_col, + preceding, + following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -1458,14 +1475,15 @@ TYPED_TEST(TypedNullTimestampTestForRangeQueries, CountMultiGroupTimestampDESCNu auto const preceding = 1L; auto const following = 1L; auto const min_periods = 1L; - auto const output = cudf::grouped_time_range_rolling_window(grouping_keys, - time_col, - cudf::order::DESCENDING, - agg_col, - preceding, - following, - min_periods, - cudf::make_count_aggregation()); + auto const output = cudf::grouped_time_range_rolling_window( + grouping_keys, + time_col, + cudf::order::DESCENDING, + agg_col, + preceding, + following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -1488,14 +1506,15 @@ TYPED_TEST(TypedNullTimestampTestForRangeQueries, CountSingleGroupAllNullTimesta auto const preceding = 1L; auto const following = 1L; auto const min_periods = 1L; - auto const output = cudf::grouped_time_range_rolling_window(grouping_keys, - time_col, - cudf::order::ASCENDING, - agg_col, - preceding, - following, - min_periods, - cudf::make_count_aggregation()); + auto const output = cudf::grouped_time_range_rolling_window( + grouping_keys, + time_col, + cudf::order::ASCENDING, + agg_col, + preceding, + following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -1518,14 +1537,15 @@ TYPED_TEST(TypedNullTimestampTestForRangeQueries, CountMultiGroupAllNullTimestam auto const preceding = 1L; auto const following = 1L; auto const min_periods = 1L; - auto const output = cudf::grouped_time_range_rolling_window(grouping_keys, - time_col, - cudf::order::ASCENDING, - agg_col, - preceding, - following, - min_periods, - cudf::make_count_aggregation()); + auto const output = cudf::grouped_time_range_rolling_window( + grouping_keys, + time_col, + cudf::order::ASCENDING, + agg_col, + preceding, + following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -1561,14 +1581,15 @@ TYPED_TEST(TypedUnboundedWindowTest, UnboundedPrecedingWindowSingleGroupTimestam auto const unbounded_preceding = cudf::window_bounds::unbounded(); auto const one_day_following = cudf::window_bounds::get(1L); auto const min_periods = 1L; - auto const output = cudf::grouped_time_range_rolling_window(grouping_keys, - time_col, - cudf::order::ASCENDING, - agg_col, - unbounded_preceding, - one_day_following, - min_periods, - cudf::make_count_aggregation()); + auto const output = cudf::grouped_time_range_rolling_window( + grouping_keys, + time_col, + cudf::order::ASCENDING, + agg_col, + unbounded_preceding, + one_day_following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -1590,14 +1611,15 @@ TYPED_TEST(TypedUnboundedWindowTest, UnboundedFollowingWindowSingleGroupTimestam auto const one_day_preceding = cudf::window_bounds::get(1L); auto const unbounded_following = cudf::window_bounds::unbounded(); auto const min_periods = 1L; - auto const output = cudf::grouped_time_range_rolling_window(grouping_keys, - time_col, - cudf::order::ASCENDING, - agg_col, - one_day_preceding, - unbounded_following, - min_periods, - cudf::make_count_aggregation()); + auto const output = cudf::grouped_time_range_rolling_window( + grouping_keys, + time_col, + cudf::order::ASCENDING, + agg_col, + one_day_preceding, + unbounded_following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -1620,14 +1642,15 @@ TYPED_TEST(TypedUnboundedWindowTest, auto const unbounded_preceding = cudf::window_bounds::unbounded(); auto const unbounded_following = cudf::window_bounds::unbounded(); auto const min_periods = 1L; - auto const output = cudf::grouped_time_range_rolling_window(grouping_keys, - time_col, - cudf::order::ASCENDING, - agg_col, - unbounded_preceding, - unbounded_following, - min_periods, - cudf::make_count_aggregation()); + auto const output = cudf::grouped_time_range_rolling_window( + grouping_keys, + time_col, + cudf::order::ASCENDING, + agg_col, + unbounded_preceding, + unbounded_following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -1649,14 +1672,15 @@ TYPED_TEST(TypedUnboundedWindowTest, UnboundedPrecedingWindowSingleGroupTimestam auto const unbounded_preceding = cudf::window_bounds::unbounded(); auto const one_day_following = cudf::window_bounds::get(1L); auto const min_periods = 1L; - auto const output = cudf::grouped_time_range_rolling_window(grouping_keys, - time_col, - cudf::order::ASCENDING, - agg_col, - unbounded_preceding, - one_day_following, - min_periods, - cudf::make_count_aggregation()); + auto const output = cudf::grouped_time_range_rolling_window( + grouping_keys, + time_col, + cudf::order::ASCENDING, + agg_col, + unbounded_preceding, + one_day_following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -1678,14 +1702,15 @@ TYPED_TEST(TypedUnboundedWindowTest, UnboundedFollowingWindowSingleGroupTimestam auto const one_day_preceding = cudf::window_bounds::get(1L); auto const unbounded_following = cudf::window_bounds::unbounded(); auto const min_periods = 1L; - auto const output = cudf::grouped_time_range_rolling_window(grouping_keys, - time_col, - cudf::order::ASCENDING, - agg_col, - one_day_preceding, - unbounded_following, - min_periods, - cudf::make_count_aggregation()); + auto const output = cudf::grouped_time_range_rolling_window( + grouping_keys, + time_col, + cudf::order::ASCENDING, + agg_col, + one_day_preceding, + unbounded_following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -1708,14 +1733,15 @@ TYPED_TEST(TypedUnboundedWindowTest, auto const unbounded_preceding = cudf::window_bounds::unbounded(); auto const unbounded_following = cudf::window_bounds::unbounded(); auto const min_periods = 1L; - auto const output = cudf::grouped_time_range_rolling_window(grouping_keys, - time_col, - cudf::order::ASCENDING, - agg_col, - unbounded_preceding, - unbounded_following, - min_periods, - cudf::make_count_aggregation()); + auto const output = cudf::grouped_time_range_rolling_window( + grouping_keys, + time_col, + cudf::order::ASCENDING, + agg_col, + unbounded_preceding, + unbounded_following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -1737,14 +1763,15 @@ TYPED_TEST(TypedUnboundedWindowTest, UnboundedPrecedingWindowSingleGroupTimestam auto const unbounded_preceding = cudf::window_bounds::unbounded(); auto const one_day_following = cudf::window_bounds::get(1L); auto const min_periods = 1L; - auto const output = cudf::grouped_time_range_rolling_window(grouping_keys, - time_col, - cudf::order::DESCENDING, - agg_col, - unbounded_preceding, - one_day_following, - min_periods, - cudf::make_count_aggregation()); + auto const output = cudf::grouped_time_range_rolling_window( + grouping_keys, + time_col, + cudf::order::DESCENDING, + agg_col, + unbounded_preceding, + one_day_following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -1766,14 +1793,15 @@ TYPED_TEST(TypedUnboundedWindowTest, UnboundedFollowingWindowSingleGroupTimestam auto const one_day_preceding = cudf::window_bounds::get(1L); auto const unbounded_following = cudf::window_bounds::unbounded(); auto const min_periods = 1L; - auto const output = cudf::grouped_time_range_rolling_window(grouping_keys, - time_col, - cudf::order::DESCENDING, - agg_col, - one_day_preceding, - unbounded_following, - min_periods, - cudf::make_count_aggregation()); + auto const output = cudf::grouped_time_range_rolling_window( + grouping_keys, + time_col, + cudf::order::DESCENDING, + agg_col, + one_day_preceding, + unbounded_following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -1796,14 +1824,15 @@ TYPED_TEST(TypedUnboundedWindowTest, auto const unbounded_preceding = cudf::window_bounds::unbounded(); auto const unbounded_following = cudf::window_bounds::unbounded(); auto const min_periods = 1L; - auto const output = cudf::grouped_time_range_rolling_window(grouping_keys, - time_col, - cudf::order::DESCENDING, - agg_col, - unbounded_preceding, - unbounded_following, - min_periods, - cudf::make_count_aggregation()); + auto const output = cudf::grouped_time_range_rolling_window( + grouping_keys, + time_col, + cudf::order::DESCENDING, + agg_col, + unbounded_preceding, + unbounded_following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -1825,14 +1854,15 @@ TYPED_TEST(TypedUnboundedWindowTest, UnboundedPrecedingWindowSingleGroupTimestam auto const unbounded_preceding = cudf::window_bounds::unbounded(); auto const one_day_following = cudf::window_bounds::get(1L); auto const min_periods = 1L; - auto const output = cudf::grouped_time_range_rolling_window(grouping_keys, - time_col, - cudf::order::DESCENDING, - agg_col, - unbounded_preceding, - one_day_following, - min_periods, - cudf::make_count_aggregation()); + auto const output = cudf::grouped_time_range_rolling_window( + grouping_keys, + time_col, + cudf::order::DESCENDING, + agg_col, + unbounded_preceding, + one_day_following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -1854,14 +1884,15 @@ TYPED_TEST(TypedUnboundedWindowTest, UnboundedFollowingWindowSingleGroupTimestam auto const one_day_preceding = cudf::window_bounds::get(1L); auto const unbounded_following = cudf::window_bounds::unbounded(); auto const min_periods = 1L; - auto const output = cudf::grouped_time_range_rolling_window(grouping_keys, - time_col, - cudf::order::DESCENDING, - agg_col, - one_day_preceding, - unbounded_following, - min_periods, - cudf::make_count_aggregation()); + auto const output = cudf::grouped_time_range_rolling_window( + grouping_keys, + time_col, + cudf::order::DESCENDING, + agg_col, + one_day_preceding, + unbounded_following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -1884,14 +1915,15 @@ TYPED_TEST(TypedUnboundedWindowTest, auto const unbounded_preceding = cudf::window_bounds::unbounded(); auto const unbounded_following = cudf::window_bounds::unbounded(); auto const min_periods = 1L; - auto const output = cudf::grouped_time_range_rolling_window(grouping_keys, - time_col, - cudf::order::DESCENDING, - agg_col, - unbounded_preceding, - unbounded_following, - min_periods, - cudf::make_count_aggregation()); + auto const output = cudf::grouped_time_range_rolling_window( + grouping_keys, + time_col, + cudf::order::DESCENDING, + agg_col, + unbounded_preceding, + unbounded_following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -1912,14 +1944,15 @@ TYPED_TEST(TypedUnboundedWindowTest, UnboundedPrecedingCountMultiGroupTimestampA auto const unbounded_preceding = cudf::window_bounds::unbounded(); auto const one_day_following = cudf::window_bounds::get(1L); auto const min_periods = 1L; - auto const output = cudf::grouped_time_range_rolling_window(grouping_keys, - time_col, - cudf::order::ASCENDING, - agg_col, - unbounded_preceding, - one_day_following, - min_periods, - cudf::make_count_aggregation()); + auto const output = cudf::grouped_time_range_rolling_window( + grouping_keys, + time_col, + cudf::order::ASCENDING, + agg_col, + unbounded_preceding, + one_day_following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -1940,14 +1973,15 @@ TYPED_TEST(TypedUnboundedWindowTest, UnboundedFollowingCountMultiGroupTimestampA auto const one_day_preceding = cudf::window_bounds::get(1L); auto const unbounded_following = cudf::window_bounds::unbounded(); auto const min_periods = 1L; - auto const output = cudf::grouped_time_range_rolling_window(grouping_keys, - time_col, - cudf::order::ASCENDING, - agg_col, - one_day_preceding, - unbounded_following, - min_periods, - cudf::make_count_aggregation()); + auto const output = cudf::grouped_time_range_rolling_window( + grouping_keys, + time_col, + cudf::order::ASCENDING, + agg_col, + one_day_preceding, + unbounded_following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -1969,14 +2003,15 @@ TYPED_TEST(TypedUnboundedWindowTest, auto const unbounded_preceding = cudf::window_bounds::unbounded(); auto const unbounded_following = cudf::window_bounds::unbounded(); auto const min_periods = 1L; - auto const output = cudf::grouped_time_range_rolling_window(grouping_keys, - time_col, - cudf::order::ASCENDING, - agg_col, - unbounded_preceding, - unbounded_following, - min_periods, - cudf::make_count_aggregation()); + auto const output = cudf::grouped_time_range_rolling_window( + grouping_keys, + time_col, + cudf::order::ASCENDING, + agg_col, + unbounded_preceding, + unbounded_following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -1997,14 +2032,15 @@ TYPED_TEST(TypedUnboundedWindowTest, UnboundedPrecedingCountMultiGroupTimestampA auto const unbounded_preceding = cudf::window_bounds::unbounded(); auto const one_day_following = cudf::window_bounds::get(1L); auto const min_periods = 1L; - auto const output = cudf::grouped_time_range_rolling_window(grouping_keys, - time_col, - cudf::order::ASCENDING, - agg_col, - unbounded_preceding, - one_day_following, - min_periods, - cudf::make_count_aggregation()); + auto const output = cudf::grouped_time_range_rolling_window( + grouping_keys, + time_col, + cudf::order::ASCENDING, + agg_col, + unbounded_preceding, + one_day_following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -2025,14 +2061,15 @@ TYPED_TEST(TypedUnboundedWindowTest, UnboundedFollowingCountMultiGroupTimestampA auto const one_day_preceding = cudf::window_bounds::get(1L); auto const unbounded_following = cudf::window_bounds::unbounded(); auto const min_periods = 1L; - auto const output = cudf::grouped_time_range_rolling_window(grouping_keys, - time_col, - cudf::order::ASCENDING, - agg_col, - one_day_preceding, - unbounded_following, - min_periods, - cudf::make_count_aggregation()); + auto const output = cudf::grouped_time_range_rolling_window( + grouping_keys, + time_col, + cudf::order::ASCENDING, + agg_col, + one_day_preceding, + unbounded_following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -2054,14 +2091,15 @@ TYPED_TEST(TypedUnboundedWindowTest, auto const unbounded_preceding = cudf::window_bounds::unbounded(); auto const unbounded_following = cudf::window_bounds::unbounded(); auto const min_periods = 1L; - auto const output = cudf::grouped_time_range_rolling_window(grouping_keys, - time_col, - cudf::order::ASCENDING, - agg_col, - unbounded_preceding, - unbounded_following, - min_periods, - cudf::make_count_aggregation()); + auto const output = cudf::grouped_time_range_rolling_window( + grouping_keys, + time_col, + cudf::order::ASCENDING, + agg_col, + unbounded_preceding, + unbounded_following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -2082,14 +2120,15 @@ TYPED_TEST(TypedUnboundedWindowTest, UnboundedPrecedingCountMultiGroupTimestampD auto const unbounded_preceding = cudf::window_bounds::unbounded(); auto const one_day_following = cudf::window_bounds::get(1L); auto const min_periods = 1L; - auto const output = cudf::grouped_time_range_rolling_window(grouping_keys, - time_col, - cudf::order::DESCENDING, - agg_col, - unbounded_preceding, - one_day_following, - min_periods, - cudf::make_count_aggregation()); + auto const output = cudf::grouped_time_range_rolling_window( + grouping_keys, + time_col, + cudf::order::DESCENDING, + agg_col, + unbounded_preceding, + one_day_following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -2110,14 +2149,15 @@ TYPED_TEST(TypedUnboundedWindowTest, UnboundedFollowingCountMultiGroupTimestampD auto const one_day_preceding = cudf::window_bounds::get(1L); auto const unbounded_following = cudf::window_bounds::unbounded(); auto const min_periods = 1L; - auto const output = cudf::grouped_time_range_rolling_window(grouping_keys, - time_col, - cudf::order::DESCENDING, - agg_col, - one_day_preceding, - unbounded_following, - min_periods, - cudf::make_count_aggregation()); + auto const output = cudf::grouped_time_range_rolling_window( + grouping_keys, + time_col, + cudf::order::DESCENDING, + agg_col, + one_day_preceding, + unbounded_following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -2139,14 +2179,15 @@ TYPED_TEST(TypedUnboundedWindowTest, auto const unbounded_preceding = cudf::window_bounds::unbounded(); auto const unbounded_following = cudf::window_bounds::unbounded(); auto const min_periods = 1L; - auto const output = cudf::grouped_time_range_rolling_window(grouping_keys, - time_col, - cudf::order::DESCENDING, - agg_col, - unbounded_preceding, - unbounded_following, - min_periods, - cudf::make_count_aggregation()); + auto const output = cudf::grouped_time_range_rolling_window( + grouping_keys, + time_col, + cudf::order::DESCENDING, + agg_col, + unbounded_preceding, + unbounded_following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -2167,14 +2208,15 @@ TYPED_TEST(TypedUnboundedWindowTest, UnboundedPrecedingCountMultiGroupTimestampD auto const unbounded_preceding = cudf::window_bounds::unbounded(); auto const one_day_following = cudf::window_bounds::get(1L); auto const min_periods = 1L; - auto const output = cudf::grouped_time_range_rolling_window(grouping_keys, - time_col, - cudf::order::DESCENDING, - agg_col, - unbounded_preceding, - one_day_following, - min_periods, - cudf::make_count_aggregation()); + auto const output = cudf::grouped_time_range_rolling_window( + grouping_keys, + time_col, + cudf::order::DESCENDING, + agg_col, + unbounded_preceding, + one_day_following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -2195,14 +2237,15 @@ TYPED_TEST(TypedUnboundedWindowTest, UnboundedFollowingCountMultiGroupTimestampD auto const one_day_preceding = cudf::window_bounds::get(1L); auto const unbounded_following = cudf::window_bounds::unbounded(); auto const min_periods = 1L; - auto const output = cudf::grouped_time_range_rolling_window(grouping_keys, - time_col, - cudf::order::DESCENDING, - agg_col, - one_day_preceding, - unbounded_following, - min_periods, - cudf::make_count_aggregation()); + auto const output = cudf::grouped_time_range_rolling_window( + grouping_keys, + time_col, + cudf::order::DESCENDING, + agg_col, + one_day_preceding, + unbounded_following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -2224,14 +2267,15 @@ TYPED_TEST(TypedUnboundedWindowTest, auto const unbounded_preceding = cudf::window_bounds::unbounded(); auto const unbounded_following = cudf::window_bounds::unbounded(); auto const min_periods = 1L; - auto const output = cudf::grouped_time_range_rolling_window(grouping_keys, - time_col, - cudf::order::DESCENDING, - agg_col, - unbounded_preceding, - unbounded_following, - min_periods, - cudf::make_count_aggregation()); + auto const output = cudf::grouped_time_range_rolling_window( + grouping_keys, + time_col, + cudf::order::DESCENDING, + agg_col, + unbounded_preceding, + unbounded_following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -2251,12 +2295,13 @@ TYPED_TEST(TypedUnboundedWindowTest, UnboundedPrecedingWindowSingleGroup) auto const unbounded_preceding = cudf::window_bounds::unbounded(); auto const one_row_following = cudf::window_bounds::get(1L); auto const min_periods = 1L; - auto const output = cudf::grouped_rolling_window(grouping_keys, - agg_col, - unbounded_preceding, - one_row_following, - min_periods, - cudf::make_count_aggregation()); + auto const output = + cudf::grouped_rolling_window(grouping_keys, + agg_col, + unbounded_preceding, + one_row_following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -2276,12 +2321,13 @@ TYPED_TEST(TypedUnboundedWindowTest, UnboundedFollowingWindowSingleGroup) auto const one_row_preceding = cudf::window_bounds::get(1L); auto const unbounded_following = cudf::window_bounds::unbounded(); auto const min_periods = 1L; - auto const output = cudf::grouped_rolling_window(grouping_keys, - agg_col, - one_row_preceding, - unbounded_following, - min_periods, - cudf::make_count_aggregation()); + auto const output = + cudf::grouped_rolling_window(grouping_keys, + agg_col, + one_row_preceding, + unbounded_following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -2301,12 +2347,13 @@ TYPED_TEST(TypedUnboundedWindowTest, UnboundedPrecedingAndFollowingWindowSingleG auto const unbounded_preceding = cudf::window_bounds::unbounded(); auto const unbounded_following = cudf::window_bounds::unbounded(); auto const min_periods = 1L; - auto const output = cudf::grouped_rolling_window(grouping_keys, - agg_col, - unbounded_preceding, - unbounded_following, - min_periods, - cudf::make_count_aggregation()); + auto const output = + cudf::grouped_rolling_window(grouping_keys, + agg_col, + unbounded_preceding, + unbounded_following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -2326,12 +2373,13 @@ TYPED_TEST(TypedUnboundedWindowTest, UnboundedPrecedingWindowMultiGroup) auto const unbounded_preceding = cudf::window_bounds::unbounded(); auto const one_row_following = cudf::window_bounds::get(1L); auto const min_periods = 1L; - auto const output = cudf::grouped_rolling_window(grouping_keys, - agg_col, - unbounded_preceding, - one_row_following, - min_periods, - cudf::make_count_aggregation()); + auto const output = + cudf::grouped_rolling_window(grouping_keys, + agg_col, + unbounded_preceding, + one_row_following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -2351,12 +2399,13 @@ TYPED_TEST(TypedUnboundedWindowTest, UnboundedFollowingWindowMultiGroup) auto const one_row_preceding = cudf::window_bounds::get(1L); auto const unbounded_following = cudf::window_bounds::unbounded(); auto const min_periods = 1L; - auto const output = cudf::grouped_rolling_window(grouping_keys, - agg_col, - one_row_preceding, - unbounded_following, - min_periods, - cudf::make_count_aggregation()); + auto const output = + cudf::grouped_rolling_window(grouping_keys, + agg_col, + one_row_preceding, + unbounded_following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ @@ -2376,12 +2425,13 @@ TYPED_TEST(TypedUnboundedWindowTest, UnboundedPrecedingAndFollowingWindowMultiGr auto const unbounded_preceding = cudf::window_bounds::unbounded(); auto const unbounded_following = cudf::window_bounds::unbounded(); auto const min_periods = 1L; - auto const output = cudf::grouped_rolling_window(grouping_keys, - agg_col, - unbounded_preceding, - unbounded_following, - min_periods, - cudf::make_count_aggregation()); + auto const output = + cudf::grouped_rolling_window(grouping_keys, + agg_col, + unbounded_preceding, + unbounded_following, + min_periods, + *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(output->view(), fixed_width_column_wrapper{ diff --git a/cpp/tests/rolling/lead_lag_test.cpp b/cpp/tests/rolling/lead_lag_test.cpp index bc71a7acab9..a54fb236f29 100644 --- a/cpp/tests/rolling/lead_lag_test.cpp +++ b/cpp/tests/rolling/lead_lag_test.cpp @@ -74,12 +74,13 @@ TYPED_TEST(TypedLeadLagWindowTest, LeadLagBasics) auto const following = 3; auto const min_periods = 1; - auto lead_3_output_col = cudf::grouped_rolling_window(grouping_keys, - input_col->view(), - preceding, - following, - min_periods, - cudf::make_lead_aggregation(3)); + auto lead_3_output_col = + cudf::grouped_rolling_window(grouping_keys, + input_col->view(), + preceding, + following, + min_periods, + *cudf::make_lead_aggregation(3)); expect_columns_equivalent( *lead_3_output_col, @@ -88,12 +89,13 @@ TYPED_TEST(TypedLeadLagWindowTest, LeadLagBasics) .release() ->view()); - auto lag_2_output_col = cudf::grouped_rolling_window(grouping_keys, - input_col->view(), - preceding, - following, - min_periods, - cudf::make_lag_aggregation(2)); + auto lag_2_output_col = + cudf::grouped_rolling_window(grouping_keys, + input_col->view(), + preceding, + following, + min_periods, + *cudf::make_lag_aggregation(2)); expect_columns_equivalent( *lag_2_output_col, @@ -118,12 +120,13 @@ TYPED_TEST(TypedLeadLagWindowTest, LeadLagWithNulls) auto const following = 3; auto const min_periods = 1; - auto lead_3_output_col = cudf::grouped_rolling_window(grouping_keys, - input_col->view(), - preceding, - following, - min_periods, - cudf::make_lead_aggregation(3)); + auto lead_3_output_col = + cudf::grouped_rolling_window(grouping_keys, + input_col->view(), + preceding, + following, + min_periods, + *cudf::make_lead_aggregation(3)); expect_columns_equivalent( *lead_3_output_col, @@ -132,12 +135,13 @@ TYPED_TEST(TypedLeadLagWindowTest, LeadLagWithNulls) .release() ->view()); - auto const lag_2_output_col = cudf::grouped_rolling_window(grouping_keys, - input_col->view(), - preceding, - following, - min_periods, - cudf::make_lag_aggregation(2)); + auto const lag_2_output_col = + cudf::grouped_rolling_window(grouping_keys, + input_col->view(), + preceding, + following, + min_periods, + *cudf::make_lag_aggregation(2)); expect_columns_equivalent( *lag_2_output_col, @@ -166,13 +170,14 @@ TYPED_TEST(TypedLeadLagWindowTest, TestLeadLagWithDefaults) cudf::make_fixed_width_scalar(detail::fixed_width_type_converter{}(99)); auto const default_outputs = cudf::make_column_from_scalar(*default_value, input_col->size()); - auto lead_3_output_col = cudf::grouped_rolling_window(grouping_keys, - input_col->view(), - *default_outputs, - preceding, - following, - min_periods, - cudf::make_lead_aggregation(3)); + auto lead_3_output_col = + cudf::grouped_rolling_window(grouping_keys, + input_col->view(), + *default_outputs, + preceding, + following, + min_periods, + *cudf::make_lead_aggregation(3)); expect_columns_equivalent( *lead_3_output_col, fixed_width_column_wrapper{{3, 4, 5, 99, 99, 99, 30, 40, 50, 99, 99, 99}, @@ -180,13 +185,14 @@ TYPED_TEST(TypedLeadLagWindowTest, TestLeadLagWithDefaults) .release() ->view()); - auto const lag_2_output_col = cudf::grouped_rolling_window(grouping_keys, - input_col->view(), - *default_outputs, - preceding, - following, - min_periods, - cudf::make_lag_aggregation(2)); + auto const lag_2_output_col = + cudf::grouped_rolling_window(grouping_keys, + input_col->view(), + *default_outputs, + preceding, + following, + min_periods, + *cudf::make_lag_aggregation(2)); expect_columns_equivalent( *lag_2_output_col, @@ -216,13 +222,14 @@ TYPED_TEST(TypedLeadLagWindowTest, TestLeadLagWithDefaultsContainingNulls) auto const following = 3; auto const min_periods = 1; - auto lead_3_output_col = cudf::grouped_rolling_window(grouping_keys, - input_col->view(), - *default_outputs, - preceding, - following, - min_periods, - cudf::make_lead_aggregation(3)); + auto lead_3_output_col = + cudf::grouped_rolling_window(grouping_keys, + input_col->view(), + *default_outputs, + preceding, + following, + min_periods, + *cudf::make_lead_aggregation(3)); expect_columns_equivalent( *lead_3_output_col, fixed_width_column_wrapper{{3, 4, 5, 99, 99, -1, 30, 40, 50, 99, 99, -1}, @@ -230,13 +237,14 @@ TYPED_TEST(TypedLeadLagWindowTest, TestLeadLagWithDefaultsContainingNulls) .release() ->view()); - auto const lag_2_output_col = cudf::grouped_rolling_window(grouping_keys, - input_col->view(), - *default_outputs, - preceding, - following, - min_periods, - cudf::make_lag_aggregation(2)); + auto const lag_2_output_col = + cudf::grouped_rolling_window(grouping_keys, + input_col->view(), + *default_outputs, + preceding, + following, + min_periods, + *cudf::make_lag_aggregation(2)); expect_columns_equivalent( *lag_2_output_col, @@ -265,12 +273,13 @@ TYPED_TEST(TypedLeadLagWindowTest, TestLeadLagWithOutOfRangeOffsets) auto const following = 3; auto const min_periods = 1; - auto lead_30_output_col = cudf::grouped_rolling_window(grouping_keys, - input_col->view(), - preceding, - following, - min_periods, - cudf::make_lead_aggregation(30)); + auto lead_30_output_col = + cudf::grouped_rolling_window(grouping_keys, + input_col->view(), + preceding, + following, + min_periods, + *cudf::make_lead_aggregation(30)); expect_columns_equivalent( *lead_30_output_col, @@ -279,13 +288,14 @@ TYPED_TEST(TypedLeadLagWindowTest, TestLeadLagWithOutOfRangeOffsets) .release() ->view()); - auto const lag_20_output_col = cudf::grouped_rolling_window(grouping_keys, - input_col->view(), - *default_outputs, - preceding, - following, - min_periods, - cudf::make_lag_aggregation(20)); + auto const lag_20_output_col = + cudf::grouped_rolling_window(grouping_keys, + input_col->view(), + *default_outputs, + preceding, + following, + min_periods, + *cudf::make_lag_aggregation(20)); expect_columns_equivalent( *lag_20_output_col, @@ -310,21 +320,23 @@ TYPED_TEST(TypedLeadLagWindowTest, TestLeadLagWithZeroOffsets) auto const following = 3; auto const min_periods = 1; - auto lead_0_output_col = cudf::grouped_rolling_window(grouping_keys, - input_col->view(), - preceding, - following, - min_periods, - cudf::make_lead_aggregation(0)); + auto lead_0_output_col = + cudf::grouped_rolling_window(grouping_keys, + input_col->view(), + preceding, + following, + min_periods, + *cudf::make_lead_aggregation(0)); expect_columns_equivalent(*lead_0_output_col, *input_col); - auto const lag_0_output_col = cudf::grouped_rolling_window(grouping_keys, - input_col->view(), - preceding, - following, - min_periods, - cudf::make_lag_aggregation(0)); + auto const lag_0_output_col = + cudf::grouped_rolling_window(grouping_keys, + input_col->view(), + preceding, + following, + min_periods, + *cudf::make_lag_aggregation(0)); expect_columns_equivalent(*lag_0_output_col, *input_col); } @@ -348,13 +360,14 @@ TYPED_TEST(TypedLeadLagWindowTest, TestLeadLagWithNegativeOffsets) auto const following = 3; auto const min_periods = 1; - auto lag_minus_3_output_col = cudf::grouped_rolling_window(grouping_keys, - input_col->view(), - *default_outputs, - preceding, - following, - min_periods, - cudf::make_lag_aggregation(-3)); + auto lag_minus_3_output_col = + cudf::grouped_rolling_window(grouping_keys, + input_col->view(), + *default_outputs, + preceding, + following, + min_periods, + *cudf::make_lag_aggregation(-3)); expect_columns_equivalent( *lag_minus_3_output_col, @@ -370,7 +383,7 @@ TYPED_TEST(TypedLeadLagWindowTest, TestLeadLagWithNegativeOffsets) preceding, following, min_periods, - cudf::make_lead_aggregation(-2)); + *cudf::make_lead_aggregation(-2)); expect_columns_equivalent( *lead_minus_2_output_col, @@ -397,25 +410,27 @@ TYPED_TEST(TypedLeadLagWindowTest, TestLeadLagWithNoGrouping) auto const following = 3; auto const min_periods = 1; - auto lead_3_output_col = cudf::grouped_rolling_window(grouping_keys, - input_col->view(), - *default_outputs, - preceding, - following, - min_periods, - cudf::make_lead_aggregation(3)); + auto lead_3_output_col = + cudf::grouped_rolling_window(grouping_keys, + input_col->view(), + *default_outputs, + preceding, + following, + min_periods, + *cudf::make_lead_aggregation(3)); expect_columns_equivalent( *lead_3_output_col, fixed_width_column_wrapper{{3, 4, 5, 99, 99, 99}, {1, 1, 1, 1, 1, 1}}.release()->view()); - auto const lag_2_output_col = cudf::grouped_rolling_window(grouping_keys, - input_col->view(), - *default_outputs, - preceding, - following, - min_periods, - cudf::make_lag_aggregation(2)); + auto const lag_2_output_col = + cudf::grouped_rolling_window(grouping_keys, + input_col->view(), + *default_outputs, + preceding, + following, + min_periods, + *cudf::make_lag_aggregation(2)); expect_columns_equivalent( *lag_2_output_col, @@ -443,13 +458,14 @@ TYPED_TEST(TypedLeadLagWindowTest, TestLeadLagWithAllNullInput) auto const following = 3; auto const min_periods = 1; - auto lead_3_output_col = cudf::grouped_rolling_window(grouping_keys, - input_col->view(), - *default_outputs, - preceding, - following, - min_periods, - cudf::make_lead_aggregation(3)); + auto lead_3_output_col = + cudf::grouped_rolling_window(grouping_keys, + input_col->view(), + *default_outputs, + preceding, + following, + min_periods, + *cudf::make_lead_aggregation(3)); expect_columns_equivalent( *lead_3_output_col, fixed_width_column_wrapper{{-1, -1, -1, 99, 99, 99, -1, -1, -1, 99, 99, 99}, @@ -457,13 +473,14 @@ TYPED_TEST(TypedLeadLagWindowTest, TestLeadLagWithAllNullInput) .release() ->view()); - auto const lag_2_output_col = cudf::grouped_rolling_window(grouping_keys, - input_col->view(), - *default_outputs, - preceding, - following, - min_periods, - cudf::make_lag_aggregation(2)); + auto const lag_2_output_col = + cudf::grouped_rolling_window(grouping_keys, + input_col->view(), + *default_outputs, + preceding, + following, + min_periods, + *cudf::make_lag_aggregation(2)); expect_columns_equivalent( *lag_2_output_col, @@ -497,17 +514,19 @@ TYPED_TEST(TypedLeadLagWindowTest, DefaultValuesWithoutLeadLag) auto const min_periods = 1; auto const assert_aggregation_fails = [&](auto&& aggr) { - EXPECT_THROW(cudf::grouped_rolling_window(grouping_keys, - input_col->view(), - default_outputs->view(), - preceding, - following, - min_periods, - cudf::make_count_aggregation()), - cudf::logic_error); + EXPECT_THROW( + cudf::grouped_rolling_window(grouping_keys, + input_col->view(), + default_outputs->view(), + preceding, + following, + min_periods, + *cudf::make_count_aggregation()), + cudf::logic_error); }; - auto aggs = {cudf::make_count_aggregation(), cudf::make_min_aggregation()}; + auto aggs = {cudf::make_count_aggregation(), + cudf::make_min_aggregation()}; std::for_each( aggs.begin(), aggs.end(), [&](auto& agg) { assert_aggregation_fails(std::move(agg)); }); } @@ -546,12 +565,13 @@ TYPED_TEST(TypedNestedLeadLagWindowTest, NumericListsWithNullsAllOver) auto const following = 3; auto const min_periods = 1; - auto lead_3_output_col = cudf::grouped_rolling_window(grouping_keys, - input_col->view(), - preceding, - following, - min_periods, - cudf::make_lead_aggregation(3)); + auto lead_3_output_col = + cudf::grouped_rolling_window(grouping_keys, + input_col->view(), + preceding, + following, + min_periods, + *cudf::make_lead_aggregation(3)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( lead_3_output_col->view(), @@ -571,12 +591,13 @@ TYPED_TEST(TypedNestedLeadLagWindowTest, NumericListsWithNullsAllOver) .release() ->view()); - auto lag_1_output_col = cudf::grouped_rolling_window(grouping_keys, - input_col->view(), - preceding, - following, - min_periods, - cudf::make_lag_aggregation(1)); + auto lag_1_output_col = + cudf::grouped_rolling_window(grouping_keys, + input_col->view(), + preceding, + following, + min_periods, + *cudf::make_lag_aggregation(1)); expect_columns_equivalent(lag_1_output_col->view(), lcw{{{}, @@ -643,12 +664,13 @@ TYPED_TEST(TypedNestedLeadLagWindowTest, NumericListsWithDefaults) auto const following = 3; auto const min_periods = 1; - auto lead_3_output_col = cudf::grouped_rolling_window(grouping_keys, - input_col->view(), - preceding, - following, - min_periods, - cudf::make_lead_aggregation(3)); + auto lead_3_output_col = + cudf::grouped_rolling_window(grouping_keys, + input_col->view(), + preceding, + following, + min_periods, + *cudf::make_lead_aggregation(3)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( lead_3_output_col->view(), @@ -668,12 +690,13 @@ TYPED_TEST(TypedNestedLeadLagWindowTest, NumericListsWithDefaults) .release() ->view()); - auto lag_1_output_col = cudf::grouped_rolling_window(grouping_keys, - input_col->view(), - preceding, - following, - min_periods, - cudf::make_lag_aggregation(1)); + auto lag_1_output_col = + cudf::grouped_rolling_window(grouping_keys, + input_col->view(), + preceding, + following, + min_periods, + *cudf::make_lag_aggregation(1)); expect_columns_equivalent(lag_1_output_col->view(), lcw{{{}, @@ -738,12 +761,13 @@ TYPED_TEST(TypedNestedLeadLagWindowTest, Structs) // Test LEAD(). { - auto lead_3_output_col = cudf::grouped_rolling_window(grouping_keys, - structs_col->view(), - preceding, - following, - min_periods, - cudf::make_lead_aggregation(3)); + auto lead_3_output_col = + cudf::grouped_rolling_window(grouping_keys, + structs_col->view(), + preceding, + following, + min_periods, + *cudf::make_lead_aggregation(3)); auto expected_lists_col = lcw{{{3, 3, 3}, {{4, 4, 4, 4}, null_at_2}, @@ -772,12 +796,13 @@ TYPED_TEST(TypedNestedLeadLagWindowTest, Structs) // Test LAG() { - auto lag_1_output_col = cudf::grouped_rolling_window(grouping_keys, - structs_col->view(), - preceding, - following, - min_periods, - cudf::make_lag_aggregation(1)); + auto lag_1_output_col = + cudf::grouped_rolling_window(grouping_keys, + structs_col->view(), + preceding, + following, + min_periods, + *cudf::make_lag_aggregation(1)); auto expected_lists_col = lcw{{{}, // null. {0, 0}, {1, 1}, @@ -850,7 +875,7 @@ TEST_F(LeadLagNonFixedWidthTest, StringsNoDefaults) preceding, following, min_periods, - cudf::make_lead_aggregation(2)); + *cudf::make_lead_aggregation(2)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( lead_2->view(), strings_column_wrapper{ @@ -862,7 +887,7 @@ TEST_F(LeadLagNonFixedWidthTest, StringsNoDefaults) preceding, following, min_periods, - cudf::make_lag_aggregation(1)); + *cudf::make_lag_aggregation(1)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( lag_1->view(), @@ -918,7 +943,7 @@ TEST_F(LeadLagNonFixedWidthTest, StringsWithDefaults) preceding, following, min_periods, - cudf::make_lead_aggregation(2)); + *cudf::make_lead_aggregation(2)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(lead_2->view(), strings_column_wrapper{"A_22", "A_333", @@ -939,7 +964,7 @@ TEST_F(LeadLagNonFixedWidthTest, StringsWithDefaults) preceding, following, min_periods, - cudf::make_lag_aggregation(1)); + *cudf::make_lag_aggregation(1)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( lag_1->view(), @@ -995,7 +1020,7 @@ TEST_F(LeadLagNonFixedWidthTest, StringsWithDefaultsNoGroups) preceding, following, min_periods, - cudf::make_lead_aggregation(2)); + *cudf::make_lead_aggregation(2)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(lead_2->view(), strings_column_wrapper{{"A_22", "A_333", @@ -1017,7 +1042,7 @@ TEST_F(LeadLagNonFixedWidthTest, StringsWithDefaultsNoGroups) preceding, following, min_periods, - cudf::make_lag_aggregation(1)); + *cudf::make_lag_aggregation(1)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( lag_1->view(), @@ -1065,12 +1090,13 @@ TEST_F(LeadLagNonFixedWidthTest, Dictionary) auto const min_periods = 1; { - auto lead_2 = grouped_rolling_window(grouping_keys, - input_col->view(), - preceding, - following, - min_periods, - cudf::make_lead_aggregation(2)); + auto lead_2 = + grouped_rolling_window(grouping_keys, + input_col->view(), + preceding, + following, + min_periods, + *cudf::make_lead_aggregation(2)); auto expected_keys = strings_column_wrapper{input_strings}.release(); auto expected_values = @@ -1089,7 +1115,7 @@ TEST_F(LeadLagNonFixedWidthTest, Dictionary) preceding, following, min_periods, - cudf::make_lag_aggregation(1)); + *cudf::make_lag_aggregation(1)); auto expected_keys = strings_column_wrapper{input_strings}.release(); auto expected_values = diff --git a/cpp/tests/rolling/range_rolling_window_test.cpp b/cpp/tests/rolling/range_rolling_window_test.cpp index 48aa69f5816..03bb7a80a37 100644 --- a/cpp/tests/rolling/range_rolling_window_test.cpp +++ b/cpp/tests/rolling/range_rolling_window_test.cpp @@ -70,7 +70,7 @@ struct window_exec { size_type num_rows() { return gby_column.size(); } - std::unique_ptr operator()(std::unique_ptr const& agg) const + std::unique_ptr operator()(std::unique_ptr const& agg) const { auto const grouping_keys = cudf::table_view{std::vector{gby_column}}; @@ -81,7 +81,7 @@ struct window_exec { range_window_bounds::get(preceding), range_window_bounds::get(following), min_periods, - agg); + *agg); } private: @@ -112,34 +112,36 @@ void verify_results_for_ascending(WindowExecT exec) auto const last_invalid = thrust::make_transform_iterator( thrust::make_counting_iterator(0), [&n_rows](auto i) { return i != (n_rows - 1); }); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(exec(make_count_aggregation(null_policy::INCLUDE))->view(), - size_col{{1, 2, 2, 3, 2, 3, 3, 4, 4, 1}, all_valid}); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(exec(make_count_aggregation())->view(), + CUDF_TEST_EXPECT_COLUMNS_EQUAL( + exec(make_count_aggregation(null_policy::INCLUDE))->view(), + size_col{{1, 2, 2, 3, 2, 3, 3, 4, 4, 1}, all_valid}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(exec(make_count_aggregation())->view(), size_col{{1, 2, 2, 3, 2, 3, 3, 4, 4, 0}, all_valid}); CUDF_TEST_EXPECT_COLUMNS_EQUAL( - exec(make_sum_aggregation())->view(), + exec(make_sum_aggregation())->view(), fwcw{{0, 12, 12, 12, 8, 17, 17, 18, 18, 1}, last_invalid}); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(exec(make_min_aggregation())->view(), + CUDF_TEST_EXPECT_COLUMNS_EQUAL(exec(make_min_aggregation())->view(), int_col{{0, 4, 4, 2, 2, 3, 3, 1, 1, 1}, last_invalid}); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(exec(make_max_aggregation())->view(), + CUDF_TEST_EXPECT_COLUMNS_EQUAL(exec(make_max_aggregation())->view(), int_col{{0, 8, 8, 6, 6, 9, 9, 9, 9, 1}, last_invalid}); CUDF_TEST_EXPECT_COLUMNS_EQUAL( - exec(make_mean_aggregation())->view(), + exec(make_mean_aggregation())->view(), fwcw{{0.0, 6.0, 6.0, 4.0, 4.0, 17.0 / 3, 17.0 / 3, 4.5, 4.5, 1.0}, last_invalid}); - CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(exec(make_collect_list_aggregation())->view(), - lists_col{{{0}, - {8, 4}, - {8, 4}, - {4, 6, 2}, - {6, 2}, - {9, 3, 5}, - {9, 3, 5}, - {9, 3, 5, 1}, - {9, 3, 5, 1}, - {{0}, all_invalid}}, - all_valid}); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( - exec(make_collect_list_aggregation(null_policy::EXCLUDE))->view(), + exec(make_collect_list_aggregation())->view(), + lists_col{{{0}, + {8, 4}, + {8, 4}, + {4, 6, 2}, + {6, 2}, + {9, 3, 5}, + {9, 3, 5}, + {9, 3, 5, 1}, + {9, 3, 5, 1}, + {{0}, all_invalid}}, + all_valid}); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + exec(make_collect_list_aggregation(null_policy::EXCLUDE))->view(), lists_col{{{0}, {8, 4}, {8, 4}, @@ -189,34 +191,36 @@ void verify_results_for_descending(WindowExecT exec) auto const first_invalid = thrust::make_transform_iterator(thrust::make_counting_iterator(0), [](auto i) { return i != 0; }); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(exec(make_count_aggregation(null_policy::INCLUDE))->view(), - size_col{{1, 4, 4, 3, 3, 2, 3, 2, 2, 1}, all_valid}); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(exec(make_count_aggregation())->view(), + CUDF_TEST_EXPECT_COLUMNS_EQUAL( + exec(make_count_aggregation(null_policy::INCLUDE))->view(), + size_col{{1, 4, 4, 3, 3, 2, 3, 2, 2, 1}, all_valid}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(exec(make_count_aggregation())->view(), size_col{{0, 4, 4, 3, 3, 2, 3, 2, 2, 1}, all_valid}); CUDF_TEST_EXPECT_COLUMNS_EQUAL( - exec(make_sum_aggregation())->view(), + exec(make_sum_aggregation())->view(), fwcw{{1, 18, 18, 17, 17, 8, 12, 12, 12, 0}, first_invalid}); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(exec(make_min_aggregation())->view(), + CUDF_TEST_EXPECT_COLUMNS_EQUAL(exec(make_min_aggregation())->view(), int_col{{1, 1, 1, 3, 3, 2, 2, 4, 4, 0}, first_invalid}); - CUDF_TEST_EXPECT_COLUMNS_EQUAL(exec(make_max_aggregation())->view(), + CUDF_TEST_EXPECT_COLUMNS_EQUAL(exec(make_max_aggregation())->view(), int_col{{1, 9, 9, 9, 9, 6, 6, 8, 8, 0}, first_invalid}); CUDF_TEST_EXPECT_COLUMNS_EQUAL( - exec(make_mean_aggregation())->view(), + exec(make_mean_aggregation())->view(), fwcw{{1.0, 4.5, 4.5, 17.0 / 3, 17.0 / 3, 4.0, 4.0, 6.0, 6.0, 0.0}, first_invalid}); - CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(exec(make_collect_list_aggregation())->view(), - lists_col{{{{0}, all_invalid}, - {1, 5, 3, 9}, - {1, 5, 3, 9}, - {5, 3, 9}, - {5, 3, 9}, - {2, 6}, - {2, 6, 4}, - {4, 8}, - {4, 8}, - {0}}, - all_valid}); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( - exec(make_collect_list_aggregation(null_policy::EXCLUDE))->view(), + exec(make_collect_list_aggregation())->view(), + lists_col{{{{0}, all_invalid}, + {1, 5, 3, 9}, + {1, 5, 3, 9}, + {5, 3, 9}, + {5, 3, 9}, + {2, 6}, + {2, 6, 4}, + {4, 8}, + {4, 8}, + {0}}, + all_valid}); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + exec(make_collect_list_aggregation(null_policy::EXCLUDE))->view(), lists_col{{{}, {1, 5, 3, 9}, {1, 5, 3, 9}, @@ -338,7 +342,7 @@ auto do_count_over_window( std::move(preceding), std::move(following), min_periods, - cudf::make_count_aggregation()); + *cudf::make_count_aggregation()); } TYPED_TEST(TypedRangeRollingNullsTest, CountSingleGroupOrderByASCNullsFirst) diff --git a/cpp/tests/rolling/rolling_test.cpp b/cpp/tests/rolling/rolling_test.cpp index b6e2b35e760..33171b269ce 100644 --- a/cpp/tests/rolling/rolling_test.cpp +++ b/cpp/tests/rolling/rolling_test.cpp @@ -53,12 +53,18 @@ TEST_F(RollingStringTest, NoNullStringMinMaxCount) fixed_width_column_wrapper expected_count({3, 4, 4, 4, 4, 4, 4, 3, 2}, {1, 1, 1, 1, 1, 1, 1, 1, 1}); - auto got_min = cudf::rolling_window(input, window[0], window[0], 1, cudf::make_min_aggregation()); - auto got_max = cudf::rolling_window(input, window[0], window[0], 1, cudf::make_max_aggregation()); - auto got_count_valid = - cudf::rolling_window(input, window[0], window[0], 1, cudf::make_count_aggregation()); + auto got_min = cudf::rolling_window( + input, window[0], window[0], 1, *cudf::make_min_aggregation()); + auto got_max = cudf::rolling_window( + input, window[0], window[0], 1, *cudf::make_max_aggregation()); + auto got_count_valid = cudf::rolling_window( + input, window[0], window[0], 1, *cudf::make_count_aggregation()); auto got_count_all = cudf::rolling_window( - input, window[0], window[0], 1, cudf::make_count_aggregation(cudf::null_policy::INCLUDE)); + input, + window[0], + window[0], + 1, + *cudf::make_count_aggregation(cudf::null_policy::INCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_min, got_min->view()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_max, got_max->view()); @@ -83,12 +89,18 @@ TEST_F(RollingStringTest, NullStringMinMaxCount) fixed_width_column_wrapper expected_count_all({3, 4, 4, 4, 4, 4, 4, 3, 2}, {1, 1, 1, 1, 1, 1, 1, 1, 1}); - auto got_min = cudf::rolling_window(input, window[0], window[0], 1, cudf::make_min_aggregation()); - auto got_max = cudf::rolling_window(input, window[0], window[0], 1, cudf::make_max_aggregation()); - auto got_count_valid = - cudf::rolling_window(input, window[0], window[0], 1, cudf::make_count_aggregation()); + auto got_min = cudf::rolling_window( + input, window[0], window[0], 1, *cudf::make_min_aggregation()); + auto got_max = cudf::rolling_window( + input, window[0], window[0], 1, *cudf::make_max_aggregation()); + auto got_count_valid = cudf::rolling_window( + input, window[0], window[0], 1, *cudf::make_count_aggregation()); auto got_count_all = cudf::rolling_window( - input, window[0], window[0], 1, cudf::make_count_aggregation(cudf::null_policy::INCLUDE)); + input, + window[0], + window[0], + 1, + *cudf::make_count_aggregation(cudf::null_policy::INCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_min, got_min->view()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_max, got_max->view()); @@ -113,12 +125,18 @@ TEST_F(RollingStringTest, MinPeriods) fixed_width_column_wrapper expected_count_all({3, 4, 4, 4, 4, 4, 4, 3, 2}, {0, 1, 1, 1, 1, 1, 1, 0, 0}); - auto got_min = cudf::rolling_window(input, window[0], window[0], 3, cudf::make_min_aggregation()); - auto got_max = cudf::rolling_window(input, window[0], window[0], 3, cudf::make_max_aggregation()); - auto got_count_valid = - cudf::rolling_window(input, window[0], window[0], 3, cudf::make_count_aggregation()); + auto got_min = cudf::rolling_window( + input, window[0], window[0], 3, *cudf::make_min_aggregation()); + auto got_max = cudf::rolling_window( + input, window[0], window[0], 3, *cudf::make_max_aggregation()); + auto got_count_valid = cudf::rolling_window( + input, window[0], window[0], 3, *cudf::make_count_aggregation()); auto got_count_all = cudf::rolling_window( - input, window[0], window[0], 4, cudf::make_count_aggregation(cudf::null_policy::INCLUDE)); + input, + window[0], + window[0], + 4, + *cudf::make_count_aggregation(cudf::null_policy::INCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_min, got_min->view()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_max, got_max->view()); @@ -134,7 +152,8 @@ TEST_F(RollingStringTest, ZeroWindowSize) fixed_width_column_wrapper expected_count({0, 0, 0, 0, 0, 0, 0, 0, 0}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); - auto got_count = cudf::rolling_window(input, 0, 0, 0, cudf::make_count_aggregation()); + auto got_count = cudf::rolling_window( + input, 0, 0, 0, *cudf::make_count_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_count, got_count->view()); } @@ -147,7 +166,7 @@ class RollingTest : public cudf::test::BaseFixture { const std::vector& preceding_window, const std::vector& following_window, size_type min_periods, - std::unique_ptr const& op) + cudf::rolling_aggregation const& op) { std::unique_ptr output; @@ -192,23 +211,39 @@ class RollingTest : public cudf::test::BaseFixture { size_type min_periods) { // test all supported aggregators - run_test_col( - input, preceding_window, following_window, min_periods, cudf::make_min_aggregation()); - run_test_col( - input, preceding_window, following_window, min_periods, cudf::make_count_aggregation()); run_test_col(input, preceding_window, following_window, min_periods, - cudf::make_count_aggregation(cudf::null_policy::INCLUDE)); + *cudf::make_min_aggregation()); + run_test_col(input, + preceding_window, + following_window, + min_periods, + *cudf::make_count_aggregation()); run_test_col( - input, preceding_window, following_window, min_periods, cudf::make_max_aggregation()); + input, + preceding_window, + following_window, + min_periods, + *cudf::make_count_aggregation(cudf::null_policy::INCLUDE)); + run_test_col(input, + preceding_window, + following_window, + min_periods, + *cudf::make_max_aggregation()); if (not cudf::is_timestamp(input.type())) { - run_test_col( - input, preceding_window, following_window, min_periods, cudf::make_sum_aggregation()); - run_test_col( - input, preceding_window, following_window, min_periods, cudf::make_mean_aggregation()); + run_test_col(input, + preceding_window, + following_window, + min_periods, + *cudf::make_sum_aggregation()); + run_test_col(input, + preceding_window, + following_window, + min_periods, + *cudf::make_mean_aggregation()); } } @@ -329,14 +364,14 @@ class RollingTest : public cudf::test::BaseFixture { } std::unique_ptr create_reference_output( - std::unique_ptr const& op, + cudf::rolling_aggregation const& op, cudf::column_view const& input, std::vector const& preceding_window, std::vector const& following_window, size_type min_periods) { // unroll aggregation types - switch (op->kind) { + switch (op.kind) { case cudf::aggregation::SUM: return create_reference_output col_valid = {1, 1, 1, 0, 1}; fixed_width_column_wrapper input(col_data.begin(), col_data.end(), col_valid.begin()); - EXPECT_THROW(cudf::rolling_window(input, 2, 2, -2, cudf::make_sum_aggregation()), - cudf::logic_error); + EXPECT_THROW( + cudf::rolling_window(input, 2, 2, -2, *cudf::make_sum_aggregation()), + cudf::logic_error); } // window array size mismatch @@ -401,38 +437,54 @@ TEST_F(RollingErrorTest, WindowArraySizeMismatch) fixed_width_column_wrapper four_elements(four.begin(), four.end()); // this runs ok - EXPECT_NO_THROW( - cudf::rolling_window(input, five_elements, five_elements, 1, cudf::make_sum_aggregation())); + EXPECT_NO_THROW(cudf::rolling_window(input, + five_elements, + five_elements, + 1, + *cudf::make_sum_aggregation())); // mismatch for the window array - EXPECT_THROW( - cudf::rolling_window(input, four_elements, five_elements, 1, cudf::make_sum_aggregation()), - cudf::logic_error); + EXPECT_THROW(cudf::rolling_window(input, + four_elements, + five_elements, + 1, + *cudf::make_sum_aggregation()), + cudf::logic_error); // mismatch for the forward window array - EXPECT_THROW( - cudf::rolling_window(input, five_elements, four_elements, 1, cudf::make_sum_aggregation()), - cudf::logic_error); + EXPECT_THROW(cudf::rolling_window(input, + five_elements, + four_elements, + 1, + *cudf::make_sum_aggregation()), + cudf::logic_error); } TEST_F(RollingErrorTest, EmptyInput) { cudf::test::fixed_width_column_wrapper empty_col{}; std::unique_ptr output; - EXPECT_NO_THROW(output = cudf::rolling_window(empty_col, 2, 0, 2, cudf::make_sum_aggregation())); + EXPECT_NO_THROW(output = cudf::rolling_window( + empty_col, 2, 0, 2, *cudf::make_sum_aggregation())); EXPECT_EQ(output->size(), 0); fixed_width_column_wrapper preceding_window{}; fixed_width_column_wrapper following_window{}; - EXPECT_NO_THROW( - output = cudf::rolling_window( - empty_col, preceding_window, following_window, 2, cudf::make_sum_aggregation())); + EXPECT_NO_THROW(output = + cudf::rolling_window(empty_col, + preceding_window, + following_window, + 2, + *cudf::make_sum_aggregation())); EXPECT_EQ(output->size(), 0); fixed_width_column_wrapper nonempty_col{{1, 2, 3}}; - EXPECT_NO_THROW( - output = cudf::rolling_window( - nonempty_col, preceding_window, following_window, 2, cudf::make_sum_aggregation())); + EXPECT_NO_THROW(output = + cudf::rolling_window(nonempty_col, + preceding_window, + following_window, + 2, + *cudf::make_sum_aggregation())); EXPECT_EQ(output->size(), 0); } @@ -445,16 +497,22 @@ TEST_F(RollingErrorTest, SizeMismatch) fixed_width_column_wrapper preceding_window{{1, 1}}; // wrong size fixed_width_column_wrapper following_window{{1, 1, 1}}; EXPECT_THROW( - output = cudf::rolling_window( - nonempty_col, preceding_window, following_window, 2, cudf::make_sum_aggregation()), + output = cudf::rolling_window(nonempty_col, + preceding_window, + following_window, + 2, + *cudf::make_sum_aggregation()), cudf::logic_error); } { fixed_width_column_wrapper preceding_window{{1, 1, 1}}; fixed_width_column_wrapper following_window{{1, 2}}; // wrong size EXPECT_THROW( - output = cudf::rolling_window( - nonempty_col, preceding_window, following_window, 2, cudf::make_sum_aggregation()), + output = cudf::rolling_window(nonempty_col, + preceding_window, + following_window, + 2, + *cudf::make_sum_aggregation()), cudf::logic_error); } } @@ -466,9 +524,13 @@ TEST_F(RollingErrorTest, WindowWrongDtype) fixed_width_column_wrapper preceding_window{{1.0f, 1.0f, 1.0f}}; fixed_width_column_wrapper following_window{{1.0f, 1.0f, 1.0f}}; - EXPECT_THROW(output = cudf::rolling_window( - nonempty_col, preceding_window, following_window, 2, cudf::make_sum_aggregation()), - cudf::logic_error); + EXPECT_THROW( + output = cudf::rolling_window(nonempty_col, + preceding_window, + following_window, + 2, + *cudf::make_sum_aggregation()), + cudf::logic_error); } // incorrect type/aggregation combo: sum of timestamps @@ -486,15 +548,20 @@ TEST_F(RollingErrorTest, SumTimestampNotSupported) fixed_width_column_wrapper input_ns( thrust::make_counting_iterator(0), thrust::make_counting_iterator(size)); - EXPECT_THROW(cudf::rolling_window(input_D, 2, 2, 0, cudf::make_sum_aggregation()), + EXPECT_THROW(cudf::rolling_window( + input_D, 2, 2, 0, *cudf::make_sum_aggregation()), cudf::logic_error); - EXPECT_THROW(cudf::rolling_window(input_s, 2, 2, 0, cudf::make_sum_aggregation()), + EXPECT_THROW(cudf::rolling_window( + input_s, 2, 2, 0, *cudf::make_sum_aggregation()), cudf::logic_error); - EXPECT_THROW(cudf::rolling_window(input_ms, 2, 2, 0, cudf::make_sum_aggregation()), + EXPECT_THROW(cudf::rolling_window( + input_ms, 2, 2, 0, *cudf::make_sum_aggregation()), cudf::logic_error); - EXPECT_THROW(cudf::rolling_window(input_us, 2, 2, 0, cudf::make_sum_aggregation()), + EXPECT_THROW(cudf::rolling_window( + input_us, 2, 2, 0, *cudf::make_sum_aggregation()), cudf::logic_error); - EXPECT_THROW(cudf::rolling_window(input_ns, 2, 2, 0, cudf::make_sum_aggregation()), + EXPECT_THROW(cudf::rolling_window( + input_ns, 2, 2, 0, *cudf::make_sum_aggregation()), cudf::logic_error); } @@ -513,15 +580,20 @@ TEST_F(RollingErrorTest, MeanTimestampNotSupported) fixed_width_column_wrapper input_ns( thrust::make_counting_iterator(0), thrust::make_counting_iterator(size)); - EXPECT_THROW(cudf::rolling_window(input_D, 2, 2, 0, cudf::make_mean_aggregation()), + EXPECT_THROW(cudf::rolling_window( + input_D, 2, 2, 0, *cudf::make_mean_aggregation()), cudf::logic_error); - EXPECT_THROW(cudf::rolling_window(input_s, 2, 2, 0, cudf::make_mean_aggregation()), + EXPECT_THROW(cudf::rolling_window( + input_s, 2, 2, 0, *cudf::make_mean_aggregation()), cudf::logic_error); - EXPECT_THROW(cudf::rolling_window(input_ms, 2, 2, 0, cudf::make_mean_aggregation()), + EXPECT_THROW(cudf::rolling_window( + input_ms, 2, 2, 0, *cudf::make_mean_aggregation()), cudf::logic_error); - EXPECT_THROW(cudf::rolling_window(input_us, 2, 2, 0, cudf::make_mean_aggregation()), + EXPECT_THROW(cudf::rolling_window( + input_us, 2, 2, 0, *cudf::make_mean_aggregation()), cudf::logic_error); - EXPECT_THROW(cudf::rolling_window(input_ns, 2, 2, 0, cudf::make_mean_aggregation()), + EXPECT_THROW(cudf::rolling_window( + input_ns, 2, 2, 0, *cudf::make_mean_aggregation()), cudf::logic_error); } @@ -755,22 +827,24 @@ TEST_F(RollingTestStrings, StringsUnsupportedOperators) std::vector window{1}; - EXPECT_THROW(cudf::rolling_window(input, 2, 2, 0, cudf::make_sum_aggregation()), - cudf::logic_error); - EXPECT_THROW(cudf::rolling_window(input, 2, 2, 0, cudf::make_mean_aggregation()), - cudf::logic_error); - EXPECT_THROW(cudf::rolling_window( - input, - 2, - 2, - 0, - cudf::make_udf_aggregation(cudf::udf_type::PTX, std::string{}, cudf::data_type{})), + EXPECT_THROW( + cudf::rolling_window(input, 2, 2, 0, *cudf::make_sum_aggregation()), + cudf::logic_error); + EXPECT_THROW( + cudf::rolling_window(input, 2, 2, 0, *cudf::make_mean_aggregation()), + cudf::logic_error); + EXPECT_THROW(cudf::rolling_window(input, + 2, + 2, + 0, + *cudf::make_udf_aggregation( + cudf::udf_type::PTX, std::string{}, cudf::data_type{})), cudf::logic_error); EXPECT_THROW(cudf::rolling_window(input, 2, 2, 0, - cudf::make_udf_aggregation( + *cudf::make_udf_aggregation( cudf::udf_type::CUDA, std::string{}, cudf::data_type{})), cudf::logic_error); } @@ -891,18 +965,18 @@ TEST_F(RollingTestUdf, StaticWindow) fixed_width_column_wrapper expected{start, start + size, valid}; // Test CUDA UDF - auto cuda_udf_agg = cudf::make_udf_aggregation( + auto cuda_udf_agg = cudf::make_udf_aggregation( cudf::udf_type::CUDA, this->cuda_func, cudf::data_type{cudf::type_id::INT64}); - output = cudf::rolling_window(input, 2, 2, 4, cuda_udf_agg); + output = cudf::rolling_window(input, 2, 2, 4, *cuda_udf_agg); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*output, expected); // Test NUMBA UDF - auto ptx_udf_agg = cudf::make_udf_aggregation( + auto ptx_udf_agg = cudf::make_udf_aggregation( cudf::udf_type::PTX, this->ptx_func, cudf::data_type{cudf::type_id::INT64}); - output = cudf::rolling_window(input, 2, 2, 4, ptx_udf_agg); + output = cudf::rolling_window(input, 2, 2, 4, *ptx_udf_agg); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*output, expected); } @@ -937,18 +1011,18 @@ TEST_F(RollingTestUdf, DynamicWindow) fixed_width_column_wrapper expected{start, start + size, valid}; // Test CUDA UDF - auto cuda_udf_agg = cudf::make_udf_aggregation( + auto cuda_udf_agg = cudf::make_udf_aggregation( cudf::udf_type::CUDA, this->cuda_func, cudf::data_type{cudf::type_id::INT64}); - output = cudf::rolling_window(input, preceding, following, 2, cuda_udf_agg); + output = cudf::rolling_window(input, preceding, following, 2, *cuda_udf_agg); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*output, expected); // Test PTX UDF - auto ptx_udf_agg = cudf::make_udf_aggregation( + auto ptx_udf_agg = cudf::make_udf_aggregation( cudf::udf_type::PTX, this->ptx_func, cudf::data_type{cudf::type_id::INT64}); - output = cudf::rolling_window(input, preceding, following, 2, ptx_udf_agg); + output = cudf::rolling_window(input, preceding, following, 2, *ptx_udf_agg); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*output, expected); } @@ -979,13 +1053,20 @@ TYPED_TEST(FixedPointTests, MinMaxCountLagLead) auto const expected_rowno = fw_wrapper{{1, 2, 2, 2, 2, 2}, {1, 1, 1, 1, 1, 1}}; auto const expected_rowno1 = fw_wrapper{{1, 1, 1, 1, 1, 1}, {1, 1, 1, 1, 1, 1}}; - auto const min = rolling_window(input, 2, 1, 1, make_min_aggregation()); - auto const max = rolling_window(input, 2, 1, 1, make_max_aggregation()); - auto const lag = rolling_window(input, 2, 1, 1, make_lag_aggregation(1)); - auto const lead = rolling_window(input, 2, 1, 1, make_lead_aggregation(1)); - auto const valid = rolling_window(input, 2, 1, 1, make_count_aggregation()); - auto const all = rolling_window(input, 2, 1, 1, make_count_aggregation(null_policy::INCLUDE)); - auto const rowno = rolling_window(input, 2, 1, 1, make_row_number_aggregation()); + auto const min = + rolling_window(input, 2, 1, 1, *make_min_aggregation()); + auto const max = + rolling_window(input, 2, 1, 1, *make_max_aggregation()); + auto const lag = + rolling_window(input, 2, 1, 1, *make_lag_aggregation(1)); + auto const lead = + rolling_window(input, 2, 1, 1, *make_lead_aggregation(1)); + auto const valid = + rolling_window(input, 2, 1, 1, *make_count_aggregation()); + auto const all = rolling_window( + input, 2, 1, 1, *make_count_aggregation(null_policy::INCLUDE)); + auto const rowno = + rolling_window(input, 2, 1, 1, *make_row_number_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_min, min->view()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_max, max->view()); @@ -997,7 +1078,8 @@ TYPED_TEST(FixedPointTests, MinMaxCountLagLead) // ROW_NUMBER will always return row 1 if the preceding window is set to a constant 1 for (int following = 1; following < 5; ++following) { - auto const rowno1 = rolling_window(input, 1, following, 1, make_row_number_aggregation()); + auto const rowno1 = rolling_window( + input, 1, following, 1, *make_row_number_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_rowno1, rowno1->view()); } } @@ -1023,14 +1105,22 @@ TYPED_TEST(FixedPointTests, MinMaxCountLagLeadNulls) auto const expected_count_all = fw_wrapper{{2, 3, 3, 3, 3, 2}, {1, 1, 1, 1, 1, 1}}; auto const expected_rowno = fw_wrapper{{1, 2, 2, 2, 2, 2}, {1, 1, 1, 1, 1, 1}}; - auto const sum = rolling_window(input, 2, 1, 1, make_sum_aggregation()); - auto const min = rolling_window(input, 2, 1, 1, make_min_aggregation()); - auto const max = rolling_window(input, 2, 1, 1, make_max_aggregation()); - auto const lag = rolling_window(input, 2, 1, 1, make_lag_aggregation(1)); - auto const lead = rolling_window(input, 2, 1, 1, make_lead_aggregation(1)); - auto const valid = rolling_window(input, 2, 1, 1, make_count_aggregation()); - auto const all = rolling_window(input, 2, 1, 1, make_count_aggregation(null_policy::INCLUDE)); - auto const rowno = rolling_window(input, 2, 1, 1, make_row_number_aggregation()); + auto const sum = + rolling_window(input, 2, 1, 1, *make_sum_aggregation()); + auto const min = + rolling_window(input, 2, 1, 1, *make_min_aggregation()); + auto const max = + rolling_window(input, 2, 1, 1, *make_max_aggregation()); + auto const lag = + rolling_window(input, 2, 1, 1, *make_lag_aggregation(1)); + auto const lead = + rolling_window(input, 2, 1, 1, *make_lead_aggregation(1)); + auto const valid = + rolling_window(input, 2, 1, 1, *make_count_aggregation()); + auto const all = rolling_window( + input, 2, 1, 1, *make_count_aggregation(null_policy::INCLUDE)); + auto const rowno = + rolling_window(input, 2, 1, 1, *make_row_number_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_sum, sum->view()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_min, min->view()); @@ -1040,13 +1130,6 @@ TYPED_TEST(FixedPointTests, MinMaxCountLagLeadNulls) CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_count_val, valid->view()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_count_all, all->view()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_rowno, rowno->view()); - - EXPECT_THROW(rolling_window(input, 2, 1, 1, make_product_aggregation()), cudf::logic_error); - EXPECT_THROW(rolling_window(input, 2, 1, 1, make_mean_aggregation()), cudf::logic_error); - EXPECT_THROW(rolling_window(input, 2, 1, 1, make_variance_aggregation()), cudf::logic_error); - EXPECT_THROW(rolling_window(input, 2, 1, 1, make_std_aggregation()), cudf::logic_error); - EXPECT_THROW(rolling_window(input, 2, 1, 1, make_sum_of_squares_aggregation()), - cudf::logic_error); } class RollingDictionaryTest : public cudf::test::BaseFixture { @@ -1064,10 +1147,16 @@ TEST_F(RollingDictionaryTest, Count) fixed_width_column_wrapper expected_row_number({1, 2, 2, 2, 2, 2, 2, 2, 2}, {1, 1, 1, 1, 1, 1, 1, 1, 1}); - auto got_count_valid = cudf::rolling_window(input, 2, 2, 1, cudf::make_count_aggregation()); - auto got_count_all = - cudf::rolling_window(input, 2, 2, 1, cudf::make_count_aggregation(cudf::null_policy::INCLUDE)); - auto got_row_number = cudf::rolling_window(input, 2, 2, 1, cudf::make_row_number_aggregation()); + auto got_count_valid = cudf::rolling_window( + input, 2, 2, 1, *cudf::make_count_aggregation()); + auto got_count_all = cudf::rolling_window( + input, + 2, + 2, + 1, + *cudf::make_count_aggregation(cudf::null_policy::INCLUDE)); + auto got_row_number = cudf::rolling_window( + input, 2, 2, 1, *cudf::make_row_number_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_count_val, got_count_valid->view()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_count_all, got_count_all->view()); @@ -1086,11 +1175,13 @@ TEST_F(RollingDictionaryTest, MinMax) {"This", "test", "test", "test", "test", "string", "string", "string", "string"}, {1, 1, 1, 1, 1, 1, 1, 1, 1}); - auto got_min_dict = cudf::rolling_window(input, 2, 2, 1, cudf::make_min_aggregation()); - auto got_min = cudf::dictionary::decode(cudf::dictionary_column_view(got_min_dict->view())); + auto got_min_dict = + cudf::rolling_window(input, 2, 2, 1, *cudf::make_min_aggregation()); + auto got_min = cudf::dictionary::decode(cudf::dictionary_column_view(got_min_dict->view())); - auto got_max_dict = cudf::rolling_window(input, 2, 2, 1, cudf::make_max_aggregation()); - auto got_max = cudf::dictionary::decode(cudf::dictionary_column_view(got_max_dict->view())); + auto got_max_dict = + cudf::rolling_window(input, 2, 2, 1, *cudf::make_max_aggregation()); + auto got_max = cudf::dictionary::decode(cudf::dictionary_column_view(got_max_dict->view())); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_min, got_min->view()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_max, got_max->view()); @@ -1106,11 +1197,13 @@ TEST_F(RollingDictionaryTest, LeadLag) cudf::test::strings_column_wrapper expected_lag( {"", "This", "", "", "test", "", "operated", "on", "string"}, {0, 1, 0, 0, 1, 0, 1, 1, 1}); - auto got_lead_dict = cudf::rolling_window(input, 2, 1, 1, cudf::make_lead_aggregation(1)); + auto got_lead_dict = cudf::rolling_window( + input, 2, 1, 1, *cudf::make_lead_aggregation(1)); auto got_lead = cudf::dictionary::decode(cudf::dictionary_column_view(got_lead_dict->view())); - auto got_lag_dict = cudf::rolling_window(input, 2, 2, 1, cudf::make_lag_aggregation(1)); - auto got_lag = cudf::dictionary::decode(cudf::dictionary_column_view(got_lag_dict->view())); + auto got_lag_dict = + cudf::rolling_window(input, 2, 2, 1, *cudf::make_lag_aggregation(1)); + auto got_lag = cudf::dictionary::decode(cudf::dictionary_column_view(got_lag_dict->view())); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_lead, got_lead->view()); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_lag, got_lag->view()); diff --git a/python/cudf/cudf/_lib/aggregation.pxd b/python/cudf/cudf/_lib/aggregation.pxd index 972f95d5aab..56fa9fdc63e 100644 --- a/python/cudf/cudf/_lib/aggregation.pxd +++ b/python/cudf/cudf/_lib/aggregation.pxd @@ -2,9 +2,14 @@ from libcpp.memory cimport unique_ptr from cudf._lib.cpp.aggregation cimport aggregation +from cudf._lib.cpp.aggregation cimport rolling_aggregation cdef class Aggregation: cdef unique_ptr[aggregation] c_obj +cdef class RollingAggregation: + cdef unique_ptr[rolling_aggregation] c_obj + cdef Aggregation make_aggregation(op, kwargs=*) +cdef RollingAggregation make_rolling_aggregation(op, kwargs=*) diff --git a/python/cudf/cudf/_lib/aggregation.pyx b/python/cudf/cudf/_lib/aggregation.pyx index b4d14c4fbc6..cda35025c7e 100644 --- a/python/cudf/cudf/_lib/aggregation.pyx +++ b/python/cudf/cudf/_lib/aggregation.pyx @@ -70,37 +70,43 @@ cdef class Aggregation: @classmethod def sum(cls): cdef Aggregation agg = cls() - agg.c_obj = move(libcudf_aggregation.make_sum_aggregation()) + agg.c_obj = move( + libcudf_aggregation.make_sum_aggregation[aggregation]()) return agg @classmethod def min(cls): cdef Aggregation agg = cls() - agg.c_obj = move(libcudf_aggregation.make_min_aggregation()) + agg.c_obj = move( + libcudf_aggregation.make_min_aggregation[aggregation]()) return agg @classmethod def max(cls): cdef Aggregation agg = cls() - agg.c_obj = move(libcudf_aggregation.make_max_aggregation()) + agg.c_obj = move( + libcudf_aggregation.make_max_aggregation[aggregation]()) return agg @classmethod def idxmin(cls): cdef Aggregation agg = cls() - agg.c_obj = move(libcudf_aggregation.make_argmin_aggregation()) + agg.c_obj = move( + libcudf_aggregation.make_argmin_aggregation[aggregation]()) return agg @classmethod def idxmax(cls): cdef Aggregation agg = cls() - agg.c_obj = move(libcudf_aggregation.make_argmax_aggregation()) + agg.c_obj = move( + libcudf_aggregation.make_argmax_aggregation[aggregation]()) return agg @classmethod def mean(cls): cdef Aggregation agg = cls() - agg.c_obj = move(libcudf_aggregation.make_mean_aggregation()) + agg.c_obj = move( + libcudf_aggregation.make_mean_aggregation[aggregation]()) return agg @classmethod @@ -112,76 +118,87 @@ cdef class Aggregation: c_null_handling = libcudf_types.null_policy.INCLUDE cdef Aggregation agg = cls() - agg.c_obj = move(libcudf_aggregation.make_count_aggregation( - c_null_handling - )) + agg.c_obj = move( + libcudf_aggregation.make_count_aggregation[aggregation]( + c_null_handling + )) return agg @classmethod def size(cls): cdef Aggregation agg = cls() - agg.c_obj = move(libcudf_aggregation.make_count_aggregation( - ( - NullHandling.INCLUDE - ) - )) + agg.c_obj = move( + libcudf_aggregation.make_count_aggregation[aggregation]( + ( + NullHandling.INCLUDE + ) + )) return agg @classmethod def nunique(cls): cdef Aggregation agg = cls() - agg.c_obj = move(libcudf_aggregation.make_nunique_aggregation()) + agg.c_obj = move( + libcudf_aggregation.make_nunique_aggregation[aggregation]()) return agg @classmethod def nth(cls, libcudf_types.size_type size): cdef Aggregation agg = cls() agg.c_obj = move( - libcudf_aggregation.make_nth_element_aggregation(size) - ) + libcudf_aggregation.make_nth_element_aggregation[aggregation]( + size)) return agg @classmethod def any(cls): cdef Aggregation agg = cls() - agg.c_obj = move(libcudf_aggregation.make_any_aggregation()) + agg.c_obj = move( + libcudf_aggregation.make_any_aggregation[aggregation]()) return agg @classmethod def all(cls): cdef Aggregation agg = cls() - agg.c_obj = move(libcudf_aggregation.make_all_aggregation()) + agg.c_obj = move( + libcudf_aggregation.make_all_aggregation[aggregation]()) return agg @classmethod def product(cls): cdef Aggregation agg = cls() - agg.c_obj = move(libcudf_aggregation.make_product_aggregation()) + agg.c_obj = move( + libcudf_aggregation.make_product_aggregation[aggregation]()) return agg prod = product @classmethod def sum_of_squares(cls): cdef Aggregation agg = cls() - agg.c_obj = move(libcudf_aggregation.make_sum_of_squares_aggregation()) + agg.c_obj = move( + libcudf_aggregation.make_sum_of_squares_aggregation[aggregation]() + ) return agg @classmethod def var(cls, ddof=1): cdef Aggregation agg = cls() - agg.c_obj = move(libcudf_aggregation.make_variance_aggregation(ddof)) + agg.c_obj = move( + libcudf_aggregation.make_variance_aggregation[aggregation](ddof)) return agg @classmethod def std(cls, ddof=1): cdef Aggregation agg = cls() - agg.c_obj = move(libcudf_aggregation.make_std_aggregation(ddof)) + agg.c_obj = move( + libcudf_aggregation.make_std_aggregation[aggregation](ddof)) return agg @classmethod def median(cls): cdef Aggregation agg = cls() - agg.c_obj = move(libcudf_aggregation.make_median_aggregation()) + agg.c_obj = move( + libcudf_aggregation.make_median_aggregation[aggregation]()) return agg @classmethod @@ -200,20 +217,23 @@ cdef class Aggregation: ) ) agg.c_obj = move( - libcudf_aggregation.make_quantile_aggregation(c_q, c_interp) + libcudf_aggregation.make_quantile_aggregation[aggregation]( + c_q, c_interp) ) return agg @classmethod def collect(cls): cdef Aggregation agg = cls() - agg.c_obj = move(libcudf_aggregation.make_collect_list_aggregation()) + agg.c_obj = move( + libcudf_aggregation.make_collect_list_aggregation[aggregation]()) return agg @classmethod def unique(cls): cdef Aggregation agg = cls() - agg.c_obj = move(libcudf_aggregation.make_collect_set_aggregation()) + agg.c_obj = move( + libcudf_aggregation.make_collect_set_aggregation[aggregation]()) return agg @classmethod @@ -244,9 +264,10 @@ cdef class Aggregation: ) out_dtype = libcudf_types.data_type(tid) - agg.c_obj = move(libcudf_aggregation.make_udf_aggregation( - libcudf_aggregation.udf_type.PTX, cpp_str, out_dtype - )) + agg.c_obj = move( + libcudf_aggregation.make_udf_aggregation[aggregation]( + libcudf_aggregation.udf_type.PTX, cpp_str, out_dtype + )) return agg # scan aggregations @@ -259,9 +280,154 @@ cdef class Aggregation: @classmethod def cumcount(cls): cdef Aggregation agg = cls() - agg.c_obj = move(libcudf_aggregation.make_count_aggregation( - libcudf_types.null_policy.INCLUDE - )) + agg.c_obj = move( + libcudf_aggregation.make_count_aggregation[aggregation]( + libcudf_types.null_policy.INCLUDE + )) + return agg + +cdef class RollingAggregation: + """A Cython wrapper for rolling window aggregations. + + **This class should never be instantiated using a standard constructor, + only using one of its many factories.** These factories handle mapping + different cudf operations to their libcudf analogs, e.g. + `cudf.DataFrame.idxmin` -> `libcudf.argmin`. Additionally, they perform + any additional configuration needed to translate Python arguments into + their corresponding C++ types (for instance, C++ enumerations used for + flag arguments). The factory approach is necessary to support operations + like `df.agg(lambda x: x.sum())`; such functions are called with this + class as an argument to generation the desired aggregation. + """ + @property + def kind(self): + return AggregationKind(self.c_obj.get()[0].kind).name + + @classmethod + def sum(cls): + cdef RollingAggregation agg = cls() + agg.c_obj = move( + libcudf_aggregation.make_sum_aggregation[rolling_aggregation]()) + return agg + + @classmethod + def min(cls): + cdef RollingAggregation agg = cls() + agg.c_obj = move( + libcudf_aggregation.make_min_aggregation[rolling_aggregation]()) + return agg + + @classmethod + def max(cls): + cdef RollingAggregation agg = cls() + agg.c_obj = move( + libcudf_aggregation.make_max_aggregation[rolling_aggregation]()) + return agg + + @classmethod + def idxmin(cls): + cdef RollingAggregation agg = cls() + agg.c_obj = move( + libcudf_aggregation.make_argmin_aggregation[ + rolling_aggregation]()) + return agg + + @classmethod + def idxmax(cls): + cdef RollingAggregation agg = cls() + agg.c_obj = move( + libcudf_aggregation.make_argmax_aggregation[ + rolling_aggregation]()) + return agg + + @classmethod + def mean(cls): + cdef RollingAggregation agg = cls() + agg.c_obj = move( + libcudf_aggregation.make_mean_aggregation[rolling_aggregation]()) + return agg + + @classmethod + def count(cls, dropna=True): + cdef libcudf_types.null_policy c_null_handling + if dropna: + c_null_handling = libcudf_types.null_policy.EXCLUDE + else: + c_null_handling = libcudf_types.null_policy.INCLUDE + + cdef RollingAggregation agg = cls() + agg.c_obj = move( + libcudf_aggregation.make_count_aggregation[rolling_aggregation]( + c_null_handling + )) + return agg + + @classmethod + def size(cls): + cdef RollingAggregation agg = cls() + agg.c_obj = move( + libcudf_aggregation.make_count_aggregation[rolling_aggregation]( + ( + NullHandling.INCLUDE) + )) + return agg + + @classmethod + def collect(cls): + cdef RollingAggregation agg = cls() + agg.c_obj = move( + libcudf_aggregation.make_collect_list_aggregation[ + rolling_aggregation]()) + return agg + + @classmethod + def from_udf(cls, op, *args, **kwargs): + cdef RollingAggregation agg = cls() + + cdef libcudf_types.type_id tid + cdef libcudf_types.data_type out_dtype + cdef string cpp_str + + # Handling UDF type + nb_type = numpy_support.from_dtype(kwargs['dtype']) + type_signature = (nb_type[:],) + compiled_op = cudautils.compile_udf(op, type_signature) + output_np_dtype = np.dtype(compiled_op[1]) + cpp_str = compiled_op[0].encode('UTF-8') + if output_np_dtype not in np_to_cudf_types: + raise TypeError( + "Result of window function has unsupported dtype {}" + .format(op[1]) + ) + tid = ( + ( + ( + np_to_cudf_types[output_np_dtype] + ) + ) + ) + out_dtype = libcudf_types.data_type(tid) + + agg.c_obj = move( + libcudf_aggregation.make_udf_aggregation[rolling_aggregation]( + libcudf_aggregation.udf_type.PTX, cpp_str, out_dtype + )) + return agg + + # scan aggregations + # TODO: update this after adding per algorithm aggregation derived types + # https://github.com/rapidsai/cudf/issues/7106 + cumsum = sum + cummin = min + cummax = max + + @classmethod + def cumcount(cls): + cdef RollingAggregation agg = cls() + agg.c_obj = move( + libcudf_aggregation.make_count_aggregation[rolling_aggregation]( + libcudf_types.null_policy.INCLUDE + )) return agg cdef Aggregation make_aggregation(op, kwargs=None): @@ -301,3 +467,41 @@ cdef Aggregation make_aggregation(op, kwargs=None): else: raise TypeError(f"Unknown aggregation {op}") return agg + +cdef RollingAggregation make_rolling_aggregation(op, kwargs=None): + r""" + Parameters + ---------- + op : str or callable + If callable, must meet one of the following requirements: + + * Is of the form lambda x: x.agg(*args, **kwargs), where + `agg` is the name of a supported aggregation. Used to + to specify aggregations that take arguments, e.g., + `lambda x: x.quantile(0.5)`. + * Is a user defined aggregation function that operates on + group values. In this case, the output dtype must be + specified in the `kwargs` dictionary. + \*\*kwargs : dict, optional + Any keyword arguments to be passed to the op. + + Returns + ------- + RollingAggregation + """ + if kwargs is None: + kwargs = {} + + cdef RollingAggregation agg + if isinstance(op, str): + agg = getattr(RollingAggregation, op)(**kwargs) + elif callable(op): + if op is list: + agg = RollingAggregation.collect() + elif "dtype" in kwargs: + agg = RollingAggregation.from_udf(op, **kwargs) + else: + agg = op(RollingAggregation) + else: + raise TypeError(f"Unknown aggregation {op}") + return agg diff --git a/python/cudf/cudf/_lib/cpp/aggregation.pxd b/python/cudf/cudf/_lib/cpp/aggregation.pxd index e9836c11361..839bdae7427 100644 --- a/python/cudf/cudf/_lib/cpp/aggregation.pxd +++ b/python/cudf/cudf/_lib/cpp/aggregation.pxd @@ -40,55 +40,58 @@ cdef extern from "cudf/aggregation.hpp" namespace "cudf" nogil: CUDA 'cudf::aggregation::CUDA' Kind kind + cdef cppclass rolling_aggregation: + aggregation.Kind kind + ctypedef enum udf_type: CUDA 'cudf::udf_type::CUDA' PTX 'cudf::udf_type::PTX' - cdef unique_ptr[aggregation] make_sum_aggregation() except + + cdef unique_ptr[T] make_sum_aggregation[T]() except + - cdef unique_ptr[aggregation] make_product_aggregation() except + + cdef unique_ptr[T] make_product_aggregation[T]() except + - cdef unique_ptr[aggregation] make_min_aggregation() except + + cdef unique_ptr[T] make_min_aggregation[T]() except + - cdef unique_ptr[aggregation] make_max_aggregation() except + + cdef unique_ptr[T] make_max_aggregation[T]() except + - cdef unique_ptr[aggregation] make_count_aggregation() except + + cdef unique_ptr[T] make_count_aggregation[T]() except + - cdef unique_ptr[aggregation] make_count_aggregation(null_policy) except + + cdef unique_ptr[T] make_count_aggregation[T](null_policy) except + - cdef unique_ptr[aggregation] make_any_aggregation() except + + cdef unique_ptr[T] make_any_aggregation[T]() except + - cdef unique_ptr[aggregation] make_all_aggregation() except + + cdef unique_ptr[T] make_all_aggregation[T]() except + - cdef unique_ptr[aggregation] make_sum_of_squares_aggregation() except + + cdef unique_ptr[T] make_sum_of_squares_aggregation[T]() except + - cdef unique_ptr[aggregation] make_mean_aggregation() except + + cdef unique_ptr[T] make_mean_aggregation[T]() except + - cdef unique_ptr[aggregation] make_variance_aggregation( + cdef unique_ptr[T] make_variance_aggregation[T]( size_type ddof) except + - cdef unique_ptr[aggregation] make_std_aggregation(size_type ddof) except + + cdef unique_ptr[T] make_std_aggregation[T](size_type ddof) except + - cdef unique_ptr[aggregation] make_median_aggregation() except + + cdef unique_ptr[T] make_median_aggregation[T]() except + - cdef unique_ptr[aggregation] make_quantile_aggregation( + cdef unique_ptr[T] make_quantile_aggregation[T]( vector[double] q, interpolation i) except + - cdef unique_ptr[aggregation] make_argmax_aggregation() except + + cdef unique_ptr[T] make_argmax_aggregation[T]() except + - cdef unique_ptr[aggregation] make_argmin_aggregation() except + + cdef unique_ptr[T] make_argmin_aggregation[T]() except + - cdef unique_ptr[aggregation] make_nunique_aggregation() except + + cdef unique_ptr[T] make_nunique_aggregation[T]() except + - cdef unique_ptr[aggregation] make_nth_element_aggregation( + cdef unique_ptr[T] make_nth_element_aggregation[T]( size_type n ) except + - cdef unique_ptr[aggregation] make_collect_list_aggregation() except + + cdef unique_ptr[T] make_collect_list_aggregation[T]() except + - cdef unique_ptr[aggregation] make_collect_set_aggregation() except + + cdef unique_ptr[T] make_collect_set_aggregation[T]() except + - cdef unique_ptr[aggregation] make_udf_aggregation( + cdef unique_ptr[T] make_udf_aggregation[T]( udf_type type, string user_defined_aggregator, data_type output_type) except + diff --git a/python/cudf/cudf/_lib/cpp/rolling.pxd b/python/cudf/cudf/_lib/cpp/rolling.pxd index 9402f1552c3..4ccc0f5ae9b 100644 --- a/python/cudf/cudf/_lib/cpp/rolling.pxd +++ b/python/cudf/cudf/_lib/cpp/rolling.pxd @@ -7,7 +7,7 @@ from cudf._lib.types import np_to_cudf_types, cudf_to_np_types from cudf._lib.cpp.types cimport size_type from cudf._lib.cpp.column.column cimport column from cudf._lib.cpp.column.column_view cimport column_view -from cudf._lib.cpp.aggregation cimport aggregation +from cudf._lib.cpp.aggregation cimport rolling_aggregation cdef extern from "cudf/rolling.hpp" namespace "cudf" nogil: @@ -16,11 +16,11 @@ cdef extern from "cudf/rolling.hpp" namespace "cudf" nogil: column_view preceding_window, column_view following_window, size_type min_periods, - unique_ptr[aggregation] agg) except + + rolling_aggregation agg) except + cdef unique_ptr[column] rolling_window( column_view source, size_type preceding_window, size_type following_window, size_type min_periods, - unique_ptr[aggregation] agg) except + + rolling_aggregation agg) except + diff --git a/python/cudf/cudf/_lib/rolling.pyx b/python/cudf/cudf/_lib/rolling.pyx index d67fb431ec4..6fe661a25a5 100644 --- a/python/cudf/cudf/_lib/rolling.pyx +++ b/python/cudf/cudf/_lib/rolling.pyx @@ -8,7 +8,7 @@ from libcpp.memory cimport unique_ptr from libcpp.utility cimport move from cudf._lib.column cimport Column -from cudf._lib.aggregation cimport Aggregation, make_aggregation +from cudf._lib.aggregation cimport RollingAggregation, make_rolling_aggregation from cudf._lib.cpp.types cimport size_type from cudf._lib.cpp.column.column cimport column @@ -46,12 +46,13 @@ def rolling(Column source_column, Column pre_column_window, cdef column_view source_column_view = source_column.view() cdef column_view pre_column_window_view cdef column_view fwd_column_window_view - cdef Aggregation cython_agg + cdef RollingAggregation cython_agg if callable(op): - cython_agg = make_aggregation(op, {'dtype': source_column.dtype}) + cython_agg = make_rolling_aggregation( + op, {'dtype': source_column.dtype}) else: - cython_agg = make_aggregation(op) + cython_agg = make_rolling_aggregation(op) if window is None: if center: @@ -68,7 +69,7 @@ def rolling(Column source_column, Column pre_column_window, pre_column_window_view, fwd_column_window_view, c_min_periods, - cython_agg.c_obj) + cython_agg.c_obj.get()[0]) ) else: c_min_periods = min_periods @@ -86,7 +87,7 @@ def rolling(Column source_column, Column pre_column_window, c_window, c_forward_window, c_min_periods, - cython_agg.c_obj) + cython_agg.c_obj.get()[0]) ) return Column.from_unique_ptr(move(c_result))