diff --git a/integration_tests/src/main/python/window_function_test.py b/integration_tests/src/main/python/window_function_test.py index 6b8a4d4b4d0..08435f61d76 100644 --- a/integration_tests/src/main/python/window_function_test.py +++ b/integration_tests/src/main/python/window_function_test.py @@ -60,9 +60,8 @@ def test_window_aggs_for_rows(data_gen): ' (partition by a order by b,c rows between 2 preceding and current row) as min_c_asc, ' ' count(1) over ' ' (partition by a order by b,c rows between UNBOUNDED preceding and UNBOUNDED following) as count_1, ' - # once https://github.com/NVIDIA/spark-rapids/issues/218 is fixed uncomment this - #' count(c) over ' - #' (partition by a order by b,c rows between UNBOUNDED preceding and UNBOUNDED following) as count_c, ' + ' count(c) over ' + ' (partition by a order by b,c rows between UNBOUNDED preceding and UNBOUNDED following) as count_c, ' ' row_number() over ' ' (partition by a order by b,c rows between UNBOUNDED preceding and CURRENT ROW) as row_num ' 'from window_agg_table ') @@ -101,10 +100,9 @@ def test_multi_types_window_aggs_for_rows_lead_lag(a_gen, b_gen, c_gen): defaultVal = gen_scalar_value(c_gen, force_no_nulls=False) def do_it(spark): - # once https://github.com/NVIDIA/spark-rapids/issues/218 is fixed uncomment this and put it in place below - #.withColumn('inc_count_c', f.count('c').over(inclusiveWindowSpec)) \ return gen_df(spark, data_gen, length=2048) \ .withColumn('inc_count_1', f.count('*').over(inclusiveWindowSpec)) \ + .withColumn('inc_count_c', f.count('c').over(inclusiveWindowSpec)) \ .withColumn('inc_max_c', f.max('c').over(inclusiveWindowSpec)) \ .withColumn('inc_min_c', f.min('c').over(inclusiveWindowSpec)) \ .withColumn('lead_5_c', f.lead('c', 5).over(baseWindowSpec)) \ @@ -136,10 +134,9 @@ def test_multi_types_window_aggs_for_rows(a_gen, b_gen, c_gen): inclusiveWindowSpec = baseWindowSpec.rowsBetween(-10, 100) def do_it(spark): - # once https://github.com/NVIDIA/spark-rapids/issues/218 is fixed uncomment this and put it in place below - #.withColumn('inc_count_c', f.count('c').over(inclusiveWindowSpec)) \ return gen_df(spark, data_gen, length=2048) \ .withColumn('inc_count_1', f.count('*').over(inclusiveWindowSpec)) \ + .withColumn('inc_count_c', f.count('c').over(inclusiveWindowSpec)) \ .withColumn('inc_max_c', f.max('c').over(inclusiveWindowSpec)) \ .withColumn('inc_min_c', f.min('c').over(inclusiveWindowSpec)) \ .withColumn('row_num', f.row_number().over(baseWindowSpec)) @@ -168,10 +165,9 @@ def test_window_aggs_for_ranges(data_gen): ' count(1) over ' ' (partition by a order by cast(b as timestamp) asc ' ' range between CURRENT ROW and UNBOUNDED following) as count_1_asc, ' - # once https://github.com/NVIDIA/spark-rapids/issues/218 is fixed uncomment this - #' count(c) over ' - #' (partition by a order by cast(b as timestamp) asc ' - #' range between CURRENT ROW and UNBOUNDED following) as count_c_asc, ' + ' count(c) over ' + ' (partition by a order by cast(b as timestamp) asc ' + ' range between CURRENT ROW and UNBOUNDED following) as count_c_asc, ' ' sum(c) over ' ' (partition by a order by cast(b as timestamp) asc ' ' range between UNBOUNDED preceding and CURRENT ROW) as sum_c_unbounded, ' @@ -194,18 +190,3 @@ def test_window_aggs_for_ranges_of_dates(data_gen): ' range between 1 preceding and 1 following) as sum_c_asc ' 'from window_agg_table' ) - -@pytest.mark.xfail(reason="[BUG] `COUNT(x)` should not count null values of `x` " - "(https://github.com/NVIDIA/spark-rapids/issues/218)") -@ignore_order -@pytest.mark.parametrize('data_gen', [_grpkey_longs_with_no_nulls], ids=idfn) -def test_window_aggs_for_rows_count_non_null(data_gen): - assert_gpu_and_cpu_are_equal_sql( - lambda spark: gen_df(spark, data_gen, length=2048), - "window_agg_table", - 'select ' - ' count(c) over ' - ' (partition by a order by b,c ' - ' rows between UNBOUNDED preceding and UNBOUNDED following) as count_non_null ' - 'from window_agg_table ' - ) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala index 04caf236d1c..a7a08c32c98 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala @@ -74,16 +74,6 @@ class GpuWindowExpressionMeta( windowFunction match { case aggregateExpression : AggregateExpression => aggregateExpression.aggregateFunction match { - // Count does not work in these cases because of a bug in cudf where a rolling count - // does not do the correct thing for null entries - // Once https://github.com/rapidsai/cudf/issues/6343 - // is fixed this can be deleted and the check will go to the next case - // where it will match and pass. - case Count(exp) => - if (!exp.forall(x => x.isInstanceOf[Literal])) { - willNotWorkOnGpu(s"Currently, only COUNT(1) and COUNT(*) are supported. " + - s"COUNT($exp) is not supported in windowing.") - } // Sadly not all aggregations work for window operations yet, so explicitly allow the // ones that do work. case Count(_) | Sum(_) | Min(_) | Max(_) => // Supported. diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/WindowFunctionSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/WindowFunctionSuite.scala index b90aba9dfb4..a7aa1ad28aa 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/WindowFunctionSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/WindowFunctionSuite.scala @@ -136,7 +136,8 @@ class WindowFunctionSuite extends SparkQueryCompareTestSuite { | SUM(dollars) OVER $windowClause, | MIN(dollars) OVER $windowClause, | MAX(dollars) OVER $windowClause, - | COUNT(1) OVER $windowClause, + | COUNT(dollars) OVER $windowClause, + | COUNT(1) OVER $windowClause, | COUNT(*) OVER $windowClause | FROM mytable |