Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for collect_list Spark aggregate function #9231

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions velox/docs/functions/spark/aggregate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ General Aggregate Functions

``hash`` cannot be null.

.. spark:function:: collect_list(x) -> array<[same as x]>
Returns an array created from the input ``x`` elements. Ignores null
inputs, and returns an empty array when all inputs are null.

.. spark:function:: first(x) -> x
Returns the first value of `x`.
Expand Down
1 change: 1 addition & 0 deletions velox/functions/sparksql/aggregates/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ add_library(
BitwiseXorAggregate.cpp
BloomFilterAggAggregate.cpp
CentralMomentsAggregate.cpp
CollectListAggregate.cpp
FirstLastAggregate.cpp
MinMaxByAggregate.cpp
Register.cpp
Expand Down
142 changes: 142 additions & 0 deletions velox/functions/sparksql/aggregates/CollectListAggregate.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "velox/functions/sparksql/aggregates/CollectListAggregate.h"

#include "velox/exec/SimpleAggregateAdapter.h"
#include "velox/functions/lib/aggregates/ValueList.h"

using namespace facebook::velox::aggregate;
using namespace facebook::velox::exec;

namespace facebook::velox::functions::aggregate::sparksql {
namespace {
class CollectListAggregate {
public:
using InputType = Row<Generic<T1>>;

using IntermediateType = Array<Generic<T1>>;

using OutputType = Array<Generic<T1>>;

/// In Spark, when all inputs are null, the output is an empty array instead
/// of null. Therefore, in the writeIntermediateResult and writeFinalResult,
/// we still need to output the empty element_ when the group is null. This
/// behavior can only be achieved when the default-null behavior is disabled.
static constexpr bool default_null_behavior_ = false;

static bool toIntermediate(
exec::out_type<Array<Generic<T1>>>& out,
exec::optional_arg_type<Generic<T1>> in) {
if (in.has_value()) {
out.add_item().copy_from(in.value());
return true;
}
return false;
}

struct AccumulatorType {
ValueList elements_;

explicit AccumulatorType(HashStringAllocator* /*allocator*/)
: elements_{} {}

static constexpr bool is_fixed_size_ = false;

bool addInput(
HashStringAllocator* allocator,
exec::optional_arg_type<Generic<T1>> data) {
if (data.has_value()) {
elements_.appendValue(data, allocator);
return true;
}
return false;
}

bool combine(
HashStringAllocator* allocator,
exec::optional_arg_type<IntermediateType> other) {
if (!other.has_value()) {
return false;
}
for (auto element : other.value()) {
elements_.appendValue(element, allocator);
}
return true;
}

bool writeIntermediateResult(
bool /*nonNullGroup*/,
exec::out_type<IntermediateType>& out) {
// If the group's accumulator is null, the corresponding intermediate
// result is an empty array.
copyValueListToArrayWriter(out, elements_);
return true;
}

bool writeFinalResult(
bool /*nonNullGroup*/,
exec::out_type<OutputType>& out) {
// If the group's accumulator is null, the corresponding result is an
// empty array.
copyValueListToArrayWriter(out, elements_);
return true;
}

void destroy(HashStringAllocator* allocator) {
elements_.free(allocator);
}
};
};

AggregateRegistrationResult registerCollectList(
const std::string& name,
bool withCompanionFunctions,
bool overwrite) {
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures{
exec::AggregateFunctionSignatureBuilder()
.typeVariable("E")
.returnType("array(E)")
.intermediateType("array(E)")
.argumentType("E")
.build()};
return exec::registerAggregateFunction(
name,
std::move(signatures),
[name](
core::AggregationNode::Step /*step*/,
const std::vector<TypePtr>& argTypes,
const TypePtr& resultType,
const core::QueryConfig& /*config*/)
-> std::unique_ptr<exec::Aggregate> {
VELOX_CHECK_EQ(
argTypes.size(), 1, "{} takes at most one argument", name);
return std::make_unique<SimpleAggregateAdapter<CollectListAggregate>>(
resultType);
},
withCompanionFunctions,
overwrite);
}
} // namespace

void registerCollectListAggregate(
const std::string& prefix,
bool withCompanionFunctions,
bool overwrite) {
registerCollectList(
prefix + "collect_list", withCompanionFunctions, overwrite);
}
} // namespace facebook::velox::functions::aggregate::sparksql
28 changes: 28 additions & 0 deletions velox/functions/sparksql/aggregates/CollectListAggregate.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <string>

namespace facebook::velox::functions::aggregate::sparksql {

void registerCollectListAggregate(
const std::string& prefix,
bool withCompanionFunctions,
bool overwrite);

} // namespace facebook::velox::functions::aggregate::sparksql
2 changes: 2 additions & 0 deletions velox/functions/sparksql/aggregates/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "velox/functions/sparksql/aggregates/BitwiseXorAggregate.h"
#include "velox/functions/sparksql/aggregates/BloomFilterAggAggregate.h"
#include "velox/functions/sparksql/aggregates/CentralMomentsAggregate.h"
#include "velox/functions/sparksql/aggregates/CollectListAggregate.h"
#include "velox/functions/sparksql/aggregates/SumAggregate.h"

namespace facebook::velox::functions::aggregate::sparksql {
Expand All @@ -45,5 +46,6 @@ void registerAggregateFunctions(
registerAverage(prefix + "avg", withCompanionFunctions, overwrite);
registerSum(prefix + "sum", withCompanionFunctions, overwrite);
registerCentralMomentsAggregate(prefix, withCompanionFunctions, overwrite);
registerCollectListAggregate(prefix, withCompanionFunctions, overwrite);
}
} // namespace facebook::velox::functions::aggregate::sparksql
1 change: 1 addition & 0 deletions velox/functions/sparksql/aggregates/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ add_executable(
BitwiseXorAggregationTest.cpp
BloomFilterAggAggregateTest.cpp
CentralMomentsAggregationTest.cpp
CollectListAggregateTest.cpp
FirstAggregateTest.cpp
LastAggregateTest.cpp
Main.cpp
Expand Down
127 changes: 127 additions & 0 deletions velox/functions/sparksql/aggregates/tests/CollectListAggregateTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "velox/functions/lib/aggregates/tests/utils/AggregationTestBase.h"
#include "velox/functions/sparksql/aggregates/Register.h"

using namespace facebook::velox::functions::aggregate::test;

namespace facebook::velox::functions::aggregate::sparksql::test {

namespace {

class CollectListAggregateTest : public AggregationTestBase {
protected:
void SetUp() override {
AggregationTestBase::SetUp();
registerAggregateFunctions("spark_");
}
};

TEST_F(CollectListAggregateTest, groupBy) {
liujiayi771 marked this conversation as resolved.
Show resolved Hide resolved
std::vector<RowVectorPtr> batches;
// Creating 3 batches of input data.
// 0: {0, null} {0, 1} {0, 2}
// 1: {1, 1} {1, null} {1, 3}
// 2: {2, 2} {2, 3} {2, null}
// 3: {3, 3} {3, 4} {3, 5}
// 4: {4, 4} {4, 5} {4, 6}
for (auto i = 0; i < 3; i++) {
RowVectorPtr data = makeRowVector(
{makeFlatVector<int32_t>({0, 1, 2, 3, 4}),
makeFlatVector<int64_t>(
5,
[&i](const vector_size_t& row) { return i + row; },
[&i](const auto& row) { return i == row; })});
batches.push_back(data);
}

auto expected = makeRowVector(
{makeFlatVector<int32_t>({0, 1, 2, 3, 4}),
makeArrayVectorFromJson<int64_t>(
{"[1, 2]", "[1, 3]", "[2, 3]", "[3, 4, 5]", "[4, 5, 6]"})});

testAggregations(
batches,
{"c0"},
{"spark_collect_list(c1)"},
{"c0", "array_sort(a0)"},
{expected});
testAggregationsWithCompanion(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any particular reasons companion function testing is not included as part of testAggregations? testAggregationsWithCompanion calls appear too verbose and repetitive.

CC: @kagamiori

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@liujiayi771 Would you take a look at this comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mbasmanova I think the possible reason is that some aggregate functions have not registered the companion functions due to certain restrictions, such as when isResultTypeResolvableGivenIntermediateType is false.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I don't understand is why we need to pass [](auto& /*builder*/) {}, and {{BIGINT()}}, to testAggregationsWithCompanion and why do we need to call both testAggregations and testAggregationsWithCompanion.

Why can't we just call

testAggregationsWithCompanion(
      batches,
      {"c0"},
      {"spark_collect_list(c1)"},
      {"c0", "array_sort(a0)"},
      "SELECT c0, array_sort(array_agg(c1)"
      "filter (where c1 is not null)) FROM tmp GROUP BY c0");

and have it test both regular functions as well as companion functions.

CC: @kagamiori

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is better, and the config parameter is also not necessary. Right now, many tests are calling testAggregations followed by testAggregationsWithCompanion. We need to combine these two test functions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do this refactoring in a follow-up.

batches,
[](auto& /*builder*/) {},
{"c0"},
{"spark_collect_list(c1)"},
{{BIGINT()}},
{"c0", "array_sort(a0)"},
{expected},
{});
}

TEST_F(CollectListAggregateTest, global) {
auto data = makeRowVector({makeNullableFlatVector<int32_t>(
{std::nullopt, 1, 2, std::nullopt, 4, 5})});
auto expected =
makeRowVector({makeArrayVectorFromJson<int32_t>({"[1, 2, 4, 5]"})});

testAggregations(
{data}, {}, {"spark_collect_list(c0)"}, {"array_sort(a0)"}, {expected});
testAggregationsWithCompanion(
{data},
[](auto& /*builder*/) {},
{},
{"spark_collect_list(c0)"},
{{INTEGER()}},
{"array_sort(a0)"},
{expected});
}

TEST_F(CollectListAggregateTest, ignoreNulls) {
auto input = makeRowVector({makeNullableFlatVector<int32_t>(
liujiayi771 marked this conversation as resolved.
Show resolved Hide resolved
{1, 2, std::nullopt, 4, std::nullopt, 6})});
// Spark will ignore all null values in the input.
auto expected =
makeRowVector({makeArrayVectorFromJson<int32_t>({"[1, 2, 4, 6]"})});
testAggregations(
{input}, {}, {"spark_collect_list(c0)"}, {"array_sort(a0)"}, {expected});
testAggregationsWithCompanion(
{input},
[](auto& /*builder*/) {},
{},
{"spark_collect_list(c0)"},
{{INTEGER()}},
{"array_sort(a0)"},
{expected},
{});
}

TEST_F(CollectListAggregateTest, allNullsInput) {
auto input = makeRowVector({makeAllNullFlatVector<int64_t>(100)});
// If all input data is null, Spark will output an empty array.
auto expected = makeRowVector({makeArrayVectorFromJson<int32_t>({"[]"})});
testAggregations({input}, {}, {"spark_collect_list(c0)"}, {expected});
testAggregationsWithCompanion(
{input},
[](auto& /*builder*/) {},
{},
{"spark_collect_list(c0)"},
{{BIGINT()}},
{},
{expected},
{});
}
} // namespace
} // namespace facebook::velox::functions::aggregate::sparksql::test
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ int main(int argc, char** argv) {
{"max_by", nullptr},
{"min_by", nullptr},
{"skewness", nullptr},
{"kurtosis", nullptr}};
{"kurtosis", nullptr},
{"collect_list", nullptr}};

size_t initialSeed = FLAGS_seed == 0 ? std::time(nullptr) : FLAGS_seed;
auto duckQueryRunner =
Expand All @@ -88,6 +89,9 @@ int main(int argc, char** argv) {
// coefficient. Meanwhile, DuckDB employs the sample kurtosis calculation
// formula. The results from the two methods are completely different.
"kurtosis",
// When all data in a group are null, Spark returns an empty array while
// DuckDB returns null.
"collect_list",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mbasmanova I think this function should not be compared with DuckDB. If the fuzzer generates a group where all the data is null, DuckDB's result will be null, while Spark will return an empty array.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. We need to change Fuzzer to verify results against Spark, not DuckDB: #9270

});

using Runner = facebook::velox::exec::test::AggregationFuzzerRunner;
Expand Down
Loading