-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for collect_list Spark aggregate function
- Loading branch information
1 parent
cb58cba
commit 1d09263
Showing
6 changed files
with
323 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
141 changes: 141 additions & 0 deletions
141
velox/functions/sparksql/aggregates/CollectListAggregate.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
/* | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "velox/functions/sparksql/aggregates/CollectListAggregate.h" | ||
|
||
using namespace facebook::velox::aggregate; | ||
using namespace facebook::velox::exec; | ||
|
||
namespace facebook::velox::functions::aggregate::sparksql { | ||
namespace { | ||
class CollectListAggregate { | ||
public: | ||
using InputType = Row<Generic<T1>>; | ||
|
||
using IntermediateType = Array<Generic<T1>>; | ||
|
||
using OutputType = Array<Generic<T1>>; | ||
|
||
/// In Spark, when all inputs are null, the output is an empty array instead | ||
/// of null. Therefore, in the writeIntermediateResult and writeFinalResult, | ||
/// we still need to output the empty element_ when the group is null. This | ||
/// behavior can only be achieved when the default-null behavior is disabled. | ||
static constexpr bool default_null_behavior_ = false; | ||
|
||
static bool toIntermediate( | ||
exec::out_type<Array<Generic<T1>>>& out, | ||
exec::optional_arg_type<Generic<T1>> in) { | ||
if (in.has_value()) { | ||
out.add_item().copy_from(in.value()); | ||
return true; | ||
} | ||
return false; | ||
} | ||
|
||
struct AccumulatorType { | ||
ValueList elements_; | ||
|
||
AccumulatorType() = delete; | ||
|
||
explicit AccumulatorType(HashStringAllocator* /*allocator*/) | ||
: elements_{} {} | ||
|
||
static constexpr bool is_fixed_size_ = false; | ||
|
||
bool addInput( | ||
HashStringAllocator* allocator, | ||
exec::optional_arg_type<Generic<T1>> data) { | ||
if (data.has_value()) { | ||
elements_.appendValue(data, allocator); | ||
return true; | ||
} | ||
return false; | ||
} | ||
|
||
bool combine( | ||
HashStringAllocator* allocator, | ||
exec::optional_arg_type<IntermediateType> other) { | ||
if (!other.has_value()) { | ||
return false; | ||
} | ||
for (auto element : other.value()) { | ||
elements_.appendValue(element, allocator); | ||
} | ||
return true; | ||
} | ||
|
||
bool writeIntermediateResult( | ||
bool /*nonNullGroup*/, | ||
exec::out_type<IntermediateType>& out) { | ||
// If the group's accumulator is null, the corresponding intermediate | ||
// result is an empty list. | ||
copyValueListToArrayWriter(out, elements_); | ||
return true; | ||
} | ||
|
||
bool writeFinalResult( | ||
bool /*nonNullGroup*/, | ||
exec::out_type<OutputType>& out) { | ||
// If the group's accumulator is null, the corresponding result is an | ||
// empty list. | ||
copyValueListToArrayWriter(out, elements_); | ||
return true; | ||
} | ||
|
||
void destroy(HashStringAllocator* allocator) { | ||
elements_.free(allocator); | ||
} | ||
}; | ||
}; | ||
|
||
AggregateRegistrationResult registerCollectList( | ||
const std::string& name, | ||
bool withCompanionFunctions, | ||
bool overwrite) { | ||
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures{ | ||
exec::AggregateFunctionSignatureBuilder() | ||
.typeVariable("E") | ||
.returnType("array(E)") | ||
.intermediateType("array(E)") | ||
.argumentType("E") | ||
.build()}; | ||
return exec::registerAggregateFunction( | ||
name, | ||
std::move(signatures), | ||
[name]( | ||
core::AggregationNode::Step /*step*/, | ||
const std::vector<TypePtr>& argTypes, | ||
const TypePtr& resultType, | ||
const core::QueryConfig& /*config*/) | ||
-> std::unique_ptr<exec::Aggregate> { | ||
VELOX_CHECK_EQ( | ||
argTypes.size(), 1, "{} takes at most one argument", name); | ||
return std::make_unique<SimpleAggregateAdapter<CollectListAggregate>>( | ||
resultType); | ||
}, | ||
withCompanionFunctions, | ||
overwrite); | ||
} | ||
} // namespace | ||
|
||
void registerCollectListAggregate( | ||
const std::string& prefix, | ||
bool withCompanionFunctions, | ||
bool overwrite) { | ||
registerCollectList( | ||
prefix + "collect_list", withCompanionFunctions, overwrite); | ||
} | ||
} // namespace facebook::velox::functions::aggregate::sparksql |
29 changes: 29 additions & 0 deletions
29
velox/functions/sparksql/aggregates/CollectListAggregate.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
/* | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include "velox/exec/SimpleAggregateAdapter.h" | ||
#include "velox/functions/lib/aggregates/ValueList.h" | ||
|
||
namespace facebook::velox::functions::aggregate::sparksql { | ||
|
||
void registerCollectListAggregate( | ||
const std::string& prefix, | ||
bool withCompanionFunctions, | ||
bool overwrite); | ||
|
||
} // namespace facebook::velox::functions::aggregate::sparksql |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
148 changes: 148 additions & 0 deletions
148
velox/functions/sparksql/aggregates/tests/CollectListAggregateTest.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
/* | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "velox/exec/tests/SimpleAggregateFunctionsRegistration.h" | ||
#include "velox/exec/tests/utils/AssertQueryBuilder.h" | ||
#include "velox/exec/tests/utils/PlanBuilder.h" | ||
#include "velox/functions/lib/aggregates/tests/utils/AggregationTestBase.h" | ||
#include "velox/functions/sparksql/aggregates/Register.h" | ||
|
||
using namespace facebook::velox::functions::aggregate::test; | ||
using facebook::velox::exec::test::AssertQueryBuilder; | ||
using facebook::velox::exec::test::PlanBuilder; | ||
|
||
namespace facebook::velox::functions::aggregate::sparksql::test { | ||
|
||
namespace { | ||
|
||
class CollectListAggregateTest : public AggregationTestBase { | ||
protected: | ||
void SetUp() override { | ||
AggregationTestBase::SetUp(); | ||
registerAggregateFunctions("spark_"); | ||
} | ||
|
||
RowVectorPtr fuzzFlat(const RowTypePtr& rowType, size_t size) { | ||
VectorFuzzer::Options options; | ||
options.vectorSize = size; | ||
VectorFuzzer fuzzer(options, pool()); | ||
return fuzzer.fuzzInputFlatRow(rowType); | ||
} | ||
}; | ||
|
||
TEST_F(CollectListAggregateTest, groupBy) { | ||
constexpr int32_t kNumGroups = 10; | ||
std::vector<RowVectorPtr> batches; | ||
batches.push_back( | ||
fuzzFlat(ROW({"c0", "a"}, {INTEGER(), ARRAY(VARCHAR())}), 100)); | ||
auto keys = batches[0]->childAt(0)->as<FlatVector<int32_t>>(); | ||
auto values = batches[0]->childAt(1)->as<ArrayVector>(); | ||
for (auto i = 0; i < keys->size(); ++i) { | ||
if (i % 10 == 0) { | ||
keys->setNull(i, true); | ||
} else { | ||
keys->set(i, i % kNumGroups); | ||
} | ||
|
||
if (i % 7 == 0) { | ||
values->setNull(i, true); | ||
} | ||
} | ||
|
||
for (auto i = 0; i < 9; ++i) { | ||
batches.push_back(batches[0]); | ||
} | ||
|
||
createDuckDbTable(batches); | ||
testAggregations( | ||
batches, | ||
{"c0"}, | ||
{"spark_collect_list(a)"}, | ||
{"c0", "array_sort(a0)"}, | ||
"SELECT c0, array_sort(array_agg(a)" | ||
"filter (where a is not null)) FROM tmp GROUP BY c0"); | ||
testAggregationsWithCompanion( | ||
batches, | ||
[](auto& /*builder*/) {}, | ||
{"c0"}, | ||
{"spark_collect_list(a)"}, | ||
{{ARRAY(VARCHAR())}}, | ||
{"c0", "array_sort(a0)"}, | ||
"SELECT c0, array_sort(array_agg(a)" | ||
"filter (where a is not null)) FROM tmp GROUP BY c0"); | ||
} | ||
|
||
TEST_F(CollectListAggregateTest, global) { | ||
vector_size_t size = 10; | ||
std::vector<RowVectorPtr> vectors = {makeRowVector({makeFlatVector<int32_t>( | ||
size, [](vector_size_t row) { return row * 2; }, nullEvery(3))})}; | ||
|
||
createDuckDbTable(vectors); | ||
testAggregations( | ||
vectors, | ||
{}, | ||
{"spark_collect_list(c0)"}, | ||
{"array_sort(a0)"}, | ||
"SELECT array_sort(array_agg(c0)" | ||
"filter (where c0 is not null)) FROM tmp"); | ||
testAggregationsWithCompanion( | ||
vectors, | ||
[](auto& /*builder*/) {}, | ||
{}, | ||
{"spark_collect_list(c0)"}, | ||
{{ARRAY(VARCHAR())}}, | ||
{"array_sort(a0)"}, | ||
"SELECT array_sort(array_agg(c0)" | ||
"filter (where c0 is not null)) FROM tmp"); | ||
} | ||
|
||
TEST_F(CollectListAggregateTest, ignoreNulls) { | ||
auto input = makeRowVector({makeNullableFlatVector<int32_t>( | ||
{1, 2, std::nullopt, 4, std::nullopt, 6}, INTEGER())}); | ||
// Spark will ignore all null values in the input. | ||
auto expected = makeRowVector({makeArrayVector<int32_t>({{1, 2, 4, 6}})}); | ||
testAggregations( | ||
{input}, {}, {"spark_collect_list(c0)"}, {"array_sort(a0)"}, {expected}); | ||
testAggregationsWithCompanion( | ||
{input}, | ||
[](auto& /*builder*/) {}, | ||
{}, | ||
{"spark_collect_list(c0)"}, | ||
{{INTEGER()}}, | ||
{"array_sort(a0)"}, | ||
{expected}, | ||
{}); | ||
} | ||
|
||
TEST_F(CollectListAggregateTest, allNullsInput) { | ||
std::vector<std::optional<int64_t>> allNull(100, std::nullopt); | ||
auto input = | ||
makeRowVector({makeNullableFlatVector<int64_t>(allNull, BIGINT())}); | ||
// If all input data is null, Spark will output an empty array. | ||
auto expected = makeRowVector({makeArrayVector<int32_t>({{}})}); | ||
testAggregations({input}, {}, {"spark_collect_list(c0)"}, {expected}); | ||
testAggregationsWithCompanion( | ||
{input}, | ||
[](auto& /*builder*/) {}, | ||
{}, | ||
{"spark_collect_list(c0)"}, | ||
{{BIGINT()}}, | ||
{}, | ||
{expected}, | ||
{}); | ||
} | ||
} // namespace | ||
} // namespace facebook::velox::functions::aggregate::sparksql::test |