Skip to content

Commit

Permalink
Add shuffle Presto function (#3404)
Browse files Browse the repository at this point in the history
Summary:
`shuffle(x) -> array`

> Generate a random permutation of the given array x.

For example:
```
SELECT shuffle(cast(array[1,2,2,3,4,null,5, null] as array(int))) -- [null, 1, 3, 4, 2, 5, 2, null] or any other permutation
SELECT shuffle(cast(array[null, null] as array(int))) -- [null, null]
SELECT shuffle(array['a', 'a', 'a']) -- ['a', 'a', 'a']
SELECT shuffle(cast(null as array(int))) -- null
```

NOTE: this is a non-deterministic function, hence, it may return different results for the same input even when input is a constant or dictionary encoded vector.

Pull Request resolved: #3404

Differential Revision: D42646294

Pulled By: darrenfu

fbshipit-source-id: ca4052758ae485c5a4522298d2a35d445f2857d7
  • Loading branch information
darrenfu authored and facebook-github-bot committed Jan 23, 2023
1 parent a49d6eb commit 3f7437c
Show file tree
Hide file tree
Showing 6 changed files with 434 additions and 0 deletions.
8 changes: 8 additions & 0 deletions velox/docs/functions/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,14 @@ Array Functions

Returns an array which has the reversed order of the input array.

.. function:: shuffle(array(E)) -> array(E)

Generate a random permutation of the given ``array``::

SELECT shuffle(ARRAY [1, 2, 3]); -- [3, 1, 2] or any other random permutation
SELECT shuffle(ARRAY [0, 0, 0]); -- [0, 0, 0]
SELECT shuffle(ARRAY [1, NULL, 1, NULL, 2]); -- [2, NULL, NULL, NULL, 1] or any other random permutation

.. function:: slice(array(E), start, length) -> array(E)

Returns a subarray starting from index ``start``(or starting from the end
Expand Down
126 changes: 126 additions & 0 deletions velox/functions/prestosql/ArrayShuffle.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <numeric>
#include <random>
#include "velox/expression/EvalCtx.h"
#include "velox/expression/Expr.h"
#include "velox/expression/VectorFunction.h"

namespace facebook::velox::functions {
namespace {
// See documentation at
// https://prestodb.io/docs/current/functions/array.html#shuffle
//
// This function will shuffle identical arrays independently, i.e. even when
// the input has duplicate rows represented using constant and dictionary
// encoding, the output is flat and likely yields different values.
//
// E.g.1: constant encoding
// Input: ConstantVector(base=ArrayVector[{1,2,3}], length=3, index=0)
// Possible Output: ArrayVector[{1,3,2},{2,3,1},{3,2,1}]
//
// E.g.2: dict encoding
// Input: DictionaryVector(
// dictionaryValues=ArrayVector[{1,2,3},{4,5},{1,2,3}],
// dictionaryIndices=[1,2,0])
// Possible Output: ArrayVector[{5,4},{2,1,3},{1,3,2}]
//
class ArrayShuffleFunction : public exec::VectorFunction {
public:
bool isDeterministic() const override {
return false;
}

void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
const TypePtr& /*outputType*/,
exec::EvalCtx& context,
VectorPtr& result) const override {
VELOX_CHECK_EQ(args.size(), 1);

// This is a non-deterministic function, which violates the guarantee on a
// deterministic single-arg function that the expression evaluation will
// peel off encodings, and we will only see flat or constant inputs. Hence,
// we need to use DecodedVector to handle ALL encodings.
exec::DecodedArgs decodedArgs(rows, args, context);
auto decodedArg = decodedArgs.at(0);
auto arrayVector = decodedArg->base()->as<ArrayVector>();
auto elementsVector = arrayVector->elements();

vector_size_t numElements = 0;
context.applyToSelectedNoThrow(rows, [&](auto row) {
const auto size = arrayVector->sizeAt(decodedArg->index(row));
numElements += size;
});

// Allocate new buffer to hold shuffled indices.
BufferPtr shuffledIndices = allocateIndices(numElements, context.pool());
BufferPtr offsets = allocateOffsets(rows.size(), context.pool());
BufferPtr sizes = allocateSizes(rows.size(), context.pool());

vector_size_t* rawIndices = shuffledIndices->asMutable<vector_size_t>();
vector_size_t* rawOffsets = offsets->asMutable<vector_size_t>();
vector_size_t* rawSizes = sizes->asMutable<vector_size_t>();

vector_size_t newOffset = 0;
std::mt19937 randGen(std::random_device{}());
context.applyToSelectedNoThrow(rows, [&](auto row) {
vector_size_t arrayRow = decodedArg->index(row);
vector_size_t size = arrayVector->sizeAt(arrayRow);
vector_size_t offset = arrayVector->offsetAt(arrayRow);

std::iota(rawIndices + newOffset, rawIndices + newOffset + size, offset);
std::shuffle(
rawIndices + newOffset, rawIndices + newOffset + size, randGen);

rawSizes[row] = size;
rawOffsets[row] = newOffset;
newOffset += size;
});

auto resultElements = BaseVector::wrapInDictionary(
nullptr, shuffledIndices, numElements, elementsVector);
auto localResult = std::make_shared<ArrayVector>(
context.pool(),
arrayVector->type(),
nullptr,
rows.size(),
std::move(offsets),
std::move(sizes),
std::move(resultElements));

context.moveOrCopyResult(localResult, rows, result);
}
};

std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
return {// array(T) -> array(T)
exec::FunctionSignatureBuilder()
.typeVariable("T")
.returnType("array(T)")
.argumentType("array(T)")
.build()};
}

} // namespace

// Register function.
VELOX_DECLARE_VECTOR_FUNCTION(
udf_array_shuffle,
signatures(),
std::make_unique<ArrayShuffleFunction>());
} // namespace facebook::velox::functions
1 change: 1 addition & 0 deletions velox/functions/prestosql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ add_library(
ArrayDuplicates.cpp
ArrayIntersectExcept.cpp
ArrayPosition.cpp
ArrayShuffle.cpp
ArraySort.cpp
ArraySum.cpp
Comparisons.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ void registerArrayFunctions() {
VELOX_REGISTER_VECTOR_FUNCTION(udf_zip, "zip");
VELOX_REGISTER_VECTOR_FUNCTION(udf_zip_with, "zip_with");
VELOX_REGISTER_VECTOR_FUNCTION(udf_array_position, "array_position");
VELOX_REGISTER_VECTOR_FUNCTION(udf_array_shuffle, "shuffle");
VELOX_REGISTER_VECTOR_FUNCTION(udf_array_sort, "array_sort");
VELOX_REGISTER_VECTOR_FUNCTION(udf_array_sum, "array_sum");
VELOX_REGISTER_VECTOR_FUNCTION(udf_repeat, "repeat");
Expand Down
Loading

0 comments on commit 3f7437c

Please sign in to comment.