Skip to content

Commit

Permalink
Add recursive type checking and fallback tests for casting array with…
Browse files Browse the repository at this point in the history
… unsupported element types to string (#4449)

Signed-off-by: remzi <13716567376yh@gmail.com>
  • Loading branch information
HaoYang670 authored Jan 4, 2022
1 parent c9125bf commit 4606172
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 18 deletions.
8 changes: 4 additions & 4 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -16771,12 +16771,12 @@ and the accelerator produces the same result.
<td> </td>
<td> </td>
<td> </td>
<td>S</td>
<td><em>PS<br/>the array's child type must also support being cast to string</em></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>The array's child type must also support being cast to the desired child type;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types CALENDAR, UDT</em></td>
<td><em>PS<br/>The array's child type must also support being cast to the desired child type(s);<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types CALENDAR, UDT</em></td>
<td> </td>
<td> </td>
<td> </td>
Expand Down Expand Up @@ -17175,12 +17175,12 @@ and the accelerator produces the same result.
<td> </td>
<td> </td>
<td> </td>
<td>S</td>
<td><em>PS<br/>the array's child type must also support being cast to string</em></td>
<td> </td>
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>The array's child type must also support being cast to the desired child type;<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types CALENDAR, UDT</em></td>
<td><em>PS<br/>The array's child type must also support being cast to the desired child type(s);<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types CALENDAR, UDT</em></td>
<td> </td>
<td> </td>
<td> </td>
Expand Down
34 changes: 23 additions & 11 deletions integration_tests/src/main/python/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def test_cast_long_to_decimal_overflow():
# casting these types to string is not exact match, marked as xfail when testing
not_matched_gens_for_cast_to_string = [float_gen, double_gen, decimal_gen_neg_scale]
# casting these types to string is not supported, marked as xfail when testing
not_support_gens_for_cast_to_string = decimal_128_gens
not_support_gens_for_cast_to_string = decimal_128_gens + [MapGen(ByteGen(False), ByteGen())]

single_level_array_gens_for_cast_to_string = [ArrayGen(sub_gen) for sub_gen in basic_gens_for_cast_to_string]
nested_array_gens_for_cast_to_string = [
Expand All @@ -188,13 +188,23 @@ def test_cast_long_to_decimal_overflow():

def _assert_cast_to_string_equal (data_gen, conf):
"""
helper function for casting to string
helper function for casting to string of supported type
"""
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).select(f.col('a').cast("STRING")),
conf
)

def _assert_cast_to_string_fallback (data_gen, conf):
"""
helper function for casting to string of unsupported type
"""
assert_gpu_fallback_collect(
lambda spark: unary_op_df(spark, data_gen).select(f.col('a').cast("STRING")),
"Cast",
conf
)

@pytest.mark.parametrize('data_gen', all_gens_for_cast_to_string, ids=idfn)
@pytest.mark.parametrize('legacy', ['true', 'false'])
def test_cast_array_to_string(data_gen, legacy):
Expand All @@ -219,12 +229,13 @@ def test_cast_array_with_unmatched_element_to_string(data_gen, legacy):

@pytest.mark.parametrize('data_gen', [ArrayGen(sub) for sub in not_support_gens_for_cast_to_string], ids=idfn)
@pytest.mark.parametrize('legacy', ['true', 'false'])
@pytest.mark.xfail(reason='casting this type to string is not supported')
def test_cast_array_with_unsupported_element_to_string(data_gen, legacy):
_assert_cast_to_string_equal(
@allow_non_gpu('ProjectExec', 'Cast', 'Alias')
def test_cast_array_with_unsupported_element_to_string_fallback(data_gen, legacy):
_assert_cast_to_string_fallback(
data_gen,
{"spark.rapids.sql.castDecimalToString.enabled" : 'true',
"spark.sql.legacy.castComplexTypesToString.enabled": legacy}
"spark.sql.legacy.castComplexTypesToString.enabled": legacy,
"spark.sql.legacy.allowNegativeScaleOfDecimal": 'true'}
)


Expand Down Expand Up @@ -285,11 +296,12 @@ def test_cast_struct_with_unmatched_element_to_string(data_gen, legacy):

@pytest.mark.parametrize('data_gen', [StructGen([["first", element_gen]]) for element_gen in not_support_gens_for_cast_to_string], ids=idfn)
@pytest.mark.parametrize('legacy', ['true', 'false'])
@pytest.mark.xfail(reason='casting this type to string is not supported')
def test_cast_struct_with_unsupported_element_to_string(data_gen, legacy):
_assert_cast_to_string_equal(
@allow_non_gpu('ProjectExec', 'Cast', 'Alias')
def test_cast_struct_with_unsupported_element_to_string_fallback(data_gen, legacy):
_assert_cast_to_string_fallback(
data_gen,
{"spark.rapids.sql.castDecimalToString.enabled" : 'true',
"spark.sql.legacy.castComplexTypesToString.enabled": legacy}
{"spark.rapids.sql.castDecimalToString.enabled" : 'true',
"spark.sql.legacy.castComplexTypesToString.enabled": legacy,
"spark.sql.legacy.allowNegativeScaleOfDecimal": 'true'}
)

Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ final class CastExprMeta[INPUT <: CastBase](
case (fromChild, toChild) =>
recursiveTagExprForGpuCheck(fromChild.dataType, toChild.dataType, depth + 1)
}
case (ArrayType(elementType, _), StringType) =>
recursiveTagExprForGpuCheck(elementType, StringType, depth + 1)

case (ArrayType(nestedFrom, _), ArrayType(nestedTo, _)) =>
recursiveTagExprForGpuCheck(nestedFrom, nestedTo, depth + 1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1270,10 +1270,11 @@ class CastChecks extends ExprChecks {
val calendarChecks: TypeSig = none
val sparkCalendarSig: TypeSig = CALENDAR + STRING

val arrayChecks: TypeSig = STRING + ARRAY.nested(commonCudfTypes + DECIMAL_128_FULL + NULL +
val arrayChecks: TypeSig = psNote(TypeEnum.STRING, "the array's child type must also support " +
"being cast to string") + ARRAY.nested(commonCudfTypes + DECIMAL_128_FULL + NULL +
ARRAY + BINARY + STRUCT + MAP) +
psNote(TypeEnum.ARRAY, "The array's child type must also support being cast to " +
"the desired child type")
psNote(TypeEnum.ARRAY, "The array's child type must also support being cast to the " +
"desired child type(s)")

val sparkArraySig: TypeSig = STRING + ARRAY.nested(all)

Expand Down

0 comments on commit 4606172

Please sign in to comment.