Skip to content

Commit

Permalink
Expose conditional join size calculation (#8928)
Browse files Browse the repository at this point in the history
Resolves #8918 by providing a new API for getting the output size for conditional joins (except full joins). This PR removes the unnecessary `conditional_join.cuh` header and inlines the logic into the `conditional_join.cu` file where it is used, and adds the new logic into that file as well. The public APIs are now exposed in `conditional_join.hpp`.

Adding the join size calculation also revealed a couple of bugs in the conditional join tests that were hiding a real bug in the conditional join implementation. The main test bug was the use of `std::equal` with the actual result as the first iterator, so if the actual result was empty it was never compared against the expected result (even if it was nonempty). This bug masked a couple of minor errors in the expected outputs encoded in the test. These are now fixed. This bug was also hiding a deeper issue where the AST device code was always using the left row index to pull data for the left hand operand to binary operations, even when the lhs was actually a column from the right table. That bug is now fixed as well.

Contributes to #8145.

Authors:
  - Vyas Ramasubramani (https://github.com/vyasr)

Approvers:
  - Jason Lowe (https://github.com/jlowe)
  - MithunR (https://github.com/mythrocks)

URL: #8928
  • Loading branch information
vyasr authored Aug 13, 2021
1 parent 43f9e3b commit fb29071
Show file tree
Hide file tree
Showing 12 changed files with 638 additions and 256 deletions.
42 changes: 27 additions & 15 deletions cpp/include/cudf/ast/detail/expression_evaluator.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@

#include <rmm/cuda_stream_view.hpp>

#include <thrust/optional.h>

namespace cudf {

namespace ast {
Expand Down Expand Up @@ -286,17 +288,26 @@ struct expression_evaluator {
*/
template <typename Element, CUDF_ENABLE_IF(column_device_view::has_element_accessor<Element>())>
__device__ possibly_null_value_t<Element, has_nulls> resolve_input(
detail::device_data_reference device_data_reference, cudf::size_type row_index) const
detail::device_data_reference device_data_reference,
cudf::size_type left_row_index,
thrust::optional<cudf::size_type> right_row_index = {}) const
{
auto const data_index = device_data_reference.data_index;
auto const ref_type = device_data_reference.reference_type;
// TODO: Everywhere in the code assumes that the table reference is either
// left or right. Should we error-check somewhere to prevent
// table_reference::OUTPUT from being specified?
auto const& table = device_data_reference.table_source == table_reference::LEFT ? left : right;
using ReturnType = possibly_null_value_t<Element, has_nulls>;
using ReturnType = possibly_null_value_t<Element, has_nulls>;
if (ref_type == detail::device_data_reference_type::COLUMN) {
// If we have nullable data, return an empty nullable type with no value if the data is null.
auto const& table =
(device_data_reference.table_source == table_reference::LEFT) ? left : right;
// Note that the code below assumes that a right index has been passed in
// any case where device_data_reference.table_source == table_reference::RIGHT.
// Otherwise, behavior is undefined.
auto const row_index = (device_data_reference.table_source == table_reference::LEFT)
? left_row_index
: *right_row_index;
if constexpr (has_nulls) {
return table.column(data_index).is_valid(row_index)
? ReturnType(table.column(data_index).element<Element>(row_index))
Expand All @@ -322,7 +333,9 @@ struct expression_evaluator {
template <typename Element,
CUDF_ENABLE_IF(not column_device_view::has_element_accessor<Element>())>
__device__ possibly_null_value_t<Element, has_nulls> resolve_input(
detail::device_data_reference device_data_reference, cudf::size_type row_index) const
detail::device_data_reference device_data_reference,
cudf::size_type left_row_index,
thrust::optional<cudf::size_type> right_row_index = {}) const
{
cudf_assert(false && "Unsupported type in resolve_input.");
// Unreachable return used to silence compiler warnings.
Expand Down Expand Up @@ -385,8 +398,8 @@ struct expression_evaluator {
cudf::size_type const output_row_index,
ast_operator const op) const
{
auto const typed_lhs = resolve_input<LHS>(lhs, left_row_index);
auto const typed_rhs = resolve_input<RHS>(rhs, right_row_index);
auto const typed_lhs = resolve_input<LHS>(lhs, left_row_index, right_row_index);
auto const typed_rhs = resolve_input<RHS>(rhs, left_row_index, right_row_index);
ast_operator_dispatcher(op,
binary_expression_output_handler<LHS, RHS>(*this),
output_object,
Expand Down Expand Up @@ -447,19 +460,18 @@ struct expression_evaluator {
cudf::size_type const right_row_index,
cudf::size_type const output_row_index)
{
auto operator_source_index = static_cast<cudf::size_type>(0);
cudf::size_type operator_source_index{0};
for (cudf::size_type operator_index = 0; operator_index < plan.operators.size();
operator_index++) {
++operator_index) {
// Execute operator
auto const op = plan.operators[operator_index];
auto const arity = ast_operator_arity(op);
if (arity == 1) {
// Unary operator
auto const input =
plan.data_references[plan.operator_source_indices[operator_source_index]];
plan.data_references[plan.operator_source_indices[operator_source_index++]];
auto const output =
plan.data_references[plan.operator_source_indices[operator_source_index + 1]];
operator_source_index += arity + 1;
plan.data_references[plan.operator_source_indices[operator_source_index++]];
auto input_row_index =
input.table_source == table_reference::LEFT ? left_row_index : right_row_index;
type_dispatcher(input.data_type,
Expand All @@ -472,12 +484,12 @@ struct expression_evaluator {
op);
} else if (arity == 2) {
// Binary operator
auto const lhs = plan.data_references[plan.operator_source_indices[operator_source_index]];
auto const lhs =
plan.data_references[plan.operator_source_indices[operator_source_index++]];
auto const rhs =
plan.data_references[plan.operator_source_indices[operator_source_index + 1]];
plan.data_references[plan.operator_source_indices[operator_source_index++]];
auto const output =
plan.data_references[plan.operator_source_indices[operator_source_index + 2]];
operator_source_index += arity + 1;
plan.data_references[plan.operator_source_indices[operator_source_index++]];
type_dispatcher(lhs.data_type,
detail::single_dispatch_binary_operator{},
*this,
Expand Down
12 changes: 7 additions & 5 deletions cpp/include/cudf/ast/detail/expression_parser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@

#include <thrust/optional.h>

#include <functional>
#include <numeric>
#include <optional>

namespace cudf {
namespace ast {
Expand Down Expand Up @@ -131,8 +133,8 @@ class expression_parser {
* @param right The right table used for evaluating the abstract syntax tree.
*/
expression_parser(node const& expr,
cudf::table_view left,
cudf::table_view right,
cudf::table_view const& left,
std::optional<std::reference_wrapper<cudf::table_view const>> right,
bool has_nulls,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
Expand All @@ -149,11 +151,11 @@ class expression_parser {
* @param table The table used for evaluating the abstract syntax tree.
*/
expression_parser(node const& expr,
cudf::table_view table,
cudf::table_view const& table,
bool has_nulls,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
: expression_parser(expr, table, table, has_nulls, stream, mr)
: expression_parser(expr, table, {}, has_nulls, stream, mr)
{
}

Expand Down Expand Up @@ -322,7 +324,7 @@ class expression_parser {
///< owned by this class and persists until it is destroyed.

cudf::table_view const& _left;
cudf::table_view const& _right;
std::optional<std::reference_wrapper<cudf::table_view const>> _right;
cudf::size_type _node_count;
intermediate_counter _intermediate_counter;
bool _has_nulls;
Expand Down
123 changes: 108 additions & 15 deletions cpp/include/cudf/join.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -673,8 +673,6 @@ class hash_join {
* Result: {{1}, {0}}
* @endcode
*
* @throw cudf::logic_error if number of elements in `left_keys` or `right_keys`
* mismatch.
* @throw cudf::logic_error if the binary predicate outputs a non-boolean result.
*
* @param left The left table
Expand All @@ -692,8 +690,9 @@ conditional_inner_join(
table_view left,
table_view right,
ast::expression binary_predicate,
null_equality compare_nulls = null_equality::EQUAL,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());
null_equality compare_nulls = null_equality::EQUAL,
std::optional<std::size_t> output_size = {},
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
* @brief Returns a pair of row index vectors corresponding to all pairs
Expand Down Expand Up @@ -721,8 +720,6 @@ conditional_inner_join(
* Result: {{0, 1, 2}, {None, 0, None}}
* @endcode
*
* @throw cudf::logic_error if number of elements in `left_keys` or `right_keys`
* mismatch.
* @throw cudf::logic_error if the binary predicate outputs a non-boolean result.
*
* @param left The left table
Expand All @@ -739,7 +736,8 @@ std::pair<std::unique_ptr<rmm::device_uvector<size_type>>,
conditional_left_join(table_view left,
table_view right,
ast::expression binary_predicate,
null_equality compare_nulls = null_equality::EQUAL,
null_equality compare_nulls = null_equality::EQUAL,
std::optional<std::size_t> output_size = {},
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
Expand Down Expand Up @@ -767,8 +765,6 @@ conditional_left_join(table_view left,
* Result: {{0, 1, 2, None, None}, {None, 0, None, 1, 2}}
* @endcode
*
* @throw cudf::logic_error if number of elements in `left_keys` or `right_keys`
* mismatch.
* @throw cudf::logic_error if the binary predicate outputs a non-boolean result.
*
* @param left The left table
Expand Down Expand Up @@ -808,8 +804,6 @@ conditional_full_join(table_view left,
* Result: {1}
* @endcode
*
* @throw cudf::logic_error if number of elements in `left_keys` or `right_keys`
* mismatch.
* @throw cudf::logic_error if the binary predicate outputs a non-boolean result.
*
* @param left The left table
Expand All @@ -826,8 +820,9 @@ std::unique_ptr<rmm::device_uvector<size_type>> conditional_left_semi_join(
table_view left,
table_view right,
ast::expression binary_predicate,
null_equality compare_nulls = null_equality::EQUAL,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());
null_equality compare_nulls = null_equality::EQUAL,
std::optional<std::size_t> output_size = {},
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
* @brief Returns an index vector corresponding to all rows in the left table
Expand All @@ -849,8 +844,6 @@ std::unique_ptr<rmm::device_uvector<size_type>> conditional_left_semi_join(
* Result: {0, 2}
* @endcode
*
* @throw cudf::logic_error if number of elements in `left_keys` or `right_keys`
* mismatch.
* @throw cudf::logic_error if the binary predicate outputs a non-boolean result.
*
* @param left The left table
Expand All @@ -864,11 +857,111 @@ std::unique_ptr<rmm::device_uvector<size_type>> conditional_left_semi_join(
* `right` .
*/
std::unique_ptr<rmm::device_uvector<size_type>> conditional_left_anti_join(
table_view left,
table_view right,
ast::expression binary_predicate,
null_equality compare_nulls = null_equality::EQUAL,
std::optional<std::size_t> output_size = {},
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
* @brief Returns the exact number of matches (rows) when performing a
* conditional inner join between the specified tables where the predicate
* evaluates to true.
*
* If the provided predicate returns NULL for a pair of rows
* (left, right), that pair is not included in the output.
*
* @throw cudf::logic_error if the binary predicate outputs a non-boolean result.
*
* @param left The left table
* @param right The right table
* @param binary_predicate The condition on which to join.
* @param compare_nulls Whether the equality operator returns true or false for two nulls.
* @param mr Device memory resource used to allocate the returned table and columns' device memory
*
* @return The size that would result from performing the requested join.
*/
std::size_t conditional_inner_join_size(
table_view left,
table_view right,
ast::expression binary_predicate,
null_equality compare_nulls = null_equality::EQUAL,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
* @brief Returns the exact number of matches (rows) when performing a
* conditional left join between the specified tables where the predicate
* evaluates to true.
*
* If the provided predicate returns NULL for a pair of rows
* (left, right), that pair is not included in the output.
*
* @throw cudf::logic_error if the binary predicate outputs a non-boolean result.
*
* @param left The left table
* @param right The right table
* @param binary_predicate The condition on which to join.
* @param compare_nulls Whether the equality operator returns true or false for two nulls.
* @param mr Device memory resource used to allocate the returned table and columns' device memory
*
* @return The size that would result from performing the requested join.
*/
std::size_t conditional_left_join_size(
table_view left,
table_view right,
ast::expression binary_predicate,
null_equality compare_nulls = null_equality::EQUAL,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
* @brief Returns the exact number of matches (rows) when performing a
* conditional left semi join between the specified tables where the predicate
* evaluates to true.
*
* If the provided predicate returns NULL for a pair of rows
* (left, right), that pair is not included in the output.
*
* @throw cudf::logic_error if the binary predicate outputs a non-boolean result.
*
* @param left The left table
* @param right The right table
* @param binary_predicate The condition on which to join.
* @param compare_nulls Whether the equality operator returns true or false for two nulls.
* @param mr Device memory resource used to allocate the returned table and columns' device memory
*
* @return The size that would result from performing the requested join.
*/
std::size_t conditional_left_semi_join_size(
table_view left,
table_view right,
ast::expression binary_predicate,
null_equality compare_nulls = null_equality::EQUAL,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
* @brief Returns the exact number of matches (rows) when performing a
* conditional left anti join between the specified tables where the predicate
* evaluates to true.
*
* If the provided predicate returns NULL for a pair of rows
* (left, right), that pair is not included in the output.
*
* @throw cudf::logic_error if the binary predicate outputs a non-boolean result.
*
* @param left The left table
* @param right The right table
* @param binary_predicate The condition on which to join.
* @param compare_nulls Whether the equality operator returns true or false for two nulls.
* @param mr Device memory resource used to allocate the returned table and columns' device memory
*
* @return The size that would result from performing the requested join.
*/
std::size_t conditional_left_anti_join_size(
table_view left,
table_view right,
ast::expression binary_predicate,
null_equality compare_nulls = null_equality::EQUAL,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());
/** @} */ // end of group
} // namespace cudf
2 changes: 2 additions & 0 deletions cpp/include/cudf/transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ std::pair<std::unique_ptr<rmm::device_buffer>, size_type> nans_to_nulls(
* This evaluates an expression over a table to produce a new column. Also called an n-ary
* transform.
*
* @throws cudf::logic_error if passed an expression operating on table_reference::RIGHT.
*
* @param table The table used for expression evaluation.
* @param expr The root of the expression tree.
* @param mr Device memory resource.
Expand Down
15 changes: 12 additions & 3 deletions cpp/src/ast/expression_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,18 @@ cudf::size_type expression_parser::visit(column_reference const& expr)
// Increment the node index
_node_count++;
// Resolve node type
auto const data_type = expr.get_table_source() == table_reference::LEFT
? expr.get_data_type(_left)
: expr.get_data_type(_right);
cudf::data_type data_type;
if (expr.get_table_source() == table_reference::LEFT) {
data_type = expr.get_data_type(_left);
} else {
if (_right.has_value()) {
data_type = expr.get_data_type(*_right);
} else {
CUDF_FAIL(
"Your expression contains a reference to the RIGHT table even though it will only be "
"evaluated on a single table (by convention, the LEFT table).");
}
}
// Push data reference
auto const source = detail::device_data_reference(detail::device_data_reference_type::COLUMN,
data_type,
Expand Down
Loading

0 comments on commit fb29071

Please sign in to comment.