Skip to content

Commit

Permalink
Make shuffle run on CPU if we do a join where we read from bucketed t…
Browse files Browse the repository at this point in the history
…able (NVIDIA#785)

* Make shuffle run on CPU if we do a join where we read from bucketed table

Signed-off-by: Thomas Graves <tgraves@nvidia.com>
  • Loading branch information
tgravescs authored Sep 17, 2020
1 parent b99069b commit 2694936
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 8 deletions.
20 changes: 20 additions & 0 deletions integration_tests/src/main/python/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,23 @@ def do_join(spark):
.withColumnRenamed("b", "r_b").withColumnRenamed("c", "r_c")
return left.join(broadcast(right), left.a.eqNullSafe(right.r_a), join_type)
assert_gpu_and_cpu_are_equal_collect(do_join)

@ignore_order
@allow_non_gpu('DataWritingCommandExec')
@pytest.mark.parametrize('repartition', ["true", "false"], ids=idfn)
def test_join_bucketed_table(repartition):
def do_join(spark):
data = [("http://fooblog.com/blog-entry-116.html", "https://fooblog.com/blog-entry-116.html"),
("http://fooblog.com/blog-entry-116.html", "http://fooblog.com/blog-entry-116.html")]
resolved = spark.sparkContext.parallelize(data).toDF(['Url','ResolvedUrl'])
feature_data = [("http://fooblog.com/blog-entry-116.html", "21")]
feature = spark.sparkContext.parallelize(feature_data).toDF(['Url','Count'])
feature.write.bucketBy(400, 'Url').sortBy('Url').format('parquet').mode('overwrite')\
.saveAsTable('featuretable')
testurls = spark.sql("SELECT Url, Count FROM featuretable")
if (repartition == "true"):
return testurls.repartition(20).join(resolved, "Url", "inner")
else:
return testurls.join(resolved, "Url", "inner")
assert_gpu_and_cpu_are_equal_collect(do_join, conf={'spark.sql.autoBroadcastJoinThreshold': '-1'})

38 changes: 30 additions & 8 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.adaptive.{QueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.command.DataWritingCommand
import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeExec}
Expand Down Expand Up @@ -447,13 +447,31 @@ abstract class SparkPlanMeta[INPUT <: SparkPlan](plan: INPUT,
case _ => childPlans.flatMap(_.findShuffleExchanges())
}

private def findBucketedReads(): Seq[Boolean] = wrapped match {
case f: FileSourceScanExec =>
if (f.bucketedScan) {
true :: Nil
} else {
false :: Nil
}
case _: ShuffleExchangeExec =>
// if we find a shuffle before a scan then it doesn't matter if its
// a bucketed read
false :: Nil
case _ =>
childPlans.flatMap(_.findBucketedReads())
}

private def makeShuffleConsistent(): Unit = {
// during query execution when AQE is enabled, the plan could consist of a mixture of
// ShuffleExchangeExec nodes for exchanges that have not started executing yet, and
// ShuffleQueryStageExec nodes for exchanges that have already started executing. This code
// attempts to tag ShuffleExchangeExec nodes for CPU if other exchanges (either
// ShuffleExchangeExec or ShuffleQueryStageExec nodes) were also tagged for CPU.
val shuffleExchanges = findShuffleExchanges()
// if any of the table reads are bucketed then we can't do the shuffle on the
// GPU because the hashing is different between the CPU and GPU
val bucketedReads = findBucketedReads().exists(_ == true)

def canThisBeReplaced(plan: Either[
SparkPlanMeta[QueryStageExec],
Expand All @@ -468,13 +486,18 @@ abstract class SparkPlanMeta[INPUT <: SparkPlan](plan: INPUT,
}
}

// if we can't convert all exchanges to GPU then we need to make sure that all of them
// run on the CPU instead
if (!shuffleExchanges.forall(canThisBeReplaced)) {
// if we are reading from a bucketed table or if we can't convert all exchanges to GPU
// then we need to make sure that all of them run on the CPU instead
if (bucketedReads || !shuffleExchanges.forall(canThisBeReplaced)) {
val errMsg = if (bucketedReads) {
"can't support shuffle on the GPU when doing a join that reads directly from a " +
"bucketed table!"
} else {
"other exchanges that feed the same join are on the CPU, and GPU hashing is " +
"not consistent with the CPU version"
}
// tag any exchanges that have not been converted to query stages yet
shuffleExchanges.filter(_.isRight)
.foreach(_.right.get.willNotWorkOnGpu("other exchanges that feed the same join are" +
" on the CPU, and GPU hashing is not consistent with the CPU version"))
shuffleExchanges.filter(_.isRight).foreach(_.right.get.willNotWorkOnGpu(errMsg))
// verify that no query stages already got converted to GPU
if (shuffleExchanges.filter(_.isLeft).exists(canThisBeReplaced)) {
throw new IllegalStateException("Join needs to run on CPU but at least one input " +
Expand All @@ -483,7 +506,6 @@ abstract class SparkPlanMeta[INPUT <: SparkPlan](plan: INPUT,
}
}


private def fixUpJoinConsistencyIfNeeded(): Unit = {
childPlans.foreach(_.fixUpJoinConsistencyIfNeeded())
wrapped match {
Expand Down

0 comments on commit 2694936

Please sign in to comment.