Skip to content

Commit

Permalink
Add in support/tests for a window count on a column (NVIDIA#935)
Browse files Browse the repository at this point in the history
Signed-off-by: Robert (Bobby) Evans <bobby@apache.org>
  • Loading branch information
revans2 authored Oct 14, 2020
1 parent 645a0cf commit ea7dd8f
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 37 deletions.
33 changes: 7 additions & 26 deletions integration_tests/src/main/python/window_function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,8 @@ def test_window_aggs_for_rows(data_gen):
' (partition by a order by b,c rows between 2 preceding and current row) as min_c_asc, '
' count(1) over '
' (partition by a order by b,c rows between UNBOUNDED preceding and UNBOUNDED following) as count_1, '
# once https://github.com/NVIDIA/spark-rapids/issues/218 is fixed uncomment this
#' count(c) over '
#' (partition by a order by b,c rows between UNBOUNDED preceding and UNBOUNDED following) as count_c, '
' count(c) over '
' (partition by a order by b,c rows between UNBOUNDED preceding and UNBOUNDED following) as count_c, '
' row_number() over '
' (partition by a order by b,c rows between UNBOUNDED preceding and CURRENT ROW) as row_num '
'from window_agg_table ')
Expand Down Expand Up @@ -101,10 +100,9 @@ def test_multi_types_window_aggs_for_rows_lead_lag(a_gen, b_gen, c_gen):
defaultVal = gen_scalar_value(c_gen, force_no_nulls=False)

def do_it(spark):
# once https://github.com/NVIDIA/spark-rapids/issues/218 is fixed uncomment this and put it in place below
#.withColumn('inc_count_c', f.count('c').over(inclusiveWindowSpec)) \
return gen_df(spark, data_gen, length=2048) \
.withColumn('inc_count_1', f.count('*').over(inclusiveWindowSpec)) \
.withColumn('inc_count_c', f.count('c').over(inclusiveWindowSpec)) \
.withColumn('inc_max_c', f.max('c').over(inclusiveWindowSpec)) \
.withColumn('inc_min_c', f.min('c').over(inclusiveWindowSpec)) \
.withColumn('lead_5_c', f.lead('c', 5).over(baseWindowSpec)) \
Expand Down Expand Up @@ -136,10 +134,9 @@ def test_multi_types_window_aggs_for_rows(a_gen, b_gen, c_gen):
inclusiveWindowSpec = baseWindowSpec.rowsBetween(-10, 100)

def do_it(spark):
# once https://github.com/NVIDIA/spark-rapids/issues/218 is fixed uncomment this and put it in place below
#.withColumn('inc_count_c', f.count('c').over(inclusiveWindowSpec)) \
return gen_df(spark, data_gen, length=2048) \
.withColumn('inc_count_1', f.count('*').over(inclusiveWindowSpec)) \
.withColumn('inc_count_c', f.count('c').over(inclusiveWindowSpec)) \
.withColumn('inc_max_c', f.max('c').over(inclusiveWindowSpec)) \
.withColumn('inc_min_c', f.min('c').over(inclusiveWindowSpec)) \
.withColumn('row_num', f.row_number().over(baseWindowSpec))
Expand Down Expand Up @@ -168,10 +165,9 @@ def test_window_aggs_for_ranges(data_gen):
' count(1) over '
' (partition by a order by cast(b as timestamp) asc '
' range between CURRENT ROW and UNBOUNDED following) as count_1_asc, '
# once https://github.com/NVIDIA/spark-rapids/issues/218 is fixed uncomment this
#' count(c) over '
#' (partition by a order by cast(b as timestamp) asc '
#' range between CURRENT ROW and UNBOUNDED following) as count_c_asc, '
' count(c) over '
' (partition by a order by cast(b as timestamp) asc '
' range between CURRENT ROW and UNBOUNDED following) as count_c_asc, '
' sum(c) over '
' (partition by a order by cast(b as timestamp) asc '
' range between UNBOUNDED preceding and CURRENT ROW) as sum_c_unbounded, '
Expand All @@ -194,18 +190,3 @@ def test_window_aggs_for_ranges_of_dates(data_gen):
' range between 1 preceding and 1 following) as sum_c_asc '
'from window_agg_table'
)

@pytest.mark.xfail(reason="[BUG] `COUNT(x)` should not count null values of `x` "
"(https://github.com/NVIDIA/spark-rapids/issues/218)")
@ignore_order
@pytest.mark.parametrize('data_gen', [_grpkey_longs_with_no_nulls], ids=idfn)
def test_window_aggs_for_rows_count_non_null(data_gen):
assert_gpu_and_cpu_are_equal_sql(
lambda spark: gen_df(spark, data_gen, length=2048),
"window_agg_table",
'select '
' count(c) over '
' (partition by a order by b,c '
' rows between UNBOUNDED preceding and UNBOUNDED following) as count_non_null '
'from window_agg_table '
)
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,6 @@ class GpuWindowExpressionMeta(
windowFunction match {
case aggregateExpression : AggregateExpression =>
aggregateExpression.aggregateFunction match {
// Count does not work in these cases because of a bug in cudf where a rolling count
// does not do the correct thing for null entries
// Once https://github.com/rapidsai/cudf/issues/6343
// is fixed this can be deleted and the check will go to the next case
// where it will match and pass.
case Count(exp) =>
if (!exp.forall(x => x.isInstanceOf[Literal])) {
willNotWorkOnGpu(s"Currently, only COUNT(1) and COUNT(*) are supported. " +
s"COUNT($exp) is not supported in windowing.")
}
// Sadly not all aggregations work for window operations yet, so explicitly allow the
// ones that do work.
case Count(_) | Sum(_) | Min(_) | Max(_) => // Supported.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ class WindowFunctionSuite extends SparkQueryCompareTestSuite {
| SUM(dollars) OVER $windowClause,
| MIN(dollars) OVER $windowClause,
| MAX(dollars) OVER $windowClause,
| COUNT(1) OVER $windowClause,
| COUNT(dollars) OVER $windowClause,
| COUNT(1) OVER $windowClause,
| COUNT(*) OVER $windowClause
| FROM mytable
|
Expand Down

0 comments on commit ea7dd8f

Please sign in to comment.