Skip to content

Commit

Permalink
resolve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
duanmeng committed Mar 21, 2023
1 parent 83dcfab commit e781eb9
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 95 deletions.
2 changes: 1 addition & 1 deletion velox/docs/functions/presto/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Array Functions

.. function:: any_match(array(T), function(T, boolean)) → boolean

Returns whether at least one element of an array match the given predicate.
Returns whether at least one element of an array matchs the given predicate.

Returns true if one or more elements match the predicate;
Returns false if none of the elements matches (a special case is when the array is empty);
Expand Down
75 changes: 39 additions & 36 deletions velox/functions/prestosql/ArrayMatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@
namespace facebook::velox::functions {
namespace {

template <bool match>
enum class MatchMethod { ALL = 0, ANY = 1, NONE = 2 };

template <MatchMethod mathMethod>
class MatchFunction : public exec::VectorFunction {
private:
void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
const TypePtr& outputType,
const TypePtr& /*outputType*/,
exec::EvalCtx& context,
VectorPtr& result) const override {
exec::LocalDecodedVector arrayDecoder(context, *args[0], rows);
Expand Down Expand Up @@ -89,7 +91,6 @@ class MatchFunction : public exec::VectorFunction {
}
}

protected:
static FOLLY_ALWAYS_INLINE bool hasError(
const ErrorVectorPtr& errors,
int idx) {
Expand All @@ -108,47 +109,49 @@ class MatchFunction : public exec::VectorFunction {
const exec::LocalDecodedVector& bitsDecoder) const {
auto size = sizes[row];
auto offset = offsets[row];

// All or any match.
// All: early exit on non-match. Initial value = true.
// Any: early exit on match. Initial value = false.
bool result = match;
auto hasNull = false;
std::exception_ptr errorPtr{nullptr};
for (auto i = 0; i < size; ++i) {
auto idx = offset + i;
if (hasError(elementErrors, idx)) {
errorPtr = *std::static_pointer_cast<std::exception_ptr>(
elementErrors->valueAt(idx));
continue;
if (mathMethod == MatchMethod::NONE) {
// TODO: Add none_match
} else {
// All or any match.
// All: early exit on non-match. Initial value = true.
// Any: early exit on match. Initial value = false.
auto match = mathMethod == MatchMethod::ALL;
bool result = match;
auto hasNull = false;
std::exception_ptr errorPtr{nullptr};
for (auto i = 0; i < size; ++i) {
auto idx = offset + i;
if (hasError(elementErrors, idx)) {
errorPtr = *std::static_pointer_cast<std::exception_ptr>(
elementErrors->valueAt(idx));
continue;
}

if (bitsDecoder->isNullAt(idx)) {
hasNull = true;
} else if (bitsDecoder->valueAt<bool>(idx) == !match) {
result = !result;
break;
}
}

if (bitsDecoder->isNullAt(idx)) {
hasNull = true;
} else if (bitsDecoder->valueAt<bool>(idx) == !match) {
result = !result;
break;
if (result != match) {
flatResult->set(row, !match);
} else if (errorPtr) {
context.setError(row, errorPtr);
} else if (hasNull) {
flatResult->setNull(row, true);
} else {
flatResult->set(row, match);
}
}

// Errors for individual array elements should be suppressed only if the
// outcome can be decided by some other array element, e.g. if there is
// another element that returns 'true' for the predicate.
if (result != match) {
flatResult->set(row, !match);
} else if (errorPtr) {
context.setError(row, errorPtr);
} else if (hasNull) {
flatResult->setNull(row, true);
} else {
flatResult->set(row, match);
}
}
};

class AllMatchFunction : public MatchFunction<true> {};
class AllMatchFunction : public MatchFunction<MatchMethod::ALL> {};
class AnyMatchFunction : public MatchFunction<MatchMethod::ANY> {};

class AnyMatchFunction : public MatchFunction<false> {};
// TODO: add class NoneMatchFunction

std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
// array(T), function(T) -> boolean
Expand Down
90 changes: 32 additions & 58 deletions velox/functions/prestosql/tests/ArrayAnyMatchTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/

#include <folly/Format.h>
#include "velox/common/base/tests/GTestUtils.h"
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"

Expand All @@ -24,19 +25,20 @@ class ArrayAnyMatchTest : public functions::test::FunctionBaseTest {
protected:
// Evaluate an expression.
void testExpr(
const VectorPtr& expected,
const std::string& expression,
const std::vector<std::optional<bool>>& expected,
const std::string& lambdaExpr,
const VectorPtr& input) {
auto expression = folly::sformat("any_match(c0, x -> ({}))", lambdaExpr);
auto result = evaluate(expression, makeRowVector({input}));
assertEqualVectors(expected, result);
assertEqualVectors(makeNullableFlatVector<bool>(expected), result);
}

void testExpr(
const VectorPtr& expected,
const std::vector<std::optional<bool>>& expected,
const std::string& expression,
const RowVectorPtr& input) {
auto result = evaluate(expression, (input));
assertEqualVectors(expected, result);
assertEqualVectors(makeNullableFlatVector<bool>(expected), result);
}
};

Expand All @@ -47,25 +49,19 @@ TEST_F(ArrayAnyMatchTest, basic) {
{2, 3},
{},
{std::nullopt, std::nullopt}});
auto expectedResult =
makeNullableFlatVector<bool>({true, true, true, false, std::nullopt});
testExpr(expectedResult, "any_match(c0, x -> (x > 1))", input);
expectedResult = makeFlatVector<bool>({true, false, false, false, true});
testExpr(expectedResult, "any_match(c0, x -> (x is null))", input);
std::vector<std::optional<bool>> expectedResult{
true, true, true, false, std::nullopt};
testExpr(expectedResult, "x > 1", input);
expectedResult = {true, false, false, false, true};
testExpr(expectedResult, "x is null", input);

input = makeNullableArrayVector<bool>(
{{false, true},
{false, false},
{std::nullopt, true},
{std::nullopt, false}});
expectedResult =
makeNullableFlatVector<bool>({true, false, true, std::nullopt});
testExpr(expectedResult, "any_match(c0, x -> x)", input);

auto emptyInput = makeArrayVector<int32_t>({{}});
expectedResult = makeFlatVector<bool>(std::vector<bool>{false});
testExpr(expectedResult, "any_match(c0, x -> (x > 1))", emptyInput);
testExpr(expectedResult, "any_match(c0, x -> (x <= 1))", emptyInput);
expectedResult = {true, false, true, std::nullopt};
testExpr(expectedResult, "x", input);
}

TEST_F(ArrayAnyMatchTest, complexTypes) {
Expand All @@ -78,11 +74,8 @@ TEST_F(ArrayAnyMatchTest, complexTypes) {
// [[], []]
// ]
auto arrayOfArrays = makeArrayVector({0, 1, 5}, baseVector);
auto expectedResult = makeNullableFlatVector<bool>({true, true, false});
testExpr(
expectedResult,
"any_match(c0, x -> (cardinality(x) > 0))",
arrayOfArrays);
std::vector<std::optional<bool>> expectedResult{true, true, false};
testExpr(expectedResult, "cardinality(x) > 0", arrayOfArrays);

// Create an array of array vector using above base vector using offsets.
// [
Expand All @@ -92,51 +85,32 @@ TEST_F(ArrayAnyMatchTest, complexTypes) {
// null
// ]
arrayOfArrays = makeArrayVector({0, 1, 5, 6}, baseVector, {3});
expectedResult =
makeNullableFlatVector<bool>({true, false, false, std::nullopt});
testExpr(
expectedResult,
"any_match(c0, x -> (cardinality(x) > 2))",
arrayOfArrays);
}

TEST_F(ArrayAnyMatchTest, bigints) {
auto input = makeNullableArrayVector<int64_t>(
{{},
{2},
{std::numeric_limits<int64_t>::max()},
{std::numeric_limits<int64_t>::min()},
{std::nullopt, std::nullopt}, // return null if all is null
{2,
std::nullopt}, // return null if one or more is null and others matched
{1, std::nullopt, 2}}); // return false if one is not matched
auto expectedResult = makeNullableFlatVector<bool>(
{false, true, false, true, std::nullopt, true, true});
testExpr(expectedResult, "any_match(c0, x -> (x % 2 = 0))", input);
expectedResult = {true, false, false, std::nullopt};
testExpr(expectedResult, "cardinality(x) > 2", arrayOfArrays);
}

TEST_F(ArrayAnyMatchTest, strings) {
auto input = makeNullableArrayVector<StringView>(
{{}, {"abc"}, {"ab", "abc"}, {std::nullopt}});
auto expectedResult =
makeNullableFlatVector<bool>({false, true, true, std::nullopt});
testExpr(expectedResult, "any_match(c0, x -> (x = 'abc'))", input);
std::vector<std::optional<bool>> expectedResult{
false, true, true, std::nullopt};
testExpr(expectedResult, "x = 'abc'", input);
}

TEST_F(ArrayAnyMatchTest, doubles) {
auto input =
makeNullableArrayVector<double>({{}, {0.2}, {3.0, 0}, {std::nullopt}});
auto expectedResult =
makeNullableFlatVector<bool>({false, false, true, std::nullopt});
testExpr(expectedResult, "any_match(c0, x -> (x > 1.1))", input);
std::vector<std::optional<bool>> expectedResult{
false, false, true, std::nullopt};
testExpr(expectedResult, "x > 1.1", input);
}

TEST_F(ArrayAnyMatchTest, errors) {
// No throw and return false if there are unmatched elements except nulls
auto expression = "any_match(c0, x -> ((10 / x) > 2))";
auto expression = "(10 / x) > 2";
auto input = makeNullableArrayVector<int8_t>(
{{0, 2, 0, 5, 0}, {2, 5, std::nullopt, 0}});
auto expectedResult = makeFlatVector<bool>({true, true});
std::vector<std::optional<bool>> expectedResult = {true, true};
testExpr(expectedResult, expression, input);

// Throw error if others are matched or null
Expand All @@ -146,12 +120,12 @@ TEST_F(ArrayAnyMatchTest, errors) {
VELOX_ASSERT_THROW(
testExpr(expectedResult, expression, errorInput), kErrorMessage);
// Rerun using TRY to get right results
expectedResult =
makeNullableFlatVector<bool>({true, true, false, std::nullopt, true});
auto errorInputRow = makeRowVector({errorInput});
expectedResult = {true, true, false, std::nullopt, true};
testExpr(
expectedResult, "TRY(any_match(c0, x -> ((10 / x) > 2)))", errorInput);
expectedResult, "TRY(any_match(c0, x -> ((10 / x) > 2)))", errorInputRow);
testExpr(
expectedResult, "any_match(c0, x -> (TRY((10 / x) > 2)))", errorInput);
expectedResult, "any_match(c0, x -> (TRY((10 / x) > 2)))", errorInputRow);
}

TEST_F(ArrayAnyMatchTest, conditional) {
Expand All @@ -164,8 +138,8 @@ TEST_F(ArrayAnyMatchTest, conditional) {
{3, std::nullopt, 0},
{300, 100}});
auto input = makeRowVector({c0, c1});
auto expectedResult = makeNullableFlatVector<bool>(
{std::nullopt, false, std::nullopt, true, false});
std::vector<std::optional<bool>> expectedResult = {
std::nullopt, false, std::nullopt, true, false};
testExpr(
expectedResult,
"any_match(c1, if (c0 <= 2, x -> (x > 100), x -> (10 / x > 2)))",
Expand Down

0 comments on commit e781eb9

Please sign in to comment.