Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nested struct binop comparison #9452

Closed
wants to merge 79 commits into from
Closed
Show file tree
Hide file tree
Changes from 78 commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
262c2a2
struct binop first pass
rwlee Sep 30, 2021
a865fe0
vector-vector nested struct comparison
rwlee Oct 14, 2021
ce4440c
cleanup and simplify core code
rwlee Oct 16, 2021
1472c5f
remove type dispatch and other code cleanup
rwlee Oct 20, 2021
4a31fb6
move struct comparison to compiled binops code
rwlee Oct 20, 2021
ce2d727
improved testing, type checks, and skipped null value calculations
rwlee Oct 21, 2021
10c95f9
cleanup
rwlee Oct 21, 2021
12cd09e
Merge branch 'branch-21.12' into rwlee/struct_col_compare
rwlee Oct 21, 2021
b6fa590
fix upmerge issues
rwlee Oct 21, 2021
d64c1f9
fix logic and improve documentation
rwlee Oct 22, 2021
57900da
clean up logic for nulls
rwlee Oct 22, 2021
f170149
remove unecessary call to superimpose parent nulls
rwlee Oct 22, 2021
de129a1
PR fixes
rwlee Oct 27, 2021
5e84e89
Merge branch 'branch-21.12' into rwlee/struct_col_compare
rwlee Oct 27, 2021
b632192
pr fixes
rwlee Oct 27, 2021
4266f8c
Merge branch 'branch-21.12' into rwlee/struct_col_compare
rwlee Oct 27, 2021
367ec07
restructure struct binop code and other pr fixes
rwlee Nov 3, 2021
5a1f016
Merge branch 'branch-21.12' into rwlee/struct_col_compare
rwlee Nov 3, 2021
1f29168
Merge branch 'branch-21.12' into rwlee/struct_col_compare
rwlee Nov 3, 2021
1d6263e
full paths for includes
rwlee Nov 9, 2021
97bd5e1
Merge branch 'branch-21.12' into rwlee/struct_col_compare
rwlee Nov 9, 2021
48d0355
move to new TU and remove common code
rwlee Nov 11, 2021
6cf0e16
fix logic errors and push down struct branching
rwlee Nov 11, 2021
9ec2acf
remove deleted file from CMakeLists
rwlee Nov 11, 2021
3016abf
Naming and comment fixes
rwlee Nov 11, 2021
b2a7973
naming
rwlee Nov 11, 2021
191da69
style formatting
rwlee Nov 11, 2021
2cf2b28
merge apply_binary_op and _impl implementation
rwlee Nov 12, 2021
2b634c4
all apply_binary_op calls call apply_binary_op_impl
rwlee Nov 13, 2021
19f1afb
common code path
rwlee Nov 23, 2021
f316a0a
explicit instantiation of struct_compare
rwlee Nov 23, 2021
c684ef1
Merge branch 'branch-22.02' into rwlee/struct_col_compare
rwlee Nov 23, 2021
7f36241
streamline explicit instantiation
rwlee Nov 29, 2021
8cf0660
Merge remote-tracking branch 'pub/branch-22.02' into rwlee/struct_col…
rwlee Nov 29, 2021
2abefd5
remove op argument
rwlee Dec 5, 2021
8cc05e2
documentation
rwlee Dec 6, 2021
1bde152
Merge branch 'branch-22.02' into rwlee/struct_col_compare
rwlee Dec 6, 2021
470acfe
Fix upmerge errors
rwlee Dec 6, 2021
ce21d90
Merge remote-tracking branch 'pub/branch-22.02' into rwlee/struct_col…
rwlee Dec 16, 2021
83fa370
Merge remote-tracking branch 'pub/branch-22.04' into rwlee/struct_col…
rwlee Feb 1, 2022
2b77739
fix new ops from upmerge
rwlee Feb 4, 2022
de09cec
Fix floating point nan handling in struct comparison binops
rwlee Feb 8, 2022
8ad9545
Merge remote-tracking branch 'pub/branch-22.04' into rwlee/struct_col…
rwlee Feb 8, 2022
251d607
fix formatting
rwlee Feb 9, 2022
703aaf8
fix copyright
rwlee Feb 14, 2022
43e451b
fix accidently deletd function
rwlee Feb 14, 2022
9ec4a41
style fix
rwlee Feb 15, 2022
201a89b
copyright fix
rwlee Feb 15, 2022
cc164d6
Merge remote-tracking branch 'pub/branch-22.04' into rwlee/struct_col…
rwlee Feb 15, 2022
1bb1534
fix cmake style
rwlee Feb 15, 2022
6c6c8ab
re-add missing function name
rwlee Feb 16, 2022
42e58ae
style fix
rwlee Feb 16, 2022
475c896
Fix struct equality binop comparisons
rwlee Feb 19, 2022
1dae04a
PR reviews
rwlee Mar 7, 2022
62224cf
Merge remote-tracking branch 'pub/branch-22.04' into rwlee/struct_col…
rwlee Mar 8, 2022
a35600d
refactor row comparison operators into common spaceship operator
rwlee Mar 22, 2022
b6f0397
Merge remote-tracking branch 'pub/branch-22.04' into rwlee/struct_col…
rwlee Mar 22, 2022
fcc1dd2
first pass, test failures
rwlee Mar 29, 2022
5abf2a8
Merge remote-tracking branch 'pub/branch-22.06' into rwlee/struct_col…
rwlee Mar 29, 2022
9d50ac0
Refactor struct binop comparison to use experimental ops
rwlee Apr 16, 2022
8628c24
Merge remote-tracking branch 'pub/branch-22.06' into rwlee/struct_col…
rwlee Apr 16, 2022
a836a96
Merge remote-tracking branch 'pub/branch-22.06' into rwlee/struct_col…
rwlee Apr 18, 2022
a537805
fix performance regression and code cleanup
rwlee May 2, 2022
f7af41f
Merge remote-tracking branch 'pub/branch-22.06' into rwlee/struct_col…
rwlee May 2, 2022
1af4643
fix merge errors
rwlee May 3, 2022
4d929d9
Merge remote-tracking branch 'upstream/branch-22.06' into rwlee/struc…
bdice May 3, 2022
2298988
Revert include changes.
bdice May 3, 2022
bf1c6ee
split off weak ordering row operator changes
rwlee May 4, 2022
5d87db2
device_row_comparator private with friend class
rwlee May 4, 2022
fd716b9
Merge remote-tracking branch 'pub/branch-22.06' into rwlee/row_op_split
rwlee May 4, 2022
2dd2045
device_less conversion to templated struct
rwlee May 6, 2022
7ba960e
fold parameter pack
rwlee May 9, 2022
84833e7
Apply suggestions from code review
rwlee May 10, 2022
a944b4f
Merge remote-tracking branch 'pub/branch-22.06' into rwlee/struct_col…
rwlee May 10, 2022
4d197ea
fix code style
rwlee May 10, 2022
08092fe
Merge branch 'rwlee/row_op_split' of github.com:rwlee/cudf into rwlee…
rwlee May 10, 2022
d8986c5
Merge branch 'rwlee/row_op_split' into rwlee/struct_col_compare
rwlee May 10, 2022
548dcf1
fix code format
rwlee May 10, 2022
1dd1159
Merge remote-tracking branch 'upstream/branch-22.06' into rwlee/struc…
bdice May 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 91 additions & 22 deletions cpp/include/cudf/table/experimental/row_operators.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#pragma once

