Skip to content

Commit

Permalink
Update GpuIf to support expressions with side effects (#4358)
Browse files Browse the repository at this point in the history
* Update GpuIf to support expressions with side effects

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* prep for review

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* improve checks for all true/false

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* refactor for readability and add documentation

* improve GpuCast side-effect check

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* release resources earlier

* remove isValid check

* Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalExpressions.scala

Co-authored-by: Liangcai Li <firestarmanllc@gmail.com>

* Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalExpressions.scala

Co-authored-by: Liangcai Li <firestarmanllc@gmail.com>

* fix resource leak and add test for CAST

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* revert add blank line

* Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalExpressions.scala

Co-authored-by: Jason Lowe <jlowe@nvidia.com>

* partially address PR review feedback

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* Change gather signature to return ColumnVector. Also add missing import.

Signed-off-by: Andy Grove <andygrove@nvidia.com>

Co-authored-by: Liangcai Li <firestarmanllc@gmail.com>
Co-authored-by: Jason Lowe <jlowe@nvidia.com>
  • Loading branch information
3 people authored Dec 17, 2021
1 parent 713636b commit c07e1e2
Show file tree
Hide file tree
Showing 7 changed files with 237 additions and 4 deletions.
29 changes: 29 additions & 0 deletions integration_tests/src/main/python/conditionals_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from pyspark.sql.types import *
import pyspark.sql.functions as f

def mk_str_gen(pattern):
return StringGen(pattern).with_special_case('').with_special_pattern('.{0,10}')

all_gens = all_gen + [NullGen()]
all_nested_gens = array_gens_sample + struct_gens_sample + map_gens_sample
all_nested_gens_nonempty_struct = array_gens_sample + nonempty_struct_gens_sample
Expand Down Expand Up @@ -182,3 +185,29 @@ def test_ifnull(data_gen):
'ifnull({}, b)'.format(s1),
'ifnull({}, b)'.format(null_lit),
'ifnull(a, {})'.format(null_lit)))

@pytest.mark.parametrize('data_gen', int_n_long_gens, ids=idfn)
def test_conditional_with_side_effects_col_col(data_gen):
gen = IntegerGen().with_special_case(2147483647)
assert_gpu_and_cpu_are_equal_collect(
lambda spark : two_col_df(spark, data_gen, gen).selectExpr(
'IF(b < 2147483647, b + 1, b)'),
conf = {'spark.sql.ansi.enabled':True})

@pytest.mark.parametrize('data_gen', int_n_long_gens, ids=idfn)
def test_conditional_with_side_effects_col_scalar(data_gen):
gen = IntegerGen().with_special_case(2147483647)
assert_gpu_and_cpu_are_equal_collect(
lambda spark : two_col_df(spark, data_gen, gen).selectExpr(
'IF(b < 2147483647, b + 1, 2147483647)',
'IF(b >= 2147483646, 2147483647, b + 1)'),
conf = {'spark.sql.ansi.enabled':True})

