Skip to content

Commit

Permalink
Fix databrick test case failure [databricks] (#5494)
Browse files Browse the repository at this point in the history
* 312db supports Ansi mode when casting string to date

Signed-off-by: Chong Gao <res_life@163.com>

* fix test_parquet_check_schema_compatibility for Databricks runtime

Signed-off-by: sperlingxx <lovedreamf@gmail.com>

Co-authored-by: sperlingxx <lovedreamf@gmail.com>
  • Loading branch information
res-life and sperlingxx authored May 17, 2022
1 parent c477305 commit 80761f8
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 12 deletions.
27 changes: 24 additions & 3 deletions integration_tests/src/main/python/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql, assert_gpu_and_cpu_error, assert_gpu_fallback_collect, assert_py4j_exception
from data_gen import *
from spark_session import is_before_spark_320, is_before_spark_330, with_gpu_session
from spark_session import is_before_spark_320, is_before_spark_330, is_databricks91_or_later, with_gpu_session
from marks import allow_non_gpu, approximate_float
from pyspark.sql.types import *
from spark_init_internal import spark_version
Expand Down Expand Up @@ -83,15 +83,36 @@ def test_cast_string_date_valid_format():
]
values_string_to_data = invalid_values_string_to_date + valid_values_string_to_date

# test Spark Spark versions < 3.2.0, ANSI mode
@pytest.mark.skipif(not is_before_spark_320(), reason="ansi cast(string as date) throws exception only in 3.2.0+")
# Spark 320+ and databricks support Ansi mode when casting string to date
# This means an exception will be thrown when casting invalid string to date on Spark 320+ or databricks
# test Spark versions < 3.2.0 and non databricks, ANSI mode
@pytest.mark.skipif((not is_before_spark_320()) or is_databricks91_or_later(), reason="ansi cast(string as date) throws exception only in 3.2.0+ or db")
def test_cast_string_date_invalid_ansi_before_320():
data_rows = [(v,) for v in values_string_to_data]
assert_gpu_and_cpu_are_equal_collect(
lambda spark: spark.createDataFrame(data_rows, "a string").select(f.col('a').cast(DateType())),
conf={'spark.rapids.sql.hasExtendedYearValues': 'false',
'spark.sql.ansi.enabled': 'true'}, )

# test databricks, ANSI mode, all databricks versions supports Ansi mode when casting string to date
@pytest.mark.skipif(not is_databricks91_or_later(), reason="Spark versions(< 320) not support Ansi mode when casting string to date")
@pytest.mark.parametrize('invalid', invalid_values_string_to_date)
def test_cast_string_date_invalid_ansi_databricks(invalid):
assert_gpu_and_cpu_error(
lambda spark: spark.createDataFrame([(invalid,)], "a string").select(f.col('a').cast(DateType())).collect(),
conf={'spark.rapids.sql.hasExtendedYearValues': 'false',
'spark.sql.ansi.enabled': 'true'},
error_message="DateTimeException")

# test databricks, ANSI mode, valid values
@pytest.mark.skipif(not is_databricks91_or_later(), reason="Spark versions(< 320) not support Ansi mode when casting string to date")
def test_cast_string_date_valid_ansi_databricks():
data_rows = [(v,) for v in valid_values_string_to_date]
assert_gpu_and_cpu_are_equal_collect(
lambda spark: spark.createDataFrame(data_rows, "a string").select(f.col('a').cast(DateType())),
conf={'spark.rapids.sql.hasExtendedYearValues': 'false',
'spark.sql.ansi.enabled': 'true'})

# test Spark versions >= 320, ANSI mode
@pytest.mark.skipif(is_before_spark_320(), reason="ansi cast(string as date) throws exception only in 3.2.0+")
@pytest.mark.parametrize('invalid', invalid_values_string_to_date)
Expand Down
14 changes: 7 additions & 7 deletions integration_tests/src/main/python/parquet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,14 +878,14 @@ def test_parquet_check_schema_compatibility(spark_tmp_path):
assert_gpu_and_cpu_error(
lambda spark: spark.read.schema(read_int_as_long).parquet(data_path).collect(),
conf={},
error_message='Parquet column cannot be converted in')
error_message='Parquet column cannot be converted')

