Skip to content

Commit

Permalink
[SPARK-44897][SQL] Propagating local properties to subquery broadcast…
Browse files Browse the repository at this point in the history
… exec

### What changes were proposed in this pull request?
https://issues.apache.org/jira/browse/SPARK-32748 previously proposed propagating these local properties to the subquery broadcast exec threads but was then reverted since it was said that local properties would already be propagated to the broadcast threads.
I believe this is not always true. In the scenario where a separate `BroadcastExchangeExec` is the first to compute the broadcast, this is fine. However, in the scenario where the `SubqueryBroadcastExec` is the first to compute the broadcast, then the local properties that are propagated to the broadcast threads would not have been propagated correctly. This is because the local properties from the subquery broadcast exec were not propagated to its Future thread.
It is difficult to write a unit test that reproduces this behavior because usually `BroadcastExchangeExec` is the first computing the broadcast variable. However, by adding a `Thread.sleep(10)` to `SubqueryBroadcastExec.doPrepare` after `relationFuture` is initialized, the added test will consistently fail.

### Why are the changes needed?
Local properties are not propagated correctly to `SubqueryBroadcastExec`

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Following test can reproduce the bug and test the solution by adding sleep to `SubqueryBroadcastExec.doPrepare`
```
protected override def doPrepare(): Unit = {
    relationFuture
    Thread.sleep(10)
}
```

```test("SPARK-44897 propagate local properties to subquery broadcast execuction thread") {
    withSQLConf(StaticSQLConf.BROADCAST_EXCHANGE_MAX_THREAD_THRESHOLD.key -> "1") {
      withTable("a", "b") {
        val confKey = "spark.sql.y"
        val confValue1 = UUID.randomUUID().toString()
        val confValue2 = UUID.randomUUID().toString()
        Seq((confValue1, "1")).toDF("key", "value")
          .write
          .format("parquet")
          .partitionBy("key")
          .mode("overwrite")
          .saveAsTable("a")
        val df1 = spark.table("a")

        def generateBroadcastDataFrame(confKey: String, confValue: String): Dataset[String] = {
          val df = spark.range(1).mapPartitions { _ =>
            Iterator(TaskContext.get.getLocalProperty(confKey))
          }.filter($"value".contains(confValue)).as("c")
          df.hint("broadcast")
        }

        // set local property and assert
        val df2 = generateBroadcastDataFrame(confKey, confValue1)
        spark.sparkContext.setLocalProperty(confKey, confValue1)
        val checkDF = df1.join(df2).where($"a.key" === $"c.value").select($"a.key", $"c.value")
        val checks = checkDF.collect()
        assert(checks.forall(_.toSeq == Seq(confValue1, confValue1)))

        // change local property and re-assert
        Seq((confValue2, "1")).toDF("key", "value")
          .write
          .format("parquet")
          .partitionBy("key")
          .mode("overwrite")
          .saveAsTable("b")
        val df3 = spark.table("b")
        val df4 = generateBroadcastDataFrame(confKey, confValue2)
        spark.sparkContext.setLocalProperty(confKey, confValue2)
        val checks2DF = df3.join(df4).where($"b.key" === $"c.value").select($"b.key", $"c.value")
        val checks2 = checks2DF.collect()
        assert(checks2.forall(_.toSeq == Seq(confValue2, confValue2)))
        assert(checks2.nonEmpty)
      }
    }
  }
  ```

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #42587 from ChenMichael/SPARK-44897-local-property-propagation-to-subquery-broadcast-exec.

Authored-by: Michael Chen <mike.chen@workday.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit 4a48562)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
Michael Chen authored and cloud-fan committed Aug 28, 2023
1 parent 9b00d36 commit 2f4a712
Showing 1 changed file with 10 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

package org.apache.spark.sql.execution

import scala.concurrent.{ExecutionContext, Future}
import java.util.concurrent.{Future => JFuture}

import scala.concurrent.ExecutionContext
import scala.concurrent.duration.Duration

import org.apache.spark.rdd.RDD
Expand All @@ -27,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.joins.{HashedRelation, HashJoin, LongHashedRelation}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
import org.apache.spark.util.ThreadUtils

/**
Expand Down Expand Up @@ -70,10 +73,11 @@ case class SubqueryBroadcastExec(
}

@transient
private lazy val relationFuture: Future[Array[InternalRow]] = {
private lazy val relationFuture: JFuture[Array[InternalRow]] = {
// relationFuture is used in "doExecute". Therefore we can get the execution id correctly here.
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
Future {
SQLExecution.withThreadLocalCaptured[Array[InternalRow]](
session, SubqueryBroadcastExec.executionContext) {
// This will run in another thread. Set the execution id so that we can connect these jobs
// with the correct execution.
SQLExecution.withExecutionId(session, executionId) {
Expand Down Expand Up @@ -104,7 +108,7 @@ case class SubqueryBroadcastExec(

rows
}
}(SubqueryBroadcastExec.executionContext)
}
}

protected override def doPrepare(): Unit = {
Expand All @@ -127,5 +131,6 @@ case class SubqueryBroadcastExec(

object SubqueryBroadcastExec {
private[execution] val executionContext = ExecutionContext.fromExecutorService(
ThreadUtils.newDaemonCachedThreadPool("dynamicpruning", 16))
ThreadUtils.newDaemonCachedThreadPool("dynamicpruning",
SQLConf.get.getConf(StaticSQLConf.BROADCAST_EXCHANGE_MAX_THREAD_THRESHOLD)))
}

0 comments on commit 2f4a712

Please sign in to comment.