From f7c6c57c7cadf8de6967c99427ca6516211dcda2 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Mon, 22 Mar 2021 19:45:05 -0500 Subject: [PATCH] broadcast exchange can fail when job group set (#1988) Signed-off-by: Thomas Graves --- integration_tests/src/main/python/join_test.py | 16 +++++++++++++++- .../execution/GpuBroadcastExchangeExec.scala | 7 ++----- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/integration_tests/src/main/python/join_test.py b/integration_tests/src/main/python/join_test.py index d68539a231e..957dae7053b 100644 --- a/integration_tests/src/main/python/join_test.py +++ b/integration_tests/src/main/python/join_test.py @@ -18,7 +18,7 @@ from conftest import is_databricks_runtime, is_emr_runtime from data_gen import * from marks import ignore_order, allow_non_gpu, incompat -from spark_session import with_spark_session, is_before_spark_310 +from spark_session import with_cpu_session, with_spark_session, is_before_spark_310 all_gen = [StringGen(), ByteGen(), ShortGen(), IntegerGen(), LongGen(), BooleanGen(), DateGen(), TimestampGen(), null_gen, @@ -80,6 +80,20 @@ def do_join(spark): return left.join(broadcast(right), left.a == right.r_a, join_type) assert_gpu_and_cpu_are_equal_collect(do_join, conf=allow_negative_scale_of_decimal_conf) +# local sort because of https://github.com/NVIDIA/spark-rapids/issues/84 +# After 3.1.0 is the min spark version we can drop this +@ignore_order(local=True) +@pytest.mark.parametrize('data_gen', all_gen, ids=idfn) +# Not all join types can be translated to a broadcast join, but this tests them to be sure we +# can handle what spark is doing +@pytest.mark.parametrize('join_type', ['Left', 'Right', 'Inner', 'LeftSemi', 'LeftAnti', 'Cross', 'FullOuter'], ids=idfn) +def test_broadcast_join_right_table_with_job_group(data_gen, join_type): + with_cpu_session(lambda spark : spark.sparkContext.setJobGroup("testjob1", "test", False)) + def do_join(spark): + left, right = create_df(spark, data_gen, 500, 250) + return left.join(broadcast(right), left.a == right.r_a, join_type) + assert_gpu_and_cpu_are_equal_collect(do_join, conf=allow_negative_scale_of_decimal_conf) + # local sort because of https://github.com/NVIDIA/spark-rapids/issues/84 # After 3.1.0 is the min spark version we can drop this @ignore_order(local=True) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala index ff761e769fd..b72b84296c9 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala @@ -29,7 +29,7 @@ import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.GpuMetric._ import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.SparkException import org.apache.spark.broadcast.Broadcast import org.apache.spark.launcher.SparkLauncher import org.apache.spark.rdd.RDD @@ -265,10 +265,7 @@ abstract class GpuBroadcastExchangeExecBase( @transient private val timeout: Long = SQLConf.get.broadcastTimeout - // Cancelling a SQL statement from Spark ThriftServer needs to cancel - // its related broadcast sub-jobs. So set the run id to job group id if exists. - val _runId: UUID = Option(sparkContext.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID)) - .map(UUID.fromString).getOrElse(UUID.randomUUID) + val _runId: UUID = UUID.randomUUID() @transient lazy val relationFuture: Future[Broadcast[Any]] = {