From 8cc1e450b37ea7f46d56c313c58676a711ec91fa Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Thu, 17 Sep 2020 07:55:14 -0500 Subject: [PATCH] Make shuffle run on CPU if we do a join where we read from bucketed table (#785) * Make shuffle run on CPU if we do a join where we read from bucketed table Signed-off-by: Thomas Graves --- .../src/main/python/join_test.py | 20 ++++++++++ .../com/nvidia/spark/rapids/RapidsMeta.scala | 38 +++++++++++++++---- 2 files changed, 50 insertions(+), 8 deletions(-) diff --git a/integration_tests/src/main/python/join_test.py b/integration_tests/src/main/python/join_test.py index 01fa612e2f5..e638eda57d1 100644 --- a/integration_tests/src/main/python/join_test.py +++ b/integration_tests/src/main/python/join_test.py @@ -162,3 +162,23 @@ def do_join(spark): .withColumnRenamed("b", "r_b").withColumnRenamed("c", "r_c") return left.join(broadcast(right), left.a.eqNullSafe(right.r_a), join_type) assert_gpu_and_cpu_are_equal_collect(do_join) + +@ignore_order +@allow_non_gpu('DataWritingCommandExec') +@pytest.mark.parametrize('repartition', ["true", "false"], ids=idfn) +def test_join_bucketed_table(repartition): + def do_join(spark): + data = [("http://fooblog.com/blog-entry-116.html", "https://fooblog.com/blog-entry-116.html"), + ("http://fooblog.com/blog-entry-116.html", "http://fooblog.com/blog-entry-116.html")] + resolved = spark.sparkContext.parallelize(data).toDF(['Url','ResolvedUrl']) + feature_data = [("http://fooblog.com/blog-entry-116.html", "21")] + feature = spark.sparkContext.parallelize(feature_data).toDF(['Url','Count']) + feature.write.bucketBy(400, 'Url').sortBy('Url').format('parquet').mode('overwrite')\ + .saveAsTable('featuretable') + testurls = spark.sql("SELECT Url, Count FROM featuretable") + if (repartition == "true"): + return testurls.repartition(20).join(resolved, "Url", "inner") + else: + return testurls.join(resolved, "Url", "inner") + assert_gpu_and_cpu_are_equal_collect(do_join, conf={'spark.sql.autoBroadcastJoinThreshold': '-1'}) + diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala index ca7495c5c4c..f3f5b7425eb 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.connector.read.Scan -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.{QueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.command.DataWritingCommand import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeExec} @@ -447,6 +447,21 @@ abstract class SparkPlanMeta[INPUT <: SparkPlan](plan: INPUT, case _ => childPlans.flatMap(_.findShuffleExchanges()) } + private def findBucketedReads(): Seq[Boolean] = wrapped match { + case f: FileSourceScanExec => + if (f.bucketedScan) { + true :: Nil + } else { + false :: Nil + } + case _: ShuffleExchangeExec => + // if we find a shuffle before a scan then it doesn't matter if its + // a bucketed read + false :: Nil + case _ => + childPlans.flatMap(_.findBucketedReads()) + } + private def makeShuffleConsistent(): Unit = { // during query execution when AQE is enabled, the plan could consist of a mixture of // ShuffleExchangeExec nodes for exchanges that have not started executing yet, and @@ -454,6 +469,9 @@ abstract class SparkPlanMeta[INPUT <: SparkPlan](plan: INPUT, // attempts to tag ShuffleExchangeExec nodes for CPU if other exchanges (either // ShuffleExchangeExec or ShuffleQueryStageExec nodes) were also tagged for CPU. val shuffleExchanges = findShuffleExchanges() + // if any of the table reads are bucketed then we can't do the shuffle on the + // GPU because the hashing is different between the CPU and GPU + val bucketedReads = findBucketedReads().exists(_ == true) def canThisBeReplaced(plan: Either[ SparkPlanMeta[QueryStageExec], @@ -468,13 +486,18 @@ abstract class SparkPlanMeta[INPUT <: SparkPlan](plan: INPUT, } } - // if we can't convert all exchanges to GPU then we need to make sure that all of them - // run on the CPU instead - if (!shuffleExchanges.forall(canThisBeReplaced)) { + // if we are reading from a bucketed table or if we can't convert all exchanges to GPU + // then we need to make sure that all of them run on the CPU instead + if (bucketedReads || !shuffleExchanges.forall(canThisBeReplaced)) { + val errMsg = if (bucketedReads) { + "can't support shuffle on the GPU when doing a join that reads directly from a " + + "bucketed table!" + } else { + "other exchanges that feed the same join are on the CPU, and GPU hashing is " + + "not consistent with the CPU version" + } // tag any exchanges that have not been converted to query stages yet - shuffleExchanges.filter(_.isRight) - .foreach(_.right.get.willNotWorkOnGpu("other exchanges that feed the same join are" + - " on the CPU, and GPU hashing is not consistent with the CPU version")) + shuffleExchanges.filter(_.isRight).foreach(_.right.get.willNotWorkOnGpu(errMsg)) // verify that no query stages already got converted to GPU if (shuffleExchanges.filter(_.isLeft).exists(canThisBeReplaced)) { throw new IllegalStateException("Join needs to run on CPU but at least one input " + @@ -483,7 +506,6 @@ abstract class SparkPlanMeta[INPUT <: SparkPlan](plan: INPUT, } } - private def fixUpJoinConsistencyIfNeeded(): Unit = { childPlans.foreach(_.fixUpJoinConsistencyIfNeeded()) wrapped match {