diff --git a/velox/docs/functions/presto/array.rst b/velox/docs/functions/presto/array.rst index e01178cf37e7..a37f8b2eb398 100644 --- a/velox/docs/functions/presto/array.rst +++ b/velox/docs/functions/presto/array.rst @@ -11,6 +11,15 @@ Array Functions Returns NULL if the predicate function returns NULL for one or more elements and true for all other elements. Throws an exception if the predicate fails for one or more elements and returns true or NULL for the rest. +.. function:: any_match(array(T), function(T, boolean)) → boolean + + Returns whether any elements of an array match the given predicate. + + Returns true if one or more elements match the predicate; + Returns false if none of the elements matches (a special case is when the array is empty); + Returns NULL NULL if the predicate function returns NULL for one or more elements and false for all other elements. + Throws an exception if the predicate fails for one or more elements and returns false or NULL for the rest. + .. function:: array_average(array(double)) -> double Returns the average of all non-null elements of the array. If there are no non-null elements, returns null. diff --git a/velox/functions/prestosql/ArrayAllMatch.cpp b/velox/functions/prestosql/ArrayAllMatch.cpp deleted file mode 100644 index d716ec325c79..000000000000 --- a/velox/functions/prestosql/ArrayAllMatch.cpp +++ /dev/null @@ -1,141 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "velox/expression/EvalCtx.h" -#include "velox/expression/Expr.h" -#include "velox/expression/VectorFunction.h" -#include "velox/functions/lib/LambdaFunctionUtil.h" -#include "velox/functions/lib/RowsTranslationUtil.h" -#include "velox/vector/FunctionVector.h" - -namespace facebook::velox::functions { -namespace { - -class AllMatchFunction : public exec::VectorFunction { - public: - void apply( - const SelectivityVector& rows, - std::vector& args, - const TypePtr& /*outputType*/, - exec::EvalCtx& context, - VectorPtr& result) const override { - exec::LocalDecodedVector arrayDecoder(context, *args[0], rows); - auto& decodedArray = *arrayDecoder.get(); - - auto flatArray = flattenArray(rows, args[0], decodedArray); - auto offsets = flatArray->rawOffsets(); - auto sizes = flatArray->rawSizes(); - - std::vector lambdaArgs = {flatArray->elements()}; - auto numElements = flatArray->elements()->size(); - - SelectivityVector finalSelection; - if (!context.isFinalSelection()) { - finalSelection = toElementRows( - numElements, *context.finalSelection(), flatArray.get()); - } - - VectorPtr matchBits; - auto elementToTopLevelRows = getElementToTopLevelRows( - numElements, rows, flatArray.get(), context.pool()); - - // Loop over lambda functions and apply these to elements of the base array, - // in most cases there will be only one function and the loop will run once. - context.ensureWritable(rows, BOOLEAN(), result); - auto flatResult = result->asFlatVector(); - exec::LocalDecodedVector bitsDecoder(context); - auto it = args[1]->asUnchecked()->iterator(&rows); - - while (auto entry = it.next()) { - ErrorVectorPtr elementErrors; - auto elementRows = - toElementRows(numElements, *entry.rows, flatArray.get()); - auto wrapCapture = toWrapCapture( - numElements, entry.callable, *entry.rows, flatArray); - entry.callable->applyNoThrow( - elementRows, - finalSelection, - wrapCapture, - &context, - lambdaArgs, - elementErrors, - &matchBits); - - bitsDecoder.get()->decode(*matchBits, elementRows); - entry.rows->applyToSelected([&](vector_size_t row) { - auto size = sizes[row]; - auto offset = offsets[row]; - auto allMatch = true; - auto hasNull = false; - std::exception_ptr errorPtr{nullptr}; - for (auto i = 0; i < size; ++i) { - auto idx = offset + i; - if (hasError(elementErrors, idx)) { - errorPtr = *std::static_pointer_cast( - elementErrors->valueAt(idx)); - continue; - } - - if (bitsDecoder->isNullAt(idx)) { - hasNull = true; - } else if (!bitsDecoder->valueAt(idx)) { - allMatch = false; - break; - } - } - - // Errors for individual array elements should be suppressed only if the - // outcome can be decided by some other array element, e.g. if there is - // another element that returns 'false' for the predicate. - if (!allMatch) { - flatResult->set(row, false); - } else if (errorPtr) { - context.setError(row, errorPtr); - } else if (hasNull) { - flatResult->setNull(row, true); - } else { - flatResult->set(row, true); - } - }); - } - } - - static std::vector> signatures() { - // array(T), function(T) -> boolean - return {exec::FunctionSignatureBuilder() - .typeVariable("T") - .returnType("boolean") - .argumentType("array(T)") - .argumentType("function(T, boolean)") - .build()}; - } - - private: - FOLLY_ALWAYS_INLINE bool hasError( - const ErrorVectorPtr& elementErrors, - int idx) const { - return elementErrors && idx < elementErrors->size() && - !elementErrors->isNullAt(idx); - } -}; -} // namespace - -VELOX_DECLARE_VECTOR_FUNCTION( - udf_all_match, - AllMatchFunction::signatures(), - std::make_unique()); - -} // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/ArrayMatch.cpp b/velox/functions/prestosql/ArrayMatch.cpp new file mode 100644 index 000000000000..704391bb96d2 --- /dev/null +++ b/velox/functions/prestosql/ArrayMatch.cpp @@ -0,0 +1,236 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/expression/EvalCtx.h" +#include "velox/expression/Expr.h" +#include "velox/expression/VectorFunction.h" +#include "velox/functions/lib/LambdaFunctionUtil.h" +#include "velox/functions/lib/RowsTranslationUtil.h" +#include "velox/vector/FunctionVector.h" + +namespace facebook::velox::functions { +namespace { + +class MatchFunction : public exec::VectorFunction { + private: + void apply( + const SelectivityVector& rows, + std::vector& args, + const TypePtr& outputType, + exec::EvalCtx& context, + VectorPtr& result) const override { + exec::LocalDecodedVector arrayDecoder(context, *args[0], rows); + auto& decodedArray = *arrayDecoder.get(); + + auto flatArray = flattenArray(rows, args[0], decodedArray); + auto offsets = flatArray->rawOffsets(); + auto sizes = flatArray->rawSizes(); + + std::vector lambdaArgs = {flatArray->elements()}; + auto numElements = flatArray->elements()->size(); + + SelectivityVector finalSelection; + if (!context.isFinalSelection()) { + finalSelection = toElementRows( + numElements, *context.finalSelection(), flatArray.get()); + } + + VectorPtr matchBits; + auto elementToTopLevelRows = getElementToTopLevelRows( + numElements, rows, flatArray.get(), context.pool()); + + // Loop over lambda functions and apply these to elements of the base array, + // in most cases there will be only one function and the loop will run once. + context.ensureWritable(rows, BOOLEAN(), result); + auto flatResult = result->asFlatVector(); + exec::LocalDecodedVector bitsDecoder(context); + auto it = args[1]->asUnchecked()->iterator(&rows); + + while (auto entry = it.next()) { + ErrorVectorPtr elementErrors; + auto elementRows = + toElementRows(numElements, *entry.rows, flatArray.get()); + auto wrapCapture = toWrapCapture( + numElements, entry.callable, *entry.rows, flatArray); + entry.callable->applyNoThrow( + elementRows, + finalSelection, + wrapCapture, + &context, + lambdaArgs, + elementErrors, + &matchBits); + + bitsDecoder.get()->decode(*matchBits, elementRows); + entry.rows->applyToSelected([&](vector_size_t row) { + applyInternal( + flatResult, + context, + row, + offsets, + sizes, + elementErrors, + bitsDecoder); + }); + } + } + + protected: + static FOLLY_ALWAYS_INLINE bool hasError( + const ErrorVectorPtr& elementErrors, + int idx) { + return elementErrors && idx < elementErrors->size() && + !elementErrors->isNullAt(idx); + } + + private: + virtual void applyInternal( + FlatVector* flatResult, + exec::EvalCtx& context, + vector_size_t row, + const vector_size_t* offsets, + const vector_size_t* sizes, + const ErrorVectorPtr& elementErrors, + const exec::LocalDecodedVector& bitsDecoder) const = 0; +}; + +class AllMatchFunction : public MatchFunction { + private: + void applyInternal( + FlatVector* flatResult, + exec::EvalCtx& context, + vector_size_t row, + const vector_size_t* offsets, + const vector_size_t* sizes, + const ErrorVectorPtr& elementErrors, + const exec::LocalDecodedVector& bitsDecoder) const override { + auto size = sizes[row]; + auto offset = offsets[row]; + auto allMatch = true; + auto hasNull = false; + std::exception_ptr errorPtr{nullptr}; + for (auto i = 0; i < size; ++i) { + auto idx = offset + i; + if (hasError(elementErrors, idx)) { + errorPtr = *std::static_pointer_cast( + elementErrors->valueAt(idx)); + continue; + } + + if (bitsDecoder->isNullAt(idx)) { + hasNull = true; + } else if (!bitsDecoder->valueAt(idx)) { + allMatch = false; + break; + } + } + + // Errors for individual array elements should be suppressed only if the + // outcome can be decided by some other array element, e.g. if there is + // another element that returns 'false' for the predicate. + if (!allMatch) { + flatResult->set(row, false); + } else if (errorPtr) { + context.setError(row, errorPtr); + } else if (hasNull) { + flatResult->setNull(row, true); + } else { + flatResult->set(row, true); + } + } + + public: + static std::vector> signatures() { + // array(T), function(T) -> boolean + return {exec::FunctionSignatureBuilder() + .typeVariable("T") + .returnType("boolean") + .argumentType("array(T)") + .argumentType("function(T, boolean)") + .build()}; + } +}; + +class AnyMatchFunction : public MatchFunction { + private: + void applyInternal( + FlatVector* flatResult, + exec::EvalCtx& context, + vector_size_t row, + const vector_size_t* offsets, + const vector_size_t* sizes, + const ErrorVectorPtr& elementErrors, + const exec::LocalDecodedVector& bitsDecoder) const override { + auto size = sizes[row]; + auto offset = offsets[row]; + auto nonMatch = true; + auto hasNull = false; + std::exception_ptr errorPtr{nullptr}; + for (auto i = 0; i < size; ++i) { + auto idx = offset + i; + if (hasError(elementErrors, idx)) { + errorPtr = *std::static_pointer_cast( + elementErrors->valueAt(idx)); + continue; + } + + if (bitsDecoder->isNullAt(idx)) { + hasNull = true; + } else if (bitsDecoder->valueAt(idx)) { + nonMatch = false; + break; + } + } + + // Errors for individual array elements should be suppressed only if the + // outcome can be decided by some other array element, e.g. if there is + // another element that returns 'true' for the predicate. + if (!nonMatch) { + flatResult->set(row, true); + } else if (errorPtr) { + context.setError(row, errorPtr); + } else if (hasNull) { + flatResult->setNull(row, true); + } else { + flatResult->set(row, false); + } + } + + public: + static std::vector> signatures() { + // array(T), function(T) -> boolean + return {exec::FunctionSignatureBuilder() + .typeVariable("T") + .returnType("boolean") + .argumentType("array(T)") + .argumentType("function(T, boolean)") + .build()}; + } +}; + +} // namespace + +VELOX_DECLARE_VECTOR_FUNCTION( + udf_all_match, + AllMatchFunction::signatures(), + std::make_unique()); + +VELOX_DECLARE_VECTOR_FUNCTION( + udf_any_match, + AnyMatchFunction::signatures(), + std::make_unique()); + +} // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/CMakeLists.txt b/velox/functions/prestosql/CMakeLists.txt index 9240e4ad200f..1d6898c1ddf6 100644 --- a/velox/functions/prestosql/CMakeLists.txt +++ b/velox/functions/prestosql/CMakeLists.txt @@ -17,12 +17,12 @@ add_subdirectory(window) add_library( velox_functions_prestosql_impl - ArrayAllMatch.cpp ArrayConstructor.cpp ArrayContains.cpp ArrayDistinct.cpp ArrayDuplicates.cpp ArrayIntersectExcept.cpp + ArrayMatch.cpp ArrayPosition.cpp ArrayShuffle.cpp ArraySort.cpp diff --git a/velox/functions/prestosql/registration/ArrayFunctionsRegistration.cpp b/velox/functions/prestosql/registration/ArrayFunctionsRegistration.cpp index dfffebb71096..0325d23ed497 100644 --- a/velox/functions/prestosql/registration/ArrayFunctionsRegistration.cpp +++ b/velox/functions/prestosql/registration/ArrayFunctionsRegistration.cpp @@ -79,6 +79,7 @@ inline void registerArrayNormalizeFunctions(const std::string& prefix) { void registerArrayFunctions(const std::string& prefix) { registerArrayConstructor(prefix + "array_constructor"); VELOX_REGISTER_VECTOR_FUNCTION(udf_all_match, prefix + "all_match"); + VELOX_REGISTER_VECTOR_FUNCTION(udf_any_match, prefix + "any_match"); VELOX_REGISTER_VECTOR_FUNCTION(udf_array_distinct, prefix + "array_distinct"); VELOX_REGISTER_VECTOR_FUNCTION( udf_array_duplicates, prefix + "array_duplicates"); diff --git a/velox/functions/prestosql/tests/ArrayAnyMatchTest.cpp b/velox/functions/prestosql/tests/ArrayAnyMatchTest.cpp new file mode 100644 index 000000000000..ee8729789267 --- /dev/null +++ b/velox/functions/prestosql/tests/ArrayAnyMatchTest.cpp @@ -0,0 +1,173 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" + +using namespace facebook::velox; +using namespace facebook::velox::test; + +class ArrayAnyMatchTest : public functions::test::FunctionBaseTest { + protected: + // Evaluate an expression. + void testExpr( + const VectorPtr& expected, + const std::string& expression, + const VectorPtr& input) { + auto result = evaluate(expression, makeRowVector({input})); + assertEqualVectors(expected, result); + } + + void testExpr( + const VectorPtr& expected, + const std::string& expression, + const RowVectorPtr& input) { + auto result = evaluate(expression, (input)); + assertEqualVectors(expected, result); + } +}; + +TEST_F(ArrayAnyMatchTest, basic) { + auto input = makeNullableArrayVector( + {{std::nullopt, 2, 3}, + {-1, 3}, + {2, 3}, + {}, + {std::nullopt, std::nullopt}}); + auto expectedResult = + makeNullableFlatVector({true, true, true, false, std::nullopt}); + testExpr(expectedResult, "any_match(c0, x -> (x > 1))", input); + expectedResult = makeFlatVector({true, false, false, false, true}); + testExpr(expectedResult, "any_match(c0, x -> (x is null))", input); + + input = makeNullableArrayVector( + {{false, true}, + {false, false}, + {std::nullopt, true}, + {std::nullopt, false}}); + expectedResult = + makeNullableFlatVector({true, false, true, std::nullopt}); + testExpr(expectedResult, "any_match(c0, x -> x)", input); + + auto emptyInput = makeArrayVector({{}}); + expectedResult = makeFlatVector(std::vector{false}); + testExpr(expectedResult, "any_match(c0, x -> (x > 1))", emptyInput); + testExpr(expectedResult, "any_match(c0, x -> (x <= 1))", emptyInput); +} + +TEST_F(ArrayAnyMatchTest, complexTypes) { + auto baseVector = + makeArrayVector({{1, 2, 3}, {2, 2}, {3, 3}, {4, 4}, {5, 5}, {}}); + // Create an array of array vector using above base vector using offsets. + // [ + // [[1, 2, 3], []], + // [[2, 2], [3, 3], [4, 4], [5, 5]], + // [[], []] + // ] + auto arrayOfArrays = makeArrayVector({0, 1, 5}, baseVector); + auto expectedResult = makeNullableFlatVector({true, true, false}); + testExpr( + expectedResult, + "any_match(c0, x -> (cardinality(x) > 0))", + arrayOfArrays); + + // Create an array of array vector using above base vector using offsets. + // [ + // [[1, 2, 3]], cardinalities is 3 + // [[2, 2], [3, 3], [4, 4], [5, 5]], all cardinalities is 2 + // [[]], + // null + // ] + arrayOfArrays = makeArrayVector({0, 1, 5, 6}, baseVector, {3}); + expectedResult = + makeNullableFlatVector({true, false, false, std::nullopt}); + testExpr( + expectedResult, + "any_match(c0, x -> (cardinality(x) > 2))", + arrayOfArrays); +} + +TEST_F(ArrayAnyMatchTest, bigints) { + auto input = makeNullableArrayVector( + {{}, + {2}, + {std::numeric_limits::max()}, + {std::numeric_limits::min()}, + {std::nullopt, std::nullopt}, // return null if all is null + {2, + std::nullopt}, // return null if one or more is null and others matched + {1, std::nullopt, 2}}); // return false if one is not matched + auto expectedResult = makeNullableFlatVector( + {false, true, false, true, std::nullopt, true, true}); + testExpr(expectedResult, "any_match(c0, x -> (x % 2 = 0))", input); +} + +TEST_F(ArrayAnyMatchTest, strings) { + auto input = makeNullableArrayVector( + {{}, {"abc"}, {"ab", "abc"}, {std::nullopt}}); + auto expectedResult = + makeNullableFlatVector({false, true, true, std::nullopt}); + testExpr(expectedResult, "any_match(c0, x -> (x = 'abc'))", input); +} + +TEST_F(ArrayAnyMatchTest, doubles) { + auto input = + makeNullableArrayVector({{}, {0.2}, {3.0, 0}, {std::nullopt}}); + auto expectedResult = + makeNullableFlatVector({false, false, true, std::nullopt}); + testExpr(expectedResult, "any_match(c0, x -> (x > 1.1))", input); +} + +TEST_F(ArrayAnyMatchTest, errors) { + // No throw and return false if there are unmatched elements except nulls + auto expression = "any_match(c0, x -> ((10 / x) > 2))"; + auto input = makeNullableArrayVector( + {{0, 2, 0, 5, 0}, {2, 5, std::nullopt, 0}}); + auto expectedResult = makeFlatVector({true, true}); + testExpr(expectedResult, expression, input); + + // Throw error if others are matched or null + static constexpr std::string_view kErrorMessage{"division by zero"}; + auto errorInput = makeNullableArrayVector( + {{1, 0}, {2}, {6}, {10, 9, 0, std::nullopt}, {0, std::nullopt, 1}}); + VELOX_ASSERT_THROW( + testExpr(expectedResult, expression, errorInput), kErrorMessage); + // Rerun using TRY to get right results + expectedResult = + makeNullableFlatVector({true, true, false, std::nullopt, true}); + testExpr( + expectedResult, "TRY(any_match(c0, x -> ((10 / x) > 2)))", errorInput); + testExpr( + expectedResult, "any_match(c0, x -> (TRY((10 / x) > 2)))", errorInput); +} + +TEST_F(ArrayAnyMatchTest, conditional) { + // No throw and return false if there are unmatched elements except nulls + auto c0 = makeFlatVector({1, 2, 3, 4, 5}); + auto c1 = makeNullableArrayVector( + {{4, 100, std::nullopt}, + {50, 12}, + {std::nullopt}, + {3, std::nullopt, 0}, + {300, 100}}); + auto input = makeRowVector({c0, c1}); + auto expectedResult = makeNullableFlatVector( + {std::nullopt, false, std::nullopt, true, false}); + testExpr( + expectedResult, + "any_match(c1, if (c0 <= 2, x -> (x > 100), x -> (10 / x > 2)))", + input); +} diff --git a/velox/functions/prestosql/tests/CMakeLists.txt b/velox/functions/prestosql/tests/CMakeLists.txt index c47130c97308..6f43d4f6f851 100644 --- a/velox/functions/prestosql/tests/CMakeLists.txt +++ b/velox/functions/prestosql/tests/CMakeLists.txt @@ -18,6 +18,7 @@ add_executable( velox_functions_test ArithmeticTest.cpp ArrayAllMatchTest.cpp + ArrayAnyMatchTest.cpp ArrayAverageTest.cpp ArrayCombinationsTest.cpp ArrayConstructorTest.cpp