From 0134a1e5ddf4b538e00e5194e29b13e8eacf3b1b Mon Sep 17 00:00:00 2001 From: macduan Date: Tue, 21 Mar 2023 11:50:10 +0800 Subject: [PATCH 1/6] add any_match function --- velox/docs/functions/presto/array.rst | 9 + velox/functions/prestosql/ArrayAllMatch.cpp | 141 ----------- velox/functions/prestosql/ArrayMatch.cpp | 236 ++++++++++++++++++ velox/functions/prestosql/CMakeLists.txt | 2 +- .../ArrayFunctionsRegistration.cpp | 1 + .../prestosql/tests/ArrayAnyMatchTest.cpp | 173 +++++++++++++ .../functions/prestosql/tests/CMakeLists.txt | 1 + 7 files changed, 421 insertions(+), 142 deletions(-) delete mode 100644 velox/functions/prestosql/ArrayAllMatch.cpp create mode 100644 velox/functions/prestosql/ArrayMatch.cpp create mode 100644 velox/functions/prestosql/tests/ArrayAnyMatchTest.cpp diff --git a/velox/docs/functions/presto/array.rst b/velox/docs/functions/presto/array.rst index e01178cf37e7..a37f8b2eb398 100644 --- a/velox/docs/functions/presto/array.rst +++ b/velox/docs/functions/presto/array.rst @@ -11,6 +11,15 @@ Array Functions Returns NULL if the predicate function returns NULL for one or more elements and true for all other elements. Throws an exception if the predicate fails for one or more elements and returns true or NULL for the rest. +.. function:: any_match(array(T), function(T, boolean)) → boolean + + Returns whether any elements of an array match the given predicate. + + Returns true if one or more elements match the predicate; + Returns false if none of the elements matches (a special case is when the array is empty); + Returns NULL NULL if the predicate function returns NULL for one or more elements and false for all other elements. + Throws an exception if the predicate fails for one or more elements and returns false or NULL for the rest. + .. function:: array_average(array(double)) -> double Returns the average of all non-null elements of the array. If there are no non-null elements, returns null. diff --git a/velox/functions/prestosql/ArrayAllMatch.cpp b/velox/functions/prestosql/ArrayAllMatch.cpp deleted file mode 100644 index d716ec325c79..000000000000 --- a/velox/functions/prestosql/ArrayAllMatch.cpp +++ /dev/null @@ -1,141 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "velox/expression/EvalCtx.h" -#include "velox/expression/Expr.h" -#include "velox/expression/VectorFunction.h" -#include "velox/functions/lib/LambdaFunctionUtil.h" -#include "velox/functions/lib/RowsTranslationUtil.h" -#include "velox/vector/FunctionVector.h" - -namespace facebook::velox::functions { -namespace { - -class AllMatchFunction : public exec::VectorFunction { - public: - void apply( - const SelectivityVector& rows, - std::vector& args, - const TypePtr& /*outputType*/, - exec::EvalCtx& context, - VectorPtr& result) const override { - exec::LocalDecodedVector arrayDecoder(context, *args[0], rows); - auto& decodedArray = *arrayDecoder.get(); - - auto flatArray = flattenArray(rows, args[0], decodedArray); - auto offsets = flatArray->rawOffsets(); - auto sizes = flatArray->rawSizes(); - - std::vector lambdaArgs = {flatArray->elements()}; - auto numElements = flatArray->elements()->size(); - - SelectivityVector finalSelection; - if (!context.isFinalSelection()) { - finalSelection = toElementRows( - numElements, *context.finalSelection(), flatArray.get()); - } - - VectorPtr matchBits; - auto elementToTopLevelRows = getElementToTopLevelRows( - numElements, rows, flatArray.get(), context.pool()); - - // Loop over lambda functions and apply these to elements of the base array, - // in most cases there will be only one function and the loop will run once. - context.ensureWritable(rows, BOOLEAN(), result); - auto flatResult = result->asFlatVector(); - exec::LocalDecodedVector bitsDecoder(context); - auto it = args[1]->asUnchecked()->iterator(&rows); - - while (auto entry = it.next()) { - ErrorVectorPtr elementErrors; - auto elementRows = - toElementRows(numElements, *entry.rows, flatArray.get()); - auto wrapCapture = toWrapCapture( - numElements, entry.callable, *entry.rows, flatArray); - entry.callable->applyNoThrow( - elementRows, - finalSelection, - wrapCapture, - &context, - lambdaArgs, - elementErrors, - &matchBits); - - bitsDecoder.get()->decode(*matchBits, elementRows); - entry.rows->applyToSelected([&](vector_size_t row) { - auto size = sizes[row]; - auto offset = offsets[row]; - auto allMatch = true; - auto hasNull = false; - std::exception_ptr errorPtr{nullptr}; - for (auto i = 0; i < size; ++i) { - auto idx = offset + i; - if (hasError(elementErrors, idx)) { - errorPtr = *std::static_pointer_cast( - elementErrors->valueAt(idx)); - continue; - } - - if (bitsDecoder->isNullAt(idx)) { - hasNull = true; - } else if (!bitsDecoder->valueAt(idx)) { - allMatch = false; - break; - } - } - - // Errors for individual array elements should be suppressed only if the - // outcome can be decided by some other array element, e.g. if there is - // another element that returns 'false' for the predicate. - if (!allMatch) { - flatResult->set(row, false); - } else if (errorPtr) { - context.setError(row, errorPtr); - } else if (hasNull) { - flatResult->setNull(row, true); - } else { - flatResult->set(row, true); - } - }); - } - } - - static std::vector> signatures() { - // array(T), function(T) -> boolean - return {exec::FunctionSignatureBuilder() - .typeVariable("T") - .returnType("boolean") - .argumentType("array(T)") - .argumentType("function(T, boolean)") - .build()}; - } - - private: - FOLLY_ALWAYS_INLINE bool hasError( - const ErrorVectorPtr& elementErrors, - int idx) const { - return elementErrors && idx < elementErrors->size() && - !elementErrors->isNullAt(idx); - } -}; -} // namespace - -VELOX_DECLARE_VECTOR_FUNCTION( - udf_all_match, - AllMatchFunction::signatures(), - std::make_unique()); - -} // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/ArrayMatch.cpp b/velox/functions/prestosql/ArrayMatch.cpp new file mode 100644 index 000000000000..704391bb96d2 --- /dev/null +++ b/velox/functions/prestosql/ArrayMatch.cpp @@ -0,0 +1,236 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/expression/EvalCtx.h" +#include "velox/expression/Expr.h" +#include "velox/expression/VectorFunction.h" +#include "velox/functions/lib/LambdaFunctionUtil.h" +#include "velox/functions/lib/RowsTranslationUtil.h" +#include "velox/vector/FunctionVector.h" + +namespace facebook::velox::functions { +namespace { + +class MatchFunction : public exec::VectorFunction { + private: + void apply( + const SelectivityVector& rows, + std::vector& args, + const TypePtr& outputType, + exec::EvalCtx& context, + VectorPtr& result) const override { + exec::LocalDecodedVector arrayDecoder(context, *args[0], rows); + auto& decodedArray = *arrayDecoder.get(); + + auto flatArray = flattenArray(rows, args[0], decodedArray); + auto offsets = flatArray->rawOffsets(); + auto sizes = flatArray->rawSizes(); + + std::vector lambdaArgs = {flatArray->elements()}; + auto numElements = flatArray->elements()->size(); + + SelectivityVector finalSelection; + if (!context.isFinalSelection()) { + finalSelection = toElementRows( + numElements, *context.finalSelection(), flatArray.get()); + } + + VectorPtr matchBits; + auto elementToTopLevelRows = getElementToTopLevelRows( + numElements, rows, flatArray.get(), context.pool()); + + // Loop over lambda functions and apply these to elements of the base array, + // in most cases there will be only one function and the loop will run once. + context.ensureWritable(rows, BOOLEAN(), result); + auto flatResult = result->asFlatVector(); + exec::LocalDecodedVector bitsDecoder(context); + auto it = args[1]->asUnchecked()->iterator(&rows); + + while (auto entry = it.next()) { + ErrorVectorPtr elementErrors; + auto elementRows = + toElementRows(numElements, *entry.rows, flatArray.get()); + auto wrapCapture = toWrapCapture( + numElements, entry.callable, *entry.rows, flatArray); + entry.callable->applyNoThrow( + elementRows, + finalSelection, + wrapCapture, + &context, + lambdaArgs, + elementErrors, + &matchBits); + + bitsDecoder.get()->decode(*matchBits, elementRows); + entry.rows->applyToSelected([&](vector_size_t row) { + applyInternal( + flatResult, + context, + row, + offsets, + sizes, + elementErrors, + bitsDecoder); + }); + } + } + + protected: + static FOLLY_ALWAYS_INLINE bool hasError( + const ErrorVectorPtr& elementErrors, + int idx) { + return elementErrors && idx < elementErrors->size() && + !elementErrors->isNullAt(idx); + } + + private: + virtual void applyInternal( + FlatVector* flatResult, + exec::EvalCtx& context, + vector_size_t row, + const vector_size_t* offsets, + const vector_size_t* sizes, + const ErrorVectorPtr& elementErrors, + const exec::LocalDecodedVector& bitsDecoder) const = 0; +}; + +class AllMatchFunction : public MatchFunction { + private: + void applyInternal( + FlatVector* flatResult, + exec::EvalCtx& context, + vector_size_t row, + const vector_size_t* offsets, + const vector_size_t* sizes, + const ErrorVectorPtr& elementErrors, + const exec::LocalDecodedVector& bitsDecoder) const override { + auto size = sizes[row]; + auto offset = offsets[row]; + auto allMatch = true; + auto hasNull = false; + std::exception_ptr errorPtr{nullptr}; + for (auto i = 0; i < size; ++i) { + auto idx = offset + i; + if (hasError(elementErrors, idx)) { + errorPtr = *std::static_pointer_cast( + elementErrors->valueAt(idx)); + continue; + } + + if (bitsDecoder->isNullAt(idx)) { + hasNull = true; + } else if (!bitsDecoder->valueAt(idx)) { + allMatch = false; + break; + } + } + + // Errors for individual array elements should be suppressed only if the + // outcome can be decided by some other array element, e.g. if there is + // another element that returns 'false' for the predicate. + if (!allMatch) { + flatResult->set(row, false); + } else if (errorPtr) { + context.setError(row, errorPtr); + } else if (hasNull) { + flatResult->setNull(row, true); + } else { + flatResult->set(row, true); + } + } + + public: + static std::vector> signatures() { + // array(T), function(T) -> boolean + return {exec::FunctionSignatureBuilder() + .typeVariable("T") + .returnType("boolean") + .argumentType("array(T)") + .argumentType("function(T, boolean)") + .build()}; + } +}; + +class AnyMatchFunction : public MatchFunction { + private: + void applyInternal( + FlatVector* flatResult, + exec::EvalCtx& context, + vector_size_t row, + const vector_size_t* offsets, + const vector_size_t* sizes, + const ErrorVectorPtr& elementErrors, + const exec::LocalDecodedVector& bitsDecoder) const override { + auto size = sizes[row]; + auto offset = offsets[row]; + auto nonMatch = true; + auto hasNull = false; + std::exception_ptr errorPtr{nullptr}; + for (auto i = 0; i < size; ++i) { + auto idx = offset + i; + if (hasError(elementErrors, idx)) { + errorPtr = *std::static_pointer_cast( + elementErrors->valueAt(idx)); + continue; + } + + if (bitsDecoder->isNullAt(idx)) { + hasNull = true; + } else if (bitsDecoder->valueAt(idx)) { + nonMatch = false; + break; + } + } + + // Errors for individual array elements should be suppressed only if the + // outcome can be decided by some other array element, e.g. if there is + // another element that returns 'true' for the predicate. + if (!nonMatch) { + flatResult->set(row, true); + } else if (errorPtr) { + context.setError(row, errorPtr); + } else if (hasNull) { + flatResult->setNull(row, true); + } else { + flatResult->set(row, false); + } + } + + public: + static std::vector> signatures() { + // array(T), function(T) -> boolean + return {exec::FunctionSignatureBuilder() + .typeVariable("T") + .returnType("boolean") + .argumentType("array(T)") + .argumentType("function(T, boolean)") + .build()}; + } +}; + +} // namespace + +VELOX_DECLARE_VECTOR_FUNCTION( + udf_all_match, + AllMatchFunction::signatures(), + std::make_unique()); + +VELOX_DECLARE_VECTOR_FUNCTION( + udf_any_match, + AnyMatchFunction::signatures(), + std::make_unique()); + +} // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/CMakeLists.txt b/velox/functions/prestosql/CMakeLists.txt index 9240e4ad200f..1d6898c1ddf6 100644 --- a/velox/functions/prestosql/CMakeLists.txt +++ b/velox/functions/prestosql/CMakeLists.txt @@ -17,12 +17,12 @@ add_subdirectory(window) add_library( velox_functions_prestosql_impl - ArrayAllMatch.cpp ArrayConstructor.cpp ArrayContains.cpp ArrayDistinct.cpp ArrayDuplicates.cpp ArrayIntersectExcept.cpp + ArrayMatch.cpp ArrayPosition.cpp ArrayShuffle.cpp ArraySort.cpp diff --git a/velox/functions/prestosql/registration/ArrayFunctionsRegistration.cpp b/velox/functions/prestosql/registration/ArrayFunctionsRegistration.cpp index dfffebb71096..0325d23ed497 100644 --- a/velox/functions/prestosql/registration/ArrayFunctionsRegistration.cpp +++ b/velox/functions/prestosql/registration/ArrayFunctionsRegistration.cpp @@ -79,6 +79,7 @@ inline void registerArrayNormalizeFunctions(const std::string& prefix) { void registerArrayFunctions(const std::string& prefix) { registerArrayConstructor(prefix + "array_constructor"); VELOX_REGISTER_VECTOR_FUNCTION(udf_all_match, prefix + "all_match"); + VELOX_REGISTER_VECTOR_FUNCTION(udf_any_match, prefix + "any_match"); VELOX_REGISTER_VECTOR_FUNCTION(udf_array_distinct, prefix + "array_distinct"); VELOX_REGISTER_VECTOR_FUNCTION( udf_array_duplicates, prefix + "array_duplicates"); diff --git a/velox/functions/prestosql/tests/ArrayAnyMatchTest.cpp b/velox/functions/prestosql/tests/ArrayAnyMatchTest.cpp new file mode 100644 index 000000000000..ee8729789267 --- /dev/null +++ b/velox/functions/prestosql/tests/ArrayAnyMatchTest.cpp @@ -0,0 +1,173 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" + +using namespace facebook::velox; +using namespace facebook::velox::test; + +class ArrayAnyMatchTest : public functions::test::FunctionBaseTest { + protected: + // Evaluate an expression. + void testExpr( + const VectorPtr& expected, + const std::string& expression, + const VectorPtr& input) { + auto result = evaluate(expression, makeRowVector({input})); + assertEqualVectors(expected, result); + } + + void testExpr( + const VectorPtr& expected, + const std::string& expression, + const RowVectorPtr& input) { + auto result = evaluate(expression, (input)); + assertEqualVectors(expected, result); + } +}; + +TEST_F(ArrayAnyMatchTest, basic) { + auto input = makeNullableArrayVector( + {{std::nullopt, 2, 3}, + {-1, 3}, + {2, 3}, + {}, + {std::nullopt, std::nullopt}}); + auto expectedResult = + makeNullableFlatVector({true, true, true, false, std::nullopt}); + testExpr(expectedResult, "any_match(c0, x -> (x > 1))", input); + expectedResult = makeFlatVector({true, false, false, false, true}); + testExpr(expectedResult, "any_match(c0, x -> (x is null))", input); + + input = makeNullableArrayVector( + {{false, true}, + {false, false}, + {std::nullopt, true}, + {std::nullopt, false}}); + expectedResult = + makeNullableFlatVector({true, false, true, std::nullopt}); + testExpr(expectedResult, "any_match(c0, x -> x)", input); + + auto emptyInput = makeArrayVector({{}}); + expectedResult = makeFlatVector(std::vector{false}); + testExpr(expectedResult, "any_match(c0, x -> (x > 1))", emptyInput); + testExpr(expectedResult, "any_match(c0, x -> (x <= 1))", emptyInput); +} + +TEST_F(ArrayAnyMatchTest, complexTypes) { + auto baseVector = + makeArrayVector({{1, 2, 3}, {2, 2}, {3, 3}, {4, 4}, {5, 5}, {}}); + // Create an array of array vector using above base vector using offsets. + // [ + // [[1, 2, 3], []], + // [[2, 2], [3, 3], [4, 4], [5, 5]], + // [[], []] + // ] + auto arrayOfArrays = makeArrayVector({0, 1, 5}, baseVector); + auto expectedResult = makeNullableFlatVector({true, true, false}); + testExpr( + expectedResult, + "any_match(c0, x -> (cardinality(x) > 0))", + arrayOfArrays); + + // Create an array of array vector using above base vector using offsets. + // [ + // [[1, 2, 3]], cardinalities is 3 + // [[2, 2], [3, 3], [4, 4], [5, 5]], all cardinalities is 2 + // [[]], + // null + // ] + arrayOfArrays = makeArrayVector({0, 1, 5, 6}, baseVector, {3}); + expectedResult = + makeNullableFlatVector({true, false, false, std::nullopt}); + testExpr( + expectedResult, + "any_match(c0, x -> (cardinality(x) > 2))", + arrayOfArrays); +} + +TEST_F(ArrayAnyMatchTest, bigints) { + auto input = makeNullableArrayVector( + {{}, + {2}, + {std::numeric_limits::max()}, + {std::numeric_limits::min()}, + {std::nullopt, std::nullopt}, // return null if all is null + {2, + std::nullopt}, // return null if one or more is null and others matched + {1, std::nullopt, 2}}); // return false if one is not matched + auto expectedResult = makeNullableFlatVector( + {false, true, false, true, std::nullopt, true, true}); + testExpr(expectedResult, "any_match(c0, x -> (x % 2 = 0))", input); +} + +TEST_F(ArrayAnyMatchTest, strings) { + auto input = makeNullableArrayVector( + {{}, {"abc"}, {"ab", "abc"}, {std::nullopt}}); + auto expectedResult = + makeNullableFlatVector({false, true, true, std::nullopt}); + testExpr(expectedResult, "any_match(c0, x -> (x = 'abc'))", input); +} + +TEST_F(ArrayAnyMatchTest, doubles) { + auto input = + makeNullableArrayVector({{}, {0.2}, {3.0, 0}, {std::nullopt}}); + auto expectedResult = + makeNullableFlatVector({false, false, true, std::nullopt}); + testExpr(expectedResult, "any_match(c0, x -> (x > 1.1))", input); +} + +TEST_F(ArrayAnyMatchTest, errors) { + // No throw and return false if there are unmatched elements except nulls + auto expression = "any_match(c0, x -> ((10 / x) > 2))"; + auto input = makeNullableArrayVector( + {{0, 2, 0, 5, 0}, {2, 5, std::nullopt, 0}}); + auto expectedResult = makeFlatVector({true, true}); + testExpr(expectedResult, expression, input); + + // Throw error if others are matched or null + static constexpr std::string_view kErrorMessage{"division by zero"}; + auto errorInput = makeNullableArrayVector( + {{1, 0}, {2}, {6}, {10, 9, 0, std::nullopt}, {0, std::nullopt, 1}}); + VELOX_ASSERT_THROW( + testExpr(expectedResult, expression, errorInput), kErrorMessage); + // Rerun using TRY to get right results + expectedResult = + makeNullableFlatVector({true, true, false, std::nullopt, true}); + testExpr( + expectedResult, "TRY(any_match(c0, x -> ((10 / x) > 2)))", errorInput); + testExpr( + expectedResult, "any_match(c0, x -> (TRY((10 / x) > 2)))", errorInput); +} + +TEST_F(ArrayAnyMatchTest, conditional) { + // No throw and return false if there are unmatched elements except nulls + auto c0 = makeFlatVector({1, 2, 3, 4, 5}); + auto c1 = makeNullableArrayVector( + {{4, 100, std::nullopt}, + {50, 12}, + {std::nullopt}, + {3, std::nullopt, 0}, + {300, 100}}); + auto input = makeRowVector({c0, c1}); + auto expectedResult = makeNullableFlatVector( + {std::nullopt, false, std::nullopt, true, false}); + testExpr( + expectedResult, + "any_match(c1, if (c0 <= 2, x -> (x > 100), x -> (10 / x > 2)))", + input); +} diff --git a/velox/functions/prestosql/tests/CMakeLists.txt b/velox/functions/prestosql/tests/CMakeLists.txt index c47130c97308..6f43d4f6f851 100644 --- a/velox/functions/prestosql/tests/CMakeLists.txt +++ b/velox/functions/prestosql/tests/CMakeLists.txt @@ -18,6 +18,7 @@ add_executable( velox_functions_test ArithmeticTest.cpp ArrayAllMatchTest.cpp + ArrayAnyMatchTest.cpp ArrayAverageTest.cpp ArrayCombinationsTest.cpp ArrayConstructorTest.cpp From 83dcfab516dae8ab6ca90928fd9ac40db36ebe44 Mon Sep 17 00:00:00 2001 From: macduan Date: Tue, 21 Mar 2023 20:31:00 +0800 Subject: [PATCH 2/6] resolve comments --- velox/docs/functions/presto/array.rst | 4 +- velox/functions/prestosql/ArrayMatch.cpp | 119 ++++++----------------- 2 files changed, 31 insertions(+), 92 deletions(-) diff --git a/velox/docs/functions/presto/array.rst b/velox/docs/functions/presto/array.rst index a37f8b2eb398..47065a66387b 100644 --- a/velox/docs/functions/presto/array.rst +++ b/velox/docs/functions/presto/array.rst @@ -13,11 +13,11 @@ Array Functions .. function:: any_match(array(T), function(T, boolean)) → boolean - Returns whether any elements of an array match the given predicate. + Returns whether at least one element of an array match the given predicate. Returns true if one or more elements match the predicate; Returns false if none of the elements matches (a special case is when the array is empty); - Returns NULL NULL if the predicate function returns NULL for one or more elements and false for all other elements. + Returns NULL if the predicate function returns NULL for one or more elements and false for all other elements. Throws an exception if the predicate fails for one or more elements and returns false or NULL for the rest. .. function:: array_average(array(double)) -> double diff --git a/velox/functions/prestosql/ArrayMatch.cpp b/velox/functions/prestosql/ArrayMatch.cpp index 704391bb96d2..389631cdb5a8 100644 --- a/velox/functions/prestosql/ArrayMatch.cpp +++ b/velox/functions/prestosql/ArrayMatch.cpp @@ -24,6 +24,7 @@ namespace facebook::velox::functions { namespace { +template class MatchFunction : public exec::VectorFunction { private: void apply( @@ -90,24 +91,12 @@ class MatchFunction : public exec::VectorFunction { protected: static FOLLY_ALWAYS_INLINE bool hasError( - const ErrorVectorPtr& elementErrors, + const ErrorVectorPtr& errors, int idx) { - return elementErrors && idx < elementErrors->size() && - !elementErrors->isNullAt(idx); + return errors && idx < errors->size() && + !errors->isNullAt(idx); } - private: - virtual void applyInternal( - FlatVector* flatResult, - exec::EvalCtx& context, - vector_size_t row, - const vector_size_t* offsets, - const vector_size_t* sizes, - const ErrorVectorPtr& elementErrors, - const exec::LocalDecodedVector& bitsDecoder) const = 0; -}; - -class AllMatchFunction : public MatchFunction { private: void applyInternal( FlatVector* flatResult, @@ -116,10 +105,14 @@ class AllMatchFunction : public MatchFunction { const vector_size_t* offsets, const vector_size_t* sizes, const ErrorVectorPtr& elementErrors, - const exec::LocalDecodedVector& bitsDecoder) const override { + const exec::LocalDecodedVector& bitsDecoder) const { auto size = sizes[row]; auto offset = offsets[row]; - auto allMatch = true; + + // All or any match. + // All: early exit on non-match. Initial value = true. + // Any: early exit on match. Initial value = false. + bool result = match; auto hasNull = false; std::exception_ptr errorPtr{nullptr}; for (auto i = 0; i < size; ++i) { @@ -132,105 +125,51 @@ class AllMatchFunction : public MatchFunction { if (bitsDecoder->isNullAt(idx)) { hasNull = true; - } else if (!bitsDecoder->valueAt(idx)) { - allMatch = false; + } else if (bitsDecoder->valueAt(idx) == !match) { + result = !result; break; } } // Errors for individual array elements should be suppressed only if the // outcome can be decided by some other array element, e.g. if there is - // another element that returns 'false' for the predicate. - if (!allMatch) { - flatResult->set(row, false); + // another element that returns 'true' for the predicate. + if (result != match) { + flatResult->set(row, !match); } else if (errorPtr) { context.setError(row, errorPtr); } else if (hasNull) { flatResult->setNull(row, true); } else { - flatResult->set(row, true); + flatResult->set(row, match); } } - - public: - static std::vector> signatures() { - // array(T), function(T) -> boolean - return {exec::FunctionSignatureBuilder() - .typeVariable("T") - .returnType("boolean") - .argumentType("array(T)") - .argumentType("function(T, boolean)") - .build()}; - } }; -class AnyMatchFunction : public MatchFunction { - private: - void applyInternal( - FlatVector* flatResult, - exec::EvalCtx& context, - vector_size_t row, - const vector_size_t* offsets, - const vector_size_t* sizes, - const ErrorVectorPtr& elementErrors, - const exec::LocalDecodedVector& bitsDecoder) const override { - auto size = sizes[row]; - auto offset = offsets[row]; - auto nonMatch = true; - auto hasNull = false; - std::exception_ptr errorPtr{nullptr}; - for (auto i = 0; i < size; ++i) { - auto idx = offset + i; - if (hasError(elementErrors, idx)) { - errorPtr = *std::static_pointer_cast( - elementErrors->valueAt(idx)); - continue; - } +class AllMatchFunction : public MatchFunction {}; - if (bitsDecoder->isNullAt(idx)) { - hasNull = true; - } else if (bitsDecoder->valueAt(idx)) { - nonMatch = false; - break; - } - } +class AnyMatchFunction : public MatchFunction {}; - // Errors for individual array elements should be suppressed only if the - // outcome can be decided by some other array element, e.g. if there is - // another element that returns 'true' for the predicate. - if (!nonMatch) { - flatResult->set(row, true); - } else if (errorPtr) { - context.setError(row, errorPtr); - } else if (hasNull) { - flatResult->setNull(row, true); - } else { - flatResult->set(row, false); - } - } - - public: - static std::vector> signatures() { - // array(T), function(T) -> boolean - return {exec::FunctionSignatureBuilder() - .typeVariable("T") - .returnType("boolean") - .argumentType("array(T)") - .argumentType("function(T, boolean)") - .build()}; - } -}; +std::vector> signatures() { + // array(T), function(T) -> boolean + return {exec::FunctionSignatureBuilder() + .typeVariable("T") + .returnType("boolean") + .argumentType("array(T)") + .argumentType("function(T, boolean)") + .build()}; +} } // namespace VELOX_DECLARE_VECTOR_FUNCTION( udf_all_match, - AllMatchFunction::signatures(), + signatures(), std::make_unique()); VELOX_DECLARE_VECTOR_FUNCTION( udf_any_match, - AnyMatchFunction::signatures(), + signatures(), std::make_unique()); } // namespace facebook::velox::functions From 4bb611d3dd8593d688772160f55eb255916007cd Mon Sep 17 00:00:00 2001 From: macduan Date: Wed, 22 Mar 2023 00:09:28 +0800 Subject: [PATCH 3/6] resolve comments --- velox/docs/functions/presto/array.rst | 2 +- velox/functions/prestosql/ArrayMatch.cpp | 75 +++++++------- .../prestosql/tests/ArrayAnyMatchTest.cpp | 97 +++++++------------ 3 files changed, 74 insertions(+), 100 deletions(-) diff --git a/velox/docs/functions/presto/array.rst b/velox/docs/functions/presto/array.rst index 47065a66387b..3a6a8cd628df 100644 --- a/velox/docs/functions/presto/array.rst +++ b/velox/docs/functions/presto/array.rst @@ -13,7 +13,7 @@ Array Functions .. function:: any_match(array(T), function(T, boolean)) → boolean - Returns whether at least one element of an array match the given predicate. + Returns whether at least one element of an array matchs the given predicate. Returns true if one or more elements match the predicate; Returns false if none of the elements matches (a special case is when the array is empty); diff --git a/velox/functions/prestosql/ArrayMatch.cpp b/velox/functions/prestosql/ArrayMatch.cpp index 389631cdb5a8..fcd57dfeecea 100644 --- a/velox/functions/prestosql/ArrayMatch.cpp +++ b/velox/functions/prestosql/ArrayMatch.cpp @@ -24,13 +24,15 @@ namespace facebook::velox::functions { namespace { -template +enum class MatchMethod { ALL = 0, ANY = 1, NONE = 2 }; + +template class MatchFunction : public exec::VectorFunction { private: void apply( const SelectivityVector& rows, std::vector& args, - const TypePtr& outputType, + const TypePtr& /*outputType*/, exec::EvalCtx& context, VectorPtr& result) const override { exec::LocalDecodedVector arrayDecoder(context, *args[0], rows); @@ -89,7 +91,6 @@ class MatchFunction : public exec::VectorFunction { } } - protected: static FOLLY_ALWAYS_INLINE bool hasError( const ErrorVectorPtr& errors, int idx) { @@ -108,47 +109,49 @@ class MatchFunction : public exec::VectorFunction { const exec::LocalDecodedVector& bitsDecoder) const { auto size = sizes[row]; auto offset = offsets[row]; - - // All or any match. - // All: early exit on non-match. Initial value = true. - // Any: early exit on match. Initial value = false. - bool result = match; - auto hasNull = false; - std::exception_ptr errorPtr{nullptr}; - for (auto i = 0; i < size; ++i) { - auto idx = offset + i; - if (hasError(elementErrors, idx)) { - errorPtr = *std::static_pointer_cast( - elementErrors->valueAt(idx)); - continue; + if (mathMethod == MatchMethod::NONE) { + // TODO: Add none_match + } else { + // All or any match. + // All: early exit on non-match. Initial value = true. + // Any: early exit on match. Initial value = false. + auto match = mathMethod == MatchMethod::ALL; + bool result = match; + auto hasNull = false; + std::exception_ptr errorPtr{nullptr}; + for (auto i = 0; i < size; ++i) { + auto idx = offset + i; + if (hasError(elementErrors, idx)) { + errorPtr = *std::static_pointer_cast( + elementErrors->valueAt(idx)); + continue; + } + + if (bitsDecoder->isNullAt(idx)) { + hasNull = true; + } else if (bitsDecoder->valueAt(idx) == !match) { + result = !result; + break; + } } - if (bitsDecoder->isNullAt(idx)) { - hasNull = true; - } else if (bitsDecoder->valueAt(idx) == !match) { - result = !result; - break; + if (result != match) { + flatResult->set(row, !match); + } else if (errorPtr) { + context.setError(row, errorPtr); + } else if (hasNull) { + flatResult->setNull(row, true); + } else { + flatResult->set(row, match); } } - - // Errors for individual array elements should be suppressed only if the - // outcome can be decided by some other array element, e.g. if there is - // another element that returns 'true' for the predicate. - if (result != match) { - flatResult->set(row, !match); - } else if (errorPtr) { - context.setError(row, errorPtr); - } else if (hasNull) { - flatResult->setNull(row, true); - } else { - flatResult->set(row, match); - } } }; -class AllMatchFunction : public MatchFunction {}; +class AllMatchFunction : public MatchFunction {}; +class AnyMatchFunction : public MatchFunction {}; -class AnyMatchFunction : public MatchFunction {}; +// TODO: add class NoneMatchFunction std::vector> signatures() { // array(T), function(T) -> boolean diff --git a/velox/functions/prestosql/tests/ArrayAnyMatchTest.cpp b/velox/functions/prestosql/tests/ArrayAnyMatchTest.cpp index ee8729789267..c940e90bda80 100644 --- a/velox/functions/prestosql/tests/ArrayAnyMatchTest.cpp +++ b/velox/functions/prestosql/tests/ArrayAnyMatchTest.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include "velox/common/base/tests/GTestUtils.h" #include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" @@ -24,48 +25,40 @@ class ArrayAnyMatchTest : public functions::test::FunctionBaseTest { protected: // Evaluate an expression. void testExpr( - const VectorPtr& expected, - const std::string& expression, + const std::vector>& expected, + const std::string& lambdaExpr, const VectorPtr& input) { + auto expression = folly::sformat("any_match(c0, x -> ({}))", lambdaExpr); auto result = evaluate(expression, makeRowVector({input})); - assertEqualVectors(expected, result); + assertEqualVectors(makeNullableFlatVector(expected), result); } void testExpr( - const VectorPtr& expected, + const std::vector>& expected, const std::string& expression, const RowVectorPtr& input) { auto result = evaluate(expression, (input)); - assertEqualVectors(expected, result); + assertEqualVectors(makeNullableFlatVector(expected), result); } }; TEST_F(ArrayAnyMatchTest, basic) { auto input = makeNullableArrayVector( - {{std::nullopt, 2, 3}, - {-1, 3}, - {2, 3}, - {}, - {std::nullopt, std::nullopt}}); - auto expectedResult = - makeNullableFlatVector({true, true, true, false, std::nullopt}); - testExpr(expectedResult, "any_match(c0, x -> (x > 1))", input); - expectedResult = makeFlatVector({true, false, false, false, true}); - testExpr(expectedResult, "any_match(c0, x -> (x is null))", input); + {{std::nullopt, 2, 0}, {-1, 3}, {-2, -3}, {}, {0, std::nullopt}}); + std::vector> expectedResult{ + true, true, false, false, std::nullopt}; + testExpr(expectedResult, "x > 1", input); + + expectedResult = {true, false, false, false, true}; + testExpr(expectedResult, "x is null", input); input = makeNullableArrayVector( {{false, true}, {false, false}, {std::nullopt, true}, {std::nullopt, false}}); - expectedResult = - makeNullableFlatVector({true, false, true, std::nullopt}); - testExpr(expectedResult, "any_match(c0, x -> x)", input); - - auto emptyInput = makeArrayVector({{}}); - expectedResult = makeFlatVector(std::vector{false}); - testExpr(expectedResult, "any_match(c0, x -> (x > 1))", emptyInput); - testExpr(expectedResult, "any_match(c0, x -> (x <= 1))", emptyInput); + expectedResult = {true, false, true, std::nullopt}; + testExpr(expectedResult, "x", input); } TEST_F(ArrayAnyMatchTest, complexTypes) { @@ -78,11 +71,8 @@ TEST_F(ArrayAnyMatchTest, complexTypes) { // [[], []] // ] auto arrayOfArrays = makeArrayVector({0, 1, 5}, baseVector); - auto expectedResult = makeNullableFlatVector({true, true, false}); - testExpr( - expectedResult, - "any_match(c0, x -> (cardinality(x) > 0))", - arrayOfArrays); + std::vector> expectedResult{true, true, false}; + testExpr(expectedResult, "cardinality(x) > 0", arrayOfArrays); // Create an array of array vector using above base vector using offsets. // [ @@ -92,51 +82,32 @@ TEST_F(ArrayAnyMatchTest, complexTypes) { // null // ] arrayOfArrays = makeArrayVector({0, 1, 5, 6}, baseVector, {3}); - expectedResult = - makeNullableFlatVector({true, false, false, std::nullopt}); - testExpr( - expectedResult, - "any_match(c0, x -> (cardinality(x) > 2))", - arrayOfArrays); -} - -TEST_F(ArrayAnyMatchTest, bigints) { - auto input = makeNullableArrayVector( - {{}, - {2}, - {std::numeric_limits::max()}, - {std::numeric_limits::min()}, - {std::nullopt, std::nullopt}, // return null if all is null - {2, - std::nullopt}, // return null if one or more is null and others matched - {1, std::nullopt, 2}}); // return false if one is not matched - auto expectedResult = makeNullableFlatVector( - {false, true, false, true, std::nullopt, true, true}); - testExpr(expectedResult, "any_match(c0, x -> (x % 2 = 0))", input); + expectedResult = {true, false, false, std::nullopt}; + testExpr(expectedResult, "cardinality(x) > 2", arrayOfArrays); } TEST_F(ArrayAnyMatchTest, strings) { auto input = makeNullableArrayVector( {{}, {"abc"}, {"ab", "abc"}, {std::nullopt}}); - auto expectedResult = - makeNullableFlatVector({false, true, true, std::nullopt}); - testExpr(expectedResult, "any_match(c0, x -> (x = 'abc'))", input); + std::vector> expectedResult{ + false, true, true, std::nullopt}; + testExpr(expectedResult, "x = 'abc'", input); } TEST_F(ArrayAnyMatchTest, doubles) { auto input = makeNullableArrayVector({{}, {0.2}, {3.0, 0}, {std::nullopt}}); - auto expectedResult = - makeNullableFlatVector({false, false, true, std::nullopt}); - testExpr(expectedResult, "any_match(c0, x -> (x > 1.1))", input); + std::vector> expectedResult{ + false, false, true, std::nullopt}; + testExpr(expectedResult, "x > 1.1", input); } TEST_F(ArrayAnyMatchTest, errors) { // No throw and return false if there are unmatched elements except nulls - auto expression = "any_match(c0, x -> ((10 / x) > 2))"; + auto expression = "(10 / x) > 2"; auto input = makeNullableArrayVector( {{0, 2, 0, 5, 0}, {2, 5, std::nullopt, 0}}); - auto expectedResult = makeFlatVector({true, true}); + std::vector> expectedResult = {true, true}; testExpr(expectedResult, expression, input); // Throw error if others are matched or null @@ -146,12 +117,12 @@ TEST_F(ArrayAnyMatchTest, errors) { VELOX_ASSERT_THROW( testExpr(expectedResult, expression, errorInput), kErrorMessage); // Rerun using TRY to get right results - expectedResult = - makeNullableFlatVector({true, true, false, std::nullopt, true}); + auto errorInputRow = makeRowVector({errorInput}); + expectedResult = {true, true, false, std::nullopt, true}; testExpr( - expectedResult, "TRY(any_match(c0, x -> ((10 / x) > 2)))", errorInput); + expectedResult, "TRY(any_match(c0, x -> ((10 / x) > 2)))", errorInputRow); testExpr( - expectedResult, "any_match(c0, x -> (TRY((10 / x) > 2)))", errorInput); + expectedResult, "any_match(c0, x -> (TRY((10 / x) > 2)))", errorInputRow); } TEST_F(ArrayAnyMatchTest, conditional) { @@ -164,8 +135,8 @@ TEST_F(ArrayAnyMatchTest, conditional) { {3, std::nullopt, 0}, {300, 100}}); auto input = makeRowVector({c0, c1}); - auto expectedResult = makeNullableFlatVector( - {std::nullopt, false, std::nullopt, true, false}); + std::vector> expectedResult = { + std::nullopt, false, std::nullopt, true, false}; testExpr( expectedResult, "any_match(c1, if (c0 <= 2, x -> (x > 100), x -> (10 / x > 2)))", From 1941a7979b418aae3b394eee5128764db1e3f156 Mon Sep 17 00:00:00 2001 From: macduan Date: Wed, 22 Mar 2023 11:20:32 +0800 Subject: [PATCH 4/6] resolve comments --- velox/docs/functions/presto/array.rst | 11 +- velox/functions/prestosql/ArrayMatch.cpp | 97 ++++++----- .../ArrayFunctionsRegistration.cpp | 1 + .../prestosql/tests/ArrayAnyMatchTest.cpp | 43 +++-- .../prestosql/tests/ArrayNoneMatchTest.cpp | 159 ++++++++++++++++++ .../functions/prestosql/tests/CMakeLists.txt | 1 + 6 files changed, 257 insertions(+), 55 deletions(-) create mode 100644 velox/functions/prestosql/tests/ArrayNoneMatchTest.cpp diff --git a/velox/docs/functions/presto/array.rst b/velox/docs/functions/presto/array.rst index 3a6a8cd628df..10b039df8389 100644 --- a/velox/docs/functions/presto/array.rst +++ b/velox/docs/functions/presto/array.rst @@ -13,13 +13,22 @@ Array Functions .. function:: any_match(array(T), function(T, boolean)) → boolean - Returns whether at least one element of an array matchs the given predicate. + Returns whether at least one element of an array matches the given predicate. Returns true if one or more elements match the predicate; Returns false if none of the elements matches (a special case is when the array is empty); Returns NULL if the predicate function returns NULL for one or more elements and false for all other elements. Throws an exception if the predicate fails for one or more elements and returns false or NULL for the rest. +.. function:: none_match(array(T), function(T, boolean)) → boolean + + Returns whether no elements of an array match the given predicate. + + Returns true if none of the elements matches the predicate (a special case is when the array is empty); + Returns false if one or more elements match; + Returns NULL if the predicate function returns NULL for one or more elements and false for all other elements. + Throws an exception if the predicate fails for one or more elements and returns false or NULL for the rest. + .. function:: array_average(array(double)) -> double Returns the average of all non-null elements of the array. If there are no non-null elements, returns null. diff --git a/velox/functions/prestosql/ArrayMatch.cpp b/velox/functions/prestosql/ArrayMatch.cpp index fcd57dfeecea..ba6df446dd31 100644 --- a/velox/functions/prestosql/ArrayMatch.cpp +++ b/velox/functions/prestosql/ArrayMatch.cpp @@ -24,9 +24,9 @@ namespace facebook::velox::functions { namespace { -enum class MatchMethod { ALL = 0, ANY = 1, NONE = 2 }; +enum class MatchMethod { kAll = 0, kAny = 1, kNone = 2 }; -template +template class MatchFunction : public exec::VectorFunction { private: void apply( @@ -94,8 +94,7 @@ class MatchFunction : public exec::VectorFunction { static FOLLY_ALWAYS_INLINE bool hasError( const ErrorVectorPtr& errors, int idx) { - return errors && idx < errors->size() && - !errors->isNullAt(idx); + return errors && idx < errors->size() && !errors->isNullAt(idx); } private: @@ -109,49 +108,66 @@ class MatchFunction : public exec::VectorFunction { const exec::LocalDecodedVector& bitsDecoder) const { auto size = sizes[row]; auto offset = offsets[row]; - if (mathMethod == MatchMethod::NONE) { - // TODO: Add none_match - } else { - // All or any match. - // All: early exit on non-match. Initial value = true. - // Any: early exit on match. Initial value = false. - auto match = mathMethod == MatchMethod::ALL; - bool result = match; - auto hasNull = false; - std::exception_ptr errorPtr{nullptr}; - for (auto i = 0; i < size; ++i) { - auto idx = offset + i; - if (hasError(elementErrors, idx)) { - errorPtr = *std::static_pointer_cast( - elementErrors->valueAt(idx)); - continue; - } - if (bitsDecoder->isNullAt(idx)) { - hasNull = true; - } else if (bitsDecoder->valueAt(idx) == !match) { - result = !result; - break; - } + // All, none, and any match have different and similar logic intertwined in + // terms of the initial value, flip condition, and the result finalization: + // + // Initial value: + // All is true + // Any is false + // None is true + // + // Flip logic: + // All flips once encounter an unmatched element + // Any flips once encounter a matched element + // None flips once encounter a matched element + // + // Result finalization: + // All: ignore the error and null if one or more elements are unmatched and + // return false Any: ignore the error and null if one or more elements + // matched and return true None: ignore the error and null if one or more + // elements matched and return false + auto match = (matchMethod != MatchMethod::kAny); + auto hasNull = false; + std::exception_ptr errorPtr{nullptr}; + for (auto i = 0; i < size; ++i) { + auto idx = offset + i; + if (hasError(elementErrors, idx)) { + errorPtr = *std::static_pointer_cast( + elementErrors->valueAt(idx)); + continue; } - if (result != match) { - flatResult->set(row, !match); - } else if (errorPtr) { - context.setError(row, errorPtr); - } else if (hasNull) { - flatResult->setNull(row, true); + if (bitsDecoder->isNullAt(idx)) { + hasNull = true; + } else if (matchMethod == MatchMethod::kAll) { + if (!bitsDecoder->valueAt(idx)) { + match = !match; + break; + } } else { - flatResult->set(row, match); + if (bitsDecoder->valueAt(idx)) { + match = !match; + break; + } } } + + if ((matchMethod == MatchMethod::kAny) == match) { + flatResult->set(row, match); + } else if (errorPtr) { + context.setError(row, errorPtr); + } else if (hasNull) { + flatResult->setNull(row, true); + } else { + flatResult->set(row, match); + } } }; -class AllMatchFunction : public MatchFunction {}; -class AnyMatchFunction : public MatchFunction {}; - -// TODO: add class NoneMatchFunction +class AllMatchFunction : public MatchFunction {}; +class AnyMatchFunction : public MatchFunction {}; +class NoneMatchFunction : public MatchFunction {}; std::vector> signatures() { // array(T), function(T) -> boolean @@ -175,4 +191,9 @@ VELOX_DECLARE_VECTOR_FUNCTION( signatures(), std::make_unique()); +VELOX_DECLARE_VECTOR_FUNCTION( + udf_none_match, + signatures(), + std::make_unique()); + } // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/registration/ArrayFunctionsRegistration.cpp b/velox/functions/prestosql/registration/ArrayFunctionsRegistration.cpp index 0325d23ed497..ba6aa5fb607c 100644 --- a/velox/functions/prestosql/registration/ArrayFunctionsRegistration.cpp +++ b/velox/functions/prestosql/registration/ArrayFunctionsRegistration.cpp @@ -80,6 +80,7 @@ void registerArrayFunctions(const std::string& prefix) { registerArrayConstructor(prefix + "array_constructor"); VELOX_REGISTER_VECTOR_FUNCTION(udf_all_match, prefix + "all_match"); VELOX_REGISTER_VECTOR_FUNCTION(udf_any_match, prefix + "any_match"); + VELOX_REGISTER_VECTOR_FUNCTION(udf_none_match, prefix + "none_match"); VELOX_REGISTER_VECTOR_FUNCTION(udf_array_distinct, prefix + "array_distinct"); VELOX_REGISTER_VECTOR_FUNCTION( udf_array_duplicates, prefix + "array_duplicates"); diff --git a/velox/functions/prestosql/tests/ArrayAnyMatchTest.cpp b/velox/functions/prestosql/tests/ArrayAnyMatchTest.cpp index c940e90bda80..498b8618725d 100644 --- a/velox/functions/prestosql/tests/ArrayAnyMatchTest.cpp +++ b/velox/functions/prestosql/tests/ArrayAnyMatchTest.cpp @@ -33,6 +33,17 @@ class ArrayAnyMatchTest : public functions::test::FunctionBaseTest { assertEqualVectors(makeNullableFlatVector(expected), result); } + template + void testExpr( + const std::vector>& expected, + const std::string& lambdaExpr, + const std::vector>>& input) { + auto expression = folly::sformat("any_match(c0, x -> ({}))", lambdaExpr); + auto result = + evaluate(expression, makeRowVector({makeNullableArrayVector(input)})); + assertEqualVectors(makeNullableFlatVector(expected), result); + } + void testExpr( const std::vector>& expected, const std::string& expression, @@ -43,22 +54,22 @@ class ArrayAnyMatchTest : public functions::test::FunctionBaseTest { }; TEST_F(ArrayAnyMatchTest, basic) { - auto input = makeNullableArrayVector( - {{std::nullopt, 2, 0}, {-1, 3}, {-2, -3}, {}, {0, std::nullopt}}); + std::vector>> ints{ + {std::nullopt, 2, 0}, {-1, 3}, {-2, -3}, {}, {0, std::nullopt}}; std::vector> expectedResult{ true, true, false, false, std::nullopt}; - testExpr(expectedResult, "x > 1", input); + testExpr(expectedResult, "x > 1", ints); expectedResult = {true, false, false, false, true}; - testExpr(expectedResult, "x is null", input); + testExpr(expectedResult, "x is null", ints); - input = makeNullableArrayVector( - {{false, true}, - {false, false}, - {std::nullopt, true}, - {std::nullopt, false}}); + std::vector>> bools{ + {false, true}, + {false, false}, + {std::nullopt, true}, + {std::nullopt, false}}; expectedResult = {true, false, true, std::nullopt}; - testExpr(expectedResult, "x", input); + testExpr(expectedResult, "x", bools); } TEST_F(ArrayAnyMatchTest, complexTypes) { @@ -87,16 +98,16 @@ TEST_F(ArrayAnyMatchTest, complexTypes) { } TEST_F(ArrayAnyMatchTest, strings) { - auto input = makeNullableArrayVector( - {{}, {"abc"}, {"ab", "abc"}, {std::nullopt}}); + std::vector>> input{ + {}, {"abc"}, {"ab", "abc"}, {std::nullopt}}; std::vector> expectedResult{ false, true, true, std::nullopt}; testExpr(expectedResult, "x = 'abc'", input); } TEST_F(ArrayAnyMatchTest, doubles) { - auto input = - makeNullableArrayVector({{}, {0.2}, {3.0, 0}, {std::nullopt}}); + std::vector>> input{ + {}, {0.2}, {3.0, 0}, {std::nullopt}}; std::vector> expectedResult{ false, false, true, std::nullopt}; testExpr(expectedResult, "x > 1.1", input); @@ -105,8 +116,8 @@ TEST_F(ArrayAnyMatchTest, doubles) { TEST_F(ArrayAnyMatchTest, errors) { // No throw and return false if there are unmatched elements except nulls auto expression = "(10 / x) > 2"; - auto input = makeNullableArrayVector( - {{0, 2, 0, 5, 0}, {2, 5, std::nullopt, 0}}); + std::vector>> input{ + {0, 2, 0, 5, 0}, {2, 5, std::nullopt, 0}}; std::vector> expectedResult = {true, true}; testExpr(expectedResult, expression, input); diff --git a/velox/functions/prestosql/tests/ArrayNoneMatchTest.cpp b/velox/functions/prestosql/tests/ArrayNoneMatchTest.cpp new file mode 100644 index 000000000000..9db44da30839 --- /dev/null +++ b/velox/functions/prestosql/tests/ArrayNoneMatchTest.cpp @@ -0,0 +1,159 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" + +using namespace facebook::velox; +using namespace facebook::velox::test; + +class ArrayNoneMatchTest : public functions::test::FunctionBaseTest { + protected: + // Evaluate an expression. + void testExpr( + const std::vector>& expected, + const std::string& lambdaExpr, + const VectorPtr& input) { + auto expression = folly::sformat("none_match(c0, x -> ({}))", lambdaExpr); + auto result = evaluate(expression, makeRowVector({input})); + assertEqualVectors(makeNullableFlatVector(expected), result); + } + + template + void testExpr( + const std::vector>& expected, + const std::string& lambdaExpr, + const std::vector>>& input) { + auto expression = folly::sformat("none_match(c0, x -> ({}))", lambdaExpr); + auto result = + evaluate(expression, makeRowVector({makeNullableArrayVector(input)})); + assertEqualVectors(makeNullableFlatVector(expected), result); + } + + void testExpr( + const std::vector>& expected, + const std::string& expression, + const RowVectorPtr& input) { + auto result = evaluate(expression, (input)); + assertEqualVectors(makeNullableFlatVector(expected), result); + } +}; + +TEST_F(ArrayNoneMatchTest, basic) { + std::vector>> ints{ + {std::nullopt, 2, 0}, {-1, 3}, {-2, -3}, {}, {0, std::nullopt}}; + std::vector> expectedResult{ + false, false, true, true, std::nullopt}; + testExpr(expectedResult, "x > 1", ints); + + expectedResult = {false, true, true, true, false}; + testExpr(expectedResult, "x is null", ints); + + std::vector>> bools{ + {false, true}, + {false, false}, + {std::nullopt, true}, + {std::nullopt, false}}; + expectedResult = {false, true, false, std::nullopt}; + testExpr(expectedResult, "x", bools); +} + +TEST_F(ArrayNoneMatchTest, complexTypes) { + auto baseVector = + makeArrayVector({{1, 2, 3}, {2, 2}, {3, 3}, {4, 4}, {5, 5}, {}}); + // Create an array of array vector using above base vector using offsets. + // [ + // [[1, 2, 3], []], + // [[2, 2], [3, 3], [4, 4], [5, 5]], + // [[], []] + // ] + auto arrayOfArrays = makeArrayVector({0, 1, 5}, baseVector); + std::vector> expectedResult{false, false, true}; + testExpr(expectedResult, "cardinality(x) > 0", arrayOfArrays); + + // Create an array of array vector using above base vector using offsets. + // [ + // [[1, 2, 3]], cardinalities is 3 + // [[2, 2], [3, 3], [4, 4], [5, 5]], all cardinalities is 2 + // [[]], + // null + // ] + arrayOfArrays = makeArrayVector({0, 1, 5, 6}, baseVector, {3}); + expectedResult = {false, true, true, std::nullopt}; + testExpr(expectedResult, "cardinality(x) > 2", arrayOfArrays); +} + +TEST_F(ArrayNoneMatchTest, strings) { + std::vector>> input{ + {}, {"abc"}, {"ab", "abc"}, {std::nullopt}}; + std::vector> expectedResult{ + true, false, false, std::nullopt}; + testExpr(expectedResult, "x = 'abc'", input); +} + +TEST_F(ArrayNoneMatchTest, doubles) { + std::vector>> input{ + {}, {0.2}, {3.0, 0}, {std::nullopt}}; + std::vector> expectedResult{ + true, true, false, std::nullopt}; + testExpr(expectedResult, "x > 1.1", input); +} + +TEST_F(ArrayNoneMatchTest, errors) { + // No throw and return false if there are unmatched elements except nulls + auto expression = "(10 / x) > 2"; + std::vector>> input{ + {0, 2, 0, 5, 0}, {2, 5, std::nullopt, 0}}; + std::vector> expectedResult = {false, false}; + testExpr(expectedResult, expression, input); + + // Throw error if others are matched or null + static constexpr std::string_view kErrorMessage{"division by zero"}; + auto errorInput = makeNullableArrayVector( + {{1, 0}, {2}, {6}, {10, 9, 0, std::nullopt}, {0, std::nullopt, 1}}); + VELOX_ASSERT_THROW( + testExpr(expectedResult, expression, errorInput), kErrorMessage); + // Rerun using TRY to get right results + auto errorInputRow = makeRowVector({errorInput}); + expectedResult = {false, false, true, std::nullopt, false}; + testExpr( + expectedResult, + "TRY(none_match(c0, x -> ((10 / x) > 2)))", + errorInputRow); + testExpr( + expectedResult, + "none_match(c0, x -> (TRY((10 / x) > 2)))", + errorInputRow); +} + +TEST_F(ArrayNoneMatchTest, conditional) { + // No throw and return false if there are unmatched elements except nulls + auto c0 = makeFlatVector({1, 2, 3, 4, 5}); + auto c1 = makeNullableArrayVector( + {{4, 100, std::nullopt}, + {50, 12}, + {std::nullopt}, + {3, std::nullopt, 0}, + {300, 100}}); + auto input = makeRowVector({c0, c1}); + std::vector> expectedResult = { + std::nullopt, true, std::nullopt, false, true}; + testExpr( + expectedResult, + "none_match(c1, if (c0 <= 2, x -> (x > 100), x -> (10 / x > 2)))", + input); +} diff --git a/velox/functions/prestosql/tests/CMakeLists.txt b/velox/functions/prestosql/tests/CMakeLists.txt index 6f43d4f6f851..181aa1ff9a67 100644 --- a/velox/functions/prestosql/tests/CMakeLists.txt +++ b/velox/functions/prestosql/tests/CMakeLists.txt @@ -32,6 +32,7 @@ add_executable( ArrayIntersectTest.cpp ArrayMaxTest.cpp ArrayMinTest.cpp + ArrayNoneMatchTest.cpp ArrayNormalizeTest.cpp ArrayPositionTest.cpp ArraySortTest.cpp From 13d267e3afdcfa45714e17eb79875594e4275d60 Mon Sep 17 00:00:00 2001 From: macduan Date: Wed, 22 Mar 2023 20:27:12 +0800 Subject: [PATCH 5/6] use bool flags --- velox/functions/prestosql/ArrayMatch.cpp | 31 +++++++++--------------- 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/velox/functions/prestosql/ArrayMatch.cpp b/velox/functions/prestosql/ArrayMatch.cpp index ba6df446dd31..0dfa834e190d 100644 --- a/velox/functions/prestosql/ArrayMatch.cpp +++ b/velox/functions/prestosql/ArrayMatch.cpp @@ -24,9 +24,7 @@ namespace facebook::velox::functions { namespace { -enum class MatchMethod { kAll = 0, kAny = 1, kNone = 2 }; - -template +template class MatchFunction : public exec::VectorFunction { private: void apply( @@ -127,7 +125,7 @@ class MatchFunction : public exec::VectorFunction { // return false Any: ignore the error and null if one or more elements // matched and return true None: ignore the error and null if one or more // elements matched and return false - auto match = (matchMethod != MatchMethod::kAny); + bool result = initialValue; auto hasNull = false; std::exception_ptr errorPtr{nullptr}; for (auto i = 0; i < size; ++i) { @@ -140,34 +138,27 @@ class MatchFunction : public exec::VectorFunction { if (bitsDecoder->isNullAt(idx)) { hasNull = true; - } else if (matchMethod == MatchMethod::kAll) { - if (!bitsDecoder->valueAt(idx)) { - match = !match; - break; - } - } else { - if (bitsDecoder->valueAt(idx)) { - match = !match; - break; - } + } else if (bitsDecoder->valueAt(idx) == earlyReturn) { + result = !result; + break; } } - if ((matchMethod == MatchMethod::kAny) == match) { - flatResult->set(row, match); + if (result != initialValue) { + flatResult->set(row, !initialValue); } else if (errorPtr) { context.setError(row, errorPtr); } else if (hasNull) { flatResult->setNull(row, true); } else { - flatResult->set(row, match); + flatResult->set(row, initialValue); } } }; -class AllMatchFunction : public MatchFunction {}; -class AnyMatchFunction : public MatchFunction {}; -class NoneMatchFunction : public MatchFunction {}; +class AllMatchFunction : public MatchFunction {}; +class AnyMatchFunction : public MatchFunction {}; +class NoneMatchFunction : public MatchFunction {}; std::vector> signatures() { // array(T), function(T) -> boolean From 650b04776e02ceb548118140b34c6572ce796bb6 Mon Sep 17 00:00:00 2001 From: macduan Date: Wed, 22 Mar 2023 23:17:44 +0800 Subject: [PATCH 6/6] resolve comments --- velox/functions/prestosql/ArrayMatch.cpp | 80 ++++-- .../prestosql/tests/ArrayAllMatchTest.cpp | 258 ++++++++++-------- .../prestosql/tests/ArrayAnyMatchTest.cpp | 161 +++++++---- .../prestosql/tests/ArrayNoneMatchTest.cpp | 156 +++++++---- 4 files changed, 430 insertions(+), 225 deletions(-) diff --git a/velox/functions/prestosql/ArrayMatch.cpp b/velox/functions/prestosql/ArrayMatch.cpp index 0dfa834e190d..7e53eadd8e90 100644 --- a/velox/functions/prestosql/ArrayMatch.cpp +++ b/velox/functions/prestosql/ArrayMatch.cpp @@ -95,7 +95,6 @@ class MatchFunction : public exec::VectorFunction { return errors && idx < errors->size() && !errors->isNullAt(idx); } - private: void applyInternal( FlatVector* flatResult, exec::EvalCtx& context, @@ -107,24 +106,69 @@ class MatchFunction : public exec::VectorFunction { auto size = sizes[row]; auto offset = offsets[row]; - // All, none, and any match have different and similar logic intertwined in - // terms of the initial value, flip condition, and the result finalization: + // all_match, none_match and any_match need to loop over predicate results + // for element arrays and check for results, nulls and errors. + // These loops can be generalized using two booleans. + // + // Here is what the individual loops look like. + // + //---- kAll ---- + // bool allMatch = true + // + // loop: + // if not match: + // allMatch = false; + // break; + // + // if (!allMatch) -> false + // else if hasError -> error + // else if hasNull -> null + // else -> true + // + //---- kAny ---- + // + // bool anyMatch = false + // + // loop: + // if match: + // anyMatch = true; + // break; + // + // if (anyMatch) -> true + // else if hasError -> error + // else if hasNull -> null + // else -> false + // + //---- kNone ---- + // + // bool noneMatch = true; + // + // loop: + // if match: + // noneMatch = false; + // break; + // + // if (!noneMatch) -> false + // else if hasError -> error + // else if hasNull -> null + // else -> true + // + // To generalize these loops, we use initialValue and earlyReturn booleans + // like so: + // + //--- generic loop --- // - // Initial value: - // All is true - // Any is false - // None is true + // bool result = initialValue // - // Flip logic: - // All flips once encounter an unmatched element - // Any flips once encounter a matched element - // None flips once encounter a matched element + // loop: + // if match == earlyReturn: + // result = false; + // break; // - // Result finalization: - // All: ignore the error and null if one or more elements are unmatched and - // return false Any: ignore the error and null if one or more elements - // matched and return true None: ignore the error and null if one or more - // elements matched and return false + // if (result != initialValue) -> result + // else if hasError -> error + // else if hasNull -> null + // else -> result bool result = initialValue; auto hasNull = false; std::exception_ptr errorPtr{nullptr}; @@ -145,13 +189,13 @@ class MatchFunction : public exec::VectorFunction { } if (result != initialValue) { - flatResult->set(row, !initialValue); + flatResult->set(row, result); } else if (errorPtr) { context.setError(row, errorPtr); } else if (hasNull) { flatResult->setNull(row, true); } else { - flatResult->set(row, initialValue); + flatResult->set(row, result); } } }; diff --git a/velox/functions/prestosql/tests/ArrayAllMatchTest.cpp b/velox/functions/prestosql/tests/ArrayAllMatchTest.cpp index 5d6523180dd6..3fdc671c509c 100644 --- a/velox/functions/prestosql/tests/ArrayAllMatchTest.cpp +++ b/velox/functions/prestosql/tests/ArrayAllMatchTest.cpp @@ -20,49 +20,86 @@ using namespace facebook::velox; using namespace facebook::velox::test; -class ArrayAllMatchTest : public functions::test::FunctionBaseTest {}; +class ArrayAllMatchTest : public functions::test::FunctionBaseTest { + protected: + void testAllMatchExpr( + const std::vector>& expected, + const std::string& lambdaExpr, + const VectorPtr& input) { + auto expression = fmt::format("all_match(c0, x -> ({}))", lambdaExpr); + testAllMatchExpr(expected, expression, makeRowVector({input})); + } + + template + void testAllMatchExpr( + const std::vector>& expected, + const std::string& lambdaExpr, + const std::vector>>& input) { + auto expression = fmt::format("all_match(c0, x -> ({}))", lambdaExpr); + testAllMatchExpr( + expected, expression, makeRowVector({makeNullableArrayVector(input)})); + } + + void testAllMatchExpr( + const std::vector>& expected, + const std::string& expression, + const RowVectorPtr& input) { + auto result = evaluate(expression, (input)); + assertEqualVectors(makeNullableFlatVector(expected), result); + } +}; TEST_F(ArrayAllMatchTest, basic) { - auto arrayVector = makeNullableArrayVector( - {{std::nullopt, 2, 3}, - {-1, 3}, - {2, 3}, - {}, - {std::nullopt, std::nullopt}}); - auto input = makeRowVector({arrayVector}); - auto result = evaluate("all_match(c0, x -> (x > 1))", input); - auto expectedResult = makeNullableFlatVector( - {std::nullopt, false, true, true, std::nullopt}); - assertEqualVectors(expectedResult, result); - - result = evaluate("all_match(c0, x -> (x is null))", input); - expectedResult = makeFlatVector({false, false, false, true, true}); - assertEqualVectors(expectedResult, result); - - arrayVector = makeNullableArrayVector( - {{false, true}, - {true, true}, - {std::nullopt, true}, - {std::nullopt, false}}); - input = makeRowVector({arrayVector}); - result = evaluate("all_match(c0, x -> x)", input); - expectedResult = - makeNullableFlatVector({false, true, std::nullopt, false}); - assertEqualVectors(expectedResult, result); - - auto emptyInput = makeRowVector({makeArrayVector({{}})}); - result = evaluate("all_match(c0, x -> (x > 1))", emptyInput); - expectedResult = makeFlatVector(std::vector{true}); - assertEqualVectors(expectedResult, result); - - result = evaluate("all_match(c0, x -> (x < 1))", emptyInput); - expectedResult = makeFlatVector(std::vector{true}); - assertEqualVectors(expectedResult, result); + std::vector>> ints{ + {std::nullopt, 2, 3}, + {-1, 3}, + {2, 3}, + {}, + {std::nullopt, std::nullopt}, + }; + std::vector> expectedResult{ + std::nullopt, + false, + true, + true, + std::nullopt, + }; + testAllMatchExpr(expectedResult, "x > 1", ints); + + expectedResult = { + false, + false, + false, + true, + true, + }; + testAllMatchExpr(expectedResult, "x is null", ints); + + std::vector>> bools{ + {false, true}, + {true, true}, + {std::nullopt, true}, + {std::nullopt, false}, + }; + + expectedResult = { + false, + true, + std::nullopt, + false, + }; + testAllMatchExpr(expectedResult, "x", bools); } TEST_F(ArrayAllMatchTest, complexTypes) { - auto baseVector = - makeArrayVector({{1, 2, 3}, {2, 2}, {3, 3}, {4, 4}, {5, 5}, {}}); + auto baseVector = makeArrayVector({ + {1, 2, 3}, + {2, 2}, + {3, 3}, + {4, 4}, + {5, 5}, + {}, + }); // Create an array of array vector using above base vector using offsets. // [ // [[1, 2, 3]], @@ -70,10 +107,12 @@ TEST_F(ArrayAllMatchTest, complexTypes) { // [[]] // ] auto arrayOfArrays = makeArrayVector({0, 1, 5}, baseVector); - auto input = makeRowVector({arrayOfArrays}); - auto result = evaluate("all_match(c0, x -> (cardinality(x) > 0))", input); - auto expectedResult = makeNullableFlatVector({true, true, false}); - assertEqualVectors(expectedResult, result); + std::vector> expectedResult{ + true, + true, + false, + }; + testAllMatchExpr(expectedResult, "cardinality(x) > 0", arrayOfArrays); // Create an array of array vector using above base vector using offsets. // [ @@ -83,92 +122,95 @@ TEST_F(ArrayAllMatchTest, complexTypes) { // null // ] arrayOfArrays = makeArrayVector({0, 1, 5, 6}, baseVector, {3}); - input = makeRowVector({arrayOfArrays}); - result = evaluate("all_match(c0, x -> (cardinality(x) > 2))", input); - expectedResult = - makeNullableFlatVector({true, false, false, std::nullopt}); - assertEqualVectors(expectedResult, result); -} - -TEST_F(ArrayAllMatchTest, bigints) { - auto arrayVector = makeNullableArrayVector( - {{}, - {2}, - {std::numeric_limits::max()}, - {std::numeric_limits::min()}, - {std::nullopt, std::nullopt}, // return null if all is null - {2, - std::nullopt}, // return null if one or more is null and others matched - {1, std::nullopt, 2}}); // return false if one is not matched - auto input = makeRowVector({arrayVector}); - auto result = evaluate("all_match(c0, x -> (x % 2 = 0))", input); - - auto expectedResult = makeNullableFlatVector( - {true, true, false, true, std::nullopt, std::nullopt, false}); - assertEqualVectors(expectedResult, result); + expectedResult = {true, false, false, std::nullopt}; + testAllMatchExpr(expectedResult, "cardinality(x) > 2", arrayOfArrays); } TEST_F(ArrayAllMatchTest, strings) { - auto input = makeRowVector({makeNullableArrayVector( - {{}, {"abc"}, {"ab", "abc"}, {std::nullopt}})}); - auto result = evaluate("all_match(c0, x -> (x = 'abc'))", input); - - auto expectedResult = - makeNullableFlatVector({true, true, false, std::nullopt}); - assertEqualVectors(expectedResult, result); + std::vector>> input{ + {}, + {"abc"}, + {"ab", "abc"}, + {std::nullopt}, + }; + std::vector> expectedResult{ + true, true, false, std::nullopt}; + testAllMatchExpr(expectedResult, "x = 'abc'", input); } TEST_F(ArrayAllMatchTest, doubles) { - auto input = makeRowVector( - {makeNullableArrayVector({{}, {1.2}, {3.0, 0}, {std::nullopt}})}); - auto result = evaluate("all_match(c0, x -> (x > 1.1))", input); - - auto expectedResult = - makeNullableFlatVector({true, true, false, std::nullopt}); - assertEqualVectors(expectedResult, result); + std::vector>> input{ + {}, + {1.2}, + {3.0, 0}, + {std::nullopt}, + }; + std::vector> expectedResult{ + true, + true, + false, + std::nullopt, + }; + testAllMatchExpr(expectedResult, "x > 1.1", input); } TEST_F(ArrayAllMatchTest, errors) { // No throw and return false if there are unmatched elements except nulls - auto input = makeRowVector({makeNullableArrayVector( - {{0, 2, 0, 5, 0}, {5, std::nullopt, 0}})}); - auto result = evaluate("all_match(c0, x -> ((10 / x) > 2))", input); - auto expectedResult = makeFlatVector({false, false}); - assertEqualVectors(expectedResult, result); + auto expression = "(10 / x) > 2"; + std::vector>> input{ + {0, 2, 0, 5, 0}, + {2, 5, std::nullopt, 0}, + }; + testAllMatchExpr({false, false}, expression, input); // Throw error if others are matched or null static constexpr std::string_view kErrorMessage{"division by zero"}; - auto errorInput = makeRowVector({makeNullableArrayVector( - {{1, 0}, {2}, {6}, {1, 0, std::nullopt}, {10, std::nullopt}})}); - VELOX_ASSERT_THROW( - evaluate("all_match(c0, x -> ((10 / x) > 2))", errorInput), - kErrorMessage); + auto errorInput = makeNullableArrayVector({ + {1, 0}, + {2}, + {6}, + {1, 0, std::nullopt}, + {10, std::nullopt}, + }); + VELOX_ASSERT_THROW( - evaluate("all_match(c0, x -> ((10 / x) > 2))", errorInput), + testAllMatchExpr({false, false}, "(10 / x) > 2", errorInput), kErrorMessage); // Rerun using TRY to get right results - expectedResult = makeNullableFlatVector( - {std::nullopt, true, false, std::nullopt, false}); - result = evaluate("TRY(all_match(c0, x -> ((10 / x) > 2)))", errorInput); - assertEqualVectors(expectedResult, result); - - result = evaluate("all_match(c0, x -> (TRY((10 / x) > 2)))", errorInput); - assertEqualVectors(expectedResult, result); + auto errorInputRow = makeRowVector({errorInput}); + std::vector> expectedResult{ + std::nullopt, + true, + false, + std::nullopt, + false, + }; + testAllMatchExpr( + expectedResult, "TRY(all_match(c0, x -> ((10 / x) > 2)))", errorInputRow); + testAllMatchExpr( + expectedResult, "all_match(c0, x -> (TRY((10 / x) > 2)))", errorInputRow); } TEST_F(ArrayAllMatchTest, conditional) { // No throw and return false if there are unmatched elements except nulls auto c0 = makeFlatVector({1, 2, 3, 4, 5}); - auto c1 = makeNullableArrayVector( - {{100, std::nullopt}, - {500, 120}, - {std::nullopt}, - {5, std::nullopt, 0}, - {3, 1}}); - auto result = evaluate( + auto c1 = makeNullableArrayVector({ + {100, std::nullopt}, + {500, 120}, + {std::nullopt}, + {5, std::nullopt, 0}, + {3, 1}, + }); + auto input = makeRowVector({c0, c1}); + std::vector> expectedResult{ + false, + true, + std::nullopt, + false, + true, + }; + testAllMatchExpr( + expectedResult, "all_match(c1, if (c0 <= 2, x -> (x > 100), x -> (10 / x > 2)))", - makeRowVector({c0, c1})); - auto expectedResult = - makeNullableFlatVector({false, true, std::nullopt, false, true}); - assertEqualVectors(expectedResult, result); + input); } diff --git a/velox/functions/prestosql/tests/ArrayAnyMatchTest.cpp b/velox/functions/prestosql/tests/ArrayAnyMatchTest.cpp index 498b8618725d..8d0c0c3f7dcc 100644 --- a/velox/functions/prestosql/tests/ArrayAnyMatchTest.cpp +++ b/velox/functions/prestosql/tests/ArrayAnyMatchTest.cpp @@ -14,7 +14,6 @@ * limitations under the License. */ -#include #include "velox/common/base/tests/GTestUtils.h" #include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" @@ -23,28 +22,24 @@ using namespace facebook::velox::test; class ArrayAnyMatchTest : public functions::test::FunctionBaseTest { protected: - // Evaluate an expression. - void testExpr( + void testAnyMatchExpr( const std::vector>& expected, const std::string& lambdaExpr, const VectorPtr& input) { - auto expression = folly::sformat("any_match(c0, x -> ({}))", lambdaExpr); - auto result = evaluate(expression, makeRowVector({input})); - assertEqualVectors(makeNullableFlatVector(expected), result); + auto expression = fmt::format("any_match(c0, x -> ({}))", lambdaExpr); + testAnyMatchExpr(expected, expression, makeRowVector({input})); } template - void testExpr( + void testAnyMatchExpr( const std::vector>& expected, const std::string& lambdaExpr, const std::vector>>& input) { - auto expression = folly::sformat("any_match(c0, x -> ({}))", lambdaExpr); - auto result = - evaluate(expression, makeRowVector({makeNullableArrayVector(input)})); - assertEqualVectors(makeNullableFlatVector(expected), result); + auto expression = fmt::format("any_match(c0, x -> ({}))", lambdaExpr); + testAnyMatchExpr(expected, lambdaExpr, makeNullableArrayVector(input)); } - void testExpr( + void testAnyMatchExpr( const std::vector>& expected, const std::string& expression, const RowVectorPtr& input) { @@ -55,35 +50,67 @@ class ArrayAnyMatchTest : public functions::test::FunctionBaseTest { TEST_F(ArrayAnyMatchTest, basic) { std::vector>> ints{ - {std::nullopt, 2, 0}, {-1, 3}, {-2, -3}, {}, {0, std::nullopt}}; + {std::nullopt, 2, 0}, + {-1, 3}, + {-2, -3}, + {}, + {0, std::nullopt}, + }; std::vector> expectedResult{ - true, true, false, false, std::nullopt}; - testExpr(expectedResult, "x > 1", ints); + true, + true, + false, + false, + std::nullopt, + }; + testAnyMatchExpr(expectedResult, "x > 1", ints); - expectedResult = {true, false, false, false, true}; - testExpr(expectedResult, "x is null", ints); + expectedResult = { + true, + false, + false, + false, + true, + }; + testAnyMatchExpr(expectedResult, "x is null", ints); std::vector>> bools{ {false, true}, {false, false}, {std::nullopt, true}, - {std::nullopt, false}}; - expectedResult = {true, false, true, std::nullopt}; - testExpr(expectedResult, "x", bools); + {std::nullopt, false}, + }; + expectedResult = { + true, + false, + true, + std::nullopt, + }; + testAnyMatchExpr(expectedResult, "x", bools); } TEST_F(ArrayAnyMatchTest, complexTypes) { - auto baseVector = - makeArrayVector({{1, 2, 3}, {2, 2}, {3, 3}, {4, 4}, {5, 5}, {}}); + auto baseVector = makeArrayVector({ + {1, 2, 3}, + {2, 2}, + {3, 3}, + {4, 4}, + {5, 5}, + {}, + }); // Create an array of array vector using above base vector using offsets. // [ - // [[1, 2, 3], []], + // [[1, 2, 3]], // [[2, 2], [3, 3], [4, 4], [5, 5]], - // [[], []] + // [[]] // ] auto arrayOfArrays = makeArrayVector({0, 1, 5}, baseVector); - std::vector> expectedResult{true, true, false}; - testExpr(expectedResult, "cardinality(x) > 0", arrayOfArrays); + std::vector> expectedResult{ + true, + true, + false, + }; + testAnyMatchExpr(expectedResult, "cardinality(x) > 0", arrayOfArrays); // Create an array of array vector using above base vector using offsets. // [ @@ -93,62 +120,96 @@ TEST_F(ArrayAnyMatchTest, complexTypes) { // null // ] arrayOfArrays = makeArrayVector({0, 1, 5, 6}, baseVector, {3}); - expectedResult = {true, false, false, std::nullopt}; - testExpr(expectedResult, "cardinality(x) > 2", arrayOfArrays); + expectedResult = { + true, + false, + false, + std::nullopt, + }; + testAnyMatchExpr(expectedResult, "cardinality(x) > 2", arrayOfArrays); } TEST_F(ArrayAnyMatchTest, strings) { std::vector>> input{ - {}, {"abc"}, {"ab", "abc"}, {std::nullopt}}; + {}, + {"abc"}, + {"ab", "abc"}, + {std::nullopt}, + }; std::vector> expectedResult{ - false, true, true, std::nullopt}; - testExpr(expectedResult, "x = 'abc'", input); + false, + true, + true, + std::nullopt, + }; + testAnyMatchExpr(expectedResult, "x = 'abc'", input); } TEST_F(ArrayAnyMatchTest, doubles) { std::vector>> input{ - {}, {0.2}, {3.0, 0}, {std::nullopt}}; + {}, + {0.2}, + {3.0, 0}, + {std::nullopt}, + }; std::vector> expectedResult{ - false, false, true, std::nullopt}; - testExpr(expectedResult, "x > 1.1", input); + false, + false, + true, + std::nullopt, + }; + testAnyMatchExpr(expectedResult, "x > 1.1", input); } TEST_F(ArrayAnyMatchTest, errors) { // No throw and return false if there are unmatched elements except nulls auto expression = "(10 / x) > 2"; std::vector>> input{ - {0, 2, 0, 5, 0}, {2, 5, std::nullopt, 0}}; - std::vector> expectedResult = {true, true}; - testExpr(expectedResult, expression, input); + {0, 2, 0, 5, 0}, + {2, 5, std::nullopt, 0}, + }; + testAnyMatchExpr({true, true}, expression, input); // Throw error if others are matched or null static constexpr std::string_view kErrorMessage{"division by zero"}; - auto errorInput = makeNullableArrayVector( - {{1, 0}, {2}, {6}, {10, 9, 0, std::nullopt}, {0, std::nullopt, 1}}); + auto errorInput = makeNullableArrayVector({ + {1, 0}, + {2}, + {6}, + {10, 9, 0, std::nullopt}, + {0, std::nullopt, 1}, + }); VELOX_ASSERT_THROW( - testExpr(expectedResult, expression, errorInput), kErrorMessage); + testAnyMatchExpr({true, true}, expression, errorInput), kErrorMessage); // Rerun using TRY to get right results auto errorInputRow = makeRowVector({errorInput}); - expectedResult = {true, true, false, std::nullopt, true}; - testExpr( + std::vector> expectedResult{ + true, + true, + false, + std::nullopt, + true, + }; + testAnyMatchExpr( expectedResult, "TRY(any_match(c0, x -> ((10 / x) > 2)))", errorInputRow); - testExpr( + testAnyMatchExpr( expectedResult, "any_match(c0, x -> (TRY((10 / x) > 2)))", errorInputRow); } TEST_F(ArrayAnyMatchTest, conditional) { // No throw and return false if there are unmatched elements except nulls auto c0 = makeFlatVector({1, 2, 3, 4, 5}); - auto c1 = makeNullableArrayVector( - {{4, 100, std::nullopt}, - {50, 12}, - {std::nullopt}, - {3, std::nullopt, 0}, - {300, 100}}); + auto c1 = makeNullableArrayVector({ + {4, 100, std::nullopt}, + {50, 12}, + {std::nullopt}, + {3, std::nullopt, 0}, + {300, 100}, + }); auto input = makeRowVector({c0, c1}); std::vector> expectedResult = { std::nullopt, false, std::nullopt, true, false}; - testExpr( + testAnyMatchExpr( expectedResult, "any_match(c1, if (c0 <= 2, x -> (x > 100), x -> (10 / x > 2)))", input); diff --git a/velox/functions/prestosql/tests/ArrayNoneMatchTest.cpp b/velox/functions/prestosql/tests/ArrayNoneMatchTest.cpp index 9db44da30839..57e1b27e5826 100644 --- a/velox/functions/prestosql/tests/ArrayNoneMatchTest.cpp +++ b/velox/functions/prestosql/tests/ArrayNoneMatchTest.cpp @@ -14,7 +14,6 @@ * limitations under the License. */ -#include #include "velox/common/base/tests/GTestUtils.h" #include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" @@ -23,28 +22,25 @@ using namespace facebook::velox::test; class ArrayNoneMatchTest : public functions::test::FunctionBaseTest { protected: - // Evaluate an expression. - void testExpr( + void testNoneMatchExpr( const std::vector>& expected, const std::string& lambdaExpr, const VectorPtr& input) { - auto expression = folly::sformat("none_match(c0, x -> ({}))", lambdaExpr); - auto result = evaluate(expression, makeRowVector({input})); - assertEqualVectors(makeNullableFlatVector(expected), result); + auto expression = fmt::format("none_match(c0, x -> ({}))", lambdaExpr); + testNoneMatchExpr(expected, expression, makeRowVector({input})); } template - void testExpr( + void testNoneMatchExpr( const std::vector>& expected, const std::string& lambdaExpr, const std::vector>>& input) { - auto expression = folly::sformat("none_match(c0, x -> ({}))", lambdaExpr); - auto result = - evaluate(expression, makeRowVector({makeNullableArrayVector(input)})); - assertEqualVectors(makeNullableFlatVector(expected), result); + auto expression = fmt::format("none_match(c0, x -> ({}))", lambdaExpr); + testNoneMatchExpr( + expected, expression, makeRowVector({makeNullableArrayVector(input)})); } - void testExpr( + void testNoneMatchExpr( const std::vector>& expected, const std::string& expression, const RowVectorPtr& input) { @@ -55,35 +51,63 @@ class ArrayNoneMatchTest : public functions::test::FunctionBaseTest { TEST_F(ArrayNoneMatchTest, basic) { std::vector>> ints{ - {std::nullopt, 2, 0}, {-1, 3}, {-2, -3}, {}, {0, std::nullopt}}; + {std::nullopt, 2, 0}, + {-1, 3}, + {-2, -3}, + {}, + {0, std::nullopt}, + }; std::vector> expectedResult{ - false, false, true, true, std::nullopt}; - testExpr(expectedResult, "x > 1", ints); + false, + false, + true, + true, + std::nullopt, + }; + testNoneMatchExpr(expectedResult, "x > 1", ints); - expectedResult = {false, true, true, true, false}; - testExpr(expectedResult, "x is null", ints); + expectedResult = { + false, + true, + true, + true, + false, + }; + testNoneMatchExpr(expectedResult, "x is null", ints); std::vector>> bools{ {false, true}, {false, false}, {std::nullopt, true}, - {std::nullopt, false}}; - expectedResult = {false, true, false, std::nullopt}; - testExpr(expectedResult, "x", bools); + {std::nullopt, false}, + }; + expectedResult = { + false, + true, + false, + std::nullopt, + }; + testNoneMatchExpr(expectedResult, "x", bools); } TEST_F(ArrayNoneMatchTest, complexTypes) { - auto baseVector = - makeArrayVector({{1, 2, 3}, {2, 2}, {3, 3}, {4, 4}, {5, 5}, {}}); + auto baseVector = makeArrayVector({ + {1, 2, 3}, + {2, 2}, + {3, 3}, + {4, 4}, + {5, 5}, + {}, + }); // Create an array of array vector using above base vector using offsets. // [ - // [[1, 2, 3], []], + // [[1, 2, 3]], // [[2, 2], [3, 3], [4, 4], [5, 5]], - // [[], []] + // [[]] // ] auto arrayOfArrays = makeArrayVector({0, 1, 5}, baseVector); std::vector> expectedResult{false, false, true}; - testExpr(expectedResult, "cardinality(x) > 0", arrayOfArrays); + testNoneMatchExpr(expectedResult, "cardinality(x) > 0", arrayOfArrays); // Create an array of array vector using above base vector using offsets. // [ @@ -94,47 +118,75 @@ TEST_F(ArrayNoneMatchTest, complexTypes) { // ] arrayOfArrays = makeArrayVector({0, 1, 5, 6}, baseVector, {3}); expectedResult = {false, true, true, std::nullopt}; - testExpr(expectedResult, "cardinality(x) > 2", arrayOfArrays); + testNoneMatchExpr(expectedResult, "cardinality(x) > 2", arrayOfArrays); } TEST_F(ArrayNoneMatchTest, strings) { std::vector>> input{ - {}, {"abc"}, {"ab", "abc"}, {std::nullopt}}; + {}, + {"abc"}, + {"ab", "abc"}, + {std::nullopt}, + }; std::vector> expectedResult{ - true, false, false, std::nullopt}; - testExpr(expectedResult, "x = 'abc'", input); + true, + false, + false, + std::nullopt, + }; + testNoneMatchExpr(expectedResult, "x = 'abc'", input); } TEST_F(ArrayNoneMatchTest, doubles) { std::vector>> input{ - {}, {0.2}, {3.0, 0}, {std::nullopt}}; + {}, + {0.2}, + {3.0, 0}, + {std::nullopt}, + }; std::vector> expectedResult{ - true, true, false, std::nullopt}; - testExpr(expectedResult, "x > 1.1", input); + true, + true, + false, + std::nullopt, + }; + testNoneMatchExpr(expectedResult, "x > 1.1", input); } TEST_F(ArrayNoneMatchTest, errors) { // No throw and return false if there are unmatched elements except nulls auto expression = "(10 / x) > 2"; std::vector>> input{ - {0, 2, 0, 5, 0}, {2, 5, std::nullopt, 0}}; - std::vector> expectedResult = {false, false}; - testExpr(expectedResult, expression, input); + {0, 2, 0, 5, 0}, + {2, 5, std::nullopt, 0}, + }; + testNoneMatchExpr({false, false}, expression, input); // Throw error if others are matched or null static constexpr std::string_view kErrorMessage{"division by zero"}; - auto errorInput = makeNullableArrayVector( - {{1, 0}, {2}, {6}, {10, 9, 0, std::nullopt}, {0, std::nullopt, 1}}); + auto errorInput = makeNullableArrayVector({ + {1, 0}, + {2}, + {6}, + {10, 9, 0, std::nullopt}, + {0, std::nullopt, 1}, + }); VELOX_ASSERT_THROW( - testExpr(expectedResult, expression, errorInput), kErrorMessage); + testNoneMatchExpr({false, false}, expression, errorInput), kErrorMessage); // Rerun using TRY to get right results auto errorInputRow = makeRowVector({errorInput}); - expectedResult = {false, false, true, std::nullopt, false}; - testExpr( + std::vector> expectedResult = { + false, + false, + true, + std::nullopt, + false, + }; + testNoneMatchExpr( expectedResult, "TRY(none_match(c0, x -> ((10 / x) > 2)))", errorInputRow); - testExpr( + testNoneMatchExpr( expectedResult, "none_match(c0, x -> (TRY((10 / x) > 2)))", errorInputRow); @@ -143,16 +195,22 @@ TEST_F(ArrayNoneMatchTest, errors) { TEST_F(ArrayNoneMatchTest, conditional) { // No throw and return false if there are unmatched elements except nulls auto c0 = makeFlatVector({1, 2, 3, 4, 5}); - auto c1 = makeNullableArrayVector( - {{4, 100, std::nullopt}, - {50, 12}, - {std::nullopt}, - {3, std::nullopt, 0}, - {300, 100}}); + auto c1 = makeNullableArrayVector({ + {4, 100, std::nullopt}, + {50, 12}, + {std::nullopt}, + {3, std::nullopt, 0}, + {300, 100}, + }); auto input = makeRowVector({c0, c1}); std::vector> expectedResult = { - std::nullopt, true, std::nullopt, false, true}; - testExpr( + std::nullopt, + true, + std::nullopt, + false, + true, + }; + testNoneMatchExpr( expectedResult, "none_match(c1, if (c0 <= 2, x -> (x > 100), x -> (10 / x > 2)))", input);