diff --git a/integration_tests/src/main/python/csv_test.py b/integration_tests/src/main/python/csv_test.py index 153c302780b..d74e8bc9aa1 100644 --- a/integration_tests/src/main/python/csv_test.py +++ b/integration_tests/src/main/python/csv_test.py @@ -119,9 +119,12 @@ def read_impl(spark): ('str.csv', _good_str_schema, ',', True) ]) @pytest.mark.parametrize('read_func', [read_csv_df, read_csv_sql]) -def test_basic_read(std_input_path, name, schema, sep, header, read_func): +@pytest.mark.parametrize('v1_enabled_list', ["", "csv"]) +def test_basic_read(std_input_path, name, schema, sep, header, read_func, v1_enabled_list): + updated_conf=_enable_ts_conf + updated_conf['spark.sql.sources.useV1SourceList']=v1_enabled_list assert_gpu_and_cpu_are_equal_collect(read_func(std_input_path + '/' + name, schema, header, sep), - conf=_enable_ts_conf) + conf=updated_conf) csv_supported_gens = [ # Spark does not escape '\r' or '\n' even though it uses it to mark end of record @@ -141,15 +144,18 @@ def test_basic_read(std_input_path, name, schema, sep, header, read_func): @approximate_float @pytest.mark.parametrize('data_gen', csv_supported_gens, ids=idfn) -def test_round_trip(spark_tmp_path, data_gen): +@pytest.mark.parametrize('v1_enabled_list', ["", "csv"]) +def test_round_trip(spark_tmp_path, data_gen, v1_enabled_list): gen = StructGen([('a', data_gen)], nullable=False) data_path = spark_tmp_path + '/CSV_DATA' schema = gen.data_type + updated_conf=_enable_ts_conf + updated_conf['spark.sql.sources.useV1SourceList']=v1_enabled_list with_cpu_session( lambda spark : gen_df(spark, gen).write.csv(data_path)) assert_gpu_and_cpu_are_equal_collect( lambda spark : spark.read.schema(schema).csv(data_path), - conf=_enable_ts_conf) + conf=updated_conf) @allow_non_gpu('FileSourceScanExec') @pytest.mark.parametrize('read_func', [read_csv_df, read_csv_sql]) @@ -174,7 +180,8 @@ def test_csv_fallback(spark_tmp_path, read_func, disable_conf): csv_supported_date_formats = ['yyyy-MM-dd', 'yyyy/MM/dd', 'yyyy-MM', 'yyyy/MM', 'MM-yyyy', 'MM/yyyy', 'MM-dd-yyyy', 'MM/dd/yyyy'] @pytest.mark.parametrize('date_format', csv_supported_date_formats, ids=idfn) -def test_date_formats_round_trip(spark_tmp_path, date_format): +@pytest.mark.parametrize('v1_enabled_list', ["", "csv"]) +def test_date_formats_round_trip(spark_tmp_path, date_format, v1_enabled_list): gen = StructGen([('a', DateGen())], nullable=False) data_path = spark_tmp_path + '/CSV_DATA' schema = gen.data_type @@ -186,7 +193,8 @@ def test_date_formats_round_trip(spark_tmp_path, date_format): lambda spark : spark.read\ .schema(schema)\ .option('dateFormat', date_format)\ - .csv(data_path)) + .csv(data_path), + conf={'spark.sql.sources.useV1SourceList': v1_enabled_list}) csv_supported_ts_parts = ['', # Just the date "'T'HH:mm:ss.SSSXXX", @@ -199,7 +207,8 @@ def test_date_formats_round_trip(spark_tmp_path, date_format): @pytest.mark.parametrize('ts_part', csv_supported_ts_parts) @pytest.mark.parametrize('date_format', csv_supported_date_formats) -def test_ts_formats_round_trip(spark_tmp_path, date_format, ts_part): +@pytest.mark.parametrize('v1_enabled_list', ["", "csv"]) +def test_ts_formats_round_trip(spark_tmp_path, date_format, ts_part, v1_enabled_list): full_format = date_format + ts_part # Once https://github.com/NVIDIA/spark-rapids/issues/122 is fixed the full range should be used data_gen = TimestampGen(start=datetime(1902, 1, 1, tzinfo=timezone.utc), @@ -211,14 +220,17 @@ def test_ts_formats_round_trip(spark_tmp_path, date_format, ts_part): lambda spark : gen_df(spark, gen).write\ .option('timestampFormat', full_format)\ .csv(data_path)) + updated_conf=_enable_ts_conf + updated_conf['spark.sql.sources.useV1SourceList']=v1_enabled_list assert_gpu_and_cpu_are_equal_collect( lambda spark : spark.read\ .schema(schema)\ .option('timestampFormat', full_format)\ .csv(data_path), - conf=_enable_ts_conf) + conf=updated_conf) -def test_input_meta(spark_tmp_path): +@pytest.mark.parametrize('v1_enabled_list', ["", "csv"]) +def test_input_meta(spark_tmp_path, v1_enabled_list): gen = StructGen([('a', long_gen), ('b', long_gen)], nullable=False) first_data_path = spark_tmp_path + '/CSV_DATA/key=0' with_cpu_session( @@ -234,4 +246,5 @@ def test_input_meta(spark_tmp_path): .selectExpr('a', 'input_file_name()', 'input_file_block_start()', - 'input_file_block_length()')) + 'input_file_block_length()'), + conf={'spark.sql.sources.useV1SourceList': v1_enabled_list}) diff --git a/integration_tests/src/main/python/orc_test.py b/integration_tests/src/main/python/orc_test.py index 571bb4e511f..e9e121f6354 100644 --- a/integration_tests/src/main/python/orc_test.py +++ b/integration_tests/src/main/python/orc_test.py @@ -29,9 +29,11 @@ def read_orc_sql(data_path): @pytest.mark.parametrize('name', ['timestamp-date-test.orc']) @pytest.mark.parametrize('read_func', [read_orc_df, read_orc_sql]) -def test_basic_read(std_input_path, name, read_func): +@pytest.mark.parametrize('v1_enabled_list', ["", "orc"]) +def test_basic_read(std_input_path, name, read_func, v1_enabled_list): assert_gpu_and_cpu_are_equal_collect( - read_func(std_input_path + '/' + name)) + read_func(std_input_path + '/' + name), + conf={'spark.sql.sources.useV1SourceList': v1_enabled_list}) orc_gens_list = [[byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)), @@ -59,13 +61,15 @@ def test_orc_fallback(spark_tmp_path, read_func, disable_conf): @pytest.mark.parametrize('orc_gens', orc_gens_list, ids=idfn) @pytest.mark.parametrize('read_func', [read_orc_df, read_orc_sql]) -def test_read_round_trip(spark_tmp_path, orc_gens, read_func): +@pytest.mark.parametrize('v1_enabled_list', ["", "orc"]) +def test_read_round_trip(spark_tmp_path, orc_gens, read_func, v1_enabled_list): gen_list = [('_c' + str(i), gen) for i, gen in enumerate(orc_gens)] data_path = spark_tmp_path + '/ORC_DATA' with_cpu_session( lambda spark : gen_df(spark, gen_list).write.orc(data_path)) assert_gpu_and_cpu_are_equal_collect( - read_func(data_path)) + read_func(data_path), + conf={'spark.sql.sources.useV1SourceList': v1_enabled_list}) orc_pred_push_gens = [ byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, boolean_gen, @@ -79,7 +83,8 @@ def test_read_round_trip(spark_tmp_path, orc_gens, read_func): @pytest.mark.parametrize('orc_gen', orc_pred_push_gens, ids=idfn) @pytest.mark.parametrize('read_func', [read_orc_df, read_orc_sql]) -def test_pred_push_round_trip(spark_tmp_path, orc_gen, read_func): +@pytest.mark.parametrize('v1_enabled_list', ["", "orc"]) +def test_pred_push_round_trip(spark_tmp_path, orc_gen, read_func, v1_enabled_list): data_path = spark_tmp_path + '/ORC_DATA' gen_list = [('a', RepeatSeqGen(orc_gen, 100)), ('b', orc_gen)] s0 = gen_scalar(orc_gen, force_no_nulls=True) @@ -87,22 +92,26 @@ def test_pred_push_round_trip(spark_tmp_path, orc_gen, read_func): lambda spark : gen_df(spark, gen_list).orderBy('a').write.orc(data_path)) rf = read_func(data_path) assert_gpu_and_cpu_are_equal_collect( - lambda spark: rf(spark).select(f.col('a') >= s0)) + lambda spark: rf(spark).select(f.col('a') >= s0), + conf={'spark.sql.sources.useV1SourceList': v1_enabled_list}) orc_compress_options = ['none', 'uncompressed', 'snappy', 'zlib'] # The following need extra jars 'lzo' # https://github.com/NVIDIA/spark-rapids/issues/143 @pytest.mark.parametrize('compress', orc_compress_options) -def test_compress_read_round_trip(spark_tmp_path, compress): +@pytest.mark.parametrize('v1_enabled_list', ["", "orc"]) +def test_compress_read_round_trip(spark_tmp_path, compress, v1_enabled_list): data_path = spark_tmp_path + '/ORC_DATA' with_cpu_session( lambda spark : binary_op_df(spark, long_gen).write.orc(data_path), conf={'spark.sql.orc.compression.codec': compress}) assert_gpu_and_cpu_are_equal_collect( - lambda spark : spark.read.orc(data_path)) + lambda spark : spark.read.orc(data_path), + conf={'spark.sql.sources.useV1SourceList': v1_enabled_list}) -def test_simple_partitioned_read(spark_tmp_path): +@pytest.mark.parametrize('v1_enabled_list', ["", "orc"]) +def test_simple_partitioned_read(spark_tmp_path, v1_enabled_list): # Once https://github.com/NVIDIA/spark-rapids/issues/131 is fixed # we should go with a more standard set of generators orc_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, @@ -117,10 +126,12 @@ def test_simple_partitioned_read(spark_tmp_path): lambda spark : gen_df(spark, gen_list).write.orc(second_data_path)) data_path = spark_tmp_path + '/ORC_DATA' assert_gpu_and_cpu_are_equal_collect( - lambda spark : spark.read.orc(data_path)) + lambda spark : spark.read.orc(data_path), + conf={'spark.sql.sources.useV1SourceList': v1_enabled_list}) @pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/135') -def test_merge_schema_read(spark_tmp_path): +@pytest.mark.parametrize('v1_enabled_list', ["", "orc"]) +def test_merge_schema_read(spark_tmp_path, v1_enabled_list): # Once https://github.com/NVIDIA/spark-rapids/issues/131 is fixed # we should go with a more standard set of generators orc_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, @@ -136,7 +147,8 @@ def test_merge_schema_read(spark_tmp_path): lambda spark : gen_df(spark, second_gen_list).write.orc(second_data_path)) data_path = spark_tmp_path + '/ORC_DATA' assert_gpu_and_cpu_are_equal_collect( - lambda spark : spark.read.option('mergeSchema', 'true').orc(data_path)) + lambda spark : spark.read.option('mergeSchema', 'true').orc(data_path), + conf={'spark.sql.sources.useV1SourceList': v1_enabled_list}) orc_write_gens_list = [ [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, diff --git a/integration_tests/src/main/python/parquet_test.py b/integration_tests/src/main/python/parquet_test.py index bac8cd11d29..a80208ce9e8 100644 --- a/integration_tests/src/main/python/parquet_test.py +++ b/integration_tests/src/main/python/parquet_test.py @@ -20,7 +20,7 @@ from data_gen import * from marks import * from pyspark.sql.types import * -from spark_session import with_cpu_session +from spark_session import with_cpu_session, with_gpu_session def read_parquet_df(data_path): return lambda spark : spark.read.parquet(data_path) @@ -35,14 +35,15 @@ def read_parquet_sql(data_path): @pytest.mark.parametrize('parquet_gens', parquet_gens_list, ids=idfn) @pytest.mark.parametrize('read_func', [read_parquet_df, read_parquet_sql]) -def test_read_round_trip(spark_tmp_path, parquet_gens, read_func): +@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"]) +def test_read_round_trip(spark_tmp_path, parquet_gens, read_func, v1_enabled_list): gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)] data_path = spark_tmp_path + '/PARQUET_DATA' with_cpu_session( lambda spark : gen_df(spark, gen_list).write.parquet(data_path), conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED'}) - assert_gpu_and_cpu_are_equal_collect( - read_func(data_path)) + assert_gpu_and_cpu_are_equal_collect(read_func(data_path), + conf={'spark.sql.sources.useV1SourceList': v1_enabled_list}) @allow_non_gpu('FileSourceScanExec') @pytest.mark.parametrize('read_func', [read_parquet_df, read_parquet_sql]) @@ -67,13 +68,15 @@ def test_parquet_fallback(spark_tmp_path, read_func, disable_conf): # https://github.com/NVIDIA/spark-rapids/issues/143 @pytest.mark.parametrize('compress', parquet_compress_options) -def test_compress_read_round_trip(spark_tmp_path, compress): +@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"]) +def test_compress_read_round_trip(spark_tmp_path, compress, v1_enabled_list): data_path = spark_tmp_path + '/PARQUET_DATA' with_cpu_session( lambda spark : binary_op_df(spark, long_gen).write.parquet(data_path), conf={'spark.sql.parquet.compression.codec': compress}) assert_gpu_and_cpu_are_equal_collect( - lambda spark : spark.read.parquet(data_path)) + lambda spark : spark.read.parquet(data_path), + conf={'spark.sql.sources.useV1SourceList': v1_enabled_list}) parquet_pred_push_gens = [ byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, boolean_gen, @@ -84,8 +87,9 @@ def test_compress_read_round_trip(spark_tmp_path, compress): @pytest.mark.parametrize('parquet_gen', parquet_pred_push_gens, ids=idfn) @pytest.mark.parametrize('read_func', [read_parquet_df, read_parquet_sql]) -def test_pred_push_round_trip(spark_tmp_path, parquet_gen, read_func): - data_path = spark_tmp_path + '/ORC_DATA' +@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"]) +def test_pred_push_round_trip(spark_tmp_path, parquet_gen, read_func, v1_enabled_list): + data_path = spark_tmp_path + '/PARQUET_DATA' gen_list = [('a', RepeatSeqGen(parquet_gen, 100)), ('b', parquet_gen)] s0 = gen_scalar(parquet_gen, force_no_nulls=True) with_cpu_session( @@ -93,13 +97,15 @@ def test_pred_push_round_trip(spark_tmp_path, parquet_gen, read_func): conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED'}) rf = read_func(data_path) assert_gpu_and_cpu_are_equal_collect( - lambda spark: rf(spark).select(f.col('a') >= s0)) + lambda spark: rf(spark).select(f.col('a') >= s0), + conf={'spark.sql.sources.useV1SourceList': v1_enabled_list}) parquet_ts_write_options = ['INT96', 'TIMESTAMP_MICROS', 'TIMESTAMP_MILLIS'] @pytest.mark.parametrize('ts_write', parquet_ts_write_options) @pytest.mark.parametrize('ts_rebase', ['CORRECTED', 'LEGACY']) -def test_ts_read_round_trip(spark_tmp_path, ts_write, ts_rebase): +@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"]) +def test_ts_read_round_trip(spark_tmp_path, ts_write, ts_rebase, v1_enabled_list): # Once https://github.com/NVIDIA/spark-rapids/issues/132 is fixed replace this with # timestamp_gen gen = TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc)) @@ -109,7 +115,8 @@ def test_ts_read_round_trip(spark_tmp_path, ts_write, ts_rebase): conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase, 'spark.sql.parquet.outputTimestampType': ts_write}) assert_gpu_and_cpu_are_equal_collect( - lambda spark : spark.read.parquet(data_path)) + lambda spark : spark.read.parquet(data_path), + conf={'spark.sql.sources.useV1SourceList': v1_enabled_list}) parquet_gens_legacy_list = [[byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)), @@ -118,16 +125,19 @@ def test_ts_read_round_trip(spark_tmp_path, ts_write, ts_rebase): pytest.param([date_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/133'))] @pytest.mark.parametrize('parquet_gens', parquet_gens_legacy_list, ids=idfn) -def test_read_round_trip_legacy(spark_tmp_path, parquet_gens): +@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"]) +def test_read_round_trip_legacy(spark_tmp_path, parquet_gens, v1_enabled_list): gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)] data_path = spark_tmp_path + '/PARQUET_DATA' with_cpu_session( lambda spark : gen_df(spark, gen_list).write.parquet(data_path), conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'LEGACY'}) assert_gpu_and_cpu_are_equal_collect( - lambda spark : spark.read.parquet(data_path)) + lambda spark : spark.read.parquet(data_path), + conf={'spark.sql.sources.useV1SourceList': v1_enabled_list}) -def test_simple_partitioned_read(spark_tmp_path): +@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"]) +def test_simple_partitioned_read(spark_tmp_path, v1_enabled_list): # Once https://github.com/NVIDIA/spark-rapids/issues/133 and https://github.com/NVIDIA/spark-rapids/issues/132 are fixed # we should go with a more standard set of generators parquet_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, @@ -144,11 +154,13 @@ def test_simple_partitioned_read(spark_tmp_path): conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED'}) data_path = spark_tmp_path + '/PARQUET_DATA' assert_gpu_and_cpu_are_equal_collect( - lambda spark : spark.read.parquet(data_path)) + lambda spark : spark.read.parquet(data_path), + conf={'spark.sql.sources.useV1SourceList': v1_enabled_list}) @pytest.mark.xfail(condition=is_databricks_runtime(), reason='https://github.com/NVIDIA/spark-rapids/issues/192') -def test_read_merge_schema(spark_tmp_path): +@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"]) +def test_read_merge_schema(spark_tmp_path, v1_enabled_list): # Once https://github.com/NVIDIA/spark-rapids/issues/133 and https://github.com/NVIDIA/spark-rapids/issues/132 are fixed # we should go with a more standard set of generators parquet_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, @@ -166,14 +178,16 @@ def test_read_merge_schema(spark_tmp_path): conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED'}) data_path = spark_tmp_path + '/PARQUET_DATA' assert_gpu_and_cpu_are_equal_collect( - lambda spark : spark.read.option('mergeSchema', 'true').parquet(data_path)) + lambda spark : spark.read.option('mergeSchema', 'true').parquet(data_path), + conf={'spark.sql.sources.useV1SourceList': v1_enabled_list}) parquet_write_gens_list = [ [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, string_gen, boolean_gen, date_gen, timestamp_gen]] @pytest.mark.parametrize('parquet_gens', parquet_write_gens_list, ids=idfn) -def test_write_round_trip(spark_tmp_path, parquet_gens): +@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"]) +def test_write_round_trip(spark_tmp_path, parquet_gens, v1_enabled_list): gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)] data_path = spark_tmp_path + '/PARQUET_DATA' assert_gpu_and_cpu_writes_are_equal_collect( @@ -181,7 +195,8 @@ def test_write_round_trip(spark_tmp_path, parquet_gens): lambda spark, path: spark.read.parquet(path), data_path, conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED', - 'spark.sql.parquet.outputTimestampType': 'TIMESTAMP_MICROS'}) + 'spark.sql.parquet.outputTimestampType': 'TIMESTAMP_MICROS', + 'spark.sql.sources.useV1SourceList': v1_enabled_list}) parquet_part_write_gens = [ byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, @@ -192,7 +207,8 @@ def test_write_round_trip(spark_tmp_path, parquet_gens): # There are race conditions around when individual files are read in for partitioned data @ignore_order @pytest.mark.parametrize('parquet_gen', parquet_part_write_gens, ids=idfn) -def test_part_write_round_trip(spark_tmp_path, parquet_gen): +@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"]) +def test_part_write_round_trip(spark_tmp_path, parquet_gen, v1_enabled_list): gen_list = [('a', RepeatSeqGen(parquet_gen, 10)), ('b', parquet_gen)] data_path = spark_tmp_path + '/PARQUET_DATA' @@ -201,19 +217,23 @@ def test_part_write_round_trip(spark_tmp_path, parquet_gen): lambda spark, path: spark.read.parquet(path), data_path, conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED', - 'spark.sql.parquet.outputTimestampType': 'TIMESTAMP_MICROS'}) + 'spark.sql.parquet.outputTimestampType': 'TIMESTAMP_MICROS', + 'spark.sql.sources.useV1SourceList': v1_enabled_list}) parquet_write_compress_options = ['none', 'uncompressed', 'snappy'] @pytest.mark.parametrize('compress', parquet_write_compress_options) -def test_compress_write_round_trip(spark_tmp_path, compress): +@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"]) +def test_compress_write_round_trip(spark_tmp_path, compress, v1_enabled_list): data_path = spark_tmp_path + '/PARQUET_DATA' assert_gpu_and_cpu_writes_are_equal_collect( lambda spark, path : binary_op_df(spark, long_gen).coalesce(1).write.parquet(path), lambda spark, path : spark.read.parquet(path), data_path, - conf={'spark.sql.parquet.compression.codec': compress}) + conf={'spark.sql.parquet.compression.codec': compress, + 'spark.sql.sources.useV1SourceList': v1_enabled_list}) -def test_input_meta(spark_tmp_path): +@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"]) +def test_input_meta(spark_tmp_path, v1_enabled_list): first_data_path = spark_tmp_path + '/PARQUET_DATA/key=0' with_cpu_session( lambda spark : unary_op_df(spark, long_gen).write.parquet(first_data_path)) @@ -227,4 +247,6 @@ def test_input_meta(spark_tmp_path): .selectExpr('a', 'input_file_name()', 'input_file_block_start()', - 'input_file_block_length()')) + 'input_file_block_length()'), + conf={'spark.sql.sources.useV1SourceList': v1_enabled_list}) + diff --git a/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/GpuOrcScan.scala b/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/GpuOrcScan.scala new file mode 100644 index 00000000000..148a05838d2 --- /dev/null +++ b/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/GpuOrcScan.scala @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims.spark300 + +import com.nvidia.spark.rapids.{GpuOrcScanBase, RapidsConf} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.v2.FileScan +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +// FileScan changed in Spark 3.1.0 so need to compile in Shim +case class GpuOrcScan( + sparkSession: SparkSession, + hadoopConf: Configuration, + fileIndex: PartitioningAwareFileIndex, + dataSchema: StructType, + readDataSchema: StructType, + readPartitionSchema: StructType, + options: CaseInsensitiveStringMap, + pushedFilters: Array[Filter], + partitionFilters: Seq[Expression], + dataFilters: Seq[Expression], + rapidsConf: RapidsConf) + extends GpuOrcScanBase(sparkSession, hadoopConf, dataSchema, readDataSchema, + readPartitionSchema, pushedFilters, rapidsConf) with FileScan { + + override def isSplitable(path: Path): Boolean = super.isSplitableBase(path) + + override def createReaderFactory(): PartitionReaderFactory = super.createReaderFactoryBase() + + override def equals(obj: Any): Boolean = obj match { + case o: GpuOrcScan => + super.equals(o) && dataSchema == o.dataSchema && options == o.options && + equivalentFilters(pushedFilters, o.pushedFilters) && rapidsConf == o.rapidsConf + case _ => false + } + + override def hashCode(): Int = getClass.hashCode() + + override def description(): String = { + super.description() + ", PushedFilters: " + seqToString(pushedFilters) + } + + override def withFilters( + partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = + this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) +} diff --git a/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/GpuParquetScan.scala b/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/GpuParquetScan.scala new file mode 100644 index 00000000000..fe0b0dbb114 --- /dev/null +++ b/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/GpuParquetScan.scala @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims.spark300 + +import com.nvidia.spark.rapids.{GpuParquetScanBase, RapidsConf} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.v2.FileScan +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +// FileScan changed in Spark 3.1.0 so need to compile in Shim +case class GpuParquetScan( + sparkSession: SparkSession, + hadoopConf: Configuration, + fileIndex: PartitioningAwareFileIndex, + dataSchema: StructType, + readDataSchema: StructType, + readPartitionSchema: StructType, + pushedFilters: Array[Filter], + options: CaseInsensitiveStringMap, + partitionFilters: Seq[Expression], + dataFilters: Seq[Expression], + rapidsConf: RapidsConf) + extends GpuParquetScanBase(sparkSession, hadoopConf, dataSchema, + readDataSchema, readPartitionSchema, pushedFilters, rapidsConf) with FileScan { + + override def isSplitable(path: Path): Boolean = super.isSplitableBase(path) + + override def createReaderFactory(): PartitionReaderFactory = super.createReaderFactoryBase() + + override def equals(obj: Any): Boolean = obj match { + case p: GpuParquetScan => + super.equals(p) && dataSchema == p.dataSchema && options == p.options && + equivalentFilters(pushedFilters, p.pushedFilters) && rapidsConf == p.rapidsConf + case _ => false + } + + override def hashCode(): Int = getClass.hashCode() + + override def description(): String = { + super.description() + ", PushedFilters: " + seqToString(pushedFilters) + } + + override def withFilters( + partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = + this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) +} diff --git a/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala b/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala index cb11b3652b9..3cc0fbe24ae 100644 --- a/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala +++ b/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala @@ -31,14 +31,18 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec -import org.apache.spark.sql.execution.datasources.{BucketingUtils, FilePartition, FileScanRDD, HadoopFsRelation, PartitionDirectory, PartitionedFile} +import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, HadoopFsRelation, PartitionDirectory, PartitionedFile} +import org.apache.spark.sql.execution.datasources.v2.csv.CSVScan +import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan +import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashJoin, SortMergeJoinExec} import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec import org.apache.spark.sql.rapids.{GpuFileSourceScanExec, GpuTimeSub, ShuffleManagerShimBase} -import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExecBase, GpuBroadcastMeta, GpuBroadcastNestedLoopJoinExecBase, GpuShuffleExchangeExecBase, GpuShuffleMeta} +import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExecBase, GpuBroadcastNestedLoopJoinExecBase, GpuShuffleExchangeExecBase} import org.apache.spark.sql.rapids.shims.spark300._ import org.apache.spark.sql.types._ import org.apache.spark.storage.{BlockId, BlockManagerId} @@ -216,6 +220,46 @@ class Spark300Shims extends SparkShims { ).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap } + override def getScans: Map[Class[_ <: Scan], ScanRule[_ <: Scan]] = Seq( + GpuOverrides.scan[ParquetScan]( + "Parquet parsing", + (a, conf, p, r) => new ScanMeta[ParquetScan](a, conf, p, r) { + override def tagSelfForGpu(): Unit = GpuParquetScanBase.tagSupport(this) + + override def convertToGpu(): Scan = + GpuParquetScan(a.sparkSession, + a.hadoopConf, + a.fileIndex, + a.dataSchema, + a.readDataSchema, + a.readPartitionSchema, + a.pushedFilters, + a.options, + a.partitionFilters, + a.dataFilters, + conf) + }), + GpuOverrides.scan[OrcScan]( + "ORC parsing", + (a, conf, p, r) => new ScanMeta[OrcScan](a, conf, p, r) { + override def tagSelfForGpu(): Unit = + GpuOrcScanBase.tagSupport(this) + + override def convertToGpu(): Scan = + GpuOrcScan(a.sparkSession, + a.hadoopConf, + a.fileIndex, + a.dataSchema, + a.readDataSchema, + a.readPartitionSchema, + a.options, + a.pushedFilters, + a.partitionFilters, + a.dataFilters, + conf) + }) + ).map(r => (r.getClassFor.asSubclass(classOf[Scan]), r)).toMap + override def getBuildSide(join: HashJoin): GpuBuildSide = { GpuJoinUtils.getGpuBuildSide(join.buildSide) } diff --git a/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/GpuOrcScan.scala b/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/GpuOrcScan.scala new file mode 100644 index 00000000000..1c65fcb4fee --- /dev/null +++ b/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/GpuOrcScan.scala @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims.spark310 + +import com.nvidia.spark.rapids.{GpuOrcScanBase, RapidsConf} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.v2.FileScan +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +// FileScan changed in Spark 3.1.0 so need to compile in Shim +case class GpuOrcScan( + sparkSession: SparkSession, + hadoopConf: Configuration, + fileIndex: PartitioningAwareFileIndex, + dataSchema: StructType, + readDataSchema: StructType, + readPartitionSchema: StructType, + options: CaseInsensitiveStringMap, + pushedFilters: Array[Filter], + partitionFilters: Seq[Expression], + dataFilters: Seq[Expression], + rapidsConf: RapidsConf) + extends GpuOrcScanBase(sparkSession, hadoopConf, dataSchema, readDataSchema, + readPartitionSchema, pushedFilters, rapidsConf) with FileScan { + + override def isSplitable(path: Path): Boolean = super.isSplitableBase(path) + + override def createReaderFactory(): PartitionReaderFactory = super.createReaderFactoryBase() + + override def equals(obj: Any): Boolean = obj match { + case o: GpuOrcScan => + super.equals(o) && dataSchema == o.dataSchema && options == o.options && + equivalentFilters(pushedFilters, o.pushedFilters) && rapidsConf == o.rapidsConf + case _ => false + } + + override def hashCode(): Int = getClass.hashCode() + + override def description(): String = { + super.description() + ", PushedFilters: " + seqToString(pushedFilters) + } + + override def withFilters( + partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = + this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) +} diff --git a/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/GpuParquetScan.scala b/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/GpuParquetScan.scala new file mode 100644 index 00000000000..aac3b13098a --- /dev/null +++ b/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/GpuParquetScan.scala @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims.spark310 + +import com.nvidia.spark.rapids.{GpuParquetScanBase, RapidsConf} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.read.PartitionReaderFactory +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.v2.FileScan +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +// FileScan changed in Spark 3.1.0 so need to compile in Shim +case class GpuParquetScan( + sparkSession: SparkSession, + hadoopConf: Configuration, + fileIndex: PartitioningAwareFileIndex, + dataSchema: StructType, + readDataSchema: StructType, + readPartitionSchema: StructType, + pushedFilters: Array[Filter], + options: CaseInsensitiveStringMap, + partitionFilters: Seq[Expression], + dataFilters: Seq[Expression], + rapidsConf: RapidsConf) + extends GpuParquetScanBase(sparkSession, hadoopConf, dataSchema, + readDataSchema, readPartitionSchema, pushedFilters, rapidsConf) with FileScan { + + override def isSplitable(path: Path): Boolean = super.isSplitableBase(path) + + override def createReaderFactory(): PartitionReaderFactory = super.createReaderFactoryBase() + + override def equals(obj: Any): Boolean = obj match { + case p: GpuParquetScan => + super.equals(p) && dataSchema == p.dataSchema && options == p.options && + equivalentFilters(pushedFilters, p.pushedFilters) && rapidsConf == p.rapidsConf + case _ => false + } + + override def hashCode(): Int = getClass.hashCode() + + override def description(): String = { + super.description() + ", PushedFilters: " + seqToString(pushedFilters) + } + + override def withFilters( + partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = + this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) +} diff --git a/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/Spark310Shims.scala b/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/Spark310Shims.scala index 9840c8f229d..02d3e175eb6 100644 --- a/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/Spark310Shims.scala +++ b/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/Spark310Shims.scala @@ -26,8 +26,12 @@ import org.apache.spark.SparkEnv import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.HadoopFsRelation +import org.apache.spark.sql.execution.datasources.v2.csv.CSVScan +import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan +import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashJoin, SortMergeJoinExec} import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec import org.apache.spark.sql.rapids.{GpuFileSourceScanExec, GpuTimeSub, ShuffleManagerShimBase} @@ -163,6 +167,47 @@ class Spark310Shims extends Spark301Shims { ).map(r => (r.getClassFor.asSubclass(classOf[SparkPlan]), r)).toMap } + override def getScans: Map[Class[_ <: Scan], ScanRule[_ <: Scan]] = Seq( + GpuOverrides.scan[ParquetScan]( + "Parquet parsing", + (a, conf, p, r) => new ScanMeta[ParquetScan](a, conf, p, r) { + override def tagSelfForGpu(): Unit = GpuParquetScanBase.tagSupport(this) + + override def convertToGpu(): Scan = + GpuParquetScan(a.sparkSession, + a.hadoopConf, + a.fileIndex, + a.dataSchema, + a.readDataSchema, + a.readPartitionSchema, + a.pushedFilters, + a.options, + a.partitionFilters, + a.dataFilters, + conf) + }), + GpuOverrides.scan[OrcScan]( + "ORC parsing", + (a, conf, p, r) => new ScanMeta[OrcScan](a, conf, p, r) { + override def tagSelfForGpu(): Unit = + GpuOrcScanBase.tagSupport(this) + + override def convertToGpu(): Scan = + GpuOrcScan(a.sparkSession, + a.hadoopConf, + a.fileIndex, + a.dataSchema, + a.readDataSchema, + a.readPartitionSchema, + a.options, + a.pushedFilters, + a.partitionFilters, + a.dataFilters, + conf) + }) + ).map(r => (r.getClassFor.asSubclass(classOf[Scan]), r)).toMap + + override def getBuildSide(join: HashJoin): GpuBuildSide = { GpuJoinUtils.getGpuBuildSide(join.buildSide) } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala index 55417797bef..4a6fc01310e 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala @@ -45,40 +45,34 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} import org.apache.spark.sql.execution.QueryExecutionException -import org.apache.spark.sql.execution.datasources.{PartitionedFile, PartitioningAwareFileIndex} +import org.apache.spark.sql.execution.datasources.PartitionedFile import org.apache.spark.sql.execution.datasources.orc.OrcUtils -import org.apache.spark.sql.execution.datasources.v2.{FilePartitionReaderFactory, FileScan} +import org.apache.spark.sql.execution.datasources.v2.FilePartitionReaderFactory import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.rapids.OrcFilters import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.SerializableConfiguration -case class GpuOrcScan( +abstract class GpuOrcScanBase( sparkSession: SparkSession, hadoopConf: Configuration, - fileIndex: PartitioningAwareFileIndex, dataSchema: StructType, readDataSchema: StructType, readPartitionSchema: StructType, - options: CaseInsensitiveStringMap, pushedFilters: Array[Filter], - partitionFilters: Seq[Expression], - dataFilters: Seq[Expression], rapidsConf: RapidsConf) - extends FileScan with ScanWithMetrics { + extends ScanWithMetrics { - override def isSplitable(path: Path): Boolean = true + def isSplitableBase(path: Path): Boolean = true - override def createReaderFactory(): PartitionReaderFactory = { + def createReaderFactoryBase(): PartitionReaderFactory = { // Unset any serialized search argument setup by Spark's OrcScanBuilder as // it will be incompatible due to shading and potential ORC classifier mismatch. hadoopConf.unset(OrcConf.KRYO_SARG.getAttribute) @@ -88,26 +82,9 @@ case class GpuOrcScan( GpuOrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, dataSchema, readDataSchema, readPartitionSchema, pushedFilters, rapidsConf, metrics) } - - override def equals(obj: Any): Boolean = obj match { - case o: GpuOrcScan => - super.equals(o) && dataSchema == o.dataSchema && options == o.options && - equivalentFilters(pushedFilters, o.pushedFilters) && rapidsConf == o.rapidsConf - case _ => false - } - - override def hashCode(): Int = getClass.hashCode() - - override def description(): String = { - super.description() + ", PushedFilters: " + seqToString(pushedFilters) - } - - override def withFilters( - partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = - this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) } -object GpuOrcScan { +object GpuOrcScanBase { def tagSupport(scanMeta: ScanMeta[OrcScan]): Unit = { val scan = scanMeta.wrapped val schema = StructType(scan.readDataSchema ++ scan.readPartitionSchema) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 102743c5f53..3f97b4d5bfa 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -1504,8 +1504,8 @@ object GpuOverrides { .map(r => r.wrap(scan, conf, parent, r).asInstanceOf[ScanMeta[INPUT]]) .getOrElse(new RuleNotFoundScanMeta(scan, conf, parent)) - val scans : Map[Class[_ <: Scan], ScanRule[_ <: Scan]] = Seq( - scan[CSVScan]( + val commonScans: Map[Class[_ <: Scan], ScanRule[_ <: Scan]] = Seq( + GpuOverrides.scan[CSVScan]( "CSV parsing", (a, conf, p, r) => new ScanMeta[CSVScan](a, conf, p, r) { override def tagSelfForGpu(): Unit = GpuCSVScan.tagSupport(this) @@ -1521,45 +1521,10 @@ object GpuOverrides { a.dataFilters, conf.maxReadBatchSizeRows, conf.maxReadBatchSizeBytes) - }), - scan[ParquetScan]( - "Parquet parsing", - (a, conf, p, r) => new ScanMeta[ParquetScan](a, conf, p, r) { - override def tagSelfForGpu(): Unit = GpuParquetScan.tagSupport(this) - - override def convertToGpu(): Scan = - GpuParquetScan(a.sparkSession, - a.hadoopConf, - a.fileIndex, - a.dataSchema, - a.readDataSchema, - a.readPartitionSchema, - a.pushedFilters, - a.options, - a.partitionFilters, - a.dataFilters, - conf) - }), - scan[OrcScan]( - "ORC parsing", - (a, conf, p, r) => new ScanMeta[OrcScan](a, conf, p, r) { - override def tagSelfForGpu(): Unit = - GpuOrcScan.tagSupport(this) + })).map(r => (r.getClassFor.asSubclass(classOf[Scan]), r)).toMap - override def convertToGpu(): Scan = - GpuOrcScan(a.sparkSession, - a.hadoopConf, - a.fileIndex, - a.dataSchema, - a.readDataSchema, - a.readPartitionSchema, - a.options, - a.pushedFilters, - a.partitionFilters, - a.dataFilters, - conf) - }) - ).map(r => (r.getClassFor.asSubclass(classOf[Scan]), r)).toMap + val scans: Map[Class[_ <: Scan], ScanRule[_ <: Scan]] = + commonScans ++ ShimLoader.getSparkShims.getScans def wrapPart[INPUT <: Partitioning]( part: INPUT, diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala index 544eb44b5c1..8d5b2fccf04 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala @@ -48,7 +48,6 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.execution.datasources.{PartitionedFile, PartitioningAwareFileIndex} @@ -60,52 +59,30 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.rapids.execution.TrampolineUtil import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.{StringType, StructType, TimestampType} -import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.SerializableConfiguration -case class GpuParquetScan( +abstract class GpuParquetScanBase( sparkSession: SparkSession, hadoopConf: Configuration, - fileIndex: PartitioningAwareFileIndex, dataSchema: StructType, readDataSchema: StructType, readPartitionSchema: StructType, pushedFilters: Array[Filter], - options: CaseInsensitiveStringMap, - partitionFilters: Seq[Expression], - dataFilters: Seq[Expression], rapidsConf: RapidsConf) - extends FileScan with ScanWithMetrics { + extends ScanWithMetrics { - override def isSplitable(path: Path): Boolean = true + def isSplitableBase(path: Path): Boolean = true - override def createReaderFactory(): PartitionReaderFactory = { + def createReaderFactoryBase(): PartitionReaderFactory = { val broadcastedConf = sparkSession.sparkContext.broadcast( new SerializableConfiguration(hadoopConf)) GpuParquetPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, dataSchema, readDataSchema, readPartitionSchema, pushedFilters, rapidsConf, metrics) } - - override def equals(obj: Any): Boolean = obj match { - case p: GpuParquetScan => - super.equals(p) && dataSchema == p.dataSchema && options == p.options && - equivalentFilters(pushedFilters, p.pushedFilters) && rapidsConf == p.rapidsConf - case _ => false - } - - override def hashCode(): Int = getClass.hashCode() - - override def description(): String = { - super.description() + ", PushedFilters: " + seqToString(pushedFilters) - } - - override def withFilters( - partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = - this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) } -object GpuParquetScan { +object GpuParquetScanBase { def tagSupport(scanMeta: ScanMeta[ParquetScan]): Unit = { val scan = scanMeta.wrapped val schema = StructType(scan.readDataSchema ++ scan.readPartitionSchema) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReadOrcFileFormat.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReadOrcFileFormat.scala index 3115f08825e..a8450079afc 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReadOrcFileFormat.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReadOrcFileFormat.scala @@ -63,7 +63,7 @@ object GpuReadOrcFileFormat { if (fsse.relation.options.getOrElse("mergeSchema", "false").toBoolean) { meta.willNotWorkOnGpu("mergeSchema and schema evolution is not supported yet") } - GpuOrcScan.tagSupport( + GpuOrcScanBase.tagSupport( fsse.sqlContext.sparkSession, fsse.requiredSchema, meta diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReadParquetFileFormat.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReadParquetFileFormat.scala index d468348776c..0d0ea9a05f6 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReadParquetFileFormat.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReadParquetFileFormat.scala @@ -60,7 +60,7 @@ class GpuReadParquetFileFormat extends ParquetFileFormat with GpuReadFileFormatW object GpuReadParquetFileFormat { def tagSupport(meta: SparkPlanMeta[FileSourceScanExec]): Unit = { val fsse = meta.wrapped - GpuParquetScan.tagSupport( + GpuParquetScanBase.tagSupport( fsse.sqlContext.sparkSession, fsse.requiredSchema, meta diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala index fe347eccc65..c95dc764f53 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec import org.apache.spark.sql.execution.datasources.{FilePartition, HadoopFsRelation, PartitionDirectory, PartitionedFile} @@ -65,6 +66,8 @@ trait SparkShims { def getBuildSide(join: BroadcastNestedLoopJoinExec): GpuBuildSide def getExprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] def getExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] + def getScans: Map[Class[_ <: Scan], ScanRule[_ <: Scan]] + def getScalaUDFAsExpression( function: AnyRef, dataType: DataType,