#include <cudf/binaryop.hpp>
#include <cudf/column/column_device_view.cuh>
#include <cudf/detail/hashing.hpp>
#include <cudf/detail/iterator.cuh>
Expand Down Expand Up @@ -45,6 +46,7 @@
#include <utility>

namespace cudf {

namespace experimental {

/**
Expand All @@ -68,28 +70,31 @@ struct dispatch_void_if_nested {
};

namespace row {

namespace lexicographic {

/**
* @brief Computes whether one row is lexicographically *less* than another row.
* @brief Computes the lexicographic comparison between 2 rows.
*
* Lexicographic ordering is determined by:
* - Two rows are compared element by element.
* - The first mismatching element defines which row is lexicographically less
* or greater than the other.
* - If the rows are compared without mismatched elements, the rows are equivalent
*
*
* Lexicographic ordering is exactly equivalent to doing an alphabetical sort of
* two words, for example, `aac` would be *less* than (or precede) `abb`. The
* second letter in both words is the first non-equal letter, and `a < b`, thus
* `aac < abb`.
*
* @tparam Nullate A cudf::nullate type describing whether to check for nulls.
* @tparam NanConfig default configuration nans are equal, if set to true triggers specialized IEEE
* 754 compliant nan handling
*/
template <typename Nullate>
template <typename Nullate, bool NanConfig = false>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try to pull out the NanConfig and related changes into a separate PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

class device_row_comparator {
friend class self_comparator;

// friend class self_comparator;
public: // needs to be removed, pending strict typing for indices
/**
* @brief Construct a function object for performing a lexicographic
* comparison between the rows of two tables.
Expand Down Expand Up @@ -139,13 +144,20 @@ class device_row_comparator {
* @param null_precedence Indicates how null values are ordered with other values
* @param depth The depth of the column if part of a nested column @see
* preprocessed_table::depths
* @param nan_result Specifies what value should be returned if either element is `nan`
*/
__device__ element_comparator(Nullate check_nulls,
column_device_view lhs,
column_device_view rhs,
null_order null_precedence = null_order::BEFORE,
int depth = 0)
: _lhs{lhs}, _rhs{rhs}, _nulls{check_nulls}, _null_precedence{null_precedence}, _depth{depth}
int depth = 0,
weak_ordering nan_result = weak_ordering::EQUIVALENT)
: _lhs{lhs},
_rhs{rhs},
_check_nulls{check_nulls},
_null_precedence{null_precedence},
_depth{depth},
_nan_result{nan_result}
{
}

Expand All @@ -162,7 +174,7 @@ class device_row_comparator {
__device__ cuda::std::pair<weak_ordering, int> operator()(
size_type const lhs_element_index, size_type const rhs_element_index) const noexcept
{
if (_nulls) {
if (_check_nulls) {
bool const lhs_is_null{_lhs.is_null(lhs_element_index)};
bool const rhs_is_null{_rhs.is_null(rhs_element_index)};

Expand All @@ -171,8 +183,12 @@ class device_row_comparator {
}
}

return cuda::std::pair(relational_compare(_lhs.element<Element>(lhs_element_index),
_rhs.element<Element>(rhs_element_index)),
return cuda::std::pair(NanConfig
? relational_compare(_lhs.element<Element>(lhs_element_index),
_rhs.element<Element>(rhs_element_index),
_nan_result)
: relational_compare(_lhs.element<Element>(lhs_element_index),
_rhs.element<Element>(rhs_element_index)),
std::numeric_limits<int>::max());
}

Expand Down Expand Up @@ -211,29 +227,32 @@ class device_row_comparator {
++depth;
}

auto const comparator = element_comparator{_nulls, lcol, rcol, _null_precedence, depth};
auto const comparator =
element_comparator{_check_nulls, lcol, rcol, _null_precedence, depth, _nan_result};
return cudf::type_dispatcher<dispatch_void_if_nested>(
lcol.type(), comparator, lhs_element_index, rhs_element_index);
}

private:
column_device_view const _lhs;
column_device_view const _rhs;
Nullate const _nulls;
Nullate const _check_nulls;
null_order const _null_precedence;
int const _depth;
weak_ordering _nan_result;
};

public:
/**
* @brief Checks whether the row at `lhs_index` in the `lhs` table compares
* lexicographically less than the row at `rhs_index` in the `rhs` table.
* lexicographically less, greater, or equivalent to the row at `rhs_index` in the `rhs` table.
*
* @param lhs_index The index of row in the `lhs` table to examine
* @param rhs_index The index of the row in the `rhs` table to examine
* @return `true` if row from the `lhs` table compares less than row in the `rhs` table
* @return weak ordering comparison of the row in the `lhs` table relative to the row in the `rhs`
* table
*/
__device__ bool operator()(size_type const lhs_index, size_type const rhs_index) const noexcept
__device__ weak_ordering operator()(size_type lhs_index, size_type rhs_index) const noexcept
{
int last_null_depth = std::numeric_limits<int>::max();
for (size_type i = 0; i < _lhs.num_columns(); ++i) {
Expand All @@ -247,17 +266,31 @@ class device_row_comparator {
_null_precedence.has_value() ? (*_null_precedence)[i] : null_order::BEFORE;

auto const comparator =
element_comparator{_check_nulls, _lhs.column(i), _rhs.column(i), null_precedence, depth};
NanConfig ? element_comparator{_check_nulls,
_lhs.column(i),
_rhs.column(i),
null_precedence,
depth,
ascending ? weak_ordering::GREATER : weak_ordering::LESS}

: element_comparator{_check_nulls,
_lhs.column(i),
_rhs.column(i),
null_precedence,
depth,
weak_ordering::EQUIVALENT};

weak_ordering state;
cuda::std::tie(state, last_null_depth) =
cudf::type_dispatcher(_lhs.column(i).type(), comparator, lhs_index, rhs_index);

if (state == weak_ordering::EQUIVALENT) { continue; }

return state == (ascending ? weak_ordering::LESS : weak_ordering::GREATER);
return ascending
? state
: (state == weak_ordering::GREATER ? weak_ordering::LESS : weak_ordering::GREATER);
}
return false;
return weak_ordering::EQUIVALENT;
}

private:
Expand All @@ -269,6 +302,42 @@ class device_row_comparator {
std::optional<device_span<null_order const>> const _null_precedence;
}; // class device_row_comparator

/**
* @brief Wraps and interprets the result of templated Comparator that returns a weak_ordering.
* Returns true if the weak_ordering matches any of the templated values.
*
* Note that this should never be used with only `weak_ordering::EQUIVALENT`.
* An equality comparator should be used instead for optimal performance.
*
* @tparam Comparator generic comparator that returns a weak_ordering.
* @tparam values weak_ordering parameter pack of orderings to interpret as true
*/
template <typename Comparator, weak_ordering... values>
struct weak_ordering_comparator_impl {
__device__ bool operator()(size_type const& lhs, size_type const& rhs)
Comment on lines +315 to +317
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can actually static_assert this requirement to be sure it doesn't happen.

Double check that I'm using the right names here.

Suggested change
template <typename Comparator, weak_ordering... values>
struct weak_ordering_comparator_impl {
__device__ bool operator()(size_type const& lhs, size_type const& rhs)
template <typename Comparator, weak_ordering... values>
struct weak_ordering_comparator_impl {
static_assert( not (sizeof...(values) == 1 and values == weak_ordering::equivalent), "weak_ordering_comparator should not be used for pure equality comparisons. The `row_equality_comparator` should be used instead");
__device__ bool operator()(size_type const& lhs, size_type const& rhs)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reference, the compiler complains about values == weak_ordering::equivalent in the context of the static assert. This forces an expansion --> ((weak_ordering::EQUIVALENT == values) && ...), which subsequently makes the sizeof...(values) == 1 check irrelevant.

{
weak_ordering const result = comparator(lhs, rhs);
return ((result == values) || ...);
}
Comparator comparator;
};

/**
* @brief Wraps and interprets the result of device_row_comparator, true if the result is
* weak_ordering::LESS meaning one row is lexicographically *less* than another row.
*
* @tparam Nullate A cudf::nullate type describing whether to check for nulls.
*/
template <typename Nullate, bool NanConfig = false>
using less_comparator =
weak_ordering_comparator_impl<device_row_comparator<Nullate, NanConfig>, weak_ordering::LESS>;

template <typename Nullate, bool NanConfig = false>
using less_equivalent_comparator =
weak_ordering_comparator_impl<device_row_comparator<Nullate, NanConfig>,
weak_ordering::LESS,
weak_ordering::EQUIVALENT>;

struct preprocessed_table {
using table_device_view_owner =
std::invoke_result_t<decltype(table_device_view::create), table_view, rmm::cuda_stream_view>;
Expand Down Expand Up @@ -416,11 +485,11 @@ class self_comparator {
*
* @tparam Nullate A cudf::nullate type describing whether to check for nulls.
*/
template <typename Nullate>
device_row_comparator<Nullate> device_comparator(Nullate nullate = {}) const
template <typename Nullate, bool NanConfig = false>
less_comparator<Nullate, NanConfig> device_comparator(Nullate nullate = {}) const
{
return device_row_comparator(
nullate, *d_t, *d_t, d_t->depths(), d_t->column_order(), d_t->null_precedence());
return less_comparator<Nullate, NanConfig>{device_row_comparator<Nullate, NanConfig>(
nullate, *d_t, *d_t, d_t->depths(), d_t->column_order(), d_t->null_precedence())};
}

private:
Expand Down
70 changes: 57 additions & 13 deletions cpp/include/cudf/table/row_operators.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,26 @@ __device__ weak_ordering relational_compare(Element lhs, Element rhs)
return detail::compare_elements(lhs, rhs);
}

/**
* @brief A specialization for floating-point `Element` type relational comparison
* to derive the order of the elements with respect to `lhs`. Returns specified weak_ordering if
* either value is `nan`, enabling IEEE 754 compliant comparison.
*
* This specialization allows `nan` values to be evaluated as not equal to any other value, while
* also not evaluating as greater or less than
*
* @param lhs first element
* @param rhs second element
* @param nan_result specifies what value should be returned if either element is `nan`
* @return Indicates the relationship between the elements in
* the `lhs` and `rhs` columns.
*/
template <typename Element, std::enable_if_t<std::is_floating_point<Element>::value>* = nullptr>
__device__ weak_ordering relational_compare(Element lhs, Element rhs, weak_ordering nan_result)
{
return isnan(lhs) or isnan(rhs) ? nan_result : detail::compare_elements(lhs, rhs);
}

/**
* @brief Compare the nulls according to null order.
*
Expand All @@ -123,11 +143,14 @@ inline __device__ auto null_compare(bool lhs_is_null, bool rhs_is_null, null_ord
*
* @param[in] lhs first element
* @param[in] rhs second element
* @param nan_result ignored for non-floating point operation
* @return Indicates the relationship between the elements in
* the `lhs` and `rhs` columns.
*/
template <typename Element, std::enable_if_t<not std::is_floating_point_v<Element>>* = nullptr>
__device__ weak_ordering relational_compare(Element lhs, Element rhs)
__device__ weak_ordering relational_compare(Element lhs,
Element rhs,
weak_ordering const nan_result = weak_ordering::GREATER)
{
return detail::compare_elements(lhs, rhs);
}
Expand All @@ -138,12 +161,15 @@ __device__ weak_ordering relational_compare(Element lhs, Element rhs)
*
* @param lhs first element
* @param rhs second element
* @param nan_result specifies what value should be returned if either element is `nan`
* @return `true` if `lhs` == `rhs` else `false`.
*/
template <typename Element, std::enable_if_t<std::is_floating_point_v<Element>>* = nullptr>
__device__ bool equality_compare(Element lhs, Element rhs)
__device__ bool equality_compare(Element lhs,
Element rhs,
nan_equality const nan_result = nan_equality::ALL_EQUAL)
{
if (isnan(lhs) and isnan(rhs)) { return true; }
if (isnan(lhs) and isnan(rhs)) { return nan_result == nan_equality::ALL_EQUAL; }
return lhs == rhs;
}

Expand All @@ -153,10 +179,13 @@ __device__ bool equality_compare(Element lhs, Element rhs)
*
* @param lhs first element
* @param rhs second element
* @param nan_result ignored for non-floating point operation
* @return `true` if `lhs` == `rhs` else `false`.
*/
template <typename Element, std::enable_if_t<not std::is_floating_point_v<Element>>* = nullptr>
__device__ bool equality_compare(Element const lhs, Element const rhs)
__device__ bool equality_compare(Element const lhs,
Element const rhs,
nan_equality const nan_result = nan_equality::ALL_EQUAL)
{
return lhs == rhs;
}
Expand All @@ -179,13 +208,19 @@ class element_equality_comparator {
* @param lhs The column containing the first element
* @param rhs The column containing the second element (may be the same as lhs)
* @param nulls_are_equal Indicates if two null elements are treated as equivalent
* @param nan_result specifies what value should be returned if either element is `nan`
*/
__host__ __device__
element_equality_comparator(Nullate has_nulls,
column_device_view lhs,
column_device_view rhs,
null_equality nulls_are_equal = null_equality::EQUAL)
: lhs{lhs}, rhs{rhs}, nulls{has_nulls}, nulls_are_equal{nulls_are_equal}
null_equality nulls_are_equal = null_equality::EQUAL,
nan_equality nans_are_equal = nan_equality::ALL_EQUAL)
: lhs{lhs},
rhs{rhs},
nulls{has_nulls},
nulls_are_equal{nulls_are_equal},
nans_are_equal{nans_are_equal}
{
}

Expand All @@ -212,7 +247,8 @@ class element_equality_comparator {
}

return equality_compare(lhs.element<Element>(lhs_element_index),
rhs.element<Element>(rhs_element_index));
rhs.element<Element>(rhs_element_index),
nans_are_equal);
}

template <typename Element,
Expand All @@ -227,6 +263,7 @@ class element_equality_comparator {
column_device_view rhs;
Nullate nulls;
null_equality nulls_are_equal;
nan_equality nans_are_equal;
};

template <typename Nullate>
Expand All @@ -235,19 +272,25 @@ class row_equality_comparator {
row_equality_comparator(Nullate has_nulls,
table_device_view lhs,
table_device_view rhs,
null_equality nulls_are_equal = null_equality::EQUAL)
: lhs{lhs}, rhs{rhs}, nulls{has_nulls}, nulls_are_equal{nulls_are_equal}
null_equality nulls_are_equal = null_equality::EQUAL,
nan_equality nans_are_equal = nan_equality::ALL_EQUAL)
: lhs{lhs},
rhs{rhs},
nulls{has_nulls},
nulls_are_equal{nulls_are_equal},
nans_are_equal{nans_are_equal}
{
CUDF_EXPECTS(lhs.num_columns() == rhs.num_columns(), "Mismatched number of columns.");
}

__device__ bool operator()(size_type lhs_row_index, size_type rhs_row_index) const noexcept
{
auto equal_elements = [=](column_device_view l, column_device_view r) {
return cudf::type_dispatcher(l.type(),
element_equality_comparator{nulls, l, r, nulls_are_equal},
lhs_row_index,
rhs_row_index);
return cudf::type_dispatcher(
l.type(),
element_equality_comparator{nulls, l, r, nulls_are_equal, nans_are_equal},
lhs_row_index,
rhs_row_index);
};

return thrust::equal(thrust::seq, lhs.begin(), lhs.end(), rhs.begin(), equal_elements);
Expand All @@ -258,6 +301,7 @@ class row_equality_comparator {
table_device_view rhs;
Nullate nulls;
null_equality nulls_are_equal;
nan_equality nans_are_equal;
};

/**
Expand Down
Loading