diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 3a41b0553db54..189740e313207 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2784,7 +2784,15 @@ class SQLConf extends Serializable with Logging { def cacheVectorizedReaderEnabled: Boolean = getConf(CACHE_VECTORIZED_READER_ENABLED) - def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS) + def defaultNumShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS) + + def numShufflePartitions: Int = { + if (adaptiveExecutionEnabled && coalesceShufflePartitionsEnabled) { + getConf(COALESCE_PARTITIONS_INITIAL_PARTITION_NUM).getOrElse(defaultNumShufflePartitions) + } else { + defaultNumShufflePartitions + } + } def adaptiveExecutionEnabled: Boolean = getConf(ADAPTIVE_EXECUTION_ENABLED) @@ -2797,9 +2805,6 @@ class SQLConf extends Serializable with Logging { def coalesceShufflePartitionsEnabled: Boolean = getConf(COALESCE_PARTITIONS_ENABLED) - def initialShufflePartitionNum: Int = - getConf(COALESCE_PARTITIONS_INITIAL_PARTITION_NUM).getOrElse(numShufflePartitions) - def minBatchesToRetain: Int = getConf(MIN_BATCHES_TO_RETAIN) def maxBatchesToRetainInMemory: Int = getConf(MAX_BATCHES_TO_RETAIN_IN_MEMORY) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 28ef793ed62db..3242ac21ab324 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -35,12 +35,6 @@ import org.apache.spark.sql.internal.SQLConf * the input partition ordering requirements are met. */ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { - private def defaultNumPreShufflePartitions: Int = - if (conf.adaptiveExecutionEnabled && conf.coalesceShufflePartitionsEnabled) { - conf.initialShufflePartitionNum - } else { - conf.numShufflePartitions - } private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = { val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution @@ -57,7 +51,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { BroadcastExchangeExec(mode, child) case (child, distribution) => val numPartitions = distribution.requiredNumPartitions - .getOrElse(defaultNumPreShufflePartitions) + .getOrElse(conf.numShufflePartitions) ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child) } @@ -95,7 +89,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { // expected number of shuffle partitions. However, if it's smaller than // `conf.numShufflePartitions`, we pick `conf.numShufflePartitions` as the // expected number of shuffle partitions. - math.max(nonShuffleChildrenNumPartitions.max, conf.numShufflePartitions) + math.max(nonShuffleChildrenNumPartitions.max, conf.defaultNumShufflePartitions) } else { childrenNumPartitions.max } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 3d0ba05f76b71..9fa97bffa8910 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -1021,4 +1021,20 @@ class AdaptiveQueryExecSuite } } } + + test("SPARK-31220 repartition obeys initialPartitionNum when adaptiveExecutionEnabled") { + Seq(true, false).foreach { enableAQE => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString, + SQLConf.SHUFFLE_PARTITIONS.key -> "6", + SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "7") { + val partitionsNum = spark.range(10).repartition($"id").rdd.collectPartitions().length + if (enableAQE) { + assert(partitionsNum === 7) + } else { + assert(partitionsNum === 6) + } + } + } + } }