Skip to content

Commit

Permalink
don't count it if shuffle after bucketed read before join
Browse files Browse the repository at this point in the history
Signed-off-by: Thomas Graves <tgraves@nvidia.com>
  • Loading branch information
tgravescs committed Sep 16, 2020
1 parent 1040bee commit 4252d47
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
8 changes: 6 additions & 2 deletions integration_tests/src/main/python/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,8 @@ def do_join(spark):

@ignore_order
@allow_non_gpu('DataWritingCommandExec')
def test_join_bucketed_table():
@pytest.mark.parametrize('repartition', ["true", "false"], ids=idfn)
def test_join_bucketed_table(repartition):
def do_join(spark):
data = [("http://fooblog.com/blog-entry-116.html", "https://fooblog.com/blog-entry-116.html"),
("http://fooblog.com/blog-entry-116.html", "http://fooblog.com/blog-entry-116.html")]
Expand All @@ -175,6 +176,9 @@ def do_join(spark):
feature.write.bucketBy(400, 'Url').sortBy('Url').format('parquet').mode('overwrite')\
.saveAsTable('featuretable')
testurls = spark.sql("SELECT Url, Count FROM featuretable")
return testurls.join(resolved, "Url", "inner")
if (repartition == "true"):
return testurls.repartition(20).join(resolved, "Url", "inner")
else:
return testurls.join(resolved, "Url", "inner")
assert_gpu_and_cpu_are_equal_collect(do_join, conf={'spark.sql.autoBroadcastJoinThreshold': '-1'})

Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,10 @@ abstract class SparkPlanMeta[INPUT <: SparkPlan](plan: INPUT,
} else {
false :: Nil
}
case _: ShuffleExchangeExec =>
// if we find a shuffle before a scan then it doesn't matter if its
// a bucketed read
false :: Nil
case _ =>
childPlans.flatMap(_.findBucketedReads())
}
Expand Down Expand Up @@ -486,7 +490,8 @@ abstract class SparkPlanMeta[INPUT <: SparkPlan](plan: INPUT,
// then we need to make sure that all of them run on the CPU instead
if (bucketedReads || !shuffleExchanges.forall(canThisBeReplaced)) {
val errMsg = if (bucketedReads) {
"can't support shuffle on the GPU with bucketed reads!"
"can't support shuffle on the GPU when doing a join that reads directly from a " +
"bucketed table!"
} else {
"other exchanges that feed the same join are on the CPU, and GPU hashing is " +
"not consistent with the CPU version"
Expand Down

0 comments on commit 4252d47

Please sign in to comment.