Skip to content

Commit

Permalink
broadcast exchange can fail when job group set (#1988)
Browse files Browse the repository at this point in the history
Signed-off-by: Thomas Graves <tgraves@nvidia.com>
  • Loading branch information
tgravescs authored Mar 23, 2021
1 parent c785c04 commit f7c6c57
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
16 changes: 15 additions & 1 deletion integration_tests/src/main/python/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from conftest import is_databricks_runtime, is_emr_runtime
from data_gen import *
from marks import ignore_order, allow_non_gpu, incompat
from spark_session import with_spark_session, is_before_spark_310
from spark_session import with_cpu_session, with_spark_session, is_before_spark_310

all_gen = [StringGen(), ByteGen(), ShortGen(), IntegerGen(), LongGen(),
BooleanGen(), DateGen(), TimestampGen(), null_gen,
Expand Down Expand Up @@ -80,6 +80,20 @@ def do_join(spark):
return left.join(broadcast(right), left.a == right.r_a, join_type)
assert_gpu_and_cpu_are_equal_collect(do_join, conf=allow_negative_scale_of_decimal_conf)

# local sort because of https://github.com/NVIDIA/spark-rapids/issues/84
# After 3.1.0 is the min spark version we can drop this
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', all_gen, ids=idfn)
# Not all join types can be translated to a broadcast join, but this tests them to be sure we
# can handle what spark is doing
@pytest.mark.parametrize('join_type', ['Left', 'Right', 'Inner', 'LeftSemi', 'LeftAnti', 'Cross', 'FullOuter'], ids=idfn)
def test_broadcast_join_right_table_with_job_group(data_gen, join_type):
with_cpu_session(lambda spark : spark.sparkContext.setJobGroup("testjob1", "test", False))
def do_join(spark):
left, right = create_df(spark, data_gen, 500, 250)
return left.join(broadcast(right), left.a == right.r_a, join_type)
assert_gpu_and_cpu_are_equal_collect(do_join, conf=allow_negative_scale_of_decimal_conf)

# local sort because of https://github.com/NVIDIA/spark-rapids/issues/84
# After 3.1.0 is the min spark version we can drop this
@ignore_order(local=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.GpuMetric._
import com.nvidia.spark.rapids.RapidsPluginImplicits._

import org.apache.spark.{SparkContext, SparkException}
import org.apache.spark.SparkException
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.launcher.SparkLauncher
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -265,10 +265,7 @@ abstract class GpuBroadcastExchangeExecBase(
@transient
private val timeout: Long = SQLConf.get.broadcastTimeout

// Cancelling a SQL statement from Spark ThriftServer needs to cancel
// its related broadcast sub-jobs. So set the run id to job group id if exists.
val _runId: UUID = Option(sparkContext.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID))
.map(UUID.fromString).getOrElse(UUID.randomUUID)
val _runId: UUID = UUID.randomUUID()

@transient
lazy val relationFuture: Future[Broadcast[Any]] = {
Expand Down

0 comments on commit f7c6c57

Please sign in to comment.