Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix databrick Shim to support Ansi mode when casting from string to date [databricks] #5494

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions integration_tests/src/main/python/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql, assert_gpu_and_cpu_error, assert_gpu_fallback_collect, assert_py4j_exception
from data_gen import *
from spark_session import is_before_spark_320, is_before_spark_330, with_gpu_session
from spark_session import is_before_spark_320, is_before_spark_330, is_databricks91_or_later, with_gpu_session
from marks import allow_non_gpu, approximate_float
from pyspark.sql.types import *
from spark_init_internal import spark_version
Expand Down Expand Up @@ -83,15 +83,36 @@ def test_cast_string_date_valid_format():
]
values_string_to_data = invalid_values_string_to_date + valid_values_string_to_date

# test Spark Spark versions < 3.2.0, ANSI mode
@pytest.mark.skipif(not is_before_spark_320(), reason="ansi cast(string as date) throws exception only in 3.2.0+")
# Spark 320+ and databricks support Ansi mode when casting string to date
# This means an exception will be thrown when casting invalid string to date on Spark 320+ or databricks
# test Spark versions < 3.2.0 and non databricks, ANSI mode
@pytest.mark.skipif((not is_before_spark_320()) or is_databricks91_or_later(), reason="ansi cast(string as date) throws exception only in 3.2.0+ or db")
def test_cast_string_date_invalid_ansi_before_320():
data_rows = [(v,) for v in values_string_to_data]
assert_gpu_and_cpu_are_equal_collect(
lambda spark: spark.createDataFrame(data_rows, "a string").select(f.col('a').cast(DateType())),
conf={'spark.rapids.sql.hasExtendedYearValues': 'false',
'spark.sql.ansi.enabled': 'true'}, )

# test databricks, ANSI mode, all databricks versions supports Ansi mode when casting string to date
@pytest.mark.skipif(not is_databricks91_or_later(), reason="Spark versions(< 320) not support Ansi mode when casting string to date")
@pytest.mark.parametrize('invalid', invalid_values_string_to_date)
def test_cast_string_date_invalid_ansi_databricks(invalid):
assert_gpu_and_cpu_error(
lambda spark: spark.createDataFrame([(invalid,)], "a string").select(f.col('a').cast(DateType())).collect(),
conf={'spark.rapids.sql.hasExtendedYearValues': 'false',
'spark.sql.ansi.enabled': 'true'},
error_message="DateTimeException")

# test databricks, ANSI mode, valid values
@pytest.mark.skipif(not is_databricks91_or_later(), reason="Spark versions(< 320) not support Ansi mode when casting string to date")
def test_cast_string_date_valid_ansi_databricks():
data_rows = [(v,) for v in valid_values_string_to_date]
assert_gpu_and_cpu_are_equal_collect(
lambda spark: spark.createDataFrame(data_rows, "a string").select(f.col('a').cast(DateType())),
conf={'spark.rapids.sql.hasExtendedYearValues': 'false',
'spark.sql.ansi.enabled': 'true'})

# test Spark versions >= 320, ANSI mode
@pytest.mark.skipif(is_before_spark_320(), reason="ansi cast(string as date) throws exception only in 3.2.0+")
@pytest.mark.parametrize('invalid', invalid_values_string_to_date)
Expand Down
14 changes: 7 additions & 7 deletions integration_tests/src/main/python/parquet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,14 +878,14 @@ def test_parquet_check_schema_compatibility(spark_tmp_path):
assert_gpu_and_cpu_error(
lambda spark: spark.read.schema(read_int_as_long).parquet(data_path).collect(),
conf={},
error_message='Parquet column cannot be converted in')
error_message='Parquet column cannot be converted')

read_dec32_as_dec64 = StructType(
[StructField('int', IntegerType()), StructField('dec32', DecimalType(15, 10))])
assert_gpu_and_cpu_error(
lambda spark: spark.read.schema(read_dec32_as_dec64).parquet(data_path).collect(),
conf={},
error_message='Parquet column cannot be converted in')
error_message='Parquet column cannot be converted')


# For nested types, GPU throws incompatible exception with a different message from CPU.
Expand All @@ -902,32 +902,32 @@ def test_parquet_check_schema_compatibility_nested_types(spark_tmp_path):
assert_py4j_exception(
lambda: with_gpu_session(
lambda spark: spark.read.schema(read_array_long_as_int).parquet(data_path).collect()),
error_message='Parquet column cannot be converted in')
error_message='Parquet column cannot be converted')

read_arr_arr_int_as_long = StructType(
[StructField('array_array_int', ArrayType(ArrayType(LongType())))])
assert_py4j_exception(
lambda: with_gpu_session(
lambda spark: spark.read.schema(read_arr_arr_int_as_long).parquet(data_path).collect()),
error_message='Parquet column cannot be converted in')
error_message='Parquet column cannot be converted')

read_struct_flt_as_dbl = StructType([StructField(
'struct_float', StructType([StructField('f', DoubleType())]))])
assert_py4j_exception(
lambda: with_gpu_session(
lambda spark: spark.read.schema(read_struct_flt_as_dbl).parquet(data_path).collect()),
error_message='Parquet column cannot be converted in')
error_message='Parquet column cannot be converted')

read_struct_arr_int_as_long = StructType([StructField(
'struct_array_int', StructType([StructField('a', ArrayType(LongType()))]))])
assert_py4j_exception(
lambda: with_gpu_session(
lambda spark: spark.read.schema(read_struct_arr_int_as_long).parquet(data_path).collect()),
error_message='Parquet column cannot be converted in')
error_message='Parquet column cannot be converted')

read_map_str_str_as_str_int = StructType([StructField(
'map', MapType(StringType(), IntegerType()))])
assert_py4j_exception(
lambda: with_gpu_session(
lambda spark: spark.read.schema(read_map_str_str_as_str_int).parquet(data_path).collect()),
error_message='Parquet column cannot be converted in')
error_message='Parquet column cannot be converted')
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,12 @@ abstract class Spark31XdbShims extends Spark31XdbShimsBase with Logging {
GpuOverrides.expr[Cast](
"Convert a column of one type of data into another type",
new CastChecks(),
// 312db supports Ansi mode when casting string to date, this means that an exception
// will be thrown when casting an invalid value to date in Ansi mode.
// Set `stringToAnsiDate` = true
(cast, conf, p, r) => new CastExprMeta[Cast](cast,
SparkSession.active.sessionState.conf.ansiEnabled, conf, p, r,
doFloatToIntCheck = true, stringToAnsiDate = false)),
doFloatToIntCheck = true, stringToAnsiDate = true)),
GpuOverrides.expr[AnsiCast](
"Convert a column of one type of data into another type",
new CastChecks {
Expand Down Expand Up @@ -112,8 +115,11 @@ abstract class Spark31XdbShims extends Spark31XdbShimsBase with Logging {
override val udtChecks: TypeSig = none
override val sparkUdtSig: TypeSig = UDT
},
// 312db supports Ansi mode when casting string to date, this means that an exception
// will be thrown when casting an invalid value to date in Ansi mode.
// Set `stringToAnsiDate` = true
(cast, conf, p, r) => new CastExprMeta[AnsiCast](cast, ansiEnabled = true, conf = conf,
parent = p, rule = r, doFloatToIntCheck = true, stringToAnsiDate = false)),
parent = p, rule = r, doFloatToIntCheck = true, stringToAnsiDate = true)),
GpuOverrides.expr[Average](
"Average aggregate operator",
ExprChecks.fullAgg(
Expand Down