diff --git a/integration_tests/src/main/python/orc_test.py b/integration_tests/src/main/python/orc_test.py index 661c124a9fd..716d27a65e7 100644 --- a/integration_tests/src/main/python/orc_test.py +++ b/integration_tests/src/main/python/orc_test.py @@ -30,10 +30,12 @@ def read_orc_sql(data_path): @pytest.mark.parametrize('name', ['timestamp-date-test.orc']) @pytest.mark.parametrize('read_func', [read_orc_df, read_orc_sql]) @pytest.mark.parametrize('v1_enabled_list', ["", "orc"]) -def test_basic_read(std_input_path, name, read_func, v1_enabled_list): +@pytest.mark.parametrize('orc_impl', ["native", "hive"]) +def test_basic_read(std_input_path, name, read_func, v1_enabled_list, orc_impl): assert_gpu_and_cpu_are_equal_collect( read_func(std_input_path + '/' + name), - conf={'spark.sql.sources.useV1SourceList': v1_enabled_list}) + conf={'spark.sql.sources.useV1SourceList': v1_enabled_list, + 'spark.sql.orc.impl': orc_impl}) orc_gens_list = [[byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)), diff --git a/integration_tests/src/main/python/orc_write_test.py b/integration_tests/src/main/python/orc_write_test.py index 93a8a948426..0d3bc94cd96 100644 --- a/integration_tests/src/main/python/orc_write_test.py +++ b/integration_tests/src/main/python/orc_write_test.py @@ -28,13 +28,15 @@ pytest.param([timestamp_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/140'))] @pytest.mark.parametrize('orc_gens', orc_write_gens_list, ids=idfn) -def test_write_round_trip(spark_tmp_path, orc_gens): +@pytest.mark.parametrize('orc_impl', ["native", "hive"]) +def test_write_round_trip(spark_tmp_path, orc_gens, orc_impl): gen_list = [('_c' + str(i), gen) for i, gen in enumerate(orc_gens)] data_path = spark_tmp_path + '/ORC_DATA' assert_gpu_and_cpu_writes_are_equal_collect( lambda spark, path: gen_df(spark, gen_list).coalesce(1).write.orc(path), lambda spark, path: spark.read.orc(path), - data_path) + data_path, + conf={'spark.sql.orc.impl': orc_impl}) orc_part_write_gens = [ byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, boolean_gen, @@ -70,10 +72,12 @@ def test_compress_write_round_trip(spark_tmp_path, compress): conf={'spark.sql.orc.compression.codec': compress}) @pytest.mark.parametrize('orc_gens', orc_write_gens_list, ids=idfn) -def test_write_save_table(spark_tmp_path, orc_gens, spark_tmp_table_factory): +@pytest.mark.parametrize('orc_impl', ["native", "hive"]) +def test_write_save_table(spark_tmp_path, orc_gens, orc_impl, spark_tmp_table_factory): gen_list = [('_c' + str(i), gen) for i, gen in enumerate(orc_gens)] data_path = spark_tmp_path + '/ORC_DATA' - all_confs={'spark.sql.sources.useV1SourceList': "orc"} + all_confs={'spark.sql.sources.useV1SourceList': "orc", + "spark.sql.orc.impl": orc_impl} assert_gpu_and_cpu_writes_are_equal_collect( lambda spark, path: gen_df(spark, gen_list).coalesce(1).write.format("orc").mode('overwrite').option("path", path).saveAsTable(spark_tmp_table_factory.get()), lambda spark, path: spark.read.orc(path), @@ -88,13 +92,15 @@ def write_orc_sql_from(spark, df, data_path, write_to_table): @pytest.mark.parametrize('orc_gens', orc_write_gens_list, ids=idfn) @pytest.mark.parametrize('ts_type', ["TIMESTAMP_MICROS", "TIMESTAMP_MILLIS"]) -def test_write_sql_save_table(spark_tmp_path, orc_gens, ts_type, spark_tmp_table_factory): +@pytest.mark.parametrize('orc_impl', ["native", "hive"]) +def test_write_sql_save_table(spark_tmp_path, orc_gens, ts_type, orc_impl, spark_tmp_table_factory): gen_list = [('_c' + str(i), gen) for i, gen in enumerate(orc_gens)] data_path = spark_tmp_path + '/ORC_DATA' assert_gpu_and_cpu_writes_are_equal_collect( lambda spark, path: write_orc_sql_from(spark, gen_df(spark, gen_list).coalesce(1), path, spark_tmp_table_factory.get()), lambda spark, path: spark.read.orc(path), - data_path) + data_path, + conf={'spark.sql.orc.impl': orc_impl}) @allow_non_gpu('DataWritingCommandExec') @pytest.mark.parametrize('codec', ['zlib', 'lzo']) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index a812300c37f..c847ae6a5aa 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -34,10 +34,9 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, CustomShuffleReaderExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, DataWritingCommand, DataWritingCommandExec, ExecutedCommandExec} -import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand +import org.apache.spark.sql.execution.datasources.{FileFormat, InsertIntoHadoopFsRelationCommand} import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.json.JsonFileFormat -import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.execution.datasources.v2.{AlterNamespaceSetPropertiesExec, AlterTableExec, AtomicReplaceTableExec, BatchScanExec, CreateNamespaceExec, CreateTableExec, DeleteFromTableExec, DescribeNamespaceExec, DescribeTableExec, DropNamespaceExec, DropTableExec, RefreshTableExec, RenameTableExec, ReplaceTableExec, SetCatalogAndNamespaceExec, ShowCurrentNamespaceExec, ShowNamespacesExec, ShowTablePropertiesExec, ShowTablesExec} @@ -300,7 +299,7 @@ final class InsertIntoHadoopFsRelationCommandMeta( case _: JsonFileFormat => willNotWorkOnGpu("JSON output is not supported") None - case _: OrcFileFormat => + case f if GpuOrcFileFormat.isSparkOrcFormat(f) => GpuOrcFileFormat.tagGpuSupport(this, spark, cmd.options) case _: ParquetFileFormat => GpuParquetFileFormat.tagGpuSupport(this, spark, cmd.options, cmd.query.schema) @@ -357,9 +356,9 @@ final class CreateDataSourceTableAsSelectCommandMeta( // Note that the data source V2 always fallsback to the V1 currently. // If that changes then this will start failing because we don't have a mapping. gpuProvider = origProvider.getConstructor().newInstance() match { - case format: OrcFileFormat => + case f: FileFormat if GpuOrcFileFormat.isSparkOrcFormat(f) => GpuOrcFileFormat.tagGpuSupport(this, spark, cmd.table.storage.properties) - case format: ParquetFileFormat => + case _: ParquetFileFormat => GpuParquetFileFormat.tagGpuSupport(this, spark, cmd.table.storage.properties, cmd.query.schema) case ds => diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileSourceScanExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileSourceScanExec.scala index ae283186836..6a3f274487a 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileSourceScanExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileSourceScanExec.scala @@ -32,7 +32,6 @@ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partition import org.apache.spark.sql.execution.{ExecSubqueryExpression, ExplainUtils, FileSourceScanExec, SQLExecution} import org.apache.spark.sql.execution.datasources.{BucketingUtils, DataSourceStrategy, DataSourceUtils, FileFormat, FilePartition, HadoopFsRelation, PartitionDirectory, PartitionedFile} import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat -import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.internal.SQLConf @@ -551,7 +550,7 @@ object GpuFileSourceScanExec { def tagSupport(meta: SparkPlanMeta[FileSourceScanExec]): Unit = { meta.wrapped.relation.fileFormat match { case _: CSVFileFormat => GpuReadCSVFileFormat.tagSupport(meta) - case _: OrcFileFormat => GpuReadOrcFileFormat.tagSupport(meta) + case f if GpuOrcFileFormat.isSparkOrcFormat(f) => GpuReadOrcFileFormat.tagSupport(meta) case _: ParquetFileFormat => GpuReadParquetFileFormat.tagSupport(meta) case f => meta.willNotWorkOnGpu(s"unsupported file format: ${f.getClass.getCanonicalName}") @@ -561,7 +560,7 @@ object GpuFileSourceScanExec { def convertFileFormat(format: FileFormat): FileFormat = { format match { case _: CSVFileFormat => new GpuReadCSVFileFormat - case _: OrcFileFormat => new GpuReadOrcFileFormat + case f if GpuOrcFileFormat.isSparkOrcFormat(f) => new GpuReadOrcFileFormat case _: ParquetFileFormat => new GpuReadParquetFileFormat case f => throw new IllegalArgumentException(s"${f.getClass.getCanonicalName} is not supported") diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala index dfcb3f0400d..80e11ff338d 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala @@ -27,10 +27,21 @@ import org.apache.orc.mapred.OrcStruct import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.execution.datasources.orc.{OrcFileFormat, OrcOptions, OrcUtils} import org.apache.spark.sql.types.StructType object GpuOrcFileFormat extends Logging { + // The classname used when Spark is configured to use the Hive implementation for ORC. + // Spark is not always compiled with Hive support so we cannot import from Spark jars directly. + private val HIVE_IMPL_CLASS = "org.apache.spark.sql.hive.orc.OrcFileFormat" + + def isSparkOrcFormat(format: FileFormat): Boolean = format match { + case _: OrcFileFormat => true + case f if f.getClass.getCanonicalName.equals(HIVE_IMPL_CLASS) => true + case _ => false + } + def tagGpuSupport(meta: RapidsMeta[_, _, _], spark: SparkSession, options: Map[String, String]): Option[GpuOrcFileFormat] = {