Skip to content

Commit

Permalink
Add support for collect_list Spark aggregate function (facebookincuba…
Browse files Browse the repository at this point in the history
…tor#9231)

Summary:
The semantics of Spark's `collect_list` and Presto's `array_agg` are
generally consistent, but there are inconsistencies in the handling of null
values. Spark always ignores null values in the input, whereas Presto has a
parameter that controls whether to retain them. Moreover, Presto returns null
when all inputs are null, while Spark returns an empty array.

Because of these differences, we need to re-implement the `array_agg`
function for Spark.

Pull Request resolved: facebookincubator#9231

Reviewed By: xiaoxmeng

Differential Revision: D55639676

Pulled By: mbasmanova

fbshipit-source-id: 958471779a1fa66dba27569a6c12538ad5489f46
  • Loading branch information
liujiayi771 authored and Joe-Abraham committed Apr 4, 2024
1 parent 63372ca commit 405c0a1
Show file tree
Hide file tree
Showing 8 changed files with 311 additions and 1 deletion.
5 changes: 5 additions & 0 deletions velox/docs/functions/spark/aggregate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ General Aggregate Functions

``hash`` cannot be null.

.. spark:function:: collect_list(x) -> array<[same as x]>
Returns an array created from the input ``x`` elements. Ignores null
inputs, and returns an empty array when all inputs are null.

.. spark:function:: first(x) -> x
Returns the first value of `x`.
Expand Down
1 change: 1 addition & 0 deletions velox/functions/sparksql/aggregates/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ add_library(
BitwiseXorAggregate.cpp
BloomFilterAggAggregate.cpp
CentralMomentsAggregate.cpp
CollectListAggregate.cpp
FirstLastAggregate.cpp
MinMaxByAggregate.cpp
Register.cpp
Expand Down
142 changes: 142 additions & 0 deletions velox/functions/sparksql/aggregates/CollectListAggregate.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "velox/functions/sparksql/aggregates/CollectListAggregate.h"

#include "velox/exec/SimpleAggregateAdapter.h"
#include "velox/functions/lib/aggregates/ValueList.h"

using namespace facebook::velox::aggregate;
using namespace facebook::velox::exec;

namespace facebook::velox::functions::aggregate::sparksql {
namespace {
class CollectListAggregate {
public:
using InputType = Row<Generic<T1>>;

using IntermediateType = Array<Generic<T1>>;

using OutputType = Array<Generic<T1>>;

/// In Spark, when all inputs are null, the output is an empty array instead
/// of null. Therefore, in the writeIntermediateResult and writeFinalResult,
/// we still need to output the empty element_ when the group is null. This
/// behavior can only be achieved when the default-null behavior is disabled.
static constexpr bool default_null_behavior_ = false;

static bool toIntermediate(
exec::out_type<Array<Generic<T1>>>& out,
exec::optional_arg_type<Generic<T1>> in) {
if (in.has_value()) {
out.add_item().copy_from(in.value());
return true;
}
return false;
}

struct AccumulatorType {
ValueList elements_;

explicit AccumulatorType(HashStringAllocator* /*allocator*/)
: elements_{} {}

static constexpr bool is_fixed_size_ = false;

bool addInput(
HashStringAllocator* allocator,
exec::optional_arg_type<Generic<T1>> data) {
if (data.has_value()) {
elements_.appendValue(data, allocator);
return true;
}
return false;
}

bool combine(
HashStringAllocator* allocator,
exec::optional_arg_type<IntermediateType> other) {
if (!other.has_value()) {
return false;
}
for (auto element : other.value()) {
elements_.appendValue(element, allocator);
}
return true;
}

bool writeIntermediateResult(
bool /*nonNullGroup*/,
exec::out_type<IntermediateType>& out) {
// If the group's accumulator is null, the corresponding intermediate
// result is an empty array.
copyValueListToArrayWriter(out, elements_);
return true;
}

bool writeFinalResult(
bool /*nonNullGroup*/,
exec::out_type<OutputType>& out) {
// If the group's accumulator is null, the corresponding result is an
// empty array.
copyValueListToArrayWriter(out, elements_);
return true;
}

void destroy(HashStringAllocator* allocator) {
elements_.free(allocator);
}
};
};

AggregateRegistrationResult registerCollectList(
const std::string& name,
bool withCompanionFunctions,
bool overwrite) {
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures{
exec::AggregateFunctionSignatureBuilder()
.typeVariable("E")
.returnType("array(E)")
.intermediateType("array(E)")
.argumentType("E")
.build()};
return exec::registerAggregateFunction(
name,
std::move(signatures),
[name](
core::AggregationNode::Step /*step*/,
const std::vector<TypePtr>& argTypes,
const TypePtr& resultType,
const core::QueryConfig& /*config*/)
-> std::unique_ptr<exec::Aggregate> {
VELOX_CHECK_EQ(
argTypes.size(), 1, "{} takes at most one argument", name);
return std::make_unique<SimpleAggregateAdapter<CollectListAggregate>>(
resultType);
},
withCompanionFunctions,
overwrite);
}
} // namespace

void registerCollectListAggregate(
const std::string& prefix,
bool withCompanionFunctions,
bool overwrite) {
registerCollectList(
prefix + "collect_list", withCompanionFunctions, overwrite);
}
} // namespace facebook::velox::functions::aggregate::sparksql
28 changes: 28 additions & 0 deletions velox/functions/sparksql/aggregates/CollectListAggregate.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <string>

namespace facebook::velox::functions::aggregate::sparksql {

void registerCollectListAggregate(
const std::string& prefix,
bool withCompanionFunctions,
bool overwrite);

} // namespace facebook::velox::functions::aggregate::sparksql
2 changes: 2 additions & 0 deletions velox/functions/sparksql/aggregates/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "velox/functions/sparksql/aggregates/BitwiseXorAggregate.h"
#include "velox/functions/sparksql/aggregates/BloomFilterAggAggregate.h"
#include "velox/functions/sparksql/aggregates/CentralMomentsAggregate.h"
#include "velox/functions/sparksql/aggregates/CollectListAggregate.h"
#include "velox/functions/sparksql/aggregates/SumAggregate.h"

namespace facebook::velox::functions::aggregate::sparksql {
Expand All @@ -45,5 +46,6 @@ void registerAggregateFunctions(
registerAverage(prefix + "avg", withCompanionFunctions, overwrite);
registerSum(prefix + "sum", withCompanionFunctions, overwrite);
registerCentralMomentsAggregate(prefix, withCompanionFunctions, overwrite);
registerCollectListAggregate(prefix, withCompanionFunctions, overwrite);
}
} // namespace facebook::velox::functions::aggregate::sparksql
1 change: 1 addition & 0 deletions velox/functions/sparksql/aggregates/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ add_executable(
BitwiseXorAggregationTest.cpp
BloomFilterAggAggregateTest.cpp
CentralMomentsAggregationTest.cpp
CollectListAggregateTest.cpp
FirstAggregateTest.cpp
LastAggregateTest.cpp
Main.cpp
Expand Down
127 changes: 127 additions & 0 deletions velox/functions/sparksql/aggregates/tests/CollectListAggregateTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "velox/functions/lib/aggregates/tests/utils/AggregationTestBase.h"
#include "velox/functions/sparksql/aggregates/Register.h"

using namespace facebook::velox::functions::aggregate::test;

namespace facebook::velox::functions::aggregate::sparksql::test {

namespace {

class CollectListAggregateTest : public AggregationTestBase {
protected:
void SetUp() override {
AggregationTestBase::SetUp();
registerAggregateFunctions("spark_");
}
};

TEST_F(CollectListAggregateTest, groupBy) {
std::vector<RowVectorPtr> batches;
// Creating 3 batches of input data.
// 0: {0, null} {0, 1} {0, 2}
// 1: {1, 1} {1, null} {1, 3}
// 2: {2, 2} {2, 3} {2, null}
// 3: {3, 3} {3, 4} {3, 5}
// 4: {4, 4} {4, 5} {4, 6}
for (auto i = 0; i < 3; i++) {
RowVectorPtr data = makeRowVector(
{makeFlatVector<int32_t>({0, 1, 2, 3, 4}),
makeFlatVector<int64_t>(
5,
[&i](const vector_size_t& row) { return i + row; },
[&i](const auto& row) { return i == row; })});
batches.push_back(data);
}

auto expected = makeRowVector(
{makeFlatVector<int32_t>({0, 1, 2, 3, 4}),
makeArrayVectorFromJson<int64_t>(
{"[1, 2]", "[1, 3]", "[2, 3]", "[3, 4, 5]", "[4, 5, 6]"})});

testAggregations(
batches,
{"c0"},
{"spark_collect_list(c1)"},
{"c0", "array_sort(a0)"},
{expected});
testAggregationsWithCompanion(
batches,
[](auto& /*builder*/) {},
{"c0"},
{"spark_collect_list(c1)"},
{{BIGINT()}},
{"c0", "array_sort(a0)"},
{expected},
{});
}

TEST_F(CollectListAggregateTest, global) {
auto data = makeRowVector({makeNullableFlatVector<int32_t>(
{std::nullopt, 1, 2, std::nullopt, 4, 5})});
auto expected =
makeRowVector({makeArrayVectorFromJson<int32_t>({"[1, 2, 4, 5]"})});

testAggregations(
{data}, {}, {"spark_collect_list(c0)"}, {"array_sort(a0)"}, {expected});
testAggregationsWithCompanion(
{data},
[](auto& /*builder*/) {},
{},
{"spark_collect_list(c0)"},
{{INTEGER()}},
{"array_sort(a0)"},
{expected});
}

TEST_F(CollectListAggregateTest, ignoreNulls) {
auto input = makeRowVector({makeNullableFlatVector<int32_t>(
{1, 2, std::nullopt, 4, std::nullopt, 6})});
// Spark will ignore all null values in the input.
auto expected =
makeRowVector({makeArrayVectorFromJson<int32_t>({"[1, 2, 4, 6]"})});
testAggregations(
{input}, {}, {"spark_collect_list(c0)"}, {"array_sort(a0)"}, {expected});
testAggregationsWithCompanion(
{input},
[](auto& /*builder*/) {},
{},
{"spark_collect_list(c0)"},
{{INTEGER()}},
{"array_sort(a0)"},
{expected},
{});
}

TEST_F(CollectListAggregateTest, allNullsInput) {
auto input = makeRowVector({makeAllNullFlatVector<int64_t>(100)});
// If all input data is null, Spark will output an empty array.
auto expected = makeRowVector({makeArrayVectorFromJson<int32_t>({"[]"})});
testAggregations({input}, {}, {"spark_collect_list(c0)"}, {expected});
testAggregationsWithCompanion(
{input},
[](auto& /*builder*/) {},
{},
{"spark_collect_list(c0)"},
{{BIGINT()}},
{},
{expected},
{});
}
} // namespace
} // namespace facebook::velox::functions::aggregate::sparksql::test
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ int main(int argc, char** argv) {
{"max_by", nullptr},
{"min_by", nullptr},
{"skewness", nullptr},
{"kurtosis", nullptr}};
{"kurtosis", nullptr},
{"collect_list", nullptr}};

size_t initialSeed = FLAGS_seed == 0 ? std::time(nullptr) : FLAGS_seed;
auto duckQueryRunner =
Expand All @@ -88,6 +89,9 @@ int main(int argc, char** argv) {
// coefficient. Meanwhile, DuckDB employs the sample kurtosis calculation
// formula. The results from the two methods are completely different.
"kurtosis",
// When all data in a group are null, Spark returns an empty array while
// DuckDB returns null.
"collect_list",
});

using Runner = facebook::velox::exec::test::AggregationFuzzerRunner;
Expand Down

0 comments on commit 405c0a1

Please sign in to comment.