diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/AdaptiveQueryExecSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/AdaptiveQueryExecSuite.scala index d74c1ab2656..5bc0c218f83 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/AdaptiveQueryExecSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/AdaptiveQueryExecSuite.scala @@ -22,11 +22,12 @@ import com.nvidia.spark.rapids.AdaptiveQueryExecSuite.TEST_FILES_ROOT import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkConf -import org.apache.spark.sql.{SaveMode, SparkSession} +import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession} import org.apache.spark.sql.execution.{PartialReducerPartitionSpec, SparkPlan} -import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper} +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, BroadcastQueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.command.DataWritingCommandExec -import org.apache.spark.sql.execution.exchange.Exchange +import org.apache.spark.sql.execution.exchange.{Exchange, ReusedExchangeExec} +import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.functions.when import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.rapids.execution.{GpuCustomShuffleReaderExec, GpuShuffledHashJoinBase} @@ -73,18 +74,37 @@ class AdaptiveQueryExecSuite (dfAdaptive.queryExecution.sparkPlan, adaptivePlan) } - private def findTopLevelSortMergeJoin(plan: SparkPlan): Seq[GpuShuffledHashJoinBase] = { + private def findTopLevelSortMergeJoin(plan: SparkPlan): Seq[SortMergeJoinExec] = { + collect(plan) { + case j: SortMergeJoinExec => j + } + } + + private def findTopLevelGpuBroadcastHashJoin(plan: SparkPlan): Seq[GpuExec] = { + collect(plan) { + case j: GpuExec if ShimLoader.getSparkShims.isBroadcastExchangeLike(j) => j + } + } + + private def findTopLevelGpuShuffleHashJoin(plan: SparkPlan): Seq[GpuShuffledHashJoinBase] = { collect(plan) { case j: GpuShuffledHashJoinBase => j } } + private def findReusedExchange(plan: SparkPlan): Seq[ReusedExchangeExec] = { + collectWithSubqueries(plan) { + case ShuffleQueryStageExec(_, e: ReusedExchangeExec) => e + case BroadcastQueryStageExec(_, e: ReusedExchangeExec) => e + } + } + test("skewed inner join optimization") { skewJoinTest { spark => val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult( spark, "SELECT * FROM skewData1 join skewData2 ON key1 = key2") - val innerSmj = findTopLevelSortMergeJoin(innerAdaptivePlan) + val innerSmj = findTopLevelGpuShuffleHashJoin(innerAdaptivePlan) checkSkewJoin(innerSmj, 2, 1) } } @@ -94,7 +114,7 @@ class AdaptiveQueryExecSuite val (_, leftAdaptivePlan) = runAdaptiveAndVerifyResult( spark, "SELECT * FROM skewData1 left outer join skewData2 ON key1 = key2") - val leftSmj = findTopLevelSortMergeJoin(leftAdaptivePlan) + val leftSmj = findTopLevelGpuShuffleHashJoin(leftAdaptivePlan) checkSkewJoin(leftSmj, 2, 0) } } @@ -104,7 +124,7 @@ class AdaptiveQueryExecSuite val (_, rightAdaptivePlan) = runAdaptiveAndVerifyResult( spark, "SELECT * FROM skewData1 right outer join skewData2 ON key1 = key2") - val rightSmj = findTopLevelSortMergeJoin(rightAdaptivePlan) + val rightSmj = findTopLevelGpuShuffleHashJoin(rightAdaptivePlan) checkSkewJoin(rightSmj, 0, 1) } } @@ -241,6 +261,38 @@ class AdaptiveQueryExecSuite }, conf) } + test("Exchange reuse") { + + assumeSpark301orLater + + val conf = new SparkConf() + .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true") + .set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "-1") + + withGpuSparkSession(spark => { + setupTestData(spark) + + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(spark, + "SELECT value FROM testData join testData2 ON key = a " + + "join (SELECT value v from testData join testData3 ON key = a) on value = v") + + // initial plan should have three SMJs + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 3) + + // executed GPU plan replaces SMJ with SHJ + val shj = findTopLevelGpuShuffleHashJoin(adaptivePlan) + assert(shj.size == 3) + + // one of the GPU exchanges should have been re-used + val ex = findReusedExchange(adaptivePlan) + assert(ex.size == 1) + assert(ShimLoader.getSparkShims.isShuffleExchangeLike(ex.head.child)) + assert(ex.head.child.isInstanceOf[GpuExec]) + + }, conf) + } + def skewJoinTest(fun: SparkSession => Unit) { assumeSpark301orLater @@ -313,4 +365,44 @@ class AdaptiveQueryExecSuite assert(rightSkew.length == rightSkewNum) } + private def setupTestData(spark: SparkSession): Unit = { + testData(spark) + testData2(spark) + testData3(spark) + } + + /** Ported from org.apache.spark.sql.test.SQLTestData */ + private def testData(spark: SparkSession) { + import spark.implicits._ + val data: Seq[(Int, String)] = (1 to 100).map(i => (i, i.toString)) + val df = data.toDF("key", "value") + .repartition(6) + registerAsParquetTable(spark, df, "testData") } + + /** Ported from org.apache.spark.sql.test.SQLTestData */ + private def testData2(spark: SparkSession) { + import spark.implicits._ + val df = Seq[(Int, Int)]((1, 1), (1, 2), (2, 1), (2, 2), (3, 1), (3, 2)) + .toDF("a", "b") + .repartition(2) + registerAsParquetTable(spark, df, "testData2") + } + + /** Ported from org.apache.spark.sql.test.SQLTestData */ + private def testData3(spark: SparkSession) { + import spark.implicits._ + val df = Seq[(Int, Option[Int])]((1, None), (2, Some(2))) + .toDF("a", "b") + .repartition(6) + registerAsParquetTable(spark, df, "testData3") + } + + private def registerAsParquetTable(spark: SparkSession, df: Dataset[Row], name: String) { + val path = new File(TEST_FILES_ROOT, s"$name.parquet").getAbsolutePath + df.write + .mode(SaveMode.Overwrite) + .parquet(path) + spark.read.parquet(path).createOrReplaceTempView(name) + } + }