Skip to content

Commit

Permalink
Avoid using floating point values as partition values in tests (#9978)
Browse files Browse the repository at this point in the history
Signed-off-by: Jason Lowe <jlowe@nvidia.com>
  • Loading branch information
jlowe authored Dec 6, 2023
1 parent a307dec commit 468abdf
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 18 deletions.
11 changes: 2 additions & 9 deletions integration_tests/src/main/python/orc_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_write_round_trip_corner(spark_tmp_path, orc_gen, orc_impl):
conf={'spark.sql.orc.impl': orc_impl, 'spark.rapids.sql.format.orc.write.enabled': True})

orc_part_write_gens = [
byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, boolean_gen,
byte_gen, short_gen, int_gen, long_gen, boolean_gen,
# Some file systems have issues with UTF8 strings so to help the test pass even there
StringGen('(\\w| ){0,50}'),
# Once https://github.com/NVIDIA/spark-rapids/issues/139 is fixed replace this with
Expand All @@ -118,14 +118,7 @@ def test_write_round_trip_corner(spark_tmp_path, orc_gen, orc_impl):
@pytest.mark.parametrize('orc_gen', orc_part_write_gens, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_part_write_round_trip(spark_tmp_path, orc_gen):
part_gen = orc_gen
# Avoid generating NaNs for partition values.
# Spark does not handle partition switching properly since NaN != NaN.
if isinstance(part_gen, FloatGen):
part_gen = FloatGen(no_nans=True)
elif isinstance(part_gen, DoubleGen):
part_gen = DoubleGen(no_nans=True)
gen_list = [('a', RepeatSeqGen(part_gen, 10)),
gen_list = [('a', RepeatSeqGen(orc_gen, 10)),
('b', orc_gen)]
data_path = spark_tmp_path + '/ORC_DATA'
assert_gpu_and_cpu_writes_are_equal_collect(
Expand Down
11 changes: 2 additions & 9 deletions integration_tests/src/main/python/parquet_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def test_write_ts_millis(spark_tmp_path, ts_type, ts_rebase):


parquet_part_write_gens = [
byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
byte_gen, short_gen, int_gen, long_gen,
# Some file systems have issues with UTF8 strings so to help the test pass even there
StringGen('(\\w| ){0,50}'),
boolean_gen, date_gen,
Expand All @@ -176,14 +176,7 @@ def test_write_ts_millis(spark_tmp_path, ts_type, ts_rebase):
@pytest.mark.parametrize('parquet_gen', parquet_part_write_gens, ids=idfn)
@pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653')
def test_part_write_round_trip(spark_tmp_path, parquet_gen):
part_gen = parquet_gen
# Avoid generating NaNs for partition values.
# Spark does not handle partition switching properly since NaN != NaN.
if isinstance(part_gen, FloatGen):
part_gen = FloatGen(no_nans=True)
elif isinstance(part_gen, DoubleGen):
part_gen = DoubleGen(no_nans=True)
gen_list = [('a', RepeatSeqGen(part_gen, 10)),
gen_list = [('a', RepeatSeqGen(parquet_gen, 10)),
('b', parquet_gen)]
data_path = spark_tmp_path + '/PARQUET_DATA'
assert_gpu_and_cpu_writes_are_equal_collect(
Expand Down

0 comments on commit 468abdf

Please sign in to comment.