diff --git a/integration_tests/src/main/python/orc_write_test.py b/integration_tests/src/main/python/orc_write_test.py index 912004ed8c7..2cb73d8e24d 100644 --- a/integration_tests/src/main/python/orc_write_test.py +++ b/integration_tests/src/main/python/orc_write_test.py @@ -103,7 +103,7 @@ def test_write_round_trip_corner(spark_tmp_path, orc_gen, orc_impl): conf={'spark.sql.orc.impl': orc_impl, 'spark.rapids.sql.format.orc.write.enabled': True}) orc_part_write_gens = [ - byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, boolean_gen, + byte_gen, short_gen, int_gen, long_gen, boolean_gen, # Some file systems have issues with UTF8 strings so to help the test pass even there StringGen('(\\w| ){0,50}'), # Once https://github.com/NVIDIA/spark-rapids/issues/139 is fixed replace this with @@ -118,14 +118,7 @@ def test_write_round_trip_corner(spark_tmp_path, orc_gen, orc_impl): @pytest.mark.parametrize('orc_gen', orc_part_write_gens, ids=idfn) @pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653') def test_part_write_round_trip(spark_tmp_path, orc_gen): - part_gen = orc_gen - # Avoid generating NaNs for partition values. - # Spark does not handle partition switching properly since NaN != NaN. - if isinstance(part_gen, FloatGen): - part_gen = FloatGen(no_nans=True) - elif isinstance(part_gen, DoubleGen): - part_gen = DoubleGen(no_nans=True) - gen_list = [('a', RepeatSeqGen(part_gen, 10)), + gen_list = [('a', RepeatSeqGen(orc_gen, 10)), ('b', orc_gen)] data_path = spark_tmp_path + '/ORC_DATA' assert_gpu_and_cpu_writes_are_equal_collect( diff --git a/integration_tests/src/main/python/parquet_write_test.py b/integration_tests/src/main/python/parquet_write_test.py index 015e4481700..1ad35f0aaae 100644 --- a/integration_tests/src/main/python/parquet_write_test.py +++ b/integration_tests/src/main/python/parquet_write_test.py @@ -164,7 +164,7 @@ def test_write_ts_millis(spark_tmp_path, ts_type, ts_rebase): parquet_part_write_gens = [ - byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, + byte_gen, short_gen, int_gen, long_gen, # Some file systems have issues with UTF8 strings so to help the test pass even there StringGen('(\\w| ){0,50}'), boolean_gen, date_gen, @@ -176,14 +176,7 @@ def test_write_ts_millis(spark_tmp_path, ts_type, ts_rebase): @pytest.mark.parametrize('parquet_gen', parquet_part_write_gens, ids=idfn) @pytest.mark.xfail(condition = is_not_utc(), reason = 'xfail non-UTC time zone tests because of https://github.com/NVIDIA/spark-rapids/issues/9653') def test_part_write_round_trip(spark_tmp_path, parquet_gen): - part_gen = parquet_gen - # Avoid generating NaNs for partition values. - # Spark does not handle partition switching properly since NaN != NaN. - if isinstance(part_gen, FloatGen): - part_gen = FloatGen(no_nans=True) - elif isinstance(part_gen, DoubleGen): - part_gen = DoubleGen(no_nans=True) - gen_list = [('a', RepeatSeqGen(part_gen, 10)), + gen_list = [('a', RepeatSeqGen(parquet_gen, 10)), ('b', parquet_gen)] data_path = spark_tmp_path + '/PARQUET_DATA' assert_gpu_and_cpu_writes_are_equal_collect(