Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix correctness issue with CASE WHEN with expressions that have side-effects #4383

Merged
merged 22 commits into from
Jan 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
aa6db16
Fix correctness issue with CASE WHEN with expressions that have side-…
andygrove Dec 17, 2021
e0879f9
code cleanup and comments
andygrove Dec 17, 2021
5756850
Revert unnecessary change
andygrove Dec 17, 2021
f509f69
Revert unnecessary change
andygrove Dec 17, 2021
58bfdad
Add license header
andygrove Dec 20, 2021
5b5247d
Add more comments. Add optimization to stop processing branches once …
andygrove Dec 21, 2021
e9e5f50
Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalE…
andygrove Dec 22, 2021
58d2c26
Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalE…
andygrove Dec 22, 2021
2ffd030
Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalE…
andygrove Dec 22, 2021
044c98a
remove redundant check
andygrove Dec 22, 2021
e83b187
close thenValues resource earlier
andygrove Dec 22, 2021
76d63a4
close elseValues resource earlier
andygrove Dec 22, 2021
89da423
refactor for readability
andygrove Dec 22, 2021
96fc175
refactor to remove duplicate code
andygrove Dec 22, 2021
4a79414
simplify inverting cumulativePred for else condition
andygrove Jan 4, 2022
aad30f4
Merge remote-tracking branch 'nvidia/branch-22.02' into gpu-case-when…
andygrove Jan 4, 2022
1233870
Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalE…
andygrove Jan 5, 2022
bde6b85
fix compilation error
andygrove Jan 5, 2022
034cc01
Use AST NULL_LOGICAL_OR for computing cumulative predicate
andygrove Jan 5, 2022
5ae6922
Update isFirstTrueWhen to be null-safe
andygrove Jan 5, 2022
bb80e29
address feedback
andygrove Jan 5, 2022
98d2207
Add tests for predicates evaluating to null
andygrove Jan 6, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion integration_tests/src/main/python/conditionals_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,4 +207,15 @@ def test_conditional_with_side_effects_cast(data_gen):
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'IF(a RLIKE "^[0-9]{1,5}$", CAST(a AS INT), 0)'),
conf = {'spark.sql.ansi.enabled':True,
'spark.rapids.sql.expression.RLike': True})
'spark.rapids.sql.expression.RLike': True})

@pytest.mark.parametrize('data_gen', [mk_str_gen('[0-9]{1,9}')], ids=idfn)
def test_conditional_with_side_effects_case_when(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr(
'CASE \
WHEN a RLIKE "^[0-9]{1,3}$" THEN CAST(a AS INT) \
WHEN a RLIKE "^[0-9]{4,6}$" THEN CAST(a AS INT) + 123 \
ELSE -1 END'),
conf = {'spark.sql.ansi.enabled':True,
'spark.rapids.sql.expression.RLike': True})
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,7 @@
package com.nvidia.spark.rapids

import ai.rapids.cudf.{ColumnVector, NullPolicy, Scalar, ScanAggregation, ScanType, Table, UnaryOp}
import ai.rapids.cudf.ast
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.shims.v2.ShimExpression

Expand Down Expand Up @@ -80,6 +81,69 @@ trait GpuConditionalExpression extends ComplexTypeMergingExpression with GpuExpr
!anyTrue.getBoolean
}
}

protected def filterBatch(
tbl: Table,
pred: ColumnVector,
colTypes: Array[DataType]): ColumnarBatch = {
withResource(tbl.filter(pred)) { filteredData =>
GpuColumnVector.from(filteredData, colTypes)
}
}

private def boolToInt(cv: ColumnVector): ColumnVector = {
withResource(GpuScalar.from(1, DataTypes.IntegerType)) { one =>
withResource(GpuScalar.from(0, DataTypes.IntegerType)) { zero =>
cv.ifElse(one, zero)
}
}
}

/**
* Invert boolean values and convert null values to true
*/
def boolInverted(cv: ColumnVector): ColumnVector = {
withResource(GpuScalar.from(true, DataTypes.BooleanType)) { t =>
withResource(GpuScalar.from(false, DataTypes.BooleanType)) { f =>
cv.ifElse(f, t)
}
}
}

protected def gather(predicate: ColumnVector, t: GpuColumnVector): ColumnVector = {
// convert the predicate boolean column to numeric where 1 = true
// and 0 (or null) = false and then use `scan` with `sum` to convert to
// indices.
//
// For example, if the predicate evaluates to [F, null, T, F, T] then this
// gets translated first to [0, 0, 1, 0, 1] and then the scan operation
// will perform an exclusive sum on these values and
// produce [0, 0, 0, 1, 1]. Combining this with the original
// predicate boolean array results in the two T values mapping to
// indices 0 and 1, respectively.

val prefixSumExclusive = withResource(boolToInt(predicate)) { boolsAsInts =>
boolsAsInts.scan(
ScanAggregation.sum(),
ScanType.EXCLUSIVE,
NullPolicy.INCLUDE)
}
val gatherMap = withResource(prefixSumExclusive) { prefixSumExclusive =>
// for the entries in the gather map that do not represent valid
// values to be gathered, we change the value to -MAX_INT which
// will be treated as null values in the gather algorithm
withResource(Scalar.fromInt(Int.MinValue)) {
outOfBoundsFlag => predicate.ifElse(prefixSumExclusive, outOfBoundsFlag)
}
}
withResource(gatherMap) { _ =>
withResource(new Table(t.getBase)) { tbl =>
withResource(tbl.gather(gatherMap)) { gatherTbl =>
gatherTbl.getColumn(0).incRefCount()
}
}
}
}
}

