Skip to content

Commit

Permalink
Improve some CSV integration tests [databricks] (#9146)
Browse files Browse the repository at this point in the history
* improve some csv tests

* remove unrelated change

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* remove unrelated test

* use named parameters to improve readability

* remove trailing comma

* update comment

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* address feedback (add link to issue for future improvements)

---------

Signed-off-by: Andy Grove <andygrove@nvidia.com>
  • Loading branch information
andygrove authored Sep 28, 2023
1 parent d7230b6 commit 7d5b904
Showing 1 changed file with 23 additions and 8 deletions.
31 changes: 23 additions & 8 deletions integration_tests/src/main/python/csv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,14 +463,15 @@ def test_input_meta_fallback(spark_tmp_path, v1_enabled_list, disable_conf):
updated_conf = copy_and_update(_enable_all_types_conf, {
'spark.sql.sources.useV1SourceList': v1_enabled_list,
disable_conf: 'false'})
assert_gpu_and_cpu_are_equal_collect(
assert_gpu_fallback_collect(
lambda spark : spark.read.schema(gen.data_type)\
.csv(data_path)\
.filter(f.col('a') > 0)\
.selectExpr('a',
'input_file_name()',
'input_file_block_start()',
'input_file_block_length()'),
cpu_fallback_class_name = 'FileSourceScanExec' if v1_enabled_list == 'csv' else 'BatchScanExec',
conf=updated_conf)

@allow_non_gpu('DataWritingCommandExec,ExecutedCommandExec,WriteFilesExec')
Expand Down Expand Up @@ -529,16 +530,18 @@ def test_round_trip_for_interval(spark_tmp_path, v1_enabled_list):
lambda spark: spark.read.schema(schema).csv(data_path),
conf=updated_conf)

@allow_non_gpu(any = True)
@allow_non_gpu('FileSourceScanExec', 'CollectLimitExec', 'DeserializeToObjectExec')
def test_csv_read_case_insensitivity(spark_tmp_path):
gen_list = [('one', int_gen), ('tWo', byte_gen), ('THREE', boolean_gen)]
data_path = spark_tmp_path + '/CSV_DATA'

with_cpu_session(lambda spark: gen_df(spark, gen_list).write.option('header', True).csv(data_path))

assert_gpu_and_cpu_are_equal_collect(
assert_cpu_and_gpu_are_equal_collect_with_capture(
lambda spark: spark.read.option('header', True).csv(data_path).select('one', 'two', 'three'),
{'spark.sql.caseSensitive': 'false'}
exist_classes = 'GpuFileSourceScanExec',
non_exist_classes = 'FileSourceScanExec',
conf = {'spark.sql.caseSensitive': 'false'}
)

@allow_non_gpu('FileSourceScanExec', 'CollectLimitExec', 'DeserializeToObjectExec')
Expand All @@ -549,7 +552,13 @@ def test_csv_read_count(spark_tmp_path):

with_cpu_session(lambda spark: gen_df(spark, gen_list).write.csv(data_path))

assert_gpu_and_cpu_row_counts_equal(lambda spark: spark.read.csv(data_path))
# TODO This does not really test that the GPU count actually runs on the GPU
# because this test has @allow_non_gpu for operators that fall back to CPU
# when Spark performs an initial scan to infer the schema. To resolve this
# we would need a new `assert_gpu_and_cpu_row_counts_equal_with_capture` function.
# Tracking issue: https://github.com/NVIDIA/spark-rapids/issues/9199
assert_gpu_and_cpu_row_counts_equal(lambda spark: spark.read.csv(data_path),
conf = {'spark.rapids.sql.explain': 'ALL'})

@allow_non_gpu('FileSourceScanExec', 'ProjectExec', 'CollectLimitExec', 'DeserializeToObjectExec')
@pytest.mark.skipif(is_before_spark_340(), reason='`TIMESTAMP_NTZ` is only supported in Spark 340+')
Expand Down Expand Up @@ -613,12 +622,18 @@ def test_csv_prefer_date_with_infer_schema(spark_tmp_path):

with_cpu_session(lambda spark: gen_df(spark, gen_list).write.csv(data_path))

assert_gpu_and_cpu_are_equal_collect(lambda spark: spark.read.option("inferSchema", "true").csv(data_path))
assert_gpu_and_cpu_are_equal_collect(lambda spark: spark.read.option("inferSchema", "true").option("preferDate", "false").csv(data_path))
assert_cpu_and_gpu_are_equal_collect_with_capture(
lambda spark: spark.read.option("inferSchema", "true").csv(data_path),
exist_classes = 'GpuFileSourceScanExec',
non_exist_classes = 'FileSourceScanExec')
assert_cpu_and_gpu_are_equal_collect_with_capture(
lambda spark: spark.read.option("inferSchema", "true").option("preferDate", "false").csv(data_path),
exist_classes = 'GpuFileSourceScanExec',
non_exist_classes = 'FileSourceScanExec')

@allow_non_gpu('FileSourceScanExec')
@pytest.mark.skipif(is_before_spark_340(), reason='enableDateTimeParsingFallback is supported from Spark3.4.0')
@pytest.mark.parametrize('filename,schema',[("date.csv", _date_schema), ("date.csv", _ts_schema,),
@pytest.mark.parametrize('filename,schema',[("date.csv", _date_schema), ("date.csv", _ts_schema),
("ts.csv", _ts_schema)])
def test_csv_datetime_parsing_fallback_cpu_fallback(std_input_path, filename, schema):
data_path = std_input_path + "/" + filename
Expand Down

0 comments on commit 7d5b904

Please sign in to comment.