From f659eafec8e99d68cbb31a9b6b9c808f81ec1721 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 3 Sep 2020 13:50:52 -0600 Subject: [PATCH] Improve warnings about AQE nodes not supported on GPU (#647) * Improve warnings about AQE nodes not supported on GPU Signed-off-by: Andy Grove * Introduce new DoNotReplaceSparkPlanMeta rule Signed-off-by: Andy Grove --- .../nvidia/spark/rapids/GpuOverrides.scala | 10 ++++++-- .../com/nvidia/spark/rapids/RapidsMeta.scala | 23 ++++++++++++++++++- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 3f97b4d5bfa..442671e49f1 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.adaptive.CustomShuffleReaderExec +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, CustomShuffleReaderExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.command.{DataWritingCommand, DataWritingCommandExec} import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, InsertIntoHadoopFsRelationCommand} @@ -1735,7 +1735,13 @@ object GpuOverrides { GpuCustomShuffleReaderExec(childPlans.head.convertIfNeeded(), exec.partitionSpecs) } - }) + }), + exec[AdaptiveSparkPlanExec]("Wrapper for adaptive query plan", (exec, conf, p, _) => + new DoNotReplaceSparkPlanMeta[AdaptiveSparkPlanExec](exec, conf, p)), + exec[BroadcastQueryStageExec]("Broadcast query stage", (exec, conf, p, _) => + new DoNotReplaceSparkPlanMeta[BroadcastQueryStageExec](exec, conf, p)), + exec[ShuffleQueryStageExec]("Shuffle query stage", (exec, conf, p, _) => + new DoNotReplaceSparkPlanMeta[ShuffleQueryStageExec](exec, conf, p)) ).map(r => (r.getClassFor.asSubclass(classOf[SparkPlan]), r)).toMap val execs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] = commonExecs ++ ShimLoader.getSparkShims.getExecs diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala index 13e28be1244..d5903f2b7bc 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala @@ -224,6 +224,8 @@ abstract class RapidsMeta[INPUT <: BASE, BASE, OUTPUT <: BASE]( private def indent(append: StringBuilder, depth: Int): Unit = append.append(" " * depth) + def suppressWillWorkOnGpuInfo: Boolean = false + private def willWorkOnGpuInfo: String = cannotBeReplacedReasons match { case None => "NOT EVALUATED FOR GPU YET" case Some(v) if v.isEmpty => "could run on GPU" @@ -253,7 +255,7 @@ abstract class RapidsMeta[INPUT <: BASE, BASE, OUTPUT <: BASE]( * @param all should all the data be printed or just what does not work on the GPU? */ protected def print(strBuilder: StringBuilder, depth: Int, all: Boolean): Unit = { - if (all || !canThisBeReplaced) { + if ((all || !canThisBeReplaced) && !suppressWillWorkOnGpuInfo) { indent(strBuilder, depth) strBuilder.append(if (canThisBeReplaced) "*" else "!") @@ -570,6 +572,25 @@ final class RuleNotFoundSparkPlanMeta[INPUT <: SparkPlan]( throw new IllegalStateException("Cannot be converted to GPU") } +/** + * Metadata for `SparkPlan` that should not be replaced. + */ +final class DoNotReplaceSparkPlanMeta[INPUT <: SparkPlan]( + plan: INPUT, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _, _]]) + extends SparkPlanMeta[INPUT](plan, conf, parent, new NoRuleConfKeysAndIncompat) { + + /** We don't want to spam the user with messages about these operators */ + override def suppressWillWorkOnGpuInfo: Boolean = true + + override def tagPlanForGpu(): Unit = + willNotWorkOnGpu(s"there is no need to replace ${plan.getClass}") + + override def convertToGpu(): GpuExec = + throw new IllegalStateException("Cannot be converted to GPU") +} + /** * Base class for metadata around `Expression`. */