Skip to content

Commit

Permalink
Add support for collect_list Spark aggregate function
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiayi771 committed Mar 29, 2024
1 parent cb58cba commit 1d09263
Show file tree
Hide file tree
Showing 6 changed files with 323 additions and 0 deletions.
1 change: 1 addition & 0 deletions velox/functions/sparksql/aggregates/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ add_library(
BitwiseXorAggregate.cpp
BloomFilterAggAggregate.cpp
CentralMomentsAggregate.cpp
CollectListAggregate.cpp
FirstLastAggregate.cpp
MinMaxByAggregate.cpp
Register.cpp
Expand Down
141 changes: 141 additions & 0 deletions velox/functions/sparksql/aggregates/CollectListAggregate.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "velox/functions/sparksql/aggregates/CollectListAggregate.h"

using namespace facebook::velox::aggregate;
using namespace facebook::velox::exec;

namespace facebook::velox::functions::aggregate::sparksql {
namespace {
class CollectListAggregate {
public:
using InputType = Row<Generic<T1>>;

using IntermediateType = Array<Generic<T1>>;

using OutputType = Array<Generic<T1>>;

/// In Spark, when all inputs are null, the output is an empty array instead
/// of null. Therefore, in the writeIntermediateResult and writeFinalResult,
/// we still need to output the empty element_ when the group is null. This
/// behavior can only be achieved when the default-null behavior is disabled.
static constexpr bool default_null_behavior_ = false;

static bool toIntermediate(
exec::out_type<Array<Generic<T1>>>& out,
exec::optional_arg_type<Generic<T1>> in) {
if (in.has_value()) {
out.add_item().copy_from(in.value());
return true;
}
return false;
}

struct AccumulatorType {
ValueList elements_;

AccumulatorType() = delete;

explicit AccumulatorType(HashStringAllocator* /*allocator*/)
: elements_{} {}

static constexpr bool is_fixed_size_ = false;

bool addInput(
HashStringAllocator* allocator,
exec::optional_arg_type<Generic<T1>> data) {
if (data.has_value()) {
elements_.appendValue(data, allocator);
return true;
}
return false;
}

bool combine(
HashStringAllocator* allocator,
exec::optional_arg_type<IntermediateType> other) {
if (!other.has_value()) {
return false;
}
for (auto element : other.value()) {
elements_.appendValue(element, allocator);
}
return true;
}

bool writeIntermediateResult(
bool /*nonNullGroup*/,
exec::out_type<IntermediateType>& out) {
// If the group's accumulator is null, the corresponding intermediate
// result is an empty list.
copyValueListToArrayWriter(out, elements_);
return true;
}

bool writeFinalResult(
bool /*nonNullGroup*/,
exec::out_type<OutputType>& out) {
// If the group's accumulator is null, the corresponding result is an
// empty list.
copyValueListToArrayWriter(out, elements_);
return true;
}

void destroy(HashStringAllocator* allocator) {
elements_.free(allocator);
}
};
};

AggregateRegistrationResult registerCollectList(
const std::string& name,
bool withCompanionFunctions,
bool overwrite) {
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures{
exec::AggregateFunctionSignatureBuilder()
.typeVariable("E")
.returnType("array(E)")
.intermediateType("array(E)")
.argumentType("E")
.build()};
return exec::registerAggregateFunction(
name,
std::move(signatures),
[name](
core::AggregationNode::Step /*step*/,
const std::vector<TypePtr>& argTypes,
const TypePtr& resultType,
const core::QueryConfig& /*config*/)
-> std::unique_ptr<exec::Aggregate> {
VELOX_CHECK_EQ(
argTypes.size(), 1, "{} takes at most one argument", name);
return std::make_unique<SimpleAggregateAdapter<CollectListAggregate>>(
resultType);
},
withCompanionFunctions,
overwrite);
}
} // namespace

void registerCollectListAggregate(
const std::string& prefix,
bool withCompanionFunctions,
bool overwrite) {
registerCollectList(
prefix + "collect_list", withCompanionFunctions, overwrite);
}
} // namespace facebook::velox::functions::aggregate::sparksql
29 changes: 29 additions & 0 deletions velox/functions/sparksql/aggregates/CollectListAggregate.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "velox/exec/SimpleAggregateAdapter.h"
#include "velox/functions/lib/aggregates/ValueList.h"

namespace facebook::velox::functions::aggregate::sparksql {

void registerCollectListAggregate(
const std::string& prefix,
bool withCompanionFunctions,
bool overwrite);

} // namespace facebook::velox::functions::aggregate::sparksql
2 changes: 2 additions & 0 deletions velox/functions/sparksql/aggregates/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "velox/functions/sparksql/aggregates/BitwiseXorAggregate.h"
#include "velox/functions/sparksql/aggregates/BloomFilterAggAggregate.h"
#include "velox/functions/sparksql/aggregates/CentralMomentsAggregate.h"
#include "velox/functions/sparksql/aggregates/CollectListAggregate.h"
#include "velox/functions/sparksql/aggregates/SumAggregate.h"

namespace facebook::velox::functions::aggregate::sparksql {
Expand All @@ -45,5 +46,6 @@ void registerAggregateFunctions(
registerAverage(prefix + "avg", withCompanionFunctions, overwrite);
registerSum(prefix + "sum", withCompanionFunctions, overwrite);
registerCentralMomentsAggregate(prefix, withCompanionFunctions, overwrite);
registerCollectListAggregate(prefix, withCompanionFunctions, overwrite);
}
} // namespace facebook::velox::functions::aggregate::sparksql
2 changes: 2 additions & 0 deletions velox/functions/sparksql/aggregates/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ add_executable(
BitwiseXorAggregationTest.cpp
BloomFilterAggAggregateTest.cpp
CentralMomentsAggregationTest.cpp
CollectListAggregateTest.cpp
FirstAggregateTest.cpp
LastAggregateTest.cpp
Main.cpp
Expand All @@ -34,6 +35,7 @@ target_link_libraries(
velox_functions_aggregates_test_lib
velox_functions_spark_aggregates
velox_hive_connector
velox_vector_fuzzer
gflags::gflags
gtest
gtest_main)
148 changes: 148 additions & 0 deletions velox/functions/sparksql/aggregates/tests/CollectListAggregateTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "velox/exec/tests/SimpleAggregateFunctionsRegistration.h"
#include "velox/exec/tests/utils/AssertQueryBuilder.h"
#include "velox/exec/tests/utils/PlanBuilder.h"
#include "velox/functions/lib/aggregates/tests/utils/AggregationTestBase.h"
#include "velox/functions/sparksql/aggregates/Register.h"

using namespace facebook::velox::functions::aggregate::test;
using facebook::velox::exec::test::AssertQueryBuilder;
using facebook::velox::exec::test::PlanBuilder;

namespace facebook::velox::functions::aggregate::sparksql::test {

namespace {

class CollectListAggregateTest : public AggregationTestBase {
protected:
void SetUp() override {
AggregationTestBase::SetUp();
registerAggregateFunctions("spark_");
}

RowVectorPtr fuzzFlat(const RowTypePtr& rowType, size_t size) {
VectorFuzzer::Options options;
options.vectorSize = size;
VectorFuzzer fuzzer(options, pool());
return fuzzer.fuzzInputFlatRow(rowType);
}
};

TEST_F(CollectListAggregateTest, groupBy) {
constexpr int32_t kNumGroups = 10;
std::vector<RowVectorPtr> batches;
batches.push_back(
fuzzFlat(ROW({"c0", "a"}, {INTEGER(), ARRAY(VARCHAR())}), 100));
auto keys = batches[0]->childAt(0)->as<FlatVector<int32_t>>();
auto values = batches[0]->childAt(1)->as<ArrayVector>();
for (auto i = 0; i < keys->size(); ++i) {
if (i % 10 == 0) {
keys->setNull(i, true);
} else {
keys->set(i, i % kNumGroups);
}

if (i % 7 == 0) {
values->setNull(i, true);
}
}

for (auto i = 0; i < 9; ++i) {
batches.push_back(batches[0]);
}

createDuckDbTable(batches);
testAggregations(
batches,
{"c0"},
{"spark_collect_list(a)"},
{"c0", "array_sort(a0)"},
"SELECT c0, array_sort(array_agg(a)"
"filter (where a is not null)) FROM tmp GROUP BY c0");
testAggregationsWithCompanion(
batches,
[](auto& /*builder*/) {},
{"c0"},
{"spark_collect_list(a)"},
{{ARRAY(VARCHAR())}},
{"c0", "array_sort(a0)"},
"SELECT c0, array_sort(array_agg(a)"
"filter (where a is not null)) FROM tmp GROUP BY c0");
}

TEST_F(CollectListAggregateTest, global) {
vector_size_t size = 10;
std::vector<RowVectorPtr> vectors = {makeRowVector({makeFlatVector<int32_t>(
size, [](vector_size_t row) { return row * 2; }, nullEvery(3))})};

createDuckDbTable(vectors);
testAggregations(
vectors,
{},
{"spark_collect_list(c0)"},
{"array_sort(a0)"},
"SELECT array_sort(array_agg(c0)"
"filter (where c0 is not null)) FROM tmp");
testAggregationsWithCompanion(
vectors,
[](auto& /*builder*/) {},
{},
{"spark_collect_list(c0)"},
{{ARRAY(VARCHAR())}},
{"array_sort(a0)"},
"SELECT array_sort(array_agg(c0)"
"filter (where c0 is not null)) FROM tmp");
}

TEST_F(CollectListAggregateTest, ignoreNulls) {
auto input = makeRowVector({makeNullableFlatVector<int32_t>(
{1, 2, std::nullopt, 4, std::nullopt, 6}, INTEGER())});
// Spark will ignore all null values in the input.
auto expected = makeRowVector({makeArrayVector<int32_t>({{1, 2, 4, 6}})});
testAggregations(
{input}, {}, {"spark_collect_list(c0)"}, {"array_sort(a0)"}, {expected});
testAggregationsWithCompanion(
{input},
[](auto& /*builder*/) {},
{},
{"spark_collect_list(c0)"},
{{INTEGER()}},
{"array_sort(a0)"},
{expected},
{});
}

TEST_F(CollectListAggregateTest, allNullsInput) {
std::vector<std::optional<int64_t>> allNull(100, std::nullopt);
auto input =
makeRowVector({makeNullableFlatVector<int64_t>(allNull, BIGINT())});
// If all input data is null, Spark will output an empty array.
auto expected = makeRowVector({makeArrayVector<int32_t>({{}})});
testAggregations({input}, {}, {"spark_collect_list(c0)"}, {expected});
testAggregationsWithCompanion(
{input},
[](auto& /*builder*/) {},
{},
{"spark_collect_list(c0)"},
{{BIGINT()}},
{},
{expected},
{});
}
} // namespace
} // namespace facebook::velox::functions::aggregate::sparksql::test

0 comments on commit 1d09263

Please sign in to comment.