From 2112f7cf47e19e5ad3ad03300db65b39e446e657 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Tue, 17 Nov 2020 10:19:17 -0600 Subject: [PATCH] Support saveAsTable for writing orc and parquet (#1134) * start saveAsTable * Add GpuDataSource * columnar ifle format * Update to GpuFileFormat * fix typo * logging * more logging * change format parquet * fix classof * fix run to runColumnar * using original providing instance for end * remove unneeded code and pass in providers so don't calculate twice * create shim for SchemaUtils checkSchemaColumnNameDuplication Signed-off-by: Thomas Graves * fix typo with checkSchemaColumnNameDuplication * fix name * fix calling * fix anothername * fix none * Fix provider vs FileFormat * split read/write tests * Write a bunch more tests for orc and parquet writing Signed-off-by: Thomas Graves * cleanup and csv test * Add more test * Add bucket write test Signed-off-by: Thomas Graves * remove debug logs Signed-off-by: Thomas Graves * Update for spark 3.1.0 --- integration_tests/src/main/python/asserts.py | 36 + integration_tests/src/main/python/csv_test.py | 12 +- integration_tests/src/main/python/orc_test.py | 54 +- .../src/main/python/orc_write_test.py | 119 ++++ .../src/main/python/parquet_test.py | 75 -- .../src/main/python/parquet_write_test.py | 239 +++++++ .../rapids/shims/spark300/Spark300Shims.scala | 8 + .../shims/spark300/GpuSchemaUtils.scala | 31 + .../rapids/shims/spark310/Spark310Shims.scala | 9 + .../shims/spark310/GpuSchemaUtils.scala | 31 + .../nvidia/spark/rapids/GpuOverrides.scala | 58 +- .../com/nvidia/spark/rapids/SparkShims.scala | 6 + ...CreateDataSourceTableAsSelectCommand.scala | 130 ++++ .../spark/sql/rapids/GpuDataSource.scala | 655 ++++++++++++++++++ 14 files changed, 1333 insertions(+), 130 deletions(-) create mode 100644 integration_tests/src/main/python/orc_write_test.py create mode 100644 integration_tests/src/main/python/parquet_write_test.py create mode 100644 shims/spark300/src/main/scala/org/apache/spark/sql/rapids/shims/spark300/GpuSchemaUtils.scala create mode 100644 shims/spark310/src/main/scala/org/apache/spark/sql/rapids/shims/spark310/GpuSchemaUtils.scala create mode 100644 sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuCreateDataSourceTableAsSelectCommand.scala create mode 100644 sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuDataSource.scala diff --git a/integration_tests/src/main/python/asserts.py b/integration_tests/src/main/python/asserts.py index 1115466fef0..4cea2e22793 100644 --- a/integration_tests/src/main/python/asserts.py +++ b/integration_tests/src/main/python/asserts.py @@ -238,6 +238,42 @@ def assert_gpu_and_cpu_writes_are_equal_iterator(write_func, read_func, base_pat """ _assert_gpu_and_cpu_writes_are_equal(write_func, read_func, base_path, False, conf=conf) +def assert_gpu_fallback_write(write_func, + read_func, + base_path, + cpu_fallback_class_name, + conf={}): + conf = _prep_incompat_conf(conf) + + print('### CPU RUN ###') + cpu_start = time.time() + cpu_path = base_path + '/CPU' + with_cpu_session(lambda spark : write_func(spark, cpu_path), conf=conf) + cpu_end = time.time() + print('### GPU RUN ###') + jvm = spark_jvm() + jvm.com.nvidia.spark.rapids.ExecutionPlanCaptureCallback.startCapture() + gpu_start = time.time() + gpu_path = base_path + '/GPU' + with_gpu_session(lambda spark : write_func(spark, gpu_path), conf=conf) + gpu_end = time.time() + jvm.com.nvidia.spark.rapids.ExecutionPlanCaptureCallback.assertCapturedAndGpuFellBack(cpu_fallback_class_name, 2000) + print('### WRITE: GPU TOOK {} CPU TOOK {} ###'.format( + gpu_end - gpu_start, cpu_end - cpu_start)) + + (cpu_bring_back, cpu_collect_type) = _prep_func_for_compare( + lambda spark: read_func(spark, cpu_path), True) + (gpu_bring_back, gpu_collect_type) = _prep_func_for_compare( + lambda spark: read_func(spark, gpu_path), True) + + from_cpu = with_cpu_session(cpu_bring_back, conf=conf) + from_gpu = with_cpu_session(gpu_bring_back, conf=conf) + if should_sort_locally(): + from_cpu.sort(key=_RowCmp) + from_gpu.sort(key=_RowCmp) + + assert_equal(from_cpu, from_gpu) + def assert_gpu_fallback_collect(func, cpu_fallback_class_name, conf={}): diff --git a/integration_tests/src/main/python/csv_test.py b/integration_tests/src/main/python/csv_test.py index 27fbc2b382c..609a4d449bd 100644 --- a/integration_tests/src/main/python/csv_test.py +++ b/integration_tests/src/main/python/csv_test.py @@ -14,7 +14,7 @@ import pytest -from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect +from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect, assert_gpu_fallback_write from datetime import datetime, timezone from data_gen import * from marks import * @@ -245,3 +245,13 @@ def test_input_meta(spark_tmp_path, v1_enabled_list): 'input_file_block_start()', 'input_file_block_length()'), conf={'spark.sql.sources.useV1SourceList': v1_enabled_list}) + +@allow_non_gpu('DataWritingCommandExec') +def test_csv_save_as_table_fallback(spark_tmp_path, spark_tmp_table_factory): + gen = TimestampGen() + data_path = spark_tmp_path + '/CSV_DATA' + assert_gpu_fallback_write( + lambda spark, path: unary_op_df(spark, gen).coalesce(1).write.format("csv").mode('overwrite').option("path", path).saveAsTable(spark_tmp_table_factory.get()), + lambda spark, path: spark.read.csv(path), + data_path, + 'DataWritingCommandExec') diff --git a/integration_tests/src/main/python/orc_test.py b/integration_tests/src/main/python/orc_test.py index 957620250ca..661c124a9fd 100644 --- a/integration_tests/src/main/python/orc_test.py +++ b/integration_tests/src/main/python/orc_test.py @@ -14,7 +14,7 @@ import pytest -from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_writes_are_equal_collect, assert_gpu_fallback_collect +from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect from datetime import date, datetime, timezone from data_gen import * from marks import * @@ -50,7 +50,7 @@ def test_orc_fallback(spark_tmp_path, read_func, disable_conf): gen_list = [('_c' + str(i), gen) for i, gen in enumerate(data_gens)] gen = StructGen(gen_list, nullable=False) - data_path = spark_tmp_path + '/PARQUET_DATA' + data_path = spark_tmp_path + '/ORC_DATA' reader = read_func(data_path) with_cpu_session( lambda spark : gen_df(spark, gen).write.orc(data_path)) @@ -151,55 +151,6 @@ def test_merge_schema_read(spark_tmp_path, v1_enabled_list): lambda spark : spark.read.option('mergeSchema', 'true').orc(data_path), conf={'spark.sql.sources.useV1SourceList': v1_enabled_list}) -orc_write_gens_list = [ - [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, - string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)), - TimestampGen(start=datetime(1970, 1, 1, tzinfo=timezone.utc))], - pytest.param([date_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/139')), - pytest.param([timestamp_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/140'))] - -@pytest.mark.parametrize('orc_gens', orc_write_gens_list, ids=idfn) -def test_write_round_trip(spark_tmp_path, orc_gens): - gen_list = [('_c' + str(i), gen) for i, gen in enumerate(orc_gens)] - data_path = spark_tmp_path + '/ORC_DATA' - assert_gpu_and_cpu_writes_are_equal_collect( - lambda spark, path: gen_df(spark, gen_list).coalesce(1).write.orc(path), - lambda spark, path: spark.read.orc(path), - data_path) - -orc_part_write_gens = [ - byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, boolean_gen, - # Some file systems have issues with UTF8 strings so to help the test pass even there - StringGen('(\\w| ){0,50}'), - # Once https://github.com/NVIDIA/spark-rapids/issues/139 is fixed replace this with - # date_gen - DateGen(start=date(1590, 1, 1)), - # Once https://github.com/NVIDIA/spark-rapids/issues/140 is fixed replace this with - # timestamp_gen - TimestampGen(start=datetime(1970, 1, 1, tzinfo=timezone.utc))] - -# There are race conditions around when individual files are read in for partitioned data -@ignore_order -@pytest.mark.parametrize('orc_gen', orc_part_write_gens, ids=idfn) -def test_part_write_round_trip(spark_tmp_path, orc_gen): - gen_list = [('a', RepeatSeqGen(orc_gen, 10)), - ('b', orc_gen)] - data_path = spark_tmp_path + '/ORC_DATA' - assert_gpu_and_cpu_writes_are_equal_collect( - lambda spark, path: gen_df(spark, gen_list).coalesce(1).write.partitionBy('a').orc(path), - lambda spark, path: spark.read.orc(path), - data_path) - -orc_write_compress_options = ['none', 'uncompressed', 'snappy'] -@pytest.mark.parametrize('compress', orc_write_compress_options) -def test_compress_write_round_trip(spark_tmp_path, compress): - data_path = spark_tmp_path + '/ORC_DATA' - assert_gpu_and_cpu_writes_are_equal_collect( - lambda spark, path : binary_op_df(spark, long_gen).coalesce(1).write.orc(path), - lambda spark, path : spark.read.orc(path), - data_path, - conf={'spark.sql.orc.compression.codec': compress}) - @pytest.mark.xfail( condition=not(is_before_spark_310()), reason='https://github.com/NVIDIA/spark-rapids/issues/576') @@ -218,3 +169,4 @@ def test_input_meta(spark_tmp_path): 'input_file_name()', 'input_file_block_start()', 'input_file_block_length()')) + diff --git a/integration_tests/src/main/python/orc_write_test.py b/integration_tests/src/main/python/orc_write_test.py new file mode 100644 index 00000000000..4633315202a --- /dev/null +++ b/integration_tests/src/main/python/orc_write_test.py @@ -0,0 +1,119 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from asserts import assert_gpu_and_cpu_writes_are_equal_collect, assert_gpu_fallback_write +from datetime import date, datetime, timezone +from data_gen import * +from marks import * +from pyspark.sql.types import * + +orc_write_gens_list = [ + [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, + string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)), + TimestampGen(start=datetime(1970, 1, 1, tzinfo=timezone.utc))], + pytest.param([date_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/139')), + pytest.param([timestamp_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/140'))] + +@pytest.mark.parametrize('orc_gens', orc_write_gens_list, ids=idfn) +def test_write_round_trip(spark_tmp_path, orc_gens): + gen_list = [('_c' + str(i), gen) for i, gen in enumerate(orc_gens)] + data_path = spark_tmp_path + '/ORC_DATA' + assert_gpu_and_cpu_writes_are_equal_collect( + lambda spark, path: gen_df(spark, gen_list).coalesce(1).write.orc(path), + lambda spark, path: spark.read.orc(path), + data_path) + +orc_part_write_gens = [ + byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, boolean_gen, + # Some file systems have issues with UTF8 strings so to help the test pass even there + StringGen('(\\w| ){0,50}'), + # Once https://github.com/NVIDIA/spark-rapids/issues/139 is fixed replace this with + # date_gen + DateGen(start=date(1590, 1, 1)), + # Once https://github.com/NVIDIA/spark-rapids/issues/140 is fixed replace this with + # timestamp_gen + TimestampGen(start=datetime(1970, 1, 1, tzinfo=timezone.utc))] + +# There are race conditions around when individual files are read in for partitioned data +@ignore_order +@pytest.mark.parametrize('orc_gen', orc_part_write_gens, ids=idfn) +def test_part_write_round_trip(spark_tmp_path, orc_gen): + gen_list = [('a', RepeatSeqGen(orc_gen, 10)), + ('b', orc_gen)] + data_path = spark_tmp_path + '/ORC_DATA' + assert_gpu_and_cpu_writes_are_equal_collect( + lambda spark, path: gen_df(spark, gen_list).coalesce(1).write.partitionBy('a').orc(path), + lambda spark, path: spark.read.orc(path), + data_path) + +orc_write_compress_options = ['none', 'uncompressed', 'snappy'] +@pytest.mark.parametrize('compress', orc_write_compress_options) +def test_compress_write_round_trip(spark_tmp_path, compress): + data_path = spark_tmp_path + '/ORC_DATA' + assert_gpu_and_cpu_writes_are_equal_collect( + lambda spark, path : binary_op_df(spark, long_gen).coalesce(1).write.orc(path), + lambda spark, path : spark.read.orc(path), + data_path, + conf={'spark.sql.orc.compression.codec': compress}) + +@pytest.mark.parametrize('orc_gens', orc_write_gens_list, ids=idfn) +def test_write_save_table(spark_tmp_path, orc_gens, spark_tmp_table_factory): + gen_list = [('_c' + str(i), gen) for i, gen in enumerate(orc_gens)] + data_path = spark_tmp_path + '/ORC_DATA' + all_confs={'spark.sql.sources.useV1SourceList': "orc"} + assert_gpu_and_cpu_writes_are_equal_collect( + lambda spark, path: gen_df(spark, gen_list).coalesce(1).write.format("orc").mode('overwrite').option("path", path).saveAsTable(spark_tmp_table_factory.get()), + lambda spark, path: spark.read.orc(path), + data_path, + conf=all_confs) + +def write_orc_sql_from(spark, df, data_path, write_to_table): + tmp_view_name = 'tmp_view_{}'.format(random.randint(0, 1000000)) + df.createOrReplaceTempView(tmp_view_name) + write_cmd = 'CREATE TABLE `{}` USING ORC location \'{}\' AS SELECT * from `{}`'.format(write_to_table, data_path, tmp_view_name) + spark.sql(write_cmd) + +@pytest.mark.parametrize('orc_gens', orc_write_gens_list, ids=idfn) +@pytest.mark.parametrize('ts_type', ["TIMESTAMP_MICROS", "TIMESTAMP_MILLIS"]) +def test_write_sql_save_table(spark_tmp_path, orc_gens, ts_type, spark_tmp_table_factory): + gen_list = [('_c' + str(i), gen) for i, gen in enumerate(orc_gens)] + data_path = spark_tmp_path + '/ORC_DATA' + assert_gpu_and_cpu_writes_are_equal_collect( + lambda spark, path: write_orc_sql_from(spark, gen_df(spark, gen_list).coalesce(1), path, spark_tmp_table_factory.get()), + lambda spark, path: spark.read.orc(path), + data_path) + +@allow_non_gpu('DataWritingCommandExec') +@pytest.mark.parametrize('codec', ['zlib', 'lzo']) +def test_orc_write_compression_fallback(spark_tmp_path, codec, spark_tmp_table_factory): + gen = TimestampGen() + data_path = spark_tmp_path + '/PARQUET_DATA' + all_confs={'spark.sql.orc.compression.codec': codec} + assert_gpu_fallback_write( + lambda spark, path: unary_op_df(spark, gen).coalesce(1).write.format("orc").mode('overwrite').option("path", path).saveAsTable(spark_tmp_table_factory.get()), + lambda spark, path: spark.read.orc(path), + data_path, + 'DataWritingCommandExec', + conf=all_confs) + +@allow_non_gpu('DataWritingCommandExec') +def test_buckets_write_fallback(spark_tmp_path, spark_tmp_table_factory): + data_path = spark_tmp_path + '/ORC_DATA' + assert_gpu_fallback_write( + lambda spark, path: spark.range(10e4).write.bucketBy(4, "id").sortBy("id").format('orc').mode('overwrite').option("path", path).saveAsTable(spark_tmp_table_factory.get()), + lambda spark, path: spark.read.orc(path), + data_path, + 'DataWritingCommandExec') diff --git a/integration_tests/src/main/python/parquet_test.py b/integration_tests/src/main/python/parquet_test.py index b33aa910a31..b082a0bab2c 100644 --- a/integration_tests/src/main/python/parquet_test.py +++ b/integration_tests/src/main/python/parquet_test.py @@ -332,81 +332,6 @@ def test_read_merge_schema_from_conf(spark_tmp_path, v1_enabled_list, reader_con lambda spark : spark.read.parquet(data_path), conf=all_confs) -parquet_write_gens_list = [ - [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, - string_gen, boolean_gen, date_gen, timestamp_gen]] - -@pytest.mark.parametrize('parquet_gens', parquet_write_gens_list, ids=idfn) -@pytest.mark.parametrize('reader_confs', reader_opt_confs) -@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"]) -@pytest.mark.parametrize('ts_type', ["TIMESTAMP_MICROS", "TIMESTAMP_MILLIS"]) -def test_write_round_trip(spark_tmp_path, parquet_gens, v1_enabled_list, ts_type, reader_confs): - gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)] - data_path = spark_tmp_path + '/PARQUET_DATA' - all_confs = reader_confs.copy() - all_confs.update({'spark.sql.sources.useV1SourceList': v1_enabled_list, - 'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED', - 'spark.sql.parquet.outputTimestampType': ts_type}) - assert_gpu_and_cpu_writes_are_equal_collect( - lambda spark, path: gen_df(spark, gen_list).coalesce(1).write.parquet(path), - lambda spark, path: spark.read.parquet(path), - data_path, - conf=all_confs) - -@pytest.mark.parametrize('ts_type', ['TIMESTAMP_MILLIS', 'TIMESTAMP_MICROS']) -@pytest.mark.parametrize('ts_rebase', ['CORRECTED']) -@ignore_order -def test_write_ts_millis(spark_tmp_path, ts_type, ts_rebase): - gen = TimestampGen() - data_path = spark_tmp_path + '/PARQUET_DATA' - assert_gpu_and_cpu_writes_are_equal_collect( - lambda spark, path: unary_op_df(spark, gen).write.parquet(path), - lambda spark, path: spark.read.parquet(path), - data_path, - conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase, - 'spark.sql.parquet.outputTimestampType': ts_type}) - -parquet_part_write_gens = [ - byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, - # Some file systems have issues with UTF8 strings so to help the test pass even there - StringGen('(\\w| ){0,50}'), - boolean_gen, date_gen, timestamp_gen] - -# There are race conditions around when individual files are read in for partitioned data -@ignore_order -@pytest.mark.parametrize('parquet_gen', parquet_part_write_gens, ids=idfn) -@pytest.mark.parametrize('reader_confs', reader_opt_confs) -@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"]) -@pytest.mark.parametrize('ts_type', ['TIMESTAMP_MILLIS', 'TIMESTAMP_MICROS']) -def test_part_write_round_trip(spark_tmp_path, parquet_gen, v1_enabled_list, ts_type, reader_confs): - gen_list = [('a', RepeatSeqGen(parquet_gen, 10)), - ('b', parquet_gen)] - data_path = spark_tmp_path + '/PARQUET_DATA' - all_confs = reader_confs.copy() - all_confs.update({'spark.sql.sources.useV1SourceList': v1_enabled_list, - 'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED', - 'spark.sql.parquet.outputTimestampType': ts_type}) - assert_gpu_and_cpu_writes_are_equal_collect( - lambda spark, path: gen_df(spark, gen_list).coalesce(1).write.partitionBy('a').parquet(path), - lambda spark, path: spark.read.parquet(path), - data_path, - conf=all_confs) - -parquet_write_compress_options = ['none', 'uncompressed', 'snappy'] -@pytest.mark.parametrize('compress', parquet_write_compress_options) -@pytest.mark.parametrize('reader_confs', reader_opt_confs) -@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"]) -def test_compress_write_round_trip(spark_tmp_path, compress, v1_enabled_list, reader_confs): - data_path = spark_tmp_path + '/PARQUET_DATA' - all_confs = reader_confs.copy() - all_confs.update({'spark.sql.sources.useV1SourceList': v1_enabled_list, - 'spark.sql.parquet.compression.codec': compress}) - assert_gpu_and_cpu_writes_are_equal_collect( - lambda spark, path : binary_op_df(spark, long_gen).coalesce(1).write.parquet(path), - lambda spark, path : spark.read.parquet(path), - data_path, - conf=all_confs) - @pytest.mark.parametrize('reader_confs', reader_opt_confs) @pytest.mark.parametrize('v1_enabled_list', ["", "parquet"]) def test_input_meta(spark_tmp_path, v1_enabled_list, reader_confs): diff --git a/integration_tests/src/main/python/parquet_write_test.py b/integration_tests/src/main/python/parquet_write_test.py new file mode 100644 index 00000000000..a3014dc31af --- /dev/null +++ b/integration_tests/src/main/python/parquet_write_test.py @@ -0,0 +1,239 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_writes_are_equal_collect, assert_gpu_fallback_write +from datetime import date, datetime, timezone +from data_gen import * +from marks import * +from pyspark.sql.types import * +from spark_session import with_cpu_session, with_gpu_session, is_before_spark_310 + +# test with original parquet file reader, the multi-file parallel reader for cloud, and coalesce file reader for +# non-cloud +original_parquet_file_reader_conf={'spark.rapids.sql.format.parquet.reader.type': 'PERFILE'} +multithreaded_parquet_file_reader_conf={'spark.rapids.sql.format.parquet.reader.type': 'MULTITHREADED'} +coalesce_parquet_file_reader_conf={'spark.rapids.sql.format.parquet.reader.type': 'COALESCING'} +reader_opt_confs = [original_parquet_file_reader_conf, multithreaded_parquet_file_reader_conf, + coalesce_parquet_file_reader_conf] + +parquet_write_gens_list = [ + [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, + string_gen, boolean_gen, date_gen, timestamp_gen]] + +@pytest.mark.parametrize('parquet_gens', parquet_write_gens_list, ids=idfn) +@pytest.mark.parametrize('reader_confs', reader_opt_confs) +@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"]) +@pytest.mark.parametrize('ts_type', ["TIMESTAMP_MICROS", "TIMESTAMP_MILLIS"]) +def test_write_round_trip(spark_tmp_path, parquet_gens, v1_enabled_list, ts_type, reader_confs): + gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)] + data_path = spark_tmp_path + '/PARQUET_DATA' + all_confs = reader_confs.copy() + all_confs.update({'spark.sql.sources.useV1SourceList': v1_enabled_list, + 'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED', + 'spark.sql.parquet.outputTimestampType': ts_type}) + assert_gpu_and_cpu_writes_are_equal_collect( + lambda spark, path: gen_df(spark, gen_list).coalesce(1).write.parquet(path), + lambda spark, path: spark.read.parquet(path), + data_path, + conf=all_confs) + +@pytest.mark.parametrize('ts_type', ['TIMESTAMP_MILLIS', 'TIMESTAMP_MICROS']) +@pytest.mark.parametrize('ts_rebase', ['CORRECTED']) +@ignore_order +def test_write_ts_millis(spark_tmp_path, ts_type, ts_rebase): + gen = TimestampGen() + data_path = spark_tmp_path + '/PARQUET_DATA' + assert_gpu_and_cpu_writes_are_equal_collect( + lambda spark, path: unary_op_df(spark, gen).write.parquet(path), + lambda spark, path: spark.read.parquet(path), + data_path, + conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase, + 'spark.sql.parquet.outputTimestampType': ts_type}) + +parquet_part_write_gens = [ + byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, + # Some file systems have issues with UTF8 strings so to help the test pass even there + StringGen('(\\w| ){0,50}'), + boolean_gen, date_gen, timestamp_gen] + +# There are race conditions around when individual files are read in for partitioned data +@ignore_order +@pytest.mark.parametrize('parquet_gen', parquet_part_write_gens, ids=idfn) +@pytest.mark.parametrize('reader_confs', reader_opt_confs) +@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"]) +@pytest.mark.parametrize('ts_type', ['TIMESTAMP_MILLIS', 'TIMESTAMP_MICROS']) +def test_part_write_round_trip(spark_tmp_path, parquet_gen, v1_enabled_list, ts_type, reader_confs): + gen_list = [('a', RepeatSeqGen(parquet_gen, 10)), + ('b', parquet_gen)] + data_path = spark_tmp_path + '/PARQUET_DATA' + all_confs = reader_confs.copy() + all_confs.update({'spark.sql.sources.useV1SourceList': v1_enabled_list, + 'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED', + 'spark.sql.parquet.outputTimestampType': ts_type}) + assert_gpu_and_cpu_writes_are_equal_collect( + lambda spark, path: gen_df(spark, gen_list).coalesce(1).write.partitionBy('a').parquet(path), + lambda spark, path: spark.read.parquet(path), + data_path, + conf=all_confs) + +parquet_write_compress_options = ['none', 'uncompressed', 'snappy'] +@pytest.mark.parametrize('compress', parquet_write_compress_options) +@pytest.mark.parametrize('reader_confs', reader_opt_confs) +@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"]) +def test_compress_write_round_trip(spark_tmp_path, compress, v1_enabled_list, reader_confs): + data_path = spark_tmp_path + '/PARQUET_DATA' + all_confs = reader_confs.copy() + all_confs.update({'spark.sql.sources.useV1SourceList': v1_enabled_list, + 'spark.sql.parquet.compression.codec': compress}) + assert_gpu_and_cpu_writes_are_equal_collect( + lambda spark, path : binary_op_df(spark, long_gen).coalesce(1).write.parquet(path), + lambda spark, path : spark.read.parquet(path), + data_path, + conf=all_confs) + +@pytest.mark.parametrize('parquet_gens', parquet_write_gens_list, ids=idfn) +@pytest.mark.parametrize('ts_type', ["TIMESTAMP_MICROS", "TIMESTAMP_MILLIS"]) +def test_write_save_table(spark_tmp_path, parquet_gens, ts_type, spark_tmp_table_factory): + gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)] + data_path = spark_tmp_path + '/PARQUET_DATA' + all_confs={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED', + 'spark.sql.parquet.outputTimestampType': ts_type} + assert_gpu_and_cpu_writes_are_equal_collect( + lambda spark, path: gen_df(spark, gen_list).coalesce(1).write.format("parquet").mode('overwrite').option("path", path).saveAsTable(spark_tmp_table_factory.get()), + lambda spark, path: spark.read.parquet(path), + data_path, + conf=all_confs) + +def write_parquet_sql_from(spark, df, data_path, write_to_table): + tmp_view_name = 'tmp_view_{}'.format(random.randint(0, 1000000)) + df.createOrReplaceTempView(tmp_view_name) + write_cmd = 'CREATE TABLE `{}` USING PARQUET location \'{}\' AS SELECT * from `{}`'.format(write_to_table, data_path, tmp_view_name) + spark.sql(write_cmd) + +@pytest.mark.parametrize('parquet_gens', parquet_write_gens_list, ids=idfn) +@pytest.mark.parametrize('ts_type', ["TIMESTAMP_MICROS", "TIMESTAMP_MILLIS"]) +def test_write_sql_save_table(spark_tmp_path, parquet_gens, ts_type, spark_tmp_table_factory): + gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)] + data_path = spark_tmp_path + '/PARQUET_DATA' + all_confs={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED', + 'spark.sql.parquet.outputTimestampType': ts_type} + assert_gpu_and_cpu_writes_are_equal_collect( + lambda spark, path: write_parquet_sql_from(spark, gen_df(spark, gen_list).coalesce(1), path, spark_tmp_table_factory.get()), + lambda spark, path: spark.read.parquet(path), + data_path, + conf=all_confs) + +def writeParquetUpgradeCatchException(spark, df, data_path, spark_tmp_table_factory, ts_rebase, ts_write): + spark.conf.set('spark.sql.legacy.parquet.datetimeRebaseModeInWrite', ts_rebase) + spark.conf.set('spark.sql.parquet.outputTimestampType', ts_write) + with pytest.raises(Exception) as e_info: + df.coalesce(1).write.format("parquet").mode('overwrite').option("path", data_path).saveAsTable(spark_tmp_table_factory.get()) + assert e_info.match(r".*SparkUpgradeException.*") + +# TODO - https://github.com/NVIDIA/spark-rapids/issues/1130 to handle TIMESTAMP_MILLIS +parquet_ts_write_options = ['TIMESTAMP_MICROS'] + +@pytest.mark.parametrize('ts_write', parquet_ts_write_options) +@pytest.mark.parametrize('ts_rebase', ['EXCEPTION']) +def test_ts_write_fails_datetime_exception(spark_tmp_path, ts_write, ts_rebase, spark_tmp_table_factory): + gen = TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc)) + data_path = spark_tmp_path + '/PARQUET_DATA' + with_gpu_session( + lambda spark : writeParquetUpgradeCatchException(spark, unary_op_df(spark, gen), data_path, spark_tmp_table_factory, ts_rebase, ts_write)) + +def writeParquetNoOverwriteCatchException(spark, df, data_path, table_name): + with pytest.raises(Exception) as e_info: + df.coalesce(1).write.format("parquet").option("path", data_path).saveAsTable(table_name) + assert e_info.match(r".*already exists.*") + +def test_ts_write_twice_fails_exception(spark_tmp_path, spark_tmp_table_factory): + gen = IntegerGen() + data_path = spark_tmp_path + '/PARQUET_DATA' + table_name = spark_tmp_table_factory.get() + with_gpu_session( + lambda spark : unary_op_df(spark, gen).coalesce(1).write.format("parquet").mode('overwrite').option("path", data_path).saveAsTable(table_name)) + with_gpu_session( + lambda spark : writeParquetNoOverwriteCatchException(spark, unary_op_df(spark, gen), data_path, table_name)) + +parquet_ts_write_options = ['INT96', 'TIMESTAMP_MICROS', 'TIMESTAMP_MILLIS'] + +@allow_non_gpu('DataWritingCommandExec') +@pytest.mark.parametrize('ts_write', parquet_ts_write_options) +@pytest.mark.parametrize('ts_rebase', ['LEGACY']) +def test_parquet_write_legacy_fallback(spark_tmp_path, ts_write, ts_rebase, spark_tmp_table_factory): + gen = TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc)) + data_path = spark_tmp_path + '/PARQUET_DATA' + all_confs={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase, + 'spark.sql.legacy.parquet.int96RebaseModeInWrite': "CORRECTED", + 'spark.sql.parquet.outputTimestampType': ts_write} + assert_gpu_fallback_write( + lambda spark, path: unary_op_df(spark, gen).coalesce(1).write.format("parquet").mode('overwrite').option("path", path).saveAsTable(spark_tmp_table_factory.get()), + lambda spark, path: spark.read.parquet(path), + data_path, + 'DataWritingCommandExec', + conf=all_confs) + +@allow_non_gpu('DataWritingCommandExec') +@pytest.mark.parametrize('ts_write', ['INT96']) +@pytest.mark.parametrize('ts_rebase', ['CORRECTED', 'EXCEPTION', 'LEGACY']) +def test_parquet_write_int96_fallback(spark_tmp_path, ts_write, ts_rebase, spark_tmp_table_factory): + gen = TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc)) + data_path = spark_tmp_path + '/PARQUET_DATA' + all_confs={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase, + 'spark.sql.legacy.parquet.int96RebaseModeInWrite': "CORRECTED", + 'spark.sql.parquet.outputTimestampType': ts_write} + assert_gpu_fallback_write( + lambda spark, path: unary_op_df(spark, gen).coalesce(1).write.format("parquet").mode('overwrite').option("path", path).saveAsTable(spark_tmp_table_factory.get()), + lambda spark, path: spark.read.parquet(path), + data_path, + 'DataWritingCommandExec', + conf=all_confs) + +@allow_non_gpu('DataWritingCommandExec') +# note that others should fail as well but requires you to load the libraries for them +# 'lzo', 'brotli', 'lz4', 'zstd' should all fallback +@pytest.mark.parametrize('codec', ['gzip']) +def test_parquet_write_compression_fallback(spark_tmp_path, codec, spark_tmp_table_factory): + gen = IntegerGen() + data_path = spark_tmp_path + '/PARQUET_DATA' + all_confs={'spark.sql.parquet.compression.codec': codec} + assert_gpu_fallback_write( + lambda spark, path: unary_op_df(spark, gen).coalesce(1).write.format("parquet").mode('overwrite').option("path", path).saveAsTable(spark_tmp_table_factory.get()), + lambda spark, path: spark.read.parquet(path), + data_path, + 'DataWritingCommandExec', + conf=all_confs) + +@allow_non_gpu('DataWritingCommandExec') +def test_parquet_writeLegacyFormat_fallback(spark_tmp_path, spark_tmp_table_factory): + gen = IntegerGen() + data_path = spark_tmp_path + '/PARQUET_DATA' + all_confs={'spark.sql.parquet.writeLegacyFormat': 'true'} + assert_gpu_fallback_write( + lambda spark, path: unary_op_df(spark, gen).coalesce(1).write.format("parquet").mode('overwrite').option("path", path).saveAsTable(spark_tmp_table_factory.get()), + lambda spark, path: spark.read.parquet(path), + data_path, + 'DataWritingCommandExec', + conf=all_confs) + +@allow_non_gpu('DataWritingCommandExec') +def test_buckets_write_fallback(spark_tmp_path, spark_tmp_table_factory): + data_path = spark_tmp_path + '/PARQUET_DATA' + assert_gpu_fallback_write( + lambda spark, path: spark.range(10e4).write.bucketBy(4, "id").sortBy("id").format('parquet').mode('overwrite').option("path", path).saveAsTable(spark_tmp_table_factory.get()), + lambda spark, path: spark.read.parquet(path), + data_path, + 'DataWritingCommandExec') diff --git a/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala b/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala index 3044e3fee9d..bd46fb78bc1 100644 --- a/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala +++ b/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala @@ -25,6 +25,7 @@ import org.apache.spark.SparkEnv import org.apache.spark.rdd.RDD import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} @@ -400,4 +401,11 @@ class Spark300Shims extends SparkShims { exportColumnRdd: Boolean): GpuColumnarToRowExecParent = { GpuColumnarToRowExec(plan, exportColumnRdd) } + + override def checkColumnNameDuplication( + schema: StructType, + colType: String, + resolver: Resolver): Unit = { + GpuSchemaUtils.checkColumnNameDuplication(schema, colType, resolver) + } } diff --git a/shims/spark300/src/main/scala/org/apache/spark/sql/rapids/shims/spark300/GpuSchemaUtils.scala b/shims/spark300/src/main/scala/org/apache/spark/sql/rapids/shims/spark300/GpuSchemaUtils.scala new file mode 100644 index 00000000000..c3cae81cf15 --- /dev/null +++ b/shims/spark300/src/main/scala/org/apache/spark/sql/rapids/shims/spark300/GpuSchemaUtils.scala @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.shims.spark300 + +import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.SchemaUtils + +object GpuSchemaUtils { + + def checkColumnNameDuplication( + schema: StructType, + colType: String, + resolver: Resolver): Unit = { + SchemaUtils.checkColumnNameDuplication(schema.map(_.name), colType, resolver) + } +} diff --git a/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/Spark310Shims.scala b/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/Spark310Shims.scala index 2939ff18534..8ca234236dc 100644 --- a/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/Spark310Shims.scala +++ b/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/Spark310Shims.scala @@ -21,6 +21,7 @@ import com.nvidia.spark.rapids.shims.spark301.Spark301Shims import com.nvidia.spark.rapids.spark310.RapidsShuffleManager import org.apache.spark.SparkEnv +import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.JoinType @@ -279,4 +280,12 @@ class Spark310Shims extends Spark301Shims { GpuColumnarToRowExec(plan) } } + + override def checkColumnNameDuplication( + schema: StructType, + colType: String, + resolver: Resolver): Unit = { + GpuSchemaUtils.checkColumnNameDuplication(schema, colType, resolver) + } + } diff --git a/shims/spark310/src/main/scala/org/apache/spark/sql/rapids/shims/spark310/GpuSchemaUtils.scala b/shims/spark310/src/main/scala/org/apache/spark/sql/rapids/shims/spark310/GpuSchemaUtils.scala new file mode 100644 index 00000000000..c71c77d6239 --- /dev/null +++ b/shims/spark310/src/main/scala/org/apache/spark/sql/rapids/shims/spark310/GpuSchemaUtils.scala @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.shims.spark310 + +import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.SchemaUtils + +object GpuSchemaUtils { + + def checkColumnNameDuplication( + schema: StructType, + colType: String, + resolver: Resolver): Unit = { + SchemaUtils.checkSchemaColumnNameDuplication(schema, colType, resolver) + } +} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 54d3681d0cf..66abfcac76e 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, CustomShuffleReaderExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} -import org.apache.spark.sql.execution.command.{DataWritingCommand, DataWritingCommandExec, ExecutedCommandExec} +import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, DataWritingCommand, DataWritingCommandExec, ExecutedCommandExec} import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.json.JsonFileFormat @@ -290,7 +290,7 @@ final class InsertIntoHadoopFsRelationCommandMeta( willNotWorkOnGpu("bucketing is not supported") } - val spark = SparkSession.active + val spark = SparkSession.active fileFormat = cmd.fileFormat match { case _: CSVFileFormat => @@ -332,6 +332,55 @@ final class InsertIntoHadoopFsRelationCommandMeta( } } +final class CreateDataSourceTableAsSelectCommandMeta( + cmd: CreateDataSourceTableAsSelectCommand, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _, _]], + rule: ConfKeysAndIncompat) + extends DataWritingCommandMeta[CreateDataSourceTableAsSelectCommand](cmd, conf, parent, rule) { + + private var origProvider: Class[_] = _ + private var gpuProvider: Option[ColumnarFileFormat] = None + + override def tagSelfForGpu(): Unit = { + if (cmd.table.bucketSpec.isDefined) { + willNotWorkOnGpu("bucketing is not supported") + } + if (cmd.table.provider.isEmpty) { + willNotWorkOnGpu("provider must be defined") + } + + val spark = SparkSession.active + origProvider = + GpuDataSource.lookupDataSourceWithFallback(cmd.table.provider.get, spark.sessionState.conf) + // Note that the data source V2 always fallsback to the V1 currently. + // If that changes then this will start failing because we don't have a mapping. + gpuProvider = origProvider.getConstructor().newInstance() match { + case format: OrcFileFormat => + GpuOrcFileFormat.tagGpuSupport(this, spark, cmd.table.storage.properties) + case format: ParquetFileFormat => + GpuParquetFileFormat.tagGpuSupport(this, spark, + cmd.table.storage.properties, cmd.query.schema) + case ds => + willNotWorkOnGpu(s"Data source class not supported: ${ds}") + None + } + } + + override def convertToGpu(): GpuDataWritingCommand = { + val newProvider = gpuProvider.getOrElse( + throw new IllegalStateException("fileFormat unexpected, tagSelfForGpu not called?")) + + GpuCreateDataSourceTableAsSelectCommand( + cmd.table, + cmd.mode, + cmd.query, + cmd.outputColumnNames, + origProvider, + newProvider) + } +} + object GpuOverrides { val FLOAT_DIFFERS_GROUP_INCOMPAT = "when enabling these, there may be extra groups produced for floating point grouping " + @@ -1809,7 +1858,10 @@ object GpuOverrides { DataWritingCommandRule[_ <: DataWritingCommand]] = Seq( dataWriteCmd[InsertIntoHadoopFsRelationCommand]( "Write to Hadoop filesystem", - (a, conf, p, r) => new InsertIntoHadoopFsRelationCommandMeta(a, conf, p, r)) + (a, conf, p, r) => new InsertIntoHadoopFsRelationCommandMeta(a, conf, p, r)), + dataWriteCmd[CreateDataSourceTableAsSelectCommand]( + "Create table with select command", + (a, conf, p, r) => new CreateDataSourceTableAsSelectCommandMeta(a, conf, p, r)) ).map(r => (r.getClassFor.asSubclass(classOf[DataWritingCommand]), r)).toMap def wrapPlan[INPUT <: SparkPlan]( diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala index 846b4fde362..09b8ef2f3df 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala @@ -19,6 +19,7 @@ package com.nvidia.spark.rapids import org.apache.spark.rdd.RDD import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.JoinType @@ -139,5 +140,10 @@ trait SparkShims { def copyFileSourceScanExec( scanExec: GpuFileSourceScanExec, queryUsesInputFile: Boolean): GpuFileSourceScanExec + + def checkColumnNameDuplication( + schema: StructType, + colType: String, + resolver: Resolver): Unit } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuCreateDataSourceTableAsSelectCommand.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuCreateDataSourceTableAsSelectCommand.scala new file mode 100644 index 00000000000..5b5509df512 --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuCreateDataSourceTableAsSelectCommand.scala @@ -0,0 +1,130 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids + +import java.net.URI + +import com.nvidia.spark.rapids.{ColumnarFileFormat, GpuDataWritingCommand} + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.command.{AlterTableRecoverPartitionsCommand, CommandUtils} +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.sources.BaseRelation +import org.apache.spark.sql.vectorized.ColumnarBatch + +case class GpuCreateDataSourceTableAsSelectCommand( + table: CatalogTable, + mode: SaveMode, + query: LogicalPlan, + outputColumnNames: Seq[String], + origProvider: Class[_], + gpuFileFormat: ColumnarFileFormat) + extends GpuDataWritingCommand { + + override def runColumnar(sparkSession: SparkSession, child: SparkPlan): Seq[ColumnarBatch] = { + assert(table.tableType != CatalogTableType.VIEW) + assert(table.provider.isDefined) + + val sessionState = sparkSession.sessionState + val db = table.identifier.database.getOrElse(sessionState.catalog.getCurrentDatabase) + val tableIdentWithDB = table.identifier.copy(database = Some(db)) + val tableName = tableIdentWithDB.unquotedString + + if (sessionState.catalog.tableExists(tableIdentWithDB)) { + assert(mode != SaveMode.Overwrite, + s"Expect the table $tableName has been dropped when the save mode is Overwrite") + + if (mode == SaveMode.ErrorIfExists) { + throw new AnalysisException(s"Table $tableName already exists. You need to drop it first.") + } + if (mode == SaveMode.Ignore) { + // Since the table already exists and the save mode is Ignore, we will just return. + return Seq.empty + } + + saveDataIntoTable( + sparkSession, table, table.storage.locationUri, child, SaveMode.Append, tableExists = true) + } else { + assert(table.schema.isEmpty) + sparkSession.sessionState.catalog.validateTableLocation(table) + val tableLocation = if (table.tableType == CatalogTableType.MANAGED) { + Some(sessionState.catalog.defaultTablePath(table.identifier)) + } else { + table.storage.locationUri + } + val result = saveDataIntoTable( + sparkSession, table, tableLocation, child, SaveMode.Overwrite, tableExists = false) + val newTable = table.copy( + storage = table.storage.copy(locationUri = tableLocation), + // We will use the schema of resolved.relation as the schema of the table (instead of + // the schema of df). It is important since the nullability may be changed by the relation + // provider (for example, see org.apache.spark.sql.parquet.DefaultSource). + schema = result.schema) + // Table location is already validated. No need to check it again during table creation. + sessionState.catalog.createTable(newTable, ignoreIfExists = false, validateLocation = false) + + result match { + case fs: HadoopFsRelation if table.partitionColumnNames.nonEmpty && + sparkSession.sqlContext.conf.manageFilesourcePartitions => + // Need to recover partitions into the metastore so our saved data is visible. + sessionState.executePlan(AlterTableRecoverPartitionsCommand(table.identifier)).toRdd + case _ => + } + } + + CommandUtils.updateTableStats(sparkSession, table) + + Seq.empty[ColumnarBatch] + } + + private def saveDataIntoTable( + session: SparkSession, + table: CatalogTable, + tableLocation: Option[URI], + physicalPlan: SparkPlan, + mode: SaveMode, + tableExists: Boolean): BaseRelation = { + // Create the relation based on the input logical plan: `query`. + val pathOption = tableLocation.map("path" -> CatalogUtils.URIToString(_)) + val dataSource = GpuDataSource( + session, + className = table.provider.get, + partitionColumns = table.partitionColumnNames, + bucketSpec = table.bucketSpec, + options = table.storage.properties ++ pathOption, + catalogTable = if (tableExists) Some(table) else None, + origProvider = origProvider, + gpuFileFormat = gpuFileFormat) + try { + dataSource.writeAndRead(mode, query, outputColumnNames, physicalPlan) + } catch { + case ex: AnalysisException => + logError(s"Failed to write to table ${table.identifier.unquotedString}", ex) + throw ex + } + } + + private val isPartitioned = table.partitionColumnNames.nonEmpty + + private val isBucketed = table.bucketSpec.nonEmpty + + // use same logic as GpuInsertIntoHadoopFsRelationCommand + override def requireSingleBatch: Boolean = isPartitioned || isBucketed +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuDataSource.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuDataSource.scala new file mode 100644 index 00000000000..1d3eca46679 --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuDataSource.scala @@ -0,0 +1,655 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids + +import java.util.{Locale, ServiceConfigurationError, ServiceLoader} + +import scala.collection.JavaConverters._ +import scala.language.implicitConversions +import scala.util.{Failure, Success, Try} + +import com.nvidia.spark.rapids.{ColumnarFileFormat, GpuParquetFileFormat, ShimLoader} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkException +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.Logging +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.command.DataWritingCommand +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat +import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider +import org.apache.spark.sql.execution.datasources.json.JsonFileFormat +import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat +import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 +import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2 +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, TextSocketSourceProvider} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.{CalendarIntervalType, StructType} +import org.apache.spark.sql.util.SchemaUtils +import org.apache.spark.util.{ThreadUtils, Utils} + +/** + * A truncated version of Spark DataSource that converts to use the GPU version of + * InsertIntoHadoopFsRelationCommand for FileFormats we support. + * This does not support DataSource V2 writing at this point because at the time of + * copying, it did not. + */ +case class GpuDataSource( + sparkSession: SparkSession, + className: String, + paths: Seq[String] = Nil, + userSpecifiedSchema: Option[StructType] = None, + partitionColumns: Seq[String] = Seq.empty, + bucketSpec: Option[BucketSpec] = None, + options: Map[String, String] = Map.empty, + catalogTable: Option[CatalogTable] = None, + origProvider: Class[_], + gpuFileFormat: ColumnarFileFormat) extends Logging { + + private def originalProvidingInstance() = origProvider.getConstructor().newInstance() + + private def newHadoopConfiguration(): Configuration = + sparkSession.sessionState.newHadoopConfWithOptions(options) + + private val caseInsensitiveOptions = CaseInsensitiveMap(options) + private val equality = sparkSession.sessionState.conf.resolver + + /** + * Whether or not paths should be globbed before being used to access files. + */ + def globPaths: Boolean = { + options.get(GpuDataSource.GLOB_PATHS_KEY) + .map(_ == "true") + .getOrElse(true) + } + + bucketSpec.map { bucket => + SchemaUtils.checkColumnNameDuplication( + bucket.bucketColumnNames, "in the bucket definition", equality) + SchemaUtils.checkColumnNameDuplication( + bucket.sortColumnNames, "in the sort definition", equality) + } + + /** + * Get the schema of the given FileFormat, if provided by `userSpecifiedSchema`, or try to infer + * it. In the read path, only managed tables by Hive provide the partition columns properly when + * initializing this class. All other file based data sources will try to infer the partitioning, + * and then cast the inferred types to user specified dataTypes if the partition columns exist + * inside `userSpecifiedSchema`, otherwise we can hit data corruption bugs like SPARK-18510. + * This method will try to skip file scanning whether `userSpecifiedSchema` and + * `partitionColumns` are provided. Here are some code paths that use this method: + * 1. `spark.read` (no schema): Most amount of work. Infer both schema and partitioning columns + * 2. `spark.read.schema(userSpecifiedSchema)`: Parse partitioning columns, cast them to the + * dataTypes provided in `userSpecifiedSchema` if they exist or fallback to inferred + * dataType if they don't. + * 3. `spark.readStream.schema(userSpecifiedSchema)`: For streaming use cases, users have to + * provide the schema. Here, we also perform partition inference like 2, and try to use + * dataTypes in `userSpecifiedSchema`. All subsequent triggers for this stream will re-use + * this information, therefore calls to this method should be very cheap, i.e. there won't + * be any further inference in any triggers. + * + * @param format the file format object for this DataSource + * @param getFileIndex [[InMemoryFileIndex]] for getting partition schema and file list + * @return A pair of the data schema (excluding partition columns) and the schema of the partition + * columns. + */ + private def getOrInferFileFormatSchema( + format: FileFormat, + getFileIndex: () => InMemoryFileIndex): (StructType, StructType) = { + lazy val tempFileIndex = getFileIndex() + + val partitionSchema = if (partitionColumns.isEmpty) { + // Try to infer partitioning, because no DataSource in the read path provides the partitioning + // columns properly unless it is a Hive DataSource + tempFileIndex.partitionSchema + } else { + // maintain old behavior before SPARK-18510. If userSpecifiedSchema is empty used inferred + // partitioning + if (userSpecifiedSchema.isEmpty) { + val inferredPartitions = tempFileIndex.partitionSchema + inferredPartitions + } else { + val partitionFields = partitionColumns.map { partitionColumn => + userSpecifiedSchema.flatMap(_.find(c => equality(c.name, partitionColumn))).orElse { + val inferredPartitions = tempFileIndex.partitionSchema + val inferredOpt = inferredPartitions.find(p => equality(p.name, partitionColumn)) + if (inferredOpt.isDefined) { + logDebug( + s"""Type of partition column: $partitionColumn not found in specified schema + |for $format. + |User Specified Schema + |===================== + |${userSpecifiedSchema.orNull} + | + |Falling back to inferred dataType if it exists. + """.stripMargin) + } + inferredOpt + }.getOrElse { + throw new AnalysisException(s"Failed to resolve the schema for $format for " + + s"the partition column: $partitionColumn. It must be specified manually.") + } + } + StructType(partitionFields) + } + } + + val dataSchema = userSpecifiedSchema.map { schema => + StructType(schema.filterNot(f => partitionSchema.exists(p => equality(p.name, f.name)))) + }.orElse { + // Remove "path" option so that it is not added to the paths returned by + // `tempFileIndex.allFiles()`. + format.inferSchema( + sparkSession, + caseInsensitiveOptions - "path", + tempFileIndex.allFiles()) + }.getOrElse { + throw new AnalysisException( + s"Unable to infer schema for $format. It must be specified manually.") + } + + // We just print a waring message if the data schema and partition schema have the duplicate + // columns. This is because we allow users to do so in the previous Spark releases and + // we have the existing tests for the cases (e.g., `ParquetHadoopFsRelationSuite`). + // See SPARK-18108 and SPARK-21144 for related discussions. + try { + SchemaUtils.checkColumnNameDuplication( + (dataSchema ++ partitionSchema).map(_.name), + "in the data schema and the partition schema", + equality) + } catch { + case e: AnalysisException => logWarning(e.getMessage) + } + + (dataSchema, partitionSchema) + } + + + /** + * Create a resolved [[BaseRelation]] that can be used to read data from or write data into this + * [[DataSource]] + * + * @param checkFilesExist Whether to confirm that the files exist when generating the + * non-streaming file based datasource. StructuredStreaming jobs already + * list file existence, and when generating incremental jobs, the batch + * is considered as a non-streaming file based data source. Since we know + * that files already exist, we don't need to check them again. + */ + def resolveRelation(checkFilesExist: Boolean = true): BaseRelation = { + val relation = (originalProvidingInstance(), userSpecifiedSchema) match { + // TODO: Throw when too much is given. + case (dataSource: SchemaRelationProvider, Some(schema)) => + dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions, schema) + case (dataSource: RelationProvider, None) => + dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions) + case (_: SchemaRelationProvider, None) => + throw new AnalysisException(s"A schema needs to be specified when using $className.") + case (dataSource: RelationProvider, Some(schema)) => + val baseRelation = + dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions) + if (baseRelation.schema != schema) { + throw new AnalysisException( + "The user-specified schema doesn't match the actual schema: " + + s"user-specified: ${schema.toDDL}, actual: ${baseRelation.schema.toDDL}. If " + + "you're using DataFrameReader.schema API or creating a table, please do not " + + "specify the schema. Or if you're scanning an existed table, please drop " + + "it and re-create it.") + } + baseRelation + + // We are reading from the results of a streaming query. Load files from the metadata log + // instead of listing them using HDFS APIs. + case (format: FileFormat, _) + if FileStreamSink.hasMetadata( + caseInsensitiveOptions.get("path").toSeq ++ paths, + newHadoopConfiguration(), + sparkSession.sessionState.conf) => + val basePath = new Path((caseInsensitiveOptions.get("path").toSeq ++ paths).head) + val fileCatalog = new MetadataLogFileIndex(sparkSession, basePath, + caseInsensitiveOptions, userSpecifiedSchema) + val dataSchema = userSpecifiedSchema.orElse { + // Remove "path" option so that it is not added to the paths returned by + // `fileCatalog.allFiles()`. + format.inferSchema( + sparkSession, + caseInsensitiveOptions - "path", + fileCatalog.allFiles()) + }.getOrElse { + throw new AnalysisException( + s"Unable to infer schema for $format at ${fileCatalog.allFiles().mkString(",")}. " + + "It must be specified manually") + } + + HadoopFsRelation( + fileCatalog, + partitionSchema = fileCatalog.partitionSchema, + dataSchema = dataSchema, + bucketSpec = None, + format, + caseInsensitiveOptions)(sparkSession) + + // This is a non-streaming file based datasource. + case (format: FileFormat, _) => + val useCatalogFileIndex = sparkSession.sqlContext.conf.manageFilesourcePartitions && + catalogTable.isDefined && catalogTable.get.tracksPartitionsInCatalog && + catalogTable.get.partitionColumnNames.nonEmpty + val (fileCatalog, dataSchema, partitionSchema) = if (useCatalogFileIndex) { + val defaultTableSize = sparkSession.sessionState.conf.defaultSizeInBytes + val index = new CatalogFileIndex( + sparkSession, + catalogTable.get, + catalogTable.get.stats.map(_.sizeInBytes.toLong).getOrElse(defaultTableSize)) + (index, catalogTable.get.dataSchema, catalogTable.get.partitionSchema) + } else { + val globbedPaths = checkAndGlobPathIfNecessary( + checkEmptyGlobPath = true, checkFilesExist = checkFilesExist) + val index = createInMemoryFileIndex(globbedPaths) + val (resultDataSchema, resultPartitionSchema) = + getOrInferFileFormatSchema(format, () => index) + (index, resultDataSchema, resultPartitionSchema) + } + + HadoopFsRelation( + fileCatalog, + partitionSchema = partitionSchema, + dataSchema = dataSchema.asNullable, + bucketSpec = bucketSpec, + format, + caseInsensitiveOptions)(sparkSession) + + case _ => + throw new AnalysisException( + s"$className is not a valid Spark SQL Data Source.") + } + + relation match { + case hs: HadoopFsRelation => + ShimLoader.getSparkShims.checkColumnNameDuplication( + hs.dataSchema, + "in the data schema", + equality) + ShimLoader.getSparkShims.checkColumnNameDuplication( + hs.partitionSchema, + "in the partition schema", + equality) + DataSourceUtils.verifySchema(hs.fileFormat, hs.dataSchema) + case _ => + ShimLoader.getSparkShims.checkColumnNameDuplication( + relation.schema, + "in the data schema", + equality) + } + + relation + } + + /** + * Creates a command node to write the given [[LogicalPlan]] out to the given [[FileFormat]]. + * The returned command is unresolved and need to be analyzed. + */ + private def planForWritingFileFormat( + format: ColumnarFileFormat, + mode: SaveMode, + data: LogicalPlan): GpuInsertIntoHadoopFsRelationCommand = { + // Don't glob path for the write path. The contracts here are: + // 1. Only one output path can be specified on the write path; + // 2. Output path must be a legal HDFS style file system path; + // 3. It's OK that the output path doesn't exist yet; + val allPaths = paths ++ caseInsensitiveOptions.get("path") + val outputPath = if (allPaths.length == 1) { + val path = new Path(allPaths.head) + val fs = path.getFileSystem(newHadoopConfiguration()) + path.makeQualified(fs.getUri, fs.getWorkingDirectory) + } else { + throw new IllegalArgumentException("Expected exactly one path to be specified, but " + + s"got: ${allPaths.mkString(", ")}") + } + + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + PartitioningUtils.validatePartitionColumn(data.schema, partitionColumns, caseSensitive) + + val fileIndex = catalogTable.map(_.identifier).map { tableIdent => + sparkSession.table(tableIdent).queryExecution.analyzed.collect { + case LogicalRelation(t: HadoopFsRelation, _, _, _) => t.location + }.head + } + // For partitioned relation r, r.schema's column ordering can be different from the column + // ordering of data.logicalPlan (partition columns are all moved after data column). This + // will be adjusted within InsertIntoHadoopFsRelation. + GpuInsertIntoHadoopFsRelationCommand( + outputPath = outputPath, + staticPartitions = Map.empty, + ifPartitionNotExists = false, + partitionColumns = partitionColumns.map(UnresolvedAttribute.quoted), + bucketSpec = bucketSpec, + fileFormat = format, + options = options, + query = data, + mode = mode, + catalogTable = catalogTable, + fileIndex = fileIndex, + outputColumnNames = data.output.map(_.name)) + } + + /** + * Writes the given [[LogicalPlan]] out to this [[DataSource]] and returns a [[BaseRelation]] for + * the following reading. + * + * @param mode The save mode for this writing. + * @param data The input query plan that produces the data to be written. Note that this plan + * is analyzed and optimized. + * @param outputColumnNames The original output column names of the input query plan. The + * optimizer may not preserve the output column's names' case, so we need + * this parameter instead of `data.output`. + * @param physicalPlan The physical plan of the input query plan. We should run the writing + * command with this physical plan instead of creating a new physical plan, + * so that the metrics can be correctly linked to the given physical plan and + * shown in the web UI. + */ + def writeAndRead( + mode: SaveMode, + data: LogicalPlan, + outputColumnNames: Seq[String], + physicalPlan: SparkPlan): BaseRelation = { + val outputColumns = DataWritingCommand.logicalPlanOutputWithNames(data, outputColumnNames) + if (outputColumns.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { + throw new AnalysisException("Cannot save interval data type into external storage.") + } + + // Only currently support ColumnarFileFormat + val cmd = planForWritingFileFormat(gpuFileFormat, mode, data) + val resolvedPartCols = cmd.partitionColumns.map { col => + // The partition columns created in `planForWritingFileFormat` should always be + // `UnresolvedAttribute` with a single name part. + assert(col.isInstanceOf[UnresolvedAttribute]) + val unresolved = col.asInstanceOf[UnresolvedAttribute] + assert(unresolved.nameParts.length == 1) + val name = unresolved.nameParts.head + outputColumns.find(a => equality(a.name, name)).getOrElse { + throw new AnalysisException( + s"Unable to resolve $name given [${data.output.map(_.name).mkString(", ")}]") + } + } + val resolved = cmd.copy( + partitionColumns = resolvedPartCols, + outputColumnNames = outputColumnNames) + resolved.runColumnar(sparkSession, physicalPlan) + // Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring + copy(userSpecifiedSchema = Some(outputColumns.toStructType.asNullable)).resolveRelation() + } + + /** Returns an [[InMemoryFileIndex]] that can be used to get partition schema and file list. */ + private def createInMemoryFileIndex(globbedPaths: Seq[Path]): InMemoryFileIndex = { + val fileStatusCache = FileStatusCache.getOrCreate(sparkSession) + new InMemoryFileIndex( + sparkSession, globbedPaths, options, userSpecifiedSchema, fileStatusCache) + } + + /** + * Checks and returns files in all the paths. + */ + private def checkAndGlobPathIfNecessary( + checkEmptyGlobPath: Boolean, + checkFilesExist: Boolean): Seq[Path] = { + val allPaths = caseInsensitiveOptions.get("path") ++ paths + GpuDataSource.checkAndGlobPathIfNecessary(allPaths.toSeq, newHadoopConfiguration(), + checkEmptyGlobPath, checkFilesExist, enableGlobbing = globPaths) + } +} + +object GpuDataSource extends Logging { + + /** A map to maintain backward compatibility in case we move data sources around. */ + private val backwardCompatibilityMap: Map[String, String] = { + val jdbc = classOf[JdbcRelationProvider].getCanonicalName + val json = classOf[JsonFileFormat].getCanonicalName + val parquet = classOf[GpuParquetFileFormat].getCanonicalName + val csv = classOf[CSVFileFormat].getCanonicalName + val libsvm = "org.apache.spark.ml.source.libsvm.LibSVMFileFormat" + val orc = "org.apache.spark.sql.hive.orc.OrcFileFormat" + val nativeOrc = classOf[OrcFileFormat].getCanonicalName + val socket = classOf[TextSocketSourceProvider].getCanonicalName + val rate = classOf[RateStreamProvider].getCanonicalName + + Map( + "org.apache.spark.sql.jdbc" -> jdbc, + "org.apache.spark.sql.jdbc.DefaultSource" -> jdbc, + "org.apache.spark.sql.execution.datasources.jdbc.DefaultSource" -> jdbc, + "org.apache.spark.sql.execution.datasources.jdbc" -> jdbc, + "org.apache.spark.sql.json" -> json, + "org.apache.spark.sql.json.DefaultSource" -> json, + "org.apache.spark.sql.execution.datasources.json" -> json, + "org.apache.spark.sql.execution.datasources.json.DefaultSource" -> json, + "org.apache.spark.sql.parquet" -> parquet, + "org.apache.spark.sql.parquet.DefaultSource" -> parquet, + "org.apache.spark.sql.execution.datasources.parquet" -> parquet, + "org.apache.spark.sql.execution.datasources.parquet.DefaultSource" -> parquet, + "org.apache.spark.sql.hive.orc.DefaultSource" -> orc, + "org.apache.spark.sql.hive.orc" -> orc, + "org.apache.spark.sql.execution.datasources.orc.DefaultSource" -> nativeOrc, + "org.apache.spark.sql.execution.datasources.orc" -> nativeOrc, + "org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm, + "org.apache.spark.ml.source.libsvm" -> libsvm, + "com.databricks.spark.csv" -> csv, + "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket, + "org.apache.spark.sql.execution.streaming.RateSourceProvider" -> rate + ) + } + + /** + * Class that were removed in Spark 2.0. Used to detect incompatibility libraries for Spark 2.0. + */ + private val spark2RemovedClasses = Set( + "org.apache.spark.sql.DataFrame", + "org.apache.spark.sql.sources.HadoopFsRelationProvider", + "org.apache.spark.Logging") + + def lookupDataSourceWithFallback(className: String, conf: SQLConf): Class[_] = { + val cls = GpuDataSource.lookupDataSource(className, conf) + // `providingClass` is used for resolving data source relation for catalog tables. + // As now catalog for data source V2 is under development, here we fall back all the + // [[FileDataSourceV2]] to [[FileFormat]] to guarantee the current catalog works. + // [[FileDataSourceV2]] will still be used if we call the load()/save() method in + // [[DataFrameReader]]/[[DataFrameWriter]], since they use method `lookupDataSource` + // instead of `providingClass`. + val fallbackCls = cls.newInstance() match { + case f: FileDataSourceV2 => f.fallbackFileFormat + case _ => cls + } + // convert to GPU version + fallbackCls + } + + /** Given a provider name, look up the data source class definition. */ + def lookupDataSource(provider: String, conf: SQLConf): Class[_] = { + val provider1 = backwardCompatibilityMap.getOrElse(provider, provider) match { + case name if name.equalsIgnoreCase("orc") && + conf.getConf(SQLConf.ORC_IMPLEMENTATION) == "native" => + classOf[OrcDataSourceV2].getCanonicalName + case name if name.equalsIgnoreCase("orc") && + conf.getConf(SQLConf.ORC_IMPLEMENTATION) == "hive" => + "org.apache.spark.sql.hive.orc.OrcFileFormat" + case "com.databricks.spark.avro" if conf.replaceDatabricksSparkAvroEnabled => + "org.apache.spark.sql.avro.AvroFileFormat" + case name => name + } + val provider2 = s"$provider1.DefaultSource" + val loader = Utils.getContextOrSparkClassLoader + val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader) + + try { + serviceLoader.asScala.filter(_.shortName().equalsIgnoreCase(provider1)).toList match { + // the provider format did not match any given registered aliases + case Nil => + try { + Try(loader.loadClass(provider1)).orElse(Try(loader.loadClass(provider2))) match { + case Success(dataSource) => + // Found the data source using fully qualified path + dataSource + case Failure(error) => + if (provider1.startsWith("org.apache.spark.sql.hive.orc")) { + throw new AnalysisException( + "Hive built-in ORC data source must be used with Hive support enabled. " + + "Please use the native ORC data source by setting 'spark.sql.orc.impl' to " + + "'native'") + } else if (provider1.toLowerCase(Locale.ROOT) == "avro" || + provider1 == "com.databricks.spark.avro" || + provider1 == "org.apache.spark.sql.avro") { + throw new AnalysisException( + s"Failed to find data source: $provider1. Avro is built-in but external data " + + "source module since Spark 2.4. Please deploy the application as per " + + "the deployment section of \"Apache Avro Data Source Guide\".") + } else if (provider1.toLowerCase(Locale.ROOT) == "kafka") { + throw new AnalysisException( + s"Failed to find data source: $provider1. Please deploy the application as " + + "per the deployment section of " + + "\"Structured Streaming + Kafka Integration Guide\".") + } else { + throw new ClassNotFoundException( + s"Failed to find data source: $provider1. Please find packages at " + + "http://spark.apache.org/third-party-projects.html", + error) + } + } + } catch { + case e: NoClassDefFoundError => // This one won't be caught by Scala NonFatal + // NoClassDefFoundError's class name uses "/" rather than "." for packages + val className = e.getMessage.replaceAll("/", ".") + if (spark2RemovedClasses.contains(className)) { + throw new ClassNotFoundException(s"$className was removed in Spark 2.0. " + + "Please check if your library is compatible with Spark 2.0", e) + } else { + throw e + } + } + case head :: Nil => + // there is exactly one registered alias + head.getClass + case sources => + // There are multiple registered aliases for the input. If there is single datasource + // that has "org.apache.spark" package in the prefix, we use it considering it is an + // internal datasource within Spark. + val sourceNames = sources.map(_.getClass.getName) + val internalSources = sources.filter(_.getClass.getName.startsWith("org.apache.spark")) + if (internalSources.size == 1) { + logWarning(s"Multiple sources found for $provider1 (${sourceNames.mkString(", ")}), " + + s"defaulting to the internal datasource (${internalSources.head.getClass.getName}).") + internalSources.head.getClass + } else { + throw new AnalysisException(s"Multiple sources found for $provider1 " + + s"(${sourceNames.mkString(", ")}), please specify the fully qualified class name.") + } + } + } catch { + case e: ServiceConfigurationError if e.getCause.isInstanceOf[NoClassDefFoundError] => + // NoClassDefFoundError's class name uses "/" rather than "." for packages + val className = e.getCause.getMessage.replaceAll("/", ".") + if (spark2RemovedClasses.contains(className)) { + throw new ClassNotFoundException(s"Detected an incompatible DataSourceRegister. " + + "Please remove the incompatible library from classpath or upgrade it. " + + s"Error: ${e.getMessage}", e) + } else { + throw e + } + } + } + + /** + * The key in the "options" map for deciding whether or not to glob paths before use. + */ + val GLOB_PATHS_KEY = "__globPaths__" + + /** + * Checks and returns files in all the paths. + */ + private[sql] def checkAndGlobPathIfNecessary( + pathStrings: Seq[String], + hadoopConf: Configuration, + checkEmptyGlobPath: Boolean, + checkFilesExist: Boolean, + numThreads: Integer = 40, + enableGlobbing: Boolean): Seq[Path] = { + val qualifiedPaths = pathStrings.map { pathString => + val path = new Path(pathString) + val fs = path.getFileSystem(hadoopConf) + path.makeQualified(fs.getUri, fs.getWorkingDirectory) + } + + // Split the paths into glob and non glob paths, because we don't need to do an existence check + // for globbed paths. + val (globPaths, nonGlobPaths) = qualifiedPaths.partition(SparkHadoopUtil.get.isGlobPath) + + val globbedPaths = + try { + ThreadUtils.parmap(globPaths, "globPath", numThreads) { globPath => + val fs = globPath.getFileSystem(hadoopConf) + val globResult = if (enableGlobbing) { + SparkHadoopUtil.get.globPath(fs, globPath) + } else { + qualifiedPaths + } + + if (checkEmptyGlobPath && globResult.isEmpty) { + throw new AnalysisException(s"Path does not exist: $globPath") + } + + globResult + }.flatten + } catch { + case e: SparkException => throw e.getCause + } + + if (checkFilesExist) { + try { + ThreadUtils.parmap(nonGlobPaths, "checkPathsExist", numThreads) { path => + val fs = path.getFileSystem(hadoopConf) + if (!fs.exists(path)) { + throw new AnalysisException(s"Path does not exist: $path") + } + } + } catch { + case e: SparkException => throw e.getCause + } + } + + val allPaths = globbedPaths ++ nonGlobPaths + if (checkFilesExist) { + val (filteredOut, filteredIn) = allPaths.partition { path => + InMemoryFileIndex.shouldFilterOut(path.getName) + } + if (filteredIn.isEmpty) { + logWarning( + s"All paths were ignored:\n ${filteredOut.mkString("\n ")}") + } else { + logDebug( + s"Some paths were ignored:\n ${filteredOut.mkString("\n ")}") + } + } + + allPaths.toSeq + } + +}