Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unit test for GPU exchange re-use with AQE #674

Merged
merged 5 commits into from
Sep 15, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ import com.nvidia.spark.rapids.AdaptiveQueryExecSuite.TEST_FILES_ROOT
import org.scalatest.BeforeAndAfterEach

import org.apache.spark.SparkConf
import org.apache.spark.sql.{SaveMode, SparkSession}
import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession}
import org.apache.spark.sql.execution.{PartialReducerPartitionSpec, SparkPlan}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.command.DataWritingCommandExec
import org.apache.spark.sql.execution.exchange.Exchange
import org.apache.spark.sql.execution.exchange.{Exchange, ReusedExchangeExec}
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.functions.when
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.execution.{GpuCustomShuffleReaderExec, GpuShuffledHashJoinBase}
Expand Down Expand Up @@ -73,18 +74,37 @@ class AdaptiveQueryExecSuite
(dfAdaptive.queryExecution.sparkPlan, adaptivePlan)
}

private def findTopLevelSortMergeJoin(plan: SparkPlan): Seq[GpuShuffledHashJoinBase] = {
private def findTopLevelSortMergeJoin(plan: SparkPlan): Seq[SortMergeJoinExec] = {
collect(plan) {
case j: SortMergeJoinExec => j
}
}

private def findTopLevelGpuBroadcastHashJoin(plan: SparkPlan): Seq[GpuExec] = {
collect(plan) {
case j: GpuExec if ShimLoader.getSparkShims.isBroadcastExchangeLike(j) => j
}
}

private def findTopLevelGpuShuffleHashJoin(plan: SparkPlan): Seq[GpuShuffledHashJoinBase] = {
collect(plan) {
case j: GpuShuffledHashJoinBase => j
}
}

private def findReusedExchange(plan: SparkPlan): Seq[ReusedExchangeExec] = {
collectWithSubqueries(plan) {
case ShuffleQueryStageExec(_, e: ReusedExchangeExec) => e
case BroadcastQueryStageExec(_, e: ReusedExchangeExec) => e
}
}

test("skewed inner join optimization") {
skewJoinTest { spark =>
val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult(
spark,
"SELECT * FROM skewData1 join skewData2 ON key1 = key2")
val innerSmj = findTopLevelSortMergeJoin(innerAdaptivePlan)
val innerSmj = findTopLevelGpuShuffleHashJoin(innerAdaptivePlan)
checkSkewJoin(innerSmj, 2, 1)
}
}
Expand All @@ -94,7 +114,7 @@ class AdaptiveQueryExecSuite
val (_, leftAdaptivePlan) = runAdaptiveAndVerifyResult(
spark,
"SELECT * FROM skewData1 left outer join skewData2 ON key1 = key2")
val leftSmj = findTopLevelSortMergeJoin(leftAdaptivePlan)
val leftSmj = findTopLevelGpuShuffleHashJoin(leftAdaptivePlan)
checkSkewJoin(leftSmj, 2, 0)
}
}
Expand All @@ -104,7 +124,7 @@ class AdaptiveQueryExecSuite
val (_, rightAdaptivePlan) = runAdaptiveAndVerifyResult(
spark,
"SELECT * FROM skewData1 right outer join skewData2 ON key1 = key2")
val rightSmj = findTopLevelSortMergeJoin(rightAdaptivePlan)
val rightSmj = findTopLevelGpuShuffleHashJoin(rightAdaptivePlan)
checkSkewJoin(rightSmj, 0, 1)
}
}
Expand Down Expand Up @@ -241,6 +261,38 @@ class AdaptiveQueryExecSuite
}, conf)
}

test("Exchange reuse") {

assumeSpark301orLater

val conf = new SparkConf()
.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")
.set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "-1")

withGpuSparkSession(spark => {
setupTestData(spark)

val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(spark,
"SELECT value FROM testData join testData2 ON key = a " +
"join (SELECT value v from testData join testData3 ON key = a) on value = v")

// initial plan should have three SMJs
val smj = findTopLevelSortMergeJoin(plan)
assert(smj.size == 3)

// executed GPU plan replaces SMJ with SHJ
val shj = findTopLevelGpuShuffleHashJoin(adaptivePlan)
assert(shj.size == 3)

// one of the GPU exchanges should have been re-used
val ex = findReusedExchange(adaptivePlan)
assert(ex.size == 1)
assert(ShimLoader.getSparkShims.isShuffleExchangeLike(ex.head.child))
assert(ex.head.child.isInstanceOf[GpuExec])

}, conf)
}

def skewJoinTest(fun: SparkSession => Unit) {
assumeSpark301orLater

Expand Down Expand Up @@ -313,4 +365,44 @@ class AdaptiveQueryExecSuite
assert(rightSkew.length == rightSkewNum)
}

private def setupTestData(spark: SparkSession): Unit = {
testData(spark)
testData2(spark)
testData3(spark)
}

/** Ported from org.apache.spark.sql.test.SQLTestData */
private def testData(spark: SparkSession) {
import spark.implicits._
val data: Seq[(Int, String)] = (1 to 100).map(i => (i, i.toString))
val df = data.toDF("key", "value")
.repartition(6)
registerAsParquetTable(spark, df, "testData") }

/** Ported from org.apache.spark.sql.test.SQLTestData */
private def testData2(spark: SparkSession) {
import spark.implicits._
val df = Seq[(Int, Int)]((1, 1), (1, 2), (2, 1), (2, 2), (3, 1), (3, 2))
.toDF("a", "b")
.repartition(2)
registerAsParquetTable(spark, df, "testData2")
}

/** Ported from org.apache.spark.sql.test.SQLTestData */
private def testData3(spark: SparkSession) {
import spark.implicits._
val df = Seq[(Int, Option[Int])]((1, None), (2, Some(2)))
.toDF("a", "b")
.repartition(6)
registerAsParquetTable(spark, df, "testData3")
}

private def registerAsParquetTable(spark: SparkSession, df: Dataset[Row], name: String) {
val path = new File(TEST_FILES_ROOT, s"$name.parquet").getAbsolutePath
df.write
.mode(SaveMode.Overwrite)
.parquet(path)
spark.read.parquet(path).createOrReplaceTempView(name)
}

}