Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose row count statistics in GpuShuffleExchangeExec #1855

Merged
merged 8 commits into from
Mar 5, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,13 @@ import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last}
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, QueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.datasources.{FileIndex, FilePartition, FileScanRDD, HadoopFsRelation, InMemoryFileIndex, PartitionDirectory, PartitionedFile}
import org.apache.spark.sql.execution.datasources.rapids.GpuPartitioningUtils
import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
Expand Down Expand Up @@ -134,6 +135,10 @@ class Spark300Shims extends SparkShims {
override def isShuffleExchangeLike(plan: SparkPlan): Boolean =
plan.isInstanceOf[ShuffleExchangeExec]

override def getQueryStageRuntimeStatistics(plan: QueryStageExec): Statistics = {
Statistics(0)
}

override def getExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] = {
Seq(
GpuOverrides.exec[WindowInPandasExec](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@ case class GpuShuffleExchangeExec(
}

override def runtimeStatistics: Statistics = {
val dataSize = metrics("dataSize").value
Statistics(dataSize)
// note that Spark will only use the sizeInBytes statistic but making the rowCount
// available here means that we can more easily reference it in GpuOverrides when
// planning future query stages when AQE is on
Statistics(
sizeInBytes = metrics("dataSize").value,
rowCount = Some(metrics("numOutputRows").value)
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@ import org.apache.spark.SparkEnv
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last}
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, QueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExecBase, GpuShuffleExchangeExecBase}
Expand Down Expand Up @@ -126,6 +127,9 @@ class Spark301Shims extends Spark300Shims {
override def isShuffleExchangeLike(plan: SparkPlan): Boolean =
plan.isInstanceOf[ShuffleExchangeLike]

override def getQueryStageRuntimeStatistics(qs: QueryStageExec): Statistics =
qs.getRuntimeStatistics

override def injectQueryStagePrepRule(
extensions: SparkSessionExtensions,
ruleBuilder: SparkSession => Rule[SparkPlan]): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@ import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, ExprId, NullOrdering, SortDirection, SortOrder}
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
import org.apache.spark.sql.execution.adaptive.{QueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.datasources.{FileIndex, FilePartition, HadoopFsRelation, PartitionDirectory, PartitionedFile}
import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins._
Expand Down Expand Up @@ -73,6 +74,7 @@ trait SparkShims {
def isGpuShuffledHashJoin(plan: SparkPlan): Boolean
def isBroadcastExchangeLike(plan: SparkPlan): Boolean
def isShuffleExchangeLike(plan: SparkPlan): Boolean
def getQueryStageRuntimeStatistics(plan: QueryStageExec): Statistics
def getRapidsShuffleManagerClass: String
def getBuildSide(join: HashJoin): GpuBuildSide
def getBuildSide(join: BroadcastNestedLoopJoinExec): GpuBuildSide
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,24 @@ class AdaptiveQueryExecSuite
collectWithSubqueries(plan)(ShimLoader.getSparkShims.reusedExchangeExecPfn)
}

test("get row counts from executed shuffle query stages") {
assumeSpark301orLater

skewJoinTest { spark =>
val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult(
spark,
"SELECT * FROM skewData1 join skewData2 ON key1 = key2")
val innerSmj = findTopLevelGpuShuffleHashJoin(innerAdaptivePlan)
val shuffleExchanges = ShimLoader.getSparkShims
.findOperators(innerAdaptivePlan, _.isInstanceOf[ShuffleQueryStageExec])
.map(_.asInstanceOf[ShuffleQueryStageExec])
assert(shuffleExchanges.length == 2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: triple equals produces better diagnostics:

scala> assert(1 == 0)
org.scalatest.exceptions.TestFailedException
  at org.scalatest.Assertions.newAssertionFailedException(Assertions.scala:528)
  at org.scalatest.Assertions.newAssertionFailedException$(Assertions.scala:527)
  at org.scalatest.Assertions$.newAssertionFailedException(Assertions.scala:1387)
  at org.scalatest.Assertions$AssertionsHelper.macroAssert(Assertions.scala:501)
  ... 59 elided

scala> assert(1 === 0)
org.scalatest.exceptions.TestFailedException: 1 did not equal 0
  at org.scalatest.Assertions.newAssertionFailedException(Assertions.scala:528)
  at org.scalatest.Assertions.newAssertionFailedException$(Assertions.scala:527)
  at org.scalatest.Assertions$.newAssertionFailedException(Assertions.scala:1387)
  at org.scalatest.Assertions$AssertionsHelper.macroAssert(Assertions.scala:501)
  ... 59 elided

val shim = ShimLoader.getSparkShims
val stats = shuffleExchanges.map(e => shim.getQueryStageRuntimeStatistics(e))
assert(stats.forall(_.rowCount.contains(1000)))
}
}

test("skewed inner join optimization") {
skewJoinTest { spark =>
val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult(
Expand Down