Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fall back to CPU for try_cast in Spark 3.4.0 [databricks] #8179

Merged
merged 19 commits into from
May 4, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion integration_tests/src/main/python/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql, assert_gpu_and_cpu_error, assert_gpu_fallback_collect, assert_py4j_exception
from data_gen import *
from spark_session import is_before_spark_320, is_before_spark_330, with_gpu_session
from spark_session import is_before_spark_320, is_before_spark_330, is_spark_32X, is_spark_33X, with_gpu_session
from marks import allow_non_gpu, approximate_float
from pyspark.sql.types import *
from spark_init_internal import spark_version
Expand Down Expand Up @@ -115,6 +115,19 @@ def test_cast_string_date_invalid_ansi(invalid):
'spark.sql.ansi.enabled': 'true'},
error_message="DateTimeException")


# test try_cast in Spark versions >= 340
test_try_cast_fallback_non_gpu = ['ProjectExec', 'Cast'] if is_spark_340_or_later() or is_databricks113_or_later() else ['ProjectExec','TryCast']
@pytest.mark.skipif(is_before_spark_320(), reason="try_cast only in Spark 3.2+")
@allow_non_gpu(test_try_cast_fallback_non_gpu)
@pytest.mark.parametrize('invalid', invalid_values_string_to_date)
gerashegalov marked this conversation as resolved.
Show resolved Hide resolved
def test_try_cast_fallback_340(invalid):
assert_gpu_fallback_collect(
lambda spark: spark.createDataFrame([(invalid,)], "a string").selectExpr("try_cast(a as date)"),
gerashegalov marked this conversation as resolved.
Show resolved Hide resolved
'Cast',
conf={'spark.rapids.sql.hasExtendedYearValues': 'false',
'spark.sql.ansi.enabled': 'true'})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
conf={'spark.rapids.sql.hasExtendedYearValues': 'false',
'spark.sql.ansi.enabled': 'true'})
conf={'spark.rapids.sql.hasExtendedYearValues': False,
'spark.sql.ansi.enabled': True})


# test all Spark versions, non ANSI mode, invalid value will be converted to NULL
def test_cast_string_date_non_ansi():
data_rows = [(v,) for v in values_string_to_data]
Expand Down
3 changes: 3 additions & 0 deletions integration_tests/src/main/python/spark_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ def is_spark_340_or_later():
def is_spark_330():
return spark_version() == "3.3.0"

def is_spark_32X():
return "3.2.0" <= spark_version() < "3.3.0"

def is_spark_33X():
return "3.3.0" <= spark_version() < "3.4.0"