@pytest.mark.parametrize('data_gen', int_n_long_gens, ids=idfn)
def test_conditional_with_side_effects_cast(data_gen):
gen = mk_str_gen('[0-9]{1,20}')
assert_gpu_and_cpu_are_equal_collect(
lambda spark : two_col_df(spark, data_gen, gen).selectExpr(
'IF(a RLIKE "^[0-9]{1,5}$", CAST(a AS INT), 0)'),
conf = {'spark.sql.ansi.enabled':True,
'spark.rapids.sql.expression.RLike': True})
17 changes: 17 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1447,6 +1447,23 @@ case class GpuCast(

import GpuCast._

// when ansi mode is enabled, some cast expressions can throw exceptions on invalid inputs
override def hasSideEffects: Boolean = {
(child.dataType, dataType) match {
case (StringType, _) if ansiMode => true
case (TimestampType, ByteType | ShortType | IntegerType) if ansiMode => true
case (_: DecimalType, LongType) if ansiMode => true
case (LongType | _: DecimalType, IntegerType) if ansiMode => true
case (LongType | IntegerType | _: DecimalType, ShortType) if ansiMode => true
case (LongType | IntegerType | ShortType | _: DecimalType, ByteType) if ansiMode => true
case (FloatType | DoubleType, ByteType) if ansiMode => true
case (FloatType | DoubleType, ShortType) if ansiMode => true
case (FloatType | DoubleType, IntegerType) if ansiMode => true
case (FloatType | DoubleType, LongType) if ansiMode => true
case _ => false
}
}

override def toString: String = if (ansiMode) {
s"ansi_cast($child as ${dataType.simpleString})"
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,13 @@ trait GpuExpression extends Expression with Arm {
*/
def convertToAst(numFirstTableColumns: Int): ast.AstExpression =
throw new IllegalStateException(s"Cannot convert ${this.getClass.getSimpleName} to AST")

/** Could evaluating this expression cause side-effects, such as throwing an exception? */
def hasSideEffects: Boolean =
children.exists {
case c: GpuExpression => c.hasSideEffects
case _ => false // This path should never really happen
}
}

abstract class GpuLeafExpression extends GpuExpression with ShimExpression {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ trait GpuUserDefinedFunction extends GpuExpression
/** True if the UDF is deterministic */
val udfDeterministic: Boolean

override def hasSideEffects: Boolean = true

override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic)

private[this] val nvtxRangeName = s"UDF: $name"
Expand Down Expand Up @@ -103,6 +105,7 @@ trait GpuRowBasedUserDefinedFunction extends GpuExpression
private[this] lazy val outputType = dataType.catalogString

override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic)
override def hasSideEffects: Boolean = true

override def columnarEval(batch: ColumnarBatch): Any = {
val cpuUDFStart = System.nanoTime
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,17 @@

package com.nvidia.spark.rapids

import ai.rapids.cudf.{ColumnVector, NullPolicy, Scalar, ScanAggregation, ScanType, Table, UnaryOp}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.shims.v2.ShimExpression

import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.{ComplexTypeMergingExpression, Expression}
import org.apache.spark.sql.types.{BooleanType, DataType}
import org.apache.spark.sql.types.{BooleanType, DataType, DataTypes}
import org.apache.spark.sql.vectorized.ColumnarBatch

trait GpuConditionalExpression extends ComplexTypeMergingExpression with GpuExpression
with ShimExpression {
with ShimExpression {

protected def computeIfElse(
batch: ColumnarBatch,
Expand Down Expand Up @@ -54,6 +55,31 @@ trait GpuConditionalExpression extends ComplexTypeMergingExpression with GpuExpr
}
}

protected def isAllTrue(col: GpuColumnVector): Boolean = {
assert(BooleanType == col.dataType())
if (col.getRowCount == 0) {
return true
}
if (col.hasNull) {
return false
}
withResource(col.getBase.all()) { allTrue =>
// Guaranteed there is at least one row and no nulls so result must be valid
allTrue.getBoolean
}
}

protected def isAllFalse(col: GpuColumnVector): Boolean = {
assert(BooleanType == col.dataType())
if (col.getRowCount == col.numNulls()) {
// all nulls, and null values are false values here
return true
}
withResource(col.getBase.any()) { anyTrue =>
// null values are considered false values in this context
!anyTrue.getBoolean
}
}
}

case class GpuIf(
Expand Down Expand Up @@ -82,8 +108,153 @@ case class GpuIf(
}
}

override def columnarEval(batch: ColumnarBatch): Any = computeIfElse(batch, predicateExpr,
trueExpr, falseExpr.columnarEval(batch))
override def columnarEval(batch: ColumnarBatch): Any = {

val gpuTrueExpr = trueExpr.asInstanceOf[GpuExpression]
val gpuFalseExpr = falseExpr.asInstanceOf[GpuExpression]

withResource(GpuExpressionsUtils.columnarEvalToColumn(predicateExpr, batch)) { pred =>
if (isAllTrue(pred)) {
GpuExpressionsUtils.columnarEvalToColumn(trueExpr, batch)
} else if (isAllFalse(pred)) {
GpuExpressionsUtils.columnarEvalToColumn(falseExpr, batch)
} else if (gpuTrueExpr.hasSideEffects || gpuFalseExpr.hasSideEffects) {
conditionalWithSideEffects(batch, pred, gpuTrueExpr, gpuFalseExpr)
} else {
withResourceIfAllowed(trueExpr.columnarEval(batch)) { trueRet =>
withResourceIfAllowed(falseExpr.columnarEval(batch)) { falseRet =>
val finalRet = (trueRet, falseRet) match {
case (t: GpuColumnVector, f: GpuColumnVector) =>
pred.getBase.ifElse(t.getBase, f.getBase)
case (t: GpuScalar, f: GpuColumnVector) =>
pred.getBase.ifElse(t.getBase, f.getBase)
case (t: GpuColumnVector, f: GpuScalar) =>
pred.getBase.ifElse(t.getBase, f.getBase)
case (t: GpuScalar, f: GpuScalar) =>
pred.getBase.ifElse(t.getBase, f.getBase)
case (t, f) =>
throw new IllegalStateException(s"Unexpected inputs" +
s" ($t: ${t.getClass}, $f: ${f.getClass})")
}
GpuColumnVector.from(finalRet, dataType)
}
}
}
}
}

/**
* When computing conditional expressions on the CPU, the true and false
* expressions are evaluated lazily, meaning that the true expression is
* only evaluated for rows where the predicate is true, and the false
* expression is only evaluated for rows where the predicate is false.
* This is important in the case where the expressions can have
* side-effects, such as throwing exceptions for invalid inputs.
*
* This method performs lazy evaluation on the GPU by first filtering the
* input batch into two batches - one for rows where the predicate is true
* and one for rows where the predicate is false. The expressions are
* evaluated against these batches and then the results are combined
* back into a single batch using the gather algorithm.
*/
private def conditionalWithSideEffects(
batch: ColumnarBatch,
pred: GpuColumnVector,
gpuTrueExpr: GpuExpression,
gpuFalseExpr: GpuExpression): GpuColumnVector = {

val colTypes = GpuColumnVector.extractTypes(batch)

withResource(GpuColumnVector.from(batch)) { tbl =>
withResource(pred.getBase.unaryOp(UnaryOp.NOT)) { inverted =>
// evaluate true expression against true batch
val tt = withResource(filterBatch(tbl, pred.getBase, colTypes)) { trueBatch =>
gpuTrueExpr.columnarEval(trueBatch)
}
withResourceIfAllowed(tt) { _ =>
// evaluate false expression against false batch
val ff = withResource(filterBatch(tbl, inverted, colTypes)) { falseBatch =>
gpuFalseExpr.columnarEval(falseBatch)
}
withResourceIfAllowed(ff) { _ =>
val finalRet = (tt, ff) match {
case (t: GpuColumnVector, f: GpuColumnVector) =>
withResource(gather(pred.getBase, t)) { trueValues =>
withResource(gather(inverted, f)) { falseValues =>
pred.getBase.ifElse(trueValues, falseValues)
}
}
case (t: GpuScalar, f: GpuColumnVector) =>
withResource(gather(inverted, f)) { falseValues =>
pred.getBase.ifElse(t.getBase, falseValues)
}
case (t: GpuColumnVector, f: GpuScalar) =>
withResource(gather(pred.getBase, t)) { trueValues =>
pred.getBase.ifElse(trueValues, f.getBase)
}
case (_: GpuScalar, _: GpuScalar) =>
throw new IllegalStateException(
"scalar expressions can never have side effects")
}
GpuColumnVector.from(finalRet, dataType)
}
}
}
}
}

private def filterBatch(
tbl: Table,
pred: ColumnVector,
colTypes: Array[DataType]): ColumnarBatch = {
withResource(tbl.filter(pred)) { filteredData =>
GpuColumnVector.from(filteredData, colTypes)
}
}

private def boolToInt(cv: ColumnVector): ColumnVector = {
withResource(GpuScalar.from(1, DataTypes.IntegerType)) { one =>
withResource(GpuScalar.from(0, DataTypes.IntegerType)) { zero =>
cv.ifElse(one, zero)
}
}
}

private def gather(predicate: ColumnVector, t: GpuColumnVector): ColumnVector = {
// convert the predicate boolean column to numeric where 1 = true
// amd 0 = false and then use `scan` with `sum` to convert to
// indices.
//
// For example, if the predicate evaluates to [F, F, T, F, T] then this
// gets translated first to [0, 0, 1, 0, 1] and then the scan operation
// will perform an exclusive sum on these values and
// produce [0, 0, 0, 1, 1]. Combining this with the original
// predicate boolean array results in the two T values mapping to
// indices 0 and 1, respectively.

withResource(boolToInt(predicate)) { boolsAsInts =>
withResource(boolsAsInts.scan(
ScanAggregation.sum(),
ScanType.EXCLUSIVE,
NullPolicy.INCLUDE)) { prefixSumExclusive =>

// for the entries in the gather map that do not represent valid
// values to be gathered, we change the value to -MAX_INT which
// will be treated as null values in the gather algorithm
val gatherMap = withResource(Scalar.fromInt(Int.MinValue)) {
outOfBoundsFlag => predicate.ifElse(prefixSumExclusive, outOfBoundsFlag)
}

withResource(new Table(t.getBase)) { tbl =>
withResource(gatherMap) { _ =>
withResource(tbl.gather(gatherMap)) { gatherTbl =>
gatherTbl.getColumn(0).incRefCount()
}
}
}
}
}
}

override def toString: String = s"if ($predicateExpr) $trueExpr else $falseExpr"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ case class GpuAbs(child: Expression, failOnError: Boolean) extends CudfUnaryExpr

abstract class CudfBinaryArithmetic extends CudfBinaryOperator with NullIntolerant {
override def dataType: DataType = left.dataType
// arithmetic operations can overflow and throw exceptions in ANSI mode
override def hasSideEffects: Boolean = SQLConf.get.ansiEnabled
}

object GpuAdd extends Arm {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ case class GpuCeil(child: Expression) extends CudfUnaryMathExpression("CEIL") {
case _ => LongType
}

override def hasSideEffects: Boolean = true

override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(DoubleType, DecimalType, LongType))

Expand Down Expand Up @@ -245,6 +247,8 @@ case class GpuFloor(child: Expression) extends CudfUnaryMathExpression("FLOOR")
case _ => LongType
}

override def hasSideEffects: Boolean = true

override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(DoubleType, DecimalType, LongType))

Expand Down

0 comments on commit c07e1e2

Please sign in to comment.