diff --git a/velox/docs/functions/spark/aggregate.rst b/velox/docs/functions/spark/aggregate.rst index 18c501ff99ac..a9b9ca9b3774 100644 --- a/velox/docs/functions/spark/aggregate.rst +++ b/velox/docs/functions/spark/aggregate.rst @@ -53,6 +53,11 @@ General Aggregate Functions ``hash`` cannot be null. +.. spark:function:: collect_list(x) -> array<[same as x]> + + Returns an array created from the input ``x`` elements. Ignores null + inputs, and returns an empty array when all inputs are null. + .. spark:function:: first(x) -> x Returns the first value of `x`. diff --git a/velox/functions/sparksql/aggregates/CMakeLists.txt b/velox/functions/sparksql/aggregates/CMakeLists.txt index 011ff1dfeb39..3c5fd2f9e54e 100644 --- a/velox/functions/sparksql/aggregates/CMakeLists.txt +++ b/velox/functions/sparksql/aggregates/CMakeLists.txt @@ -17,6 +17,7 @@ add_library( BitwiseXorAggregate.cpp BloomFilterAggAggregate.cpp CentralMomentsAggregate.cpp + CollectListAggregate.cpp FirstLastAggregate.cpp MinMaxByAggregate.cpp Register.cpp diff --git a/velox/functions/sparksql/aggregates/CollectListAggregate.cpp b/velox/functions/sparksql/aggregates/CollectListAggregate.cpp new file mode 100644 index 000000000000..e2c14cfa7969 --- /dev/null +++ b/velox/functions/sparksql/aggregates/CollectListAggregate.cpp @@ -0,0 +1,142 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/functions/sparksql/aggregates/CollectListAggregate.h" + +#include "velox/exec/SimpleAggregateAdapter.h" +#include "velox/functions/lib/aggregates/ValueList.h" + +using namespace facebook::velox::aggregate; +using namespace facebook::velox::exec; + +namespace facebook::velox::functions::aggregate::sparksql { +namespace { +class CollectListAggregate { + public: + using InputType = Row>; + + using IntermediateType = Array>; + + using OutputType = Array>; + + /// In Spark, when all inputs are null, the output is an empty array instead + /// of null. Therefore, in the writeIntermediateResult and writeFinalResult, + /// we still need to output the empty element_ when the group is null. This + /// behavior can only be achieved when the default-null behavior is disabled. + static constexpr bool default_null_behavior_ = false; + + static bool toIntermediate( + exec::out_type>>& out, + exec::optional_arg_type> in) { + if (in.has_value()) { + out.add_item().copy_from(in.value()); + return true; + } + return false; + } + + struct AccumulatorType { + ValueList elements_; + + explicit AccumulatorType(HashStringAllocator* /*allocator*/) + : elements_{} {} + + static constexpr bool is_fixed_size_ = false; + + bool addInput( + HashStringAllocator* allocator, + exec::optional_arg_type> data) { + if (data.has_value()) { + elements_.appendValue(data, allocator); + return true; + } + return false; + } + + bool combine( + HashStringAllocator* allocator, + exec::optional_arg_type other) { + if (!other.has_value()) { + return false; + } + for (auto element : other.value()) { + elements_.appendValue(element, allocator); + } + return true; + } + + bool writeIntermediateResult( + bool /*nonNullGroup*/, + exec::out_type& out) { + // If the group's accumulator is null, the corresponding intermediate + // result is an empty array. + copyValueListToArrayWriter(out, elements_); + return true; + } + + bool writeFinalResult( + bool /*nonNullGroup*/, + exec::out_type& out) { + // If the group's accumulator is null, the corresponding result is an + // empty array. + copyValueListToArrayWriter(out, elements_); + return true; + } + + void destroy(HashStringAllocator* allocator) { + elements_.free(allocator); + } + }; +}; + +AggregateRegistrationResult registerCollectList( + const std::string& name, + bool withCompanionFunctions, + bool overwrite) { + std::vector> signatures{ + exec::AggregateFunctionSignatureBuilder() + .typeVariable("E") + .returnType("array(E)") + .intermediateType("array(E)") + .argumentType("E") + .build()}; + return exec::registerAggregateFunction( + name, + std::move(signatures), + [name]( + core::AggregationNode::Step /*step*/, + const std::vector& argTypes, + const TypePtr& resultType, + const core::QueryConfig& /*config*/) + -> std::unique_ptr { + VELOX_CHECK_EQ( + argTypes.size(), 1, "{} takes at most one argument", name); + return std::make_unique>( + resultType); + }, + withCompanionFunctions, + overwrite); +} +} // namespace + +void registerCollectListAggregate( + const std::string& prefix, + bool withCompanionFunctions, + bool overwrite) { + registerCollectList( + prefix + "collect_list", withCompanionFunctions, overwrite); +} +} // namespace facebook::velox::functions::aggregate::sparksql diff --git a/velox/functions/sparksql/aggregates/CollectListAggregate.h b/velox/functions/sparksql/aggregates/CollectListAggregate.h new file mode 100644 index 000000000000..3c32023db95e --- /dev/null +++ b/velox/functions/sparksql/aggregates/CollectListAggregate.h @@ -0,0 +1,28 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace facebook::velox::functions::aggregate::sparksql { + +void registerCollectListAggregate( + const std::string& prefix, + bool withCompanionFunctions, + bool overwrite); + +} // namespace facebook::velox::functions::aggregate::sparksql diff --git a/velox/functions/sparksql/aggregates/Register.cpp b/velox/functions/sparksql/aggregates/Register.cpp index 7a2886dff2de..8b60c91b8894 100644 --- a/velox/functions/sparksql/aggregates/Register.cpp +++ b/velox/functions/sparksql/aggregates/Register.cpp @@ -20,6 +20,7 @@ #include "velox/functions/sparksql/aggregates/BitwiseXorAggregate.h" #include "velox/functions/sparksql/aggregates/BloomFilterAggAggregate.h" #include "velox/functions/sparksql/aggregates/CentralMomentsAggregate.h" +#include "velox/functions/sparksql/aggregates/CollectListAggregate.h" #include "velox/functions/sparksql/aggregates/SumAggregate.h" namespace facebook::velox::functions::aggregate::sparksql { @@ -45,5 +46,6 @@ void registerAggregateFunctions( registerAverage(prefix + "avg", withCompanionFunctions, overwrite); registerSum(prefix + "sum", withCompanionFunctions, overwrite); registerCentralMomentsAggregate(prefix, withCompanionFunctions, overwrite); + registerCollectListAggregate(prefix, withCompanionFunctions, overwrite); } } // namespace facebook::velox::functions::aggregate::sparksql diff --git a/velox/functions/sparksql/aggregates/tests/CMakeLists.txt b/velox/functions/sparksql/aggregates/tests/CMakeLists.txt index f7a5fdf05cdc..a769b5c3d640 100644 --- a/velox/functions/sparksql/aggregates/tests/CMakeLists.txt +++ b/velox/functions/sparksql/aggregates/tests/CMakeLists.txt @@ -18,6 +18,7 @@ add_executable( BitwiseXorAggregationTest.cpp BloomFilterAggAggregateTest.cpp CentralMomentsAggregationTest.cpp + CollectListAggregateTest.cpp FirstAggregateTest.cpp LastAggregateTest.cpp Main.cpp diff --git a/velox/functions/sparksql/aggregates/tests/CollectListAggregateTest.cpp b/velox/functions/sparksql/aggregates/tests/CollectListAggregateTest.cpp new file mode 100644 index 000000000000..73088a47b620 --- /dev/null +++ b/velox/functions/sparksql/aggregates/tests/CollectListAggregateTest.cpp @@ -0,0 +1,127 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/functions/lib/aggregates/tests/utils/AggregationTestBase.h" +#include "velox/functions/sparksql/aggregates/Register.h" + +using namespace facebook::velox::functions::aggregate::test; + +namespace facebook::velox::functions::aggregate::sparksql::test { + +namespace { + +class CollectListAggregateTest : public AggregationTestBase { + protected: + void SetUp() override { + AggregationTestBase::SetUp(); + registerAggregateFunctions("spark_"); + } +}; + +TEST_F(CollectListAggregateTest, groupBy) { + std::vector batches; + // Creating 3 batches of input data. + // 0: {0, null} {0, 1} {0, 2} + // 1: {1, 1} {1, null} {1, 3} + // 2: {2, 2} {2, 3} {2, null} + // 3: {3, 3} {3, 4} {3, 5} + // 4: {4, 4} {4, 5} {4, 6} + for (auto i = 0; i < 3; i++) { + RowVectorPtr data = makeRowVector( + {makeFlatVector({0, 1, 2, 3, 4}), + makeFlatVector( + 5, + [&i](const vector_size_t& row) { return i + row; }, + [&i](const auto& row) { return i == row; })}); + batches.push_back(data); + } + + auto expected = makeRowVector( + {makeFlatVector({0, 1, 2, 3, 4}), + makeArrayVectorFromJson( + {"[1, 2]", "[1, 3]", "[2, 3]", "[3, 4, 5]", "[4, 5, 6]"})}); + + testAggregations( + batches, + {"c0"}, + {"spark_collect_list(c1)"}, + {"c0", "array_sort(a0)"}, + {expected}); + testAggregationsWithCompanion( + batches, + [](auto& /*builder*/) {}, + {"c0"}, + {"spark_collect_list(c1)"}, + {{BIGINT()}}, + {"c0", "array_sort(a0)"}, + {expected}, + {}); +} + +TEST_F(CollectListAggregateTest, global) { + auto data = makeRowVector({makeNullableFlatVector( + {std::nullopt, 1, 2, std::nullopt, 4, 5})}); + auto expected = + makeRowVector({makeArrayVectorFromJson({"[1, 2, 4, 5]"})}); + + testAggregations( + {data}, {}, {"spark_collect_list(c0)"}, {"array_sort(a0)"}, {expected}); + testAggregationsWithCompanion( + {data}, + [](auto& /*builder*/) {}, + {}, + {"spark_collect_list(c0)"}, + {{INTEGER()}}, + {"array_sort(a0)"}, + {expected}); +} + +TEST_F(CollectListAggregateTest, ignoreNulls) { + auto input = makeRowVector({makeNullableFlatVector( + {1, 2, std::nullopt, 4, std::nullopt, 6})}); + // Spark will ignore all null values in the input. + auto expected = + makeRowVector({makeArrayVectorFromJson({"[1, 2, 4, 6]"})}); + testAggregations( + {input}, {}, {"spark_collect_list(c0)"}, {"array_sort(a0)"}, {expected}); + testAggregationsWithCompanion( + {input}, + [](auto& /*builder*/) {}, + {}, + {"spark_collect_list(c0)"}, + {{INTEGER()}}, + {"array_sort(a0)"}, + {expected}, + {}); +} + +TEST_F(CollectListAggregateTest, allNullsInput) { + auto input = makeRowVector({makeAllNullFlatVector(100)}); + // If all input data is null, Spark will output an empty array. + auto expected = makeRowVector({makeArrayVectorFromJson({"[]"})}); + testAggregations({input}, {}, {"spark_collect_list(c0)"}, {expected}); + testAggregationsWithCompanion( + {input}, + [](auto& /*builder*/) {}, + {}, + {"spark_collect_list(c0)"}, + {{BIGINT()}}, + {}, + {expected}, + {}); +} +} // namespace +} // namespace facebook::velox::functions::aggregate::sparksql::test diff --git a/velox/functions/sparksql/fuzzer/SparkAggregationFuzzerTest.cpp b/velox/functions/sparksql/fuzzer/SparkAggregationFuzzerTest.cpp index 89d47224fde1..2b18f21dc7b5 100644 --- a/velox/functions/sparksql/fuzzer/SparkAggregationFuzzerTest.cpp +++ b/velox/functions/sparksql/fuzzer/SparkAggregationFuzzerTest.cpp @@ -71,7 +71,8 @@ int main(int argc, char** argv) { {"max_by", nullptr}, {"min_by", nullptr}, {"skewness", nullptr}, - {"kurtosis", nullptr}}; + {"kurtosis", nullptr}, + {"collect_list", nullptr}}; size_t initialSeed = FLAGS_seed == 0 ? std::time(nullptr) : FLAGS_seed; auto duckQueryRunner = @@ -88,6 +89,9 @@ int main(int argc, char** argv) { // coefficient. Meanwhile, DuckDB employs the sample kurtosis calculation // formula. The results from the two methods are completely different. "kurtosis", + // When all data in a group are null, Spark returns an empty array while + // DuckDB returns null. + "collect_list", }); using Runner = facebook::velox::exec::test::AggregationFuzzerRunner;