diff --git a/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala b/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala index 04398faaeae..180accf98ae 100644 --- a/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala +++ b/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala @@ -48,6 +48,7 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleEx import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashJoin, SortMergeJoinExec} import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec import org.apache.spark.sql.execution.python.WindowInPandasExec +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.rapids.{GpuFileSourceScanExec, GpuStringReplace, GpuTimeSub, ShuffleManagerShimBase} import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExecBase, GpuBroadcastNestedLoopJoinExecBase, GpuShuffleExchangeExecBase} import org.apache.spark.sql.rapids.execution.python.GpuWindowInPandasExecMetaBase @@ -400,6 +401,9 @@ class Spark300Shims extends SparkShims { new FileScanRDD(sparkSession, readFunction, filePartitions) } + // Hardcoded for Spark-3.0.* + override def getFileSourceMaxMetadataValueLength(sqlConf: SQLConf): Int = 100 + override def createFilePartition(index: Int, files: Array[PartitionedFile]): FilePartition = { FilePartition(index, files) } diff --git a/shims/spark311/src/main/scala/com/nvidia/spark/rapids/shims/spark311/Spark311Shims.scala b/shims/spark311/src/main/scala/com/nvidia/spark/rapids/shims/spark311/Spark311Shims.scala index 4db19e8b0b7..eee1772b019 100644 --- a/shims/spark311/src/main/scala/com/nvidia/spark/rapids/shims/spark311/Spark311Shims.scala +++ b/shims/spark311/src/main/scala/com/nvidia/spark/rapids/shims/spark311/Spark311Shims.scala @@ -96,6 +96,9 @@ class Spark311Shims extends Spark301Shims { } } + override def getFileSourceMaxMetadataValueLength(sqlConf: SQLConf): Int = + sqlConf.maxMetadataStringLength + def exprs311: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Seq( GpuOverrides.expr[Cast]( "Convert a column of one type of data into another type", diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala index ccd8cc1881c..07ee8a42786 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec import org.apache.spark.sql.execution.datasources.{FileIndex, FilePartition, HadoopFsRelation, PartitionDirectory, PartitionedFile} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.rapids.{GpuFileSourceScanExec, ShuffleManagerShimBase} import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExecBase, GpuBroadcastNestedLoopJoinExecBase, GpuShuffleExchangeExecBase} import org.apache.spark.sql.types._ @@ -138,6 +139,8 @@ trait SparkShims { readFunction: (PartitionedFile) => Iterator[InternalRow], filePartitions: Seq[FilePartition]): RDD[InternalRow] + def getFileSourceMaxMetadataValueLength(sqlConf: SQLConf): Int + def copyParquetBatchScanExec( batchScanExec: GpuBatchScanExec, queryUsesInputFile: Boolean): GpuBatchScanExec diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuDataSourceScanExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuDataSourceScanExec.scala index 0acdc171ad8..01f23946222 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuDataSourceScanExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuDataSourceScanExec.scala @@ -16,6 +16,7 @@ package org.apache.spark.sql.rapids +import com.nvidia.spark.rapids.ShimLoader import org.apache.commons.lang3.StringUtils import org.apache.hadoop.fs.Path @@ -40,7 +41,8 @@ trait GpuDataSourceScanExec extends LeafExecNode { // Metadata that describes more details of this scan. protected def metadata: Map[String, String] - protected val maxMetadataValueLength = 100 + protected val maxMetadataValueLength = ShimLoader.getSparkShims + .getFileSourceMaxMetadataValueLength(sqlContext.sessionState.conf) override def simpleString(maxFields: Int): String = { val metadataEntries = metadata.toSeq.sorted.map { diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala index b72b84296c9..ff761e769fd 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala @@ -29,7 +29,7 @@ import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.GpuMetric._ import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import org.apache.spark.SparkException +import org.apache.spark.{SparkContext, SparkException} import org.apache.spark.broadcast.Broadcast import org.apache.spark.launcher.SparkLauncher import org.apache.spark.rdd.RDD @@ -265,7 +265,10 @@ abstract class GpuBroadcastExchangeExecBase( @transient private val timeout: Long = SQLConf.get.broadcastTimeout - val _runId: UUID = UUID.randomUUID() + // Cancelling a SQL statement from Spark ThriftServer needs to cancel + // its related broadcast sub-jobs. So set the run id to job group id if exists. + val _runId: UUID = Option(sparkContext.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID)) + .map(UUID.fromString).getOrElse(UUID.randomUUID) @transient lazy val relationFuture: Future[Broadcast[Any]] = {