Expand Down
21 changes: 17 additions & 4 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.spark.sql.types._
/** Meta-data for cast and ansi_cast. */
final class CastExprMeta[INPUT <: UnaryExpression with TimeZoneAwareExpression with NullIntolerant](
cast: INPUT,
val ansiEnabled: Boolean,
val evalMode: GpuEvalMode.Value,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule,
Expand All @@ -52,15 +52,28 @@ final class CastExprMeta[INPUT <: UnaryExpression with TimeZoneAwareExpression w
toTypeOverride: Option[DataType] = None)
extends UnaryExprMeta[INPUT](cast, conf, parent, rule) {

def withToTypeOverride(newToType: DecimalType): CastExprMeta[INPUT] =
new CastExprMeta[INPUT](cast, ansiEnabled, conf, parent, rule,
val ansiEnabled = evalMode == GpuEvalMode.ANSI

def withToTypeOverride(newToType: DecimalType): CastExprMeta[INPUT] = {
val evalMode = if (ansiEnabled) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Why use evalMode to compute ansiEnabled and then turn around and use ansiEnabled to compute evalMode?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was trying to minimize changes to existing code, but I can revisit this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is just a nit so you can decide what to do. But if you delete lines 58 to 62 I would not complain about it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

GpuEvalMode.ANSI
} else {
GpuEvalMode.LEGACY
}
new CastExprMeta[INPUT](cast, evalMode, conf, parent, rule,
doFloatToIntCheck, stringToAnsiDate, Some(newToType))
}

val fromType: DataType = cast.child.dataType
val toType: DataType = toTypeOverride.getOrElse(cast.dataType)
val legacyCastToString: Boolean = SQLConf.get.getConf(SQLConf.LEGACY_COMPLEX_TYPES_TO_STRING)

override def tagExprForGpu(): Unit = recursiveTagExprForGpuCheck()
override def tagExprForGpu(): Unit = {
if (evalMode == GpuEvalMode.TRY) {
willNotWorkOnGpu("try_cast is not supported on the GPU")
}
recursiveTagExprForGpuCheck()
}

private def recursiveTagExprForGpuCheck(
fromDataType: DataType = fromType,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids

/**
* Expression evaluation modes.
* - LEGACY: the default evaluation mode, which is compliant to Hive SQL.
* - ANSI: a evaluation mode which is compliant to ANSI SQL standard.
* - TRY: a evaluation mode for `try_*` functions. It is identical to ANSI evaluation mode
* except for returning null result on errors.
*/
object GpuEvalMode extends Enumeration {
val LEGACY, ANSI, TRY = Value
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

import com.nvidia.spark.rapids.GpuCast
import com.nvidia.spark.rapids.{GpuCast, GpuEvalMode}

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{AnsiCast, Cast, Expression}

object AnsiCastShim {
Expand All @@ -46,4 +47,12 @@ object AnsiCastShim {
m.getBoolean(e)
case _ => false
}

def getEvalMode(c: Cast): GpuEvalMode.Value = {
revans2 marked this conversation as resolved.
Show resolved Hide resolved
if (SparkSession.active.sessionState.conf.ansiEnabled) {
GpuEvalMode.ANSI
} else {
GpuEvalMode.LEGACY
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ abstract class Spark31XShims extends Spark31Xuntil33XShims with Logging {
override val udtChecks: TypeSig = none
override val sparkUdtSig: TypeSig = UDT
},
(cast, conf, p, r) => new CastExprMeta[AnsiCast](cast, ansiEnabled = true, conf = conf,
(cast, conf, p, r) => new CastExprMeta[AnsiCast](cast, GpuEvalMode.ANSI, conf = conf,
parent = p, rule = r, doFloatToIntCheck = true, stringToAnsiDate = false))
}

Expand All @@ -244,7 +244,7 @@ abstract class Spark31XShims extends Spark31Xuntil33XShims with Logging {
"Convert a column of one type of data into another type",
new CastChecks(),
(cast, conf, p, r) => new CastExprMeta[Cast](cast,
SparkSession.active.sessionState.conf.ansiEnabled, conf, p, r,
AnsiCastShim.getEvalMode(cast), conf, p, r,
doFloatToIntCheck = true, stringToAnsiDate = false)),
GpuOverrides.expr[Average](
"Average aggregate operator",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ trait AnsiCastRuleShims extends SparkShims {
override val udtChecks: TypeSig = none
override val sparkUdtSig: TypeSig = UDT
},
(cast, conf, p, r) => new CastExprMeta[AnsiCast](cast, ansiEnabled = true, conf = conf,
(cast, conf, p, r) => new CastExprMeta[AnsiCast](cast, GpuEvalMode.ANSI, conf = conf,
parent = p, rule = r, doFloatToIntCheck = true, stringToAnsiDate = true))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,11 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging {
GpuOverrides.expr[Cast](
"Convert a column of one type of data into another type",
new CastChecks(),
(cast, conf, p, r) => new CastExprMeta[Cast](cast,
SparkSession.active.sessionState.conf.ansiEnabled, conf, p, r,
doFloatToIntCheck = true, stringToAnsiDate = true)),
(cast, conf, p, r) => {
new CastExprMeta[Cast](cast,
AnsiCastShim.getEvalMode(cast), conf, p, r,
doFloatToIntCheck = true, stringToAnsiDate = true)
}),
GpuOverrides.expr[Average](
"Average aggregate operator",
ExprChecks.fullAgg(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

import com.nvidia.spark.rapids.GpuCast
import com.nvidia.spark.rapids.{GpuCast, GpuEvalMode}

import org.apache.spark.sql.catalyst.expressions.{Cast, EvalMode, Expression}

Expand All @@ -29,4 +29,12 @@ object AnsiCastShim {
case c: Cast => c.evalMode == EvalMode.ANSI
case _ => false
}

def getEvalMode(c: Cast): GpuEvalMode.Value = {
c.evalMode match {
case EvalMode.LEGACY => GpuEvalMode.LEGACY
case EvalMode.ANSI => GpuEvalMode.ANSI
case EvalMode.TRY => GpuEvalMode.TRY
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ package com.nvidia.spark.rapids.shims
import com.nvidia.spark.rapids._

import org.apache.spark.rapids.shims.GpuShuffleExchangeExec
import org.apache.spark.sql.catalyst.expressions.{Expression, KnownNullable}
import org.apache.spark.sql.catalyst.expressions.Empty2Null
import org.apache.spark.sql.catalyst.expressions.{Empty2Null, Expression, KnownNullable}
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
import org.apache.spark.sql.execution.{CollectLimitExec, GlobalLimitExec, SparkPlan}
import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, DataWritingCommand, RunnableCommand}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused and it is not clear how Spark340Plus shims is going to know if the cast is ANSI, TRY, or LEGACY. This inherits from Spark331PlusShims, But 331 does not do this, Only databricks 330, and I don't think this inherits from that.

Copy link
Contributor Author

@andygrove andygrove Apr 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We provide overrides for Cast in Spark320PlusShims and Spark31XShims.

Spark340PlusShims indirectly extends Spark320PlusShims (via Spark331PlusShims, Spark330PlusNonDBShims, Spark330PlusShims, and Spark321PlusShims).

These shims are delegating to AnsiCastShim, which is shimmed for 311+ and 330db/340 as follows:

sql-plugin/src/main/spark311/scala/com/nvidia/spark/rapids/shims/AnsiCastShim.scala (311+)
sql-plugin/src/main/spark330db/scala/com/nvidia/spark/rapids/shims/AnsiCastShim.scala (330db + 340)

It is all very confusing, for sure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me, it was especially confusing that we have a db shim that shims for non-db. I understand that it makes sense, but I think this is not a pattern we are used to seeing so far.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it works, because the tests pass, but we should fix this. Can you file an issue so that we don't have 340 depend on 330db, unless we have one to do it already?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was a discussion about this in #8169 (comment) and it seems that it is correct that we have both 330db/340 in the same shim.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is right. Not that it makes it any less confusing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I filed #8188 for removing the dependency from 340 to 330db

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the best way to make the code more clear is to break up trait inheritance, stop encoding version ranges in the class/trait/object names

Expand Down Expand Up @@ -79,6 +78,7 @@ trait Spark340PlusShims extends Spark331PlusShims {
// AnsiCast is removed from Spark3.4.0
override def ansiCastRule: ExprRule[_ <: Expression] = null


override def getExprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = {
val shimExprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Seq(
// Empty2Null is pulled out of FileFormatWriter by default since Spark 3.4.0,
Expand Down