Skip to content

Commit

Permalink
Support replacing ORC format when Hive is configured (NVIDIA#1220)
Browse files Browse the repository at this point in the history
Signed-off-by: Jason Lowe <jlowe@nvidia.com>
  • Loading branch information
jlowe authored Dec 1, 2020
1 parent 7e63584 commit 23d4067
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 16 deletions.
6 changes: 4 additions & 2 deletions integration_tests/src/main/python/orc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@ def read_orc_sql(data_path):
@pytest.mark.parametrize('name', ['timestamp-date-test.orc'])
@pytest.mark.parametrize('read_func', [read_orc_df, read_orc_sql])
@pytest.mark.parametrize('v1_enabled_list', ["", "orc"])
def test_basic_read(std_input_path, name, read_func, v1_enabled_list):
@pytest.mark.parametrize('orc_impl', ["native", "hive"])
def test_basic_read(std_input_path, name, read_func, v1_enabled_list, orc_impl):
assert_gpu_and_cpu_are_equal_collect(
read_func(std_input_path + '/' + name),
conf={'spark.sql.sources.useV1SourceList': v1_enabled_list})
conf={'spark.sql.sources.useV1SourceList': v1_enabled_list,
'spark.sql.orc.impl': orc_impl})

orc_gens_list = [[byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)),
Expand Down
18 changes: 12 additions & 6 deletions integration_tests/src/main/python/orc_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@
pytest.param([timestamp_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/140'))]

@pytest.mark.parametrize('orc_gens', orc_write_gens_list, ids=idfn)
def test_write_round_trip(spark_tmp_path, orc_gens):
@pytest.mark.parametrize('orc_impl', ["native", "hive"])
def test_write_round_trip(spark_tmp_path, orc_gens, orc_impl):
gen_list = [('_c' + str(i), gen) for i, gen in enumerate(orc_gens)]
data_path = spark_tmp_path + '/ORC_DATA'
assert_gpu_and_cpu_writes_are_equal_collect(
lambda spark, path: gen_df(spark, gen_list).coalesce(1).write.orc(path),
lambda spark, path: spark.read.orc(path),
data_path)
data_path,
conf={'spark.sql.orc.impl': orc_impl})

orc_part_write_gens = [
byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen, boolean_gen,
Expand Down Expand Up @@ -70,10 +72,12 @@ def test_compress_write_round_trip(spark_tmp_path, compress):
conf={'spark.sql.orc.compression.codec': compress})

@pytest.mark.parametrize('orc_gens', orc_write_gens_list, ids=idfn)
def test_write_save_table(spark_tmp_path, orc_gens, spark_tmp_table_factory):
@pytest.mark.parametrize('orc_impl', ["native", "hive"])
def test_write_save_table(spark_tmp_path, orc_gens, orc_impl, spark_tmp_table_factory):
gen_list = [('_c' + str(i), gen) for i, gen in enumerate(orc_gens)]
data_path = spark_tmp_path + '/ORC_DATA'
all_confs={'spark.sql.sources.useV1SourceList': "orc"}
all_confs={'spark.sql.sources.useV1SourceList': "orc",
"spark.sql.orc.impl": orc_impl}
assert_gpu_and_cpu_writes_are_equal_collect(
lambda spark, path: gen_df(spark, gen_list).coalesce(1).write.format("orc").mode('overwrite').option("path", path).saveAsTable(spark_tmp_table_factory.get()),
lambda spark, path: spark.read.orc(path),
Expand All @@ -88,13 +92,15 @@ def write_orc_sql_from(spark, df, data_path, write_to_table):

@pytest.mark.parametrize('orc_gens', orc_write_gens_list, ids=idfn)
@pytest.mark.parametrize('ts_type', ["TIMESTAMP_MICROS", "TIMESTAMP_MILLIS"])
def test_write_sql_save_table(spark_tmp_path, orc_gens, ts_type, spark_tmp_table_factory):
@pytest.mark.parametrize('orc_impl', ["native", "hive"])
def test_write_sql_save_table(spark_tmp_path, orc_gens, ts_type, orc_impl, spark_tmp_table_factory):
gen_list = [('_c' + str(i), gen) for i, gen in enumerate(orc_gens)]
data_path = spark_tmp_path + '/ORC_DATA'
assert_gpu_and_cpu_writes_are_equal_collect(
lambda spark, path: write_orc_sql_from(spark, gen_df(spark, gen_list).coalesce(1), path, spark_tmp_table_factory.get()),
lambda spark, path: spark.read.orc(path),
data_path)
data_path,
conf={'spark.sql.orc.impl': orc_impl})

@allow_non_gpu('DataWritingCommandExec')
@pytest.mark.parametrize('codec', ['zlib', 'lzo'])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,9 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, CustomShuffleReaderExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, DataWritingCommand, DataWritingCommandExec, ExecutedCommandExec}
import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
import org.apache.spark.sql.execution.datasources.{FileFormat, InsertIntoHadoopFsRelationCommand}
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
import org.apache.spark.sql.execution.datasources.v2.{AlterNamespaceSetPropertiesExec, AlterTableExec, AtomicReplaceTableExec, BatchScanExec, CreateNamespaceExec, CreateTableExec, DeleteFromTableExec, DescribeNamespaceExec, DescribeTableExec, DropNamespaceExec, DropTableExec, RefreshTableExec, RenameTableExec, ReplaceTableExec, SetCatalogAndNamespaceExec, ShowCurrentNamespaceExec, ShowNamespacesExec, ShowTablePropertiesExec, ShowTablesExec}
Expand Down Expand Up @@ -300,7 +299,7 @@ final class InsertIntoHadoopFsRelationCommandMeta(
case _: JsonFileFormat =>
willNotWorkOnGpu("JSON output is not supported")
None
case _: OrcFileFormat =>
case f if GpuOrcFileFormat.isSparkOrcFormat(f) =>
GpuOrcFileFormat.tagGpuSupport(this, spark, cmd.options)
case _: ParquetFileFormat =>
GpuParquetFileFormat.tagGpuSupport(this, spark, cmd.options, cmd.query.schema)
Expand Down Expand Up @@ -357,9 +356,9 @@ final class CreateDataSourceTableAsSelectCommandMeta(
// Note that the data source V2 always fallsback to the V1 currently.
// If that changes then this will start failing because we don't have a mapping.
gpuProvider = origProvider.getConstructor().newInstance() match {
case format: OrcFileFormat =>
case f: FileFormat if GpuOrcFileFormat.isSparkOrcFormat(f) =>
GpuOrcFileFormat.tagGpuSupport(this, spark, cmd.table.storage.properties)
case format: ParquetFileFormat =>
case _: ParquetFileFormat =>
GpuParquetFileFormat.tagGpuSupport(this, spark,
cmd.table.storage.properties, cmd.query.schema)
case ds =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partition
import org.apache.spark.sql.execution.{ExecSubqueryExpression, ExplainUtils, FileSourceScanExec, SQLExecution}
import org.apache.spark.sql.execution.datasources.{BucketingUtils, DataSourceStrategy, DataSourceUtils, FileFormat, FilePartition, HadoopFsRelation, PartitionDirectory, PartitionedFile}
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -551,7 +550,7 @@ object GpuFileSourceScanExec {
def tagSupport(meta: SparkPlanMeta[FileSourceScanExec]): Unit = {
meta.wrapped.relation.fileFormat match {
case _: CSVFileFormat => GpuReadCSVFileFormat.tagSupport(meta)
case _: OrcFileFormat => GpuReadOrcFileFormat.tagSupport(meta)
case f if GpuOrcFileFormat.isSparkOrcFormat(f) => GpuReadOrcFileFormat.tagSupport(meta)
case _: ParquetFileFormat => GpuReadParquetFileFormat.tagSupport(meta)
case f =>
meta.willNotWorkOnGpu(s"unsupported file format: ${f.getClass.getCanonicalName}")
Expand All @@ -561,7 +560,7 @@ object GpuFileSourceScanExec {
def convertFileFormat(format: FileFormat): FileFormat = {
format match {
case _: CSVFileFormat => new GpuReadCSVFileFormat
case _: OrcFileFormat => new GpuReadOrcFileFormat
case f if GpuOrcFileFormat.isSparkOrcFormat(f) => new GpuReadOrcFileFormat
case _: ParquetFileFormat => new GpuReadParquetFileFormat
case f =>
throw new IllegalArgumentException(s"${f.getClass.getCanonicalName} is not supported")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,21 @@ import org.apache.orc.mapred.OrcStruct
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.datasources.FileFormat
import org.apache.spark.sql.execution.datasources.orc.{OrcFileFormat, OrcOptions, OrcUtils}
import org.apache.spark.sql.types.StructType

object GpuOrcFileFormat extends Logging {
// The classname used when Spark is configured to use the Hive implementation for ORC.
// Spark is not always compiled with Hive support so we cannot import from Spark jars directly.
private val HIVE_IMPL_CLASS = "org.apache.spark.sql.hive.orc.OrcFileFormat"

def isSparkOrcFormat(format: FileFormat): Boolean = format match {
case _: OrcFileFormat => true
case f if f.getClass.getCanonicalName.equals(HIVE_IMPL_CLASS) => true
case _ => false
}

def tagGpuSupport(meta: RapidsMeta[_, _, _],
spark: SparkSession,
options: Map[String, String]): Option[GpuOrcFileFormat] = {
Expand Down

0 comments on commit 23d4067

Please sign in to comment.