From fb2907149777725bae1addb42dfc1ff900199ba4 Mon Sep 17 00:00:00 2001 From: Vyas Ramasubramani Date: Fri, 13 Aug 2021 13:09:06 -0700 Subject: [PATCH] Expose conditional join size calculation (#8928) Resolves #8918 by providing a new API for getting the output size for conditional joins (except full joins). This PR removes the unnecessary `conditional_join.cuh` header and inlines the logic into the `conditional_join.cu` file where it is used, and adds the new logic into that file as well. The public APIs are now exposed in `conditional_join.hpp`. Adding the join size calculation also revealed a couple of bugs in the conditional join tests that were hiding a real bug in the conditional join implementation. The main test bug was the use of `std::equal` with the actual result as the first iterator, so if the actual result was empty it was never compared against the expected result (even if it was nonempty). This bug masked a couple of minor errors in the expected outputs encoded in the test. These are now fixed. This bug was also hiding a deeper issue where the AST device code was always using the left row index to pull data for the left hand operand to binary operations, even when the lhs was actually a column from the right table. That bug is now fixed as well. Contributes to #8145. Authors: - Vyas Ramasubramani (https://github.com/vyasr) Approvers: - Jason Lowe (https://github.com/jlowe) - MithunR (https://github.com/mythrocks) URL: https://github.com/rapidsai/cudf/pull/8928 --- .../cudf/ast/detail/expression_evaluator.cuh | 42 ++- .../cudf/ast/detail/expression_parser.hpp | 12 +- cpp/include/cudf/join.hpp | 123 +++++++- cpp/include/cudf/transform.hpp | 2 + cpp/src/ast/expression_parser.cpp | 15 +- cpp/src/join/conditional_join.cu | 298 +++++++++++++++++- cpp/src/join/conditional_join.cuh | 183 ----------- cpp/src/join/conditional_join.hpp | 78 +++++ cpp/src/join/conditional_join_kernels.cuh | 30 +- cpp/src/transform/compute_column.cu | 1 - cpp/tests/join/conditional_join_tests.cu | 103 +++++- .../cudf/ast/CompiledExpressionTest.java | 7 +- 12 files changed, 638 insertions(+), 256 deletions(-) delete mode 100644 cpp/src/join/conditional_join.cuh create mode 100644 cpp/src/join/conditional_join.hpp diff --git a/cpp/include/cudf/ast/detail/expression_evaluator.cuh b/cpp/include/cudf/ast/detail/expression_evaluator.cuh index 2a3cd059e80..ca2cab96123 100644 --- a/cpp/include/cudf/ast/detail/expression_evaluator.cuh +++ b/cpp/include/cudf/ast/detail/expression_evaluator.cuh @@ -31,6 +31,8 @@ #include +#include + namespace cudf { namespace ast { @@ -286,17 +288,26 @@ struct expression_evaluator { */ template ())> __device__ possibly_null_value_t resolve_input( - detail::device_data_reference device_data_reference, cudf::size_type row_index) const + detail::device_data_reference device_data_reference, + cudf::size_type left_row_index, + thrust::optional right_row_index = {}) const { auto const data_index = device_data_reference.data_index; auto const ref_type = device_data_reference.reference_type; // TODO: Everywhere in the code assumes that the table reference is either // left or right. Should we error-check somewhere to prevent // table_reference::OUTPUT from being specified? - auto const& table = device_data_reference.table_source == table_reference::LEFT ? left : right; - using ReturnType = possibly_null_value_t; + using ReturnType = possibly_null_value_t; if (ref_type == detail::device_data_reference_type::COLUMN) { // If we have nullable data, return an empty nullable type with no value if the data is null. + auto const& table = + (device_data_reference.table_source == table_reference::LEFT) ? left : right; + // Note that the code below assumes that a right index has been passed in + // any case where device_data_reference.table_source == table_reference::RIGHT. + // Otherwise, behavior is undefined. + auto const row_index = (device_data_reference.table_source == table_reference::LEFT) + ? left_row_index + : *right_row_index; if constexpr (has_nulls) { return table.column(data_index).is_valid(row_index) ? ReturnType(table.column(data_index).element(row_index)) @@ -322,7 +333,9 @@ struct expression_evaluator { template ())> __device__ possibly_null_value_t resolve_input( - detail::device_data_reference device_data_reference, cudf::size_type row_index) const + detail::device_data_reference device_data_reference, + cudf::size_type left_row_index, + thrust::optional right_row_index = {}) const { cudf_assert(false && "Unsupported type in resolve_input."); // Unreachable return used to silence compiler warnings. @@ -385,8 +398,8 @@ struct expression_evaluator { cudf::size_type const output_row_index, ast_operator const op) const { - auto const typed_lhs = resolve_input(lhs, left_row_index); - auto const typed_rhs = resolve_input(rhs, right_row_index); + auto const typed_lhs = resolve_input(lhs, left_row_index, right_row_index); + auto const typed_rhs = resolve_input(rhs, left_row_index, right_row_index); ast_operator_dispatcher(op, binary_expression_output_handler(*this), output_object, @@ -447,19 +460,18 @@ struct expression_evaluator { cudf::size_type const right_row_index, cudf::size_type const output_row_index) { - auto operator_source_index = static_cast(0); + cudf::size_type operator_source_index{0}; for (cudf::size_type operator_index = 0; operator_index < plan.operators.size(); - operator_index++) { + ++operator_index) { // Execute operator auto const op = plan.operators[operator_index]; auto const arity = ast_operator_arity(op); if (arity == 1) { // Unary operator auto const input = - plan.data_references[plan.operator_source_indices[operator_source_index]]; + plan.data_references[plan.operator_source_indices[operator_source_index++]]; auto const output = - plan.data_references[plan.operator_source_indices[operator_source_index + 1]]; - operator_source_index += arity + 1; + plan.data_references[plan.operator_source_indices[operator_source_index++]]; auto input_row_index = input.table_source == table_reference::LEFT ? left_row_index : right_row_index; type_dispatcher(input.data_type, @@ -472,12 +484,12 @@ struct expression_evaluator { op); } else if (arity == 2) { // Binary operator - auto const lhs = plan.data_references[plan.operator_source_indices[operator_source_index]]; + auto const lhs = + plan.data_references[plan.operator_source_indices[operator_source_index++]]; auto const rhs = - plan.data_references[plan.operator_source_indices[operator_source_index + 1]]; + plan.data_references[plan.operator_source_indices[operator_source_index++]]; auto const output = - plan.data_references[plan.operator_source_indices[operator_source_index + 2]]; - operator_source_index += arity + 1; + plan.data_references[plan.operator_source_indices[operator_source_index++]]; type_dispatcher(lhs.data_type, detail::single_dispatch_binary_operator{}, *this, diff --git a/cpp/include/cudf/ast/detail/expression_parser.hpp b/cpp/include/cudf/ast/detail/expression_parser.hpp index bb42bfbc631..9eca250b898 100644 --- a/cpp/include/cudf/ast/detail/expression_parser.hpp +++ b/cpp/include/cudf/ast/detail/expression_parser.hpp @@ -23,7 +23,9 @@ #include +#include #include +#include namespace cudf { namespace ast { @@ -131,8 +133,8 @@ class expression_parser { * @param right The right table used for evaluating the abstract syntax tree. */ expression_parser(node const& expr, - cudf::table_view left, - cudf::table_view right, + cudf::table_view const& left, + std::optional> right, bool has_nulls, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) @@ -149,11 +151,11 @@ class expression_parser { * @param table The table used for evaluating the abstract syntax tree. */ expression_parser(node const& expr, - cudf::table_view table, + cudf::table_view const& table, bool has_nulls, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) - : expression_parser(expr, table, table, has_nulls, stream, mr) + : expression_parser(expr, table, {}, has_nulls, stream, mr) { } @@ -322,7 +324,7 @@ class expression_parser { ///< owned by this class and persists until it is destroyed. cudf::table_view const& _left; - cudf::table_view const& _right; + std::optional> _right; cudf::size_type _node_count; intermediate_counter _intermediate_counter; bool _has_nulls; diff --git a/cpp/include/cudf/join.hpp b/cpp/include/cudf/join.hpp index d0d2083b85b..dbafa95ee77 100644 --- a/cpp/include/cudf/join.hpp +++ b/cpp/include/cudf/join.hpp @@ -673,8 +673,6 @@ class hash_join { * Result: {{1}, {0}} * @endcode * - * @throw cudf::logic_error if number of elements in `left_keys` or `right_keys` - * mismatch. * @throw cudf::logic_error if the binary predicate outputs a non-boolean result. * * @param left The left table @@ -692,8 +690,9 @@ conditional_inner_join( table_view left, table_view right, ast::expression binary_predicate, - null_equality compare_nulls = null_equality::EQUAL, - rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + null_equality compare_nulls = null_equality::EQUAL, + std::optional output_size = {}, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); /** * @brief Returns a pair of row index vectors corresponding to all pairs @@ -721,8 +720,6 @@ conditional_inner_join( * Result: {{0, 1, 2}, {None, 0, None}} * @endcode * - * @throw cudf::logic_error if number of elements in `left_keys` or `right_keys` - * mismatch. * @throw cudf::logic_error if the binary predicate outputs a non-boolean result. * * @param left The left table @@ -739,7 +736,8 @@ std::pair>, conditional_left_join(table_view left, table_view right, ast::expression binary_predicate, - null_equality compare_nulls = null_equality::EQUAL, + null_equality compare_nulls = null_equality::EQUAL, + std::optional output_size = {}, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); /** @@ -767,8 +765,6 @@ conditional_left_join(table_view left, * Result: {{0, 1, 2, None, None}, {None, 0, None, 1, 2}} * @endcode * - * @throw cudf::logic_error if number of elements in `left_keys` or `right_keys` - * mismatch. * @throw cudf::logic_error if the binary predicate outputs a non-boolean result. * * @param left The left table @@ -808,8 +804,6 @@ conditional_full_join(table_view left, * Result: {1} * @endcode * - * @throw cudf::logic_error if number of elements in `left_keys` or `right_keys` - * mismatch. * @throw cudf::logic_error if the binary predicate outputs a non-boolean result. * * @param left The left table @@ -826,8 +820,9 @@ std::unique_ptr> conditional_left_semi_join( table_view left, table_view right, ast::expression binary_predicate, - null_equality compare_nulls = null_equality::EQUAL, - rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + null_equality compare_nulls = null_equality::EQUAL, + std::optional output_size = {}, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); /** * @brief Returns an index vector corresponding to all rows in the left table @@ -849,8 +844,6 @@ std::unique_ptr> conditional_left_semi_join( * Result: {0, 2} * @endcode * - * @throw cudf::logic_error if number of elements in `left_keys` or `right_keys` - * mismatch. * @throw cudf::logic_error if the binary predicate outputs a non-boolean result. * * @param left The left table @@ -864,11 +857,111 @@ std::unique_ptr> conditional_left_semi_join( * `right` . */ std::unique_ptr> conditional_left_anti_join( + table_view left, + table_view right, + ast::expression binary_predicate, + null_equality compare_nulls = null_equality::EQUAL, + std::optional output_size = {}, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + +/** + * @brief Returns the exact number of matches (rows) when performing a + * conditional inner join between the specified tables where the predicate + * evaluates to true. + * + * If the provided predicate returns NULL for a pair of rows + * (left, right), that pair is not included in the output. + * + * @throw cudf::logic_error if the binary predicate outputs a non-boolean result. + * + * @param left The left table + * @param right The right table + * @param binary_predicate The condition on which to join. + * @param compare_nulls Whether the equality operator returns true or false for two nulls. + * @param mr Device memory resource used to allocate the returned table and columns' device memory + * + * @return The size that would result from performing the requested join. + */ +std::size_t conditional_inner_join_size( + table_view left, + table_view right, + ast::expression binary_predicate, + null_equality compare_nulls = null_equality::EQUAL, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + +/** + * @brief Returns the exact number of matches (rows) when performing a + * conditional left join between the specified tables where the predicate + * evaluates to true. + * + * If the provided predicate returns NULL for a pair of rows + * (left, right), that pair is not included in the output. + * + * @throw cudf::logic_error if the binary predicate outputs a non-boolean result. + * + * @param left The left table + * @param right The right table + * @param binary_predicate The condition on which to join. + * @param compare_nulls Whether the equality operator returns true or false for two nulls. + * @param mr Device memory resource used to allocate the returned table and columns' device memory + * + * @return The size that would result from performing the requested join. + */ +std::size_t conditional_left_join_size( + table_view left, + table_view right, + ast::expression binary_predicate, + null_equality compare_nulls = null_equality::EQUAL, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + +/** + * @brief Returns the exact number of matches (rows) when performing a + * conditional left semi join between the specified tables where the predicate + * evaluates to true. + * + * If the provided predicate returns NULL for a pair of rows + * (left, right), that pair is not included in the output. + * + * @throw cudf::logic_error if the binary predicate outputs a non-boolean result. + * + * @param left The left table + * @param right The right table + * @param binary_predicate The condition on which to join. + * @param compare_nulls Whether the equality operator returns true or false for two nulls. + * @param mr Device memory resource used to allocate the returned table and columns' device memory + * + * @return The size that would result from performing the requested join. + */ +std::size_t conditional_left_semi_join_size( table_view left, table_view right, ast::expression binary_predicate, null_equality compare_nulls = null_equality::EQUAL, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); +/** + * @brief Returns the exact number of matches (rows) when performing a + * conditional left anti join between the specified tables where the predicate + * evaluates to true. + * + * If the provided predicate returns NULL for a pair of rows + * (left, right), that pair is not included in the output. + * + * @throw cudf::logic_error if the binary predicate outputs a non-boolean result. + * + * @param left The left table + * @param right The right table + * @param binary_predicate The condition on which to join. + * @param compare_nulls Whether the equality operator returns true or false for two nulls. + * @param mr Device memory resource used to allocate the returned table and columns' device memory + * + * @return The size that would result from performing the requested join. + */ +std::size_t conditional_left_anti_join_size( + table_view left, + table_view right, + ast::expression binary_predicate, + null_equality compare_nulls = null_equality::EQUAL, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); /** @} */ // end of group } // namespace cudf diff --git a/cpp/include/cudf/transform.hpp b/cpp/include/cudf/transform.hpp index 6cf62d1c684..cf391b2b23d 100644 --- a/cpp/include/cudf/transform.hpp +++ b/cpp/include/cudf/transform.hpp @@ -81,6 +81,8 @@ std::pair, size_type> nans_to_nulls( * This evaluates an expression over a table to produce a new column. Also called an n-ary * transform. * + * @throws cudf::logic_error if passed an expression operating on table_reference::RIGHT. + * * @param table The table used for expression evaluation. * @param expr The root of the expression tree. * @param mr Device memory resource. diff --git a/cpp/src/ast/expression_parser.cpp b/cpp/src/ast/expression_parser.cpp index 66d72fbb454..760f47a5045 100644 --- a/cpp/src/ast/expression_parser.cpp +++ b/cpp/src/ast/expression_parser.cpp @@ -100,9 +100,18 @@ cudf::size_type expression_parser::visit(column_reference const& expr) // Increment the node index _node_count++; // Resolve node type - auto const data_type = expr.get_table_source() == table_reference::LEFT - ? expr.get_data_type(_left) - : expr.get_data_type(_right); + cudf::data_type data_type; + if (expr.get_table_source() == table_reference::LEFT) { + data_type = expr.get_data_type(_left); + } else { + if (_right.has_value()) { + data_type = expr.get_data_type(*_right); + } else { + CUDF_FAIL( + "Your expression contains a reference to the RIGHT table even though it will only be " + "evaluated on a single table (by convention, the LEFT table)."); + } + } // Push data reference auto const source = detail::device_data_reference(detail::device_data_reference_type::COLUMN, data_type, diff --git a/cpp/src/join/conditional_join.cu b/cpp/src/join/conditional_join.cu index 1538780db5e..ee076d80140 100644 --- a/cpp/src/join/conditional_join.cu +++ b/cpp/src/join/conditional_join.cu @@ -13,31 +13,235 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include -#include +#include +#include +#include #include #include +#include #include +#include +#include +#include +#include +#include #include +#include + namespace cudf { namespace detail { std::pair>, std::unique_ptr>> -conditional_join(table_view left, - table_view right, +conditional_join(table_view const& left, + table_view const& right, ast::expression binary_predicate, null_equality compare_nulls, - join_kind JoinKind, + join_kind join_type, + std::optional output_size, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { - CUDF_FUNC_RANGE(); - return get_conditional_join_indices( - left, right, binary_predicate, compare_nulls, JoinKind, stream, mr); + // We can immediately filter out cases where the right table is empty. In + // some cases, we return all the rows of the left table with a corresponding + // null index for the right table; in others, we return an empty output. + if (right.num_rows() == 0) { + switch (join_type) { + // Left, left anti, and full (which are effectively left because we are + // guaranteed that left has more rows than right) all return a all the + // row indices from left with a corresponding NULL from the right. + case join_kind::LEFT_JOIN: + case join_kind::LEFT_ANTI_JOIN: + case join_kind::FULL_JOIN: return get_trivial_left_join_indices(left, stream); + // Inner and left semi joins return empty output because no matches can exist. + case join_kind::INNER_JOIN: + case join_kind::LEFT_SEMI_JOIN: + return std::make_pair(std::make_unique>(0, stream, mr), + std::make_unique>(0, stream, mr)); + } + } + + // Prepare output column. Whether or not the output column is nullable is + // determined by whether any of the columns in the input table are nullable. + // If none of the input columns actually contain nulls, we can still use the + // non-nullable version of the expression evaluation code path for + // performance, so we capture that information as well. + auto const nullable = cudf::nullable(left) || cudf::nullable(right); + auto const has_nulls = nullable && (cudf::has_nulls(left) || cudf::has_nulls(right)); + + auto const parser = + ast::detail::expression_parser{binary_predicate, left, right, has_nulls, stream, mr}; + CUDF_EXPECTS(parser.output_type().id() == type_id::BOOL8, + "The expression must produce a boolean output."); + + auto left_table = table_device_view::create(left, stream); + auto right_table = table_device_view::create(right, stream); + + // Allocate storage for the counter used to get the size of the join output + detail::grid_1d config(left_table->num_rows(), DEFAULT_JOIN_BLOCK_SIZE); + auto const shmem_size_per_block = + parser.device_expression_data.shmem_per_thread * config.num_threads_per_block; + join_kind kernel_join_type = join_type == join_kind::FULL_JOIN ? join_kind::LEFT_JOIN : join_type; + + // If the join size was not provided as an input, compute it here. + std::size_t join_size; + if (output_size.has_value()) { + join_size = *output_size; + } else { + rmm::device_scalar size(0, stream, mr); + CHECK_CUDA(stream.value()); + if (has_nulls) { + compute_conditional_join_output_size + <<>>( + *left_table, + *right_table, + kernel_join_type, + compare_nulls, + parser.device_expression_data, + size.data()); + } else { + compute_conditional_join_output_size + <<>>( + *left_table, + *right_table, + kernel_join_type, + compare_nulls, + parser.device_expression_data, + size.data()); + } + CHECK_CUDA(stream.value()); + join_size = size.value(stream); + } + + // If the output size will be zero, we can return immediately. + if (join_size == 0) { + return std::make_pair(std::make_unique>(0, stream, mr), + std::make_unique>(0, stream, mr)); + } + + rmm::device_scalar write_index(0, stream); + + auto left_indices = std::make_unique>(join_size, stream, mr); + auto right_indices = std::make_unique>(join_size, stream, mr); + + auto const& join_output_l = left_indices->data(); + auto const& join_output_r = right_indices->data(); + if (has_nulls) { + conditional_join + <<>>( + *left_table, + *right_table, + kernel_join_type, + compare_nulls, + join_output_l, + join_output_r, + write_index.data(), + parser.device_expression_data, + join_size); + } else { + conditional_join + <<>>( + *left_table, + *right_table, + kernel_join_type, + compare_nulls, + join_output_l, + join_output_r, + write_index.data(), + parser.device_expression_data, + join_size); + } + + CHECK_CUDA(stream.value()); + + auto join_indices = std::make_pair(std::move(left_indices), std::move(right_indices)); + + // For full joins, get the indices in the right table that were not joined to + // by any row in the left table. + if (join_type == join_kind::FULL_JOIN) { + auto complement_indices = detail::get_left_join_indices_complement( + join_indices.second, left.num_rows(), right.num_rows(), stream, mr); + join_indices = detail::concatenate_vector_pairs(join_indices, complement_indices, stream); + } + return join_indices; +} + +std::size_t compute_conditional_join_output_size(table_view const& left, + table_view const& right, + ast::expression binary_predicate, + null_equality compare_nulls, + join_kind join_type, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + // We can immediately filter out cases where the right table is empty. In + // some cases, we return all the rows of the left table with a corresponding + // null index for the right table; in others, we return an empty output. + if (right.num_rows() == 0) { + switch (join_type) { + // Left, left anti, and full (which are effectively left because we are + // guaranteed that left has more rows than right) all return a all the + // row indices from left with a corresponding NULL from the right. + case join_kind::LEFT_JOIN: + case join_kind::LEFT_ANTI_JOIN: + case join_kind::FULL_JOIN: return left.num_rows(); + // Inner and left semi joins return empty output because no matches can exist. + case join_kind::INNER_JOIN: + case join_kind::LEFT_SEMI_JOIN: return 0; + } + } + + // Prepare output column. Whether or not the output column is nullable is + // determined by whether any of the columns in the input table are nullable. + // If none of the input columns actually contain nulls, we can still use the + // non-nullable version of the expression evaluation code path for + // performance, so we capture that information as well. + auto const nullable = cudf::nullable(left) || cudf::nullable(right); + auto const has_nulls = nullable && (cudf::has_nulls(left) || cudf::has_nulls(right)); + + auto const parser = + ast::detail::expression_parser{binary_predicate, left, right, has_nulls, stream, mr}; + CUDF_EXPECTS(parser.output_type().id() == type_id::BOOL8, + "The expression must produce a boolean output."); + + auto left_table = table_device_view::create(left, stream); + auto right_table = table_device_view::create(right, stream); + + // Allocate storage for the counter used to get the size of the join output + rmm::device_scalar size(0, stream, mr); + CHECK_CUDA(stream.value()); + detail::grid_1d config(left_table->num_rows(), DEFAULT_JOIN_BLOCK_SIZE); + auto const shmem_size_per_block = + parser.device_expression_data.shmem_per_thread * config.num_threads_per_block; + + // Determine number of output rows without actually building the output to simply + // find what the size of the output will be. + assert(join_type != join_kind::FULL_JOIN); + if (has_nulls) { + compute_conditional_join_output_size + <<>>( + *left_table, + *right_table, + join_type, + compare_nulls, + parser.device_expression_data, + size.data()); + } else { + compute_conditional_join_output_size + <<>>( + *left_table, + *right_table, + join_type, + compare_nulls, + parser.device_expression_data, + size.data()); + } + CHECK_CUDA(stream.value()); + + return size.value(stream); } } // namespace detail @@ -48,13 +252,16 @@ conditional_inner_join(table_view left, table_view right, ast::expression binary_predicate, null_equality compare_nulls, + std::optional output_size, rmm::mr::device_memory_resource* mr) { + CUDF_FUNC_RANGE(); return detail::conditional_join(left, right, binary_predicate, compare_nulls, detail::join_kind::INNER_JOIN, + output_size, rmm::cuda_stream_default, mr); } @@ -65,13 +272,16 @@ conditional_left_join(table_view left, table_view right, ast::expression binary_predicate, null_equality compare_nulls, + std::optional output_size, rmm::mr::device_memory_resource* mr) { + CUDF_FUNC_RANGE(); return detail::conditional_join(left, right, binary_predicate, compare_nulls, detail::join_kind::LEFT_JOIN, + output_size, rmm::cuda_stream_default, mr); } @@ -84,11 +294,13 @@ conditional_full_join(table_view left, null_equality compare_nulls, rmm::mr::device_memory_resource* mr) { + CUDF_FUNC_RANGE(); return detail::conditional_join(left, right, binary_predicate, compare_nulls, detail::join_kind::FULL_JOIN, + {}, rmm::cuda_stream_default, mr); } @@ -98,13 +310,16 @@ std::unique_ptr> conditional_left_semi_join( table_view right, ast::expression binary_predicate, null_equality compare_nulls, + std::optional output_size, rmm::mr::device_memory_resource* mr) { + CUDF_FUNC_RANGE(); return std::move(detail::conditional_join(left, right, binary_predicate, compare_nulls, detail::join_kind::LEFT_SEMI_JOIN, + output_size, rmm::cuda_stream_default, mr) .first); @@ -115,16 +330,83 @@ std::unique_ptr> conditional_left_anti_join( table_view right, ast::expression binary_predicate, null_equality compare_nulls, + std::optional output_size, rmm::mr::device_memory_resource* mr) { + CUDF_FUNC_RANGE(); return std::move(detail::conditional_join(left, right, binary_predicate, compare_nulls, detail::join_kind::LEFT_ANTI_JOIN, + output_size, rmm::cuda_stream_default, mr) .first); } +std::size_t conditional_inner_join_size(table_view left, + table_view right, + ast::expression binary_predicate, + null_equality compare_nulls, + rmm::mr::device_memory_resource* mr) +{ + CUDF_FUNC_RANGE(); + return detail::compute_conditional_join_output_size(left, + right, + binary_predicate, + compare_nulls, + detail::join_kind::INNER_JOIN, + rmm::cuda_stream_default, + mr); +} + +std::size_t conditional_left_join_size(table_view left, + table_view right, + ast::expression binary_predicate, + null_equality compare_nulls, + rmm::mr::device_memory_resource* mr) +{ + CUDF_FUNC_RANGE(); + return detail::compute_conditional_join_output_size(left, + right, + binary_predicate, + compare_nulls, + detail::join_kind::LEFT_JOIN, + rmm::cuda_stream_default, + mr); +} + +std::size_t conditional_left_semi_join_size(table_view left, + table_view right, + ast::expression binary_predicate, + null_equality compare_nulls, + rmm::mr::device_memory_resource* mr) +{ + CUDF_FUNC_RANGE(); + return std::move(detail::compute_conditional_join_output_size(left, + right, + binary_predicate, + compare_nulls, + detail::join_kind::LEFT_SEMI_JOIN, + rmm::cuda_stream_default, + mr)); +} + +std::size_t conditional_left_anti_join_size(table_view left, + table_view right, + ast::expression binary_predicate, + null_equality compare_nulls, + rmm::mr::device_memory_resource* mr) +{ + CUDF_FUNC_RANGE(); + return std::move(detail::compute_conditional_join_output_size(left, + right, + binary_predicate, + compare_nulls, + detail::join_kind::LEFT_ANTI_JOIN, + rmm::cuda_stream_default, + mr)); +} + } // namespace cudf diff --git a/cpp/src/join/conditional_join.cuh b/cpp/src/join/conditional_join.cuh deleted file mode 100644 index 3d5af7d0657..00000000000 --- a/cpp/src/join/conditional_join.cuh +++ /dev/null @@ -1,183 +0,0 @@ -/* - * Copyright (c) 2021, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -namespace cudf { -namespace detail { - -/** - * @brief Computes the join operation between two tables and returns the - * output indices of left and right table as a combined table - * - * @param left Table of left columns to join - * @param right Table of right columns to join - * tables have been flipped, meaning the output indices should also be flipped - * @param JoinKind The type of join to be performed - * @param compare_nulls Controls whether null join-key values should match or not. - * @param stream CUDA stream used for device memory operations and kernel launches - * - * @return Join output indices vector pair - */ -std::pair>, - std::unique_ptr>> -get_conditional_join_indices(table_view const& left, - table_view const& right, - ast::expression binary_predicate, - null_equality compare_nulls, - join_kind JoinKind, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) -{ - // We can immediately filter out cases where the right table is empty. In - // some cases, we return all the rows of the left table with a corresponding - // null index for the right table; in others, we return an empty output. - if (right.num_rows() == 0) { - switch (JoinKind) { - // Left, left anti, and full (which are effectively left because we are - // guaranteed that left has more rows than right) all return a all the - // row indices from left with a corresponding NULL from the right. - case join_kind::LEFT_JOIN: - case join_kind::LEFT_ANTI_JOIN: - case join_kind::FULL_JOIN: return get_trivial_left_join_indices(left, stream); - // Inner and left semi joins return empty output because no matches can exist. - case join_kind::INNER_JOIN: - case join_kind::LEFT_SEMI_JOIN: - return std::make_pair(std::make_unique>(0, stream, mr), - std::make_unique>(0, stream, mr)); - } - } - - // Prepare output column. Whether or not the output column is nullable is - // determined by whether any of the columns in the input table are nullable. - // If none of the input columns actually contain nulls, we can still use the - // non-nullable version of the expression evaluation code path for - // performance, so we capture that information as well. - auto const nullable = cudf::nullable(left) || cudf::nullable(right); - auto const has_nulls = nullable && (cudf::has_nulls(left) || cudf::has_nulls(right)); - - auto const parser = - ast::detail::expression_parser{binary_predicate, left, right, has_nulls, stream, mr}; - CUDF_EXPECTS(parser.output_type().id() == type_id::BOOL8, - "The expression must produce a boolean output."); - - auto left_table = table_device_view::create(left, stream); - auto right_table = table_device_view::create(right, stream); - - // Allocate storage for the counter used to get the size of the join output - rmm::device_scalar size(0, stream, mr); - CHECK_CUDA(stream.value()); - constexpr int block_size{DEFAULT_JOIN_BLOCK_SIZE}; - detail::grid_1d config(left_table->num_rows(), block_size); - auto const shmem_size_per_block = - parser.device_expression_data.shmem_per_thread * config.num_threads_per_block; - - // Determine number of output rows without actually building the output to simply - // find what the size of the output will be. - join_kind KernelJoinKind = JoinKind == join_kind::FULL_JOIN ? join_kind::LEFT_JOIN : JoinKind; - if (has_nulls) { - compute_conditional_join_output_size - <<>>( - *left_table, - *right_table, - KernelJoinKind, - compare_nulls, - parser.device_expression_data, - size.data()); - } else { - compute_conditional_join_output_size - <<>>( - *left_table, - *right_table, - KernelJoinKind, - compare_nulls, - parser.device_expression_data, - size.data()); - } - CHECK_CUDA(stream.value()); - - size_type const join_size = size.value(stream); - - // If the output size will be zero, we can return immediately. - if (join_size == 0) { - return std::make_pair(std::make_unique>(0, stream, mr), - std::make_unique>(0, stream, mr)); - } - - rmm::device_scalar write_index(0, stream); - - auto left_indices = std::make_unique>(join_size, stream, mr); - auto right_indices = std::make_unique>(join_size, stream, mr); - - auto const& join_output_l = left_indices->data(); - auto const& join_output_r = right_indices->data(); - if (has_nulls) { - conditional_join - <<>>( - *left_table, - *right_table, - KernelJoinKind, - compare_nulls, - join_output_l, - join_output_r, - write_index.data(), - parser.device_expression_data, - join_size); - } else { - conditional_join - <<>>( - *left_table, - *right_table, - KernelJoinKind, - compare_nulls, - join_output_l, - join_output_r, - write_index.data(), - parser.device_expression_data, - join_size); - } - - CHECK_CUDA(stream.value()); - - auto join_indices = std::make_pair(std::move(left_indices), std::move(right_indices)); - - // For full joins, get the indices in the right table that were not joined to - // by any row in the left table. - if (JoinKind == join_kind::FULL_JOIN) { - auto complement_indices = detail::get_left_join_indices_complement( - join_indices.second, left.num_rows(), right.num_rows(), stream, mr); - join_indices = detail::concatenate_vector_pairs(join_indices, complement_indices, stream); - } - return join_indices; -} - -} // namespace detail - -} // namespace cudf diff --git a/cpp/src/join/conditional_join.hpp b/cpp/src/join/conditional_join.hpp new file mode 100644 index 00000000000..b5b49815381 --- /dev/null +++ b/cpp/src/join/conditional_join.hpp @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "join_common_utils.hpp" + +#include +#include + +#include +#include + +#include + +namespace cudf { +namespace detail { + +/** + * @brief Computes the join operation between two tables and returns the + * output indices of left and right table as a combined table + * + * @param left Table of left columns to join + * @param right Table of right columns to join + * tables have been flipped, meaning the output indices should also be flipped + * @param JoinKind The type of join to be performed + * @param compare_nulls Controls whether null join-key values should match or not. + * @param stream CUDA stream used for device memory operations and kernel launches + * + * @return Join output indices vector pair + */ +std::pair>, + std::unique_ptr>> +conditional_join(table_view const& left, + table_view const& right, + ast::expression binary_predicate, + null_equality compare_nulls, + join_kind JoinKind, + std::optional output_size = {}, + rmm::cuda_stream_view stream = rmm::cuda_stream_default, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + +/** + * @brief Computes the size of a join operation between two tables without + * materializing the result and returns the total size value. + * + * @param left Table of left columns to join + * @param right Table of right columns to join + * tables have been flipped, meaning the output indices should also be flipped + * @param JoinKind The type of join to be performed + * @param compare_nulls Controls whether null join-key values should match or not. + * @param stream CUDA stream used for device memory operations and kernel launches + * + * @return Join output indices vector pair + */ +std::size_t compute_conditional_join_output_size( + table_view const& left, + table_view const& right, + ast::expression binary_predicate, + null_equality compare_nulls, + join_kind JoinKind, + rmm::cuda_stream_view stream = rmm::cuda_stream_default, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + +} // namespace detail +} // namespace cudf diff --git a/cpp/src/join/conditional_join_kernels.cuh b/cpp/src/join/conditional_join_kernels.cuh index 6f84a1c0fc0..9fcc7bf5cfb 100644 --- a/cpp/src/join/conditional_join_kernels.cuh +++ b/cpp/src/join/conditional_join_kernels.cuh @@ -40,7 +40,7 @@ namespace detail { * * @param[in] left_table The left table * @param[in] right_table The right table - * @param[in] JoinKind The type of join to be performed + * @param[in] join_type The type of join to be performed * @param[in] compare_nulls Controls whether null join-key values should match or not. * @param[in] device_expression_data Container of device data required to evaluate the desired * expression. @@ -50,10 +50,10 @@ template __global__ void compute_conditional_join_output_size( table_device_view left_table, table_device_view right_table, - join_kind JoinKind, + join_kind join_type, null_equality compare_nulls, ast::detail::expression_device_view device_expression_data, - cudf::size_type* output_size) + std::size_t* output_size) { // The (required) extern storage of the shared memory array leads to // conflicting declarations between different templates. The easiest @@ -65,7 +65,7 @@ __global__ void compute_conditional_join_output_size( auto thread_intermediate_storage = &intermediate_storage[threadIdx.x * device_expression_data.num_intermediates]; - cudf::size_type thread_counter(0); + std::size_t thread_counter{0}; cudf::size_type const left_start_idx = threadIdx.x + blockIdx.x * blockDim.x; cudf::size_type const left_stride = blockDim.x * gridDim.x; cudf::size_type const left_num_rows = left_table.num_rows(); @@ -81,15 +81,15 @@ __global__ void compute_conditional_join_output_size( auto output_dest = cudf::ast::detail::value_expression_result(); evaluator.evaluate(output_dest, left_row_index, right_row_index, 0); if (output_dest.is_valid() && output_dest.value()) { - if ((JoinKind != join_kind::LEFT_ANTI_JOIN) && - !(JoinKind == join_kind::LEFT_SEMI_JOIN && found_match)) { + if ((join_type != join_kind::LEFT_ANTI_JOIN) && + !(join_type == join_kind::LEFT_SEMI_JOIN && found_match)) { ++thread_counter; } found_match = true; } } - if ((JoinKind == join_kind::LEFT_JOIN || JoinKind == join_kind::LEFT_ANTI_JOIN || - JoinKind == join_kind::FULL_JOIN) && + if ((join_type == join_kind::LEFT_JOIN || join_type == join_kind::LEFT_ANTI_JOIN || + join_type == join_kind::FULL_JOIN) && (!found_match)) { ++thread_counter; } @@ -97,7 +97,7 @@ __global__ void compute_conditional_join_output_size( using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; - cudf::size_type block_counter = BlockReduce(temp_storage).Sum(thread_counter); + std::size_t block_counter = BlockReduce(temp_storage).Sum(thread_counter); // Add block counter to global counter if (threadIdx.x == 0) atomicAdd(output_size, block_counter); @@ -115,7 +115,7 @@ __global__ void compute_conditional_join_output_size( * * @param[in] left_table The left table * @param[in] right_table The right table - * @param[in] JoinKind The type of join to be performed + * @param[in] join_type The type of join to be performed * @param compare_nulls Controls whether null join-key values should match or not. * @param[out] join_output_l The left result of the join operation * @param[out] join_output_r The right result of the join operation @@ -128,7 +128,7 @@ __global__ void compute_conditional_join_output_size( template __global__ void conditional_join(table_device_view left_table, table_device_view right_table, - join_kind JoinKind, + join_kind join_type, null_equality compare_nulls, cudf::size_type* join_output_l, cudf::size_type* join_output_r, @@ -181,8 +181,8 @@ __global__ void conditional_join(table_device_view left_table, // that the current logic relies on the fact that we process all right // table rows for a single left table row on a single thread so that no // synchronization of found_match is required). - if ((JoinKind != join_kind::LEFT_ANTI_JOIN) && - !(JoinKind == join_kind::LEFT_SEMI_JOIN && found_match)) { + if ((join_type != join_kind::LEFT_ANTI_JOIN) && + !(join_type == join_kind::LEFT_SEMI_JOIN && found_match)) { add_pair_to_cache(left_row_index, right_row_index, current_idx_shared, @@ -214,8 +214,8 @@ __global__ void conditional_join(table_device_view left_table, // Left, left anti, and full joins all require saving left columns that // aren't present in the right. - if ((JoinKind == join_kind::LEFT_JOIN || JoinKind == join_kind::LEFT_ANTI_JOIN || - JoinKind == join_kind::FULL_JOIN) && + if ((join_type == join_kind::LEFT_JOIN || join_type == join_kind::LEFT_ANTI_JOIN || + join_type == join_kind::FULL_JOIN) && (!found_match)) { add_pair_to_cache(left_row_index, static_cast(JoinNoneValue), diff --git a/cpp/src/transform/compute_column.cu b/cpp/src/transform/compute_column.cu index 1d4cde10306..cd8196e555c 100644 --- a/cpp/src/transform/compute_column.cu +++ b/cpp/src/transform/compute_column.cu @@ -17,7 +17,6 @@ #include #include #include -#include #include #include #include diff --git a/cpp/tests/join/conditional_join_tests.cu b/cpp/tests/join/conditional_join_tests.cu index 57abdf17aa6..e16e1ec7de8 100644 --- a/cpp/tests/join/conditional_join_tests.cu +++ b/cpp/tests/join/conditional_join_tests.cu @@ -154,8 +154,10 @@ struct ConditionalJoinPairReturnTest : public ConditionalJoinTest { // resulting column views will be referencing potentially invalid memory. auto [left_wrappers, right_wrappers, left_columns, right_columns, left, right] = this->parse_input(left_data, right_data); - auto result = this->join(left, right, predicate); + auto result_size = this->join_size(left, right, predicate); + EXPECT_TRUE(result_size == expected_outputs.size()); + auto result = this->join(left, right, predicate); std::vector> result_pairs; for (size_t i = 0; i < result.first->size(); ++i) { // Note: Not trying to be terribly efficient here since these tests are @@ -167,7 +169,7 @@ struct ConditionalJoinPairReturnTest : public ConditionalJoinTest { std::sort(result_pairs.begin(), result_pairs.end()); std::sort(expected_outputs.begin(), expected_outputs.end()); - EXPECT_TRUE(std::equal(result_pairs.begin(), result_pairs.end(), expected_outputs.begin())); + EXPECT_TRUE(std::equal(expected_outputs.begin(), expected_outputs.end(), result_pairs.begin())); } void test_nulls(std::vector, std::vector>> left_data, @@ -179,8 +181,10 @@ struct ConditionalJoinPairReturnTest : public ConditionalJoinTest { // resulting column views will be referencing potentially invalid memory. auto [left_wrappers, right_wrappers, left_columns, right_columns, left, right] = this->parse_input(left_data, right_data); - auto result = this->join(left, right, predicate); + auto result_size = this->join_size(left, right, predicate); + EXPECT_TRUE(result_size == expected_outputs.size()); + auto result = this->join(left, right, predicate); std::vector> result_pairs; for (size_t i = 0; i < result.first->size(); ++i) { // Note: Not trying to be terribly efficient here since these tests are @@ -192,7 +196,7 @@ struct ConditionalJoinPairReturnTest : public ConditionalJoinTest { std::sort(result_pairs.begin(), result_pairs.end()); std::sort(expected_outputs.begin(), expected_outputs.end()); - EXPECT_TRUE(std::equal(result_pairs.begin(), result_pairs.end(), expected_outputs.begin())); + EXPECT_TRUE(std::equal(expected_outputs.begin(), expected_outputs.end(), result_pairs.begin())); } /* @@ -238,7 +242,7 @@ struct ConditionalJoinPairReturnTest : public ConditionalJoinTest { thrust::sort(thrust::device, reference_pairs.begin(), reference_pairs.end()); EXPECT_TRUE(thrust::equal( - thrust::device, result_pairs.begin(), result_pairs.end(), reference_pairs.begin())); + thrust::device, reference_pairs.begin(), reference_pairs.end(), result_pairs.begin())); } /** @@ -250,6 +254,15 @@ struct ConditionalJoinPairReturnTest : public ConditionalJoinTest { std::unique_ptr>> join(cudf::table_view left, cudf::table_view right, cudf::ast::expression predicate) = 0; + /** + * This method must be implemented by subclasses for specific types of joins. + * It should be a simply forwarding of arguments to the appropriate cudf + * conditional join size computation API. + */ + virtual std::size_t join_size(cudf::table_view left, + cudf::table_view right, + cudf::ast::expression predicate) = 0; + /** * This method must be implemented by subclasses for specific types of joins. * It should be a simply forwarding of arguments to the appropriate cudf @@ -272,6 +285,13 @@ struct ConditionalInnerJoinTest : public ConditionalJoinPairReturnTest { return cudf::conditional_inner_join(left, right, predicate); } + std::size_t join_size(cudf::table_view left, + cudf::table_view right, + cudf::ast::expression predicate) override + { + return cudf::conditional_inner_join_size(left, right, predicate); + } + std::pair>, std::unique_ptr>> reference_join(cudf::table_view left, cudf::table_view right) override @@ -384,6 +404,20 @@ TYPED_TEST(ConditionalInnerJoinTest, TestComplexConditionMultipleColumns) {{4, 0}, {5, 0}, {6, 0}, {7, 0}}); }; +TYPED_TEST(ConditionalInnerJoinTest, TestSymmetry) +{ + auto col_ref_0 = cudf::ast::column_reference(0); + auto col_ref_1 = cudf::ast::column_reference(0, cudf::ast::table_reference::RIGHT); + auto expression = cudf::ast::expression(cudf::ast::ast_operator::GREATER, col_ref_1, col_ref_0); + auto expression_reverse = + cudf::ast::expression(cudf::ast::ast_operator::LESS, col_ref_0, col_ref_1); + + this->test( + {{0, 1, 2}}, {{1, 2, 3}}, expression, {{0, 0}, {0, 1}, {0, 2}, {1, 1}, {1, 2}, {2, 2}}); + this->test( + {{0, 1, 2}}, {{1, 2, 3}}, expression_reverse, {{0, 0}, {0, 1}, {0, 2}, {1, 1}, {1, 2}, {2, 2}}); +}; + TYPED_TEST(ConditionalInnerJoinTest, TestCompareRandomToHash) { // Generate columns of 10 repeats of the integer range [0, 10), then merge @@ -418,7 +452,7 @@ TYPED_TEST(ConditionalInnerJoinTest, TestOneColumnTwoNullsRowAllEqual) TYPED_TEST(ConditionalInnerJoinTest, TestOneColumnTwoNullsNoOutputRowAllEqual) { - this->test_nulls({{{0, 1}, {0, 1}}}, {{{0, 0}, {1, 1}}}, left_zero_eq_right_zero, {{}, {}}); + this->test_nulls({{{0, 1}, {0, 1}}}, {{{0, 0}, {1, 1}}}, left_zero_eq_right_zero, {}); }; /** @@ -433,6 +467,13 @@ struct ConditionalLeftJoinTest : public ConditionalJoinPairReturnTest { return cudf::conditional_left_join(left, right, predicate); } + std::size_t join_size(cudf::table_view left, + cudf::table_view right, + cudf::ast::expression predicate) override + { + return cudf::conditional_left_join_size(left, right, predicate); + } + std::pair>, std::unique_ptr>> reference_join(cudf::table_view left, cudf::table_view right) override @@ -489,6 +530,16 @@ struct ConditionalFullJoinTest : public ConditionalJoinPairReturnTest { return cudf::conditional_full_join(left, right, predicate); } + std::size_t join_size(cudf::table_view left, + cudf::table_view right, + cudf::ast::expression predicate) override + { + // Full joins don't actually support size calculations, but to support a + // uniform testing framework we just calculate it from the result of doing + // the join. + return cudf::conditional_full_join(left, right, predicate).first->size(); + } + std::pair>, std::unique_ptr>> reference_join(cudf::table_view left, cudf::table_view right) override @@ -499,6 +550,19 @@ struct ConditionalFullJoinTest : public ConditionalJoinPairReturnTest { TYPED_TEST_CASE(ConditionalFullJoinTest, cudf::test::IntegralTypesNotBool); +TYPED_TEST(ConditionalFullJoinTest, TestOneColumnNoneEqual) +{ + this->test({{0, 1, 2}}, + {{3, 4, 5}}, + left_zero_eq_right_zero, + {{0, JoinNoneValue}, + {1, JoinNoneValue}, + {2, JoinNoneValue}, + {JoinNoneValue, 0}, + {JoinNoneValue, 1}, + {JoinNoneValue, 2}}); +}; + TYPED_TEST(ConditionalFullJoinTest, TestTwoColumnThreeRowSomeEqual) { this->test({{0, 1, 2}, {10, 20, 30}}, @@ -551,8 +615,10 @@ struct ConditionalJoinSingleReturnTest : public ConditionalJoinTest { { auto [left_wrappers, right_wrappers, left_columns, right_columns, left, right] = this->parse_input(left_data, right_data); - auto result = this->join(left, right, predicate); + auto result_size = this->join_size(left, right, predicate); + EXPECT_TRUE(result_size == expected_outputs.size()); + auto result = this->join(left, right, predicate); std::vector resulting_indices; for (size_t i = 0; i < result->size(); ++i) { // Note: Not trying to be terribly efficient here since these tests are @@ -597,6 +663,15 @@ struct ConditionalJoinSingleReturnTest : public ConditionalJoinTest { virtual std::unique_ptr> join( cudf::table_view left, cudf::table_view right, cudf::ast::expression predicate) = 0; + /** + * This method must be implemented by subclasses for specific types of joins. + * It should be a simply forwarding of arguments to the appropriate cudf + * conditional join size computation API. + */ + virtual std::size_t join_size(cudf::table_view left, + cudf::table_view right, + cudf::ast::expression predicate) = 0; + /** * This method must be implemented by subclasses for specific types of joins. * It should be a simply forwarding of arguments to the appropriate cudf @@ -617,6 +692,13 @@ struct ConditionalLeftSemiJoinTest : public ConditionalJoinSingleReturnTest { return cudf::conditional_left_semi_join(left, right, predicate); } + std::size_t join_size(cudf::table_view left, + cudf::table_view right, + cudf::ast::expression predicate) override + { + return cudf::conditional_left_semi_join_size(left, right, predicate); + } + std::unique_ptr> reference_join( cudf::table_view left, cudf::table_view right) override { @@ -668,6 +750,13 @@ struct ConditionalLeftAntiJoinTest : public ConditionalJoinSingleReturnTest { return cudf::conditional_left_anti_join(left, right, predicate); } + std::size_t join_size(cudf::table_view left, + cudf::table_view right, + cudf::ast::expression predicate) override + { + return cudf::conditional_left_anti_join_size(left, right, predicate); + } + std::unique_ptr> reference_join( cudf::table_view left, cudf::table_view right) override { diff --git a/java/src/test/java/ai/rapids/cudf/ast/CompiledExpressionTest.java b/java/src/test/java/ai/rapids/cudf/ast/CompiledExpressionTest.java index 5a64fd6ab09..177abe9d6e3 100644 --- a/java/src/test/java/ai/rapids/cudf/ast/CompiledExpressionTest.java +++ b/java/src/test/java/ai/rapids/cudf/ast/CompiledExpressionTest.java @@ -61,13 +61,12 @@ public void testColumnReferenceTransform() { @Test public void testInvalidColumnReferenceTransform() { - // verify attempting to reference an invalid table remaps to the only valid table + // Verify that computeColumn throws when passed an expression operating on TableReference.RIGHT. UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, new ColumnReference(1, TableReference.RIGHT)); try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build(); - CompiledExpression compiledExpr = expr.compile(); - ColumnVector actual = compiledExpr.computeColumn(t)) { - assertColumnsAreEqual(t.getColumn(1), actual); + CompiledExpression compiledExpr = expr.compile()) { + Assertions.assertThrows(CudfException.class, () -> compiledExpr.computeColumn(t).close()); } }