Skip to content

Commit

Permalink
Update GpuDataSourceScanExec and GpuBroadcastExchangeExec to fix audi…
Browse files Browse the repository at this point in the history
…t issues (#1760)

Signed-off-by: Niranjan Artal <nartal@nvidia.com>
  • Loading branch information
nartal1 authored Feb 19, 2021
1 parent 5758be5 commit b4d996b
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleEx
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashJoin, SortMergeJoinExec}
import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec
import org.apache.spark.sql.execution.python.WindowInPandasExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.{GpuFileSourceScanExec, GpuStringReplace, GpuTimeSub, ShuffleManagerShimBase}
import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExecBase, GpuBroadcastNestedLoopJoinExecBase, GpuShuffleExchangeExecBase}
import org.apache.spark.sql.rapids.execution.python.GpuWindowInPandasExecMetaBase
Expand Down Expand Up @@ -400,6 +401,9 @@ class Spark300Shims extends SparkShims {
new FileScanRDD(sparkSession, readFunction, filePartitions)
}

// Hardcoded for Spark-3.0.*
override def getFileSourceMaxMetadataValueLength(sqlConf: SQLConf): Int = 100

override def createFilePartition(index: Int, files: Array[PartitionedFile]): FilePartition = {
FilePartition(index, files)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ class Spark311Shims extends Spark301Shims {
}
}

override def getFileSourceMaxMetadataValueLength(sqlConf: SQLConf): Int =
sqlConf.maxMetadataStringLength

def exprs311: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Seq(
GpuOverrides.expr[Cast](
"Convert a column of one type of data into another type",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
import org.apache.spark.sql.execution.datasources.{FileIndex, FilePartition, HadoopFsRelation, PartitionDirectory, PartitionedFile}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.{GpuFileSourceScanExec, ShuffleManagerShimBase}
import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExecBase, GpuBroadcastNestedLoopJoinExecBase, GpuShuffleExchangeExecBase}
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -138,6 +139,8 @@ trait SparkShims {
readFunction: (PartitionedFile) => Iterator[InternalRow],
filePartitions: Seq[FilePartition]): RDD[InternalRow]

def getFileSourceMaxMetadataValueLength(sqlConf: SQLConf): Int

def copyParquetBatchScanExec(
batchScanExec: GpuBatchScanExec,
queryUsesInputFile: Boolean): GpuBatchScanExec
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.apache.spark.sql.rapids

import com.nvidia.spark.rapids.ShimLoader
import org.apache.commons.lang3.StringUtils
import org.apache.hadoop.fs.Path

Expand All @@ -40,7 +41,8 @@ trait GpuDataSourceScanExec extends LeafExecNode {
// Metadata that describes more details of this scan.
protected def metadata: Map[String, String]

protected val maxMetadataValueLength = 100
protected val maxMetadataValueLength = ShimLoader.getSparkShims
.getFileSourceMaxMetadataValueLength(sqlContext.sessionState.conf)

override def simpleString(maxFields: Int): String = {
val metadataEntries = metadata.toSeq.sorted.map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.GpuMetric._
import com.nvidia.spark.rapids.RapidsPluginImplicits._

import org.apache.spark.SparkException
import org.apache.spark.{SparkContext, SparkException}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.launcher.SparkLauncher
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -265,7 +265,10 @@ abstract class GpuBroadcastExchangeExecBase(
@transient
private val timeout: Long = SQLConf.get.broadcastTimeout

val _runId: UUID = UUID.randomUUID()
// Cancelling a SQL statement from Spark ThriftServer needs to cancel
// its related broadcast sub-jobs. So set the run id to job group id if exists.
val _runId: UUID = Option(sparkContext.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID))
.map(UUID.fromString).getOrElse(UUID.randomUUID)

@transient
lazy val relationFuture: Future[Broadcast[Any]] = {
Expand Down

0 comments on commit b4d996b

Please sign in to comment.