Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add zip_with Presto lambda function #2685

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion velox/docs/functions/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,15 @@ Array Functions
The M-th element of the N-th argument will be the N-th field of the M-th output element.
If the arguments have an uneven length, missing values are filled with ``NULL`` ::

SELECT zip(ARRAY[1, 2], ARRAY['1b', null, '3b']); -- [ROW(1, '1b'), ROW(2, null), ROW(null, '3b')]
SELECT zip(ARRAY[1, 2], ARRAY['1b', null, '3b']); -- [ROW(1, '1b'), ROW(2, null), ROW(null, '3b')]

.. function:: zip_with(array(T), array(U), function(T,U,R)) -> array(R)

Merges the two given arrays, element-wise, into a single array using ``function``.
If one array is shorter, nulls are appended at the end to match the length of the longer array, before applying ``function``::

SELECT zip_with(ARRAY[1, 3, 5], ARRAY['a', 'b', 'c'], (x, y) -> (y, x)); -- [ROW('a', 1), ROW('b', 3), ROW('c', 5)]
SELECT zip_with(ARRAY[1, 2], ARRAY[3, 4], (x, y) -> x + y); -- [4, 6]
SELECT zip_with(ARRAY['a', 'b', 'c'], ARRAY['d', 'e', 'f'], (x, y) -> concat(x, y)); -- ['ad', 'be', 'cf']
SELECT zip_with(ARRAY['a'], ARRAY['d', null, 'f'], (x, y) -> coalesce(x, y)); -- ['a', null, 'f']

3 changes: 2 additions & 1 deletion velox/functions/prestosql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ add_library(
URLFunctions.cpp
VectorArithmetic.cpp
WidthBucketArray.cpp
Zip.cpp)
Zip.cpp
ZipWith.cpp)

