Skip to content

Commit

Permalink
Return weak orderings from device_row_comparator. (#10793)
Browse files Browse the repository at this point in the history
This PR changes the experimental `device_row_comparator` to return `weak_ordering` instead of `bool`.

Originally part of PR #9452. Aids PR #10730, which builds strongly-typed two table comparators and should return a `weak_ordering`.

Authors:
  - Ryan Lee (https://github.com/rwlee)

Approvers:
  - Bradley Dice (https://github.com/bdice)
  - Nghia Truong (https://github.com/ttnghia)

URL: #10793
  • Loading branch information
rwlee authored May 11, 2022
1 parent 0cc29a0 commit 325fa77
Showing 1 changed file with 58 additions and 16 deletions.
74 changes: 58 additions & 16 deletions cpp/include/cudf/table/experimental/row_operators.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include <utility>

namespace cudf {

namespace experimental {

/**
Expand All @@ -68,16 +69,17 @@ struct dispatch_void_if_nested {
};

namespace row {

namespace lexicographic {

/**
* @brief Computes whether one row is lexicographically *less* than another row.
* @brief Computes the lexicographic comparison between 2 rows.
*
* Lexicographic ordering is determined by:
* - Two rows are compared element by element.
* - The first mismatching element defines which row is lexicographically less
* or greater than the other.
* - If the rows are compared without mismatched elements, the rows are equivalent
*
*
* Lexicographic ordering is exactly equivalent to doing an alphabetical sort of
* two words, for example, `aac` would be *less* than (or precede) `abb`. The
Expand All @@ -89,7 +91,6 @@ namespace lexicographic {
template <typename Nullate>
class device_row_comparator {
friend class self_comparator;

/**
* @brief Construct a function object for performing a lexicographic
* comparison between the rows of two tables.
Expand Down Expand Up @@ -145,7 +146,11 @@ class device_row_comparator {
column_device_view rhs,
null_order null_precedence = null_order::BEFORE,
int depth = 0)
: _lhs{lhs}, _rhs{rhs}, _nulls{check_nulls}, _null_precedence{null_precedence}, _depth{depth}
: _lhs{lhs},
_rhs{rhs},
_check_nulls{check_nulls},
_null_precedence{null_precedence},
_depth{depth}
{
}

Expand All @@ -162,7 +167,7 @@ class device_row_comparator {
__device__ cuda::std::pair<weak_ordering, int> operator()(
size_type const lhs_element_index, size_type const rhs_element_index) const noexcept
{
if (_nulls) {
if (_check_nulls) {
bool const lhs_is_null{_lhs.is_null(lhs_element_index)};
bool const rhs_is_null{_rhs.is_null(rhs_element_index)};

Expand Down Expand Up @@ -211,29 +216,30 @@ class device_row_comparator {
++depth;
}

auto const comparator = element_comparator{_nulls, lcol, rcol, _null_precedence, depth};
auto const comparator = element_comparator{_check_nulls, lcol, rcol, _null_precedence, depth};
return cudf::type_dispatcher<dispatch_void_if_nested>(
lcol.type(), comparator, lhs_element_index, rhs_element_index);
}

private:
column_device_view const _lhs;
column_device_view const _rhs;
Nullate const _nulls;
Nullate const _check_nulls;
null_order const _null_precedence;
int const _depth;
};

public:
/**
* @brief Checks whether the row at `lhs_index` in the `lhs` table compares
* lexicographically less than the row at `rhs_index` in the `rhs` table.
* lexicographically less, greater, or equivalent to the row at `rhs_index` in the `rhs` table.
*
* @param lhs_index The index of row in the `lhs` table to examine
* @param rhs_index The index of the row in the `rhs` table to examine
* @return `true` if row from the `lhs` table compares less than row in the `rhs` table
* @return weak ordering comparison of the row in the `lhs` table relative to the row in the `rhs`
* table
*/
__device__ bool operator()(size_type const lhs_index, size_type const rhs_index) const noexcept
__device__ weak_ordering operator()(size_type lhs_index, size_type rhs_index) const noexcept
{
int last_null_depth = std::numeric_limits<int>::max();
for (size_type i = 0; i < _lhs.num_columns(); ++i) {
Expand All @@ -248,16 +254,17 @@ class device_row_comparator {

auto const comparator =
element_comparator{_check_nulls, _lhs.column(i), _rhs.column(i), null_precedence, depth};

weak_ordering state;
cuda::std::tie(state, last_null_depth) =
cudf::type_dispatcher(_lhs.column(i).type(), comparator, lhs_index, rhs_index);

if (state == weak_ordering::EQUIVALENT) { continue; }

return state == (ascending ? weak_ordering::LESS : weak_ordering::GREATER);
return ascending
? state
: (state == weak_ordering::GREATER ? weak_ordering::LESS : weak_ordering::GREATER);
}
return false;
return weak_ordering::EQUIVALENT;
}

private:
Expand All @@ -269,6 +276,41 @@ class device_row_comparator {
std::optional<device_span<null_order const>> const _null_precedence;
}; // class device_row_comparator

/**
* @brief Wraps and interprets the result of templated Comparator that returns a weak_ordering.
* Returns true if the weak_ordering matches any of the templated values.
*
* Note that this should never be used with only `weak_ordering::EQUIVALENT`.
* An equality comparator should be used instead for optimal performance.
*
* @tparam Comparator generic comparator that returns a weak_ordering.
* @tparam values weak_ordering parameter pack of orderings to interpret as true
*/
template <typename Comparator, weak_ordering... values>
struct weak_ordering_comparator_impl {
__device__ bool operator()(size_type const& lhs, size_type const& rhs)
{
weak_ordering const result = comparator(lhs, rhs);
return ((result == values) || ...);
}
Comparator comparator;
};

/**
* @brief Wraps and interprets the result of device_row_comparator, true if the result is
* weak_ordering::LESS meaning one row is lexicographically *less* than another row.
*
* @tparam Nullate A cudf::nullate type describing whether to check for nulls.
*/
template <typename Nullate>
using less_comparator =
weak_ordering_comparator_impl<device_row_comparator<Nullate>, weak_ordering::LESS>;

template <typename Nullate>
using less_equivalent_comparator = weak_ordering_comparator_impl<device_row_comparator<Nullate>,
weak_ordering::LESS,
weak_ordering::EQUIVALENT>;

struct preprocessed_table {
using table_device_view_owner =
std::invoke_result_t<decltype(table_device_view::create), table_view, rmm::cuda_stream_view>;
Expand Down Expand Up @@ -417,10 +459,10 @@ class self_comparator {
* @tparam Nullate A cudf::nullate type describing whether to check for nulls.
*/
template <typename Nullate>
device_row_comparator<Nullate> device_comparator(Nullate nullate = {}) const
less_comparator<Nullate> device_comparator(Nullate nullate = {}) const
{
return device_row_comparator(
nullate, *d_t, *d_t, d_t->depths(), d_t->column_order(), d_t->null_precedence());
return less_comparator<Nullate>{device_row_comparator<Nullate>(
nullate, *d_t, *d_t, d_t->depths(), d_t->column_order(), d_t->null_precedence())};
}

private:
Expand Down

0 comments on commit 325fa77

Please sign in to comment.