Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add any_match and none_match Presto functions #4327

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions velox/docs/functions/presto/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,24 @@ Array Functions
Returns NULL if the predicate function returns NULL for one or more elements and true for all other elements.
Throws an exception if the predicate fails for one or more elements and returns true or NULL for the rest.

.. function:: any_match(array(T), function(T, boolean)) → boolean

Returns whether at least one element of an array matches the given predicate.

Returns true if one or more elements match the predicate;
Returns false if none of the elements matches (a special case is when the array is empty);
Returns NULL if the predicate function returns NULL for one or more elements and false for all other elements.
Throws an exception if the predicate fails for one or more elements and returns false or NULL for the rest.

.. function:: none_match(array(T), function(T, boolean)) → boolean

Returns whether no elements of an array match the given predicate.

Returns true if none of the elements matches the predicate (a special case is when the array is empty);
Returns false if one or more elements match;
Returns NULL if the predicate function returns NULL for one or more elements and false for all other elements.
Throws an exception if the predicate fails for one or more elements and returns false or NULL for the rest.

.. function:: array_average(array(double)) -> double

Returns the average of all non-null elements of the array. If there are no non-null elements, returns null.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
namespace facebook::velox::functions {
namespace {

class AllMatchFunction : public exec::VectorFunction {
public:
template <bool initialValue, bool earlyReturn>
class MatchFunction : public exec::VectorFunction {
private:
void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
Expand Down Expand Up @@ -76,66 +77,114 @@ class AllMatchFunction : public exec::VectorFunction {

bitsDecoder.get()->decode(*matchBits, elementRows);
entry.rows->applyToSelected([&](vector_size_t row) {
auto size = sizes[row];
auto offset = offsets[row];
auto allMatch = true;
auto hasNull = false;
std::exception_ptr errorPtr{nullptr};
for (auto i = 0; i < size; ++i) {
auto idx = offset + i;
if (hasError(elementErrors, idx)) {
errorPtr = *std::static_pointer_cast<std::exception_ptr>(
elementErrors->valueAt(idx));
continue;
}

if (bitsDecoder->isNullAt(idx)) {
hasNull = true;
} else if (!bitsDecoder->valueAt<bool>(idx)) {
allMatch = false;
break;
}
}

// Errors for individual array elements should be suppressed only if the
// outcome can be decided by some other array element, e.g. if there is
// another element that returns 'false' for the predicate.
if (!allMatch) {
flatResult->set(row, false);
} else if (errorPtr) {
context.setError(row, errorPtr);
} else if (hasNull) {
flatResult->setNull(row, true);
} else {
flatResult->set(row, true);
}
applyInternal(
flatResult,
context,
row,
offsets,
sizes,
elementErrors,
bitsDecoder);
});
}
}

static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
// array(T), function(T) -> boolean
return {exec::FunctionSignatureBuilder()
.typeVariable("T")
.returnType("boolean")
.argumentType("array(T)")
.argumentType("function(T, boolean)")
.build()};
static FOLLY_ALWAYS_INLINE bool hasError(
const ErrorVectorPtr& errors,
int idx) {
return errors && idx < errors->size() && !errors->isNullAt(idx);
}

private:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This "private:" is not needed as there is already one above it.

FOLLY_ALWAYS_INLINE bool hasError(
void applyInternal(
FlatVector<bool>* flatResult,
exec::EvalCtx& context,
vector_size_t row,
const vector_size_t* offsets,
const vector_size_t* sizes,
const ErrorVectorPtr& elementErrors,
int idx) const {
return elementErrors && idx < elementErrors->size() &&
!elementErrors->isNullAt(idx);
const exec::LocalDecodedVector& bitsDecoder) const {
auto size = sizes[row];
auto offset = offsets[row];

// All, none, and any match have different and similar logic intertwined in
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this comment is easy to understand.

Would it be helpful to show pseudo-code for different loops along with a generic loop?

all_match, none_match and any_match need to loop over predicate results for element arrays and check for results, nulls and errors. These loops can be generalized using two booleans.

Here is what the individual loops look like.

---- kAll ----
bool allMatch = true

loop:
   if not match:
    allMatch = false;
    break;

if (!allMatch) -> false
else if hasError -> error
else if hasNull -> null
else -> true

---- kAny ----

bool anyMatch = false

loop:
   if match:
    anyMatch = true;
    break;

if (anyMatch) -> true
else if hasError -> error
else if hasNull -> null
else -> false

---- kNone ----

bool noneMatch = true;

loop:
   if match:
    noneMatch = false;
    break;

if (!noneMatch) -> false
else if hasError -> error
else if hasNull -> null
else -> true

To generalize these loops, we use initialValue and earlyReturn booleans like so:

--- generic loop ---

bool result = initialValue

loop:
   if match == earlyReturn:
    result = false;
    break;

if (result != initialValue) -> result
else if hasError -> error
else if hasNull -> null
else -> result

// terms of the initial value, flip condition, and the result finalization:
//
// Initial value:
// All is true
// Any is false
// None is true
//
// Flip logic:
// All flips once encounter an unmatched element
// Any flips once encounter a matched element
// None flips once encounter a matched element
//
// Result finalization:
// All: ignore the error and null if one or more elements are unmatched and
// return false Any: ignore the error and null if one or more elements
// matched and return true None: ignore the error and null if one or more
// elements matched and return false
bool result = initialValue;
auto hasNull = false;
std::exception_ptr errorPtr{nullptr};
for (auto i = 0; i < size; ++i) {
auto idx = offset + i;
if (hasError(elementErrors, idx)) {
errorPtr = *std::static_pointer_cast<std::exception_ptr>(
elementErrors->valueAt(idx));
continue;
}

if (bitsDecoder->isNullAt(idx)) {
hasNull = true;
} else if (bitsDecoder->valueAt<bool>(idx) == earlyReturn) {
result = !result;
break;
}
}

if (result != initialValue) {
flatResult->set(row, !initialValue);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: perhaps, it would be more readable to do

flatResult->set(row, result);

} else if (errorPtr) {
context.setError(row, errorPtr);
} else if (hasNull) {
flatResult->setNull(row, true);
} else {
flatResult->set(row, initialValue);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and

flatResult->set(row, result);

}
}
};

class AllMatchFunction : public MatchFunction<true, false> {};
class AnyMatchFunction : public MatchFunction<false, true> {};
class NoneMatchFunction : public MatchFunction<true, true> {};

std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
// array(T), function(T) -> boolean
return {exec::FunctionSignatureBuilder()
.typeVariable("T")
.returnType("boolean")
.argumentType("array(T)")
.argumentType("function(T, boolean)")
.build()};
}

} // namespace

VELOX_DECLARE_VECTOR_FUNCTION(
udf_all_match,
AllMatchFunction::signatures(),
signatures(),
std::make_unique<AllMatchFunction>());

VELOX_DECLARE_VECTOR_FUNCTION(
udf_any_match,
signatures(),
std::make_unique<AnyMatchFunction>());

VELOX_DECLARE_VECTOR_FUNCTION(
udf_none_match,
signatures(),
std::make_unique<NoneMatchFunction>());

} // namespace facebook::velox::functions
2 changes: 1 addition & 1 deletion velox/functions/prestosql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ add_subdirectory(window)

