-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add any_match and none_match Presto functions #4327
Changes from 5 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,8 +24,9 @@ | |
namespace facebook::velox::functions { | ||
namespace { | ||
|
||
class AllMatchFunction : public exec::VectorFunction { | ||
public: | ||
template <bool initialValue, bool earlyReturn> | ||
class MatchFunction : public exec::VectorFunction { | ||
private: | ||
void apply( | ||
const SelectivityVector& rows, | ||
std::vector<VectorPtr>& args, | ||
|
@@ -76,66 +77,114 @@ class AllMatchFunction : public exec::VectorFunction { | |
|
||
bitsDecoder.get()->decode(*matchBits, elementRows); | ||
entry.rows->applyToSelected([&](vector_size_t row) { | ||
auto size = sizes[row]; | ||
auto offset = offsets[row]; | ||
auto allMatch = true; | ||
auto hasNull = false; | ||
std::exception_ptr errorPtr{nullptr}; | ||
for (auto i = 0; i < size; ++i) { | ||
auto idx = offset + i; | ||
if (hasError(elementErrors, idx)) { | ||
errorPtr = *std::static_pointer_cast<std::exception_ptr>( | ||
elementErrors->valueAt(idx)); | ||
continue; | ||
} | ||
|
||
if (bitsDecoder->isNullAt(idx)) { | ||
hasNull = true; | ||
} else if (!bitsDecoder->valueAt<bool>(idx)) { | ||
allMatch = false; | ||
break; | ||
} | ||
} | ||
|
||
// Errors for individual array elements should be suppressed only if the | ||
// outcome can be decided by some other array element, e.g. if there is | ||
// another element that returns 'false' for the predicate. | ||
if (!allMatch) { | ||
flatResult->set(row, false); | ||
} else if (errorPtr) { | ||
context.setError(row, errorPtr); | ||
} else if (hasNull) { | ||
flatResult->setNull(row, true); | ||
} else { | ||
flatResult->set(row, true); | ||
} | ||
applyInternal( | ||
flatResult, | ||
context, | ||
row, | ||
offsets, | ||
sizes, | ||
elementErrors, | ||
bitsDecoder); | ||
}); | ||
} | ||
} | ||
|
||
static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() { | ||
// array(T), function(T) -> boolean | ||
return {exec::FunctionSignatureBuilder() | ||
.typeVariable("T") | ||
.returnType("boolean") | ||
.argumentType("array(T)") | ||
.argumentType("function(T, boolean)") | ||
.build()}; | ||
static FOLLY_ALWAYS_INLINE bool hasError( | ||
const ErrorVectorPtr& errors, | ||
int idx) { | ||
return errors && idx < errors->size() && !errors->isNullAt(idx); | ||
} | ||
|
||
private: | ||
FOLLY_ALWAYS_INLINE bool hasError( | ||
void applyInternal( | ||
FlatVector<bool>* flatResult, | ||
exec::EvalCtx& context, | ||
vector_size_t row, | ||
const vector_size_t* offsets, | ||
const vector_size_t* sizes, | ||
const ErrorVectorPtr& elementErrors, | ||
int idx) const { | ||
return elementErrors && idx < elementErrors->size() && | ||
!elementErrors->isNullAt(idx); | ||
const exec::LocalDecodedVector& bitsDecoder) const { | ||
auto size = sizes[row]; | ||
auto offset = offsets[row]; | ||
|
||
// All, none, and any match have different and similar logic intertwined in | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if this comment is easy to understand. Would it be helpful to show pseudo-code for different loops along with a generic loop?
|
||
// terms of the initial value, flip condition, and the result finalization: | ||
// | ||
// Initial value: | ||
// All is true | ||
// Any is false | ||
// None is true | ||
// | ||
// Flip logic: | ||
// All flips once encounter an unmatched element | ||
// Any flips once encounter a matched element | ||
// None flips once encounter a matched element | ||
// | ||
// Result finalization: | ||
// All: ignore the error and null if one or more elements are unmatched and | ||
// return false Any: ignore the error and null if one or more elements | ||
// matched and return true None: ignore the error and null if one or more | ||
// elements matched and return false | ||
bool result = initialValue; | ||
auto hasNull = false; | ||
std::exception_ptr errorPtr{nullptr}; | ||
for (auto i = 0; i < size; ++i) { | ||
auto idx = offset + i; | ||
if (hasError(elementErrors, idx)) { | ||
errorPtr = *std::static_pointer_cast<std::exception_ptr>( | ||
elementErrors->valueAt(idx)); | ||
continue; | ||
} | ||
|
||
if (bitsDecoder->isNullAt(idx)) { | ||
hasNull = true; | ||
} else if (bitsDecoder->valueAt<bool>(idx) == earlyReturn) { | ||
result = !result; | ||
break; | ||
} | ||
} | ||
|
||
if (result != initialValue) { | ||
flatResult->set(row, !initialValue); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: perhaps, it would be more readable to do
|
||
} else if (errorPtr) { | ||
context.setError(row, errorPtr); | ||
} else if (hasNull) { | ||
flatResult->setNull(row, true); | ||
} else { | ||
flatResult->set(row, initialValue); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and
|
||
} | ||
} | ||
}; | ||
|
||
class AllMatchFunction : public MatchFunction<true, false> {}; | ||
class AnyMatchFunction : public MatchFunction<false, true> {}; | ||
class NoneMatchFunction : public MatchFunction<true, true> {}; | ||
|
||
std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() { | ||
// array(T), function(T) -> boolean | ||
return {exec::FunctionSignatureBuilder() | ||
.typeVariable("T") | ||
.returnType("boolean") | ||
.argumentType("array(T)") | ||
.argumentType("function(T, boolean)") | ||
.build()}; | ||
} | ||
|
||
} // namespace | ||
|
||
VELOX_DECLARE_VECTOR_FUNCTION( | ||
udf_all_match, | ||
AllMatchFunction::signatures(), | ||
signatures(), | ||
std::make_unique<AllMatchFunction>()); | ||
|
||
VELOX_DECLARE_VECTOR_FUNCTION( | ||
udf_any_match, | ||
signatures(), | ||
std::make_unique<AnyMatchFunction>()); | ||
|
||
VELOX_DECLARE_VECTOR_FUNCTION( | ||
udf_none_match, | ||
signatures(), | ||
std::make_unique<NoneMatchFunction>()); | ||
|
||
} // namespace facebook::velox::functions |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
/* | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include <folly/Format.h> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's remove this and use fmt::format instead. |
||
#include "velox/common/base/tests/GTestUtils.h" | ||
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" | ||
|
||
using namespace facebook::velox; | ||
using namespace facebook::velox::test; | ||
|
||
class ArrayAnyMatchTest : public functions::test::FunctionBaseTest { | ||
protected: | ||
// Evaluate an expression. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same |
||
void testExpr( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: testAnyMatch to distinguish from testExpr |
||
const std::vector<std::optional<bool>>& expected, | ||
const std::string& lambdaExpr, | ||
const VectorPtr& input) { | ||
auto expression = folly::sformat("any_match(c0, x -> ({}))", lambdaExpr); | ||
auto result = evaluate(expression, makeRowVector({input})); | ||
assertEqualVectors(makeNullableFlatVector<bool>(expected), result); | ||
} | ||
|
||
template <typename T> | ||
void testExpr( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. testAnyMatch |
||
const std::vector<std::optional<bool>>& expected, | ||
const std::string& lambdaExpr, | ||
const std::vector<std::vector<std::optional<T>>>& input) { | ||
auto expression = folly::sformat("any_match(c0, x -> ({}))", lambdaExpr); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. perhaps, |
||
auto result = | ||
evaluate(expression, makeRowVector({makeNullableArrayVector(input)})); | ||
assertEqualVectors(makeNullableFlatVector<bool>(expected), result); | ||
} | ||
|
||
void testExpr( | ||
const std::vector<std::optional<bool>>& expected, | ||
const std::string& expression, | ||
const RowVectorPtr& input) { | ||
auto result = evaluate(expression, (input)); | ||
assertEqualVectors(makeNullableFlatVector<bool>(expected), result); | ||
} | ||
}; | ||
|
||
TEST_F(ArrayAnyMatchTest, basic) { | ||
std::vector<std::vector<std::optional<int32_t>>> ints{ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a comma after last value in the vector for readability. This way each array value will appear on a separate line making it easier to understand. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it, thanks. |
||
{std::nullopt, 2, 0}, {-1, 3}, {-2, -3}, {}, {0, std::nullopt}}; | ||
std::vector<std::optional<bool>> expectedResult{ | ||
true, true, false, false, std::nullopt}; | ||
testExpr(expectedResult, "x > 1", ints); | ||
|
||
expectedResult = {true, false, false, false, true}; | ||
testExpr(expectedResult, "x is null", ints); | ||
|
||
std::vector<std::vector<std::optional<bool>>> bools{ | ||
{false, true}, | ||
{false, false}, | ||
{std::nullopt, true}, | ||
{std::nullopt, false}}; | ||
expectedResult = {true, false, true, std::nullopt}; | ||
testExpr(expectedResult, "x", bools); | ||
} | ||
|
||
TEST_F(ArrayAnyMatchTest, complexTypes) { | ||
auto baseVector = | ||
makeArrayVector<int64_t>({{1, 2, 3}, {2, 2}, {3, 3}, {4, 4}, {5, 5}, {}}); | ||
// Create an array of array vector using above base vector using offsets. | ||
// [ | ||
// [[1, 2, 3], []], | ||
// [[2, 2], [3, 3], [4, 4], [5, 5]], | ||
// [[], []] | ||
// ] | ||
auto arrayOfArrays = makeArrayVector({0, 1, 5}, baseVector); | ||
std::vector<std::optional<bool>> expectedResult{true, true, false}; | ||
testExpr(expectedResult, "cardinality(x) > 0", arrayOfArrays); | ||
|
||
// Create an array of array vector using above base vector using offsets. | ||
// [ | ||
// [[1, 2, 3]], cardinalities is 3 | ||
// [[2, 2], [3, 3], [4, 4], [5, 5]], all cardinalities is 2 | ||
// [[]], | ||
// null | ||
// ] | ||
arrayOfArrays = makeArrayVector({0, 1, 5, 6}, baseVector, {3}); | ||
expectedResult = {true, false, false, std::nullopt}; | ||
testExpr(expectedResult, "cardinality(x) > 2", arrayOfArrays); | ||
} | ||
|
||
TEST_F(ArrayAnyMatchTest, strings) { | ||
std::vector<std::vector<std::optional<StringView>>> input{ | ||
{}, {"abc"}, {"ab", "abc"}, {std::nullopt}}; | ||
std::vector<std::optional<bool>> expectedResult{ | ||
false, true, true, std::nullopt}; | ||
testExpr(expectedResult, "x = 'abc'", input); | ||
} | ||
|
||
TEST_F(ArrayAnyMatchTest, doubles) { | ||
std::vector<std::vector<std::optional<double>>> input{ | ||
{}, {0.2}, {3.0, 0}, {std::nullopt}}; | ||
std::vector<std::optional<bool>> expectedResult{ | ||
false, false, true, std::nullopt}; | ||
testExpr(expectedResult, "x > 1.1", input); | ||
} | ||
|
||
TEST_F(ArrayAnyMatchTest, errors) { | ||
// No throw and return false if there are unmatched elements except nulls | ||
auto expression = "(10 / x) > 2"; | ||
std::vector<std::vector<std::optional<int8_t>>> input{ | ||
{0, 2, 0, 5, 0}, {2, 5, std::nullopt, 0}}; | ||
std::vector<std::optional<bool>> expectedResult = {true, true}; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. perhaps,
|
||
testExpr(expectedResult, expression, input); | ||
|
||
// Throw error if others are matched or null | ||
static constexpr std::string_view kErrorMessage{"division by zero"}; | ||
auto errorInput = makeNullableArrayVector<int8_t>( | ||
{{1, 0}, {2}, {6}, {10, 9, 0, std::nullopt}, {0, std::nullopt, 1}}); | ||
VELOX_ASSERT_THROW( | ||
testExpr(expectedResult, expression, errorInput), kErrorMessage); | ||
// Rerun using TRY to get right results | ||
auto errorInputRow = makeRowVector({errorInput}); | ||
expectedResult = {true, true, false, std::nullopt, true}; | ||
testExpr( | ||
expectedResult, "TRY(any_match(c0, x -> ((10 / x) > 2)))", errorInputRow); | ||
testExpr( | ||
expectedResult, "any_match(c0, x -> (TRY((10 / x) > 2)))", errorInputRow); | ||
} | ||
|
||
TEST_F(ArrayAnyMatchTest, conditional) { | ||
// No throw and return false if there are unmatched elements except nulls | ||
auto c0 = makeFlatVector<uint32_t>({1, 2, 3, 4, 5}); | ||
auto c1 = makeNullableArrayVector<int32_t>( | ||
{{4, 100, std::nullopt}, | ||
{50, 12}, | ||
{std::nullopt}, | ||
{3, std::nullopt, 0}, | ||
{300, 100}}); | ||
auto input = makeRowVector({c0, c1}); | ||
std::vector<std::optional<bool>> expectedResult = { | ||
std::nullopt, false, std::nullopt, true, false}; | ||
testExpr( | ||
expectedResult, | ||
"any_match(c1, if (c0 <= 2, x -> (x > 100), x -> (10 / x > 2)))", | ||
input); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This "private:" is not needed as there is already one above it.