diff --git a/integration_tests/src/main/python/csv_test.py b/integration_tests/src/main/python/csv_test.py index 5227dd0a41c..5f593be88c1 100644 --- a/integration_tests/src/main/python/csv_test.py +++ b/integration_tests/src/main/python/csv_test.py @@ -463,7 +463,7 @@ def test_input_meta_fallback(spark_tmp_path, v1_enabled_list, disable_conf): updated_conf = copy_and_update(_enable_all_types_conf, { 'spark.sql.sources.useV1SourceList': v1_enabled_list, disable_conf: 'false'}) - assert_gpu_and_cpu_are_equal_collect( + assert_gpu_fallback_collect( lambda spark : spark.read.schema(gen.data_type)\ .csv(data_path)\ .filter(f.col('a') > 0)\ @@ -471,6 +471,7 @@ def test_input_meta_fallback(spark_tmp_path, v1_enabled_list, disable_conf): 'input_file_name()', 'input_file_block_start()', 'input_file_block_length()'), + cpu_fallback_class_name = 'FileSourceScanExec' if v1_enabled_list == 'csv' else 'BatchScanExec', conf=updated_conf) @allow_non_gpu('DataWritingCommandExec,ExecutedCommandExec,WriteFilesExec') @@ -529,16 +530,18 @@ def test_round_trip_for_interval(spark_tmp_path, v1_enabled_list): lambda spark: spark.read.schema(schema).csv(data_path), conf=updated_conf) -@allow_non_gpu(any = True) +@allow_non_gpu('FileSourceScanExec', 'CollectLimitExec', 'DeserializeToObjectExec') def test_csv_read_case_insensitivity(spark_tmp_path): gen_list = [('one', int_gen), ('tWo', byte_gen), ('THREE', boolean_gen)] data_path = spark_tmp_path + '/CSV_DATA' with_cpu_session(lambda spark: gen_df(spark, gen_list).write.option('header', True).csv(data_path)) - assert_gpu_and_cpu_are_equal_collect( + assert_cpu_and_gpu_are_equal_collect_with_capture( lambda spark: spark.read.option('header', True).csv(data_path).select('one', 'two', 'three'), - {'spark.sql.caseSensitive': 'false'} + exist_classes = 'GpuFileSourceScanExec', + non_exist_classes = 'FileSourceScanExec', + conf = {'spark.sql.caseSensitive': 'false'} ) @allow_non_gpu('FileSourceScanExec', 'CollectLimitExec', 'DeserializeToObjectExec') @@ -549,7 +552,13 @@ def test_csv_read_count(spark_tmp_path): with_cpu_session(lambda spark: gen_df(spark, gen_list).write.csv(data_path)) - assert_gpu_and_cpu_row_counts_equal(lambda spark: spark.read.csv(data_path)) + # TODO This does not really test that the GPU count actually runs on the GPU + # because this test has @allow_non_gpu for operators that fall back to CPU + # when Spark performs an initial scan to infer the schema. To resolve this + # we would need a new `assert_gpu_and_cpu_row_counts_equal_with_capture` function. + # Tracking issue: https://github.com/NVIDIA/spark-rapids/issues/9199 + assert_gpu_and_cpu_row_counts_equal(lambda spark: spark.read.csv(data_path), + conf = {'spark.rapids.sql.explain': 'ALL'}) @allow_non_gpu('FileSourceScanExec', 'CollectLimitExec', 'DeserializeToObjectExec') @pytest.mark.skipif(is_before_spark_340(), reason='`preferDate` is only supported in Spark 340+') @@ -561,12 +570,18 @@ def test_csv_prefer_date_with_infer_schema(spark_tmp_path): with_cpu_session(lambda spark: gen_df(spark, gen_list).write.csv(data_path)) - assert_gpu_and_cpu_are_equal_collect(lambda spark: spark.read.option("inferSchema", "true").csv(data_path)) - assert_gpu_and_cpu_are_equal_collect(lambda spark: spark.read.option("inferSchema", "true").option("preferDate", "false").csv(data_path)) + assert_cpu_and_gpu_are_equal_collect_with_capture( + lambda spark: spark.read.option("inferSchema", "true").csv(data_path), + exist_classes = 'GpuFileSourceScanExec', + non_exist_classes = 'FileSourceScanExec') + assert_cpu_and_gpu_are_equal_collect_with_capture( + lambda spark: spark.read.option("inferSchema", "true").option("preferDate", "false").csv(data_path), + exist_classes = 'GpuFileSourceScanExec', + non_exist_classes = 'FileSourceScanExec') @allow_non_gpu('FileSourceScanExec') @pytest.mark.skipif(is_before_spark_340(), reason='enableDateTimeParsingFallback is supported from Spark3.4.0') -@pytest.mark.parametrize('filename,schema',[("date.csv", _date_schema), ("date.csv", _ts_schema,), +@pytest.mark.parametrize('filename,schema',[("date.csv", _date_schema), ("date.csv", _ts_schema), ("ts.csv", _ts_schema)]) def test_csv_datetime_parsing_fallback_cpu_fallback(std_input_path, filename, schema): data_path = std_input_path + "/" + filename