diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala index 9c11b0e358ef..255e9b77f1eb 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala @@ -51,7 +51,13 @@ class GpuWindowExpressionMeta( windowFunction match { case aggregateExpression : AggregateExpression => aggregateExpression.aggregateFunction match { - case Count(_) | Sum(_) | Min(_) | Max(_) => // Supported. + case Count(exp) => { + if (!exp.forall(x => x.isInstanceOf[Literal])) { + willNotWorkOnGpu(s"Currently, only COUNT(1) and COUNT(*) are supported. " + + s"COUNT($exp) is not supported in windowing.") + } + } + case Sum(_) | Min(_) | Max(_) => // Supported. case other: AggregateFunction => willNotWorkOnGpu(s"AggregateFunction ${other.prettyName} " + s"is not supported in windowing.") diff --git a/tests/src/test/resources/window-function-test.orc b/tests/src/test/resources/window-function-test.orc index 529cade894a8..c947c8e2f8b6 100644 Binary files a/tests/src/test/resources/window-function-test.orc and b/tests/src/test/resources/window-function-test.orc differ diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/WindowFunctionSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/WindowFunctionSuite.scala index 07c29e1438d5..0d4d102997d6 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/WindowFunctionSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/WindowFunctionSuite.scala @@ -27,7 +27,6 @@ class WindowFunctionSuite extends SparkQueryCompareTestSuite { sum("dollars").over(windowSpec), min("dollars").over(windowSpec), max("dollars").over(windowSpec), - count("dollars").over(windowSpec), count("*").over(windowSpec) ) @@ -36,7 +35,6 @@ class WindowFunctionSuite extends SparkQueryCompareTestSuite { sum("dollars").over(windowSpec), min("dollars").over(windowSpec), max("dollars").over(windowSpec), - count("dollars").over(windowSpec), row_number().over(windowSpec), count("*").over(windowSpec) ) @@ -130,7 +128,7 @@ class WindowFunctionSuite extends SparkQueryCompareTestSuite { | SUM(dollars) OVER $windowClause, | MIN(dollars) OVER $windowClause, | MAX(dollars) OVER $windowClause, - | COUNT(dollars) OVER $windowClause, + | COUNT(1) OVER $windowClause, | COUNT(*) OVER $windowClause | FROM mytable | @@ -406,4 +404,32 @@ class WindowFunctionSuite extends SparkQueryCompareTestSuite { // scalastyle:on line.size.limit } } + + ALLOW_NON_GPU_testSparkResultsAreEqual( + "[Window] [RANGE] [ ASC] [-2 DAYS, 3 DAYS] ", + windowTestDfOrc, + Seq("AggregateExpression", + "Alias", + "AttributeReference", + "Count", + "Literal", + "SpecifiedWindowFrame", + "WindowExec", + "WindowExpression", + "WindowSpecDefinition")) { + (df : DataFrame) => { + df.createOrReplaceTempView("mytable") + // scalastyle:off line.size.limit + df.sparkSession.sql( + """ + | SELECT COUNT(dollars+1) OVER + | (PARTITION BY uid + | ORDER BY CAST(dateLong AS TIMESTAMP) ASC + | RANGE BETWEEN INTERVAL 2 DAYS PRECEDING AND INTERVAL 3 DAYS FOLLOWING) + | FROM mytable + | + |""".stripMargin) + // scalastyle:on line.size.limit + } + } }