Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preliminary support for keeping broadcast exchanges on GPU when AQE is enabled #448

Merged
merged 6 commits into from
Jul 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.spark300.RapidsShuffleManager

import org.apache.spark.SparkEnv
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last}
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources.HadoopFsRelation
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashJoin, SortMergeJoinExec}
Expand Down Expand Up @@ -193,4 +195,11 @@ class Spark300Shims extends SparkShims {
override def getRapidsShuffleManagerClass: String = {
classOf[RapidsShuffleManager].getCanonicalName
}

override def injectQueryStagePrepRule(
extensions: SparkSessionExtensions,
ruleBuilder: SparkSession => Rule[SparkPlan]): Unit = {
// not supported in 3.0.0 but it doesn't matter because AdaptiveSparkPlanExec in 3.0.0 will
// never allow us to replace an Exchange node, so they just stay on CPU
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we throw an exception here to be sure of that assumption?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, this code will still be called when we run with 3.0.0 but we can't inject the rule, because that feature isn't available in 3.0.0.

When the plugin runs against 3.0.0 with AQE on, our optimizer rules will only get applied to the children of any exchange nodes.

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.shims.spark300.Spark300Shims
import com.nvidia.spark.rapids.spark301.RapidsShuffleManager

import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan

class Spark301Shims extends Spark300Shims {

Expand All @@ -47,4 +50,10 @@ class Spark301Shims extends Spark300Shims {
override def getRapidsShuffleManagerClass: String = {
classOf[RapidsShuffleManager].getCanonicalName
}

override def injectQueryStagePrepRule(
extensions: SparkSessionExtensions,
ruleBuilder: SparkSession => Rule[SparkPlan]): Unit = {
extensions.injectQueryStagePrepRule(ruleBuilder)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1744,6 +1744,16 @@ object GpuOverrides {
val execs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] =
commonExecs ++ ShimLoader.getSparkShims.getExecs
}
/** Tag the initial plan when AQE is enabled */
case class GpuQueryStagePrepOverrides() extends Rule[SparkPlan] with Logging {
override def apply(plan: SparkPlan) :SparkPlan = {
// Note that we disregard the GPU plan returned here and instead rely on side effects of
// tagging the underlying SparkPlan.
GpuOverrides().apply(plan)
// return the original plan which is now modified as a side-effect of invoking GpuOverrides
plan
}
}

case class GpuOverrides() extends Rule[SparkPlan] with Logging {
override def apply(plan: SparkPlan) :SparkPlan = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class SQLExecPlugin extends (SparkSessionExtensions => Unit) with Logging {
logWarning("Installing extensions to enable rapids GPU SQL support." +
s" To disable GPU support set `${RapidsConf.SQL_ENABLED}` to false")
extensions.injectColumnar(_ => ColumnarOverrideRules())
ShimLoader.getSparkShims.injectQueryStagePrepRule(extensions, _ => GpuQueryStagePrepOverrides())
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import com.nvidia.spark.rapids.GpuOverrides.isStringLit
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, ComplexTypeMergingExpression, Expression, String2TrimExpression, TernaryExpression, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.command.DataWritingCommand
Expand Down Expand Up @@ -116,12 +117,21 @@ abstract class RapidsMeta[INPUT <: BASE, BASE, OUTPUT <: BASE](

private var shouldBeRemovedReasons: Option[mutable.Set[String]] = None

val gpuSupportedTag = TreeNodeTag[String]("rapids.gpu.supported")

/**
* Call this to indicate that this should not be replaced with a GPU enabled version
* @param because why it should not be replaced.
*/
final def willNotWorkOnGpu(because: String): Unit =
final def willNotWorkOnGpu(because: String): Unit = {
cannotBeReplacedReasons.get.add(because)
// annotate the real spark plan with the reason as well so that the information is available
// during query stage planning when AQE is on
wrapped match {
case p: SparkPlan => p.setTagValue(gpuSupportedTag, because)
case _ =>
}
}

final def shouldBeRemoved(because: String): Unit =
shouldBeRemovedReasons.get.add(because)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

package com.nvidia.spark.rapids

import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.rapids.execution.GpuBroadcastNestedLoopJoinExecBase
Expand Down Expand Up @@ -68,4 +70,8 @@ trait SparkShims {
endMapIndex: Int,
startPartition: Int,
endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]

def injectQueryStagePrepRule(
extensions: SparkSessionExtensions,
rule: SparkSession => Rule[SparkPlan])
}
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,18 @@ class GpuBroadcastMeta(
case _: BroadcastNestedLoopJoinExec => true
case _ => false
}
if (!parent.exists(isSupported)) {
willNotWorkOnGpu("BroadcastExchange only works on the GPU if being used " +
"with a GPU version of BroadcastHashJoinExec or BroadcastNestedLoopJoinExec")
if (parent.isDefined) {
if (!parent.exists(isSupported)) {
willNotWorkOnGpu("BroadcastExchange only works on the GPU if being used " +
"with a GPU version of BroadcastHashJoinExec or BroadcastNestedLoopJoinExec")
}
} else {
// when AQE is enabled and we are planning a new query stage, parent will be None so
// we need to look at meta-data previously stored on the spark plan
wrapped.getTagValue(gpuSupportedTag) match {
case Some(reason) => willNotWorkOnGpu(reason)
case None => // this broadcast is supported on GPU
}
}
}

Expand Down