diff --git a/integration_tests/src/main/python/csv_test.py b/integration_tests/src/main/python/csv_test.py index 575446963af..3067bb1ee5b 100644 --- a/integration_tests/src/main/python/csv_test.py +++ b/integration_tests/src/main/python/csv_test.py @@ -127,9 +127,12 @@ _double_schema = StructType([ StructField('number', DoubleType())]) -_decimal_schema = StructType([ +_decimal_10_2_schema = StructType([ StructField('number', DecimalType(10, 2))]) +_decimal_10_3_schema = StructType([ + StructField('number', DecimalType(10, 3))]) + _number_as_string_schema = StructType([ StructField('number', StringType())]) @@ -223,8 +226,9 @@ def read_impl(spark): pytest.param('simple_int_values.csv', _long_schema, {'header': 'true'}), ('simple_int_values.csv', _float_schema, {'header': 'true'}), ('simple_int_values.csv', _double_schema, {'header': 'true'}), - ('simple_int_values.csv', _decimal_schema, {'header': 'true'}), - ('decimals.csv', _decimal_schema, {'header': 'true'}), + ('simple_int_values.csv', _decimal_10_2_schema, {'header': 'true'}), + ('decimals.csv', _decimal_10_2_schema, {'header': 'true'}), + ('decimals.csv', _decimal_10_3_schema, {'header': 'true'}), pytest.param('empty_int_values.csv', _empty_byte_schema, {'header': 'true'}), pytest.param('empty_int_values.csv', _empty_short_schema, {'header': 'true'}), pytest.param('empty_int_values.csv', _empty_int_schema, {'header': 'true'}), @@ -240,7 +244,8 @@ def read_impl(spark): pytest.param('simple_float_values.csv', _long_schema, {'header': 'true'}), pytest.param('simple_float_values.csv', _float_schema, {'header': 'true'}), pytest.param('simple_float_values.csv', _double_schema, {'header': 'true'}), - pytest.param('simple_float_values.csv', _decimal_schema, {'header': 'true'}), + pytest.param('simple_float_values.csv', _decimal_10_2_schema, {'header': 'true'}), + pytest.param('simple_float_values.csv', _decimal_10_3_schema, {'header': 'true'}), pytest.param('simple_boolean_values.csv', _bool_schema, {'header': 'true'}), pytest.param('ints_with_whitespace.csv', _number_as_string_schema, {'header': 'true'}, marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/2069')), pytest.param('ints_with_whitespace.csv', _byte_schema, {'header': 'true'}, marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/130')) diff --git a/integration_tests/src/main/python/json_test.py b/integration_tests/src/main/python/json_test.py index 26f8cd24da5..6c27df15572 100644 --- a/integration_tests/src/main/python/json_test.py +++ b/integration_tests/src/main/python/json_test.py @@ -59,9 +59,12 @@ _double_schema = StructType([ StructField('number', DoubleType())]) -_decimal_schema = StructType([ +_decimal_10_2_schema = StructType([ StructField('number', DecimalType(10, 2))]) +_decimal_10_3_schema = StructType([ + StructField('number', DecimalType(10, 3))]) + _string_schema = StructType([ StructField('a', StringType())]) @@ -201,7 +204,7 @@ def test_json_ts_formats_round_trip(spark_tmp_path, date_format, ts_part, v1_ena pytest.param('floats_edge_cases.json', marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/4647')), 'decimals.json', ]) -@pytest.mark.parametrize('schema', [_bool_schema, _byte_schema, _short_schema, _int_schema, _long_schema, _float_schema, _double_schema, _decimal_schema]) +@pytest.mark.parametrize('schema', [_bool_schema, _byte_schema, _short_schema, _int_schema, _long_schema, _float_schema, _double_schema, _decimal_10_2_schema, _decimal_10_3_schema]) @pytest.mark.parametrize('read_func', [read_json_df, read_json_sql]) @pytest.mark.parametrize('allow_non_numeric_numbers', ["true", "false"]) @pytest.mark.parametrize('allow_numeric_leading_zeros', ["true"]) diff --git a/integration_tests/src/test/resources/decimals.csv b/integration_tests/src/test/resources/decimals.csv index a199a772d0a..b8e83af6fff 100644 --- a/integration_tests/src/test/resources/decimals.csv +++ b/integration_tests/src/test/resources/decimals.csv @@ -11,4 +11,6 @@ 12.34 12.3456 12345678.12 +33.545454 +33.454545 diff --git a/integration_tests/src/test/resources/decimals.json b/integration_tests/src/test/resources/decimals.json index 5a8fd685ff9..0a98fd05474 100644 --- a/integration_tests/src/test/resources/decimals.json +++ b/integration_tests/src/test/resources/decimals.json @@ -10,3 +10,5 @@ { "number": 12.3456 } { "number": 12.345678 } { "number": 123456.78 } +{ "number": 33.454545 } +{ "number": 33.545454 } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/json/rapids/GpuJsonScan.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/json/rapids/GpuJsonScan.scala index abaf13d5728..cbc561c8811 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/json/rapids/GpuJsonScan.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/json/rapids/GpuJsonScan.scala @@ -371,7 +371,7 @@ class JsonPartitionReader( /** * JSON has strict rules about valid numeric formats. See https://www.json.org/ for specification. * - * Spark then has it's own rules for supporting NaN and Infinity, which are not + * Spark then has its own rules for supporting NaN and Infinity, which are not * valid numbers in JSON. */ private def sanitizeNumbers(input: ColumnVector): ColumnVector = {