Skip to content

Commit

Permalink
Add getOrcSchemaString for OrcShims [databricks] (#5066)
Browse files Browse the repository at this point in the history
* Add getOrcSchemaString for OrcShims

Signed-off-by: Bobby Wang <wbo4958@gmail.com>
  • Loading branch information
wbo4958 authored Mar 28, 2022
1 parent 35d777e commit 0063053
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ import org.apache.orc.{CompressionCodec, CompressionKind, DataReader, OrcFile, O
import org.apache.orc.impl.{DataReaderProperties, OutStream, SchemaEvolution}
import org.apache.orc.impl.RecordReaderImpl.SargApplier

import org.apache.spark.sql.execution.datasources.orc.OrcUtils
import org.apache.spark.sql.types.DataType

trait OrcShims311until320Base {

// read data to buffer
Expand Down Expand Up @@ -96,4 +99,10 @@ trait OrcShims311until320Base {
def forcePositionalEvolution(conf:Configuration): Boolean = {
false
}

// orcTypeDescriptionString is renamed to getOrcSchemaString from 3.3+
def getOrcSchemaString(dt: DataType): String = {
OrcUtils.orcTypeDescriptionString(dt)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.shims

import scala.collection.mutable.ArrayBuffer
Expand All @@ -27,8 +28,7 @@ import org.apache.orc.impl.RecordReaderImpl.SargApplier
import org.apache.orc.impl.reader.StripePlanner
import org.apache.orc.impl.writer.StreamOptions

// 320+ ORC shims
object OrcShims {
trait OrcShims320untilAllBase {

// the ORC Reader in non-CDH Spark is closeable
def withReader[T <: Reader, V](r: T)(block: T => V): V = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.nvidia.spark.rapids.shims

import org.apache.spark.sql.execution.datasources.orc.OrcUtils
import org.apache.spark.sql.types.DataType

// 320+ ORC shims
object OrcShims extends OrcShims320untilAllBase {

// orcTypeDescriptionString is renamed to getOrcSchemaString from 3.3+
def getOrcSchemaString(dt: DataType): String = {
OrcUtils.orcTypeDescriptionString(dt)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.nvidia.spark.rapids.shims

import org.apache.spark.sql.execution.datasources.orc.OrcUtils
import org.apache.spark.sql.types.DataType

// 330+ ORC shims
object OrcShims extends OrcShims320untilAllBase {

// orcTypeDescriptionString is renamed to getOrcSchemaString from 3.3+
def getOrcSchemaString(dt: DataType): String = {
OrcUtils.getOrcSchemaString(dt)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.execution.QueryExecutionException
import org.apache.spark.sql.execution.datasources.PartitionedFile
import org.apache.spark.sql.execution.datasources.orc.OrcUtils
import org.apache.spark.sql.execution.datasources.rapids.OrcFiltersWrapper
import org.apache.spark.sql.execution.datasources.v2.{EmptyPartitionReader, FilePartitionReaderFactory}
import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
Expand Down Expand Up @@ -930,9 +929,9 @@ private case class GpuOrcFileFilterHandler(
partitionSchema: StructType,
conf: Configuration): String = {
val resultSchemaString = if (canPruneCols) {
OrcUtils.orcTypeDescriptionString(readDataSchema)
OrcShims.getOrcSchemaString(readDataSchema)
} else {
OrcUtils.orcTypeDescriptionString(StructType(dataSchema.fields ++ partitionSchema.fields))
OrcShims.getOrcSchemaString(StructType(dataSchema.fields ++ partitionSchema.fields))
}
OrcConf.MAPRED_INPUT_SCHEMA.setString(conf, resultSchemaString)
resultSchemaString
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql.rapids

import ai.rapids.cudf._
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.shims.OrcShims
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
import org.apache.orc.OrcConf
Expand Down Expand Up @@ -129,7 +130,7 @@ class GpuOrcFileFormat extends ColumnarFileFormat with Logging {

val conf = job.getConfiguration

conf.set(MAPRED_OUTPUT_SCHEMA.getAttribute, OrcUtils.orcTypeDescriptionString(dataSchema))
conf.set(MAPRED_OUTPUT_SCHEMA.getAttribute, OrcShims.getOrcSchemaString(dataSchema))

conf.set(COMPRESS.getAttribute, orcOptions.compressionCodec)

Expand Down

0 comments on commit 0063053

Please sign in to comment.