diff --git a/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala b/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala index 01e1cf4e279..e74f2c6753b 100644 --- a/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala +++ b/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala @@ -36,12 +36,13 @@ import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, ShuffleQueryStageExec} +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, QueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.command.{AlterTableRecoverPartitionsCommand, RunnableCommand} import org.apache.spark.sql.execution.datasources.{FileIndex, FilePartition, FileScanRDD, HadoopFsRelation, InMemoryFileIndex, PartitionDirectory, PartitionedFile} import org.apache.spark.sql.execution.datasources.rapids.GpuPartitioningUtils @@ -149,6 +150,10 @@ class Spark300Shims extends SparkShims { override def isShuffleExchangeLike(plan: SparkPlan): Boolean = plan.isInstanceOf[ShuffleExchangeExec] + override def getQueryStageRuntimeStatistics(plan: QueryStageExec): Statistics = { + Statistics(0) + } + override def getExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] = { Seq( GpuOverrides.exec[WindowInPandasExec]( diff --git a/shims/spark301/src/main/scala/com/nvidia/spark/rapids/shims/spark301/GpuShuffleExchangeExec.scala b/shims/spark301/src/main/scala/com/nvidia/spark/rapids/shims/spark301/GpuShuffleExchangeExec.scala index 9712f62bd7b..006d91f74be 100644 --- a/shims/spark301/src/main/scala/com/nvidia/spark/rapids/shims/spark301/GpuShuffleExchangeExec.scala +++ b/shims/spark301/src/main/scala/com/nvidia/spark/rapids/shims/spark301/GpuShuffleExchangeExec.scala @@ -37,7 +37,12 @@ case class GpuShuffleExchangeExec( } override def runtimeStatistics: Statistics = { - val dataSize = metrics("dataSize").value - Statistics(dataSize) + // note that Spark will only use the sizeInBytes statistic but making the rowCount + // available here means that we can more easily reference it in GpuOverrides when + // planning future query stages when AQE is on + Statistics( + sizeInBytes = metrics("dataSize").value, + rowCount = Some(metrics("numOutputRows").value) + ) } } diff --git a/shims/spark301/src/main/scala/com/nvidia/spark/rapids/shims/spark301/Spark301Shims.scala b/shims/spark301/src/main/scala/com/nvidia/spark/rapids/shims/spark301/Spark301Shims.scala index 2d0aacb648d..b91dbac46c5 100644 --- a/shims/spark301/src/main/scala/com/nvidia/spark/rapids/shims/spark301/Spark301Shims.scala +++ b/shims/spark301/src/main/scala/com/nvidia/spark/rapids/shims/spark301/Spark301Shims.scala @@ -26,10 +26,11 @@ import org.apache.spark.SparkEnv import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} +import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, ShuffleQueryStageExec} +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, QueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExecBase, GpuShuffleExchangeExecBase} @@ -126,6 +127,9 @@ class Spark301Shims extends Spark300Shims { override def isShuffleExchangeLike(plan: SparkPlan): Boolean = plan.isInstanceOf[ShuffleExchangeLike] + override def getQueryStageRuntimeStatistics(qs: QueryStageExec): Statistics = + qs.getRuntimeStatistics + override def injectQueryStagePrepRule( extensions: SparkSessionExtensions, ruleBuilder: SparkSession => Rule[SparkPlan]): Unit = { diff --git a/shims/spark301db/src/main/scala/com/nvidia/spark/rapids/shims/spark301db/GpuShuffleExchangeExec.scala b/shims/spark301db/src/main/scala/com/nvidia/spark/rapids/shims/spark301db/GpuShuffleExchangeExec.scala index 5c92da25aaf..ecf8fdbbb69 100644 --- a/shims/spark301db/src/main/scala/com/nvidia/spark/rapids/shims/spark301db/GpuShuffleExchangeExec.scala +++ b/shims/spark301db/src/main/scala/com/nvidia/spark/rapids/shims/spark301db/GpuShuffleExchangeExec.scala @@ -39,7 +39,12 @@ case class GpuShuffleExchangeExec( } override def runtimeStatistics: Statistics = { - val dataSize = metrics("dataSize").value - Statistics(dataSize) + // note that Spark will only use the sizeInBytes statistic but making the rowCount + // available here means that we can more easily reference it in GpuOverrides when + // planning future query stages when AQE is on + Statistics( + sizeInBytes = metrics("dataSize").value, + rowCount = Some(metrics("numOutputRows").value) + ) } } diff --git a/shims/spark311/src/main/scala/com/nvidia/spark/rapids/shims/spark311/GpuShuffleExchangeExec.scala b/shims/spark311/src/main/scala/com/nvidia/spark/rapids/shims/spark311/GpuShuffleExchangeExec.scala index d2b0ca6ffba..7deafe47d9d 100644 --- a/shims/spark311/src/main/scala/com/nvidia/spark/rapids/shims/spark311/GpuShuffleExchangeExec.scala +++ b/shims/spark311/src/main/scala/com/nvidia/spark/rapids/shims/spark311/GpuShuffleExchangeExec.scala @@ -38,7 +38,12 @@ case class GpuShuffleExchangeExec( } override def runtimeStatistics: Statistics = { - val dataSize = metrics("dataSize").value - Statistics(dataSize) + // note that Spark will only use the sizeInBytes statistic but making the rowCount + // available here means that we can more easily reference it in GpuOverrides when + // planning future query stages when AQE is on + Statistics( + sizeInBytes = metrics("dataSize").value, + rowCount = Some(metrics("numOutputRows").value) + ) } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala index cc221a37996..9991171ced1 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala @@ -28,12 +28,13 @@ import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, ExprId, NullOrdering, SortDirection, SortOrder} import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec +import org.apache.spark.sql.execution.adaptive.{QueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.execution.datasources.{FileIndex, FilePartition, HadoopFsRelation, PartitionDirectory, PartitionedFile} import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeExec} @@ -82,6 +83,7 @@ trait SparkShims { def isGpuShuffledHashJoin(plan: SparkPlan): Boolean def isBroadcastExchangeLike(plan: SparkPlan): Boolean def isShuffleExchangeLike(plan: SparkPlan): Boolean + def getQueryStageRuntimeStatistics(plan: QueryStageExec): Statistics def getRapidsShuffleManagerClass: String def getBuildSide(join: HashJoin): GpuBuildSide def getBuildSide(join: BroadcastNestedLoopJoinExec): GpuBuildSide diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/AdaptiveQueryExecSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/AdaptiveQueryExecSuite.scala index 7f2a2ae4779..fd74438ec77 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/AdaptiveQueryExecSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/AdaptiveQueryExecSuite.scala @@ -97,6 +97,24 @@ class AdaptiveQueryExecSuite collectWithSubqueries(plan)(ShimLoader.getSparkShims.reusedExchangeExecPfn) } + test("get row counts from executed shuffle query stages") { + assumeSpark301orLater + + skewJoinTest { spark => + val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult( + spark, + "SELECT * FROM skewData1 join skewData2 ON key1 = key2") + val innerSmj = findTopLevelGpuShuffleHashJoin(innerAdaptivePlan) + val shuffleExchanges = ShimLoader.getSparkShims + .findOperators(innerAdaptivePlan, _.isInstanceOf[ShuffleQueryStageExec]) + .map(_.asInstanceOf[ShuffleQueryStageExec]) + assert(shuffleExchanges.length === 2) + val shim = ShimLoader.getSparkShims + val stats = shuffleExchanges.map(e => shim.getQueryStageRuntimeStatistics(e)) + assert(stats.forall(_.rowCount.contains(1000))) + } + } + test("skewed inner join optimization") { skewJoinTest { spark => val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult(