Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return weak orderings from device_row_comparator. #10793

Merged
merged 8 commits into from
May 11, 2022
71 changes: 55 additions & 16 deletions cpp/include/cudf/table/experimental/row_operators.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include <utility>

namespace cudf {

namespace experimental {
Comment on lines 47 to 49
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
namespace cudf {
namespace experimental {
namespace cudf::experimental {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file has several layers of nested namespaces. For consistency, I would recommend leaving this unchanged.


/**
Expand All @@ -68,16 +69,17 @@ struct dispatch_void_if_nested {
};

namespace row {

namespace lexicographic {

/**
* @brief Computes whether one row is lexicographically *less* than another row.
* @brief Computes the lexicographic comparison between 2 rows.
*
* Lexicographic ordering is determined by:
* - Two rows are compared element by element.
* - The first mismatching element defines which row is lexicographically less
* or greater than the other.
* - If the rows are compared without mismatched elements, the rows are equivalent
*
*
* Lexicographic ordering is exactly equivalent to doing an alphabetical sort of
* two words, for example, `aac` would be *less* than (or precede) `abb`. The
Expand All @@ -88,8 +90,8 @@ namespace lexicographic {
*/
template <typename Nullate>
class device_row_comparator {
// friend class device_less_comparator<Nullate>;
rwlee marked this conversation as resolved.
Show resolved Hide resolved
friend class self_comparator;

/**
* @brief Construct a function object for performing a lexicographic
* comparison between the rows of two tables.
Expand Down Expand Up @@ -145,7 +147,11 @@ class device_row_comparator {
column_device_view rhs,
null_order null_precedence = null_order::BEFORE,
int depth = 0)
: _lhs{lhs}, _rhs{rhs}, _nulls{check_nulls}, _null_precedence{null_precedence}, _depth{depth}
: _lhs{lhs},
_rhs{rhs},
_check_nulls{check_nulls},
_null_precedence{null_precedence},
_depth{depth}
{
}

Expand All @@ -162,7 +168,7 @@ class device_row_comparator {
__device__ cuda::std::pair<weak_ordering, int> operator()(
size_type const lhs_element_index, size_type const rhs_element_index) const noexcept
{
if (_nulls) {
if (_check_nulls) {
bool const lhs_is_null{_lhs.is_null(lhs_element_index)};
bool const rhs_is_null{_rhs.is_null(rhs_element_index)};

Expand Down Expand Up @@ -211,29 +217,30 @@ class device_row_comparator {
++depth;
}

auto const comparator = element_comparator{_nulls, lcol, rcol, _null_precedence, depth};
auto const comparator = element_comparator{_check_nulls, lcol, rcol, _null_precedence, depth};
return cudf::type_dispatcher<dispatch_void_if_nested>(
lcol.type(), comparator, lhs_element_index, rhs_element_index);
}

private:
column_device_view const _lhs;
column_device_view const _rhs;
Nullate const _nulls;
Nullate const _check_nulls;
null_order const _null_precedence;
int const _depth;
};

public:
/**
* @brief Checks whether the row at `lhs_index` in the `lhs` table compares
* lexicographically less than the row at `rhs_index` in the `rhs` table.
* lexicographically less, greater, or equivalent to the row at `rhs_index` in the `rhs` table.
*
* @param lhs_index The index of row in the `lhs` table to examine
* @param rhs_index The index of the row in the `rhs` table to examine
* @return `true` if row from the `lhs` table compares less than row in the `rhs` table
* @return weak ordering comparison of the row in the `lhs` table relative to the row in the `rhs`
* table
*/
__device__ bool operator()(size_type const lhs_index, size_type const rhs_index) const noexcept
__device__ weak_ordering operator()(size_type lhs_index, size_type rhs_index) const noexcept
{
int last_null_depth = std::numeric_limits<int>::max();
for (size_type i = 0; i < _lhs.num_columns(); ++i) {
Expand All @@ -248,16 +255,17 @@ class device_row_comparator {

auto const comparator =
element_comparator{_check_nulls, _lhs.column(i), _rhs.column(i), null_precedence, depth};

weak_ordering state;
cuda::std::tie(state, last_null_depth) =
cudf::type_dispatcher(_lhs.column(i).type(), comparator, lhs_index, rhs_index);

if (state == weak_ordering::EQUIVALENT) { continue; }

return state == (ascending ? weak_ordering::LESS : weak_ordering::GREATER);
return ascending
? state
: (state == weak_ordering::GREATER ? weak_ordering::LESS : weak_ordering::GREATER);
}
return false;
return weak_ordering::EQUIVALENT;
}

private:
Expand All @@ -269,6 +277,37 @@ class device_row_comparator {
std::optional<device_span<null_order const>> const _null_precedence;
}; // class device_row_comparator

/**
* @brief Wraps and interprets the result of templated Comparator that returns a weak_ordering.
* Returns true if the weak_ordering matches any of the templated values.
rwlee marked this conversation as resolved.
Show resolved Hide resolved
*
* @tparam Comparator generic comparator that returns a weak_ordering.
* @tparam values weak_ordering parameter pack of orderings to interpret as true
*/
template <typename Comparator, weak_ordering... values>
struct weak_ordering_comparator_impl{
__device__ bool operator()(size_type const& lhs, size_type const& rhs){
weak_ordering const result = comparator(lhs, rhs);
return ( (result == values) || ...);
}
Comparator comparator;
};

/**
* @brief Wraps and interprets the result of device_row_comparator, true if the result is
* weak_ordering::LESS meaning one row is lexicographically *less* than another row.
*
* @tparam Nullate A cudf::nullate type describing whether to check for nulls.
*/
template <typename Nullate>
using less_comparator =
weak_ordering_comparator_impl<device_row_comparator<Nullate>, weak_ordering::LESS>;

template <typename Nullate>
using less_equivalent_comparator = weak_ordering_comparator_impl<device_row_comparator<Nullate>,
weak_ordering::LESS,
weak_ordering::EQUIVALENT>;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be useful to have aliases for the remaining comparator too (i.e., > and >=).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then how about less_equivalent_comparator? We are not using it now, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#9452 uses less_equivalent_comparator.

struct preprocessed_table {
using table_device_view_owner =
std::invoke_result_t<decltype(table_device_view::create), table_view, rmm::cuda_stream_view>;
Expand Down Expand Up @@ -417,10 +456,10 @@ class self_comparator {
* @tparam Nullate A cudf::nullate type describing whether to check for nulls.
*/
template <typename Nullate>
device_row_comparator<Nullate> device_comparator(Nullate nullate = {}) const
less_comparator<Nullate> device_comparator(Nullate nullate = {}) const
{
return device_row_comparator(
nullate, *d_t, *d_t, d_t->depths(), d_t->column_order(), d_t->null_precedence());
return less_comparator<Nullate>{device_row_comparator<Nullate>(
nullate, *d_t, *d_t, d_t->depths(), d_t->column_order(), d_t->null_precedence())};
}

private:
Expand Down