diff --git a/docs/supported_ops.md b/docs/supported_ops.md
index 0e6d8b7bd93..02a1641b4e7 100644
--- a/docs/supported_ops.md
+++ b/docs/supported_ops.md
@@ -15722,12 +15722,12 @@ Accelerator support is described below.
window |
input |
|
-NS |
-NS |
-NS |
-NS |
-NS |
-NS |
+S |
+S |
+S |
+S |
+S |
+S |
|
|
|
@@ -15748,7 +15748,7 @@ Accelerator support is described below.
|
|
|
-NS |
+S |
|
|
|
diff --git a/integration_tests/src/main/python/window_function_test.py b/integration_tests/src/main/python/window_function_test.py
index 9bbf93a6e26..64127f7d497 100644
--- a/integration_tests/src/main/python/window_function_test.py
+++ b/integration_tests/src/main/python/window_function_test.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020, NVIDIA CORPORATION.
+# Copyright (c) 2020-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -82,9 +82,12 @@ def test_window_aggs_for_rows(data_gen):
' (partition by a order by b,c rows between UNBOUNDED preceding and UNBOUNDED following) as count_1, '
' count(c) over '
' (partition by a order by b,c rows between UNBOUNDED preceding and UNBOUNDED following) as count_c, '
+ ' avg(c) over '
+ ' (partition by a order by b,c rows between UNBOUNDED preceding and UNBOUNDED following) as avg_c, '
' row_number() over '
' (partition by a order by b,c rows between UNBOUNDED preceding and CURRENT ROW) as row_num '
- 'from window_agg_table ')
+ 'from window_agg_table ',
+ conf = {'spark.rapids.sql.castFloatToDecimal.enabled': True})
part_and_order_gens = [long_gen, DoubleGen(no_nans=True, special_cases=[]),
@@ -179,6 +182,9 @@ def test_window_aggs_for_ranges(data_gen):
' sum(c) over '
' (partition by a order by cast(b as timestamp) asc '
' range between interval 1 day preceding and interval 1 day following) as sum_c_asc, '
+ ' avg(c) over '
+ ' (partition by a order by cast(b as timestamp) asc '
+ ' range between interval 1 day preceding and interval 1 day following) as avg_c_asc, '
' max(c) over '
' (partition by a order by cast(b as timestamp) desc '
' range between interval 2 days preceding and interval 1 days following) as max_c_desc, '
@@ -191,13 +197,17 @@ def test_window_aggs_for_ranges(data_gen):
' count(c) over '
' (partition by a order by cast(b as timestamp) asc '
' range between CURRENT ROW and UNBOUNDED following) as count_c_asc, '
+ ' avg(c) over '
+ ' (partition by a order by cast(b as timestamp) asc '
+ ' range between UNBOUNDED preceding and CURRENT ROW) as avg_c_unbounded, '
' sum(c) over '
' (partition by a order by cast(b as timestamp) asc '
' range between UNBOUNDED preceding and CURRENT ROW) as sum_c_unbounded, '
' max(c) over '
' (partition by a order by cast(b as timestamp) asc '
' range between UNBOUNDED preceding and UNBOUNDED following) as max_c_unbounded '
- 'from window_agg_table')
+ 'from window_agg_table',
+ conf = {'spark.rapids.sql.castFloatToDecimal.enabled': True})
@pytest.mark.xfail(reason="[UNSUPPORTED] Ranges over non-timestamp columns "
"(https://github.com/NVIDIA/spark-rapids/issues/216)")
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
index f4ef7374eee..44d041efac8 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
@@ -1786,7 +1786,7 @@ object GpuOverrides {
}),
expr[Average](
"Average aggregate operator",
- ExprChecks.aggNotWindow(
+ ExprChecks.fullAgg(
TypeSig.DOUBLE, TypeSig.DOUBLE + TypeSig.DECIMAL,
Seq(ParamCheck("input", TypeSig.integral + TypeSig.fp, TypeSig.numeric))),
(a, conf, p, r) => new AggExprMeta[Average](a, conf, p, r) {
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala
index ec1db228972..e76051ef36d 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2020, NVIDIA CORPORATION.
+ * Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -387,7 +387,8 @@ case class GpuCount(children: Seq[Expression]) extends GpuDeclarativeAggregate
Aggregation.count(false).onColumn(inputs.head._2)
}
-case class GpuAverage(child: Expression) extends GpuDeclarativeAggregate {
+case class GpuAverage(child: Expression) extends GpuDeclarativeAggregate
+ with GpuAggregateWindowFunction {
// averages are either Decimal or Double. We don't support decimal yet, so making this double.
private lazy val cudfSum = AttributeReference("sum", DoubleType)()
private lazy val cudfCount = AttributeReference("count", LongType)()
@@ -444,6 +445,10 @@ case class GpuAverage(child: Expression) extends GpuDeclarativeAggregate {
override def children: Seq[Expression] = child :: Nil
override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForNumericExpr(child.dataType, "function gpu average")
+
+ override val windowInputProjection: Seq[Expression] = Seq(children.head)
+ override def windowAggregation(inputs: Seq[(ColumnVector, Int)]): AggregationOnColumn =
+ Aggregation.mean().onColumn(inputs.head._2)
}
/*