diff --git a/cpp/benchmarks/join/conditional_join_benchmark.cu b/cpp/benchmarks/join/conditional_join_benchmark.cu index 71b90685fb9..141e726027b 100644 --- a/cpp/benchmarks/join/conditional_join_benchmark.cu +++ b/cpp/benchmarks/join/conditional_join_benchmark.cu @@ -20,6 +20,10 @@ template class ConditionalJoin : public cudf::benchmark { }; +// For compatibility with the shared logic for equality (hash) joins, all of +// the join lambdas defined by these macros accept a null_equality parameter +// but ignore it (don't forward it to the underlying join implementation) +// because conditional joins do not use this parameter. #define CONDITIONAL_INNER_JOIN_BENCHMARK_DEFINE(name, key_type, payload_type, nullable) \ BENCHMARK_TEMPLATE_DEFINE_F(ConditionalJoin, name, key_type, payload_type) \ (::benchmark::State & st) \ @@ -28,7 +32,7 @@ class ConditionalJoin : public cudf::benchmark { cudf::table_view const& right, \ cudf::ast::operation binary_pred, \ cudf::null_equality compare_nulls) { \ - return cudf::conditional_inner_join(left, right, binary_pred, compare_nulls); \ + return cudf::conditional_inner_join(left, right, binary_pred); \ }; \ constexpr bool is_conditional = true; \ BM_join(st, join); \ @@ -47,7 +51,7 @@ CONDITIONAL_INNER_JOIN_BENCHMARK_DEFINE(conditional_inner_join_64bit_nulls, int6 cudf::table_view const& right, \ cudf::ast::operation binary_pred, \ cudf::null_equality compare_nulls) { \ - return cudf::conditional_left_join(left, right, binary_pred, compare_nulls); \ + return cudf::conditional_left_join(left, right, binary_pred); \ }; \ constexpr bool is_conditional = true; \ BM_join(st, join); \ @@ -66,7 +70,7 @@ CONDITIONAL_LEFT_JOIN_BENCHMARK_DEFINE(conditional_left_join_64bit_nulls, int64_ cudf::table_view const& right, \ cudf::ast::operation binary_pred, \ cudf::null_equality compare_nulls) { \ - return cudf::conditional_inner_join(left, right, binary_pred, compare_nulls); \ + return cudf::conditional_inner_join(left, right, binary_pred); \ }; \ constexpr bool is_conditional = true; \ BM_join(st, join); \ @@ -85,7 +89,7 @@ CONDITIONAL_FULL_JOIN_BENCHMARK_DEFINE(conditional_full_join_64bit_nulls, int64_ cudf::table_view const& right, \ cudf::ast::operation binary_pred, \ cudf::null_equality compare_nulls) { \ - return cudf::conditional_left_anti_join(left, right, binary_pred, compare_nulls); \ + return cudf::conditional_left_anti_join(left, right, binary_pred); \ }; \ constexpr bool is_conditional = true; \ BM_join(st, join); \ @@ -116,7 +120,7 @@ CONDITIONAL_LEFT_ANTI_JOIN_BENCHMARK_DEFINE(conditional_left_anti_join_64bit_nul cudf::table_view const& right, \ cudf::ast::operation binary_pred, \ cudf::null_equality compare_nulls) { \ - return cudf::conditional_left_semi_join(left, right, binary_pred, compare_nulls); \ + return cudf::conditional_left_semi_join(left, right, binary_pred); \ }; \ constexpr bool is_conditional = true; \ BM_join(st, join); \ diff --git a/cpp/include/cudf/ast/detail/expression_evaluator.cuh b/cpp/include/cudf/ast/detail/expression_evaluator.cuh index e3266c2ed47..b6c47afe19d 100644 --- a/cpp/include/cudf/ast/detail/expression_evaluator.cuh +++ b/cpp/include/cudf/ast/detail/expression_evaluator.cuh @@ -235,19 +235,13 @@ struct expression_evaluator { * @param plan The collection of device references representing the expression to evaluate. * @param thread_intermediate_storage Pointer to this thread's portion of shared memory for * storing intermediates. - * @param compare_nulls Whether the equality operator returns true or false for two nulls. */ __device__ expression_evaluator(table_device_view const& left, table_device_view const& right, expression_device_view const& plan, - IntermediateDataType* thread_intermediate_storage, - null_equality compare_nulls = null_equality::EQUAL) - : left(left), - right(right), - plan(plan), - thread_intermediate_storage(thread_intermediate_storage), - compare_nulls(compare_nulls) + IntermediateDataType* thread_intermediate_storage) + : left(left), right(right), plan(plan), thread_intermediate_storage(thread_intermediate_storage) { } @@ -258,17 +252,14 @@ struct expression_evaluator { * @param plan The collection of device references representing the expression to evaluate. * @param thread_intermediate_storage Pointer to this thread's portion of shared memory for * storing intermediates. - * @param compare_nulls Whether the equality operator returns true or false for two nulls. */ __device__ expression_evaluator(table_device_view const& table, expression_device_view const& plan, - IntermediateDataType* thread_intermediate_storage, - null_equality compare_nulls = null_equality::EQUAL) + IntermediateDataType* thread_intermediate_storage) : left(table), right(table), plan(plan), - thread_intermediate_storage(thread_intermediate_storage), - compare_nulls(compare_nulls) + thread_intermediate_storage(thread_intermediate_storage) { } @@ -603,32 +594,28 @@ struct expression_evaluator { * @param input Input to the operation. * @param output Output data reference. */ - template < - ast_operator op, - typename OutputType, - std::enable_if_t, Input>>* = nullptr> + template , + possibly_null_value_t>>* = nullptr> __device__ void operator()(OutputType& output_object, cudf::size_type const output_row_index, possibly_null_value_t const input, detail::device_data_reference const output) const { - using OperatorFunctor = detail::operator_functor; - using Out = cuda::std::invoke_result_t; - if constexpr (has_nulls) { - auto const result = input.has_value() - ? possibly_null_value_t(OperatorFunctor{}(*input)) - : possibly_null_value_t(); - this->template resolve_output(output_object, output, output_row_index, result); - } else { - this->template resolve_output( - output_object, output, output_row_index, OperatorFunctor{}(input)); - } + // The output data type is the same whether or not nulls are present, so + // pull from the non-nullable operator. + using Out = cuda::std::invoke_result_t, Input>; + this->template resolve_output( + output_object, output, output_row_index, detail::operator_functor{}(input)); } - template < - ast_operator op, - typename OutputType, - std::enable_if_t, Input>>* = nullptr> + template , + possibly_null_value_t>>* = nullptr> __device__ void operator()(OutputType& output_object, cudf::size_type const output_row_index, possibly_null_value_t const input, @@ -665,50 +652,31 @@ struct expression_evaluator { */ template , LHS, RHS>>* = nullptr> + std::enable_if_t, + possibly_null_value_t, + possibly_null_value_t>>* = + nullptr> __device__ void operator()(OutputType& output_object, cudf::size_type const output_row_index, possibly_null_value_t const lhs, possibly_null_value_t const rhs, detail::device_data_reference const output) const { - using OperatorFunctor = detail::operator_functor; - using Out = cuda::std::invoke_result_t; - if constexpr (has_nulls) { - if constexpr (op == ast_operator::EQUAL) { - // Special handling of the equality operator based on what kind - // of null handling was requested. - possibly_null_value_t result; - if (!lhs.has_value() && !rhs.has_value()) { - // Case 1: Both null, so the output is based on compare_nulls. - result = possibly_null_value_t(this->evaluator.compare_nulls == - null_equality::EQUAL); - } else if (lhs.has_value() && rhs.has_value()) { - // Case 2: Neither is null, so the output is given by the operation. - result = possibly_null_value_t(OperatorFunctor{}(*lhs, *rhs)); - } else { - // Case 3: One value is null, while the other is not, so we simply propagate nulls. - result = possibly_null_value_t(); - } - this->template resolve_output(output_object, output, output_row_index, result); - } else { - // Default behavior for all other operators is to propagate nulls. - auto result = (lhs.has_value() && rhs.has_value()) - ? possibly_null_value_t(OperatorFunctor{}(*lhs, *rhs)) - : possibly_null_value_t(); - this->template resolve_output(output_object, output, output_row_index, result); - } - } else { - this->template resolve_output( - output_object, output, output_row_index, OperatorFunctor{}(lhs, rhs)); - } + // The output data type is the same whether or not nulls are present, so + // pull from the non-nullable operator. + using Out = cuda::std::invoke_result_t, LHS, RHS>; + this->template resolve_output(output_object, + output, + output_row_index, + detail::operator_functor{}(lhs, rhs)); } template , LHS, RHS>>* = nullptr> + !detail::is_valid_binary_op, + possibly_null_value_t, + possibly_null_value_t>>* = nullptr> __device__ void operator()(OutputType& output_object, cudf::size_type const output_row_index, possibly_null_value_t const lhs, @@ -726,8 +694,6 @@ struct expression_evaluator { IntermediateDataType* thread_intermediate_storage; ///< The shared memory store of intermediates produced during ///< evaluation. - null_equality - compare_nulls; ///< Whether the equality operator returns true or false for two nulls. }; } // namespace detail diff --git a/cpp/include/cudf/ast/detail/operators.hpp b/cpp/include/cudf/ast/detail/operators.hpp index 00723004a9f..19df8d8e7b6 100644 --- a/cpp/include/cudf/ast/detail/operators.hpp +++ b/cpp/include/cudf/ast/detail/operators.hpp @@ -84,6 +84,9 @@ CUDA_HOST_DEVICE_CALLABLE constexpr void ast_operator_dispatcher(ast_operator op case ast_operator::EQUAL: f.template operator()(std::forward(args)...); break; + case ast_operator::NULL_EQUAL: + f.template operator()(std::forward(args)...); + break; case ast_operator::NOT_EQUAL: f.template operator()(std::forward(args)...); break; @@ -111,9 +114,15 @@ CUDA_HOST_DEVICE_CALLABLE constexpr void ast_operator_dispatcher(ast_operator op case ast_operator::LOGICAL_AND: f.template operator()(std::forward(args)...); break; + case ast_operator::NULL_LOGICAL_AND: + f.template operator()(std::forward(args)...); + break; case ast_operator::LOGICAL_OR: f.template operator()(std::forward(args)...); break; + case ast_operator::NULL_LOGICAL_OR: + f.template operator()(std::forward(args)...); + break; case ast_operator::IDENTITY: f.template operator()(std::forward(args)...); break; @@ -207,12 +216,12 @@ CUDA_HOST_DEVICE_CALLABLE constexpr void ast_operator_dispatcher(ast_operator op * * @tparam op AST operator. */ -template +template struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{2}; template @@ -223,7 +232,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{2}; template @@ -234,7 +243,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{2}; template @@ -245,7 +254,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{2}; template @@ -256,7 +265,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{2}; template @@ -268,7 +277,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{2}; template @@ -280,7 +289,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{2}; template { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{2}; template { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{2}; template @@ -373,7 +382,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{2}; template @@ -383,8 +392,14 @@ struct operator_functor { } }; +// Alias NULL_EQUAL = EQUAL in the non-nullable case. +template <> +struct operator_functor + : public operator_functor { +}; + template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{2}; template @@ -395,7 +410,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{2}; template @@ -406,7 +421,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{2}; template @@ -417,7 +432,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{2}; template @@ -428,7 +443,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{2}; template @@ -439,7 +454,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{2}; template @@ -450,7 +465,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{2}; template @@ -461,7 +476,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{2}; template @@ -472,7 +487,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{2}; template @@ -482,8 +497,14 @@ struct operator_functor { } }; +// Alias NULL_LOGICAL_AND = LOGICAL_AND in the non-nullable case. +template <> +struct operator_functor + : public operator_functor { +}; + template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{2}; template @@ -493,8 +514,14 @@ struct operator_functor { } }; +// Alias NULL_LOGICAL_OR = LOGICAL_OR in the non-nullable case. +template <> +struct operator_functor + : public operator_functor { +}; + template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{1}; template @@ -505,7 +532,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{1}; template ::value>* = nullptr> @@ -516,7 +543,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{1}; template ::value>* = nullptr> @@ -527,7 +554,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{1}; template ::value>* = nullptr> @@ -538,7 +565,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{1}; template ::value>* = nullptr> @@ -549,7 +576,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{1}; template ::value>* = nullptr> @@ -560,7 +587,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{1}; template ::value>* = nullptr> @@ -571,7 +598,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{1}; template ::value>* = nullptr> @@ -582,7 +609,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{1}; template ::value>* = nullptr> @@ -593,7 +620,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{1}; template ::value>* = nullptr> @@ -604,7 +631,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{1}; template ::value>* = nullptr> @@ -615,7 +642,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{1}; template ::value>* = nullptr> @@ -626,7 +653,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{1}; template ::value>* = nullptr> @@ -637,7 +664,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{1}; template @@ -648,7 +675,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{1}; template @@ -659,7 +686,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{1}; template @@ -670,7 +697,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{1}; template @@ -681,7 +708,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{1}; template @@ -692,7 +719,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{1}; template @@ -703,7 +730,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{1}; // Only accept signed or unsigned types (both require is_arithmetic to be true) @@ -721,7 +748,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{1}; template @@ -732,7 +759,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{1}; template @@ -743,7 +770,7 @@ struct operator_functor { }; template <> -struct operator_functor { +struct operator_functor { static constexpr auto arity{1}; template @@ -753,6 +780,104 @@ struct operator_functor { } }; +/* + * The default specialization of nullable operators is to fall back to the non-nullable + * implementation + */ +template +struct operator_functor { + using NonNullOperator = operator_functor; + static constexpr auto arity = NonNullOperator::arity; + + template * = nullptr> + CUDA_DEVICE_CALLABLE auto operator()(LHS const lhs, RHS const rhs) + -> possibly_null_value_t + { + using Out = possibly_null_value_t; + return (lhs.has_value() && rhs.has_value()) ? Out{NonNullOperator{}(*lhs, *rhs)} : Out{}; + } + + template * = nullptr> + CUDA_DEVICE_CALLABLE auto operator()(Input const input) + -> possibly_null_value_t + { + using Out = possibly_null_value_t; + return input.has_value() ? Out{NonNullOperator{}(*input)} : Out{}; + } +}; + +// NULL_EQUAL(null, null) is true, NULL_EQUAL(null, valid) is false, and NULL_EQUAL(valid, valid) == +// EQUAL(valid, valid) +template <> +struct operator_functor { + using NonNullOperator = operator_functor; + static constexpr auto arity = NonNullOperator::arity; + + template + CUDA_DEVICE_CALLABLE auto operator()(LHS const lhs, RHS const rhs) + -> possibly_null_value_t + { + // Case 1: Neither is null, so the output is given by the operation. + if (lhs.has_value() && rhs.has_value()) { return {NonNullOperator{}(*lhs, *rhs)}; } + // Case 2: Two nulls compare equal. + if (!lhs.has_value() && !rhs.has_value()) { return {true}; } + // Case 3: One value is null, while the other is not, so we return false. + return {false}; + } +}; + +///< NULL_LOGICAL_AND(null, null) is null, NULL_LOGICAL_AND(null, true) is null, +///< NULL_LOGICAL_AND(null, false) is false, and NULL_LOGICAL_AND(valid, valid) == +///< LOGICAL_AND(valid, valid) +template <> +struct operator_functor { + using NonNullOperator = operator_functor; + static constexpr auto arity = NonNullOperator::arity; + + template + CUDA_DEVICE_CALLABLE auto operator()(LHS const lhs, RHS const rhs) + -> possibly_null_value_t + { + // Case 1: Neither is null, so the output is given by the operation. + if (lhs.has_value() && rhs.has_value()) { return {NonNullOperator{}(*lhs, *rhs)}; } + // Case 2: Two nulls return null. + if (!lhs.has_value() && !rhs.has_value()) { return {}; } + // Case 3: One value is null, while the other is not. If it's true we return null, otherwise we + // return false. + auto const& valid_element = lhs.has_value() ? lhs : rhs; + if (*valid_element) { return {}; } + return {false}; + } +}; + +///< NULL_LOGICAL_OR(null, null) is null, NULL_LOGICAL_OR(null, true) is true, NULL_LOGICAL_OR(null, +///< false) is null, and NULL_LOGICAL_OR(valid, valid) == LOGICAL_OR(valid, valid) +template <> +struct operator_functor { + using NonNullOperator = operator_functor; + static constexpr auto arity = NonNullOperator::arity; + + template + CUDA_DEVICE_CALLABLE auto operator()(LHS const lhs, RHS const rhs) + -> possibly_null_value_t + { + // Case 1: Neither is null, so the output is given by the operation. + if (lhs.has_value() && rhs.has_value()) { return {NonNullOperator{}(*lhs, *rhs)}; } + // Case 2: Two nulls return null. + if (!lhs.has_value() && !rhs.has_value()) { return {}; } + // Case 3: One value is null, while the other is not. If it's true we return true, otherwise we + // return null. + auto const& valid_element = lhs.has_value() ? lhs : rhs; + if (*valid_element) { return {true}; } + return {}; + } +}; + /** * @brief Functor used to single-type-dispatch binary operators. * @@ -812,10 +937,12 @@ struct type_dispatch_binary_op { Ts&&... args) { // Single dispatch (assume lhs_type == rhs_type) - type_dispatcher(lhs_type, - detail::single_dispatch_binary_operator_types>{}, - std::forward(f), - std::forward(args)...); + type_dispatcher( + lhs_type, + // Always dispatch to the non-null operator for the purpose of type determination. + detail::single_dispatch_binary_operator_types>{}, + std::forward(f), + std::forward(args)...); } }; @@ -881,10 +1008,12 @@ struct type_dispatch_unary_op { template CUDA_HOST_DEVICE_CALLABLE void operator()(cudf::data_type input_type, F&& f, Ts&&... args) { - type_dispatcher(input_type, - detail::dispatch_unary_operator_types>{}, - std::forward(f), - std::forward(args)...); + type_dispatcher( + input_type, + // Always dispatch to the non-null operator for the purpose of type determination. + detail::dispatch_unary_operator_types>{}, + std::forward(f), + std::forward(args)...); } }; @@ -1005,7 +1134,8 @@ struct arity_functor { template CUDA_HOST_DEVICE_CALLABLE void operator()(cudf::size_type& result) { - result = operator_functor::arity; + // Arity is not dependent on null handling, so just use the false implementation here. + result = operator_functor::arity; } }; diff --git a/cpp/include/cudf/ast/expressions.hpp b/cpp/include/cudf/ast/expressions.hpp index 865c4b9fbbd..5454f9a2b95 100644 --- a/cpp/include/cudf/ast/expressions.hpp +++ b/cpp/include/cudf/ast/expressions.hpp @@ -55,27 +55,38 @@ struct expression { */ enum class ast_operator { // Binary operators - ADD, ///< operator + - SUB, ///< operator - - MUL, ///< operator * - DIV, ///< operator / using common type of lhs and rhs - TRUE_DIV, ///< operator / after promoting type to floating point - FLOOR_DIV, ///< operator / after promoting to 64 bit floating point and then - ///< flooring the result - MOD, ///< operator % - PYMOD, ///< operator % but following python's sign rules for negatives - POW, ///< lhs ^ rhs - EQUAL, ///< operator == - NOT_EQUAL, ///< operator != - LESS, ///< operator < - GREATER, ///< operator > - LESS_EQUAL, ///< operator <= - GREATER_EQUAL, ///< operator >= - BITWISE_AND, ///< operator & - BITWISE_OR, ///< operator | - BITWISE_XOR, ///< operator ^ - LOGICAL_AND, ///< operator && - LOGICAL_OR, ///< operator || + ADD, ///< operator + + SUB, ///< operator - + MUL, ///< operator * + DIV, ///< operator / using common type of lhs and rhs + TRUE_DIV, ///< operator / after promoting type to floating point + FLOOR_DIV, ///< operator / after promoting to 64 bit floating point and then + ///< flooring the result + MOD, ///< operator % + PYMOD, ///< operator % using Python's sign rules for negatives + POW, ///< lhs ^ rhs + EQUAL, ///< operator == + NULL_EQUAL, ///< operator == with Spark rules: NULL_EQUAL(null, null) is true, NULL_EQUAL(null, + ///< valid) is false, and + ///< NULL_EQUAL(valid, valid) == EQUAL(valid, valid) + NOT_EQUAL, ///< operator != + LESS, ///< operator < + GREATER, ///< operator > + LESS_EQUAL, ///< operator <= + GREATER_EQUAL, ///< operator >= + BITWISE_AND, ///< operator & + BITWISE_OR, ///< operator | + BITWISE_XOR, ///< operator ^ + LOGICAL_AND, ///< operator && + NULL_LOGICAL_AND, ///< operator && with Spark rules: NULL_LOGICAL_AND(null, null) is null, + ///< NULL_LOGICAL_AND(null, true) is + ///< null, NULL_LOGICAL_AND(null, false) is false, and NULL_LOGICAL_AND(valid, + ///< valid) == LOGICAL_AND(valid, valid) + LOGICAL_OR, ///< operator || + NULL_LOGICAL_OR, ///< operator || with Spark rules: NULL_LOGICAL_OR(null, null) is null, + ///< NULL_LOGICAL_OR(null, true) is true, + ///< NULL_LOGICAL_OR(null, false) is null, and NULL_LOGICAL_OR(valid, valid) == + ///< LOGICAL_OR(valid, valid) // Unary operators IDENTITY, ///< Identity function SIN, ///< Trigonometric sine diff --git a/cpp/include/cudf/join.hpp b/cpp/include/cudf/join.hpp index 483cd75c739..c55c6d49028 100644 --- a/cpp/include/cudf/join.hpp +++ b/cpp/include/cudf/join.hpp @@ -678,7 +678,6 @@ class hash_join { * @param left The left table * @param right The right table * @param binary_predicate The condition on which to join. - * @param compare_nulls Whether the equality operator returns true or false for two nulls. * @param mr Device memory resource used to allocate the returned table and columns' device memory * * @return A pair of vectors [`left_indices`, `right_indices`] that can be used to construct @@ -690,7 +689,6 @@ conditional_inner_join( table_view const& left, table_view const& right, ast::expression const& binary_predicate, - null_equality compare_nulls = null_equality::EQUAL, std::optional output_size = {}, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); @@ -725,7 +723,6 @@ conditional_inner_join( * @param left The left table * @param right The right table * @param binary_predicate The condition on which to join. - * @param compare_nulls Whether the equality operator returns true or false for two nulls. * @param mr Device memory resource used to allocate the returned table and columns' device memory * * @return A pair of vectors [`left_indices`, `right_indices`] that can be used to construct @@ -736,7 +733,6 @@ std::pair>, conditional_left_join(table_view const& left, table_view const& right, ast::expression const& binary_predicate, - null_equality compare_nulls = null_equality::EQUAL, std::optional output_size = {}, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); @@ -770,7 +766,6 @@ conditional_left_join(table_view const& left, * @param left The left table * @param right The right table * @param binary_predicate The condition on which to join. - * @param compare_nulls Whether the equality operator returns true or false for two nulls. * @param mr Device memory resource used to allocate the returned table and columns' device memory * * @return A pair of vectors [`left_indices`, `right_indices`] that can be used to construct @@ -781,7 +776,6 @@ std::pair>, conditional_full_join(table_view const& left, table_view const& right, ast::expression const& binary_predicate, - null_equality compare_nulls = null_equality::EQUAL, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); /** @@ -809,7 +803,6 @@ conditional_full_join(table_view const& left, * @param left The left table * @param right The right table * @param binary_predicate The condition on which to join. - * @param compare_nulls Whether the equality operator returns true or false for two nulls. * @param mr Device memory resource used to allocate the returned table and columns' device memory * * @return A vector `left_indices` that can be used to construct the result of @@ -820,7 +813,6 @@ std::unique_ptr> conditional_left_semi_join( table_view const& left, table_view const& right, ast::expression const& binary_predicate, - null_equality compare_nulls = null_equality::EQUAL, std::optional output_size = {}, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); @@ -849,7 +841,6 @@ std::unique_ptr> conditional_left_semi_join( * @param left The left table * @param right The right table * @param binary_predicate The condition on which to join. - * @param compare_nulls Whether the equality operator returns true or false for two nulls. * @param mr Device memory resource used to allocate the returned table and columns' device memory * * @return A vector `left_indices` that can be used to construct the result of @@ -860,7 +851,6 @@ std::unique_ptr> conditional_left_anti_join( table_view const& left, table_view const& right, ast::expression const& binary_predicate, - null_equality compare_nulls = null_equality::EQUAL, std::optional output_size = {}, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); @@ -877,7 +867,6 @@ std::unique_ptr> conditional_left_anti_join( * @param left The left table * @param right The right table * @param binary_predicate The condition on which to join. - * @param compare_nulls Whether the equality operator returns true or false for two nulls. * @param mr Device memory resource used to allocate the returned table and columns' device memory * * @return The size that would result from performing the requested join. @@ -886,7 +875,6 @@ std::size_t conditional_inner_join_size( table_view const& left, table_view const& right, ast::expression const& binary_predicate, - null_equality compare_nulls = null_equality::EQUAL, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); /** @@ -902,7 +890,6 @@ std::size_t conditional_inner_join_size( * @param left The left table * @param right The right table * @param binary_predicate The condition on which to join. - * @param compare_nulls Whether the equality operator returns true or false for two nulls. * @param mr Device memory resource used to allocate the returned table and columns' device memory * * @return The size that would result from performing the requested join. @@ -911,7 +898,6 @@ std::size_t conditional_left_join_size( table_view const& left, table_view const& right, ast::expression const& binary_predicate, - null_equality compare_nulls = null_equality::EQUAL, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); /** @@ -927,7 +913,6 @@ std::size_t conditional_left_join_size( * @param left The left table * @param right The right table * @param binary_predicate The condition on which to join. - * @param compare_nulls Whether the equality operator returns true or false for two nulls. * @param mr Device memory resource used to allocate the returned table and columns' device memory * * @return The size that would result from performing the requested join. @@ -936,7 +921,6 @@ std::size_t conditional_left_semi_join_size( table_view const& left, table_view const& right, ast::expression const& binary_predicate, - null_equality compare_nulls = null_equality::EQUAL, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); /** @@ -952,7 +936,6 @@ std::size_t conditional_left_semi_join_size( * @param left The left table * @param right The right table * @param binary_predicate The condition on which to join. - * @param compare_nulls Whether the equality operator returns true or false for two nulls. * @param mr Device memory resource used to allocate the returned table and columns' device memory * * @return The size that would result from performing the requested join. @@ -961,7 +944,6 @@ std::size_t conditional_left_anti_join_size( table_view const& left, table_view const& right, ast::expression const& binary_predicate, - null_equality compare_nulls = null_equality::EQUAL, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); /** @} */ // end of group } // namespace cudf diff --git a/cpp/src/join/conditional_join.cu b/cpp/src/join/conditional_join.cu index f409d626935..1f49ee749ec 100644 --- a/cpp/src/join/conditional_join.cu +++ b/cpp/src/join/conditional_join.cu @@ -39,7 +39,6 @@ std::pair>, conditional_join(table_view const& left, table_view const& right, ast::expression const& binary_predicate, - null_equality compare_nulls, join_kind join_type, std::optional output_size, rmm::cuda_stream_view stream, @@ -109,21 +108,11 @@ conditional_join(table_view const& left, if (has_nulls) { compute_conditional_join_output_size <<>>( - *left_table, - *right_table, - kernel_join_type, - compare_nulls, - parser.device_expression_data, - size.data()); + *left_table, *right_table, kernel_join_type, parser.device_expression_data, size.data()); } else { compute_conditional_join_output_size <<>>( - *left_table, - *right_table, - kernel_join_type, - compare_nulls, - parser.device_expression_data, - size.data()); + *left_table, *right_table, kernel_join_type, parser.device_expression_data, size.data()); } CHECK_CUDA(stream.value()); join_size = size.value(stream); @@ -153,7 +142,6 @@ conditional_join(table_view const& left, *left_table, *right_table, kernel_join_type, - compare_nulls, join_output_l, join_output_r, write_index.data(), @@ -165,7 +153,6 @@ conditional_join(table_view const& left, *left_table, *right_table, kernel_join_type, - compare_nulls, join_output_l, join_output_r, write_index.data(), @@ -190,7 +177,6 @@ conditional_join(table_view const& left, std::size_t compute_conditional_join_output_size(table_view const& left, table_view const& right, ast::expression const& binary_predicate, - null_equality compare_nulls, join_kind join_type, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) @@ -252,21 +238,11 @@ std::size_t compute_conditional_join_output_size(table_view const& left, if (has_nulls) { compute_conditional_join_output_size <<>>( - *left_table, - *right_table, - join_type, - compare_nulls, - parser.device_expression_data, - size.data()); + *left_table, *right_table, join_type, parser.device_expression_data, size.data()); } else { compute_conditional_join_output_size <<>>( - *left_table, - *right_table, - join_type, - compare_nulls, - parser.device_expression_data, - size.data()); + *left_table, *right_table, join_type, parser.device_expression_data, size.data()); } CHECK_CUDA(stream.value()); @@ -280,7 +256,6 @@ std::pair>, conditional_inner_join(table_view const& left, table_view const& right, ast::expression const& binary_predicate, - null_equality compare_nulls, std::optional output_size, rmm::mr::device_memory_resource* mr) { @@ -288,7 +263,6 @@ conditional_inner_join(table_view const& left, return detail::conditional_join(left, right, binary_predicate, - compare_nulls, detail::join_kind::INNER_JOIN, output_size, rmm::cuda_stream_default, @@ -300,7 +274,6 @@ std::pair>, conditional_left_join(table_view const& left, table_view const& right, ast::expression const& binary_predicate, - null_equality compare_nulls, std::optional output_size, rmm::mr::device_memory_resource* mr) { @@ -308,7 +281,6 @@ conditional_left_join(table_view const& left, return detail::conditional_join(left, right, binary_predicate, - compare_nulls, detail::join_kind::LEFT_JOIN, output_size, rmm::cuda_stream_default, @@ -320,25 +292,17 @@ std::pair>, conditional_full_join(table_view const& left, table_view const& right, ast::expression const& binary_predicate, - null_equality compare_nulls, rmm::mr::device_memory_resource* mr) { CUDF_FUNC_RANGE(); - return detail::conditional_join(left, - right, - binary_predicate, - compare_nulls, - detail::join_kind::FULL_JOIN, - {}, - rmm::cuda_stream_default, - mr); + return detail::conditional_join( + left, right, binary_predicate, detail::join_kind::FULL_JOIN, {}, rmm::cuda_stream_default, mr); } std::unique_ptr> conditional_left_semi_join( table_view const& left, table_view const& right, ast::expression const& binary_predicate, - null_equality compare_nulls, std::optional output_size, rmm::mr::device_memory_resource* mr) { @@ -346,7 +310,6 @@ std::unique_ptr> conditional_left_semi_join( return std::move(detail::conditional_join(left, right, binary_predicate, - compare_nulls, detail::join_kind::LEFT_SEMI_JOIN, output_size, rmm::cuda_stream_default, @@ -358,7 +321,6 @@ std::unique_ptr> conditional_left_anti_join( table_view const& left, table_view const& right, ast::expression const& binary_predicate, - null_equality compare_nulls, std::optional output_size, rmm::mr::device_memory_resource* mr) { @@ -366,7 +328,6 @@ std::unique_ptr> conditional_left_anti_join( return std::move(detail::conditional_join(left, right, binary_predicate, - compare_nulls, detail::join_kind::LEFT_ANTI_JOIN, output_size, rmm::cuda_stream_default, @@ -377,46 +338,32 @@ std::unique_ptr> conditional_left_anti_join( std::size_t conditional_inner_join_size(table_view const& left, table_view const& right, ast::expression const& binary_predicate, - null_equality compare_nulls, rmm::mr::device_memory_resource* mr) { CUDF_FUNC_RANGE(); - return detail::compute_conditional_join_output_size(left, - right, - binary_predicate, - compare_nulls, - detail::join_kind::INNER_JOIN, - rmm::cuda_stream_default, - mr); + return detail::compute_conditional_join_output_size( + left, right, binary_predicate, detail::join_kind::INNER_JOIN, rmm::cuda_stream_default, mr); } std::size_t conditional_left_join_size(table_view const& left, table_view const& right, ast::expression const& binary_predicate, - null_equality compare_nulls, rmm::mr::device_memory_resource* mr) { CUDF_FUNC_RANGE(); - return detail::compute_conditional_join_output_size(left, - right, - binary_predicate, - compare_nulls, - detail::join_kind::LEFT_JOIN, - rmm::cuda_stream_default, - mr); + return detail::compute_conditional_join_output_size( + left, right, binary_predicate, detail::join_kind::LEFT_JOIN, rmm::cuda_stream_default, mr); } std::size_t conditional_left_semi_join_size(table_view const& left, table_view const& right, ast::expression const& binary_predicate, - null_equality compare_nulls, rmm::mr::device_memory_resource* mr) { CUDF_FUNC_RANGE(); return std::move(detail::compute_conditional_join_output_size(left, right, binary_predicate, - compare_nulls, detail::join_kind::LEFT_SEMI_JOIN, rmm::cuda_stream_default, mr)); @@ -425,14 +372,12 @@ std::size_t conditional_left_semi_join_size(table_view const& left, std::size_t conditional_left_anti_join_size(table_view const& left, table_view const& right, ast::expression const& binary_predicate, - null_equality compare_nulls, rmm::mr::device_memory_resource* mr) { CUDF_FUNC_RANGE(); return std::move(detail::compute_conditional_join_output_size(left, right, binary_predicate, - compare_nulls, detail::join_kind::LEFT_ANTI_JOIN, rmm::cuda_stream_default, mr)); diff --git a/cpp/src/join/conditional_join.hpp b/cpp/src/join/conditional_join.hpp index 5a3fe887838..911cbd222a0 100644 --- a/cpp/src/join/conditional_join.hpp +++ b/cpp/src/join/conditional_join.hpp @@ -36,7 +36,6 @@ namespace detail { * @param right Table of right columns to join * tables have been flipped, meaning the output indices should also be flipped * @param JoinKind The type of join to be performed - * @param compare_nulls Controls whether null join-key values should match or not. * @param stream CUDA stream used for device memory operations and kernel launches * * @return Join output indices vector pair @@ -46,7 +45,6 @@ std::pair>, conditional_join(table_view const& left, table_view const& right, ast::expression const& binary_predicate, - null_equality compare_nulls, join_kind JoinKind, std::optional output_size = {}, rmm::cuda_stream_view stream = rmm::cuda_stream_default, @@ -60,7 +58,6 @@ conditional_join(table_view const& left, * @param right Table of right columns to join * tables have been flipped, meaning the output indices should also be flipped * @param JoinKind The type of join to be performed - * @param compare_nulls Controls whether null join-key values should match or not. * @param stream CUDA stream used for device memory operations and kernel launches * * @return Join output indices vector pair @@ -69,7 +66,6 @@ std::size_t compute_conditional_join_output_size( table_view const& left, table_view const& right, ast::expression const& binary_predicate, - null_equality compare_nulls, join_kind JoinKind, rmm::cuda_stream_view stream = rmm::cuda_stream_default, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); diff --git a/cpp/src/join/conditional_join_kernels.cuh b/cpp/src/join/conditional_join_kernels.cuh index 9fcc7bf5cfb..2ad7c6ad8b8 100644 --- a/cpp/src/join/conditional_join_kernels.cuh +++ b/cpp/src/join/conditional_join_kernels.cuh @@ -41,7 +41,6 @@ namespace detail { * @param[in] left_table The left table * @param[in] right_table The right table * @param[in] join_type The type of join to be performed - * @param[in] compare_nulls Controls whether null join-key values should match or not. * @param[in] device_expression_data Container of device data required to evaluate the desired * expression. * @param[out] output_size The resulting output size @@ -51,7 +50,6 @@ __global__ void compute_conditional_join_output_size( table_device_view left_table, table_device_view right_table, join_kind join_type, - null_equality compare_nulls, ast::detail::expression_device_view device_expression_data, std::size_t* output_size) { @@ -72,7 +70,7 @@ __global__ void compute_conditional_join_output_size( cudf::size_type const right_num_rows = right_table.num_rows(); auto evaluator = cudf::ast::detail::expression_evaluator( - left_table, right_table, device_expression_data, thread_intermediate_storage, compare_nulls); + left_table, right_table, device_expression_data, thread_intermediate_storage); for (cudf::size_type left_row_index = left_start_idx; left_row_index < left_num_rows; left_row_index += left_stride) { @@ -116,7 +114,6 @@ __global__ void compute_conditional_join_output_size( * @param[in] left_table The left table * @param[in] right_table The right table * @param[in] join_type The type of join to be performed - * @param compare_nulls Controls whether null join-key values should match or not. * @param[out] join_output_l The left result of the join operation * @param[out] join_output_r The right result of the join operation * @param[in,out] current_idx A global counter used by threads to coordinate @@ -129,7 +126,6 @@ template ( - left_table, right_table, device_expression_data, thread_intermediate_storage, compare_nulls); + left_table, right_table, device_expression_data, thread_intermediate_storage); if (left_row_index < left_num_rows) { bool found_match = false; diff --git a/cpp/tests/ast/transform_tests.cpp b/cpp/tests/ast/transform_tests.cpp index 0bac0484637..175918a0846 100644 --- a/cpp/tests/ast/transform_tests.cpp +++ b/cpp/tests/ast/transform_tests.cpp @@ -109,6 +109,22 @@ TEST_F(TransformTest, BasicAddition) cudf::test::expect_columns_equal(expected, result->view(), verbosity); } +TEST_F(TransformTest, BasicEquality) +{ + auto c_0 = column_wrapper{3, 20, 1, 50}; + auto c_1 = column_wrapper{3, 7, 1, 0}; + auto table = cudf::table_view{{c_0, c_1}}; + + auto col_ref_0 = cudf::ast::column_reference(0); + auto col_ref_1 = cudf::ast::column_reference(1); + auto expression = cudf::ast::operation(cudf::ast::ast_operator::EQUAL, col_ref_0, col_ref_1); + + auto expected = column_wrapper{true, false, true, false}; + auto result = cudf::compute_column(table, expression); + + cudf::test::expect_columns_equal(expected, result->view(), verbosity); +} + TEST_F(TransformTest, BasicAdditionLarge) { auto a = thrust::make_counting_iterator(0); @@ -460,6 +476,69 @@ TEST_F(TransformTest, PyMod) cudf::test::expect_columns_equal(expected, result->view(), verbosity); } +TEST_F(TransformTest, BasicEqualityNullEqualNoNulls) +{ + auto c_0 = column_wrapper{3, 20, 1, 50}; + auto c_1 = column_wrapper{3, 7, 1, 0}; + auto table = cudf::table_view{{c_0, c_1}}; + + auto col_ref_0 = cudf::ast::column_reference(0); + auto col_ref_1 = cudf::ast::column_reference(1); + auto expression = cudf::ast::operation(cudf::ast::ast_operator::NULL_EQUAL, col_ref_0, col_ref_1); + + auto expected = column_wrapper{true, false, true, false}; + auto result = cudf::compute_column(table, expression); + + cudf::test::expect_columns_equal(expected, result->view(), verbosity); +} + +TEST_F(TransformTest, BasicEqualityNormalEqualWithNulls) +{ + auto c_0 = column_wrapper{{3, 20, 1, 50}, {1, 1, 0, 0}}; + auto c_1 = column_wrapper{{3, 7, 1, 0}, {1, 1, 0, 0}}; + auto table = cudf::table_view{{c_0, c_1}}; + + auto col_ref_0 = cudf::ast::column_reference(0); + auto col_ref_1 = cudf::ast::column_reference(1); + auto expression = cudf::ast::operation(cudf::ast::ast_operator::EQUAL, col_ref_0, col_ref_1); + + auto expected = column_wrapper{{true, false, true, true}, {1, 1, 0, 0}}; + auto result = cudf::compute_column(table, expression); + + cudf::test::expect_columns_equal(expected, result->view(), verbosity); +} + +TEST_F(TransformTest, BasicEqualityNulls) +{ + auto c_0 = column_wrapper{{3, 20, 1, 2, 50}, {1, 1, 0, 1, 0}}; + auto c_1 = column_wrapper{{3, 7, 1, 2, 0}, {1, 1, 1, 0, 0}}; + auto table = cudf::table_view{{c_0, c_1}}; + + auto col_ref_0 = cudf::ast::column_reference(0); + auto col_ref_1 = cudf::ast::column_reference(1); + auto expression = cudf::ast::operation(cudf::ast::ast_operator::NULL_EQUAL, col_ref_0, col_ref_1); + + auto expected = column_wrapper{{true, false, false, false, true}, {1, 1, 1, 1, 1}}; + auto result = cudf::compute_column(table, expression); + + cudf::test::expect_columns_equal(expected, result->view(), verbosity); +} + +TEST_F(TransformTest, UnaryNotNulls) +{ + auto c_0 = column_wrapper{{3, 0, 0, 50}, {0, 0, 1, 1}}; + auto table = cudf::table_view{{c_0}}; + + auto col_ref_0 = cudf::ast::column_reference(0); + + auto expression = cudf::ast::operation(cudf::ast::ast_operator::NOT, col_ref_0); + + auto result = cudf::compute_column(table, expression); + auto expected = column_wrapper{{false, true, true, false}, {0, 0, 1, 1}}; + + cudf::test::expect_columns_equal(expected, result->view(), verbosity); +} + TEST_F(TransformTest, BasicAdditionNulls) { auto c_0 = column_wrapper{{3, 20, 1, 50}, {0, 0, 1, 1}}; @@ -502,4 +581,44 @@ TEST_F(TransformTest, BasicAdditionLargeNulls) cudf::test::expect_columns_equal(expected, result->view(), verbosity); } +TEST_F(TransformTest, NullLogicalAnd) +{ + auto c_0 = column_wrapper{{false, false, true, true, false, false, true, true}, + {1, 1, 1, 1, 1, 0, 0, 0}}; + auto c_1 = column_wrapper{{false, true, false, true, true, true, false, true}, + {1, 1, 1, 1, 0, 1, 1, 0}}; + auto table = cudf::table_view{{c_0, c_1}}; + + auto col_ref_0 = cudf::ast::column_reference(0); + auto col_ref_1 = cudf::ast::column_reference(1); + auto expression = + cudf::ast::operation(cudf::ast::ast_operator::NULL_LOGICAL_AND, col_ref_0, col_ref_1); + + auto expected = column_wrapper{{false, false, false, true, false, false, false, true}, + {1, 1, 1, 1, 1, 0, 1, 0}}; + auto result = cudf::compute_column(table, expression); + + cudf::test::expect_columns_equal(expected, result->view(), verbosity); +} + +TEST_F(TransformTest, NullLogicalOr) +{ + auto c_0 = column_wrapper{{false, false, true, true, false, false, true, true}, + {1, 1, 1, 1, 1, 0, 1, 0}}; + auto c_1 = column_wrapper{{false, true, false, true, true, true, false, true}, + {1, 1, 1, 1, 0, 1, 0, 0}}; + auto table = cudf::table_view{{c_0, c_1}}; + + auto col_ref_0 = cudf::ast::column_reference(0); + auto col_ref_1 = cudf::ast::column_reference(1); + auto expression = + cudf::ast::operation(cudf::ast::ast_operator::NULL_LOGICAL_OR, col_ref_0, col_ref_1); + + auto expected = column_wrapper{{false, true, true, true, false, true, true, true}, + {1, 1, 1, 1, 0, 1, 1, 0}}; + auto result = cudf::compute_column(table, expression); + + cudf::test::expect_columns_equal(expected, result->view(), verbosity); +} + CUDF_TEST_PROGRAM_MAIN() diff --git a/cpp/tests/join/conditional_join_tests.cu b/cpp/tests/join/conditional_join_tests.cu index 6c73fb67d7e..f6d76f9ea70 100644 --- a/cpp/tests/join/conditional_join_tests.cu +++ b/cpp/tests/join/conditional_join_tests.cu @@ -31,13 +31,23 @@ #include #include +#include #include #include #include -// Defining expressions for AST evaluation is currently a bit tedious, so we -// define some standard nodes here that can be easily reused elsewhere. namespace { +using PairJoinReturn = std::pair>, + std::unique_ptr>>; +using SingleJoinReturn = std::unique_ptr>; +using NullMaskVector = std::vector; + +template +using ColumnVector = std::vector>; + +template +using NullableColumnVector = std::vector, NullMaskVector>>; + constexpr cudf::size_type JoinNoneValue = std::numeric_limits::min(); // TODO: how to test if this isn't public? @@ -48,6 +58,62 @@ const auto col_ref_right_0 = cudf::ast::column_reference(0, cudf::ast::table_ref // Common expressions. auto left_zero_eq_right_zero = cudf::ast::operation(cudf::ast::ast_operator::EQUAL, col_ref_left_0, col_ref_right_0); + +// Generate a single pair of left/right non-nullable columns of random data +// suitable for testing a join against a reference join implementation. +template +std::pair, std::vector> gen_random_repeated_columns(unsigned int N = 10000, + unsigned int num_repeats = 10) +{ + // Generate columns of num_repeats repeats of the integer range [0, num_unique), + // then merge a shuffled version and compare to hash join. + unsigned int num_unique = N / num_repeats; + + std::vector left(N); + std::vector right(N); + + for (unsigned int i = 0; i < num_repeats; ++i) { + std::iota( + std::next(left.begin(), num_unique * i), std::next(left.begin(), num_unique * (i + 1)), 0); + std::iota( + std::next(right.begin(), num_unique * i), std::next(right.begin(), num_unique * (i + 1)), 0); + } + + std::random_device rd; + std::mt19937 gen(rd()); + std::shuffle(left.begin(), left.end(), gen); + std::shuffle(right.begin(), right.end(), gen); + return std::make_pair(std::move(left), std::move(right)); +} + +// Generate a single pair of left/right nullable columns of random data +// suitable for testing a join against a reference join implementation. +template +std::pair, std::vector>, + std::pair, std::vector>> +gen_random_nullable_repeated_columns(unsigned int N = 10000, unsigned int num_repeats = 10) +{ + auto [left, right] = gen_random_repeated_columns(N, num_repeats); + + std::vector left_nulls(N); + std::vector right_nulls(N); + + // Seed with a real random value, if available + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution<> uniform_dist(0, 1); + + std::generate(left_nulls.begin(), left_nulls.end(), [&uniform_dist, &gen]() { + return uniform_dist(gen) > 0.5; + }); + std::generate(right_nulls.begin(), right_nulls.end(), [&uniform_dist, &gen]() { + return uniform_dist(gen) > 0.5; + }); + + return std::make_pair(std::make_pair(std::move(left), std::move(left_nulls)), + std::make_pair(std::move(right), std::move(right_nulls))); +} + } // namespace /** @@ -59,66 +125,39 @@ struct ConditionalJoinTest : public cudf::test::BaseFixture { * Convenience utility for parsing initializer lists of input data into * suitable inputs for tables. */ + template std::tuple>, std::vector>, std::vector, std::vector, cudf::table_view, cudf::table_view> - parse_input(std::vector> left_data, std::vector> right_data) + parse_input(std::vector left_data, std::vector right_data) { + auto wrapper_generator = [](U& v) { + if constexpr (std::is_same_v>) { + return cudf::test::fixed_width_column_wrapper(v.begin(), v.end()); + } else if constexpr (std::is_same_v, std::vector>>) { + return cudf::test::fixed_width_column_wrapper( + v.first.begin(), v.first.end(), v.second.begin()); + } + throw std::runtime_error("Invalid input to parse_input."); + return cudf::test::fixed_width_column_wrapper(); + }; + // Note that we need to maintain the column wrappers otherwise the // resulting column views will be referencing potentially invalid memory. std::vector> left_wrappers; - std::vector> right_wrappers; - std::vector left_columns; - std::vector right_columns; - for (auto v : left_data) { - left_wrappers.push_back(cudf::test::fixed_width_column_wrapper(v.begin(), v.end())); + left_wrappers.push_back(wrapper_generator(v)); left_columns.push_back(left_wrappers.back()); } - for (auto v : right_data) { - right_wrappers.push_back(cudf::test::fixed_width_column_wrapper(v.begin(), v.end())); - right_columns.push_back(right_wrappers.back()); - } - - return std::make_tuple(std::move(left_wrappers), - std::move(right_wrappers), - std::move(left_columns), - std::move(right_columns), - cudf::table_view(left_columns), - cudf::table_view(right_columns)); - } - - std::tuple>, - std::vector>, - std::vector, - std::vector, - cudf::table_view, - cudf::table_view> - parse_input(std::vector, std::vector>> left_data, - std::vector, std::vector>> right_data) - { - // Note that we need to maintain the column wrappers otherwise the - // resulting column views will be referencing potentially invalid memory. - std::vector> left_wrappers; std::vector> right_wrappers; - - std::vector left_columns; std::vector right_columns; - - for (auto v : left_data) { - left_wrappers.push_back(cudf::test::fixed_width_column_wrapper( - v.first.begin(), v.first.end(), v.second.begin())); - left_columns.push_back(left_wrappers.back()); - } - for (auto v : right_data) { - right_wrappers.push_back(cudf::test::fixed_width_column_wrapper( - v.first.begin(), v.first.end(), v.second.begin())); + right_wrappers.push_back(wrapper_generator(v)); right_columns.push_back(right_wrappers.back()); } @@ -142,15 +181,11 @@ struct ConditionalJoinPairReturnTest : public ConditionalJoinTest { * the provided predicate and verify that the outputs match the expected * outputs (up to order). */ - void test(std::vector> left_data, - std::vector> right_data, - cudf::ast::operation predicate, - std::vector> expected_outputs) + void _test(cudf::table_view left, + cudf::table_view right, + cudf::ast::operation predicate, + std::vector> expected_outputs) { - // Note that we need to maintain the column wrappers otherwise the - // resulting column views will be referencing potentially invalid memory. - auto [left_wrappers, right_wrappers, left_columns, right_columns, left, right] = - this->parse_input(left_data, right_data); auto result_size = this->join_size(left, right, predicate); EXPECT_TRUE(result_size == expected_outputs.size()); @@ -169,8 +204,25 @@ struct ConditionalJoinPairReturnTest : public ConditionalJoinTest { EXPECT_TRUE(std::equal(expected_outputs.begin(), expected_outputs.end(), result_pairs.begin())); } - void test_nulls(std::vector, std::vector>> left_data, - std::vector, std::vector>> right_data, + /* + * Perform a join of tables constructed from two input data sets according to + * the provided predicate and verify that the outputs match the expected + * outputs (up to order). + */ + void test(ColumnVector left_data, + ColumnVector right_data, + cudf::ast::operation predicate, + std::vector> expected_outputs) + { + // Note that we need to maintain the column wrappers otherwise the + // resulting column views will be referencing potentially invalid memory. + auto [left_wrappers, right_wrappers, left_columns, right_columns, left, right] = + this->parse_input(left_data, right_data); + this->_test(left, right, predicate, expected_outputs); + } + + void test_nulls(NullableColumnVector left_data, + NullableColumnVector right_data, cudf::ast::operation predicate, std::vector> expected_outputs) { @@ -178,22 +230,7 @@ struct ConditionalJoinPairReturnTest : public ConditionalJoinTest { // resulting column views will be referencing potentially invalid memory. auto [left_wrappers, right_wrappers, left_columns, right_columns, left, right] = this->parse_input(left_data, right_data); - auto result_size = this->join_size(left, right, predicate); - EXPECT_TRUE(result_size == expected_outputs.size()); - - auto result = this->join(left, right, predicate); - std::vector> result_pairs; - for (size_t i = 0; i < result.first->size(); ++i) { - // Note: Not trying to be terribly efficient here since these tests are - // small, otherwise a batch copy to host before constructing the tuples - // would be important. - result_pairs.push_back({result.first->element(i, rmm::cuda_stream_default), - result.second->element(i, rmm::cuda_stream_default)}); - } - std::sort(result_pairs.begin(), result_pairs.end()); - std::sort(expected_outputs.begin(), expected_outputs.end()); - - EXPECT_TRUE(std::equal(expected_outputs.begin(), expected_outputs.end(), result_pairs.begin())); + this->_test(left, right, predicate, expected_outputs); } /* @@ -201,18 +238,8 @@ struct ConditionalJoinPairReturnTest : public ConditionalJoinTest { * an equality predicate on all corresponding columns and verify that the outputs match the * expected outputs (up to order). */ - void compare_to_hash_join(std::vector> left_data, - std::vector> right_data) + void _compare_to_hash_join(PairJoinReturn const& result, PairJoinReturn const& reference) { - // Note that we need to maintain the column wrappers otherwise the - // resulting column views will be referencing potentially invalid memory. - auto [left_wrappers, right_wrappers, left_columns, right_columns, left, right] = - this->parse_input(left_data, right_data); - // TODO: Generalize this to support multiple columns by automatically - // constructing the appropriate expression. - auto result = this->join(left, right, left_zero_eq_right_zero); - auto reference = this->reference_join(left, right); - thrust::device_vector> result_pairs( result.first->size()); thrust::device_vector> reference_pairs( @@ -242,14 +269,48 @@ struct ConditionalJoinPairReturnTest : public ConditionalJoinTest { thrust::device, reference_pairs.begin(), reference_pairs.end(), result_pairs.begin())); } + void compare_to_hash_join(ColumnVector left_data, ColumnVector right_data) + { + // Note that we need to maintain the column wrappers otherwise the + // resulting column views will be referencing potentially invalid memory. + auto [left_wrappers, right_wrappers, left_columns, right_columns, left, right] = + this->parse_input(left_data, right_data); + auto result = this->join(left, right, left_zero_eq_right_zero); + auto reference = this->reference_join(left, right); + this->_compare_to_hash_join(result, reference); + } + + void compare_to_hash_join_nulls(NullableColumnVector left_data, + NullableColumnVector right_data) + { + // Note that we need to maintain the column wrappers otherwise the + // resulting column views will be referencing potentially invalid memory. + auto [left_wrappers, right_wrappers, left_columns, right_columns, left, right] = + this->parse_input(left_data, right_data); + + // Test comparing nulls as equal (the default for ref joins, uses NULL_EQUAL for AST + // expression). + auto predicate = + cudf::ast::operation(cudf::ast::ast_operator::NULL_EQUAL, col_ref_left_0, col_ref_right_0); + auto result = this->join(left, right, predicate); + auto reference = this->reference_join(left, right); + this->_compare_to_hash_join(result, reference); + + // Test comparing nulls as equal (null_equality::UNEQUAL for ref joins, uses EQUAL for AST + // expression). + result = this->join(left, right, left_zero_eq_right_zero); + reference = this->reference_join(left, right, cudf::null_equality::UNEQUAL); + this->_compare_to_hash_join(result, reference); + } + /** * This method must be implemented by subclasses for specific types of joins. * It should be a simply forwarding of arguments to the appropriate cudf * conditional join API. */ - virtual std::pair>, - std::unique_ptr>> - join(cudf::table_view left, cudf::table_view right, cudf::ast::operation predicate) = 0; + virtual PairJoinReturn join(cudf::table_view left, + cudf::table_view right, + cudf::ast::operation predicate) = 0; /** * This method must be implemented by subclasses for specific types of joins. @@ -265,9 +326,10 @@ struct ConditionalJoinPairReturnTest : public ConditionalJoinTest { * It should be a simply forwarding of arguments to the appropriate cudf * hash join API for comparison with conditional joins. */ - virtual std::pair>, - std::unique_ptr>> - reference_join(cudf::table_view left, cudf::table_view right) = 0; + virtual PairJoinReturn reference_join( + cudf::table_view left, + cudf::table_view right, + cudf::null_equality compare_nulls = cudf::null_equality::EQUAL) = 0; }; /** @@ -275,9 +337,9 @@ struct ConditionalJoinPairReturnTest : public ConditionalJoinTest { */ template struct ConditionalInnerJoinTest : public ConditionalJoinPairReturnTest { - std::pair>, - std::unique_ptr>> - join(cudf::table_view left, cudf::table_view right, cudf::ast::operation predicate) override + PairJoinReturn join(cudf::table_view left, + cudf::table_view right, + cudf::ast::operation predicate) override { return cudf::conditional_inner_join(left, right, predicate); } @@ -289,11 +351,12 @@ struct ConditionalInnerJoinTest : public ConditionalJoinPairReturnTest { return cudf::conditional_inner_join_size(left, right, predicate); } - std::pair>, - std::unique_ptr>> - reference_join(cudf::table_view left, cudf::table_view right) override + PairJoinReturn reference_join( + cudf::table_view left, + cudf::table_view right, + cudf::null_equality compare_nulls = cudf::null_equality::EQUAL) override { - return cudf::inner_join(left, right); + return cudf::inner_join(left, right, compare_nulls); } }; @@ -422,30 +485,16 @@ TYPED_TEST(ConditionalInnerJoinTest, TestSymmetry) TYPED_TEST(ConditionalInnerJoinTest, TestCompareRandomToHash) { - // Generate columns of 10 repeats of the integer range [0, 10), then merge - // a shuffled version and compare to hash join. - unsigned int N = 10000; - unsigned int num_repeats = 10; - unsigned int num_unique = N / num_repeats; - - std::vector left(N); - std::vector right(N); - - for (unsigned int i = 0; i < num_repeats; ++i) { - std::iota( - std::next(left.begin(), num_unique * i), std::next(left.begin(), num_unique * (i + 1)), 0); - std::iota( - std::next(right.begin(), num_unique * i), std::next(right.begin(), num_unique * (i + 1)), 0); - } - - std::random_device rd; - std::mt19937 gen(rd()); - std::shuffle(left.begin(), left.end(), gen); - std::shuffle(right.begin(), right.end(), gen); - + auto [left, right] = gen_random_repeated_columns(); this->compare_to_hash_join({left}, {right}); }; +TYPED_TEST(ConditionalInnerJoinTest, TestCompareRandomToHashNulls) +{ + auto [left, right] = gen_random_nullable_repeated_columns(); + this->compare_to_hash_join_nulls({left}, {right}); +}; + TYPED_TEST(ConditionalInnerJoinTest, TestOneColumnTwoNullsRowAllEqual) { this->test_nulls( @@ -462,9 +511,9 @@ TYPED_TEST(ConditionalInnerJoinTest, TestOneColumnTwoNullsNoOutputRowAllEqual) */ template struct ConditionalLeftJoinTest : public ConditionalJoinPairReturnTest { - std::pair>, - std::unique_ptr>> - join(cudf::table_view left, cudf::table_view right, cudf::ast::operation predicate) override + PairJoinReturn join(cudf::table_view left, + cudf::table_view right, + cudf::ast::operation predicate) override { return cudf::conditional_left_join(left, right, predicate); } @@ -476,11 +525,12 @@ struct ConditionalLeftJoinTest : public ConditionalJoinPairReturnTest { return cudf::conditional_left_join_size(left, right, predicate); } - std::pair>, - std::unique_ptr>> - reference_join(cudf::table_view left, cudf::table_view right) override + PairJoinReturn reference_join( + cudf::table_view left, + cudf::table_view right, + cudf::null_equality compare_nulls = cudf::null_equality::EQUAL) override { - return cudf::left_join(left, right); + return cudf::left_join(left, right, compare_nulls); } }; @@ -501,27 +551,13 @@ TYPED_TEST(ConditionalLeftJoinTest, TestOneColumnLeftEmpty) TYPED_TEST(ConditionalLeftJoinTest, TestCompareRandomToHash) { - // Generate columns of 10 repeats of the integer range [0, 10), then merge - // a shuffled version and compare to hash join. - unsigned int N = 10000; - unsigned int num_repeats = 10; - unsigned int num_unique = N / num_repeats; - - std::vector left(N); - std::vector right(N); - - for (unsigned int i = 0; i < num_repeats; ++i) { - std::iota( - std::next(left.begin(), num_unique * i), std::next(left.begin(), num_unique * (i + 1)), 0); - std::iota( - std::next(right.begin(), num_unique * i), std::next(right.begin(), num_unique * (i + 1)), 0); - } - - std::random_device rd; - std::mt19937 gen(rd()); - std::shuffle(left.begin(), left.end(), gen); - std::shuffle(right.begin(), right.end(), gen); + auto [left, right] = gen_random_repeated_columns(); + this->compare_to_hash_join({left}, {right}); +}; +TYPED_TEST(ConditionalLeftJoinTest, TestCompareRandomToHashNulls) +{ + auto [left, right] = gen_random_repeated_columns(); this->compare_to_hash_join({left}, {right}); }; @@ -530,9 +566,9 @@ TYPED_TEST(ConditionalLeftJoinTest, TestCompareRandomToHash) */ template struct ConditionalFullJoinTest : public ConditionalJoinPairReturnTest { - std::pair>, - std::unique_ptr>> - join(cudf::table_view left, cudf::table_view right, cudf::ast::operation predicate) override + PairJoinReturn join(cudf::table_view left, + cudf::table_view right, + cudf::ast::operation predicate) override { return cudf::conditional_full_join(left, right, predicate); } @@ -547,11 +583,12 @@ struct ConditionalFullJoinTest : public ConditionalJoinPairReturnTest { return cudf::conditional_full_join(left, right, predicate).first->size(); } - std::pair>, - std::unique_ptr>> - reference_join(cudf::table_view left, cudf::table_view right) override + PairJoinReturn reference_join( + cudf::table_view left, + cudf::table_view right, + cudf::null_equality compare_nulls = cudf::null_equality::EQUAL) override { - return cudf::full_join(left, right); + return cudf::full_join(left, right, compare_nulls); } }; @@ -588,27 +625,13 @@ TYPED_TEST(ConditionalFullJoinTest, TestTwoColumnThreeRowSomeEqual) TYPED_TEST(ConditionalFullJoinTest, TestCompareRandomToHash) { - // Generate columns of 10 repeats of the integer range [0, 10), then merge - // a shuffled version and compare to hash join. - unsigned int N = 10000; - unsigned int num_repeats = 10; - unsigned int num_unique = N / num_repeats; - - std::vector left(N); - std::vector right(N); - - for (unsigned int i = 0; i < num_repeats; ++i) { - std::iota( - std::next(left.begin(), num_unique * i), std::next(left.begin(), num_unique * (i + 1)), 0); - std::iota( - std::next(right.begin(), num_unique * i), std::next(right.begin(), num_unique * (i + 1)), 0); - } - - std::random_device rd; - std::mt19937 gen(rd()); - std::shuffle(left.begin(), left.end(), gen); - std::shuffle(right.begin(), right.end(), gen); + auto [left, right] = gen_random_repeated_columns(); + this->compare_to_hash_join({left}, {right}); +}; +TYPED_TEST(ConditionalFullJoinTest, TestCompareRandomToHashNulls) +{ + auto [left, right] = gen_random_repeated_columns(); this->compare_to_hash_join({left}, {right}); }; @@ -623,8 +646,8 @@ struct ConditionalJoinSingleReturnTest : public ConditionalJoinTest { * the provided predicate and verify that the outputs match the expected * outputs (up to order). */ - void test(std::vector> left_data, - std::vector> right_data, + void test(ColumnVector left_data, + ColumnVector right_data, cudf::ast::operation predicate, std::vector expected_outputs) { @@ -647,27 +670,46 @@ struct ConditionalJoinSingleReturnTest : public ConditionalJoinTest { std::equal(resulting_indices.begin(), resulting_indices.end(), expected_outputs.begin())); } + void _compare_to_hash_join(std::unique_ptr> const& result, + std::unique_ptr> const& reference) + { + thrust::sort(thrust::device, result->begin(), result->end()); + thrust::sort(thrust::device, reference->begin(), reference->end()); + EXPECT_TRUE(thrust::equal(thrust::device, result->begin(), result->end(), reference->begin())); + } + /* * Perform a join of tables constructed from two input data sets according to * an equality predicate on all corresponding columns and verify that the outputs match the * expected outputs (up to order). */ - void compare_to_hash_join(std::vector> left_data, - std::vector> right_data) + void compare_to_hash_join(ColumnVector left_data, ColumnVector right_data) { // Note that we need to maintain the column wrappers otherwise the // resulting column views will be referencing potentially invalid memory. auto [left_wrappers, right_wrappers, left_columns, right_columns, left, right] = this->parse_input(left_data, right_data); - // TODO: Generalize this to support multiple columns by automatically - // constructing the appropriate expression. auto result = this->join(left, right, left_zero_eq_right_zero); auto reference = this->reference_join(left, right); + this->_compare_to_hash_join(result, reference); + } - thrust::sort(thrust::device, result->begin(), result->end()); - thrust::sort(thrust::device, reference->begin(), reference->end()); + void compare_to_hash_join_nulls(NullableColumnVector left_data, + NullableColumnVector right_data) + { + // Note that we need to maintain the column wrappers otherwise the + // resulting column views will be referencing potentially invalid memory. + auto [left_wrappers, right_wrappers, left_columns, right_columns, left, right] = + this->parse_input(left_data, right_data); + auto predicate = + cudf::ast::operation(cudf::ast::ast_operator::NULL_EQUAL, col_ref_left_0, col_ref_right_0); + auto result = this->join(left, right, predicate); + auto reference = this->reference_join(left, right); + this->_compare_to_hash_join(result, reference); - EXPECT_TRUE(thrust::equal(thrust::device, result->begin(), result->end(), reference->begin())); + result = this->join(left, right, left_zero_eq_right_zero); + reference = this->reference_join(left, right, cudf::null_equality::UNEQUAL); + this->_compare_to_hash_join(result, reference); } /** @@ -675,8 +717,9 @@ struct ConditionalJoinSingleReturnTest : public ConditionalJoinTest { * It should be a simply forwarding of arguments to the appropriate cudf * conditional join API. */ - virtual std::unique_ptr> join( - cudf::table_view left, cudf::table_view right, cudf::ast::operation predicate) = 0; + virtual SingleJoinReturn join(cudf::table_view left, + cudf::table_view right, + cudf::ast::operation predicate) = 0; /** * This method must be implemented by subclasses for specific types of joins. @@ -692,8 +735,10 @@ struct ConditionalJoinSingleReturnTest : public ConditionalJoinTest { * It should be a simply forwarding of arguments to the appropriate cudf * hash join API for comparison with conditional joins. */ - virtual std::unique_ptr> reference_join( - cudf::table_view left, cudf::table_view right) = 0; + virtual SingleJoinReturn reference_join( + cudf::table_view left, + cudf::table_view right, + cudf::null_equality compare_nulls = cudf::null_equality::EQUAL) = 0; }; /** @@ -701,8 +746,9 @@ struct ConditionalJoinSingleReturnTest : public ConditionalJoinTest { */ template struct ConditionalLeftSemiJoinTest : public ConditionalJoinSingleReturnTest { - std::unique_ptr> join( - cudf::table_view left, cudf::table_view right, cudf::ast::operation predicate) override + SingleJoinReturn join(cudf::table_view left, + cudf::table_view right, + cudf::ast::operation predicate) override { return cudf::conditional_left_semi_join(left, right, predicate); } @@ -714,10 +760,12 @@ struct ConditionalLeftSemiJoinTest : public ConditionalJoinSingleReturnTest { return cudf::conditional_left_semi_join_size(left, right, predicate); } - std::unique_ptr> reference_join( - cudf::table_view left, cudf::table_view right) override + SingleJoinReturn reference_join( + cudf::table_view left, + cudf::table_view right, + cudf::null_equality compare_nulls = cudf::null_equality::EQUAL) override { - return cudf::left_semi_join(left, right); + return cudf::left_semi_join(left, right, compare_nulls); } }; @@ -730,37 +778,24 @@ TYPED_TEST(ConditionalLeftSemiJoinTest, TestTwoColumnThreeRowSomeEqual) TYPED_TEST(ConditionalLeftSemiJoinTest, TestCompareRandomToHash) { - // Generate columns of 10 repeats of the integer range [0, 10), then merge - // a shuffled version and compare to hash join. - unsigned int N = 10000; - unsigned int num_repeats = 10; - unsigned int num_unique = N / num_repeats; - - std::vector left(N); - std::vector right(N); - - for (unsigned int i = 0; i < num_repeats; ++i) { - std::iota( - std::next(left.begin(), num_unique * i), std::next(left.begin(), num_unique * (i + 1)), 0); - std::iota( - std::next(right.begin(), num_unique * i), std::next(right.begin(), num_unique * (i + 1)), 0); - } - - std::random_device rd; - std::mt19937 gen(rd()); - std::shuffle(left.begin(), left.end(), gen); - std::shuffle(right.begin(), right.end(), gen); - + auto [left, right] = gen_random_repeated_columns(); this->compare_to_hash_join({left}, {right}); }; +TYPED_TEST(ConditionalLeftSemiJoinTest, TestCompareRandomToHashNulls) +{ + auto [left, right] = gen_random_nullable_repeated_columns(); + this->compare_to_hash_join_nulls({left}, {right}); +}; + /** * Tests of left anti joins. */ template struct ConditionalLeftAntiJoinTest : public ConditionalJoinSingleReturnTest { - std::unique_ptr> join( - cudf::table_view left, cudf::table_view right, cudf::ast::operation predicate) override + SingleJoinReturn join(cudf::table_view left, + cudf::table_view right, + cudf::ast::operation predicate) override { return cudf::conditional_left_anti_join(left, right, predicate); } @@ -772,10 +807,12 @@ struct ConditionalLeftAntiJoinTest : public ConditionalJoinSingleReturnTest { return cudf::conditional_left_anti_join_size(left, right, predicate); } - std::unique_ptr> reference_join( - cudf::table_view left, cudf::table_view right) override + SingleJoinReturn reference_join( + cudf::table_view left, + cudf::table_view right, + cudf::null_equality compare_nulls = cudf::null_equality::EQUAL) override { - return cudf::left_anti_join(left, right); + return cudf::left_anti_join(left, right, compare_nulls); } }; @@ -788,26 +825,12 @@ TYPED_TEST(ConditionalLeftAntiJoinTest, TestTwoColumnThreeRowSomeEqual) TYPED_TEST(ConditionalLeftAntiJoinTest, TestCompareRandomToHash) { - // Generate columns of 10 repeats of the integer range [0, 10), then merge - // a shuffled version and compare to hash join. - unsigned int N = 10000; - unsigned int num_repeats = 10; - unsigned int num_unique = N / num_repeats; - - std::vector left(N); - std::vector right(N); - - for (unsigned int i = 0; i < num_repeats; ++i) { - std::iota( - std::next(left.begin(), num_unique * i), std::next(left.begin(), num_unique * (i + 1)), 0); - std::iota( - std::next(right.begin(), num_unique * i), std::next(right.begin(), num_unique * (i + 1)), 0); - } - - std::random_device rd; - std::mt19937 gen(rd()); - std::shuffle(left.begin(), left.end(), gen); - std::shuffle(right.begin(), right.end(), gen); - + auto [left, right] = gen_random_repeated_columns(); this->compare_to_hash_join({left}, {right}); }; + +TYPED_TEST(ConditionalLeftAntiJoinTest, TestCompareRandomToHashNulls) +{ + auto [left, right] = gen_random_nullable_repeated_columns(); + this->compare_to_hash_join_nulls({left}, {right}); +}; diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index eeb2d308f1a..2744728fb44 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -569,64 +569,50 @@ private static native long[] leftAntiJoinGatherMap(long leftKeys, long rightKeys boolean compareNullsEqual) throws CudfException; private static native long conditionalLeftJoinRowCount(long leftTable, long rightTable, - long condition, - boolean compareNullsEqual) throws CudfException; + long condition) throws CudfException; private static native long[] conditionalLeftJoinGatherMaps(long leftTable, long rightTable, - long condition, - boolean compareNullsEqual) throws CudfException; + long condition) throws CudfException; private static native long[] conditionalLeftJoinGatherMapsWithCount(long leftTable, long rightTable, long condition, - boolean compareNullsEqual, long rowCount) throws CudfException; private static native long conditionalInnerJoinRowCount(long leftTable, long rightTable, - long condition, - boolean compareNullsEqual) throws CudfException; + long condition) throws CudfException; private static native long[] conditionalInnerJoinGatherMaps(long leftTable, long rightTable, - long condition, - boolean compareNullsEqual) throws CudfException; + long condition) throws CudfException; private static native long[] conditionalInnerJoinGatherMapsWithCount(long leftTable, long rightTable, long condition, - boolean compareNullsEqual, long rowCount) throws CudfException; private static native long[] conditionalFullJoinGatherMaps(long leftTable, long rightTable, - long condition, - boolean compareNullsEqual) throws CudfException; + long condition) throws CudfException; private static native long[] conditionalFullJoinGatherMapsWithCount(long leftTable, long rightTable, long condition, - boolean compareNullsEqual, long rowCount) throws CudfException; private static native long conditionalLeftSemiJoinRowCount(long leftTable, long rightTable, - long condition, - boolean compareNullsEqual) throws CudfException; + long condition) throws CudfException; private static native long[] conditionalLeftSemiJoinGatherMap(long leftTable, long rightTable, - long condition, - boolean compareNullsEqual) throws CudfException; + long condition) throws CudfException; private static native long[] conditionalLeftSemiJoinGatherMapWithCount(long leftTable, long rightTable, long condition, - boolean compareNullsEqual, long rowCount) throws CudfException; private static native long conditionalLeftAntiJoinRowCount(long leftTable, long rightTable, - long condition, - boolean compareNullsEqual) throws CudfException; + long condition) throws CudfException; private static native long[] conditionalLeftAntiJoinGatherMap(long leftTable, long rightTable, - long condition, - boolean compareNullsEqual) throws CudfException; + long condition) throws CudfException; private static native long[] conditionalLeftAntiJoinGatherMapWithCount(long leftTable, long rightTable, long condition, - boolean compareNullsEqual, long rowCount) throws CudfException; private static native long[] crossJoin(long leftTable, long rightTable) throws CudfException; @@ -2148,13 +2134,11 @@ public GatherMap[] leftJoinGatherMaps(HashJoin rightHash, long outputRowCount) { * the left table, and the table argument represents the columns from the right table. * @param rightTable the right side table of the join in the join * @param condition conditional expression to evaluate during the join - * @param compareNullsEqual true if null key values should match otherwise false * @return row count for the join result */ - public long conditionalLeftJoinRowCount(Table rightTable, CompiledExpression condition, - boolean compareNullsEqual) { + public long conditionalLeftJoinRowCount(Table rightTable, CompiledExpression condition) { return conditionalLeftJoinRowCount(getNativeView(), rightTable.getNativeView(), - condition.getNativeHandle(), compareNullsEqual); + condition.getNativeHandle()); } /** @@ -2166,15 +2150,13 @@ public long conditionalLeftJoinRowCount(Table rightTable, CompiledExpression con * It is the responsibility of the caller to close the resulting gather map instances. * @param rightTable the right side table of the join in the join * @param condition conditional expression to evaluate during the join - * @param compareNullsEqual true if null key values should match otherwise false * @return left and right table gather maps */ public GatherMap[] conditionalLeftJoinGatherMaps(Table rightTable, - CompiledExpression condition, - boolean compareNullsEqual) { + CompiledExpression condition) { long[] gatherMapData = conditionalLeftJoinGatherMaps(getNativeView(), rightTable.getNativeView(), - condition.getNativeHandle(), compareNullsEqual); + condition.getNativeHandle()); return buildJoinGatherMaps(gatherMapData); } @@ -2191,17 +2173,15 @@ public GatherMap[] conditionalLeftJoinGatherMaps(Table rightTable, * in undefined behavior. * @param rightTable the right side table of the join in the join * @param condition conditional expression to evaluate during the join - * @param compareNullsEqual true if null key values should match otherwise false * @param outputRowCount number of output rows in the join result * @return left and right table gather maps */ public GatherMap[] conditionalLeftJoinGatherMaps(Table rightTable, CompiledExpression condition, - boolean compareNullsEqual, long outputRowCount) { long[] gatherMapData = conditionalLeftJoinGatherMapsWithCount(getNativeView(), rightTable.getNativeView(), - condition.getNativeHandle(), compareNullsEqual, outputRowCount); + condition.getNativeHandle(), outputRowCount); return buildJoinGatherMaps(gatherMapData); } @@ -2293,14 +2273,12 @@ public GatherMap[] innerJoinGatherMaps(HashJoin rightHash, long outputRowCount) * the left table, and the table argument represents the columns from the right table. * @param rightTable the right side table of the join in the join * @param condition conditional expression to evaluate during the join - * @param compareNullsEqual true if null key values should match otherwise false * @return row count for the join result */ public long conditionalInnerJoinRowCount(Table rightTable, - CompiledExpression condition, - boolean compareNullsEqual) { + CompiledExpression condition) { return conditionalInnerJoinRowCount(getNativeView(), rightTable.getNativeView(), - condition.getNativeHandle(), compareNullsEqual); + condition.getNativeHandle()); } /** @@ -2312,15 +2290,13 @@ public long conditionalInnerJoinRowCount(Table rightTable, * It is the responsibility of the caller to close the resulting gather map instances. * @param rightTable the right side table of the join * @param condition conditional expression to evaluate during the join - * @param compareNullsEqual true if null key values should match otherwise false * @return left and right table gather maps */ public GatherMap[] conditionalInnerJoinGatherMaps(Table rightTable, - CompiledExpression condition, - boolean compareNullsEqual) { + CompiledExpression condition) { long[] gatherMapData = conditionalInnerJoinGatherMaps(getNativeView(), rightTable.getNativeView(), - condition.getNativeHandle(), compareNullsEqual); + condition.getNativeHandle()); return buildJoinGatherMaps(gatherMapData); } @@ -2337,17 +2313,15 @@ public GatherMap[] conditionalInnerJoinGatherMaps(Table rightTable, * in undefined behavior. * @param rightTable the right side table of the join in the join * @param condition conditional expression to evaluate during the join - * @param compareNullsEqual true if null key values should match otherwise false * @param outputRowCount number of output rows in the join result * @return left and right table gather maps */ public GatherMap[] conditionalInnerJoinGatherMaps(Table rightTable, CompiledExpression condition, - boolean compareNullsEqual, long outputRowCount) { long[] gatherMapData = conditionalInnerJoinGatherMapsWithCount(getNativeView(), rightTable.getNativeView(), - condition.getNativeHandle(), compareNullsEqual, outputRowCount); + condition.getNativeHandle(), outputRowCount); return buildJoinGatherMaps(gatherMapData); } @@ -2447,15 +2421,13 @@ public GatherMap[] fullJoinGatherMaps(HashJoin rightHash, long outputRowCount) { * It is the responsibility of the caller to close the resulting gather map instances. * @param rightTable the right side table of the join * @param condition conditional expression to evaluate during the join - * @param compareNullsEqual true if null key values should match otherwise false * @return left and right table gather maps */ public GatherMap[] conditionalFullJoinGatherMaps(Table rightTable, - CompiledExpression condition, - boolean compareNullsEqual) { + CompiledExpression condition) { long[] gatherMapData = conditionalFullJoinGatherMaps(getNativeView(), rightTable.getNativeView(), - condition.getNativeHandle(), compareNullsEqual); + condition.getNativeHandle()); return buildJoinGatherMaps(gatherMapData); } @@ -2493,14 +2465,12 @@ public GatherMap leftSemiJoinGatherMap(Table rightKeys, boolean compareNullsEqua * the left table, and the table argument represents the columns from the right table. * @param rightTable the right side table of the join in the join * @param condition conditional expression to evaluate during the join - * @param compareNullsEqual true if null key values should match otherwise false * @return row count for the join result */ public long conditionalLeftSemiJoinRowCount(Table rightTable, - CompiledExpression condition, - boolean compareNullsEqual) { + CompiledExpression condition) { return conditionalLeftSemiJoinRowCount(getNativeView(), rightTable.getNativeView(), - condition.getNativeHandle(), compareNullsEqual); + condition.getNativeHandle()); } /** @@ -2512,15 +2482,13 @@ public long conditionalLeftSemiJoinRowCount(Table rightTable, * It is the responsibility of the caller to close the resulting gather map instance. * @param rightTable the right side table of the join * @param condition conditional expression to evaluate during the join - * @param compareNullsEqual true if null key values should match otherwise false * @return left table gather map */ public GatherMap conditionalLeftSemiJoinGatherMap(Table rightTable, - CompiledExpression condition, - boolean compareNullsEqual) { + CompiledExpression condition) { long[] gatherMapData = conditionalLeftSemiJoinGatherMap(getNativeView(), rightTable.getNativeView(), - condition.getNativeHandle(), compareNullsEqual); + condition.getNativeHandle()); return buildSemiJoinGatherMap(gatherMapData); } @@ -2537,17 +2505,15 @@ public GatherMap conditionalLeftSemiJoinGatherMap(Table rightTable, * in undefined behavior. * @param rightTable the right side table of the join * @param condition conditional expression to evaluate during the join - * @param compareNullsEqual true if null key values should match otherwise false * @param outputRowCount number of output rows in the join result * @return left table gather map */ public GatherMap conditionalLeftSemiJoinGatherMap(Table rightTable, CompiledExpression condition, - boolean compareNullsEqual, long outputRowCount) { long[] gatherMapData = conditionalLeftSemiJoinGatherMapWithCount(getNativeView(), rightTable.getNativeView(), - condition.getNativeHandle(), compareNullsEqual, outputRowCount); + condition.getNativeHandle(), outputRowCount); return buildSemiJoinGatherMap(gatherMapData); } @@ -2578,14 +2544,12 @@ public GatherMap leftAntiJoinGatherMap(Table rightKeys, boolean compareNullsEqua * the left table, and the table argument represents the columns from the right table. * @param rightTable the right side table of the join in the join * @param condition conditional expression to evaluate during the join - * @param compareNullsEqual true if null key values should match otherwise false * @return row count for the join result */ public long conditionalLeftAntiJoinRowCount(Table rightTable, - CompiledExpression condition, - boolean compareNullsEqual) { + CompiledExpression condition) { return conditionalLeftAntiJoinRowCount(getNativeView(), rightTable.getNativeView(), - condition.getNativeHandle(), compareNullsEqual); + condition.getNativeHandle()); } /** @@ -2597,15 +2561,13 @@ public long conditionalLeftAntiJoinRowCount(Table rightTable, * It is the responsibility of the caller to close the resulting gather map instance. * @param rightTable the right side table of the join * @param condition conditional expression to evaluate during the join - * @param compareNullsEqual true if null key values should match otherwise false * @return left table gather map */ public GatherMap conditionalLeftAntiJoinGatherMap(Table rightTable, - CompiledExpression condition, - boolean compareNullsEqual) { + CompiledExpression condition) { long[] gatherMapData = conditionalLeftAntiJoinGatherMap(getNativeView(), rightTable.getNativeView(), - condition.getNativeHandle(), compareNullsEqual); + condition.getNativeHandle()); return buildSemiJoinGatherMap(gatherMapData); } @@ -2622,17 +2584,15 @@ public GatherMap conditionalLeftAntiJoinGatherMap(Table rightTable, * in undefined behavior. * @param rightTable the right side table of the join * @param condition conditional expression to evaluate during the join - * @param compareNullsEqual true if null key values should match otherwise false * @param outputRowCount number of output rows in the join result * @return left table gather map */ public GatherMap conditionalLeftAntiJoinGatherMap(Table rightTable, CompiledExpression condition, - boolean compareNullsEqual, long outputRowCount) { long[] gatherMapData = conditionalLeftAntiJoinGatherMapWithCount(getNativeView(), rightTable.getNativeView(), - condition.getNativeHandle(), compareNullsEqual, outputRowCount); + condition.getNativeHandle(), outputRowCount); return buildSemiJoinGatherMap(gatherMapData); } diff --git a/java/src/main/java/ai/rapids/cudf/ast/BinaryOperator.java b/java/src/main/java/ai/rapids/cudf/ast/BinaryOperator.java index 595badb14b6..be902658371 100644 --- a/java/src/main/java/ai/rapids/cudf/ast/BinaryOperator.java +++ b/java/src/main/java/ai/rapids/cudf/ast/BinaryOperator.java @@ -23,26 +23,29 @@ * NOTE: This must be kept in sync with `jni_to_binary_operator` in CompiledExpression.cpp! */ public enum BinaryOperator { - ADD(0), // operator + - SUB(1), // operator - - MUL(2), // operator * - DIV(3), // operator / using common type of lhs and rhs - TRUE_DIV(4), // operator / after promoting type to floating point - FLOOR_DIV(5), // operator / after promoting to 64 bit floating point and then flooring the result - MOD(6), // operator % - PYMOD(7), // operator % but following python's sign rules for negatives - POW(8), // lhs ^ rhs - EQUAL(9), // operator == - NOT_EQUAL(10), // operator != - LESS(11), // operator < - GREATER(12), // operator > - LESS_EQUAL(13), // operator <= - GREATER_EQUAL(14), // operator >= - BITWISE_AND(15), // operator & - BITWISE_OR(16), // operator | - BITWISE_XOR(17), // operator ^ - LOGICAL_AND(18), // operator && - LOGICAL_OR(19); // operator || + ADD(0), // operator + + SUB(1), // operator - + MUL(2), // operator * + DIV(3), // operator / using common type of lhs and rhs + TRUE_DIV(4), // operator / after promoting type to floating point + FLOOR_DIV(5), // operator / after promoting to 64 bit floating point and then flooring the result + MOD(6), // operator % + PYMOD(7), // operator % using Python's sign rules for negatives + POW(8), // lhs ^ rhs + EQUAL(9), // operator == + NULL_EQUAL(10), // operator == using Spark rules for null inputs + NOT_EQUAL(11), // operator != + LESS(12), // operator < + GREATER(13), // operator > + LESS_EQUAL(14), // operator <= + GREATER_EQUAL(15), // operator >= + BITWISE_AND(16), // operator & + BITWISE_OR(17), // operator | + BITWISE_XOR(18), // operator ^ + LOGICAL_AND(19), // operator && + NULL_LOGICAL_AND(20), // operator && using Spark rules for null inputs + LOGICAL_OR(21), // operator || + NULL_LOGICAL_OR(22); // operator || using Spark rules for null inputs private final byte nativeId; diff --git a/java/src/main/native/src/CompiledExpression.cpp b/java/src/main/native/src/CompiledExpression.cpp index 470464f35c8..4b378905a43 100644 --- a/java/src/main/native/src/CompiledExpression.cpp +++ b/java/src/main/native/src/CompiledExpression.cpp @@ -165,16 +165,19 @@ cudf::ast::ast_operator jni_to_binary_operator(jbyte jni_op_value) { case 7: return cudf::ast::ast_operator::PYMOD; case 8: return cudf::ast::ast_operator::POW; case 9: return cudf::ast::ast_operator::EQUAL; - case 10: return cudf::ast::ast_operator::NOT_EQUAL; - case 11: return cudf::ast::ast_operator::LESS; - case 12: return cudf::ast::ast_operator::GREATER; - case 13: return cudf::ast::ast_operator::LESS_EQUAL; - case 14: return cudf::ast::ast_operator::GREATER_EQUAL; - case 15: return cudf::ast::ast_operator::BITWISE_AND; - case 16: return cudf::ast::ast_operator::BITWISE_OR; - case 17: return cudf::ast::ast_operator::BITWISE_XOR; - case 18: return cudf::ast::ast_operator::LOGICAL_AND; - case 19: return cudf::ast::ast_operator::LOGICAL_OR; + case 10: return cudf::ast::ast_operator::NULL_EQUAL; + case 11: return cudf::ast::ast_operator::NOT_EQUAL; + case 12: return cudf::ast::ast_operator::LESS; + case 13: return cudf::ast::ast_operator::GREATER; + case 14: return cudf::ast::ast_operator::LESS_EQUAL; + case 15: return cudf::ast::ast_operator::GREATER_EQUAL; + case 16: return cudf::ast::ast_operator::BITWISE_AND; + case 17: return cudf::ast::ast_operator::BITWISE_OR; + case 18: return cudf::ast::ast_operator::BITWISE_XOR; + case 19: return cudf::ast::ast_operator::LOGICAL_AND; + case 20: return cudf::ast::ast_operator::NULL_LOGICAL_AND; + case 21: return cudf::ast::ast_operator::LOGICAL_OR; + case 22: return cudf::ast::ast_operator::NULL_LOGICAL_OR; default: throw std::invalid_argument("unexpected JNI AST binary operator value"); } } diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index 2bb56565f7a..92e96b0de61 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -830,7 +830,7 @@ jlongArray hash_join_gather_maps(JNIEnv *env, jlong j_left_keys, jlong j_right_h // Generate gather maps needed to manifest the result of a conditional join between two tables. template jlongArray cond_join_gather_maps(JNIEnv *env, jlong j_left_table, jlong j_right_table, - jlong j_condition, jboolean compare_nulls_equal, T join_func) { + jlong j_condition, T join_func) { JNI_NULL_CHECK(env, j_left_table, "left_table is null", NULL); JNI_NULL_CHECK(env, j_right_table, "right_table is null", NULL); JNI_NULL_CHECK(env, j_condition, "condition is null", NULL); @@ -839,25 +839,22 @@ jlongArray cond_join_gather_maps(JNIEnv *env, jlong j_left_table, jlong j_right_ auto left_table = reinterpret_cast(j_left_table); auto right_table = reinterpret_cast(j_right_table); auto condition = reinterpret_cast(j_condition); - auto nulleq = compare_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; return gather_maps_to_java( - env, join_func(*left_table, *right_table, condition->get_top_expression(), nulleq)); + env, join_func(*left_table, *right_table, condition->get_top_expression())); } CATCH_STD(env, NULL); } // Generate a gather map needed to manifest the result of a semi/anti join between two tables. template -jlongArray join_gather_single_map(JNIEnv *env, jlong j_left_keys, jlong j_right_keys, - jboolean compare_nulls_equal, T join_func) { +jlongArray join_gather_single_map(JNIEnv *env, jlong j_left_keys, jlong j_right_keys, T join_func) { JNI_NULL_CHECK(env, j_left_keys, "left_table is null", NULL); JNI_NULL_CHECK(env, j_right_keys, "right_table is null", NULL); try { cudf::jni::auto_set_device(env); auto left_keys = reinterpret_cast(j_left_keys); auto right_keys = reinterpret_cast(j_right_keys); - auto nulleq = compare_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; - return gather_map_to_java(env, join_func(*left_keys, *right_keys, nulleq)); + return gather_map_to_java(env, join_func(*left_keys, *right_keys)); } CATCH_STD(env, NULL); } @@ -866,8 +863,7 @@ jlongArray join_gather_single_map(JNIEnv *env, jlong j_left_keys, jlong j_right_ // between two tables. template jlongArray cond_join_gather_single_map(JNIEnv *env, jlong j_left_table, jlong j_right_table, - jlong j_condition, jboolean compare_nulls_equal, - T join_func) { + jlong j_condition, T join_func) { JNI_NULL_CHECK(env, j_left_table, "left_table is null", NULL); JNI_NULL_CHECK(env, j_right_table, "right_table is null", NULL); JNI_NULL_CHECK(env, j_condition, "condition is null", NULL); @@ -876,9 +872,8 @@ jlongArray cond_join_gather_single_map(JNIEnv *env, jlong j_left_table, jlong j_ auto left_table = reinterpret_cast(j_left_table); auto right_table = reinterpret_cast(j_right_table); auto condition = reinterpret_cast(j_condition); - auto nulleq = compare_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; return gather_map_to_java( - env, join_func(*left_table, *right_table, condition->get_top_expression(), nulleq)); + env, join_func(*left_table, *right_table, condition->get_top_expression())); } CATCH_STD(env, NULL); } @@ -2043,9 +2038,10 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftHashJoinGatherMapsWit }); } -JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_conditionalLeftJoinRowCount( - JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, - jboolean compare_nulls_equal) { +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_conditionalLeftJoinRowCount(JNIEnv *env, jclass, + jlong j_left_table, + jlong j_right_table, + jlong j_condition) { JNI_NULL_CHECK(env, j_left_table, "left_table is null", 0); JNI_NULL_CHECK(env, j_right_table, "right_table is null", 0); JNI_NULL_CHECK(env, j_condition, "condition is null", 0); @@ -2054,34 +2050,32 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_conditionalLeftJoinRowCount( auto left_table = reinterpret_cast(j_left_table); auto right_table = reinterpret_cast(j_right_table); auto condition = reinterpret_cast(j_condition); - auto nulleq = compare_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; auto row_count = cudf::conditional_left_join_size(*left_table, *right_table, - condition->get_top_expression(), nulleq); + condition->get_top_expression()); return static_cast(row_count); } CATCH_STD(env, 0); } JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalLeftJoinGatherMaps( - JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, - jboolean compare_nulls_equal) { + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition) { return cudf::jni::cond_join_gather_maps( - env, j_left_table, j_right_table, j_condition, compare_nulls_equal, + env, j_left_table, j_right_table, j_condition, [](cudf::table_view const &left, cudf::table_view const &right, - cudf::ast::expression const &cond_expr, cudf::null_equality nulleq) { - return cudf::conditional_left_join(left, right, cond_expr, nulleq); + cudf::ast::expression const &cond_expr) { + return cudf::conditional_left_join(left, right, cond_expr); }); } JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalLeftJoinGatherMapsWithCount( JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, - jboolean compare_nulls_equal, jlong j_row_count) { + jlong j_row_count) { auto row_count = static_cast(j_row_count); return cudf::jni::cond_join_gather_maps( - env, j_left_table, j_right_table, j_condition, compare_nulls_equal, + env, j_left_table, j_right_table, j_condition, [row_count](cudf::table_view const &left, cudf::table_view const &right, - cudf::ast::expression const &cond_expr, cudf::null_equality nulleq) { - return cudf::conditional_left_join(left, right, cond_expr, nulleq, row_count); + cudf::ast::expression const &cond_expr) { + return cudf::conditional_left_join(left, right, cond_expr, row_count); }); } @@ -2133,9 +2127,10 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_innerHashJoinGatherMapsWi }); } -JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_conditionalInnerJoinRowCount( - JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, - jboolean compare_nulls_equal) { +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_conditionalInnerJoinRowCount(JNIEnv *env, jclass, + jlong j_left_table, + jlong j_right_table, + jlong j_condition) { JNI_NULL_CHECK(env, j_left_table, "left_table is null", 0); JNI_NULL_CHECK(env, j_right_table, "right_table is null", 0); JNI_NULL_CHECK(env, j_condition, "condition is null", 0); @@ -2144,34 +2139,32 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_conditionalInnerJoinRowCount( auto left_table = reinterpret_cast(j_left_table); auto right_table = reinterpret_cast(j_right_table); auto condition = reinterpret_cast(j_condition); - auto nulleq = compare_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; auto row_count = cudf::conditional_inner_join_size(*left_table, *right_table, - condition->get_top_expression(), nulleq); + condition->get_top_expression()); return static_cast(row_count); } CATCH_STD(env, 0); } JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalInnerJoinGatherMaps( - JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, - jboolean compare_nulls_equal) { + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition) { return cudf::jni::cond_join_gather_maps( - env, j_left_table, j_right_table, j_condition, compare_nulls_equal, + env, j_left_table, j_right_table, j_condition, [](cudf::table_view const &left, cudf::table_view const &right, - cudf::ast::expression const &cond_expr, cudf::null_equality nulleq) { - return cudf::conditional_inner_join(left, right, cond_expr, nulleq); + cudf::ast::expression const &cond_expr) { + return cudf::conditional_inner_join(left, right, cond_expr); }); } JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalInnerJoinGatherMapsWithCount( JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, - jboolean compare_nulls_equal, jlong j_row_count) { + jlong j_row_count) { auto row_count = static_cast(j_row_count); return cudf::jni::cond_join_gather_maps( - env, j_left_table, j_right_table, j_condition, compare_nulls_equal, + env, j_left_table, j_right_table, j_condition, [row_count](cudf::table_view const &left, cudf::table_view const &right, - cudf::ast::expression const &cond_expr, cudf::null_equality nulleq) { - return cudf::conditional_inner_join(left, right, cond_expr, nulleq, row_count); + cudf::ast::expression const &cond_expr) { + return cudf::conditional_inner_join(left, right, cond_expr, row_count); }); } @@ -2224,28 +2217,27 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_fullHashJoinGatherMapsWit } JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalFullJoinGatherMaps( - JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, - jboolean compare_nulls_equal) { + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition) { return cudf::jni::cond_join_gather_maps( - env, j_left_table, j_right_table, j_condition, compare_nulls_equal, + env, j_left_table, j_right_table, j_condition, [](cudf::table_view const &left, cudf::table_view const &right, - cudf::ast::expression const &cond_expr, cudf::null_equality nulleq) { - return cudf::conditional_full_join(left, right, cond_expr, nulleq); + cudf::ast::expression const &cond_expr) { + return cudf::conditional_full_join(left, right, cond_expr); }); } -JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftSemiJoinGatherMap( - JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jboolean compare_nulls_equal) { +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftSemiJoinGatherMap(JNIEnv *env, jclass, + jlong j_left_keys, + jlong j_right_keys) { return cudf::jni::join_gather_single_map( - env, j_left_keys, j_right_keys, compare_nulls_equal, - [](cudf::table_view const &left, cudf::table_view const &right, cudf::null_equality nulleq) { - return cudf::left_semi_join(left, right, nulleq); + env, j_left_keys, j_right_keys, + [](cudf::table_view const &left, cudf::table_view const &right) { + return cudf::left_semi_join(left, right); }); } JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_conditionalLeftSemiJoinRowCount( - JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, - jboolean compare_nulls_equal) { + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition) { JNI_NULL_CHECK(env, j_left_table, "left_table is null", 0); JNI_NULL_CHECK(env, j_right_table, "right_table is null", 0); JNI_NULL_CHECK(env, j_condition, "condition is null", 0); @@ -2254,49 +2246,47 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_conditionalLeftSemiJoinRowCoun auto left_table = reinterpret_cast(j_left_table); auto right_table = reinterpret_cast(j_right_table); auto condition = reinterpret_cast(j_condition); - auto nulleq = compare_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; auto row_count = cudf::conditional_left_semi_join_size(*left_table, *right_table, - condition->get_top_expression(), nulleq); + condition->get_top_expression()); return static_cast(row_count); } CATCH_STD(env, 0); } JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalLeftSemiJoinGatherMap( - JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, - jboolean compare_nulls_equal) { + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition) { return cudf::jni::cond_join_gather_single_map( - env, j_left_table, j_right_table, j_condition, compare_nulls_equal, + env, j_left_table, j_right_table, j_condition, [](cudf::table_view const &left, cudf::table_view const &right, - cudf::ast::expression const &cond_expr, cudf::null_equality nulleq) { - return cudf::conditional_left_semi_join(left, right, cond_expr, nulleq); + cudf::ast::expression const &cond_expr) { + return cudf::conditional_left_semi_join(left, right, cond_expr); }); } JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalLeftSemiJoinGatherMapWithCount( JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, - jboolean compare_nulls_equal, jlong j_row_count) { + jlong j_row_count) { auto row_count = static_cast(j_row_count); return cudf::jni::cond_join_gather_single_map( - env, j_left_table, j_right_table, j_condition, compare_nulls_equal, + env, j_left_table, j_right_table, j_condition, [row_count](cudf::table_view const &left, cudf::table_view const &right, - cudf::ast::expression const &cond_expr, cudf::null_equality nulleq) { - return cudf::conditional_left_semi_join(left, right, cond_expr, nulleq, row_count); + cudf::ast::expression const &cond_expr) { + return cudf::conditional_left_semi_join(left, right, cond_expr, row_count); }); } -JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftAntiJoinGatherMap( - JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jboolean compare_nulls_equal) { +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftAntiJoinGatherMap(JNIEnv *env, jclass, + jlong j_left_keys, + jlong j_right_keys) { return cudf::jni::join_gather_single_map( - env, j_left_keys, j_right_keys, compare_nulls_equal, - [](cudf::table_view const &left, cudf::table_view const &right, cudf::null_equality nulleq) { - return cudf::left_anti_join(left, right, nulleq); + env, j_left_keys, j_right_keys, + [](cudf::table_view const &left, cudf::table_view const &right) { + return cudf::left_anti_join(left, right); }); } JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_conditionalLeftAntiJoinRowCount( - JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, - jboolean compare_nulls_equal) { + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition) { JNI_NULL_CHECK(env, j_left_table, "left_table is null", 0); JNI_NULL_CHECK(env, j_right_table, "right_table is null", 0); JNI_NULL_CHECK(env, j_condition, "condition is null", 0); @@ -2305,34 +2295,32 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_conditionalLeftAntiJoinRowCoun auto left_table = reinterpret_cast(j_left_table); auto right_table = reinterpret_cast(j_right_table); auto condition = reinterpret_cast(j_condition); - auto nulleq = compare_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; auto row_count = cudf::conditional_left_anti_join_size(*left_table, *right_table, - condition->get_top_expression(), nulleq); + condition->get_top_expression()); return static_cast(row_count); } CATCH_STD(env, 0); } JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalLeftAntiJoinGatherMap( - JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, - jboolean compare_nulls_equal) { + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition) { return cudf::jni::cond_join_gather_single_map( - env, j_left_table, j_right_table, j_condition, compare_nulls_equal, + env, j_left_table, j_right_table, j_condition, [](cudf::table_view const &left, cudf::table_view const &right, - cudf::ast::expression const &cond_expr, cudf::null_equality nulleq) { - return cudf::conditional_left_anti_join(left, right, cond_expr, nulleq); + cudf::ast::expression const &cond_expr) { + return cudf::conditional_left_anti_join(left, right, cond_expr); }); } JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalLeftAntiJoinGatherMapWithCount( JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, - jboolean compare_nulls_equal, jlong j_row_count) { + jlong j_row_count) { auto row_count = static_cast(j_row_count); return cudf::jni::cond_join_gather_single_map( - env, j_left_table, j_right_table, j_condition, compare_nulls_equal, + env, j_left_table, j_right_table, j_condition, [row_count](cudf::table_view const &left, cudf::table_view const &right, - cudf::ast::expression const &cond_expr, cudf::null_equality nulleq) { - return cudf::conditional_left_anti_join(left, right, cond_expr, nulleq, row_count); + cudf::ast::expression const &cond_expr) { + return cudf::conditional_left_anti_join(left, right, cond_expr, row_count); }); } diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index cc030c392cb..cd1e433d07b 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -1608,7 +1608,7 @@ void testConditionalLeftJoinGatherMaps() { .column(inv, inv, 0, 1, 3, inv, inv, 0, 1, inv, 1, inv, 0, 1) .build(); CompiledExpression condition = expr.compile()) { - GatherMap[] maps = left.conditionalLeftJoinGatherMaps(right, condition, false); + GatherMap[] maps = left.conditionalLeftJoinGatherMaps(right, condition); try { verifyJoinGatherMaps(maps, expected); } finally { @@ -1622,7 +1622,7 @@ void testConditionalLeftJoinGatherMaps() { @Test void testConditionalLeftJoinGatherMapsNulls() { final int inv = Integer.MIN_VALUE; - BinaryOperation expr = new BinaryOperation(BinaryOperator.EQUAL, + BinaryOperation expr = new BinaryOperation(BinaryOperator.NULL_EQUAL, new ColumnReference(0, TableReference.LEFT), new ColumnReference(0, TableReference.RIGHT)); try (Table left = new Table.TestBuilder() @@ -1636,7 +1636,7 @@ void testConditionalLeftJoinGatherMapsNulls() { .column(inv, inv, 2, inv, inv, inv, inv, 0, 1, 0, 1, 3) // right .build(); CompiledExpression condition = expr.compile()) { - GatherMap[] maps = left.conditionalLeftJoinGatherMaps(right, condition, true); + GatherMap[] maps = left.conditionalLeftJoinGatherMaps(right, condition); try { verifyJoinGatherMaps(maps, expected); } finally { @@ -1662,9 +1662,9 @@ void testConditionalLeftJoinGatherMapsWithCount() { .column(inv, inv, 0, 1, 3, inv, inv, 0, 1, inv, 1, inv, 0, 1) .build(); CompiledExpression condition = expr.compile()) { - long rowCount = left.conditionalLeftJoinRowCount(right, condition, false); + long rowCount = left.conditionalLeftJoinRowCount(right, condition); assertEquals(expected.getRowCount(), rowCount); - GatherMap[] maps = left.conditionalLeftJoinGatherMaps(right, condition, false, rowCount); + GatherMap[] maps = left.conditionalLeftJoinGatherMaps(right, condition, rowCount); try { verifyJoinGatherMaps(maps, expected); } finally { @@ -1678,7 +1678,7 @@ void testConditionalLeftJoinGatherMapsWithCount() { @Test void testConditionalLeftJoinGatherMapsNullsWithCount() { final int inv = Integer.MIN_VALUE; - BinaryOperation expr = new BinaryOperation(BinaryOperator.EQUAL, + BinaryOperation expr = new BinaryOperation(BinaryOperator.NULL_EQUAL, new ColumnReference(0, TableReference.LEFT), new ColumnReference(0, TableReference.RIGHT)); try (Table left = new Table.TestBuilder() @@ -1692,9 +1692,9 @@ void testConditionalLeftJoinGatherMapsNullsWithCount() { .column(inv, inv, 2, inv, inv, inv, inv, 0, 1, 0, 1, 3) // right .build(); CompiledExpression condition = expr.compile()) { - long rowCount = left.conditionalLeftJoinRowCount(right, condition, true); + long rowCount = left.conditionalLeftJoinRowCount(right, condition); assertEquals(expected.getRowCount(), rowCount); - GatherMap[] maps = left.conditionalLeftJoinGatherMaps(right, condition, true, rowCount); + GatherMap[] maps = left.conditionalLeftJoinGatherMaps(right, condition, rowCount); try { verifyJoinGatherMaps(maps, expected); } finally { @@ -1853,7 +1853,7 @@ void testConditionalInnerJoinGatherMaps() { .column(0, 1, 3, 0, 1, 1, 0, 1) .build(); CompiledExpression condition = expr.compile()) { - GatherMap[] maps = left.conditionalInnerJoinGatherMaps(right, condition, false); + GatherMap[] maps = left.conditionalInnerJoinGatherMaps(right, condition); try { verifyJoinGatherMaps(maps, expected); } finally { @@ -1864,11 +1864,39 @@ void testConditionalInnerJoinGatherMaps() { } } + // Test non-null-supporting equality at least once. @Test - void testConditionalInnerJoinGatherMapsNulls() { + void testConditionalInnerJoinGatherMapsEqual() { BinaryOperation expr = new BinaryOperation(BinaryOperator.EQUAL, new ColumnReference(0, TableReference.LEFT), new ColumnReference(0, TableReference.RIGHT)); + try (Table left = new Table.TestBuilder() + .column(2, 3, 9, 0, 1, 7, 4, null, null, 8) + .build(); + Table right = new Table.TestBuilder() + .column(null, null, 9, 8, 10, 32) + .build(); + Table expected = new Table.TestBuilder() + .column(2, 9) // left + .column(2, 3) // right + .build(); + CompiledExpression condition = expr.compile()) { + GatherMap[] maps = left.conditionalInnerJoinGatherMaps(right, condition); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + + @Test + void testConditionalInnerJoinGatherMapsNulls() { + BinaryOperation expr = new BinaryOperation(BinaryOperator.NULL_EQUAL, + new ColumnReference(0, TableReference.LEFT), + new ColumnReference(0, TableReference.RIGHT)); try (Table left = new Table.TestBuilder() .column(2, 3, 9, 0, 1, 7, 4, null, null, 8) .build(); @@ -1880,7 +1908,7 @@ void testConditionalInnerJoinGatherMapsNulls() { .column(2, 0, 1, 0, 1, 3) // right .build(); CompiledExpression condition = expr.compile()) { - GatherMap[] maps = left.conditionalInnerJoinGatherMaps(right, condition, true); + GatherMap[] maps = left.conditionalInnerJoinGatherMaps(right, condition); try { verifyJoinGatherMaps(maps, expected); } finally { @@ -1905,9 +1933,9 @@ void testConditionalInnerJoinGatherMapsWithCount() { .column(0, 1, 3, 0, 1, 1, 0, 1) .build(); CompiledExpression condition = expr.compile()) { - long rowCount = left.conditionalInnerJoinRowCount(right, condition, false); + long rowCount = left.conditionalInnerJoinRowCount(right, condition); assertEquals(expected.getRowCount(), rowCount); - GatherMap[] maps = left.conditionalInnerJoinGatherMaps(right, condition, false, rowCount); + GatherMap[] maps = left.conditionalInnerJoinGatherMaps(right, condition, rowCount); try { verifyJoinGatherMaps(maps, expected); } finally { @@ -1920,7 +1948,7 @@ void testConditionalInnerJoinGatherMapsWithCount() { @Test void testConditionalInnerJoinGatherMapsNullsWithCount() { - BinaryOperation expr = new BinaryOperation(BinaryOperator.EQUAL, + BinaryOperation expr = new BinaryOperation(BinaryOperator.NULL_EQUAL, new ColumnReference(0, TableReference.LEFT), new ColumnReference(0, TableReference.RIGHT)); try (Table left = new Table.TestBuilder() @@ -1934,9 +1962,9 @@ void testConditionalInnerJoinGatherMapsNullsWithCount() { .column(2, 0, 1, 0, 1, 3) // right .build(); CompiledExpression condition = expr.compile()) { - long rowCount = left.conditionalInnerJoinRowCount(right, condition, true); + long rowCount = left.conditionalInnerJoinRowCount(right, condition); assertEquals(expected.getRowCount(), rowCount); - GatherMap[] maps = left.conditionalInnerJoinGatherMaps(right, condition, true, rowCount); + GatherMap[] maps = left.conditionalInnerJoinGatherMaps(right, condition, rowCount); try { verifyJoinGatherMaps(maps, expected); } finally { @@ -2102,7 +2130,7 @@ void testConditionalFullJoinGatherMaps() { .column( 2, 4, 5, inv, inv, 0, 1, 3, inv, inv, 0, 1, inv, 1, inv, 0, 1) .build(); CompiledExpression condition = expr.compile()) { - GatherMap[] maps = left.conditionalFullJoinGatherMaps(right, condition, false); + GatherMap[] maps = left.conditionalFullJoinGatherMaps(right, condition); try { verifyJoinGatherMaps(maps, expected); } finally { @@ -2116,7 +2144,7 @@ void testConditionalFullJoinGatherMaps() { @Test void testConditionalFullJoinGatherMapsNulls() { final int inv = Integer.MIN_VALUE; - BinaryOperation expr = new BinaryOperation(BinaryOperator.EQUAL, + BinaryOperation expr = new BinaryOperation(BinaryOperator.NULL_EQUAL, new ColumnReference(0, TableReference.LEFT), new ColumnReference(0, TableReference.RIGHT)); try (Table left = new Table.TestBuilder() @@ -2130,7 +2158,7 @@ void testConditionalFullJoinGatherMapsNulls() { .column( 4, 5, inv, inv, 2, inv, inv, inv, inv, 0, 1, 0, 1, 3) // right .build(); CompiledExpression condition = expr.compile()) { - GatherMap[] maps = left.conditionalFullJoinGatherMaps(right, condition, true); + GatherMap[] maps = left.conditionalFullJoinGatherMaps(right, condition); try { verifyJoinGatherMaps(maps, expected); } finally { @@ -2182,14 +2210,14 @@ void testConditionalLeftSemiJoinGatherMap() { .column(2, 5, 7, 9) // left .build(); CompiledExpression condition = expr.compile(); - GatherMap map = left.conditionalLeftSemiJoinGatherMap(right, condition, false)) { + GatherMap map = left.conditionalLeftSemiJoinGatherMap(right, condition)) { verifySemiJoinGatherMap(map, expected); } } @Test void testConditionalLeftSemiJoinGatherMapNulls() { - BinaryOperation expr = new BinaryOperation(BinaryOperator.EQUAL, + BinaryOperation expr = new BinaryOperation(BinaryOperator.NULL_EQUAL, new ColumnReference(0, TableReference.LEFT), new ColumnReference(0, TableReference.RIGHT)); try (Table left = new Table.TestBuilder() @@ -2202,7 +2230,7 @@ void testConditionalLeftSemiJoinGatherMapNulls() { .column(2, 7, 8, 9) // left .build(); CompiledExpression condition = expr.compile(); - GatherMap map = left.conditionalLeftSemiJoinGatherMap(right, condition, true)) { + GatherMap map = left.conditionalLeftSemiJoinGatherMap(right, condition)) { verifySemiJoinGatherMap(map, expected); } } @@ -2220,10 +2248,10 @@ void testConditionalLeftSemiJoinGatherMapWithCount() { .column(2, 5, 7, 9) // left .build(); CompiledExpression condition = expr.compile()) { - long rowCount = left.conditionalLeftSemiJoinRowCount(right, condition, false); + long rowCount = left.conditionalLeftSemiJoinRowCount(right, condition); assertEquals(expected.getRowCount(), rowCount); try (GatherMap map = - left.conditionalLeftSemiJoinGatherMap(right, condition, false, rowCount)) { + left.conditionalLeftSemiJoinGatherMap(right, condition, rowCount)) { verifySemiJoinGatherMap(map, expected); } } @@ -2231,7 +2259,7 @@ void testConditionalLeftSemiJoinGatherMapWithCount() { @Test void testConditionalLeftSemiJoinGatherMapNullsWithCount() { - BinaryOperation expr = new BinaryOperation(BinaryOperator.EQUAL, + BinaryOperation expr = new BinaryOperation(BinaryOperator.NULL_EQUAL, new ColumnReference(0, TableReference.LEFT), new ColumnReference(0, TableReference.RIGHT)); try (Table left = new Table.TestBuilder() @@ -2244,10 +2272,10 @@ void testConditionalLeftSemiJoinGatherMapNullsWithCount() { .column(2, 7, 8, 9) // left .build(); CompiledExpression condition = expr.compile()) { - long rowCount = left.conditionalLeftSemiJoinRowCount(right, condition, true); + long rowCount = left.conditionalLeftSemiJoinRowCount(right, condition); assertEquals(expected.getRowCount(), rowCount); try (GatherMap map = - left.conditionalLeftSemiJoinGatherMap(right, condition, true, rowCount)) { + left.conditionalLeftSemiJoinGatherMap(right, condition, rowCount)) { verifySemiJoinGatherMap(map, expected); } } @@ -2294,14 +2322,14 @@ void testConditionalLeftAntiJoinGatherMap() { .column(0, 1, 3, 4, 6, 8) // left .build(); CompiledExpression condition = expr.compile(); - GatherMap map = left.conditionalLeftAntiJoinGatherMap(right, condition, false)) { + GatherMap map = left.conditionalLeftAntiJoinGatherMap(right, condition)) { verifySemiJoinGatherMap(map, expected); } } @Test void testConditionalAntiSemiJoinGatherMapNulls() { - BinaryOperation expr = new BinaryOperation(BinaryOperator.EQUAL, + BinaryOperation expr = new BinaryOperation(BinaryOperator.NULL_EQUAL, new ColumnReference(0, TableReference.LEFT), new ColumnReference(0, TableReference.RIGHT)); try (Table left = new Table.TestBuilder() @@ -2314,7 +2342,7 @@ void testConditionalAntiSemiJoinGatherMapNulls() { .column(0, 1, 3, 4, 5, 6) // left .build(); CompiledExpression condition = expr.compile(); - GatherMap map = left.conditionalLeftAntiJoinGatherMap(right, condition, true)) { + GatherMap map = left.conditionalLeftAntiJoinGatherMap(right, condition)) { verifySemiJoinGatherMap(map, expected); } } @@ -2332,10 +2360,10 @@ void testConditionalLeftAntiJoinGatherMapWithCount() { .column(0, 1, 3, 4, 6, 8) // left .build(); CompiledExpression condition = expr.compile()) { - long rowCount = left.conditionalLeftAntiJoinRowCount(right, condition, false); + long rowCount = left.conditionalLeftAntiJoinRowCount(right, condition); assertEquals(expected.getRowCount(), rowCount); try (GatherMap map = - left.conditionalLeftAntiJoinGatherMap(right, condition, false, rowCount)) { + left.conditionalLeftAntiJoinGatherMap(right, condition, rowCount)) { verifySemiJoinGatherMap(map, expected); } } @@ -2343,7 +2371,7 @@ void testConditionalLeftAntiJoinGatherMapWithCount() { @Test void testConditionalAntiSemiJoinGatherMapNullsWithCount() { - BinaryOperation expr = new BinaryOperation(BinaryOperator.EQUAL, + BinaryOperation expr = new BinaryOperation(BinaryOperator.NULL_EQUAL, new ColumnReference(0, TableReference.LEFT), new ColumnReference(0, TableReference.RIGHT)); try (Table left = new Table.TestBuilder() @@ -2356,10 +2384,10 @@ void testConditionalAntiSemiJoinGatherMapNullsWithCount() { .column(0, 1, 3, 4, 5, 6) // left .build(); CompiledExpression condition = expr.compile()) { - long rowCount = left.conditionalLeftAntiJoinRowCount(right, condition, true); + long rowCount = left.conditionalLeftAntiJoinRowCount(right, condition); assertEquals(expected.getRowCount(), rowCount); try (GatherMap map = - left.conditionalLeftAntiJoinGatherMap(right, condition, true, rowCount)) { + left.conditionalLeftAntiJoinGatherMap(right, condition, rowCount)) { verifySemiJoinGatherMap(map, expected); } } diff --git a/java/src/test/java/ai/rapids/cudf/ast/CompiledExpressionTest.java b/java/src/test/java/ai/rapids/cudf/ast/CompiledExpressionTest.java index 77f24b1e9f3..2fb6792b409 100644 --- a/java/src/test/java/ai/rapids/cudf/ast/CompiledExpressionTest.java +++ b/java/src/test/java/ai/rapids/cudf/ast/CompiledExpressionTest.java @@ -452,7 +452,7 @@ private static Stream createBinaryComparisonOperationParams() { Integer[] in2 = new Integer[] { 123, -456, null, 0, -3 }; return Stream.of( // nulls compare as equal by default - Arguments.of(BinaryOperator.EQUAL, in1, in2, Arrays.asList(false, false, true, false, true)), + Arguments.of(BinaryOperator.NULL_EQUAL, in1, in2, Arrays.asList(false, false, true, false, true)), Arguments.of(BinaryOperator.NOT_EQUAL, in1, in2, mapArray(in1, in2, (a, b) -> !a.equals(b))), Arguments.of(BinaryOperator.LESS, in1, in2, mapArray(in1, in2, (a, b) -> a < b)), Arguments.of(BinaryOperator.GREATER, in1, in2, mapArray(in1, in2, (a, b) -> a > b)), @@ -502,11 +502,13 @@ void testBinaryBitwiseOperationTransform(BinaryOperator op, Integer[] in1, Integ } private static Stream createBinaryBooleanOperationParams() { - Boolean[] in1 = new Boolean[] { false, true, null, true, false }; - Boolean[] in2 = new Boolean[] { true, null, null, true, false }; + Boolean[] in1 = new Boolean[] { false, true, false, null, true, false }; + Boolean[] in2 = new Boolean[] { true, null, null, null, true, false }; return Stream.of( Arguments.of(BinaryOperator.LOGICAL_AND, in1, in2, mapArray(in1, in2, (a, b) -> a && b)), - Arguments.of(BinaryOperator.LOGICAL_OR, in1, in2, mapArray(in1, in2, (a, b) -> a || b))); + Arguments.of(BinaryOperator.LOGICAL_OR, in1, in2, mapArray(in1, in2, (a, b) -> a || b)), + Arguments.of(BinaryOperator.NULL_LOGICAL_AND, in1, in2, Arrays.asList(false, null, false, null, true, false)), + Arguments.of(BinaryOperator.NULL_LOGICAL_OR, in1, in2, Arrays.asList(true, true, null, null, true, false))); } @ParameterizedTest