Skip to content

Commit

Permalink
[window] Disable GPU for COUNT(exp) queries (NVIDIA#666)
Browse files Browse the repository at this point in the history
GpuWindowExec currently counts null-rows when running COUNT(col)
(or generally COUNT(expr)) window queries, owing to a bug in CUDF/Java.
Left unchecked, this will produce incorrect results for said queries.

This commit disables GPU acceleration for COUNT(expr) queries, while
retaining support for COUNT(1) and COUNT(*).

This may be reverted once we have a fix in CUDF/Java.

Signed-off-by: Mithun RK <mythrocks@gmail.com>
  • Loading branch information
mythrocks authored Sep 5, 2020
1 parent d500916 commit de7ed04
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,13 @@ class GpuWindowExpressionMeta(
windowFunction match {
case aggregateExpression : AggregateExpression =>
aggregateExpression.aggregateFunction match {
case Count(_) | Sum(_) | Min(_) | Max(_) => // Supported.
case Count(exp) => {
if (!exp.forall(x => x.isInstanceOf[Literal])) {
willNotWorkOnGpu(s"Currently, only COUNT(1) and COUNT(*) are supported. " +
s"COUNT($exp) is not supported in windowing.")
}
}
case Sum(_) | Min(_) | Max(_) => // Supported.
case other: AggregateFunction =>
willNotWorkOnGpu(s"AggregateFunction ${other.prettyName} " +
s"is not supported in windowing.")
Expand Down
Binary file modified tests/src/test/resources/window-function-test.orc
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ class WindowFunctionSuite extends SparkQueryCompareTestSuite {
sum("dollars").over(windowSpec),
min("dollars").over(windowSpec),
max("dollars").over(windowSpec),
count("dollars").over(windowSpec),
count("*").over(windowSpec)
)

Expand All @@ -36,7 +35,6 @@ class WindowFunctionSuite extends SparkQueryCompareTestSuite {
sum("dollars").over(windowSpec),
min("dollars").over(windowSpec),
max("dollars").over(windowSpec),
count("dollars").over(windowSpec),
row_number().over(windowSpec),
count("*").over(windowSpec)
)
Expand Down Expand Up @@ -130,7 +128,7 @@ class WindowFunctionSuite extends SparkQueryCompareTestSuite {
| SUM(dollars) OVER $windowClause,
| MIN(dollars) OVER $windowClause,
| MAX(dollars) OVER $windowClause,
| COUNT(dollars) OVER $windowClause,
| COUNT(1) OVER $windowClause,
| COUNT(*) OVER $windowClause
| FROM mytable
|
Expand Down Expand Up @@ -406,4 +404,32 @@ class WindowFunctionSuite extends SparkQueryCompareTestSuite {
// scalastyle:on line.size.limit
}
}

ALLOW_NON_GPU_testSparkResultsAreEqual(
"[Window] [RANGE] [ ASC] [-2 DAYS, 3 DAYS] ",
windowTestDfOrc,
Seq("AggregateExpression",
"Alias",
"AttributeReference",
"Count",
"Literal",
"SpecifiedWindowFrame",
"WindowExec",
"WindowExpression",
"WindowSpecDefinition")) {
(df : DataFrame) => {
df.createOrReplaceTempView("mytable")
// scalastyle:off line.size.limit
df.sparkSession.sql(
"""
| SELECT COUNT(dollars+1) OVER
| (PARTITION BY uid
| ORDER BY CAST(dateLong AS TIMESTAMP) ASC
| RANGE BETWEEN INTERVAL 2 DAYS PRECEDING AND INTERVAL 3 DAYS FOLLOWING)
| FROM mytable
|
|""".stripMargin)
// scalastyle:on line.size.limit
}
}
}

0 comments on commit de7ed04

Please sign in to comment.