From 36ca2c8127337a9cb7f65a9dd98434694f8e0391 Mon Sep 17 00:00:00 2001 From: Mithun RK Date: Thu, 3 Sep 2020 16:50:11 -0700 Subject: [PATCH] [window] Disable GPU for COUNT(exp) queries GpuWindowExec currently counts null-rows when running COUNT(col) (or generally COUNT(expr)) window queries, owing to a bug in CUDF/Java. Left unchecked, this will produce incorrect results for said queries. This commit disables GPU acceleration for COUNT(expr) queries, while retaining support for COUNT(1) and COUNT(*). This may be reverted once we have a fix in CUDF/Java. Signed-off-by: Mithun RK --- .../spark/rapids/GpuWindowExpression.scala | 8 ++++- .../test/resources/window-function-test.orc | Bin 1259 -> 1278 bytes .../spark/rapids/WindowFunctionSuite.scala | 32 ++++++++++++++++-- 3 files changed, 36 insertions(+), 4 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala index 9c11b0e358e..255e9b77f1e 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala @@ -51,7 +51,13 @@ class GpuWindowExpressionMeta( windowFunction match { case aggregateExpression : AggregateExpression => aggregateExpression.aggregateFunction match { - case Count(_) | Sum(_) | Min(_) | Max(_) => // Supported. + case Count(exp) => { + if (!exp.forall(x => x.isInstanceOf[Literal])) { + willNotWorkOnGpu(s"Currently, only COUNT(1) and COUNT(*) are supported. " + + s"COUNT($exp) is not supported in windowing.") + } + } + case Sum(_) | Min(_) | Max(_) => // Supported. case other: AggregateFunction => willNotWorkOnGpu(s"AggregateFunction ${other.prettyName} " + s"is not supported in windowing.") diff --git a/tests/src/test/resources/window-function-test.orc b/tests/src/test/resources/window-function-test.orc index 529cade894a89bba2b7a428e8157e8089b68668f..c947c8e2f8b6bef06c199eba218aa09ca240fa85 100644 GIT binary patch delta 248 zcmaFO`HyqLWmyvj25A8_TmS!q zHUop7kd&NU#w3+BN>$U2%vmKfVaJN5EGb5YWs?(_wD>jHBxJeRI2Z*OCHPr6JSNX% z(&ph{5MTqc83ouRBqpC?l3`4m{G4f?u1_-~mjJ3&5%L904Ck0Wgm5qmd}Wp51yXMr zCB8B#OqjfpS&=^{KQCR1DWyb;JteUuwRrMF=Idq|&5TwUrZ`OS^I(WzXWWhs3I1};%9W(EcZAwCWfAr1~Hfg_9(eS!fD8$X0G%IY#O_z7D{$;o9* zQdy%^HSNfpRWcKHtZ2&OV`P{&Ig?3?pMi~4pny?;QG%a~je}8O@**Z}HV;MtHVMhe zH<@G@6DEISny16k%*beAf^2$(WEK;{8Kw^*9LxgGSfzM@)GJ1b_e=_1lMga0^84iH zrAslTlt{6sB$lKGPyWb!-56|$2dW_s6Z||FICd}zFiA8t2q-Zz@MvsiG+=M~!p!U+ H { + df.createOrReplaceTempView("mytable") + // scalastyle:off line.size.limit + df.sparkSession.sql( + """ + | SELECT COUNT(dollars+1) OVER + | (PARTITION BY uid + | ORDER BY CAST(dateLong AS TIMESTAMP) ASC + | RANGE BETWEEN INTERVAL 2 DAYS PRECEDING AND INTERVAL 3 DAYS FOLLOWING) + | FROM mytable + | + |""".stripMargin) + // scalastyle:on line.size.limit + } + } }