Skip to content

Commit

Permalink
Fix regression in AQE optimizations (#4354)
Browse files Browse the repository at this point in the history
* Fix regression in AQE

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* simplify test

* sign-off

Signed-off-by: Andy Grove <andygrove@nvidia.com>
  • Loading branch information
andygrove authored Dec 14, 2021
1 parent bc4cd09 commit 3c8e5df
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,14 @@ class GpuTransitionOverrides extends Rule[SparkPlan] {
plan: SparkPlan,
parent: Option[SparkPlan]): SparkPlan = plan match {

case GpuBringBackToHost(child) if parent.isEmpty =>
case bb @ GpuBringBackToHost(child) if parent.isEmpty =>
// This is hacky but we need to remove the GpuBringBackToHost from the final
// query stage, if there is one. It gets inserted by
// GpuTransitionOverrides.insertColumnarFromGpu around columnar adaptive
// plans when we are writing to columnar formats on the GPU. It would be nice to avoid
// inserting it in the first place but we just don't have enough context
// at the time GpuTransitionOverrides is applying rules.
child
optimizeAdaptiveTransitions(child, Some(bb))

// HostColumnarToGpu(RowToColumnarExec(..)) => GpuRowToColumnarExec(..)
case HostColumnarToGpu(r2c: RowToColumnarExec, goal) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,31 @@ class AdaptiveQueryExecSuite
}, conf)
}

// repro case for https://github.com/NVIDIA/spark-rapids/issues/4351
test("Write parquet from AQE shuffle with limit") {
logError("Write parquet from AQE shuffle with limit")

val conf = new SparkConf()
.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")

withGpuSparkSession(spark => {
import spark.implicits._

val path = new File(TEST_FILES_ROOT, "AvoidTransitionInput.parquet").getAbsolutePath
(0 until 100).toDF("a")
.write
.mode(SaveMode.Overwrite)
.parquet(path)

val outputPath = new File(TEST_FILES_ROOT, "AvoidTransitionOutput.parquet").getAbsolutePath
spark.read.parquet(path)
.limit(100)
.write.mode(SaveMode.Overwrite)
.parquet(outputPath)
}, conf)
}


test("Exchange reuse") {
logError("Exchange reuse")
assumeSpark301orLater
Expand Down

0 comments on commit 3c8e5df

Please sign in to comment.