diff --git a/integration_tests/src/main/python/conditionals_test.py b/integration_tests/src/main/python/conditionals_test.py index df0c3d0909a..9673913241e 100644 --- a/integration_tests/src/main/python/conditionals_test.py +++ b/integration_tests/src/main/python/conditionals_test.py @@ -207,4 +207,15 @@ def test_conditional_with_side_effects_cast(data_gen): lambda spark : unary_op_df(spark, data_gen).selectExpr( 'IF(a RLIKE "^[0-9]{1,5}$", CAST(a AS INT), 0)'), conf = {'spark.sql.ansi.enabled':True, - 'spark.rapids.sql.expression.RLike': True}) \ No newline at end of file + 'spark.rapids.sql.expression.RLike': True}) + +@pytest.mark.parametrize('data_gen', [mk_str_gen('[0-9]{1,9}')], ids=idfn) +def test_conditional_with_side_effects_case_when(data_gen): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen).selectExpr( + 'CASE \ + WHEN a RLIKE "^[0-9]{1,3}$" THEN CAST(a AS INT) \ + WHEN a RLIKE "^[0-9]{4,6}$" THEN CAST(a AS INT) + 123 \ + ELSE -1 END'), + conf = {'spark.sql.ansi.enabled':True, + 'spark.rapids.sql.expression.RLike': True}) \ No newline at end of file diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalExpressions.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalExpressions.scala index 5b1c1172d5a..7c82883c383 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalExpressions.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalExpressions.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package com.nvidia.spark.rapids import ai.rapids.cudf.{ColumnVector, NullPolicy, Scalar, ScanAggregation, ScanType, Table, UnaryOp} +import ai.rapids.cudf.ast import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.shims.v2.ShimExpression @@ -80,6 +81,69 @@ trait GpuConditionalExpression extends ComplexTypeMergingExpression with GpuExpr !anyTrue.getBoolean } } + + protected def filterBatch( + tbl: Table, + pred: ColumnVector, + colTypes: Array[DataType]): ColumnarBatch = { + withResource(tbl.filter(pred)) { filteredData => + GpuColumnVector.from(filteredData, colTypes) + } + } + + private def boolToInt(cv: ColumnVector): ColumnVector = { + withResource(GpuScalar.from(1, DataTypes.IntegerType)) { one => + withResource(GpuScalar.from(0, DataTypes.IntegerType)) { zero => + cv.ifElse(one, zero) + } + } + } + + /** + * Invert boolean values and convert null values to true + */ + def boolInverted(cv: ColumnVector): ColumnVector = { + withResource(GpuScalar.from(true, DataTypes.BooleanType)) { t => + withResource(GpuScalar.from(false, DataTypes.BooleanType)) { f => + cv.ifElse(f, t) + } + } + } + + protected def gather(predicate: ColumnVector, t: GpuColumnVector): ColumnVector = { + // convert the predicate boolean column to numeric where 1 = true + // and 0 (or null) = false and then use `scan` with `sum` to convert to + // indices. + // + // For example, if the predicate evaluates to [F, null, T, F, T] then this + // gets translated first to [0, 0, 1, 0, 1] and then the scan operation + // will perform an exclusive sum on these values and + // produce [0, 0, 0, 1, 1]. Combining this with the original + // predicate boolean array results in the two T values mapping to + // indices 0 and 1, respectively. + + val prefixSumExclusive = withResource(boolToInt(predicate)) { boolsAsInts => + boolsAsInts.scan( + ScanAggregation.sum(), + ScanType.EXCLUSIVE, + NullPolicy.INCLUDE) + } + val gatherMap = withResource(prefixSumExclusive) { prefixSumExclusive => + // for the entries in the gather map that do not represent valid + // values to be gathered, we change the value to -MAX_INT which + // will be treated as null values in the gather algorithm + withResource(Scalar.fromInt(Int.MinValue)) { + outOfBoundsFlag => predicate.ifElse(prefixSumExclusive, outOfBoundsFlag) + } + } + withResource(gatherMap) { _ => + withResource(new Table(t.getBase)) { tbl => + withResource(tbl.gather(gatherMap)) { gatherTbl => + gatherTbl.getColumn(0).incRefCount() + } + } + } + } } case class GpuIf( @@ -203,59 +267,6 @@ case class GpuIf( } } - private def filterBatch( - tbl: Table, - pred: ColumnVector, - colTypes: Array[DataType]): ColumnarBatch = { - withResource(tbl.filter(pred)) { filteredData => - GpuColumnVector.from(filteredData, colTypes) - } - } - - private def boolToInt(cv: ColumnVector): ColumnVector = { - withResource(GpuScalar.from(1, DataTypes.IntegerType)) { one => - withResource(GpuScalar.from(0, DataTypes.IntegerType)) { zero => - cv.ifElse(one, zero) - } - } - } - - private def gather(predicate: ColumnVector, t: GpuColumnVector): ColumnVector = { - // convert the predicate boolean column to numeric where 1 = true - // amd 0 = false and then use `scan` with `sum` to convert to - // indices. - // - // For example, if the predicate evaluates to [F, F, T, F, T] then this - // gets translated first to [0, 0, 1, 0, 1] and then the scan operation - // will perform an exclusive sum on these values and - // produce [0, 0, 0, 1, 1]. Combining this with the original - // predicate boolean array results in the two T values mapping to - // indices 0 and 1, respectively. - - withResource(boolToInt(predicate)) { boolsAsInts => - withResource(boolsAsInts.scan( - ScanAggregation.sum(), - ScanType.EXCLUSIVE, - NullPolicy.INCLUDE)) { prefixSumExclusive => - - // for the entries in the gather map that do not represent valid - // values to be gathered, we change the value to -MAX_INT which - // will be treated as null values in the gather algorithm - val gatherMap = withResource(Scalar.fromInt(Int.MinValue)) { - outOfBoundsFlag => predicate.ifElse(prefixSumExclusive, outOfBoundsFlag) - } - - withResource(new Table(t.getBase)) { tbl => - withResource(gatherMap) { _ => - withResource(tbl.gather(gatherMap)) { gatherTbl => - gatherTbl.getColumn(0).incRefCount() - } - } - } - } - } - } - override def toString: String = s"if ($predicateExpr) $trueExpr else $falseExpr" override def sql: String = s"(IF(${predicateExpr.sql}, ${trueExpr.sql}, ${falseExpr.sql}))" @@ -274,6 +285,9 @@ case class GpuCaseWhen( branches.map(_._2.dataType) ++ elseValue.map(_.dataType) } + private lazy val branchesWithSideEffects = + branches.exists(_._2.asInstanceOf[GpuExpression].hasSideEffects) + override def nullable: Boolean = { // Result is nullable if any of the branch is nullable, or if the else value is nullable branches.exists(_._2.nullable) || elseValue.forall(_.nullable) @@ -301,12 +315,182 @@ case class GpuCaseWhen( } override def columnarEval(batch: ColumnarBatch): Any = { - // `elseRet` will be closed in `computeIfElse`. - val elseRet = elseValue - .map(_.columnarEval(batch)) - .getOrElse(GpuScalar(null, branches.last._2.dataType)) - branches.foldRight[Any](elseRet) { case ((predicateExpr, trueExpr), falseRet) => - computeIfElse(batch, predicateExpr, trueExpr, falseRet) + if (branchesWithSideEffects) { + columnarEvalWithSideEffects(batch) + } else { + // `elseRet` will be closed in `computeIfElse`. + val elseRet = elseValue + .map(_.columnarEval(batch)) + .getOrElse(GpuScalar(null, branches.last._2.dataType)) + branches.foldRight[Any](elseRet) { + case ((predicateExpr, trueExpr), falseRet) => + computeIfElse(batch, predicateExpr, trueExpr, falseRet) + } + } + } + + /** + * Perform lazy evaluation of each branch so that we only evaluate the THEN expressions + * against rows where the WHEN expression is true. + */ + private def columnarEvalWithSideEffects(batch: ColumnarBatch): Any = { + val colTypes = GpuColumnVector.extractTypes(batch) + + // track cumulative state of predicate evaluation per row so that we never evaluate expressions + // for a row if an earlier expression has already been evaluated to true for that row + var cumulativePred: Option[GpuColumnVector] = None + + // this variable contains the currently evaluated value for each row and gets updated + // as each branch is evaluated + var currentValue: Option[GpuColumnVector] = None + + try { + withResource(GpuColumnVector.from(batch)) { tbl => + + // iterate over the WHEN THEN branches first + branches.foreach { + case (whenExpr, thenExpr) => + // evaluate the WHEN predicate + withResource(GpuExpressionsUtils.columnarEvalToColumn(whenExpr, batch)) { whenBool => + // we only want to evaluate where this WHEN is true and no previous WHEN has been true + val firstTrueWhen = isFirstTrueWhen(cumulativePred, whenBool) + + withResource(firstTrueWhen) { _ => + if (isAllTrue(firstTrueWhen)) { + // if this WHEN predicate is true for all rows and no previous predicate has + // been true then we can return immediately + return GpuExpressionsUtils.columnarEvalToColumn(thenExpr, batch) + } + val thenValues = filterEvaluateWhenThen(colTypes, tbl, firstTrueWhen.getBase, + thenExpr) + withResource(thenValues) { _ => + currentValue = Some(calcCurrentValue(currentValue, firstTrueWhen, thenValues)) + } + cumulativePred = Some(calcCumulativePredicate( + cumulativePred, whenBool, firstTrueWhen)) + + if (isAllTrue(cumulativePred.get)) { + // no need to process any more branches or the else condition + return currentValue.get.incRefCount() + } + } + } + } + + // invert the cumulative predicate to get the ELSE predicate + withResource(boolInverted(cumulativePred.get.getBase)) { elsePredNoNulls => + elseValue match { + case Some(expr) => + if (isAllFalse(cumulativePred.get)) { + GpuExpressionsUtils.columnarEvalToColumn(expr, batch) + } else { + val elseValues = filterEvaluateWhenThen(colTypes, tbl, elsePredNoNulls, expr) + withResource(elseValues) { _ => + GpuColumnVector.from(elsePredNoNulls.ifElse( + elseValues, currentValue.get.getBase), dataType) + } + } + + case None => + // if there is no ELSE condition then we return NULL for any rows not matched by + // previous branches + withResource(GpuScalar.from(null, dataType)) { nullScalar => + if (isAllFalse(cumulativePred.get)) { + GpuColumnVector.from(nullScalar, elsePredNoNulls.getRowCount.toInt, dataType) + } else { + GpuColumnVector.from( + elsePredNoNulls.ifElse(nullScalar, currentValue.get.getBase), + dataType) + } + } + } + } + } + } finally { + currentValue.foreach(_.safeClose()) + cumulativePred.foreach(_.safeClose()) + } + } + + /** + * Filter the batch to just the rows where the WHEN condition is true and + * then evaluate the THEN expression. + */ + private def filterEvaluateWhenThen( + colTypes: Array[DataType], + tbl: Table, + whenBool: ColumnVector, + thenExpr: Expression): ColumnVector = { + val filteredBatch = filterBatch(tbl, whenBool, colTypes) + val thenValues = withResource(filteredBatch) { trueBatch => + GpuExpressionsUtils.columnarEvalToColumn(thenExpr, trueBatch) + } + withResource(thenValues) { _ => + gather(whenBool, thenValues) + } + } + + /** + * Calculate the cumulative predicate so far using the logical expression + * `prevPredicate OR thisPredicate`. + */ + private def calcCumulativePredicate( + cumulativePred: Option[GpuColumnVector], + whenBool: GpuColumnVector, + firstTrueWhen: GpuColumnVector): GpuColumnVector = { + cumulativePred match { + case Some(prev) => + withResource(prev) { _ => + val result = withResource(new Table(prev.getBase, whenBool.getBase)) { t => + val or = new ast.BinaryOperation(ast.BinaryOperator.NULL_LOGICAL_OR, + new ast.ColumnReference(0), + new ast.ColumnReference(1) + ) + withResource(or.compile()) { + _.computeColumn(t) + } + } + GpuColumnVector.from(result, DataTypes.BooleanType) + } + case _ => + firstTrueWhen.incRefCount() + } + } + + /** + * Calculate the current values by merging the THEN values for this branch (where the WHEN + * predicate was true) with the previous values. + */ + private def calcCurrentValue( + prevValue: Option[GpuColumnVector], + whenBool: GpuColumnVector, + thenValues: ColumnVector): GpuColumnVector = { + prevValue match { + case Some(v) => + withResource(v) { _ => + GpuColumnVector.from(whenBool.getBase.ifElse(thenValues, v.getBase), dataType) + } + case _ => + GpuColumnVector.from(thenValues.incRefCount(), dataType) + } + } + + /** + * Determine for each row whether this is the first WHEN predicate so far to evaluate to true + */ + private def isFirstTrueWhen( + cumulativePred: Option[GpuColumnVector], + whenBool: GpuColumnVector): GpuColumnVector = { + cumulativePred match { + case Some(prev) => + withResource(boolInverted(prev.getBase)) { notPrev => + val whenReplaced = withResource(Scalar.fromBool(false)) { falseScalar => + whenBool.getBase.replaceNulls(falseScalar) + } + GpuColumnVector.from(whenReplaced.and(notPrev), DataTypes.BooleanType) + } + case None => + whenBool.incRefCount() } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/ConditionalsSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/ConditionalsSuite.scala new file mode 100644 index 00000000000..e197099c96d --- /dev/null +++ b/tests/src/test/scala/com/nvidia/spark/rapids/ConditionalsSuite.scala @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids + +import org.apache.spark.SparkConf +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.functions.expr + +class ConditionalsSuite extends SparkQueryCompareTestSuite { + + private val conf = new SparkConf() + .set("spark.sql.ansi.enabled", "true") + .set("spark.rapids.sql.expression.RLike", "true") + + testSparkResultsAreEqual("CASE WHEN test all branches", testData, conf) { df => + df.withColumn("test", expr( + "CASE " + + "WHEN a RLIKE '^[0-9]{1,3}$' THEN CAST(a AS INT) " + + "WHEN a RLIKE '^[0-9]{4,6}$' THEN CAST(a AS INT) + 123 " + + "ELSE -1 END")) + } + + testSparkResultsAreEqual("CASE WHEN first branch always true", testData2, conf) { df => + df.withColumn("test", expr( + "CASE " + + "WHEN a RLIKE '^[0-9]{1,3}$' THEN CAST(a AS INT) " + + "WHEN a RLIKE '^[0-9]{4,6}$' THEN CAST(a AS INT) + 123 " + + "ELSE -1 END")) + } + + testSparkResultsAreEqual("CASE WHEN second branch always true", testData2, conf) { df => + df.withColumn("test", expr( + "CASE " + + "WHEN a RLIKE '^[0-9]{4,6}$' THEN CAST(a AS INT) " + + "WHEN a RLIKE '^[0-9]{1,3}$' THEN CAST(a AS INT) + 123 " + + "ELSE -1 END")) + } + + testSparkResultsAreEqual("CASE WHEN else condition always true", testData2, conf) { df => + df.withColumn("test", expr( + "CASE " + + "WHEN a RLIKE '^[0-9]{4,6}$' THEN CAST(a AS INT) " + + "WHEN a RLIKE '^[0-9]{7,9}$' THEN CAST(a AS INT) + 123 " + + "ELSE CAST(a AS INT) END")) + } + + testSparkResultsAreEqual("CASE WHEN first or second branch is true", testData3, conf) { df => + df.withColumn("test", expr( + "CASE " + + "WHEN a RLIKE '^[0-9]{1,3}$' THEN CAST(a AS INT) " + + "WHEN a RLIKE '^[0-9]{4,6}$' THEN CAST(a AS INT) + 123 " + + "ELSE -1 END")) + } + + testSparkResultsAreEqual("CASE WHEN with null predicate values on first branch", + testData3, conf) { df => + df.withColumn("test", expr( + "CASE " + + "WHEN char_length(a) < 4 THEN CAST(a AS INT) " + + "WHEN char_length(a) < 7 THEN CAST(a AS INT) + 123 " + + "WHEN char_length(a) IS NULL THEN -999 " + + "ELSE -1 END")) + } + + testSparkResultsAreEqual("CASE WHEN with null predicate values after first branch", + testData3, conf) { df => + df.withColumn("test", expr( + "CASE " + + "WHEN char_length(a) IS NULL THEN -999 " + + "WHEN char_length(a) < 4 THEN CAST(a AS INT) " + + "WHEN char_length(a) < 7 THEN CAST(a AS INT) + 123 " + + "ELSE -1 END")) + } + + private def testData(session: SparkSession): DataFrame = { + import session.sqlContext.implicits._ + Seq( + "123", + "123456", + "123456789", + null, + "99999999999999999999" + ).toDF("a").repartition(2) + } + + private def testData2(session: SparkSession): DataFrame = { + import session.sqlContext.implicits._ + Seq( + null, + "123", + "456" + ).toDF("a").repartition(2) + } + + private def testData3(session: SparkSession): DataFrame = { + import session.sqlContext.implicits._ + Seq( + null, + "123", + "123456" + ).toDF("a").repartition(2) + } + +}