Skip to content

Commit

Permalink
Change aggregation class hierarchy to allow per-algorithm type enforc…
Browse files Browse the repository at this point in the history
…ement. (#8052)

Partially addresses  #7106

Fundamentally, this changes the aggregation class hierarchy in the following ways:
- The base `aggregation` class becomes abstract, with the `clone()` and` finalize()` functions being pure virtual.
- Every aggregation type now has a concrete class associated with it, derived from `aggregation`.
- "Intermediate" classes such as `rolling_aggregation` are used to allow individual algorithms to only accept aggregation types that are valid for it (as opposed to enforcing this internally at runtime).

All of the rolling_window interfaces have been updated to take a `rolling_aggregation`.  Other algorithms such as groupby are not yet converted and still take generic `aggregation` objects.

Marking this as Do Not Merge for now since this is a breaking change with immediately implications for Spark.

Authors:
  - https://github.com/nvdbaranec

Approvers:
  - Jake Hemstad (https://github.com/jrhemstad)
  - Mike Wilson (https://github.com/hyperbolic2346)
  - Vyas Ramasubramani (https://github.com/vyasr)
  - Ashwin Srinath (https://github.com/shwina)

URL: #8052
  • Loading branch information
nvdbaranec authored May 7, 2021
1 parent 5f9dade commit e2c7067
Show file tree
Hide file tree
Showing 21 changed files with 2,569 additions and 1,336 deletions.
122 changes: 78 additions & 44 deletions cpp/include/cudf/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,16 @@ namespace cudf {

// forward declaration
namespace detail {
class simple_aggregations_collector;
class aggregation_finalizer;
} // namespace detail
/**
* @brief Base class for specifying the desired aggregation in an
* @brief Abstract base class for specifying the desired aggregation in an
* `aggregation_request`.
*
* Other kinds of aggregations may derive from this class to encapsulate
* additional information needed to compute the aggregation.
* All aggregations must derive from this class to implement the pure virtual
* functions and potentially encapsulate additional information needed to
* compute the aggregation.
*/
class aggregation {
public:
Expand Down Expand Up @@ -82,109 +84,135 @@ class aggregation {
CUDA ///< CUDA UDF based reduction
};

aggregation() = delete;
aggregation(aggregation::Kind a) : kind{a} {}
Kind kind; ///< The aggregation to perform
virtual ~aggregation() = default;

virtual bool is_equal(aggregation const& other) const { return kind == other.kind; }

virtual size_t do_hash() const { return std::hash<int>{}(kind); }
virtual std::unique_ptr<aggregation> clone() const = 0;

virtual std::unique_ptr<aggregation> clone() const
{
return std::make_unique<aggregation>(*this);
}
// override functions for compound aggregations
virtual std::vector<std::unique_ptr<aggregation>> get_simple_aggregations(
data_type col_type, cudf::detail::simple_aggregations_collector& collector) const = 0;
virtual void finalize(cudf::detail::aggregation_finalizer& finalizer) const = 0;
};

virtual ~aggregation() = default;
/**
* @brief Derived class intended for enforcing operation-specific restrictions
* when calling various cudf functions.
*
* As an example, rolling_window will only accept rolling_aggregation inputs,
* and the appropriate derived classes (sum_aggregation, mean_aggregation, etc)
* derive from this interface to represent these valid options.
*/
class rolling_aggregation : public virtual aggregation {
public:
~rolling_aggregation() = default;

// override functions for compound aggregations
virtual std::vector<aggregation::Kind> get_simple_aggregations(data_type col_type) const;
virtual void finalize(cudf::detail::aggregation_finalizer& finalizer);
protected:
rolling_aggregation() {}
};

enum class udf_type : bool { CUDA, PTX };

/// Factory to create a SUM aggregation
std::unique_ptr<aggregation> make_sum_aggregation();
template <typename Base = aggregation>
std::unique_ptr<Base> make_sum_aggregation();

/// Factory to create a PRODUCT aggregation
std::unique_ptr<aggregation> make_product_aggregation();
template <typename Base = aggregation>
std::unique_ptr<Base> make_product_aggregation();

/// Factory to create a MIN aggregation
std::unique_ptr<aggregation> make_min_aggregation();
template <typename Base = aggregation>
std::unique_ptr<Base> make_min_aggregation();

/// Factory to create a MAX aggregation
std::unique_ptr<aggregation> make_max_aggregation();
template <typename Base = aggregation>
std::unique_ptr<Base> make_max_aggregation();

/**
* @brief Factory to create a COUNT aggregation
*
* @param null_handling Indicates if null values will be counted.
*/
std::unique_ptr<aggregation> make_count_aggregation(
null_policy null_handling = null_policy::EXCLUDE);
template <typename Base = aggregation>
std::unique_ptr<Base> make_count_aggregation(null_policy null_handling = null_policy::EXCLUDE);

/// Factory to create a ANY aggregation
std::unique_ptr<aggregation> make_any_aggregation();
/// Factory to create an ANY aggregation
template <typename Base = aggregation>
std::unique_ptr<Base> make_any_aggregation();

/// Factory to create a ALL aggregation
std::unique_ptr<aggregation> make_all_aggregation();
template <typename Base = aggregation>
std::unique_ptr<Base> make_all_aggregation();

/// Factory to create a SUM_OF_SQUARES aggregation
std::unique_ptr<aggregation> make_sum_of_squares_aggregation();
template <typename Base = aggregation>
std::unique_ptr<Base> make_sum_of_squares_aggregation();

/// Factory to create a MEAN aggregation
std::unique_ptr<aggregation> make_mean_aggregation();
template <typename Base = aggregation>
std::unique_ptr<Base> make_mean_aggregation();

/**
* @brief Factory to create a VARIANCE aggregation
*
* @param ddof Delta degrees of freedom. The divisor used in calculation of
* `variance` is `N - ddof`, where `N` is the population size.
*/
std::unique_ptr<aggregation> make_variance_aggregation(size_type ddof = 1);
template <typename Base = aggregation>
std::unique_ptr<Base> make_variance_aggregation(size_type ddof = 1);

/**
* @brief Factory to create a STD aggregation
*
* @param ddof Delta degrees of freedom. The divisor used in calculation of
* `std` is `N - ddof`, where `N` is the population size.
*/
std::unique_ptr<aggregation> make_std_aggregation(size_type ddof = 1);
template <typename Base = aggregation>
std::unique_ptr<Base> make_std_aggregation(size_type ddof = 1);

/// Factory to create a MEDIAN aggregation
std::unique_ptr<aggregation> make_median_aggregation();
template <typename Base = aggregation>
std::unique_ptr<Base> make_median_aggregation();

/**
* @brief Factory to create a QUANTILE aggregation
*
* @param quantiles The desired quantiles
* @param interpolation The desired interpolation
*/
std::unique_ptr<aggregation> make_quantile_aggregation(std::vector<double> const& q,
interpolation i = interpolation::LINEAR);
template <typename Base = aggregation>
std::unique_ptr<Base> make_quantile_aggregation(std::vector<double> const& q,
interpolation i = interpolation::LINEAR);

/**
* @brief Factory to create an `argmax` aggregation
*
* `argmax` returns the index of the maximum element.
*/
std::unique_ptr<aggregation> make_argmax_aggregation();
template <typename Base = aggregation>
std::unique_ptr<Base> make_argmax_aggregation();

/**
* @brief Factory to create an `argmin` aggregation
*
* `argmin` returns the index of the minimum element.
*/
std::unique_ptr<aggregation> make_argmin_aggregation();
template <typename Base = aggregation>
std::unique_ptr<Base> make_argmin_aggregation();

/**
* @brief Factory to create a `nunique` aggregation
*
* `nunique` returns the number of unique elements.
* @param null_handling Indicates if null values will be counted.
*/
std::unique_ptr<aggregation> make_nunique_aggregation(
null_policy null_handling = null_policy::EXCLUDE);
template <typename Base = aggregation>
std::unique_ptr<Base> make_nunique_aggregation(null_policy null_handling = null_policy::EXCLUDE);

/**
* @brief Factory to create a `nth_element` aggregation
Expand All @@ -199,11 +227,13 @@ std::unique_ptr<aggregation> make_nunique_aggregation(
* @param n index of nth element in each group.
* @param null_handling Indicates to include/exclude nulls during indexing.
*/
std::unique_ptr<aggregation> make_nth_element_aggregation(
template <typename Base = aggregation>
std::unique_ptr<Base> make_nth_element_aggregation(
size_type n, null_policy null_handling = null_policy::INCLUDE);

/// Factory to create a ROW_NUMBER aggregation
std::unique_ptr<aggregation> make_row_number_aggregation();
template <typename Base = aggregation>
std::unique_ptr<Base> make_row_number_aggregation();

/**
* @brief Factory to create a COLLECT_LIST aggregation
Expand All @@ -215,7 +245,8 @@ std::unique_ptr<aggregation> make_row_number_aggregation();
*
* @param null_handling Indicates whether to include/exclude nulls in list elements.
*/
std::unique_ptr<aggregation> make_collect_list_aggregation(
template <typename Base = aggregation>
std::unique_ptr<Base> make_collect_list_aggregation(
null_policy null_handling = null_policy::INCLUDE);

/**
Expand All @@ -233,16 +264,18 @@ std::unique_ptr<aggregation> make_collect_list_aggregation(
* @param nans_equal Flag to specify whether NaN values in floating point column should be
* considered equal
*/
std::unique_ptr<aggregation> make_collect_set_aggregation(
null_policy null_handling = null_policy::INCLUDE,
null_equality nulls_equal = null_equality::EQUAL,
nan_equality nans_equal = nan_equality::UNEQUAL);
template <typename Base = aggregation>
std::unique_ptr<Base> make_collect_set_aggregation(null_policy null_handling = null_policy::INCLUDE,
null_equality nulls_equal = null_equality::EQUAL,
nan_equality nans_equal = nan_equality::UNEQUAL);

/// Factory to create a LAG aggregation
std::unique_ptr<aggregation> make_lag_aggregation(size_type offset);
template <typename Base = aggregation>
std::unique_ptr<Base> make_lag_aggregation(size_type offset);

/// Factory to create a LEAD aggregation
std::unique_ptr<aggregation> make_lead_aggregation(size_type offset);
template <typename Base = aggregation>
std::unique_ptr<Base> make_lead_aggregation(size_type offset);

/**
* @brief Factory to create an aggregation base on UDF for PTX or CUDA
Expand All @@ -253,9 +286,10 @@ std::unique_ptr<aggregation> make_lead_aggregation(size_type offset);
*
* @return aggregation unique pointer housing user_defined_aggregator string.
*/
std::unique_ptr<aggregation> make_udf_aggregation(udf_type type,
std::string const& user_defined_aggregator,
data_type output_type);
template <typename Base = aggregation>
std::unique_ptr<Base> make_udf_aggregation(udf_type type,
std::string const& user_defined_aggregator,
data_type output_type);

/** @} */ // end of group
} // namespace cudf
Loading

0 comments on commit e2c7067

Please sign in to comment.