diff --git a/velox/functions/sparksql/aggregates/tests/CollectListAggregateTest.cpp b/velox/functions/sparksql/aggregates/tests/CollectListAggregateTest.cpp index a591e777d271..73088a47b620 100644 --- a/velox/functions/sparksql/aggregates/tests/CollectListAggregateTest.cpp +++ b/velox/functions/sparksql/aggregates/tests/CollectListAggregateTest.cpp @@ -49,14 +49,17 @@ TEST_F(CollectListAggregateTest, groupBy) { batches.push_back(data); } - createDuckDbTable(batches); + auto expected = makeRowVector( + {makeFlatVector({0, 1, 2, 3, 4}), + makeArrayVectorFromJson( + {"[1, 2]", "[1, 3]", "[2, 3]", "[3, 4, 5]", "[4, 5, 6]"})}); + testAggregations( batches, {"c0"}, {"spark_collect_list(c1)"}, {"c0", "array_sort(a0)"}, - "SELECT c0, array_sort(array_agg(c1)" - "filter (where c1 is not null)) FROM tmp GROUP BY c0"); + {expected}); testAggregationsWithCompanion( batches, [](auto& /*builder*/) {}, @@ -64,22 +67,18 @@ TEST_F(CollectListAggregateTest, groupBy) { {"spark_collect_list(c1)"}, {{BIGINT()}}, {"c0", "array_sort(a0)"}, - "SELECT c0, array_sort(array_agg(c1)" - "filter (where c1 is not null)) FROM tmp GROUP BY c0"); + {expected}, + {}); } TEST_F(CollectListAggregateTest, global) { auto data = makeRowVector({makeNullableFlatVector( {std::nullopt, 1, 2, std::nullopt, 4, 5})}); + auto expected = + makeRowVector({makeArrayVectorFromJson({"[1, 2, 4, 5]"})}); - createDuckDbTable({data}); testAggregations( - {data}, - {}, - {"spark_collect_list(c0)"}, - {"array_sort(a0)"}, - "SELECT array_sort(array_agg(c0)" - "filter (where c0 is not null)) FROM tmp"); + {data}, {}, {"spark_collect_list(c0)"}, {"array_sort(a0)"}, {expected}); testAggregationsWithCompanion( {data}, [](auto& /*builder*/) {}, @@ -87,8 +86,7 @@ TEST_F(CollectListAggregateTest, global) { {"spark_collect_list(c0)"}, {{INTEGER()}}, {"array_sort(a0)"}, - "SELECT array_sort(array_agg(c0)" - "filter (where c0 is not null)) FROM tmp"); + {expected}); } TEST_F(CollectListAggregateTest, ignoreNulls) {