read_dec32_as_dec64 = StructType(
[StructField('int', IntegerType()), StructField('dec32', DecimalType(15, 10))])
assert_gpu_and_cpu_error(
lambda spark: spark.read.schema(read_dec32_as_dec64).parquet(data_path).collect(),
conf={},
error_message='Parquet column cannot be converted in')
error_message='Parquet column cannot be converted')


# For nested types, GPU throws incompatible exception with a different message from CPU.
Expand All @@ -902,32 +902,32 @@ def test_parquet_check_schema_compatibility_nested_types(spark_tmp_path):
assert_py4j_exception(
lambda: with_gpu_session(
lambda spark: spark.read.schema(read_array_long_as_int).parquet(data_path).collect()),
error_message='Parquet column cannot be converted in')
error_message='Parquet column cannot be converted')

read_arr_arr_int_as_long = StructType(
[StructField('array_array_int', ArrayType(ArrayType(LongType())))])
assert_py4j_exception(
lambda: with_gpu_session(
lambda spark: spark.read.schema(read_arr_arr_int_as_long).parquet(data_path).collect()),
error_message='Parquet column cannot be converted in')
error_message='Parquet column cannot be converted')

read_struct_flt_as_dbl = StructType([StructField(
'struct_float', StructType([StructField('f', DoubleType())]))])
assert_py4j_exception(
lambda: with_gpu_session(
lambda spark: spark.read.schema(read_struct_flt_as_dbl).parquet(data_path).collect()),
error_message='Parquet column cannot be converted in')
error_message='Parquet column cannot be converted')

read_struct_arr_int_as_long = StructType([StructField(
'struct_array_int', StructType([StructField('a', ArrayType(LongType()))]))])
assert_py4j_exception(
lambda: with_gpu_session(
lambda spark: spark.read.schema(read_struct_arr_int_as_long).parquet(data_path).collect()),
error_message='Parquet column cannot be converted in')
error_message='Parquet column cannot be converted')

read_map_str_str_as_str_int = StructType([StructField(
'map', MapType(StringType(), IntegerType()))])
assert_py4j_exception(
lambda: with_gpu_session(
lambda spark: spark.read.schema(read_map_str_str_as_str_int).parquet(data_path).collect()),
error_message='Parquet column cannot be converted in')
error_message='Parquet column cannot be converted')
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,12 @@ abstract class Spark31XdbShims extends Spark31XdbShimsBase with Logging {
GpuOverrides.expr[Cast](
"Convert a column of one type of data into another type",
new CastChecks(),
// 312db supports Ansi mode when casting string to date, this means that an exception
// will be thrown when casting an invalid value to date in Ansi mode.
// Set `stringToAnsiDate` = true
(cast, conf, p, r) => new CastExprMeta[Cast](cast,
SparkSession.active.sessionState.conf.ansiEnabled, conf, p, r,
doFloatToIntCheck = true, stringToAnsiDate = false)),
doFloatToIntCheck = true, stringToAnsiDate = true)),
GpuOverrides.expr[AnsiCast](
"Convert a column of one type of data into another type",
new CastChecks {
Expand Down Expand Up @@ -112,8 +115,11 @@ abstract class Spark31XdbShims extends Spark31XdbShimsBase with Logging {
override val udtChecks: TypeSig = none
override val sparkUdtSig: TypeSig = UDT
},
// 312db supports Ansi mode when casting string to date, this means that an exception
// will be thrown when casting an invalid value to date in Ansi mode.
// Set `stringToAnsiDate` = true
(cast, conf, p, r) => new CastExprMeta[AnsiCast](cast, ansiEnabled = true, conf = conf,
parent = p, rule = r, doFloatToIntCheck = true, stringToAnsiDate = false)),
parent = p, rule = r, doFloatToIntCheck = true, stringToAnsiDate = true)),
GpuOverrides.expr[Average](
"Average aggregate operator",
ExprChecks.fullAgg(
Expand Down

0 comments on commit 80761f8

Please sign in to comment.