Skip to content

Commit

Permalink
Fall back to CPU for try_cast in Spark 3.4.0 [databricks] (#8179)
Browse files Browse the repository at this point in the history
* Fall back to CPU for try_cast in Spark 3.4.0

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* Refactor

* specific tests for 320 and 340

* update test

* update copyright year

* fix 31x shim

* refactor and fix failure in 330db

* Update integration_tests/src/main/python/cast_test.py

Co-authored-by: Navin Kumar <97137715+NVnavkumar@users.noreply.github.com>

* update test for db 11.3

* fix imports, rename test

* remove redundant code

* address feedback

* simplify code

* revert changes to pytest

* update skipif

* rename test

* fix another test error

* revert changes to Spark340PlusShims

---------

Signed-off-by: Andy Grove <andygrove@nvidia.com>
Co-authored-by: Navin Kumar <97137715+NVnavkumar@users.noreply.github.com>
  • Loading branch information
andygrove and NVnavkumar authored May 4, 2023
1 parent bb8f474 commit 8703289
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 17 deletions.
25 changes: 24 additions & 1 deletion integration_tests/src/main/python/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql, assert_gpu_and_cpu_error, assert_gpu_fallback_collect, assert_spark_exception
from data_gen import *
from spark_session import is_before_spark_320, is_before_spark_330, with_gpu_session
from spark_session import is_before_spark_320, is_before_spark_330, is_spark_340_or_later, is_databricks113_or_later, with_gpu_session
from marks import allow_non_gpu, approximate_float
from pyspark.sql.types import *
from spark_init_internal import spark_version
Expand Down Expand Up @@ -115,6 +115,29 @@ def test_cast_string_date_invalid_ansi(invalid):
'spark.sql.ansi.enabled': 'true'},
error_message="DateTimeException")


# test try_cast in Spark versions >= 320 and < 340
@pytest.mark.skipif(is_before_spark_320() or is_spark_340_or_later() or is_databricks113_or_later(), reason="try_cast only in Spark 3.2+")
@allow_non_gpu('ProjectExec', 'TryCast')
@pytest.mark.parametrize('invalid', invalid_values_string_to_date)
def test_try_cast_fallback(invalid):
assert_gpu_fallback_collect(
lambda spark: spark.createDataFrame([(invalid,)], "a string").selectExpr("try_cast(a as date)"),
'TryCast',
conf={'spark.rapids.sql.hasExtendedYearValues': False,
'spark.sql.ansi.enabled': True})

# test try_cast in Spark versions >= 340
@pytest.mark.skipif(not (is_spark_340_or_later() or is_databricks113_or_later()), reason="Cast with EvalMode only in Spark 3.4+")
@allow_non_gpu('ProjectExec','Cast')
@pytest.mark.parametrize('invalid', invalid_values_string_to_date)
def test_try_cast_fallback_340(invalid):
assert_gpu_fallback_collect(
lambda spark: spark.createDataFrame([(invalid,)], "a string").selectExpr("try_cast(a as date)"),
'Cast',
conf={'spark.rapids.sql.hasExtendedYearValues': False,
'spark.sql.ansi.enabled': True})

