diff --git a/velox/docs/functions/presto/array.rst b/velox/docs/functions/presto/array.rst index 3a6a8cd628df..10b039df8389 100644 --- a/velox/docs/functions/presto/array.rst +++ b/velox/docs/functions/presto/array.rst @@ -13,13 +13,22 @@ Array Functions .. function:: any_match(array(T), function(T, boolean)) → boolean - Returns whether at least one element of an array matchs the given predicate. + Returns whether at least one element of an array matches the given predicate. Returns true if one or more elements match the predicate; Returns false if none of the elements matches (a special case is when the array is empty); Returns NULL if the predicate function returns NULL for one or more elements and false for all other elements. Throws an exception if the predicate fails for one or more elements and returns false or NULL for the rest. +.. function:: none_match(array(T), function(T, boolean)) → boolean + + Returns whether no elements of an array match the given predicate. + + Returns true if none of the elements matches the predicate (a special case is when the array is empty); + Returns false if one or more elements match; + Returns NULL if the predicate function returns NULL for one or more elements and false for all other elements. + Throws an exception if the predicate fails for one or more elements and returns false or NULL for the rest. + .. function:: array_average(array(double)) -> double Returns the average of all non-null elements of the array. If there are no non-null elements, returns null. diff --git a/velox/functions/prestosql/ArrayMatch.cpp b/velox/functions/prestosql/ArrayMatch.cpp index fcd57dfeecea..ba6df446dd31 100644 --- a/velox/functions/prestosql/ArrayMatch.cpp +++ b/velox/functions/prestosql/ArrayMatch.cpp @@ -24,9 +24,9 @@ namespace facebook::velox::functions { namespace { -enum class MatchMethod { ALL = 0, ANY = 1, NONE = 2 }; +enum class MatchMethod { kAll = 0, kAny = 1, kNone = 2 }; -template +template class MatchFunction : public exec::VectorFunction { private: void apply( @@ -94,8 +94,7 @@ class MatchFunction : public exec::VectorFunction { static FOLLY_ALWAYS_INLINE bool hasError( const ErrorVectorPtr& errors, int idx) { - return errors && idx < errors->size() && - !errors->isNullAt(idx); + return errors && idx < errors->size() && !errors->isNullAt(idx); } private: @@ -109,49 +108,66 @@ class MatchFunction : public exec::VectorFunction { const exec::LocalDecodedVector& bitsDecoder) const { auto size = sizes[row]; auto offset = offsets[row]; - if (mathMethod == MatchMethod::NONE) { - // TODO: Add none_match - } else { - // All or any match. - // All: early exit on non-match. Initial value = true. - // Any: early exit on match. Initial value = false. - auto match = mathMethod == MatchMethod::ALL; - bool result = match; - auto hasNull = false; - std::exception_ptr errorPtr{nullptr}; - for (auto i = 0; i < size; ++i) { - auto idx = offset + i; - if (hasError(elementErrors, idx)) { - errorPtr = *std::static_pointer_cast( - elementErrors->valueAt(idx)); - continue; - } - if (bitsDecoder->isNullAt(idx)) { - hasNull = true; - } else if (bitsDecoder->valueAt(idx) == !match) { - result = !result; - break; - } + // All, none, and any match have different and similar logic intertwined in + // terms of the initial value, flip condition, and the result finalization: + // + // Initial value: + // All is true + // Any is false + // None is true + // + // Flip logic: + // All flips once encounter an unmatched element + // Any flips once encounter a matched element + // None flips once encounter a matched element + // + // Result finalization: + // All: ignore the error and null if one or more elements are unmatched and + // return false Any: ignore the error and null if one or more elements + // matched and return true None: ignore the error and null if one or more + // elements matched and return false + auto match = (matchMethod != MatchMethod::kAny); + auto hasNull = false; + std::exception_ptr errorPtr{nullptr}; + for (auto i = 0; i < size; ++i) { + auto idx = offset + i; + if (hasError(elementErrors, idx)) { + errorPtr = *std::static_pointer_cast( + elementErrors->valueAt(idx)); + continue; } - if (result != match) { - flatResult->set(row, !match); - } else if (errorPtr) { - context.setError(row, errorPtr); - } else if (hasNull) { - flatResult->setNull(row, true); + if (bitsDecoder->isNullAt(idx)) { + hasNull = true; + } else if (matchMethod == MatchMethod::kAll) { + if (!bitsDecoder->valueAt(idx)) { + match = !match; + break; + } } else { - flatResult->set(row, match); + if (bitsDecoder->valueAt(idx)) { + match = !match; + break; + } } } + + if ((matchMethod == MatchMethod::kAny) == match) { + flatResult->set(row, match); + } else if (errorPtr) { + context.setError(row, errorPtr); + } else if (hasNull) { + flatResult->setNull(row, true); + } else { + flatResult->set(row, match); + } } }; -class AllMatchFunction : public MatchFunction {}; -class AnyMatchFunction : public MatchFunction {}; - -// TODO: add class NoneMatchFunction +class AllMatchFunction : public MatchFunction {}; +class AnyMatchFunction : public MatchFunction {}; +class NoneMatchFunction : public MatchFunction {}; std::vector> signatures() { // array(T), function(T) -> boolean @@ -175,4 +191,9 @@ VELOX_DECLARE_VECTOR_FUNCTION( signatures(), std::make_unique()); +VELOX_DECLARE_VECTOR_FUNCTION( + udf_none_match, + signatures(), + std::make_unique()); + } // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/registration/ArrayFunctionsRegistration.cpp b/velox/functions/prestosql/registration/ArrayFunctionsRegistration.cpp index 0325d23ed497..ba6aa5fb607c 100644 --- a/velox/functions/prestosql/registration/ArrayFunctionsRegistration.cpp +++ b/velox/functions/prestosql/registration/ArrayFunctionsRegistration.cpp @@ -80,6 +80,7 @@ void registerArrayFunctions(const std::string& prefix) { registerArrayConstructor(prefix + "array_constructor"); VELOX_REGISTER_VECTOR_FUNCTION(udf_all_match, prefix + "all_match"); VELOX_REGISTER_VECTOR_FUNCTION(udf_any_match, prefix + "any_match"); + VELOX_REGISTER_VECTOR_FUNCTION(udf_none_match, prefix + "none_match"); VELOX_REGISTER_VECTOR_FUNCTION(udf_array_distinct, prefix + "array_distinct"); VELOX_REGISTER_VECTOR_FUNCTION( udf_array_duplicates, prefix + "array_duplicates"); diff --git a/velox/functions/prestosql/tests/ArrayAnyMatchTest.cpp b/velox/functions/prestosql/tests/ArrayAnyMatchTest.cpp index c940e90bda80..498b8618725d 100644 --- a/velox/functions/prestosql/tests/ArrayAnyMatchTest.cpp +++ b/velox/functions/prestosql/tests/ArrayAnyMatchTest.cpp @@ -33,6 +33,17 @@ class ArrayAnyMatchTest : public functions::test::FunctionBaseTest { assertEqualVectors(makeNullableFlatVector(expected), result); } + template + void testExpr( + const std::vector>& expected, + const std::string& lambdaExpr, + const std::vector>>& input) { + auto expression = folly::sformat("any_match(c0, x -> ({}))", lambdaExpr); + auto result = + evaluate(expression, makeRowVector({makeNullableArrayVector(input)})); + assertEqualVectors(makeNullableFlatVector(expected), result); + } + void testExpr( const std::vector>& expected, const std::string& expression, @@ -43,22 +54,22 @@ class ArrayAnyMatchTest : public functions::test::FunctionBaseTest { }; TEST_F(ArrayAnyMatchTest, basic) { - auto input = makeNullableArrayVector( - {{std::nullopt, 2, 0}, {-1, 3}, {-2, -3}, {}, {0, std::nullopt}}); + std::vector>> ints{ + {std::nullopt, 2, 0}, {-1, 3}, {-2, -3}, {}, {0, std::nullopt}}; std::vector> expectedResult{ true, true, false, false, std::nullopt}; - testExpr(expectedResult, "x > 1", input); + testExpr(expectedResult, "x > 1", ints); expectedResult = {true, false, false, false, true}; - testExpr(expectedResult, "x is null", input); + testExpr(expectedResult, "x is null", ints); - input = makeNullableArrayVector( - {{false, true}, - {false, false}, - {std::nullopt, true}, - {std::nullopt, false}}); + std::vector>> bools{ + {false, true}, + {false, false}, + {std::nullopt, true}, + {std::nullopt, false}}; expectedResult = {true, false, true, std::nullopt}; - testExpr(expectedResult, "x", input); + testExpr(expectedResult, "x", bools); } TEST_F(ArrayAnyMatchTest, complexTypes) { @@ -87,16 +98,16 @@ TEST_F(ArrayAnyMatchTest, complexTypes) { } TEST_F(ArrayAnyMatchTest, strings) { - auto input = makeNullableArrayVector( - {{}, {"abc"}, {"ab", "abc"}, {std::nullopt}}); + std::vector>> input{ + {}, {"abc"}, {"ab", "abc"}, {std::nullopt}}; std::vector> expectedResult{ false, true, true, std::nullopt}; testExpr(expectedResult, "x = 'abc'", input); } TEST_F(ArrayAnyMatchTest, doubles) { - auto input = - makeNullableArrayVector({{}, {0.2}, {3.0, 0}, {std::nullopt}}); + std::vector>> input{ + {}, {0.2}, {3.0, 0}, {std::nullopt}}; std::vector> expectedResult{ false, false, true, std::nullopt}; testExpr(expectedResult, "x > 1.1", input); @@ -105,8 +116,8 @@ TEST_F(ArrayAnyMatchTest, doubles) { TEST_F(ArrayAnyMatchTest, errors) { // No throw and return false if there are unmatched elements except nulls auto expression = "(10 / x) > 2"; - auto input = makeNullableArrayVector( - {{0, 2, 0, 5, 0}, {2, 5, std::nullopt, 0}}); + std::vector>> input{ + {0, 2, 0, 5, 0}, {2, 5, std::nullopt, 0}}; std::vector> expectedResult = {true, true}; testExpr(expectedResult, expression, input); diff --git a/velox/functions/prestosql/tests/ArrayNoneMatchTest.cpp b/velox/functions/prestosql/tests/ArrayNoneMatchTest.cpp new file mode 100644 index 000000000000..9db44da30839 --- /dev/null +++ b/velox/functions/prestosql/tests/ArrayNoneMatchTest.cpp @@ -0,0 +1,159 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" + +using namespace facebook::velox; +using namespace facebook::velox::test; + +class ArrayNoneMatchTest : public functions::test::FunctionBaseTest { + protected: + // Evaluate an expression. + void testExpr( + const std::vector>& expected, + const std::string& lambdaExpr, + const VectorPtr& input) { + auto expression = folly::sformat("none_match(c0, x -> ({}))", lambdaExpr); + auto result = evaluate(expression, makeRowVector({input})); + assertEqualVectors(makeNullableFlatVector(expected), result); + } + + template + void testExpr( + const std::vector>& expected, + const std::string& lambdaExpr, + const std::vector>>& input) { + auto expression = folly::sformat("none_match(c0, x -> ({}))", lambdaExpr); + auto result = + evaluate(expression, makeRowVector({makeNullableArrayVector(input)})); + assertEqualVectors(makeNullableFlatVector(expected), result); + } + + void testExpr( + const std::vector>& expected, + const std::string& expression, + const RowVectorPtr& input) { + auto result = evaluate(expression, (input)); + assertEqualVectors(makeNullableFlatVector(expected), result); + } +}; + +TEST_F(ArrayNoneMatchTest, basic) { + std::vector>> ints{ + {std::nullopt, 2, 0}, {-1, 3}, {-2, -3}, {}, {0, std::nullopt}}; + std::vector> expectedResult{ + false, false, true, true, std::nullopt}; + testExpr(expectedResult, "x > 1", ints); + + expectedResult = {false, true, true, true, false}; + testExpr(expectedResult, "x is null", ints); + + std::vector>> bools{ + {false, true}, + {false, false}, + {std::nullopt, true}, + {std::nullopt, false}}; + expectedResult = {false, true, false, std::nullopt}; + testExpr(expectedResult, "x", bools); +} + +TEST_F(ArrayNoneMatchTest, complexTypes) { + auto baseVector = + makeArrayVector({{1, 2, 3}, {2, 2}, {3, 3}, {4, 4}, {5, 5}, {}}); + // Create an array of array vector using above base vector using offsets. + // [ + // [[1, 2, 3], []], + // [[2, 2], [3, 3], [4, 4], [5, 5]], + // [[], []] + // ] + auto arrayOfArrays = makeArrayVector({0, 1, 5}, baseVector); + std::vector> expectedResult{false, false, true}; + testExpr(expectedResult, "cardinality(x) > 0", arrayOfArrays); + + // Create an array of array vector using above base vector using offsets. + // [ + // [[1, 2, 3]], cardinalities is 3 + // [[2, 2], [3, 3], [4, 4], [5, 5]], all cardinalities is 2 + // [[]], + // null + // ] + arrayOfArrays = makeArrayVector({0, 1, 5, 6}, baseVector, {3}); + expectedResult = {false, true, true, std::nullopt}; + testExpr(expectedResult, "cardinality(x) > 2", arrayOfArrays); +} + +TEST_F(ArrayNoneMatchTest, strings) { + std::vector>> input{ + {}, {"abc"}, {"ab", "abc"}, {std::nullopt}}; + std::vector> expectedResult{ + true, false, false, std::nullopt}; + testExpr(expectedResult, "x = 'abc'", input); +} + +TEST_F(ArrayNoneMatchTest, doubles) { + std::vector>> input{ + {}, {0.2}, {3.0, 0}, {std::nullopt}}; + std::vector> expectedResult{ + true, true, false, std::nullopt}; + testExpr(expectedResult, "x > 1.1", input); +} + +TEST_F(ArrayNoneMatchTest, errors) { + // No throw and return false if there are unmatched elements except nulls + auto expression = "(10 / x) > 2"; + std::vector>> input{ + {0, 2, 0, 5, 0}, {2, 5, std::nullopt, 0}}; + std::vector> expectedResult = {false, false}; + testExpr(expectedResult, expression, input); + + // Throw error if others are matched or null + static constexpr std::string_view kErrorMessage{"division by zero"}; + auto errorInput = makeNullableArrayVector( + {{1, 0}, {2}, {6}, {10, 9, 0, std::nullopt}, {0, std::nullopt, 1}}); + VELOX_ASSERT_THROW( + testExpr(expectedResult, expression, errorInput), kErrorMessage); + // Rerun using TRY to get right results + auto errorInputRow = makeRowVector({errorInput}); + expectedResult = {false, false, true, std::nullopt, false}; + testExpr( + expectedResult, + "TRY(none_match(c0, x -> ((10 / x) > 2)))", + errorInputRow); + testExpr( + expectedResult, + "none_match(c0, x -> (TRY((10 / x) > 2)))", + errorInputRow); +} + +TEST_F(ArrayNoneMatchTest, conditional) { + // No throw and return false if there are unmatched elements except nulls + auto c0 = makeFlatVector({1, 2, 3, 4, 5}); + auto c1 = makeNullableArrayVector( + {{4, 100, std::nullopt}, + {50, 12}, + {std::nullopt}, + {3, std::nullopt, 0}, + {300, 100}}); + auto input = makeRowVector({c0, c1}); + std::vector> expectedResult = { + std::nullopt, true, std::nullopt, false, true}; + testExpr( + expectedResult, + "none_match(c1, if (c0 <= 2, x -> (x > 100), x -> (10 / x > 2)))", + input); +} diff --git a/velox/functions/prestosql/tests/CMakeLists.txt b/velox/functions/prestosql/tests/CMakeLists.txt index 6f43d4f6f851..181aa1ff9a67 100644 --- a/velox/functions/prestosql/tests/CMakeLists.txt +++ b/velox/functions/prestosql/tests/CMakeLists.txt @@ -32,6 +32,7 @@ add_executable( ArrayIntersectTest.cpp ArrayMaxTest.cpp ArrayMinTest.cpp + ArrayNoneMatchTest.cpp ArrayNormalizeTest.cpp ArrayPositionTest.cpp ArraySortTest.cpp