diff --git a/velox/functions/sparksql/aggregates/CMakeLists.txt b/velox/functions/sparksql/aggregates/CMakeLists.txt index 011ff1dfeb39..3c5fd2f9e54e 100644 --- a/velox/functions/sparksql/aggregates/CMakeLists.txt +++ b/velox/functions/sparksql/aggregates/CMakeLists.txt @@ -17,6 +17,7 @@ add_library( BitwiseXorAggregate.cpp BloomFilterAggAggregate.cpp CentralMomentsAggregate.cpp + CollectListAggregate.cpp FirstLastAggregate.cpp MinMaxByAggregate.cpp Register.cpp diff --git a/velox/functions/sparksql/aggregates/CollectListAggregate.cpp b/velox/functions/sparksql/aggregates/CollectListAggregate.cpp new file mode 100644 index 000000000000..101eee9d8a11 --- /dev/null +++ b/velox/functions/sparksql/aggregates/CollectListAggregate.cpp @@ -0,0 +1,141 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/functions/sparksql/aggregates/CollectListAggregate.h" + +using namespace facebook::velox::aggregate; +using namespace facebook::velox::exec; + +namespace facebook::velox::functions::aggregate::sparksql { +namespace { +class CollectListAggregate { + public: + using InputType = Row>; + + using IntermediateType = Array>; + + using OutputType = Array>; + + /// In Spark, when all inputs are null, the output is an empty array instead + /// of null. Therefore, in the writeIntermediateResult and writeFinalResult, + /// we still need to output the empty element_ when the group is null. This + /// behavior can only be achieved when the default-null behavior is disabled. + static constexpr bool default_null_behavior_ = false; + + static bool toIntermediate( + exec::out_type>>& out, + exec::optional_arg_type> in) { + if (in.has_value()) { + out.add_item().copy_from(in.value()); + return true; + } + return false; + } + + struct AccumulatorType { + ValueList elements_; + + AccumulatorType() = delete; + + explicit AccumulatorType(HashStringAllocator* /*allocator*/) + : elements_{} {} + + static constexpr bool is_fixed_size_ = false; + + bool addInput( + HashStringAllocator* allocator, + exec::optional_arg_type> data) { + if (data.has_value()) { + elements_.appendValue(data, allocator); + return true; + } + return false; + } + + bool combine( + HashStringAllocator* allocator, + exec::optional_arg_type other) { + if (!other.has_value()) { + return false; + } + for (auto element : other.value()) { + elements_.appendValue(element, allocator); + } + return true; + } + + bool writeIntermediateResult( + bool /*nonNullGroup*/, + exec::out_type& out) { + // If the group's accumulator is null, the corresponding intermediate + // result is an empty list. + copyValueListToArrayWriter(out, elements_); + return true; + } + + bool writeFinalResult( + bool /*nonNullGroup*/, + exec::out_type& out) { + // If the group's accumulator is null, the corresponding result is an + // empty list. + copyValueListToArrayWriter(out, elements_); + return true; + } + + void destroy(HashStringAllocator* allocator) { + elements_.free(allocator); + } + }; +}; + +AggregateRegistrationResult registerCollectList( + const std::string& name, + bool withCompanionFunctions, + bool overwrite) { + std::vector> signatures{ + exec::AggregateFunctionSignatureBuilder() + .typeVariable("E") + .returnType("array(E)") + .intermediateType("array(E)") + .argumentType("E") + .build()}; + return exec::registerAggregateFunction( + name, + std::move(signatures), + [name]( + core::AggregationNode::Step /*step*/, + const std::vector& argTypes, + const TypePtr& resultType, + const core::QueryConfig& /*config*/) + -> std::unique_ptr { + VELOX_CHECK_EQ( + argTypes.size(), 1, "{} takes at most one argument", name); + return std::make_unique>( + resultType); + }, + withCompanionFunctions, + overwrite); +} +} // namespace + +void registerCollectListAggregate( + const std::string& prefix, + bool withCompanionFunctions, + bool overwrite) { + registerCollectList( + prefix + "collect_list", withCompanionFunctions, overwrite); +} +} // namespace facebook::velox::functions::aggregate::sparksql diff --git a/velox/functions/sparksql/aggregates/CollectListAggregate.h b/velox/functions/sparksql/aggregates/CollectListAggregate.h new file mode 100644 index 000000000000..5696d68e1f0d --- /dev/null +++ b/velox/functions/sparksql/aggregates/CollectListAggregate.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/exec/SimpleAggregateAdapter.h" +#include "velox/functions/lib/aggregates/ValueList.h" + +namespace facebook::velox::functions::aggregate::sparksql { + +void registerCollectListAggregate( + const std::string& prefix, + bool withCompanionFunctions, + bool overwrite); + +} // namespace facebook::velox::functions::aggregate::sparksql diff --git a/velox/functions/sparksql/aggregates/Register.cpp b/velox/functions/sparksql/aggregates/Register.cpp index 7a2886dff2de..8b60c91b8894 100644 --- a/velox/functions/sparksql/aggregates/Register.cpp +++ b/velox/functions/sparksql/aggregates/Register.cpp @@ -20,6 +20,7 @@ #include "velox/functions/sparksql/aggregates/BitwiseXorAggregate.h" #include "velox/functions/sparksql/aggregates/BloomFilterAggAggregate.h" #include "velox/functions/sparksql/aggregates/CentralMomentsAggregate.h" +#include "velox/functions/sparksql/aggregates/CollectListAggregate.h" #include "velox/functions/sparksql/aggregates/SumAggregate.h" namespace facebook::velox::functions::aggregate::sparksql { @@ -45,5 +46,6 @@ void registerAggregateFunctions( registerAverage(prefix + "avg", withCompanionFunctions, overwrite); registerSum(prefix + "sum", withCompanionFunctions, overwrite); registerCentralMomentsAggregate(prefix, withCompanionFunctions, overwrite); + registerCollectListAggregate(prefix, withCompanionFunctions, overwrite); } } // namespace facebook::velox::functions::aggregate::sparksql diff --git a/velox/functions/sparksql/aggregates/tests/CMakeLists.txt b/velox/functions/sparksql/aggregates/tests/CMakeLists.txt index f7a5fdf05cdc..ee45daca5e7e 100644 --- a/velox/functions/sparksql/aggregates/tests/CMakeLists.txt +++ b/velox/functions/sparksql/aggregates/tests/CMakeLists.txt @@ -18,6 +18,7 @@ add_executable( BitwiseXorAggregationTest.cpp BloomFilterAggAggregateTest.cpp CentralMomentsAggregationTest.cpp + CollectListAggregateTest.cpp FirstAggregateTest.cpp LastAggregateTest.cpp Main.cpp @@ -34,6 +35,7 @@ target_link_libraries( velox_functions_aggregates_test_lib velox_functions_spark_aggregates velox_hive_connector + velox_vector_fuzzer gflags::gflags gtest gtest_main) diff --git a/velox/functions/sparksql/aggregates/tests/CollectListAggregateTest.cpp b/velox/functions/sparksql/aggregates/tests/CollectListAggregateTest.cpp new file mode 100644 index 000000000000..3d69420abe5f --- /dev/null +++ b/velox/functions/sparksql/aggregates/tests/CollectListAggregateTest.cpp @@ -0,0 +1,148 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/tests/SimpleAggregateFunctionsRegistration.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/functions/lib/aggregates/tests/utils/AggregationTestBase.h" +#include "velox/functions/sparksql/aggregates/Register.h" + +using namespace facebook::velox::functions::aggregate::test; +using facebook::velox::exec::test::AssertQueryBuilder; +using facebook::velox::exec::test::PlanBuilder; + +namespace facebook::velox::functions::aggregate::sparksql::test { + +namespace { + +class CollectListAggregateTest : public AggregationTestBase { + protected: + void SetUp() override { + AggregationTestBase::SetUp(); + registerAggregateFunctions("spark_"); + } + + RowVectorPtr fuzzFlat(const RowTypePtr& rowType, size_t size) { + VectorFuzzer::Options options; + options.vectorSize = size; + VectorFuzzer fuzzer(options, pool()); + return fuzzer.fuzzInputFlatRow(rowType); + } +}; + +TEST_F(CollectListAggregateTest, groupBy) { + constexpr int32_t kNumGroups = 10; + std::vector batches; + batches.push_back( + fuzzFlat(ROW({"c0", "a"}, {INTEGER(), ARRAY(VARCHAR())}), 100)); + auto keys = batches[0]->childAt(0)->as>(); + auto values = batches[0]->childAt(1)->as(); + for (auto i = 0; i < keys->size(); ++i) { + if (i % 10 == 0) { + keys->setNull(i, true); + } else { + keys->set(i, i % kNumGroups); + } + + if (i % 7 == 0) { + values->setNull(i, true); + } + } + + for (auto i = 0; i < 9; ++i) { + batches.push_back(batches[0]); + } + + createDuckDbTable(batches); + testAggregations( + batches, + {"c0"}, + {"spark_collect_list(a)"}, + {"c0", "array_sort(a0)"}, + "SELECT c0, array_sort(array_agg(a)" + "filter (where a is not null)) FROM tmp GROUP BY c0"); + testAggregationsWithCompanion( + batches, + [](auto& /*builder*/) {}, + {"c0"}, + {"spark_collect_list(a)"}, + {{ARRAY(VARCHAR())}}, + {"c0", "array_sort(a0)"}, + "SELECT c0, array_sort(array_agg(a)" + "filter (where a is not null)) FROM tmp GROUP BY c0"); +} + +TEST_F(CollectListAggregateTest, global) { + vector_size_t size = 10; + std::vector vectors = {makeRowVector({makeFlatVector( + size, [](vector_size_t row) { return row * 2; }, nullEvery(3))})}; + + createDuckDbTable(vectors); + testAggregations( + vectors, + {}, + {"spark_collect_list(c0)"}, + {"array_sort(a0)"}, + "SELECT array_sort(array_agg(c0)" + "filter (where c0 is not null)) FROM tmp"); + testAggregationsWithCompanion( + vectors, + [](auto& /*builder*/) {}, + {}, + {"spark_collect_list(c0)"}, + {{ARRAY(VARCHAR())}}, + {"array_sort(a0)"}, + "SELECT array_sort(array_agg(c0)" + "filter (where c0 is not null)) FROM tmp"); +} + +TEST_F(CollectListAggregateTest, ignoreNulls) { + auto input = makeRowVector({makeNullableFlatVector( + {1, 2, std::nullopt, 4, std::nullopt, 6}, INTEGER())}); + // Spark will ignore all null values in the input. + auto expected = makeRowVector({makeArrayVector({{1, 2, 4, 6}})}); + testAggregations( + {input}, {}, {"spark_collect_list(c0)"}, {"array_sort(a0)"}, {expected}); + testAggregationsWithCompanion( + {input}, + [](auto& /*builder*/) {}, + {}, + {"spark_collect_list(c0)"}, + {{INTEGER()}}, + {"array_sort(a0)"}, + {expected}, + {}); +} + +TEST_F(CollectListAggregateTest, allNullsInput) { + std::vector> allNull(100, std::nullopt); + auto input = + makeRowVector({makeNullableFlatVector(allNull, BIGINT())}); + // If all input data is null, Spark will output an empty array. + auto expected = makeRowVector({makeArrayVector({{}})}); + testAggregations({input}, {}, {"spark_collect_list(c0)"}, {expected}); + testAggregationsWithCompanion( + {input}, + [](auto& /*builder*/) {}, + {}, + {"spark_collect_list(c0)"}, + {{BIGINT()}}, + {}, + {expected}, + {}); +} +} // namespace +} // namespace facebook::velox::functions::aggregate::sparksql::test