target_link_libraries(
velox_functions_prestosql_impl
Expand Down
287 changes: 287 additions & 0 deletions velox/functions/prestosql/ZipWith.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/expression/Expr.h"
#include "velox/expression/VarSetter.h"
#include "velox/expression/VectorFunction.h"
#include "velox/functions/lib/LambdaFunctionUtil.h"
#include "velox/vector/FunctionVector.h"

namespace facebook::velox::functions {
namespace {

struct Buffers {
BufferPtr offsets;
BufferPtr sizes;
BufferPtr nulls;
vector_size_t numElements;
};

struct DecodedInputs {
DecodedVector* decodedLeft;
DecodedVector* decodedRight;
const ArrayVector* baseLeft;
const ArrayVector* baseRight;

DecodedInputs(DecodedVector* _decodedLeft, DecodedVector* _decodeRight)
: decodedLeft{_decodedLeft},
decodedRight{_decodeRight},
baseLeft{decodedLeft->base()->asUnchecked<ArrayVector>()},
baseRight{decodedRight->base()->asUnchecked<ArrayVector>()} {}
};

// See documentation at
// https://prestodb.io/docs/current/functions/array.html#zip_with
class ZipWithFunction : public exec::VectorFunction {
public:
bool isDefaultNullBehavior() const override {
// zip_with is null preserving for the arrays, but since an
// expr tree with a lambda depends on all named fields, including
// captures, a null in a capture does not automatically make a
// null result.
return false;
}

void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
const TypePtr& outputType,
exec::EvalCtx& context,
VectorPtr& result) const override {
VELOX_CHECK_EQ(args.size(), 3);
exec::DecodedArgs decodedArgs(rows, {args[0], args[1]}, context);
DecodedInputs decodedInputs{decodedArgs.at(0), decodedArgs.at(1)};

// Number of elements in the result vector.
// Sizes, offsets and nulls for the result ArrayVector.
// Size of the result array is the max of sizes of the input arrays.
// Result array is null if one or both of the input arrays are null.
bool leftNeedsPadding = false;
bool rightNeedsPadding = false;
auto resultBuffers = computeResultBuffers(
decodedInputs,
rows,
context.pool(),
leftNeedsPadding,
rightNeedsPadding);

// If one array is shorter than the other, add nulls at the end of the
// shorter array. Use dictionary encoding to represent elements of the
// padded arrays.
auto lambdaArgs = flattenAndPadArrays(
decodedInputs,
resultBuffers,
rows,
context.pool(),
leftNeedsPadding,
rightNeedsPadding);

const auto numResultElements = resultBuffers.numElements;
auto rawOffsets = resultBuffers.offsets->as<vector_size_t>();
auto rawSizes = resultBuffers.sizes->as<vector_size_t>();

const SelectivityVector allElementRows(numResultElements);

VectorPtr newElements;

// Loop over lambda functions and apply these to (leftElements,
// rightElements). In most cases there will be only one function and the
// loop will run once.
auto it = args[2]->asUnchecked<FunctionVector>()->iterator(&rows);
while (auto entry = it.next()) {
SelectivityVector elementRows(numResultElements, false);
entry.rows->applyToSelected([&](auto row) {
elementRows.setValidRange(
rawOffsets[row], rawOffsets[row] + rawSizes[row], true);
});
elementRows.updateBounds();

BufferPtr wrapCapture;
if (entry.callable->hasCapture()) {
wrapCapture = allocateIndices(numResultElements, context.pool());
auto rawWrapCaptures = wrapCapture->asMutable<vector_size_t>();

vector_size_t offset = 0;
entry.rows->applyToSelected([&](auto row) {
for (auto i = 0; i < rawSizes[row]; ++i) {
rawWrapCaptures[offset++] = row;
}
});
}

// Make sure already populated entries in newElements do not get
// overwritten.
VarSetter finalSelection(
context.mutableFinalSelection(), &allElementRows);
VarSetter isFinalSelection(context.mutableIsFinalSelection(), false);

entry.callable->apply(
elementRows,
allElementRows,
wrapCapture,
&context,
lambdaArgs,
&newElements);
}

auto localResult = std::make_shared<ArrayVector>(
context.pool(),
outputType,
resultBuffers.nulls,
rows.end(),
resultBuffers.offsets,
resultBuffers.sizes,
newElements);
context.moveOrCopyResult(localResult, rows, result);
}

static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
// array(T), array(U), function(T, U, R) -> array(R)
return {exec::FunctionSignatureBuilder()
.typeVariable("T")
.typeVariable("U")
.typeVariable("R")
.returnType("array(R)")
.argumentType("array(T)")
.argumentType("array(U)")
.argumentType("function(T, U, R)")
.build()};
}

private:
static Buffers computeResultBuffers(
const DecodedInputs& decodedInputs,
const SelectivityVector& rows,
memory::MemoryPool* pool,
bool& leftNeedsPadding,
bool& rightNeedsPadding) {
BufferPtr sizes = allocateSizes(rows.end(), pool);
auto rawSizes = sizes->asMutable<vector_size_t>();

BufferPtr offsets = allocateOffsets(rows.end(), pool);
auto rawOffsets = offsets->asMutable<vector_size_t>();

BufferPtr nulls = allocateNulls(rows.end(), pool);
auto rawNulls = nulls->asMutable<uint64_t>();

vector_size_t offset = 0;
rows.applyToSelected([&](auto row) {
if (decodedInputs.decodedLeft->isNullAt(row) ||
decodedInputs.decodedRight->isNullAt(row)) {
rawSizes[row] = 0;
rawOffsets[row] = 0;
bits::setNull(rawNulls, row);
return;
}

auto leftRow = decodedInputs.decodedLeft->index(row);
auto rightRow = decodedInputs.decodedRight->index(row);
auto leftSize = decodedInputs.baseLeft->sizeAt(leftRow);
auto rightSize = decodedInputs.baseRight->sizeAt(rightRow);
auto size = std::max(leftSize, rightSize);
if (leftSize < size) {
leftNeedsPadding = true;
}
if (rightSize < size) {
rightNeedsPadding = true;
}
rawOffsets[row] = offset;
rawSizes[row] = size;
offset += size;
});

return {offsets, sizes, nulls, offset};
}

static VectorPtr flattenAndPadArray(
DecodedVector* decoded,
const ArrayVector* base,
const SelectivityVector& rows,
memory::MemoryPool* pool,
vector_size_t numResultElements,
const vector_size_t* resultSizes,
bool needsPadding) {
BufferPtr indices = allocateIndices(numResultElements, pool);
auto* rawIndices = indices->asMutable<vector_size_t>();

BufferPtr nulls;
uint64_t* rawNulls = nullptr;
if (needsPadding) {
nulls = allocateNulls(numResultElements, pool);
rawNulls = nulls->asMutable<uint64_t>();
}

vector_size_t resultOffset = 0;
rows.applyToSelected([&](auto row) {
const auto resultSize = resultSizes[row];
if (resultSize == 0) {
return;
}

auto baseRow = decoded->index(row);
auto size = base->sizeAt(baseRow);
auto offset = base->offsetAt(baseRow);

for (auto i = 0; i < size; ++i) {
rawIndices[resultOffset + i] = offset + i;
}
for (auto i = size; i < resultSize; ++i) {
bits::setNull(rawNulls, resultOffset + i);
}
resultOffset += resultSize;
});

return BaseVector::wrapInDictionary(
nulls, indices, numResultElements, base->elements());
}

static std::vector<VectorPtr> flattenAndPadArrays(
const DecodedInputs& decodedInputs,
const Buffers& resultBuffers,
const SelectivityVector& rows,
memory::MemoryPool* pool,
bool leftNeedsPadding,
bool rightNeedsPadding) {
auto resultSizes = resultBuffers.sizes->as<vector_size_t>();

auto paddedLeft = flattenAndPadArray(
decodedInputs.decodedLeft,
decodedInputs.baseLeft,
rows,
pool,
resultBuffers.numElements,
resultSizes,
leftNeedsPadding);

auto paddedRight = flattenAndPadArray(
decodedInputs.decodedRight,
decodedInputs.baseRight,
rows,
pool,
resultBuffers.numElements,
resultSizes,
rightNeedsPadding);

return {paddedLeft, paddedRight};
}
};
} // namespace