case class GpuIf(
Expand Down Expand Up @@ -203,59 +267,6 @@ case class GpuIf(
}
}

private def filterBatch(
tbl: Table,
pred: ColumnVector,
colTypes: Array[DataType]): ColumnarBatch = {
withResource(tbl.filter(pred)) { filteredData =>
GpuColumnVector.from(filteredData, colTypes)
}
}

private def boolToInt(cv: ColumnVector): ColumnVector = {
withResource(GpuScalar.from(1, DataTypes.IntegerType)) { one =>
withResource(GpuScalar.from(0, DataTypes.IntegerType)) { zero =>
cv.ifElse(one, zero)
}
}
}

private def gather(predicate: ColumnVector, t: GpuColumnVector): ColumnVector = {
// convert the predicate boolean column to numeric where 1 = true
// amd 0 = false and then use `scan` with `sum` to convert to
// indices.
//
// For example, if the predicate evaluates to [F, F, T, F, T] then this
// gets translated first to [0, 0, 1, 0, 1] and then the scan operation
// will perform an exclusive sum on these values and
// produce [0, 0, 0, 1, 1]. Combining this with the original
// predicate boolean array results in the two T values mapping to
// indices 0 and 1, respectively.

withResource(boolToInt(predicate)) { boolsAsInts =>
withResource(boolsAsInts.scan(
ScanAggregation.sum(),
ScanType.EXCLUSIVE,
NullPolicy.INCLUDE)) { prefixSumExclusive =>

// for the entries in the gather map that do not represent valid
// values to be gathered, we change the value to -MAX_INT which
// will be treated as null values in the gather algorithm
val gatherMap = withResource(Scalar.fromInt(Int.MinValue)) {
outOfBoundsFlag => predicate.ifElse(prefixSumExclusive, outOfBoundsFlag)
}

withResource(new Table(t.getBase)) { tbl =>
withResource(gatherMap) { _ =>
withResource(tbl.gather(gatherMap)) { gatherTbl =>
gatherTbl.getColumn(0).incRefCount()
}
}
}
}
}
}

override def toString: String = s"if ($predicateExpr) $trueExpr else $falseExpr"

override def sql: String = s"(IF(${predicateExpr.sql}, ${trueExpr.sql}, ${falseExpr.sql}))"
Expand All @@ -274,6 +285,9 @@ case class GpuCaseWhen(
branches.map(_._2.dataType) ++ elseValue.map(_.dataType)
}

private lazy val branchesWithSideEffects =
branches.exists(_._2.asInstanceOf[GpuExpression].hasSideEffects)

