Skip to content

Commit

Permalink
[window] Disable GPU for non-COUNT(*)
Browse files Browse the repository at this point in the history
GpuWindowExec returns incorrect results for COUNT(col)
or COUNT(expr) window functions. (Null values are inadvertently
counted when they should be skipped.)

This commit disables GPU acceleration for such queries. This may
be reverted after COUNT(col) is handled correctly in CUDF.

Signed-off-by: Mithun RK <mythrocks@gmail.com>
  • Loading branch information
mythrocks committed Sep 3, 2020
1 parent 99c67db commit 4e3395e
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,13 @@ class GpuWindowExpressionMeta(
windowFunction match {
case aggregateExpression : AggregateExpression =>
aggregateExpression.aggregateFunction match {
case Count(_) | Sum(_) | Min(_) | Max(_) => // Supported.
case Count(exp) => {
if (!exp.forall(x => x.isInstanceOf[Literal])) {
willNotWorkOnGpu(s"Currently, only COUNT(1) and COUNT(*) are supported. " +
s"COUNT($exp) is not supported in windowing.")
}
}
case Sum(_) | Min(_) | Max(_) => // Supported.
case other: AggregateFunction =>
willNotWorkOnGpu(s"AggregateFunction ${other.prettyName} " +
s"is not supported in windowing.")
Expand Down
Binary file modified tests/src/test/resources/window-function-test.orc
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ class WindowFunctionSuite extends SparkQueryCompareTestSuite {
sum("dollars").over(windowSpec),
min("dollars").over(windowSpec),
max("dollars").over(windowSpec),
count("dollars").over(windowSpec),
count("*").over(windowSpec)
)

Expand All @@ -36,7 +35,6 @@ class WindowFunctionSuite extends SparkQueryCompareTestSuite {
sum("dollars").over(windowSpec),
min("dollars").over(windowSpec),
max("dollars").over(windowSpec),
count("dollars").over(windowSpec),
row_number().over(windowSpec),
count("*").over(windowSpec)
)
Expand Down Expand Up @@ -130,7 +128,7 @@ class WindowFunctionSuite extends SparkQueryCompareTestSuite {
| SUM(dollars) OVER $windowClause,
| MIN(dollars) OVER $windowClause,
| MAX(dollars) OVER $windowClause,
| COUNT(dollars) OVER $windowClause,
| COUNT(1) OVER $windowClause,
| COUNT(*) OVER $windowClause
| FROM mytable
|
Expand Down Expand Up @@ -406,4 +404,32 @@ class WindowFunctionSuite extends SparkQueryCompareTestSuite {
// scalastyle:on line.size.limit
}
}

ALLOW_NON_GPU_testSparkResultsAreEqual(
"[Window] [RANGE] [ ASC] [-2 DAYS, 3 DAYS] ",
windowTestDfOrc,
Seq("AggregateExpression",
"Alias",
"AttributeReference",
"Count",
"Literal",
"SpecifiedWindowFrame",
"WindowExec",
"WindowExpression",
"WindowSpecDefinition")) {
(df : DataFrame) => {
df.createOrReplaceTempView("mytable")
// scalastyle:off line.size.limit
df.sparkSession.sql(
"""
| SELECT COUNT(dollars+1) OVER
| (PARTITION BY uid
| ORDER BY CAST(dateLong AS TIMESTAMP) ASC
| RANGE BETWEEN INTERVAL 2 DAYS PRECEDING AND INTERVAL 3 DAYS FOLLOWING)
| FROM mytable
|
|""".stripMargin)
// scalastyle:on line.size.limit
}
}
}

0 comments on commit 4e3395e

Please sign in to comment.