Skip to content

Commit

Permalink
improve decimal tests to cover different rounding cases
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Feb 22, 2022
1 parent ec31dc9 commit 3d121bb
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 7 deletions.
13 changes: 9 additions & 4 deletions integration_tests/src/main/python/csv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,12 @@
_double_schema = StructType([
StructField('number', DoubleType())])

_decimal_schema = StructType([
_decimal_10_2_schema = StructType([
StructField('number', DecimalType(10, 2))])

_decimal_10_3_schema = StructType([
StructField('number', DecimalType(10, 3))])

_number_as_string_schema = StructType([
StructField('number', StringType())])

Expand Down Expand Up @@ -223,8 +226,9 @@ def read_impl(spark):
pytest.param('simple_int_values.csv', _long_schema, {'header': 'true'}),
('simple_int_values.csv', _float_schema, {'header': 'true'}),
('simple_int_values.csv', _double_schema, {'header': 'true'}),
('simple_int_values.csv', _decimal_schema, {'header': 'true'}),
('decimals.csv', _decimal_schema, {'header': 'true'}),
('simple_int_values.csv', _decimal_10_2_schema, {'header': 'true'}),
('decimals.csv', _decimal_10_2_schema, {'header': 'true'}),
('decimals.csv', _decimal_10_3_schema, {'header': 'true'}),
pytest.param('empty_int_values.csv', _empty_byte_schema, {'header': 'true'}),
pytest.param('empty_int_values.csv', _empty_short_schema, {'header': 'true'}),
pytest.param('empty_int_values.csv', _empty_int_schema, {'header': 'true'}),
Expand All @@ -240,7 +244,8 @@ def read_impl(spark):
pytest.param('simple_float_values.csv', _long_schema, {'header': 'true'}),
pytest.param('simple_float_values.csv', _float_schema, {'header': 'true'}),
pytest.param('simple_float_values.csv', _double_schema, {'header': 'true'}),
pytest.param('simple_float_values.csv', _decimal_schema, {'header': 'true'}),
pytest.param('simple_float_values.csv', _decimal_10_2_schema, {'header': 'true'}),
pytest.param('simple_float_values.csv', _decimal_10_3_schema, {'header': 'true'}),
pytest.param('simple_boolean_values.csv', _bool_schema, {'header': 'true'}),
pytest.param('ints_with_whitespace.csv', _number_as_string_schema, {'header': 'true'}, marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/2069')),
pytest.param('ints_with_whitespace.csv', _byte_schema, {'header': 'true'}, marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/130'))
Expand Down
7 changes: 5 additions & 2 deletions integration_tests/src/main/python/json_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,12 @@
_double_schema = StructType([
StructField('number', DoubleType())])

_decimal_schema = StructType([
_decimal_10_2_schema = StructType([
StructField('number', DecimalType(10, 2))])

_decimal_10_3_schema = StructType([
StructField('number', DecimalType(10, 3))])

_string_schema = StructType([
StructField('a', StringType())])

Expand Down Expand Up @@ -201,7 +204,7 @@ def test_json_ts_formats_round_trip(spark_tmp_path, date_format, ts_part, v1_ena
pytest.param('floats_edge_cases.json', marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/4647')),
'decimals.json',
])
@pytest.mark.parametrize('schema', [_bool_schema, _byte_schema, _short_schema, _int_schema, _long_schema, _float_schema, _double_schema, _decimal_schema])
@pytest.mark.parametrize('schema', [_bool_schema, _byte_schema, _short_schema, _int_schema, _long_schema, _float_schema, _double_schema, _decimal_10_2_schema, _decimal_10_3_schema])
@pytest.mark.parametrize('read_func', [read_json_df, read_json_sql])
@pytest.mark.parametrize('allow_non_numeric_numbers', ["true", "false"])
@pytest.mark.parametrize('allow_numeric_leading_zeros', ["true"])
Expand Down
2 changes: 2 additions & 0 deletions integration_tests/src/test/resources/decimals.csv
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,6 @@
12.34
12.3456
12345678.12
33.545454
33.454545

2 changes: 2 additions & 0 deletions integration_tests/src/test/resources/decimals.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@
{ "number": 12.3456 }
{ "number": 12.345678 }
{ "number": 123456.78 }
{ "number": 33.454545 }
{ "number": 33.545454 }
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ class JsonPartitionReader(
/**
* JSON has strict rules about valid numeric formats. See https://www.json.org/ for specification.
*
* Spark then has it's own rules for supporting NaN and Infinity, which are not
* Spark then has its own rules for supporting NaN and Infinity, which are not
* valid numbers in JSON.
*/
private def sanitizeNumbers(input: ColumnVector): ColumnVector = {
Expand Down

0 comments on commit 3d121bb

Please sign in to comment.