Skip to content

Commit

Permalink
resolve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
duanmeng committed Mar 22, 2023
1 parent 4bb611d commit 1941a79
Show file tree
Hide file tree
Showing 6 changed files with 257 additions and 55 deletions.
11 changes: 10 additions & 1 deletion velox/docs/functions/presto/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,22 @@ Array Functions

.. function:: any_match(array(T), function(T, boolean)) → boolean

Returns whether at least one element of an array matchs the given predicate.
Returns whether at least one element of an array matches the given predicate.

Returns true if one or more elements match the predicate;
Returns false if none of the elements matches (a special case is when the array is empty);
Returns NULL if the predicate function returns NULL for one or more elements and false for all other elements.
Throws an exception if the predicate fails for one or more elements and returns false or NULL for the rest.

.. function:: none_match(array(T), function(T, boolean)) → boolean

Returns whether no elements of an array match the given predicate.

Returns true if none of the elements matches the predicate (a special case is when the array is empty);
Returns false if one or more elements match;
Returns NULL if the predicate function returns NULL for one or more elements and false for all other elements.
Throws an exception if the predicate fails for one or more elements and returns false or NULL for the rest.

.. function:: array_average(array(double)) -> double

Returns the average of all non-null elements of the array. If there are no non-null elements, returns null.
Expand Down
97 changes: 59 additions & 38 deletions velox/functions/prestosql/ArrayMatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
namespace facebook::velox::functions {
namespace {

enum class MatchMethod { ALL = 0, ANY = 1, NONE = 2 };
enum class MatchMethod { kAll = 0, kAny = 1, kNone = 2 };

template <MatchMethod mathMethod>
template <MatchMethod matchMethod>
class MatchFunction : public exec::VectorFunction {
private:
void apply(
Expand Down Expand Up @@ -94,8 +94,7 @@ class MatchFunction : public exec::VectorFunction {
static FOLLY_ALWAYS_INLINE bool hasError(
const ErrorVectorPtr& errors,
int idx) {
return errors && idx < errors->size() &&
!errors->isNullAt(idx);
return errors && idx < errors->size() && !errors->isNullAt(idx);
}

private:
Expand All @@ -109,49 +108,66 @@ class MatchFunction : public exec::VectorFunction {
const exec::LocalDecodedVector& bitsDecoder) const {
auto size = sizes[row];
auto offset = offsets[row];
if (mathMethod == MatchMethod::NONE) {
// TODO: Add none_match
} else {
// All or any match.
// All: early exit on non-match. Initial value = true.
// Any: early exit on match. Initial value = false.
auto match = mathMethod == MatchMethod::ALL;
bool result = match;
auto hasNull = false;
std::exception_ptr errorPtr{nullptr};
for (auto i = 0; i < size; ++i) {
auto idx = offset + i;
if (hasError(elementErrors, idx)) {
errorPtr = *std::static_pointer_cast<std::exception_ptr>(
elementErrors->valueAt(idx));
continue;
}

if (bitsDecoder->isNullAt(idx)) {
hasNull = true;
} else if (bitsDecoder->valueAt<bool>(idx) == !match) {
result = !result;
break;
}
// All, none, and any match have different and similar logic intertwined in
// terms of the initial value, flip condition, and the result finalization:
//
// Initial value:
// All is true
// Any is false
// None is true
//
// Flip logic:
// All flips once encounter an unmatched element
// Any flips once encounter a matched element
// None flips once encounter a matched element
//
// Result finalization:
// All: ignore the error and null if one or more elements are unmatched and
// return false Any: ignore the error and null if one or more elements
// matched and return true None: ignore the error and null if one or more
// elements matched and return false
auto match = (matchMethod != MatchMethod::kAny);
auto hasNull = false;
std::exception_ptr errorPtr{nullptr};
for (auto i = 0; i < size; ++i) {
auto idx = offset + i;
if (hasError(elementErrors, idx)) {
errorPtr = *std::static_pointer_cast<std::exception_ptr>(
elementErrors->valueAt(idx));
continue;
}

if (result != match) {
flatResult->set(row, !match);
} else if (errorPtr) {
context.setError(row, errorPtr);
} else if (hasNull) {
flatResult->setNull(row, true);
if (bitsDecoder->isNullAt(idx)) {
hasNull = true;
} else if (matchMethod == MatchMethod::kAll) {
if (!bitsDecoder->valueAt<bool>(idx)) {
match = !match;
break;
}
} else {
flatResult->set(row, match);
if (bitsDecoder->valueAt<bool>(idx)) {
match = !match;
break;
}
}
}

if ((matchMethod == MatchMethod::kAny) == match) {
flatResult->set(row, match);
} else if (errorPtr) {
context.setError(row, errorPtr);
} else if (hasNull) {
flatResult->setNull(row, true);
} else {
flatResult->set(row, match);
}
}
};

class AllMatchFunction : public MatchFunction<MatchMethod::ALL> {};
class AnyMatchFunction : public MatchFunction<MatchMethod::ANY> {};

// TODO: add class NoneMatchFunction
class AllMatchFunction : public MatchFunction<MatchMethod::kAll> {};
class AnyMatchFunction : public MatchFunction<MatchMethod::kAny> {};
class NoneMatchFunction : public MatchFunction<MatchMethod::kNone> {};

std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
// array(T), function(T) -> boolean
Expand All @@ -175,4 +191,9 @@ VELOX_DECLARE_VECTOR_FUNCTION(
signatures(),
std::make_unique<AnyMatchFunction>());

VELOX_DECLARE_VECTOR_FUNCTION(
udf_none_match,
signatures(),
std::make_unique<NoneMatchFunction>());

} // namespace facebook::velox::functions
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ void registerArrayFunctions(const std::string& prefix) {
registerArrayConstructor(prefix + "array_constructor");
VELOX_REGISTER_VECTOR_FUNCTION(udf_all_match, prefix + "all_match");
VELOX_REGISTER_VECTOR_FUNCTION(udf_any_match, prefix + "any_match");
VELOX_REGISTER_VECTOR_FUNCTION(udf_none_match, prefix + "none_match");
VELOX_REGISTER_VECTOR_FUNCTION(udf_array_distinct, prefix + "array_distinct");
VELOX_REGISTER_VECTOR_FUNCTION(
udf_array_duplicates, prefix + "array_duplicates");
Expand Down
43 changes: 27 additions & 16 deletions velox/functions/prestosql/tests/ArrayAnyMatchTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,17 @@ class ArrayAnyMatchTest : public functions::test::FunctionBaseTest {
assertEqualVectors(makeNullableFlatVector<bool>(expected), result);
}

template <typename T>
void testExpr(
const std::vector<std::optional<bool>>& expected,
const std::string& lambdaExpr,
const std::vector<std::vector<std::optional<T>>>& input) {
auto expression = folly::sformat("any_match(c0, x -> ({}))", lambdaExpr);
auto result =
evaluate(expression, makeRowVector({makeNullableArrayVector(input)}));
assertEqualVectors(makeNullableFlatVector<bool>(expected), result);
}

void testExpr(
const std::vector<std::optional<bool>>& expected,
const std::string& expression,
Expand All @@ -43,22 +54,22 @@ class ArrayAnyMatchTest : public functions::test::FunctionBaseTest {
};

TEST_F(ArrayAnyMatchTest, basic) {
auto input = makeNullableArrayVector<int64_t>(
{{std::nullopt, 2, 0}, {-1, 3}, {-2, -3}, {}, {0, std::nullopt}});
std::vector<std::vector<std::optional<int32_t>>> ints{
{std::nullopt, 2, 0}, {-1, 3}, {-2, -3}, {}, {0, std::nullopt}};
std::vector<std::optional<bool>> expectedResult{
true, true, false, false, std::nullopt};
testExpr(expectedResult, "x > 1", input);
testExpr(expectedResult, "x > 1", ints);

expectedResult = {true, false, false, false, true};
testExpr(expectedResult, "x is null", input);
testExpr(expectedResult, "x is null", ints);

input = makeNullableArrayVector<bool>(
{{false, true},
{false, false},
{std::nullopt, true},
{std::nullopt, false}});
std::vector<std::vector<std::optional<bool>>> bools{
{false, true},
{false, false},
{std::nullopt, true},
{std::nullopt, false}};
expectedResult = {true, false, true, std::nullopt};
testExpr(expectedResult, "x", input);
testExpr(expectedResult, "x", bools);
}

TEST_F(ArrayAnyMatchTest, complexTypes) {
Expand Down Expand Up @@ -87,16 +98,16 @@ TEST_F(ArrayAnyMatchTest, complexTypes) {
}

TEST_F(ArrayAnyMatchTest, strings) {
auto input = makeNullableArrayVector<StringView>(
{{}, {"abc"}, {"ab", "abc"}, {std::nullopt}});
std::vector<std::vector<std::optional<StringView>>> input{
{}, {"abc"}, {"ab", "abc"}, {std::nullopt}};
std::vector<std::optional<bool>> expectedResult{
false, true, true, std::nullopt};
testExpr(expectedResult, "x = 'abc'", input);
}

TEST_F(ArrayAnyMatchTest, doubles) {
auto input =
makeNullableArrayVector<double>({{}, {0.2}, {3.0, 0}, {std::nullopt}});
std::vector<std::vector<std::optional<double>>> input{
{}, {0.2}, {3.0, 0}, {std::nullopt}};
std::vector<std::optional<bool>> expectedResult{
false, false, true, std::nullopt};
testExpr(expectedResult, "x > 1.1", input);
Expand All @@ -105,8 +116,8 @@ TEST_F(ArrayAnyMatchTest, doubles) {
TEST_F(ArrayAnyMatchTest, errors) {
// No throw and return false if there are unmatched elements except nulls
auto expression = "(10 / x) > 2";
auto input = makeNullableArrayVector<int8_t>(
{{0, 2, 0, 5, 0}, {2, 5, std::nullopt, 0}});
std::vector<std::vector<std::optional<int8_t>>> input{
{0, 2, 0, 5, 0}, {2, 5, std::nullopt, 0}};
std::vector<std::optional<bool>> expectedResult = {true, true};
testExpr(expectedResult, expression, input);

Expand Down
Loading

0 comments on commit 1941a79

Please sign in to comment.