diff --git a/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala b/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala index 9564c50d899..29ba139f3b8 100644 --- a/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala +++ b/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala @@ -22,10 +22,12 @@ import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.spark300.RapidsShuffleManager import org.apache.spark.SparkEnv +import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.HadoopFsRelation import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashJoin, SortMergeJoinExec} @@ -193,4 +195,11 @@ class Spark300Shims extends SparkShims { override def getRapidsShuffleManagerClass: String = { classOf[RapidsShuffleManager].getCanonicalName } + + override def injectQueryStagePrepRule( + extensions: SparkSessionExtensions, + ruleBuilder: SparkSession => Rule[SparkPlan]): Unit = { + // not supported in 3.0.0 but it doesn't matter because AdaptiveSparkPlanExec in 3.0.0 will + // never allow us to replace an Exchange node, so they just stay on CPU + } } diff --git a/shims/spark301/src/main/scala/com/nvidia/spark/rapids/shims/spark301/Spark301Shims.scala b/shims/spark301/src/main/scala/com/nvidia/spark/rapids/shims/spark301/Spark301Shims.scala index 381e3020345..dcbd954eb06 100644 --- a/shims/spark301/src/main/scala/com/nvidia/spark/rapids/shims/spark301/Spark301Shims.scala +++ b/shims/spark301/src/main/scala/com/nvidia/spark/rapids/shims/spark301/Spark301Shims.scala @@ -20,8 +20,11 @@ import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.shims.spark300.Spark300Shims import com.nvidia.spark.rapids.spark301.RapidsShuffleManager +import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.SparkPlan class Spark301Shims extends Spark300Shims { @@ -47,4 +50,10 @@ class Spark301Shims extends Spark300Shims { override def getRapidsShuffleManagerClass: String = { classOf[RapidsShuffleManager].getCanonicalName } + + override def injectQueryStagePrepRule( + extensions: SparkSessionExtensions, + ruleBuilder: SparkSession => Rule[SparkPlan]): Unit = { + extensions.injectQueryStagePrepRule(ruleBuilder) + } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 81fd95f9286..71d22f58621 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -1744,6 +1744,16 @@ object GpuOverrides { val execs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] = commonExecs ++ ShimLoader.getSparkShims.getExecs } +/** Tag the initial plan when AQE is enabled */ +case class GpuQueryStagePrepOverrides() extends Rule[SparkPlan] with Logging { + override def apply(plan: SparkPlan) :SparkPlan = { + // Note that we disregard the GPU plan returned here and instead rely on side effects of + // tagging the underlying SparkPlan. + GpuOverrides().apply(plan) + // return the original plan which is now modified as a side-effect of invoking GpuOverrides + plan + } +} case class GpuOverrides() extends Rule[SparkPlan] with Logging { override def apply(plan: SparkPlan) :SparkPlan = { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala index 4705aa25348..9832b8e0513 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala @@ -50,6 +50,7 @@ class SQLExecPlugin extends (SparkSessionExtensions => Unit) with Logging { logWarning("Installing extensions to enable rapids GPU SQL support." + s" To disable GPU support set `${RapidsConf.SQL_ENABLED}` to false") extensions.injectColumnar(_ => ColumnarOverrideRules()) + ShimLoader.getSparkShims.injectQueryStagePrepRule(extensions, _ => GpuQueryStagePrepOverrides()) } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala index 393307a05e3..259cdd2d1af 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala @@ -23,6 +23,7 @@ import com.nvidia.spark.rapids.GpuOverrides.isStringLit import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, ComplexTypeMergingExpression, Expression, String2TrimExpression, TernaryExpression, UnaryExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.DataWritingCommand @@ -116,12 +117,21 @@ abstract class RapidsMeta[INPUT <: BASE, BASE, OUTPUT <: BASE]( private var shouldBeRemovedReasons: Option[mutable.Set[String]] = None + val gpuSupportedTag = TreeNodeTag[String]("rapids.gpu.supported") + /** * Call this to indicate that this should not be replaced with a GPU enabled version * @param because why it should not be replaced. */ - final def willNotWorkOnGpu(because: String): Unit = + final def willNotWorkOnGpu(because: String): Unit = { cannotBeReplacedReasons.get.add(because) + // annotate the real spark plan with the reason as well so that the information is available + // during query stage planning when AQE is on + wrapped match { + case p: SparkPlan => p.setTagValue(gpuSupportedTag, because) + case _ => + } + } final def shouldBeRemoved(because: String): Unit = shouldBeRemovedReasons.get.add(because) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala index 878ada5e1a0..3628abdd2a4 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala @@ -16,9 +16,11 @@ package com.nvidia.spark.rapids +import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.rapids.execution.GpuBroadcastNestedLoopJoinExecBase @@ -68,4 +70,8 @@ trait SparkShims { endMapIndex: Int, startPartition: Int, endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] + + def injectQueryStagePrepRule( + extensions: SparkSessionExtensions, + rule: SparkSession => Rule[SparkPlan]) } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala index e0f0532ac1d..f6cf795629c 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala @@ -217,9 +217,18 @@ class GpuBroadcastMeta( case _: BroadcastNestedLoopJoinExec => true case _ => false } - if (!parent.exists(isSupported)) { - willNotWorkOnGpu("BroadcastExchange only works on the GPU if being used " + - "with a GPU version of BroadcastHashJoinExec or BroadcastNestedLoopJoinExec") + if (parent.isDefined) { + if (!parent.exists(isSupported)) { + willNotWorkOnGpu("BroadcastExchange only works on the GPU if being used " + + "with a GPU version of BroadcastHashJoinExec or BroadcastNestedLoopJoinExec") + } + } else { + // when AQE is enabled and we are planning a new query stage, parent will be None so + // we need to look at meta-data previously stored on the spark plan + wrapped.getTagValue(gpuSupportedTag) match { + case Some(reason) => willNotWorkOnGpu(reason) + case None => // this broadcast is supported on GPU + } } }