From abdba9fade9e6b16026ad44c9c93680bd5930980 Mon Sep 17 00:00:00 2001 From: Firestarman Date: Mon, 1 Mar 2021 13:54:21 +0800 Subject: [PATCH] Enable the tests for collect over window. Since cudf supports skipping null values by PRs rapidsai/cudf#7264, and rapidsai/cudf#7457. Signed-off-by: Firestarman --- .../src/main/python/window_function_test.py | 88 +++++++------------ 1 file changed, 33 insertions(+), 55 deletions(-) diff --git a/integration_tests/src/main/python/window_function_test.py b/integration_tests/src/main/python/window_function_test.py index d3372f83e35..c9889791278 100644 --- a/integration_tests/src/main/python/window_function_test.py +++ b/integration_tests/src/main/python/window_function_test.py @@ -223,66 +223,44 @@ def test_window_aggs_for_ranges_of_dates(data_gen): 'from window_agg_table' ) +_gen_data_for_collect = [ + ('a', RepeatSeqGen(LongGen(), length=20)), + ('b', IntegerGen()), + ('c_int', IntegerGen()), + ('c_long', LongGen()), + ('c_time', DateGen()), + ('c_string', StringGen()), + ('c_float', FloatGen()), + ('c_decimal', DecimalGen(precision=8, scale=3)), + ('c_struct', StructGen(children=[ + ['child_int', IntegerGen()], + ['child_time', DateGen()], + ['child_string', StringGen()], + ['child_decimal', DecimalGen(precision=8, scale=3)]]))] -def _gen_data_for_collect(nullable=True): - return [ - ('a', RepeatSeqGen(LongGen(), length=20)), - ('b', IntegerGen()), - ('c_int', IntegerGen(nullable=nullable)), - ('c_long', LongGen(nullable=nullable)), - ('c_time', DateGen(nullable=nullable)), - ('c_string', StringGen(nullable=nullable)), - ('c_float', FloatGen(nullable=nullable)), - ('c_decimal', DecimalGen(nullable=nullable, precision=8, scale=3)), - ('c_struct', StructGen(nullable=nullable, children=[ - ['child_int', IntegerGen()], - ['child_time', DateGen()], - ['child_string', StringGen()], - ['child_decimal', DecimalGen(precision=8, scale=3)]]))] - - - -_collect_sql_string =\ - ''' - select - collect_list(c_int) over - (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as collect_int, - collect_list(c_long) over - (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as collect_long, - collect_list(c_time) over - (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as collect_time, - collect_list(c_string) over - (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as collect_string, - collect_list(c_float) over - (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as collect_float, - collect_list(c_decimal) over - (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as collect_decimal, - collect_list(c_struct) over - (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as collect_struct - from window_collect_table - ''' # SortExec does not support array type, so sort the result locally. @ignore_order(local=True) -@pytest.mark.xfail(reason="https://github.com/NVIDIA/spark-rapids/issues/1638") def test_window_aggs_for_rows_collect_list(): assert_gpu_and_cpu_are_equal_sql( - lambda spark : gen_df(spark, _gen_data_for_collect(), length=2048), - "window_collect_table", - _collect_sql_string, - {'spark.rapids.sql.expression.CollectList': 'true'}) - - -''' - Spark will drop nulls when collecting, but seems GPU does not yet, so exceptions come up. - Now set nullable to false to verify the current functionality without null values. - Once native supports dropping nulls, will enable the tests above and remove this one. -''' -# SortExec does not support array type, so sort the result locally. -@ignore_order(local=True) -def test_window_aggs_for_rows_collect_list_no_nulls(): - assert_gpu_and_cpu_are_equal_sql( - lambda spark : gen_df(spark, _gen_data_for_collect(False), length=2048), + lambda spark : gen_df(spark, _gen_data_for_collect), "window_collect_table", - _collect_sql_string, + ''' + select + collect_list(c_int) over + (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as collect_int, + collect_list(c_long) over + (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as collect_long, + collect_list(c_time) over + (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as collect_time, + collect_list(c_string) over + (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as collect_string, + collect_list(c_float) over + (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as collect_float, + collect_list(c_decimal) over + (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as collect_decimal, + collect_list(c_struct) over + (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as collect_struct + from window_collect_table + ''', {'spark.rapids.sql.expression.CollectList': 'true'})