VELOX_DECLARE_VECTOR_FUNCTION(
udf_zip_with,
ZipWithFunction::signatures(),
std::make_unique<ZipWithFunction>());

} // namespace facebook::velox::functions
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ void registerArrayFunctions() {
VELOX_REGISTER_VECTOR_FUNCTION(udf_array_intersect, "array_intersect");
VELOX_REGISTER_VECTOR_FUNCTION(udf_array_contains, "contains");
VELOX_REGISTER_VECTOR_FUNCTION(udf_array_except, "array_except");
VELOX_REGISTER_VECTOR_FUNCTION(udf_array_duplicates, "array_duplicates");
VELOX_REGISTER_VECTOR_FUNCTION(udf_arrays_overlap, "arrays_overlap");
VELOX_REGISTER_VECTOR_FUNCTION(udf_slice, "slice");
VELOX_REGISTER_VECTOR_FUNCTION(udf_zip, "zip");
VELOX_REGISTER_VECTOR_FUNCTION(udf_zip_with, "zip_with");
VELOX_REGISTER_VECTOR_FUNCTION(udf_array_position, "array_position");
VELOX_REGISTER_VECTOR_FUNCTION(udf_array_sort, "array_sort");
VELOX_REGISTER_VECTOR_FUNCTION(udf_array_sum, "array_sum");
Expand Down
4 changes: 3 additions & 1 deletion velox/functions/prestosql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ add_executable(
URLFunctionsTest.cpp
WidthBucketArrayTest.cpp
GreatestLeastTest.cpp
ZipTest.cpp)
ZipTest.cpp
ZipWithTest.cpp)

add_test(velox_functions_test velox_functions_test)

Expand All @@ -77,6 +78,7 @@ target_link_libraries(
velox_functions_lib
velox_exec_test_lib
velox_dwio_common_test_utils
velox_vector_fuzzer
gtest
gtest_main
${gflags_LIBRARIES}
Expand Down
Loading