diff --git a/integration_tests/src/main/python/repart_test.py b/integration_tests/src/main/python/repart_test.py index 4c6116f5f48..8b8df12c2cb 100644 --- a/integration_tests/src/main/python/repart_test.py +++ b/integration_tests/src/main/python/repart_test.py @@ -267,3 +267,14 @@ def test_hash_repartition_exact(gen, num_parts): .withColumn('hashed', f.hash(*part_on))\ .selectExpr('*', 'pmod(hashed, {})'.format(num_parts)), conf = allow_negative_scale_of_decimal_conf) + +# Test a query that should cause Spark to leverage getShuffleRDD +@ignore_order(local=True) +def test_union_with_filter(): + def doit(spark): + dfa = spark.range(1, 100).withColumn("id2", f.col("id")) + dfb = dfa.groupBy("id").agg(f.size(f.collect_set("id2")).alias("idc")) + dfc = dfb.filter(f.col("idc") == 1).select("id") + return dfc.union(dfc) + conf = { "spark.sql.adaptive.enabled": "true" } + assert_gpu_and_cpu_are_equal_collect(doit, conf) diff --git a/sql-plugin/src/main/301+-nondb/scala/com/nvidia/spark/rapids/shims/v2/AQEUtils.scala b/sql-plugin/src/main/301+-nondb/scala/com/nvidia/spark/rapids/shims/v2/AQEUtils.scala new file mode 100644 index 00000000000..2adb6d96c10 --- /dev/null +++ b/sql-plugin/src/main/301+-nondb/scala/com/nvidia/spark/rapids/shims/v2/AQEUtils.scala @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims.v2 + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.execution.adaptive.{QueryStageExec, ShuffleQueryStageExec} + +/** Utility methods for manipulating Catalyst classes involved in Adaptive Query Execution */ +object AQEUtils { + /** Return a new QueryStageExec reuse instance with updated output attributes */ + def newReuseInstance(sqse: ShuffleQueryStageExec, newOutput: Seq[Attribute]): QueryStageExec = { + sqse.newReuseInstance(sqse.id, newOutput) + } +} diff --git a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/AQEUtils.scala b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/AQEUtils.scala new file mode 100644 index 00000000000..2adb6d96c10 --- /dev/null +++ b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/AQEUtils.scala @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims.v2 + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.execution.adaptive.{QueryStageExec, ShuffleQueryStageExec} + +/** Utility methods for manipulating Catalyst classes involved in Adaptive Query Execution */ +object AQEUtils { + /** Return a new QueryStageExec reuse instance with updated output attributes */ + def newReuseInstance(sqse: ShuffleQueryStageExec, newOutput: Seq[Attribute]): QueryStageExec = { + sqse.newReuseInstance(sqse.id, newOutput) + } +} diff --git a/sql-plugin/src/main/301db/scala/org/apache/spark/rapids/shims/v2/GpuShuffleExchangeExec.scala b/sql-plugin/src/main/301db/scala/org/apache/spark/rapids/shims/v2/GpuShuffleExchangeExec.scala index 12fe559cccc..b3ea5fbcafd 100644 --- a/sql-plugin/src/main/301db/scala/org/apache/spark/rapids/shims/v2/GpuShuffleExchangeExec.scala +++ b/sql-plugin/src/main/301db/scala/org/apache/spark/rapids/shims/v2/GpuShuffleExchangeExec.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{ShufflePartitionSpec, SparkPlan} import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike -import org.apache.spark.sql.rapids.execution.GpuShuffleExchangeExecBaseWithMetrics +import org.apache.spark.sql.rapids.execution.{GpuShuffleExchangeExecBaseWithMetrics, ShuffledBatchRDD} case class GpuShuffleExchangeExec( gpuOutputPartitioning: GpuPartitioning, @@ -43,7 +43,7 @@ case class GpuShuffleExchangeExec( override def getShuffleRDD( partitionSpecs: Array[ShufflePartitionSpec], partitionSizes: Option[Array[Long]]): RDD[_] = { - throw new UnsupportedOperationException + new ShuffledBatchRDD(shuffleDependencyColumnar, metrics ++ readMetrics, partitionSpecs) } override def runtimeStatistics: Statistics = { diff --git a/sql-plugin/src/main/301until310-nondb/scala/org/apache/spark/rapids/shims/v2/GpuShuffleExchangeExec.scala b/sql-plugin/src/main/301until310-nondb/scala/org/apache/spark/rapids/shims/v2/GpuShuffleExchangeExec.scala index 1c2596eaa29..60460525fa6 100644 --- a/sql-plugin/src/main/301until310-nondb/scala/org/apache/spark/rapids/shims/v2/GpuShuffleExchangeExec.scala +++ b/sql-plugin/src/main/301until310-nondb/scala/org/apache/spark/rapids/shims/v2/GpuShuffleExchangeExec.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{ShufflePartitionSpec, SparkPlan} import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike -import org.apache.spark.sql.rapids.execution.GpuShuffleExchangeExecBaseWithMetrics +import org.apache.spark.sql.rapids.execution.{GpuShuffleExchangeExecBaseWithMetrics, ShuffledBatchRDD} case class GpuShuffleExchangeExec( gpuOutputPartitioning: GpuPartitioning, @@ -41,7 +41,7 @@ case class GpuShuffleExchangeExec( override def numPartitions: Int = shuffleDependencyColumnar.partitioner.numPartitions override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[_] = { - throw new UnsupportedOperationException + new ShuffledBatchRDD(shuffleDependencyColumnar, metrics ++ readMetrics, partitionSpecs) } override def runtimeStatistics: Statistics = { diff --git a/sql-plugin/src/main/311+-nondb/scala/org/apache/spark/rapids/shims/v2/GpuShuffleExchangeExec.scala b/sql-plugin/src/main/311+-nondb/scala/org/apache/spark/rapids/shims/v2/GpuShuffleExchangeExec.scala index 1ca57a75235..7528ba635b9 100644 --- a/sql-plugin/src/main/311+-nondb/scala/org/apache/spark/rapids/shims/v2/GpuShuffleExchangeExec.scala +++ b/sql-plugin/src/main/311+-nondb/scala/org/apache/spark/rapids/shims/v2/GpuShuffleExchangeExec.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{ShufflePartitionSpec, SparkPlan} import org.apache.spark.sql.execution.exchange.{ShuffleExchangeLike, ShuffleOrigin} -import org.apache.spark.sql.rapids.execution.GpuShuffleExchangeExecBaseWithMetrics +import org.apache.spark.sql.rapids.execution.{GpuShuffleExchangeExecBaseWithMetrics, ShuffledBatchRDD} case class GpuShuffleExchangeExec( gpuOutputPartitioning: GpuPartitioning, @@ -42,7 +42,7 @@ case class GpuShuffleExchangeExec( override def numPartitions: Int = shuffleDependencyColumnar.partitioner.numPartitions override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[_] = { - throw new UnsupportedOperationException + new ShuffledBatchRDD(shuffleDependencyColumnar, metrics ++ readMetrics, partitionSpecs) } override def runtimeStatistics: Statistics = { diff --git a/sql-plugin/src/main/312db/scala/com/nvidia/spark/rapids/shims/v2/AQEUtils.scala b/sql-plugin/src/main/312db/scala/com/nvidia/spark/rapids/shims/v2/AQEUtils.scala new file mode 100644 index 00000000000..df2aee9268c --- /dev/null +++ b/sql-plugin/src/main/312db/scala/com/nvidia/spark/rapids/shims/v2/AQEUtils.scala @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims.v2 + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.execution.adaptive.{QueryStageExec, ShuffleQueryStageExec} +import org.apache.spark.sql.execution.exchange.ReusedExchangeExec + +/** Utility methods for manipulating Catalyst classes involved in Adaptive Query Execution */ +object AQEUtils { + /** Return a new QueryStageExec reuse instance with updated output attributes */ + def newReuseInstance(sqse: ShuffleQueryStageExec, newOutput: Seq[Attribute]): QueryStageExec = { + val reusedExchange = ReusedExchangeExec(newOutput, sqse.shuffle) + ShuffleQueryStageExec(sqse.id, reusedExchange, sqse.originalPlan) + } +} diff --git a/sql-plugin/src/main/312db/scala/org/apache/spark/rapids/shims/v2/GpuShuffleExchangeExec.scala b/sql-plugin/src/main/312db/scala/org/apache/spark/rapids/shims/v2/GpuShuffleExchangeExec.scala index 02bc0dda8ea..07bcefe93c7 100644 --- a/sql-plugin/src/main/312db/scala/org/apache/spark/rapids/shims/v2/GpuShuffleExchangeExec.scala +++ b/sql-plugin/src/main/312db/scala/org/apache/spark/rapids/shims/v2/GpuShuffleExchangeExec.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{ShufflePartitionSpec, SparkPlan} import org.apache.spark.sql.execution.exchange.{ShuffleExchangeLike, ShuffleOrigin} -import org.apache.spark.sql.rapids.execution.GpuShuffleExchangeExecBase +import org.apache.spark.sql.rapids.execution.{GpuShuffleExchangeExecBase, ShuffledBatchRDD} case class GpuShuffleExchangeExec( gpuOutputPartitioning: GpuPartitioning, @@ -54,7 +54,7 @@ case class GpuShuffleExchangeExec( override def getShuffleRDD( partitionSpecs: Array[ShufflePartitionSpec], partitionSizes: Option[Array[Long]]): RDD[_] = { - throw new UnsupportedOperationException + new ShuffledBatchRDD(shuffleDependencyColumnar, metrics ++ readMetrics, partitionSpecs) } // DB SPECIFIC - throw if called since we don't know how its used diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 919ffc775e3..abeb3f1fb19 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -24,7 +24,7 @@ import scala.util.control.NonFatal import ai.rapids.cudf.DType import com.nvidia.spark.rapids.RapidsConf.{SUPPRESS_PLANNING_FAILURE, TEST_CONF} -import com.nvidia.spark.rapids.shims.v2.{GpuSpecifiedWindowFrameMeta, GpuWindowExpressionMeta, OffsetWindowFunctionMeta} +import com.nvidia.spark.rapids.shims.v2.{AQEUtils, GpuSpecifiedWindowFrameMeta, GpuWindowExpressionMeta, OffsetWindowFunctionMeta} import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, SparkSession} @@ -548,6 +548,30 @@ object GpuOverrides extends Logging { } } + /** + * Searches the plan for ReusedExchangeExec instances containing a GPU shuffle where the + * output types between the two plan nodes do not match. In such a case the ReusedExchangeExec + * will be updated to match the GPU shuffle output types. + */ + def fixupReusedExchangeExecs(plan: SparkPlan): SparkPlan = { + def outputTypesMatch(a: Seq[Attribute], b: Seq[Attribute]): Boolean = + a.corresponds(b)((x, y) => x.dataType == y.dataType) + plan.transformUp { + case sqse: ShuffleQueryStageExec => + sqse.plan match { + case ReusedExchangeExec(output, gsee: GpuShuffleExchangeExecBase) if ( + !outputTypesMatch(output, gsee.output)) => + val newOutput = sqse.plan.output.zip(gsee.output).map { case (c, g) => + assert(c.isInstanceOf[AttributeReference] && g.isInstanceOf[AttributeReference], + s"Expected AttributeReference but found $c and $g") + AttributeReference(c.name, g.dataType, c.nullable, c.metadata)(c.exprId, c.qualifier) + } + AQEUtils.newReuseInstance(sqse, newOutput) + case _ => sqse + } + } + } + @scala.annotation.tailrec def extractLit(exp: Expression): Option[Literal] = exp match { case l: Literal => Some(l) @@ -3910,7 +3934,11 @@ case class GpuOverrides() extends Rule[SparkPlan] with Logging { val updatedPlan = if (plan.conf.adaptiveExecutionEnabled) { // AQE can cause Spark to inject undesired CPU shuffles into the plan because GPU and CPU // distribution expressions are not semantically equal. - GpuOverrides.removeExtraneousShuffles(plan, conf) + val newPlan = GpuOverrides.removeExtraneousShuffles(plan, conf) + + // AQE can cause ReusedExchangeExec instance to cache the wrong aggregation buffer type + // compared to the desired buffer type from a reused GPU shuffle. + GpuOverrides.fixupReusedExchangeExecs(newPlan) } else { plan }