-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for collect_list Spark aggregate function #9231
Changes from 6 commits
7870dd9
b89d2ee
c4dd224
36a7595
0f2b0d6
48edcdd
3b0a10a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
/* | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "velox/functions/sparksql/aggregates/CollectListAggregate.h" | ||
|
||
#include "velox/exec/SimpleAggregateAdapter.h" | ||
#include "velox/functions/lib/aggregates/ValueList.h" | ||
|
||
using namespace facebook::velox::aggregate; | ||
using namespace facebook::velox::exec; | ||
|
||
namespace facebook::velox::functions::aggregate::sparksql { | ||
namespace { | ||
class CollectListAggregate { | ||
public: | ||
using InputType = Row<Generic<T1>>; | ||
|
||
using IntermediateType = Array<Generic<T1>>; | ||
|
||
using OutputType = Array<Generic<T1>>; | ||
|
||
/// In Spark, when all inputs are null, the output is an empty array instead | ||
/// of null. Therefore, in the writeIntermediateResult and writeFinalResult, | ||
/// we still need to output the empty element_ when the group is null. This | ||
/// behavior can only be achieved when the default-null behavior is disabled. | ||
static constexpr bool default_null_behavior_ = false; | ||
|
||
static bool toIntermediate( | ||
exec::out_type<Array<Generic<T1>>>& out, | ||
exec::optional_arg_type<Generic<T1>> in) { | ||
if (in.has_value()) { | ||
out.add_item().copy_from(in.value()); | ||
return true; | ||
} | ||
return false; | ||
} | ||
|
||
struct AccumulatorType { | ||
ValueList elements_; | ||
|
||
explicit AccumulatorType(HashStringAllocator* /*allocator*/) | ||
: elements_{} {} | ||
|
||
static constexpr bool is_fixed_size_ = false; | ||
|
||
bool addInput( | ||
HashStringAllocator* allocator, | ||
exec::optional_arg_type<Generic<T1>> data) { | ||
if (data.has_value()) { | ||
elements_.appendValue(data, allocator); | ||
return true; | ||
} | ||
return false; | ||
} | ||
|
||
bool combine( | ||
HashStringAllocator* allocator, | ||
exec::optional_arg_type<IntermediateType> other) { | ||
if (!other.has_value()) { | ||
return false; | ||
} | ||
for (auto element : other.value()) { | ||
elements_.appendValue(element, allocator); | ||
} | ||
return true; | ||
} | ||
|
||
bool writeIntermediateResult( | ||
bool /*nonNullGroup*/, | ||
exec::out_type<IntermediateType>& out) { | ||
// If the group's accumulator is null, the corresponding intermediate | ||
// result is an empty array. | ||
copyValueListToArrayWriter(out, elements_); | ||
return true; | ||
} | ||
|
||
bool writeFinalResult( | ||
bool /*nonNullGroup*/, | ||
exec::out_type<OutputType>& out) { | ||
// If the group's accumulator is null, the corresponding result is an | ||
// empty array. | ||
copyValueListToArrayWriter(out, elements_); | ||
return true; | ||
} | ||
|
||
void destroy(HashStringAllocator* allocator) { | ||
elements_.free(allocator); | ||
} | ||
}; | ||
}; | ||
|
||
AggregateRegistrationResult registerCollectList( | ||
const std::string& name, | ||
bool withCompanionFunctions, | ||
bool overwrite) { | ||
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures{ | ||
exec::AggregateFunctionSignatureBuilder() | ||
.typeVariable("E") | ||
.returnType("array(E)") | ||
.intermediateType("array(E)") | ||
.argumentType("E") | ||
.build()}; | ||
return exec::registerAggregateFunction( | ||
name, | ||
std::move(signatures), | ||
[name]( | ||
core::AggregationNode::Step /*step*/, | ||
const std::vector<TypePtr>& argTypes, | ||
const TypePtr& resultType, | ||
const core::QueryConfig& /*config*/) | ||
-> std::unique_ptr<exec::Aggregate> { | ||
VELOX_CHECK_EQ( | ||
argTypes.size(), 1, "{} takes at most one argument", name); | ||
return std::make_unique<SimpleAggregateAdapter<CollectListAggregate>>( | ||
resultType); | ||
}, | ||
withCompanionFunctions, | ||
overwrite); | ||
} | ||
} // namespace | ||
|
||
void registerCollectListAggregate( | ||
const std::string& prefix, | ||
bool withCompanionFunctions, | ||
bool overwrite) { | ||
registerCollectList( | ||
prefix + "collect_list", withCompanionFunctions, overwrite); | ||
} | ||
} // namespace facebook::velox::functions::aggregate::sparksql |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
/* | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <string> | ||
|
||
namespace facebook::velox::functions::aggregate::sparksql { | ||
|
||
void registerCollectListAggregate( | ||
const std::string& prefix, | ||
bool withCompanionFunctions, | ||
bool overwrite); | ||
|
||
} // namespace facebook::velox::functions::aggregate::sparksql |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
/* | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "velox/functions/lib/aggregates/tests/utils/AggregationTestBase.h" | ||
#include "velox/functions/sparksql/aggregates/Register.h" | ||
|
||
using namespace facebook::velox::functions::aggregate::test; | ||
|
||
namespace facebook::velox::functions::aggregate::sparksql::test { | ||
|
||
namespace { | ||
|
||
class CollectListAggregateTest : public AggregationTestBase { | ||
protected: | ||
void SetUp() override { | ||
AggregationTestBase::SetUp(); | ||
registerAggregateFunctions("spark_"); | ||
} | ||
}; | ||
|
||
TEST_F(CollectListAggregateTest, groupBy) { | ||
liujiayi771 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
std::vector<RowVectorPtr> batches; | ||
// Creating 3 batches of input data. | ||
// 0: {0, null} {0, 1} {0, 2} | ||
// 1: {1, 1} {1, null} {1, 3} | ||
// 2: {2, 2} {2, 3} {2, null} | ||
// 3: {3, 3} {3, 4} {3, 5} | ||
// 4: {4, 4} {4, 5} {4, 6} | ||
for (auto i = 0; i < 3; i++) { | ||
RowVectorPtr data = makeRowVector( | ||
{makeFlatVector<int32_t>({0, 1, 2, 3, 4}), | ||
makeFlatVector<int64_t>( | ||
5, | ||
[&i](const vector_size_t& row) { return i + row; }, | ||
[&i](const auto& row) { return i == row; })}); | ||
batches.push_back(data); | ||
} | ||
|
||
auto expected = makeRowVector( | ||
{makeFlatVector<int32_t>({0, 1, 2, 3, 4}), | ||
makeArrayVectorFromJson<int64_t>( | ||
{"[1, 2]", "[1, 3]", "[2, 3]", "[3, 4, 5]", "[4, 5, 6]"})}); | ||
|
||
testAggregations( | ||
batches, | ||
{"c0"}, | ||
{"spark_collect_list(c1)"}, | ||
{"c0", "array_sort(a0)"}, | ||
{expected}); | ||
testAggregationsWithCompanion( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there any particular reasons companion function testing is not included as part of testAggregations? testAggregationsWithCompanion calls appear too verbose and repetitive. CC: @kagamiori There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @liujiayi771 Would you take a look at this comment? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mbasmanova I think the possible reason is that some aggregate functions have not registered the companion functions due to certain restrictions, such as when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What I don't understand is why we need to pass Why can't we just call
and have it test both regular functions as well as companion functions. CC: @kagamiori There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this is better, and the config parameter is also not necessary. Right now, many tests are calling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's do this refactoring in a follow-up. |
||
batches, | ||
[](auto& /*builder*/) {}, | ||
{"c0"}, | ||
{"spark_collect_list(c1)"}, | ||
{{BIGINT()}}, | ||
{"c0", "array_sort(a0)"}, | ||
{expected}, | ||
{}); | ||
} | ||
|
||
TEST_F(CollectListAggregateTest, global) { | ||
auto data = makeRowVector({makeNullableFlatVector<int32_t>( | ||
{std::nullopt, 1, 2, std::nullopt, 4, 5})}); | ||
auto expected = | ||
makeRowVector({makeArrayVectorFromJson<int32_t>({"[1, 2, 4, 5]"})}); | ||
|
||
testAggregations( | ||
{data}, {}, {"spark_collect_list(c0)"}, {"array_sort(a0)"}, {expected}); | ||
testAggregationsWithCompanion( | ||
{data}, | ||
[](auto& /*builder*/) {}, | ||
{}, | ||
{"spark_collect_list(c0)"}, | ||
{{INTEGER()}}, | ||
{"array_sort(a0)"}, | ||
{expected}); | ||
} | ||
|
||
TEST_F(CollectListAggregateTest, ignoreNulls) { | ||
auto input = makeRowVector({makeNullableFlatVector<int32_t>( | ||
liujiayi771 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{1, 2, std::nullopt, 4, std::nullopt, 6})}); | ||
// Spark will ignore all null values in the input. | ||
auto expected = | ||
makeRowVector({makeArrayVectorFromJson<int32_t>({"[1, 2, 4, 6]"})}); | ||
testAggregations( | ||
{input}, {}, {"spark_collect_list(c0)"}, {"array_sort(a0)"}, {expected}); | ||
testAggregationsWithCompanion( | ||
{input}, | ||
[](auto& /*builder*/) {}, | ||
{}, | ||
{"spark_collect_list(c0)"}, | ||
{{INTEGER()}}, | ||
{"array_sort(a0)"}, | ||
{expected}, | ||
{}); | ||
} | ||
|
||
TEST_F(CollectListAggregateTest, allNullsInput) { | ||
auto input = makeRowVector({makeAllNullFlatVector<int64_t>(100)}); | ||
// If all input data is null, Spark will output an empty array. | ||
auto expected = makeRowVector({makeArrayVectorFromJson<int32_t>({"[]"})}); | ||
testAggregations({input}, {}, {"spark_collect_list(c0)"}, {expected}); | ||
testAggregationsWithCompanion( | ||
{input}, | ||
[](auto& /*builder*/) {}, | ||
{}, | ||
{"spark_collect_list(c0)"}, | ||
{{BIGINT()}}, | ||
{}, | ||
{expected}, | ||
{}); | ||
} | ||
} // namespace | ||
} // namespace facebook::velox::functions::aggregate::sparksql::test |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -71,7 +71,8 @@ int main(int argc, char** argv) { | |
{"max_by", nullptr}, | ||
{"min_by", nullptr}, | ||
{"skewness", nullptr}, | ||
{"kurtosis", nullptr}}; | ||
{"kurtosis", nullptr}, | ||
{"collect_list", nullptr}}; | ||
|
||
size_t initialSeed = FLAGS_seed == 0 ? std::time(nullptr) : FLAGS_seed; | ||
auto duckQueryRunner = | ||
|
@@ -88,6 +89,9 @@ int main(int argc, char** argv) { | |
// coefficient. Meanwhile, DuckDB employs the sample kurtosis calculation | ||
// formula. The results from the two methods are completely different. | ||
"kurtosis", | ||
// When all data in a group are null, Spark returns an empty array while | ||
// DuckDB returns null. | ||
"collect_list", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mbasmanova I think this function should not be compared with DuckDB. If the fuzzer generates a group where all the data is null, DuckDB's result will be null, while Spark will return an empty array. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree. We need to change Fuzzer to verify results against Spark, not DuckDB: #9270 |
||
}); | ||
|
||
using Runner = facebook::velox::exec::test::AggregationFuzzerRunner; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mbasmanova I noticed that there's an omission here that hasn't been removed. I have removed it. Please help to re-import.