override def nullable: Boolean = {
// Result is nullable if any of the branch is nullable, or if the else value is nullable
branches.exists(_._2.nullable) || elseValue.forall(_.nullable)
Expand Down Expand Up @@ -301,12 +315,182 @@ case class GpuCaseWhen(
}

override def columnarEval(batch: ColumnarBatch): Any = {
// `elseRet` will be closed in `computeIfElse`.
val elseRet = elseValue
.map(_.columnarEval(batch))
.getOrElse(GpuScalar(null, branches.last._2.dataType))
branches.foldRight[Any](elseRet) { case ((predicateExpr, trueExpr), falseRet) =>
computeIfElse(batch, predicateExpr, trueExpr, falseRet)
if (branchesWithSideEffects) {
columnarEvalWithSideEffects(batch)
} else {
// `elseRet` will be closed in `computeIfElse`.
val elseRet = elseValue
.map(_.columnarEval(batch))
.getOrElse(GpuScalar(null, branches.last._2.dataType))
branches.foldRight[Any](elseRet) {
case ((predicateExpr, trueExpr), falseRet) =>
computeIfElse(batch, predicateExpr, trueExpr, falseRet)
}
}
}

/**
* Perform lazy evaluation of each branch so that we only evaluate the THEN expressions
* against rows where the WHEN expression is true.
*/
private def columnarEvalWithSideEffects(batch: ColumnarBatch): Any = {
jlowe marked this conversation as resolved.
Show resolved Hide resolved
val colTypes = GpuColumnVector.extractTypes(batch)

// track cumulative state of predicate evaluation per row so that we never evaluate expressions
// for a row if an earlier expression has already been evaluated to true for that row
var cumulativePred: Option[GpuColumnVector] = None

// this variable contains the currently evaluated value for each row and gets updated
// as each branch is evaluated
var currentValue: Option[GpuColumnVector] = None

try {
withResource(GpuColumnVector.from(batch)) { tbl =>

// iterate over the WHEN THEN branches first
branches.foreach {
case (whenExpr, thenExpr) =>
// evaluate the WHEN predicate
withResource(GpuExpressionsUtils.columnarEvalToColumn(whenExpr, batch)) { whenBool =>
// we only want to evaluate where this WHEN is true and no previous WHEN has been true
val firstTrueWhen = isFirstTrueWhen(cumulativePred, whenBool)

withResource(firstTrueWhen) { _ =>
if (isAllTrue(firstTrueWhen)) {
// if this WHEN predicate is true for all rows and no previous predicate has
// been true then we can return immediately
return GpuExpressionsUtils.columnarEvalToColumn(thenExpr, batch)
}
val thenValues = filterEvaluateWhenThen(colTypes, tbl, firstTrueWhen.getBase,
thenExpr)
withResource(thenValues) { _ =>
currentValue = Some(calcCurrentValue(currentValue, firstTrueWhen, thenValues))
}
cumulativePred = Some(calcCumulativePredicate(
cumulativePred, whenBool, firstTrueWhen))

if (isAllTrue(cumulativePred.get)) {
// no need to process any more branches or the else condition
return currentValue.get.incRefCount()
}
}
}
}

// invert the cumulative predicate to get the ELSE predicate
withResource(boolInverted(cumulativePred.get.getBase)) { elsePredNoNulls =>
elseValue match {
case Some(expr) =>
if (isAllFalse(cumulativePred.get)) {
GpuExpressionsUtils.columnarEvalToColumn(expr, batch)
} else {
val elseValues = filterEvaluateWhenThen(colTypes, tbl, elsePredNoNulls, expr)
withResource(elseValues) { _ =>
GpuColumnVector.from(elsePredNoNulls.ifElse(
elseValues, currentValue.get.getBase), dataType)
}
}

case None =>
// if there is no ELSE condition then we return NULL for any rows not matched by
// previous branches
withResource(GpuScalar.from(null, dataType)) { nullScalar =>
if (isAllFalse(cumulativePred.get)) {
GpuColumnVector.from(nullScalar, elsePredNoNulls.getRowCount.toInt, dataType)
} else {
GpuColumnVector.from(
elsePredNoNulls.ifElse(nullScalar, currentValue.get.getBase),
dataType)
}
}
}
}
}
} finally {
currentValue.foreach(_.safeClose())
cumulativePred.foreach(_.safeClose())
}
}

/**
* Filter the batch to just the rows where the WHEN condition is true and
* then evaluate the THEN expression.
*/
private def filterEvaluateWhenThen(
colTypes: Array[DataType],
tbl: Table,
whenBool: ColumnVector,
thenExpr: Expression): ColumnVector = {
val filteredBatch = filterBatch(tbl, whenBool, colTypes)
val thenValues = withResource(filteredBatch) { trueBatch =>
GpuExpressionsUtils.columnarEvalToColumn(thenExpr, trueBatch)
}
withResource(thenValues) { _ =>
gather(whenBool, thenValues)
}
}

/**
* Calculate the cumulative predicate so far using the logical expression
* `prevPredicate OR thisPredicate`.
*/
private def calcCumulativePredicate(
cumulativePred: Option[GpuColumnVector],
whenBool: GpuColumnVector,
firstTrueWhen: GpuColumnVector): GpuColumnVector = {
cumulativePred match {
case Some(prev) =>
withResource(prev) { _ =>
val result = withResource(new Table(prev.getBase, whenBool.getBase)) { t =>
val or = new ast.BinaryOperation(ast.BinaryOperator.NULL_LOGICAL_OR,
new ast.ColumnReference(0),
new ast.ColumnReference(1)
)
withResource(or.compile()) {
_.computeColumn(t)
}
}
GpuColumnVector.from(result, DataTypes.BooleanType)
}
case _ =>
firstTrueWhen.incRefCount()
}
}

/**
* Calculate the current values by merging the THEN values for this branch (where the WHEN
* predicate was true) with the previous values.
*/
private def calcCurrentValue(
prevValue: Option[GpuColumnVector],
whenBool: GpuColumnVector,
thenValues: ColumnVector): GpuColumnVector = {
prevValue match {
case Some(v) =>
withResource(v) { _ =>
GpuColumnVector.from(whenBool.getBase.ifElse(thenValues, v.getBase), dataType)
}
case _ =>
GpuColumnVector.from(thenValues.incRefCount(), dataType)
}
}

/**
* Determine for each row whether this is the first WHEN predicate so far to evaluate to true
*/
private def isFirstTrueWhen(
cumulativePred: Option[GpuColumnVector],
whenBool: GpuColumnVector): GpuColumnVector = {
cumulativePred match {
case Some(prev) =>
withResource(boolInverted(prev.getBase)) { notPrev =>
val whenReplaced = withResource(Scalar.fromBool(false)) { falseScalar =>
whenBool.getBase.replaceNulls(falseScalar)
}
GpuColumnVector.from(whenReplaced.and(notPrev), DataTypes.BooleanType)
}
case None =>
whenBool.incRefCount()
}
}

Expand Down
Loading