From 83df4488a781fc6ad886eeec79b2672ca029586e Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 8 Sep 2020 12:56:07 -0600 Subject: [PATCH] fixUpJoinConsistency rule now works when AQE is enabled (#676) * fixUpJoinConsistency rule now works with AQE Signed-off-by: Andy Grove * Add comma to error message Signed-off-by: Andy Grove * Improved validation checks and error messages Signed-off-by: Andy Grove * bug fix: walk tree once to find shuffle exchanges and query stages Signed-off-by: Andy Grove * code simplification Signed-off-by: Andy Grove --- .../com/nvidia/spark/rapids/JoinsSuite.scala | 65 +++++++++++++++++++ .../com/nvidia/spark/rapids/RapidsMeta.scala | 54 ++++++++++++--- .../execution/GpuBroadcastExchangeExec.scala | 10 +-- .../execution/GpuShuffleExchangeExec.scala | 6 ++ 4 files changed, 119 insertions(+), 16 deletions(-) diff --git a/integration_tests/src/test/scala/com/nvidia/spark/rapids/JoinsSuite.scala b/integration_tests/src/test/scala/com/nvidia/spark/rapids/JoinsSuite.scala index bd4b82ca2a7..4656e8c2a81 100644 --- a/integration_tests/src/test/scala/com/nvidia/spark/rapids/JoinsSuite.scala +++ b/integration_tests/src/test/scala/com/nvidia/spark/rapids/JoinsSuite.scala @@ -17,6 +17,10 @@ package com.nvidia.spark.rapids import org.apache.spark.SparkConf +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.catalyst.trees.TreeNodeTag +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.functions.{col, upper} class JoinsSuite extends SparkQueryCompareTestSuite { @@ -97,4 +101,65 @@ class JoinsSuite extends SparkQueryCompareTestSuite { mixedDfWithNulls, mixedDfWithNulls, sortBeforeRepart = true) { (A, B) => A.join(B, A("longs") === B("longs"), "LeftAnti") } + + test("fixUpJoinConsistencyIfNeeded AQE on") { + // this test is only valid in Spark 3.0.1 and later due to AQE supporting the plugin + val isValidTestForSparkVersion = ShimLoader.getSparkShims.getSparkShimVersion match { + case SparkShimVersion(3, 0, 0) => false + case DatabricksShimVersion(3, 0, 0) => false + case _ => true + } + assume(isValidTestForSparkVersion) + testFixUpJoinConsistencyIfNeeded(true) + } + + test("fixUpJoinConsistencyIfNeeded AQE off") { + testFixUpJoinConsistencyIfNeeded(false) + } + + private def testFixUpJoinConsistencyIfNeeded(aqe: Boolean) { + + val conf = shuffledJoinConf.clone() + .set("spark.sql.adaptive.enabled", String.valueOf(aqe)) + .set("spark.rapids.sql.test.allowedNonGpu", + "BroadcastHashJoinExec,SortMergeJoinExec,SortExec,Upper") + .set("spark.rapids.sql.incompatibleOps.enabled", "false") // force UPPER onto CPU + + withGpuSparkSession(spark => { + import spark.implicits._ + + def createStringDF(name: String, upper: Boolean = false): DataFrame = { + val countryNames = (0 until 1000).map(i => s"country_$i") + if (upper) { + countryNames.map(_.toUpperCase).toDF(name) + } else { + countryNames.toDF(name) + } + } + + val left = createStringDF("c1") + .join(createStringDF("c2"), col("c1") === col("c2")) + + val right = createStringDF("c3") + .join(createStringDF("c4"), col("c3") === col("c4")) + + val join = left.join(right, upper(col("c1")) === col("c4")) + + // call collect so that we get the final executed plan when AQE is on + join.collect() + + val shuffleExec = TestUtils + .findOperator(join.queryExecution.executedPlan, _.isInstanceOf[ShuffleExchangeExec]) + .get + + val gpuSupportedTag = TreeNodeTag[Set[String]]("rapids.gpu.supported") + val reasons = shuffleExec.getTagValue(gpuSupportedTag).getOrElse(Set.empty) + assert(reasons.contains( + "other exchanges that feed the same join are on the CPU, and GPU " + + "hashing is not consistent with the CPU version")) + + }, conf) + + } + } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala index 17340564150..f86e6908e5e 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala @@ -26,8 +26,9 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.adaptive.{QueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.command.DataWritingCommand -import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.types.DataType @@ -117,7 +118,7 @@ abstract class RapidsMeta[INPUT <: BASE, BASE, OUTPUT <: BASE]( private var shouldBeRemovedReasons: Option[mutable.Set[String]] = None - val gpuSupportedTag = TreeNodeTag[String]("rapids.gpu.supported") + val gpuSupportedTag = TreeNodeTag[Set[String]]("rapids.gpu.supported") /** * Call this to indicate that this should not be replaced with a GPU enabled version @@ -128,7 +129,9 @@ abstract class RapidsMeta[INPUT <: BASE, BASE, OUTPUT <: BASE]( // annotate the real spark plan with the reason as well so that the information is available // during query stage planning when AQE is on wrapped match { - case p: SparkPlan => p.setTagValue(gpuSupportedTag, because) + case p: SparkPlan => + p.setTagValue(gpuSupportedTag, + p.getTagValue(gpuSupportedTag).getOrElse(Set.empty) + because) case _ => } } @@ -429,9 +432,13 @@ abstract class SparkPlanMeta[INPUT <: SparkPlan](plan: INPUT, wrapped.withNewChildren(childPlans.map(_.convertIfNeeded())) } - private def findShuffleExchanges(): Seq[SparkPlanMeta[ShuffleExchangeExec]] = wrapped match { + private def findShuffleExchanges(): Seq[Either[ + SparkPlanMeta[QueryStageExec], + SparkPlanMeta[ShuffleExchangeExec]]] = wrapped match { + case _: ShuffleQueryStageExec => + Left(this.asInstanceOf[SparkPlanMeta[QueryStageExec]]) :: Nil case _: ShuffleExchangeExec => - this.asInstanceOf[SparkPlanMeta[ShuffleExchangeExec]] :: Nil + Right(this.asInstanceOf[SparkPlanMeta[ShuffleExchangeExec]]) :: Nil case bkj: BroadcastHashJoinExec => ShimLoader.getSparkShims.getBuildSide(bkj) match { case GpuBuildLeft => childPlans(1).findShuffleExchanges() case GpuBuildRight => childPlans(0).findShuffleExchanges() @@ -440,13 +447,42 @@ abstract class SparkPlanMeta[INPUT <: SparkPlan](plan: INPUT, } private def makeShuffleConsistent(): Unit = { - val exchanges = findShuffleExchanges() - if (!exchanges.forall(_.canThisBeReplaced)) { - exchanges.foreach(_.willNotWorkOnGpu("other exchanges that feed the same join are" + - " on the CPU and GPU hashing is not consistent with the CPU version")) + // during query execution when AQE is enabled, the plan could consist of a mixture of + // ShuffleExchangeExec nodes for exchanges that have not started executing yet, and + // ShuffleQueryStageExec nodes for exchanges that have already started executing. This code + // attempts to tag ShuffleExchangeExec nodes for CPU if other exchanges (either + // ShuffleExchangeExec or ShuffleQueryStageExec nodes) were also tagged for CPU. + val shuffleExchanges = findShuffleExchanges() + + def canThisBeReplaced(plan: Either[ + SparkPlanMeta[QueryStageExec], + SparkPlanMeta[ShuffleExchangeExec]]): Boolean = { + plan match { + case Left(qs) => qs.wrapped.plan match { + case _: GpuExec => true + case ReusedExchangeExec(_, _: GpuExec) => true + case _ => false + } + case Right(e) => e.canThisBeReplaced + } + } + + // if we can't convert all exchanges to GPU then we need to make sure that all of them + // run on the CPU instead + if (!shuffleExchanges.forall(canThisBeReplaced)) { + // tag any exchanges that have not been converted to query stages yet + shuffleExchanges.filter(_.isRight) + .foreach(_.right.get.willNotWorkOnGpu("other exchanges that feed the same join are" + + " on the CPU, and GPU hashing is not consistent with the CPU version")) + // verify that no query stages already got converted to GPU + if (shuffleExchanges.filter(_.isLeft).exists(canThisBeReplaced)) { + throw new IllegalStateException("Join needs to run on CPU but at least one input " + + "query stage ran on GPU") + } } } + private def fixUpJoinConsistencyIfNeeded(): Unit = { childPlans.foreach(_.fixUpJoinConsistencyIfNeeded()) wrapped match { diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala index 40178bc518c..e147928871c 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala @@ -222,14 +222,10 @@ class GpuBroadcastMeta( willNotWorkOnGpu("BroadcastExchange only works on the GPU if being used " + "with a GPU version of BroadcastHashJoinExec or BroadcastNestedLoopJoinExec") } - } else { - // when AQE is enabled and we are planning a new query stage, parent will be None so - // we need to look at meta-data previously stored on the spark plan - wrapped.getTagValue(gpuSupportedTag) match { - case Some(reason) => willNotWorkOnGpu(reason) - case None => // this broadcast is supported on GPU - } } + // when AQE is enabled and we are planning a new query stage, we need to look at meta-data + // previously stored on the spark plan to determine whether this exchange can run on GPU + wrapped.getTagValue(gpuSupportedTag).foreach(_.foreach(willNotWorkOnGpu)) } override def convertToGpu(): GpuExec = { diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExec.scala index 4b01fda5cab..27ebacb4a96 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExec.scala @@ -49,6 +49,12 @@ class GpuShuffleMeta( override val childParts: scala.Seq[PartMeta[_]] = Seq(GpuOverrides.wrapPart(shuffle.outputPartitioning, conf, Some(this))) + override def tagPlanForGpu(): Unit = { + // when AQE is enabled and we are planning a new query stage, we need to look at meta-data + // previously stored on the spark plan to determine whether this exchange can run on GPU + wrapped.getTagValue(gpuSupportedTag).foreach(_.foreach(willNotWorkOnGpu)) + } + override def convertToGpu(): GpuExec = ShimLoader.getSparkShims.getGpuShuffleExchangeExec( childParts(0).convertToGpu(),