# test all Spark versions, non ANSI mode, invalid value will be converted to NULL
def test_cast_string_date_non_ansi():
data_rows = [(v,) for v in values_string_to_data]
Expand Down
13 changes: 9 additions & 4 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.spark.sql.types._
/** Meta-data for cast and ansi_cast. */
final class CastExprMeta[INPUT <: UnaryExpression with TimeZoneAwareExpression with NullIntolerant](
cast: INPUT,
val ansiEnabled: Boolean,
val evalMode: GpuEvalMode.Value,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule,
Expand All @@ -53,14 +53,19 @@ final class CastExprMeta[INPUT <: UnaryExpression with TimeZoneAwareExpression w
extends UnaryExprMeta[INPUT](cast, conf, parent, rule) {

def withToTypeOverride(newToType: DecimalType): CastExprMeta[INPUT] =
new CastExprMeta[INPUT](cast, ansiEnabled, conf, parent, rule,
new CastExprMeta[INPUT](cast, evalMode, conf, parent, rule,
doFloatToIntCheck, stringToAnsiDate, Some(newToType))

val fromType: DataType = cast.child.dataType
val toType: DataType = toTypeOverride.getOrElse(cast.dataType)
val legacyCastToString: Boolean = SQLConf.get.getConf(SQLConf.LEGACY_COMPLEX_TYPES_TO_STRING)

override def tagExprForGpu(): Unit = recursiveTagExprForGpuCheck()
override def tagExprForGpu(): Unit = {
if (evalMode == GpuEvalMode.TRY) {
willNotWorkOnGpu("try_cast is not supported on the GPU")
}
recursiveTagExprForGpuCheck()
}

private def recursiveTagExprForGpuCheck(
fromDataType: DataType = fromType,
Expand Down Expand Up @@ -156,7 +161,7 @@ final class CastExprMeta[INPUT <: UnaryExpression with TimeZoneAwareExpression w
}

override def convertToGpu(child: Expression): GpuExpression =
GpuCast(child, toType, ansiEnabled, cast.timeZoneId, legacyCastToString,
GpuCast(child, toType, evalMode == GpuEvalMode.ANSI, cast.timeZoneId, legacyCastToString,
stringToAnsiDate)

// timezone tagging in type checks is good enough, so always false
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids

/**
* Expression evaluation modes.
* - LEGACY: the default evaluation mode, which is compliant to Hive SQL.
* - ANSI: a evaluation mode which is compliant to ANSI SQL standard.
* - TRY: a evaluation mode for `try_*` functions. It is identical to ANSI evaluation mode
* except for returning null result on errors.
*/
object GpuEvalMode extends Enumeration {
val LEGACY, ANSI, TRY = Value
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,29 @@
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

import com.nvidia.spark.rapids.GpuCast
import com.nvidia.spark.rapids.{GpuCast, GpuEvalMode}

import org.apache.spark.sql.catalyst.expressions.{AnsiCast, Cast, Expression}

object AnsiCastShim {
def isAnsiCast(e: Expression): Boolean = e match {
case c: GpuCast => c.ansiMode
case _: AnsiCast => true
case _: Cast =>
val m = e.getClass.getDeclaredField("ansiEnabled")
m.setAccessible(true)
m.getBoolean(e)
case _: Cast => isAnsiEnabled(e)
case _ => false
}

def getEvalMode(c: Cast): GpuEvalMode.Value = {
if (isAnsiEnabled(c)) {
GpuEvalMode.ANSI
} else {
GpuEvalMode.LEGACY
}
}

private def isAnsiEnabled(e: Expression) = {
val m = e.getClass.getDeclaredField("ansiEnabled")
m.setAccessible(true)
m.getBoolean(e)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ abstract class Spark31XShims extends Spark31Xuntil33XShims with Logging {
override val udtChecks: TypeSig = none
override val sparkUdtSig: TypeSig = UDT
},
(cast, conf, p, r) => new CastExprMeta[AnsiCast](cast, ansiEnabled = true, conf = conf,
(cast, conf, p, r) => new CastExprMeta[AnsiCast](cast, GpuEvalMode.ANSI, conf = conf,
parent = p, rule = r, doFloatToIntCheck = true, stringToAnsiDate = false))
}

Expand All @@ -244,7 +244,7 @@ abstract class Spark31XShims extends Spark31Xuntil33XShims with Logging {
"Convert a column of one type of data into another type",
new CastChecks(),
(cast, conf, p, r) => new CastExprMeta[Cast](cast,
SparkSession.active.sessionState.conf.ansiEnabled, conf, p, r,
AnsiCastShim.getEvalMode(cast), conf, p, r,
doFloatToIntCheck = true, stringToAnsiDate = false)),
GpuOverrides.expr[Average](
"Average aggregate operator",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ trait AnsiCastRuleShims extends SparkShims {
override val udtChecks: TypeSig = none
override val sparkUdtSig: TypeSig = UDT
},
(cast, conf, p, r) => new CastExprMeta[AnsiCast](cast, ansiEnabled = true, conf = conf,
(cast, conf, p, r) => new CastExprMeta[AnsiCast](cast, GpuEvalMode.ANSI, conf = conf,
parent = p, rule = r, doFloatToIntCheck = true, stringToAnsiDate = true))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,11 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging {
GpuOverrides.expr[Cast](
"Convert a column of one type of data into another type",
new CastChecks(),
(cast, conf, p, r) => new CastExprMeta[Cast](cast,
SparkSession.active.sessionState.conf.ansiEnabled, conf, p, r,
doFloatToIntCheck = true, stringToAnsiDate = true)),
(cast, conf, p, r) => {
new CastExprMeta[Cast](cast,
AnsiCastShim.getEvalMode(cast), conf, p, r,
doFloatToIntCheck = true, stringToAnsiDate = true)
}),
GpuOverrides.expr[Average](
"Average aggregate operator",
ExprChecks.fullAgg(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

import com.nvidia.spark.rapids.GpuCast
import com.nvidia.spark.rapids.{GpuCast, GpuEvalMode}

import org.apache.spark.sql.catalyst.expressions.{Cast, EvalMode, Expression}

Expand All @@ -29,4 +29,12 @@ object AnsiCastShim {
case c: Cast => c.evalMode == EvalMode.ANSI
case _ => false
}

def getEvalMode(c: Cast): GpuEvalMode.Value = {
c.evalMode match {
case EvalMode.LEGACY => GpuEvalMode.LEGACY
case EvalMode.ANSI => GpuEvalMode.ANSI
case EvalMode.TRY => GpuEvalMode.TRY
}
}
}

0 comments on commit 8703289

Please sign in to comment.