diff --git a/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala b/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala index 3a3bcf6811f3..4485fbdb7ead 100644 --- a/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala +++ b/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala @@ -96,7 +96,7 @@ class Spark300Shims extends SparkShims { override def getGpuShuffleExchangeExec( outputPartitioning: Partitioning, child: SparkPlan, - canChangeNumPartitions: Boolean): GpuShuffleExchangeExecBase = { + cpuShuffle: Option[ShuffleExchangeExec]): GpuShuffleExchangeExecBase = { GpuShuffleExchangeExec(outputPartitioning, child) } @@ -108,21 +108,21 @@ class Spark300Shims extends SparkShims { override def isGpuHashJoin(plan: SparkPlan): Boolean = { plan match { case _: GpuHashJoin => true - case p => false + case _ => false } } override def isGpuBroadcastHashJoin(plan: SparkPlan): Boolean = { plan match { case _: GpuBroadcastHashJoinExec => true - case p => false + case _ => false } } override def isGpuShuffledHashJoin(plan: SparkPlan): Boolean = { plan match { case _: GpuShuffledHashJoinExec => true - case p => false + case _ => false } } @@ -381,7 +381,7 @@ class Spark300Shims extends SparkShims { override def getFileScanRDD( sparkSession: SparkSession, - readFunction: (PartitionedFile) => Iterator[InternalRow], + readFunction: PartitionedFile => Iterator[InternalRow], filePartitions: Seq[FilePartition]): RDD[InternalRow] = { new FileScanRDD(sparkSession, readFunction, filePartitions) } diff --git a/shims/spark301/src/main/scala/com/nvidia/spark/rapids/shims/spark301/Spark301Shims.scala b/shims/spark301/src/main/scala/com/nvidia/spark/rapids/shims/spark301/Spark301Shims.scala index 55cc8ec8ee12..e3aca38466f0 100644 --- a/shims/spark301/src/main/scala/com/nvidia/spark/rapids/shims/spark301/Spark301Shims.scala +++ b/shims/spark301/src/main/scala/com/nvidia/spark/rapids/shims/spark301/Spark301Shims.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ShuffleExchangeLike} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExecBase, GpuShuffleExchangeExecBase} import org.apache.spark.sql.types.DataType @@ -102,7 +102,8 @@ class Spark301Shims extends Spark300Shims { override def getGpuShuffleExchangeExec( outputPartitioning: Partitioning, child: SparkPlan, - canChangeNumPartitions: Boolean): GpuShuffleExchangeExecBase = { + cpuShuffle: Option[ShuffleExchangeExec]): GpuShuffleExchangeExecBase = { + val canChangeNumPartitions = cpuShuffle.forall(_.canChangeNumPartitions) GpuShuffleExchangeExec(outputPartitioning, child, canChangeNumPartitions) } diff --git a/shims/spark301db/src/main/scala/com/nvidia/spark/rapids/shims/spark301db/Spark301dbShims.scala b/shims/spark301db/src/main/scala/com/nvidia/spark/rapids/shims/spark301db/Spark301dbShims.scala index c1812fe9b28b..6a2ae17565cf 100644 --- a/shims/spark301db/src/main/scala/com/nvidia/spark/rapids/shims/spark301db/Spark301dbShims.scala +++ b/shims/spark301db/src/main/scala/com/nvidia/spark/rapids/shims/spark301db/Spark301dbShims.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec import org.apache.spark.sql.execution.datasources.{FilePartition, HadoopFsRelation, PartitionDirectory, PartitionedFile} +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashJoin, SortMergeJoinExec} import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec import org.apache.spark.sql.execution.python.WindowInPandasExec @@ -216,7 +217,8 @@ class Spark301dbShims extends Spark301Shims { override def getGpuShuffleExchangeExec( outputPartitioning: Partitioning, child: SparkPlan, - canChangeNumPartitions: Boolean): GpuShuffleExchangeExecBase = { + cpuShuffle: Option[ShuffleExchangeExec]): GpuShuffleExchangeExecBase = { + val canChangeNumPartitions = cpuShuffle.forall(_.canChangeNumPartitions) GpuShuffleExchangeExec(outputPartitioning, child, canChangeNumPartitions) } diff --git a/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/Spark310Shims.scala b/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/Spark310Shims.scala index cff0366eecff..9280e71d36d2 100644 --- a/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/Spark310Shims.scala +++ b/shims/spark310/src/main/scala/com/nvidia/spark/rapids/shims/spark310/Spark310Shims.scala @@ -25,16 +25,18 @@ import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.datasources.HadoopFsRelation import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan +import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.internal.StaticSQLConf import org.apache.spark.sql.rapids.{GpuFileSourceScanExec, GpuStringReplace, ShuffleManagerShimBase} -import org.apache.spark.sql.rapids.execution.GpuBroadcastNestedLoopJoinExecBase +import org.apache.spark.sql.rapids.execution.{GpuBroadcastNestedLoopJoinExecBase, GpuShuffleExchangeExecBase} import org.apache.spark.sql.rapids.shims.spark310._ import org.apache.spark.sql.types._ import org.apache.spark.storage.{BlockId, BlockManagerId} @@ -79,21 +81,21 @@ class Spark310Shims extends Spark301Shims { override def isGpuHashJoin(plan: SparkPlan): Boolean = { plan match { case _: GpuHashJoin => true - case p => false + case _ => false } } override def isGpuBroadcastHashJoin(plan: SparkPlan): Boolean = { plan match { case _: GpuBroadcastHashJoinExec => true - case p => false + case _ => false } } override def isGpuShuffledHashJoin(plan: SparkPlan): Boolean = { plan match { case _: GpuShuffledHashJoinExec => true - case p => false + case _ => false } } @@ -289,4 +291,11 @@ class Spark310Shims extends Spark301Shims { GpuSchemaUtils.checkColumnNameDuplication(schema, colType, resolver) } + override def getGpuShuffleExchangeExec( + outputPartitioning: Partitioning, + child: SparkPlan, + cpuShuffle: Option[ShuffleExchangeExec]): GpuShuffleExchangeExecBase = { + val shuffleOrigin = cpuShuffle.map(_.shuffleOrigin).getOrElse(ENSURE_REQUIREMENTS) + GpuShuffleExchangeExec(outputPartitioning, child, shuffleOrigin) + } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala index 09b8ef2f3df3..b0a9848e6913 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec import org.apache.spark.sql.execution.datasources.{FilePartition, HadoopFsRelation, PartitionDirectory, PartitionedFile} +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.rapids.{GpuFileSourceScanExec, ShuffleManagerShimBase} import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExecBase, GpuBroadcastNestedLoopJoinExecBase, GpuShuffleExchangeExecBase} @@ -101,7 +102,7 @@ trait SparkShims { def getGpuShuffleExchangeExec( outputPartitioning: Partitioning, child: SparkPlan, - canChangeNumPartitions: Boolean = true): GpuShuffleExchangeExecBase + cpuShuffle: Option[ShuffleExchangeExec] = None): GpuShuffleExchangeExecBase def getGpuShuffleExchangeExec( queryStage: ShuffleQueryStageExec): GpuShuffleExchangeExecBase diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExec.scala index 641cd8aec377..a93f94612796 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExec.scala @@ -65,7 +65,7 @@ class GpuShuffleMeta( ShimLoader.getSparkShims.getGpuShuffleExchangeExec( childParts(0).convertToGpu(), childPlans(0).convertIfNeeded(), - shuffle.canChangeNumPartitions) + Some(shuffle)) } /**