Skip to content

Commit

Permalink
fixUpJoinConsistency rule now works when AQE is enabled (#676)
Browse files Browse the repository at this point in the history
* fixUpJoinConsistency rule now works with AQE

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* Add comma to error message

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* Improved validation checks and error messages

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* bug fix: walk tree once to find shuffle exchanges and query stages

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* code simplification

Signed-off-by: Andy Grove <andygrove@nvidia.com>
  • Loading branch information
andygrove authored Sep 8, 2020
1 parent edadbbd commit 221e1c5
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
package com.nvidia.spark.rapids

import org.apache.spark.SparkConf
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.functions.{col, upper}

class JoinsSuite extends SparkQueryCompareTestSuite {

Expand Down Expand Up @@ -97,4 +101,65 @@ class JoinsSuite extends SparkQueryCompareTestSuite {
mixedDfWithNulls, mixedDfWithNulls, sortBeforeRepart = true) {
(A, B) => A.join(B, A("longs") === B("longs"), "LeftAnti")
}

test("fixUpJoinConsistencyIfNeeded AQE on") {
// this test is only valid in Spark 3.0.1 and later due to AQE supporting the plugin
val isValidTestForSparkVersion = ShimLoader.getSparkShims.getSparkShimVersion match {
case SparkShimVersion(3, 0, 0) => false
case DatabricksShimVersion(3, 0, 0) => false
case _ => true
}
assume(isValidTestForSparkVersion)
testFixUpJoinConsistencyIfNeeded(true)
}

test("fixUpJoinConsistencyIfNeeded AQE off") {
testFixUpJoinConsistencyIfNeeded(false)
}

private def testFixUpJoinConsistencyIfNeeded(aqe: Boolean) {

val conf = shuffledJoinConf.clone()
.set("spark.sql.adaptive.enabled", String.valueOf(aqe))
.set("spark.rapids.sql.test.allowedNonGpu",
"BroadcastHashJoinExec,SortMergeJoinExec,SortExec,Upper")
.set("spark.rapids.sql.incompatibleOps.enabled", "false") // force UPPER onto CPU

withGpuSparkSession(spark => {
import spark.implicits._

def createStringDF(name: String, upper: Boolean = false): DataFrame = {
val countryNames = (0 until 1000).map(i => s"country_$i")
if (upper) {
countryNames.map(_.toUpperCase).toDF(name)
} else {
countryNames.toDF(name)
}
}

val left = createStringDF("c1")
.join(createStringDF("c2"), col("c1") === col("c2"))

val right = createStringDF("c3")
.join(createStringDF("c4"), col("c3") === col("c4"))

val join = left.join(right, upper(col("c1")) === col("c4"))

// call collect so that we get the final executed plan when AQE is on
join.collect()

val shuffleExec = TestUtils
.findOperator(join.queryExecution.executedPlan, _.isInstanceOf[ShuffleExchangeExec])
.get

val gpuSupportedTag = TreeNodeTag[Set[String]]("rapids.gpu.supported")
val reasons = shuffleExec.getTagValue(gpuSupportedTag).getOrElse(Set.empty)
assert(reasons.contains(
"other exchanges that feed the same join are on the CPU, and GPU " +
"hashing is not consistent with the CPU version"))

}, conf)

}

}
54 changes: 45 additions & 9 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.adaptive.{QueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.command.DataWritingCommand
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.types.DataType

Expand Down Expand Up @@ -117,7 +118,7 @@ abstract class RapidsMeta[INPUT <: BASE, BASE, OUTPUT <: BASE](

private var shouldBeRemovedReasons: Option[mutable.Set[String]] = None

val gpuSupportedTag = TreeNodeTag[String]("rapids.gpu.supported")
val gpuSupportedTag = TreeNodeTag[Set[String]]("rapids.gpu.supported")

/**
* Call this to indicate that this should not be replaced with a GPU enabled version
Expand All @@ -128,7 +129,9 @@ abstract class RapidsMeta[INPUT <: BASE, BASE, OUTPUT <: BASE](
// annotate the real spark plan with the reason as well so that the information is available
// during query stage planning when AQE is on
wrapped match {
case p: SparkPlan => p.setTagValue(gpuSupportedTag, because)
case p: SparkPlan =>
p.setTagValue(gpuSupportedTag,
p.getTagValue(gpuSupportedTag).getOrElse(Set.empty) + because)
case _ =>
}
}
Expand Down Expand Up @@ -429,9 +432,13 @@ abstract class SparkPlanMeta[INPUT <: SparkPlan](plan: INPUT,
wrapped.withNewChildren(childPlans.map(_.convertIfNeeded()))
}

private def findShuffleExchanges(): Seq[SparkPlanMeta[ShuffleExchangeExec]] = wrapped match {
private def findShuffleExchanges(): Seq[Either[
SparkPlanMeta[QueryStageExec],
SparkPlanMeta[ShuffleExchangeExec]]] = wrapped match {
case _: ShuffleQueryStageExec =>
Left(this.asInstanceOf[SparkPlanMeta[QueryStageExec]]) :: Nil
case _: ShuffleExchangeExec =>
this.asInstanceOf[SparkPlanMeta[ShuffleExchangeExec]] :: Nil
Right(this.asInstanceOf[SparkPlanMeta[ShuffleExchangeExec]]) :: Nil
case bkj: BroadcastHashJoinExec => ShimLoader.getSparkShims.getBuildSide(bkj) match {
case GpuBuildLeft => childPlans(1).findShuffleExchanges()
case GpuBuildRight => childPlans(0).findShuffleExchanges()
Expand All @@ -440,13 +447,42 @@ abstract class SparkPlanMeta[INPUT <: SparkPlan](plan: INPUT,
}

private def makeShuffleConsistent(): Unit = {
val exchanges = findShuffleExchanges()
if (!exchanges.forall(_.canThisBeReplaced)) {
exchanges.foreach(_.willNotWorkOnGpu("other exchanges that feed the same join are" +
" on the CPU and GPU hashing is not consistent with the CPU version"))
// during query execution when AQE is enabled, the plan could consist of a mixture of
// ShuffleExchangeExec nodes for exchanges that have not started executing yet, and
// ShuffleQueryStageExec nodes for exchanges that have already started executing. This code
// attempts to tag ShuffleExchangeExec nodes for CPU if other exchanges (either
// ShuffleExchangeExec or ShuffleQueryStageExec nodes) were also tagged for CPU.
val shuffleExchanges = findShuffleExchanges()

def canThisBeReplaced(plan: Either[
SparkPlanMeta[QueryStageExec],
SparkPlanMeta[ShuffleExchangeExec]]): Boolean = {
plan match {
case Left(qs) => qs.wrapped.plan match {
case _: GpuExec => true
case ReusedExchangeExec(_, _: GpuExec) => true
case _ => false
}
case Right(e) => e.canThisBeReplaced
}
}

// if we can't convert all exchanges to GPU then we need to make sure that all of them
// run on the CPU instead
if (!shuffleExchanges.forall(canThisBeReplaced)) {
// tag any exchanges that have not been converted to query stages yet
shuffleExchanges.filter(_.isRight)
.foreach(_.right.get.willNotWorkOnGpu("other exchanges that feed the same join are" +
" on the CPU, and GPU hashing is not consistent with the CPU version"))
// verify that no query stages already got converted to GPU
if (shuffleExchanges.filter(_.isLeft).exists(canThisBeReplaced)) {
throw new IllegalStateException("Join needs to run on CPU but at least one input " +
"query stage ran on GPU")
}
}
}


private def fixUpJoinConsistencyIfNeeded(): Unit = {
childPlans.foreach(_.fixUpJoinConsistencyIfNeeded())
wrapped match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,14 +222,10 @@ class GpuBroadcastMeta(
willNotWorkOnGpu("BroadcastExchange only works on the GPU if being used " +
"with a GPU version of BroadcastHashJoinExec or BroadcastNestedLoopJoinExec")
}
} else {
// when AQE is enabled and we are planning a new query stage, parent will be None so
// we need to look at meta-data previously stored on the spark plan
wrapped.getTagValue(gpuSupportedTag) match {
case Some(reason) => willNotWorkOnGpu(reason)
case None => // this broadcast is supported on GPU
}
}
// when AQE is enabled and we are planning a new query stage, we need to look at meta-data
// previously stored on the spark plan to determine whether this exchange can run on GPU
wrapped.getTagValue(gpuSupportedTag).foreach(_.foreach(willNotWorkOnGpu))
}

override def convertToGpu(): GpuExec = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ class GpuShuffleMeta(
override val childParts: scala.Seq[PartMeta[_]] =
Seq(GpuOverrides.wrapPart(shuffle.outputPartitioning, conf, Some(this)))

override def tagPlanForGpu(): Unit = {
// when AQE is enabled and we are planning a new query stage, we need to look at meta-data
// previously stored on the spark plan to determine whether this exchange can run on GPU
wrapped.getTagValue(gpuSupportedTag).foreach(_.foreach(willNotWorkOnGpu))
}

override def convertToGpu(): GpuExec =
ShimLoader.getSparkShims.getGpuShuffleExchangeExec(
childParts(0).convertToGpu(),
Expand Down

0 comments on commit 221e1c5

Please sign in to comment.