From e781eb953a988559147135084bc3f13818eb6aee Mon Sep 17 00:00:00 2001 From: macduan Date: Wed, 22 Mar 2023 00:09:28 +0800 Subject: [PATCH] resolve comments --- velox/docs/functions/presto/array.rst | 2 +- velox/functions/prestosql/ArrayMatch.cpp | 75 ++++++++-------- .../prestosql/tests/ArrayAnyMatchTest.cpp | 90 +++++++------------ 3 files changed, 72 insertions(+), 95 deletions(-) diff --git a/velox/docs/functions/presto/array.rst b/velox/docs/functions/presto/array.rst index 47065a66387b..3a6a8cd628df 100644 --- a/velox/docs/functions/presto/array.rst +++ b/velox/docs/functions/presto/array.rst @@ -13,7 +13,7 @@ Array Functions .. function:: any_match(array(T), function(T, boolean)) → boolean - Returns whether at least one element of an array match the given predicate. + Returns whether at least one element of an array matchs the given predicate. Returns true if one or more elements match the predicate; Returns false if none of the elements matches (a special case is when the array is empty); diff --git a/velox/functions/prestosql/ArrayMatch.cpp b/velox/functions/prestosql/ArrayMatch.cpp index 389631cdb5a8..fcd57dfeecea 100644 --- a/velox/functions/prestosql/ArrayMatch.cpp +++ b/velox/functions/prestosql/ArrayMatch.cpp @@ -24,13 +24,15 @@ namespace facebook::velox::functions { namespace { -template +enum class MatchMethod { ALL = 0, ANY = 1, NONE = 2 }; + +template class MatchFunction : public exec::VectorFunction { private: void apply( const SelectivityVector& rows, std::vector& args, - const TypePtr& outputType, + const TypePtr& /*outputType*/, exec::EvalCtx& context, VectorPtr& result) const override { exec::LocalDecodedVector arrayDecoder(context, *args[0], rows); @@ -89,7 +91,6 @@ class MatchFunction : public exec::VectorFunction { } } - protected: static FOLLY_ALWAYS_INLINE bool hasError( const ErrorVectorPtr& errors, int idx) { @@ -108,47 +109,49 @@ class MatchFunction : public exec::VectorFunction { const exec::LocalDecodedVector& bitsDecoder) const { auto size = sizes[row]; auto offset = offsets[row]; - - // All or any match. - // All: early exit on non-match. Initial value = true. - // Any: early exit on match. Initial value = false. - bool result = match; - auto hasNull = false; - std::exception_ptr errorPtr{nullptr}; - for (auto i = 0; i < size; ++i) { - auto idx = offset + i; - if (hasError(elementErrors, idx)) { - errorPtr = *std::static_pointer_cast( - elementErrors->valueAt(idx)); - continue; + if (mathMethod == MatchMethod::NONE) { + // TODO: Add none_match + } else { + // All or any match. + // All: early exit on non-match. Initial value = true. + // Any: early exit on match. Initial value = false. + auto match = mathMethod == MatchMethod::ALL; + bool result = match; + auto hasNull = false; + std::exception_ptr errorPtr{nullptr}; + for (auto i = 0; i < size; ++i) { + auto idx = offset + i; + if (hasError(elementErrors, idx)) { + errorPtr = *std::static_pointer_cast( + elementErrors->valueAt(idx)); + continue; + } + + if (bitsDecoder->isNullAt(idx)) { + hasNull = true; + } else if (bitsDecoder->valueAt(idx) == !match) { + result = !result; + break; + } } - if (bitsDecoder->isNullAt(idx)) { - hasNull = true; - } else if (bitsDecoder->valueAt(idx) == !match) { - result = !result; - break; + if (result != match) { + flatResult->set(row, !match); + } else if (errorPtr) { + context.setError(row, errorPtr); + } else if (hasNull) { + flatResult->setNull(row, true); + } else { + flatResult->set(row, match); } } - - // Errors for individual array elements should be suppressed only if the - // outcome can be decided by some other array element, e.g. if there is - // another element that returns 'true' for the predicate. - if (result != match) { - flatResult->set(row, !match); - } else if (errorPtr) { - context.setError(row, errorPtr); - } else if (hasNull) { - flatResult->setNull(row, true); - } else { - flatResult->set(row, match); - } } }; -class AllMatchFunction : public MatchFunction {}; +class AllMatchFunction : public MatchFunction {}; +class AnyMatchFunction : public MatchFunction {}; -class AnyMatchFunction : public MatchFunction {}; +// TODO: add class NoneMatchFunction std::vector> signatures() { // array(T), function(T) -> boolean diff --git a/velox/functions/prestosql/tests/ArrayAnyMatchTest.cpp b/velox/functions/prestosql/tests/ArrayAnyMatchTest.cpp index ee8729789267..aee8b50b9ac1 100644 --- a/velox/functions/prestosql/tests/ArrayAnyMatchTest.cpp +++ b/velox/functions/prestosql/tests/ArrayAnyMatchTest.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include "velox/common/base/tests/GTestUtils.h" #include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" @@ -24,19 +25,20 @@ class ArrayAnyMatchTest : public functions::test::FunctionBaseTest { protected: // Evaluate an expression. void testExpr( - const VectorPtr& expected, - const std::string& expression, + const std::vector>& expected, + const std::string& lambdaExpr, const VectorPtr& input) { + auto expression = folly::sformat("any_match(c0, x -> ({}))", lambdaExpr); auto result = evaluate(expression, makeRowVector({input})); - assertEqualVectors(expected, result); + assertEqualVectors(makeNullableFlatVector(expected), result); } void testExpr( - const VectorPtr& expected, + const std::vector>& expected, const std::string& expression, const RowVectorPtr& input) { auto result = evaluate(expression, (input)); - assertEqualVectors(expected, result); + assertEqualVectors(makeNullableFlatVector(expected), result); } }; @@ -47,25 +49,19 @@ TEST_F(ArrayAnyMatchTest, basic) { {2, 3}, {}, {std::nullopt, std::nullopt}}); - auto expectedResult = - makeNullableFlatVector({true, true, true, false, std::nullopt}); - testExpr(expectedResult, "any_match(c0, x -> (x > 1))", input); - expectedResult = makeFlatVector({true, false, false, false, true}); - testExpr(expectedResult, "any_match(c0, x -> (x is null))", input); + std::vector> expectedResult{ + true, true, true, false, std::nullopt}; + testExpr(expectedResult, "x > 1", input); + expectedResult = {true, false, false, false, true}; + testExpr(expectedResult, "x is null", input); input = makeNullableArrayVector( {{false, true}, {false, false}, {std::nullopt, true}, {std::nullopt, false}}); - expectedResult = - makeNullableFlatVector({true, false, true, std::nullopt}); - testExpr(expectedResult, "any_match(c0, x -> x)", input); - - auto emptyInput = makeArrayVector({{}}); - expectedResult = makeFlatVector(std::vector{false}); - testExpr(expectedResult, "any_match(c0, x -> (x > 1))", emptyInput); - testExpr(expectedResult, "any_match(c0, x -> (x <= 1))", emptyInput); + expectedResult = {true, false, true, std::nullopt}; + testExpr(expectedResult, "x", input); } TEST_F(ArrayAnyMatchTest, complexTypes) { @@ -78,11 +74,8 @@ TEST_F(ArrayAnyMatchTest, complexTypes) { // [[], []] // ] auto arrayOfArrays = makeArrayVector({0, 1, 5}, baseVector); - auto expectedResult = makeNullableFlatVector({true, true, false}); - testExpr( - expectedResult, - "any_match(c0, x -> (cardinality(x) > 0))", - arrayOfArrays); + std::vector> expectedResult{true, true, false}; + testExpr(expectedResult, "cardinality(x) > 0", arrayOfArrays); // Create an array of array vector using above base vector using offsets. // [ @@ -92,51 +85,32 @@ TEST_F(ArrayAnyMatchTest, complexTypes) { // null // ] arrayOfArrays = makeArrayVector({0, 1, 5, 6}, baseVector, {3}); - expectedResult = - makeNullableFlatVector({true, false, false, std::nullopt}); - testExpr( - expectedResult, - "any_match(c0, x -> (cardinality(x) > 2))", - arrayOfArrays); -} - -TEST_F(ArrayAnyMatchTest, bigints) { - auto input = makeNullableArrayVector( - {{}, - {2}, - {std::numeric_limits::max()}, - {std::numeric_limits::min()}, - {std::nullopt, std::nullopt}, // return null if all is null - {2, - std::nullopt}, // return null if one or more is null and others matched - {1, std::nullopt, 2}}); // return false if one is not matched - auto expectedResult = makeNullableFlatVector( - {false, true, false, true, std::nullopt, true, true}); - testExpr(expectedResult, "any_match(c0, x -> (x % 2 = 0))", input); + expectedResult = {true, false, false, std::nullopt}; + testExpr(expectedResult, "cardinality(x) > 2", arrayOfArrays); } TEST_F(ArrayAnyMatchTest, strings) { auto input = makeNullableArrayVector( {{}, {"abc"}, {"ab", "abc"}, {std::nullopt}}); - auto expectedResult = - makeNullableFlatVector({false, true, true, std::nullopt}); - testExpr(expectedResult, "any_match(c0, x -> (x = 'abc'))", input); + std::vector> expectedResult{ + false, true, true, std::nullopt}; + testExpr(expectedResult, "x = 'abc'", input); } TEST_F(ArrayAnyMatchTest, doubles) { auto input = makeNullableArrayVector({{}, {0.2}, {3.0, 0}, {std::nullopt}}); - auto expectedResult = - makeNullableFlatVector({false, false, true, std::nullopt}); - testExpr(expectedResult, "any_match(c0, x -> (x > 1.1))", input); + std::vector> expectedResult{ + false, false, true, std::nullopt}; + testExpr(expectedResult, "x > 1.1", input); } TEST_F(ArrayAnyMatchTest, errors) { // No throw and return false if there are unmatched elements except nulls - auto expression = "any_match(c0, x -> ((10 / x) > 2))"; + auto expression = "(10 / x) > 2"; auto input = makeNullableArrayVector( {{0, 2, 0, 5, 0}, {2, 5, std::nullopt, 0}}); - auto expectedResult = makeFlatVector({true, true}); + std::vector> expectedResult = {true, true}; testExpr(expectedResult, expression, input); // Throw error if others are matched or null @@ -146,12 +120,12 @@ TEST_F(ArrayAnyMatchTest, errors) { VELOX_ASSERT_THROW( testExpr(expectedResult, expression, errorInput), kErrorMessage); // Rerun using TRY to get right results - expectedResult = - makeNullableFlatVector({true, true, false, std::nullopt, true}); + auto errorInputRow = makeRowVector({errorInput}); + expectedResult = {true, true, false, std::nullopt, true}; testExpr( - expectedResult, "TRY(any_match(c0, x -> ((10 / x) > 2)))", errorInput); + expectedResult, "TRY(any_match(c0, x -> ((10 / x) > 2)))", errorInputRow); testExpr( - expectedResult, "any_match(c0, x -> (TRY((10 / x) > 2)))", errorInput); + expectedResult, "any_match(c0, x -> (TRY((10 / x) > 2)))", errorInputRow); } TEST_F(ArrayAnyMatchTest, conditional) { @@ -164,8 +138,8 @@ TEST_F(ArrayAnyMatchTest, conditional) { {3, std::nullopt, 0}, {300, 100}}); auto input = makeRowVector({c0, c1}); - auto expectedResult = makeNullableFlatVector( - {std::nullopt, false, std::nullopt, true, false}); + std::vector> expectedResult = { + std::nullopt, false, std::nullopt, true, false}; testExpr( expectedResult, "any_match(c1, if (c0 <= 2, x -> (x > 100), x -> (10 / x > 2)))",