Skip to content

Commit

Permalink
Add in window support for average (NVIDIA#1615)
Browse files Browse the repository at this point in the history
Signed-off-by: Robert (Bobby) Evans <bobby@apache.org>
  • Loading branch information
revans2 authored Jan 28, 2021
1 parent 830bda9 commit ba0b177
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 13 deletions.
14 changes: 7 additions & 7 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -15722,12 +15722,12 @@ Accelerator support is described below.
<td rowSpan="2">window</td>
<td>input</td>
<td> </td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
Expand All @@ -15748,7 +15748,7 @@ Accelerator support is described below.
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td>S</td>
<td> </td>
<td> </td>
<td> </td>
Expand Down
16 changes: 13 additions & 3 deletions integration_tests/src/main/python/window_function_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020, NVIDIA CORPORATION.
# Copyright (c) 2020-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -82,9 +82,12 @@ def test_window_aggs_for_rows(data_gen):
' (partition by a order by b,c rows between UNBOUNDED preceding and UNBOUNDED following) as count_1, '
' count(c) over '
' (partition by a order by b,c rows between UNBOUNDED preceding and UNBOUNDED following) as count_c, '
' avg(c) over '
' (partition by a order by b,c rows between UNBOUNDED preceding and UNBOUNDED following) as avg_c, '
' row_number() over '
' (partition by a order by b,c rows between UNBOUNDED preceding and CURRENT ROW) as row_num '
'from window_agg_table ')
'from window_agg_table ',
conf = {'spark.rapids.sql.castFloatToDecimal.enabled': True})


part_and_order_gens = [long_gen, DoubleGen(no_nans=True, special_cases=[]),
Expand Down Expand Up @@ -179,6 +182,9 @@ def test_window_aggs_for_ranges(data_gen):
' sum(c) over '
' (partition by a order by cast(b as timestamp) asc '
' range between interval 1 day preceding and interval 1 day following) as sum_c_asc, '
' avg(c) over '
' (partition by a order by cast(b as timestamp) asc '
' range between interval 1 day preceding and interval 1 day following) as avg_c_asc, '
' max(c) over '
' (partition by a order by cast(b as timestamp) desc '
' range between interval 2 days preceding and interval 1 days following) as max_c_desc, '
Expand All @@ -191,13 +197,17 @@ def test_window_aggs_for_ranges(data_gen):
' count(c) over '
' (partition by a order by cast(b as timestamp) asc '
' range between CURRENT ROW and UNBOUNDED following) as count_c_asc, '
' avg(c) over '
' (partition by a order by cast(b as timestamp) asc '
' range between UNBOUNDED preceding and CURRENT ROW) as avg_c_unbounded, '
' sum(c) over '
' (partition by a order by cast(b as timestamp) asc '
' range between UNBOUNDED preceding and CURRENT ROW) as sum_c_unbounded, '
' max(c) over '
' (partition by a order by cast(b as timestamp) asc '
' range between UNBOUNDED preceding and UNBOUNDED following) as max_c_unbounded '
'from window_agg_table')
'from window_agg_table',
conf = {'spark.rapids.sql.castFloatToDecimal.enabled': True})

@pytest.mark.xfail(reason="[UNSUPPORTED] Ranges over non-timestamp columns "
"(https://github.com/NVIDIA/spark-rapids/issues/216)")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1786,7 +1786,7 @@ object GpuOverrides {
}),
expr[Average](
"Average aggregate operator",
ExprChecks.aggNotWindow(
ExprChecks.fullAgg(
TypeSig.DOUBLE, TypeSig.DOUBLE + TypeSig.DECIMAL,
Seq(ParamCheck("input", TypeSig.integral + TypeSig.fp, TypeSig.numeric))),
(a, conf, p, r) => new AggExprMeta[Average](a, conf, p, r) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -387,7 +387,8 @@ case class GpuCount(children: Seq[Expression]) extends GpuDeclarativeAggregate
Aggregation.count(false).onColumn(inputs.head._2)
}

case class GpuAverage(child: Expression) extends GpuDeclarativeAggregate {
case class GpuAverage(child: Expression) extends GpuDeclarativeAggregate
with GpuAggregateWindowFunction {
// averages are either Decimal or Double. We don't support decimal yet, so making this double.
private lazy val cudfSum = AttributeReference("sum", DoubleType)()
private lazy val cudfCount = AttributeReference("count", LongType)()
Expand Down Expand Up @@ -444,6 +445,10 @@ case class GpuAverage(child: Expression) extends GpuDeclarativeAggregate {
override def children: Seq[Expression] = child :: Nil
override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForNumericExpr(child.dataType, "function gpu average")

override val windowInputProjection: Seq[Expression] = Seq(children.head)
override def windowAggregation(inputs: Seq[(ColumnVector, Int)]): AggregationOnColumn =
Aggregation.mean().onColumn(inputs.head._2)
}

/*
Expand Down

0 comments on commit ba0b177

Please sign in to comment.