add_library(
velox_functions_prestosql_impl
ArrayAllMatch.cpp
ArrayConstructor.cpp
ArrayContains.cpp
ArrayDistinct.cpp
ArrayDuplicates.cpp
ArrayIntersectExcept.cpp
ArrayMatch.cpp
ArrayPosition.cpp
ArrayShuffle.cpp
ArraySort.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ inline void registerArrayNormalizeFunctions(const std::string& prefix) {
void registerArrayFunctions(const std::string& prefix) {
registerArrayConstructor(prefix + "array_constructor");
VELOX_REGISTER_VECTOR_FUNCTION(udf_all_match, prefix + "all_match");
VELOX_REGISTER_VECTOR_FUNCTION(udf_any_match, prefix + "any_match");
VELOX_REGISTER_VECTOR_FUNCTION(udf_none_match, prefix + "none_match");
VELOX_REGISTER_VECTOR_FUNCTION(udf_array_distinct, prefix + "array_distinct");
VELOX_REGISTER_VECTOR_FUNCTION(
udf_array_duplicates, prefix + "array_duplicates");
Expand Down
155 changes: 155 additions & 0 deletions velox/functions/prestosql/tests/ArrayAnyMatchTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <folly/Format.h>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove this and use fmt::format instead.

#include "velox/common/base/tests/GTestUtils.h"
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"

using namespace facebook::velox;
using namespace facebook::velox::test;

class ArrayAnyMatchTest : public functions::test::FunctionBaseTest {
protected:
// Evaluate an expression.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

void testExpr(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: testAnyMatch to distinguish from testExpr

const std::vector<std::optional<bool>>& expected,
const std::string& lambdaExpr,
const VectorPtr& input) {
auto expression = folly::sformat("any_match(c0, x -> ({}))", lambdaExpr);
auto result = evaluate(expression, makeRowVector({input}));
assertEqualVectors(makeNullableFlatVector<bool>(expected), result);
}

template <typename T>
void testExpr(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

testAnyMatch

const std::vector<std::optional<bool>>& expected,
const std::string& lambdaExpr,
const std::vector<std::vector<std::optional<T>>>& input) {
auto expression = folly::sformat("any_match(c0, x -> ({}))", lambdaExpr);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps, testAnyMatch(expected, lambdaExpr, makeNullableArrayVector(input)) to reduce code duplication

auto result =
evaluate(expression, makeRowVector({makeNullableArrayVector(input)}));
assertEqualVectors(makeNullableFlatVector<bool>(expected), result);
}

void testExpr(
const std::vector<std::optional<bool>>& expected,
const std::string& expression,
const RowVectorPtr& input) {
auto result = evaluate(expression, (input));
assertEqualVectors(makeNullableFlatVector<bool>(expected), result);
}
};

TEST_F(ArrayAnyMatchTest, basic) {
std::vector<std::vector<std::optional<int32_t>>> ints{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comma after last value in the vector for readability. This way each array value will appear on a separate line making it easier to understand.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks.

{std::nullopt, 2, 0}, {-1, 3}, {-2, -3}, {}, {0, std::nullopt}};
std::vector<std::optional<bool>> expectedResult{
true, true, false, false, std::nullopt};
testExpr(expectedResult, "x > 1", ints);

expectedResult = {true, false, false, false, true};
testExpr(expectedResult, "x is null", ints);

std::vector<std::vector<std::optional<bool>>> bools{
{false, true},
{false, false},
{std::nullopt, true},
{std::nullopt, false}};
expectedResult = {true, false, true, std::nullopt};
testExpr(expectedResult, "x", bools);
}

TEST_F(ArrayAnyMatchTest, complexTypes) {
auto baseVector =
makeArrayVector<int64_t>({{1, 2, 3}, {2, 2}, {3, 3}, {4, 4}, {5, 5}, {}});
// Create an array of array vector using above base vector using offsets.
// [
// [[1, 2, 3], []],
// [[2, 2], [3, 3], [4, 4], [5, 5]],
// [[], []]
// ]
auto arrayOfArrays = makeArrayVector({0, 1, 5}, baseVector);
std::vector<std::optional<bool>> expectedResult{true, true, false};
testExpr(expectedResult, "cardinality(x) > 0", arrayOfArrays);

// Create an array of array vector using above base vector using offsets.
// [
// [[1, 2, 3]], cardinalities is 3
// [[2, 2], [3, 3], [4, 4], [5, 5]], all cardinalities is 2
// [[]],
// null
// ]
arrayOfArrays = makeArrayVector({0, 1, 5, 6}, baseVector, {3});
expectedResult = {true, false, false, std::nullopt};
testExpr(expectedResult, "cardinality(x) > 2", arrayOfArrays);
}

TEST_F(ArrayAnyMatchTest, strings) {
std::vector<std::vector<std::optional<StringView>>> input{
{}, {"abc"}, {"ab", "abc"}, {std::nullopt}};
std::vector<std::optional<bool>> expectedResult{
false, true, true, std::nullopt};
testExpr(expectedResult, "x = 'abc'", input);
}

TEST_F(ArrayAnyMatchTest, doubles) {
std::vector<std::vector<std::optional<double>>> input{
{}, {0.2}, {3.0, 0}, {std::nullopt}};
std::vector<std::optional<bool>> expectedResult{
false, false, true, std::nullopt};
testExpr(expectedResult, "x > 1.1", input);
}

TEST_F(ArrayAnyMatchTest, errors) {
// No throw and return false if there are unmatched elements except nulls
auto expression = "(10 / x) > 2";
std::vector<std::vector<std::optional<int8_t>>> input{
{0, 2, 0, 5, 0}, {2, 5, std::nullopt, 0}};
std::vector<std::optional<bool>> expectedResult = {true, true};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps,

testAnyMatch({true, true}, "(10 / x) > 2", input);

testExpr(expectedResult, expression, input);

// Throw error if others are matched or null
static constexpr std::string_view kErrorMessage{"division by zero"};
auto errorInput = makeNullableArrayVector<int8_t>(
{{1, 0}, {2}, {6}, {10, 9, 0, std::nullopt}, {0, std::nullopt, 1}});
VELOX_ASSERT_THROW(
testExpr(expectedResult, expression, errorInput), kErrorMessage);
// Rerun using TRY to get right results
auto errorInputRow = makeRowVector({errorInput});
expectedResult = {true, true, false, std::nullopt, true};
testExpr(
expectedResult, "TRY(any_match(c0, x -> ((10 / x) > 2)))", errorInputRow);
testExpr(
expectedResult, "any_match(c0, x -> (TRY((10 / x) > 2)))", errorInputRow);
}

TEST_F(ArrayAnyMatchTest, conditional) {
// No throw and return false if there are unmatched elements except nulls
auto c0 = makeFlatVector<uint32_t>({1, 2, 3, 4, 5});
auto c1 = makeNullableArrayVector<int32_t>(
{{4, 100, std::nullopt},
{50, 12},
{std::nullopt},
{3, std::nullopt, 0},
{300, 100}});
auto input = makeRowVector({c0, c1});
std::vector<std::optional<bool>> expectedResult = {
std::nullopt, false, std::nullopt, true, false};
testExpr(
expectedResult,
"any_match(c1, if (c0 <= 2, x -> (x > 100), x -> (10 / x > 2)))",
input);
}
Loading