Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change aggregation class hierarchy to allow per-algorithm type enforcement. #8052

Merged
merged 15 commits into from
May 7, 2021
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 73 additions & 36 deletions cpp/include/cudf/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,16 @@ namespace cudf {

// forward declaration
namespace detail {
class simple_aggregations_collector;
class aggregation_finalizer;
} // namespace detail
/**
* @brief Base class for specifying the desired aggregation in an
* @brief Abstract base class for specifying the desired aggregation in an
* `aggregation_request`.
*
* Other kinds of aggregations may derive from this class to encapsulate
* additional information needed to compute the aggregation.
* All aggregations must derive from this class to implement the pure virtual
* functions and potentially encapsulate additional information needed to
* compute the aggregation.
*/
class aggregation {
public:
Expand Down Expand Up @@ -82,108 +84,136 @@ class aggregation {
CUDA ///< CUDA UDF based reduction
};

aggregation() = delete;
aggregation(aggregation::Kind a) : kind{a} {}
Kind kind; ///< The aggregation to perform
virtual ~aggregation() = default;

virtual bool is_equal(aggregation const& other) const { return kind == other.kind; }

virtual size_t do_hash() const { return std::hash<int>{}(kind); }
virtual std::unique_ptr<aggregation> clone() const = 0;

virtual std::unique_ptr<aggregation> clone() const
{
return std::make_unique<aggregation>(*this);
}
// override functions for compound aggregations
virtual std::vector<std::unique_ptr<aggregation>> get_simple_aggregations(
data_type col_type, cudf::detail::simple_aggregations_collector& collector) const = 0;
virtual void finalize(cudf::detail::aggregation_finalizer& finalizer) = 0;
};

virtual ~aggregation() = default;
/**
* @brief Derived class intended for enforcing operation-specific restrictions
* when calling various cudf functions.
*
* As an example, rolling_window will only accept rolling_aggregation inputs,
* and the appropriate derived classes (sum_aggregation, mean_aggregation, etc)
* derive from this interface to represent these valid options.
*/
class rolling_aggregation : public virtual aggregation {
public:
~rolling_aggregation() = default;

// override functions for compound aggregations
virtual std::vector<aggregation::Kind> get_simple_aggregations(data_type col_type) const;
virtual void finalize(cudf::detail::aggregation_finalizer& finalizer);
protected:
rolling_aggregation(){};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is a ctor needed at all in this type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It gets implicitly called by all the derived classes. eg

class count_aggregation final : public rolling_aggregation {
 public:
  count_aggregation(aggregation::Kind kind) : aggregation(kind) {}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but why isn't the compiler provided one sufficient?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

};

enum class udf_type : bool { CUDA, PTX };

/// Factory to create a SUM aggregation
std::unique_ptr<aggregation> make_sum_aggregation();
template <typename Base = aggregation>
extern std::unique_ptr<Base> make_sum_aggregation();
nvdbaranec marked this conversation as resolved.
Show resolved Hide resolved

/// Factory to create a PRODUCT aggregation
std::unique_ptr<aggregation> make_product_aggregation();
template <typename Base = aggregation>
extern std::unique_ptr<Base> make_product_aggregation();

/// Factory to create a MIN aggregation
std::unique_ptr<aggregation> make_min_aggregation();
template <typename Base = aggregation>
extern std::unique_ptr<Base> make_min_aggregation();

/// Factory to create a MAX aggregation
std::unique_ptr<aggregation> make_max_aggregation();
template <typename Base = aggregation>
extern std::unique_ptr<Base> make_max_aggregation();

/**
* @brief Factory to create a COUNT aggregation
*
* @param null_handling Indicates if null values will be counted.
*/
std::unique_ptr<aggregation> make_count_aggregation(
template <typename Base = aggregation>
extern std::unique_ptr<Base> make_count_aggregation(
null_policy null_handling = null_policy::EXCLUDE);

/// Factory to create a ANY aggregation
std::unique_ptr<aggregation> make_any_aggregation();
/// Factory to create an ANY aggregation
template <typename Base = aggregation>
extern std::unique_ptr<Base> make_any_aggregation();

/// Factory to create a ALL aggregation
std::unique_ptr<aggregation> make_all_aggregation();
template <typename Base = aggregation>
extern std::unique_ptr<Base> make_all_aggregation();

/// Factory to create a SUM_OF_SQUARES aggregation
std::unique_ptr<aggregation> make_sum_of_squares_aggregation();
template <typename Base = aggregation>
extern std::unique_ptr<Base> make_sum_of_squares_aggregation();

/// Factory to create a MEAN aggregation
std::unique_ptr<aggregation> make_mean_aggregation();
template <typename Base = aggregation>
extern std::unique_ptr<Base> make_mean_aggregation();

/**
* @brief Factory to create a VARIANCE aggregation
*
* @param ddof Delta degrees of freedom. The divisor used in calculation of
* `variance` is `N - ddof`, where `N` is the population size.
*/
std::unique_ptr<aggregation> make_variance_aggregation(size_type ddof = 1);
template <typename Base = aggregation>
extern std::unique_ptr<Base> make_variance_aggregation(size_type ddof = 1);

/**
* @brief Factory to create a STD aggregation
*
* @param ddof Delta degrees of freedom. The divisor used in calculation of
* `std` is `N - ddof`, where `N` is the population size.
*/
std::unique_ptr<aggregation> make_std_aggregation(size_type ddof = 1);
template <typename Base = aggregation>
extern std::unique_ptr<Base> make_std_aggregation(size_type ddof = 1);

/// Factory to create a MEDIAN aggregation
std::unique_ptr<aggregation> make_median_aggregation();
template <typename Base = aggregation>
extern std::unique_ptr<Base> make_median_aggregation();

/**
* @brief Factory to create a QUANTILE aggregation
*
* @param quantiles The desired quantiles
* @param interpolation The desired interpolation
*/
std::unique_ptr<aggregation> make_quantile_aggregation(std::vector<double> const& q,
template <typename Base = aggregation>
extern std::unique_ptr<Base> make_quantile_aggregation(std::vector<double> const& q,
interpolation i = interpolation::LINEAR);

/**
* @brief Factory to create an `argmax` aggregation
*
* `argmax` returns the index of the maximum element.
*/
std::unique_ptr<aggregation> make_argmax_aggregation();
template <typename Base = aggregation>
extern std::unique_ptr<Base> make_argmax_aggregation();

/**
* @brief Factory to create an `argmin` aggregation
*
* `argmin` returns the index of the minimum element.
*/
std::unique_ptr<aggregation> make_argmin_aggregation();
template <typename Base = aggregation>
extern std::unique_ptr<Base> make_argmin_aggregation();

/**
* @brief Factory to create a `nunique` aggregation
*
* `nunique` returns the number of unique elements.
* @param null_handling Indicates if null values will be counted.
*/
std::unique_ptr<aggregation> make_nunique_aggregation(
template <typename Base = aggregation>
extern std::unique_ptr<Base> make_nunique_aggregation(
null_policy null_handling = null_policy::EXCLUDE);

/**
Expand All @@ -199,11 +229,13 @@ std::unique_ptr<aggregation> make_nunique_aggregation(
* @param n index of nth element in each group.
* @param null_handling Indicates to include/exclude nulls during indexing.
*/
std::unique_ptr<aggregation> make_nth_element_aggregation(
template <typename Base = aggregation>
extern std::unique_ptr<Base> make_nth_element_aggregation(
size_type n, null_policy null_handling = null_policy::INCLUDE);

/// Factory to create a ROW_NUMBER aggregation
std::unique_ptr<aggregation> make_row_number_aggregation();
template <typename Base = aggregation>
extern std::unique_ptr<Base> make_row_number_aggregation();

/**
* @brief Factory to create a COLLECT_LIST aggregation
Expand All @@ -215,7 +247,8 @@ std::unique_ptr<aggregation> make_row_number_aggregation();
*
* @param null_handling Indicates whether to include/exclude nulls in list elements.
*/
std::unique_ptr<aggregation> make_collect_list_aggregation(
template <typename Base = aggregation>
extern std::unique_ptr<Base> make_collect_list_aggregation(
null_policy null_handling = null_policy::INCLUDE);

/**
Expand All @@ -233,16 +266,19 @@ std::unique_ptr<aggregation> make_collect_list_aggregation(
* @param nans_equal Flag to specify whether NaN values in floating point column should be
* considered equal
*/
std::unique_ptr<aggregation> make_collect_set_aggregation(
template <typename Base = aggregation>
extern std::unique_ptr<Base> make_collect_set_aggregation(
null_policy null_handling = null_policy::INCLUDE,
null_equality nulls_equal = null_equality::EQUAL,
nan_equality nans_equal = nan_equality::UNEQUAL);

/// Factory to create a LAG aggregation
std::unique_ptr<aggregation> make_lag_aggregation(size_type offset);
template <typename Base = aggregation>
extern std::unique_ptr<Base> make_lag_aggregation(size_type offset);

/// Factory to create a LEAD aggregation
std::unique_ptr<aggregation> make_lead_aggregation(size_type offset);
template <typename Base = aggregation>
extern std::unique_ptr<Base> make_lead_aggregation(size_type offset);

/**
* @brief Factory to create an aggregation base on UDF for PTX or CUDA
Expand All @@ -253,7 +289,8 @@ std::unique_ptr<aggregation> make_lead_aggregation(size_type offset);
*
* @return aggregation unique pointer housing user_defined_aggregator string.
*/
std::unique_ptr<aggregation> make_udf_aggregation(udf_type type,
template <typename Base = aggregation>
extern std::unique_ptr<Base> make_udf_aggregation(udf_type type,
std::string const& user_defined_aggregator,
data_type output_type);

Expand Down
Loading