From 0782dd900c0f7d750487ecb336b30c9a8dac4869 Mon Sep 17 00:00:00 2001 From: Arpit Porwal Date: Tue, 13 Jun 2023 10:21:41 -0700 Subject: [PATCH] Added support for complex type (Array, Map) in NeqFunction (#5186) Summary: Pull Request resolved: https://github.com/facebookincubator/velox/pull/5186 Similar to eqFunction,registered function of type , Generic> Reviewed By: laithsakka Differential Revision: D46540217 fbshipit-source-id: 94f05e42153a7312c83fecd483e2ebf1756107b6 --- velox/functions/prestosql/Comparisons.h | 25 +++++++++++++++- .../ComparisonFunctionsRegistration.cpp | 2 ++ .../prestosql/tests/ComparisonsTest.cpp | 30 ++++++++++++++----- 3 files changed, 48 insertions(+), 9 deletions(-) diff --git a/velox/functions/prestosql/Comparisons.h b/velox/functions/prestosql/Comparisons.h index f3d7c7c7e5d8..2642e4347bf2 100644 --- a/velox/functions/prestosql/Comparisons.h +++ b/velox/functions/prestosql/Comparisons.h @@ -30,7 +30,6 @@ namespace facebook::velox::functions { } \ }; -VELOX_GEN_BINARY_EXPR(NeqFunction, lhs != rhs, bool); VELOX_GEN_BINARY_EXPR(LtFunction, lhs < rhs, bool); VELOX_GEN_BINARY_EXPR(GtFunction, lhs > rhs, bool); VELOX_GEN_BINARY_EXPR(LteFunction, lhs <= rhs, bool); @@ -86,6 +85,30 @@ struct EqFunction { } }; +template +struct NeqFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + // Used for primitive inputs. + template + void call(bool& out, const TInput& lhs, const TInput& rhs) { + out = (lhs != rhs); + } + + // For arbitrary nested complex types. Can return null. + bool call( + bool& out, + const arg_type>& lhs, + const arg_type>& rhs) { + if (EqFunction().call(out, lhs, rhs)) { + out = !out; + return true; + } else { + return false; + } + } +}; + template struct BetweenFunction { template diff --git a/velox/functions/prestosql/registration/ComparisonFunctionsRegistration.cpp b/velox/functions/prestosql/registration/ComparisonFunctionsRegistration.cpp index 3e43371646d1..b87fc755f227 100644 --- a/velox/functions/prestosql/registration/ComparisonFunctionsRegistration.cpp +++ b/velox/functions/prestosql/registration/ComparisonFunctionsRegistration.cpp @@ -27,6 +27,8 @@ void registerComparisonFunctions(const std::string& prefix) { registerNonSimdizableScalar({prefix + "neq"}); VELOX_REGISTER_VECTOR_FUNCTION(udf_simd_comparison_neq, prefix + "neq"); + registerFunction, Generic>( + {prefix + "neq"}); registerNonSimdizableScalar({prefix + "lt"}); VELOX_REGISTER_VECTOR_FUNCTION(udf_simd_comparison_lt, prefix + "lt"); diff --git a/velox/functions/prestosql/tests/ComparisonsTest.cpp b/velox/functions/prestosql/tests/ComparisonsTest.cpp index 9b4631957bdc..fffb94d3ca18 100644 --- a/velox/functions/prestosql/tests/ComparisonsTest.cpp +++ b/velox/functions/prestosql/tests/ComparisonsTest.cpp @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include #include "velox/common/base/tests/GTestUtils.h" #include "velox/functions/Udf.h" #include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" @@ -369,22 +370,30 @@ TEST_F(ComparisonsTest, gtLtDecimal) { runAndCompare("c1 <= c0", longDecimalsInputs, expectedGteLte); }; -TEST_F(ComparisonsTest, eqArray) { +TEST_F(ComparisonsTest, eqNeqArray) { auto test = [&](const std::optional>>& array1, const std::optional>>& array2, std::optional expected) { auto vector1 = vectorMaker_.arrayVectorNullable({array1}); auto vector2 = vectorMaker_.arrayVectorNullable({array2}); - auto result = evaluate>( + auto eqResult = evaluate>( "c0 == c1", makeRowVector({vector1, vector2})); - ASSERT_EQ(expected.has_value(), !result->isNullAt(0)); + auto neqResult = evaluate>( + "c0 != c1", makeRowVector({vector1, vector2})); + + ASSERT_EQ(expected.has_value(), !eqResult->isNullAt(0)); + ASSERT_EQ(expected.has_value(), !neqResult->isNullAt(0)); if (expected.has_value()) { - ASSERT_EQ(expected.value(), result->valueAt(0)); + // equals check + ASSERT_EQ(expected.value(), eqResult->valueAt(0)); + // not equal check + ASSERT_EQ(!expected.value(), neqResult->valueAt(0)); } }; + // eq and neq function test test(std::nullopt, std::nullopt, std::nullopt); test(std::nullopt, {{1}}, std::nullopt); test({{1}}, std::nullopt, std::nullopt); @@ -414,7 +423,7 @@ TEST_F(ComparisonsTest, eqArray) { std::nullopt); } -TEST_F(ComparisonsTest, eqMap) { +TEST_F(ComparisonsTest, eqNeqMap) { using map_t = std::optional>>>; auto test = @@ -422,16 +431,21 @@ TEST_F(ComparisonsTest, eqMap) { auto vector1 = makeNullableMapVector({map1}); auto vector2 = makeNullableMapVector({map2}); - auto result = evaluate>( + auto eqResult = evaluate>( "c0 == c1", makeRowVector({vector1, vector2})); - ASSERT_EQ(expected.has_value(), !result->isNullAt(0)); + auto neqResult = evaluate>( + "c0 != c1", makeRowVector({vector1, vector2})); + ASSERT_EQ(expected.has_value(), !eqResult->isNullAt(0)); + ASSERT_EQ(expected.has_value(), !neqResult->isNullAt(0)); if (expected.has_value()) { - ASSERT_EQ(expected.value(), result->valueAt(0)); + ASSERT_EQ(expected.value(), eqResult->valueAt(0)); + ASSERT_EQ(!expected.value(), neqResult->valueAt(0)); } }; + // eq and neq function test test({{{1, 2}, {3, 4}}}, {{{1, 2}, {3, 4}}}, true); // Elements checked in sorted order.