From a7cdaa9eff17293cae9cee04d16b07d924fd40eb Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Thu, 30 May 2024 08:49:53 -0700 Subject: [PATCH] Added Shim for BatchScanExec to Support Spark 4.0 [databricks] (#10944) * Added shim for BatchScanExec to support Spark 4.0 Signed-off-by: Raza Jafri * fixed the failing shim --------- Signed-off-by: Raza Jafri --- .../spark/rapids/shims/GpuBatchScanExec.scala | 1 - .../rapids/shims/BatchScanExecMeta.scala | 52 +--- .../rapids/shims/BatchScanExecMetaBase.scala | 81 ++++++ .../rapids/shims/BatchScanExecMeta.scala | 38 +++ .../spark/rapids/shims/GpuBatchScanExec.scala | 269 ++++++++++++++++++ 5 files changed, 389 insertions(+), 52 deletions(-) create mode 100644 sql-plugin/src/main/spark350/scala/com/nvidia/spark/rapids/shims/BatchScanExecMetaBase.scala create mode 100644 sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/BatchScanExecMeta.scala create mode 100644 sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala diff --git a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala index 39f42d8b833..5fb252524fd 100644 --- a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala +++ b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala @@ -22,7 +22,6 @@ {"spark": "343"} {"spark": "350"} {"spark": "351"} -{"spark": "400"} spark-rapids-shim-json-lines ***/ package com.nvidia.spark.rapids.shims diff --git a/sql-plugin/src/main/spark350/scala/com/nvidia/spark/rapids/shims/BatchScanExecMeta.scala b/sql-plugin/src/main/spark350/scala/com/nvidia/spark/rapids/shims/BatchScanExecMeta.scala index 4bbc4644241..4b29de25bf0 100644 --- a/sql-plugin/src/main/spark350/scala/com/nvidia/spark/rapids/shims/BatchScanExecMeta.scala +++ b/sql-plugin/src/main/spark350/scala/com/nvidia/spark/rapids/shims/BatchScanExecMeta.scala @@ -17,68 +17,18 @@ /*** spark-rapids-shim-json-lines {"spark": "350"} {"spark": "351"} -{"spark": "400"} spark-rapids-shim-json-lines ***/ package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.v2.BatchScanExec class BatchScanExecMeta(p: BatchScanExec, conf: RapidsConf, parent: Option[RapidsMeta[_, _, _]], rule: DataFromReplacementRule) - extends SparkPlanMeta[BatchScanExec](p, conf, parent, rule) { - // Replaces SubqueryBroadcastExec inside dynamic pruning filters with GPU counterpart - // if possible. Instead regarding filters as childExprs of current Meta, we create - // a new meta for SubqueryBroadcastExec. The reason is that the GPU replacement of - // BatchScanExec is independent from the replacement of the runtime filters. It is - // possible that the BatchScanExec is on the CPU, while the dynamic runtime filters - // are on the GPU. And vice versa. - private lazy val runtimeFilters = { - val convertBroadcast = (bc: SubqueryBroadcastExec) => { - val meta = GpuOverrides.wrapAndTagPlan(bc, conf) - meta.tagForExplain() - meta.convertIfNeeded().asInstanceOf[BaseSubqueryExec] - } - wrapped.runtimeFilters.map { filter => - filter.transformDown { - case dpe @ DynamicPruningExpression(inSub: InSubqueryExec) => - inSub.plan match { - case bc: SubqueryBroadcastExec => - dpe.copy(inSub.copy(plan = convertBroadcast(bc))) - case reuse @ ReusedSubqueryExec(bc: SubqueryBroadcastExec) => - dpe.copy(inSub.copy(plan = reuse.copy(convertBroadcast(bc)))) - case _ => - dpe - } - } - } - } - - override val childExprs: Seq[BaseExprMeta[_]] = { - // We want to leave the runtime filters as CPU expressions - p.output.map(GpuOverrides.wrapExpr(_, conf, Some(this))) - } - - override val childScans: scala.Seq[ScanMeta[_]] = - Seq(GpuOverrides.wrapScan(p.scan, conf, Some(this))) - - override def tagPlanForGpu(): Unit = { - if (!p.runtimeFilters.isEmpty && !childScans.head.supportsRuntimeFilters) { - willNotWorkOnGpu("runtime filtering (DPP) is not supported for this scan") - } - } - - override def convertToCpu(): SparkPlan = { - val cpu = wrapped.copy(runtimeFilters = runtimeFilters) - cpu.copyTagsFrom(wrapped) - cpu - } - + extends BatchScanExecMetaBase(p, conf, parent, rule) { override def convertToGpu(): GpuExec = { val spj = p.spjParams GpuBatchScanExec(p.output, childScans.head.convertToGpu(), runtimeFilters, diff --git a/sql-plugin/src/main/spark350/scala/com/nvidia/spark/rapids/shims/BatchScanExecMetaBase.scala b/sql-plugin/src/main/spark350/scala/com/nvidia/spark/rapids/shims/BatchScanExecMetaBase.scala new file mode 100644 index 00000000000..914702a289c --- /dev/null +++ b/sql-plugin/src/main/spark350/scala/com/nvidia/spark/rapids/shims/BatchScanExecMetaBase.scala @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "350"} +{"spark": "351"} +{"spark": "400"} +spark-rapids-shim-json-lines ***/ +package com.nvidia.spark.rapids.shims + +import com.nvidia.spark.rapids._ + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec + +abstract class BatchScanExecMetaBase(p: BatchScanExec, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _, _]], + rule: DataFromReplacementRule) + extends SparkPlanMeta[BatchScanExec](p, conf, parent, rule) { + // Replaces SubqueryBroadcastExec inside dynamic pruning filters with GPU counterpart + // if possible. Instead regarding filters as childExprs of current Meta, we create + // a new meta for SubqueryBroadcastExec. The reason is that the GPU replacement of + // BatchScanExec is independent from the replacement of the runtime filters. It is + // possible that the BatchScanExec is on the CPU, while the dynamic runtime filters + // are on the GPU. And vice versa. + protected lazy val runtimeFilters = { + val convertBroadcast = (bc: SubqueryBroadcastExec) => { + val meta = GpuOverrides.wrapAndTagPlan(bc, conf) + meta.tagForExplain() + meta.convertIfNeeded().asInstanceOf[BaseSubqueryExec] + } + wrapped.runtimeFilters.map { filter => + filter.transformDown { + case dpe @ DynamicPruningExpression(inSub: InSubqueryExec) => + inSub.plan match { + case bc: SubqueryBroadcastExec => + dpe.copy(inSub.copy(plan = convertBroadcast(bc))) + case reuse @ ReusedSubqueryExec(bc: SubqueryBroadcastExec) => + dpe.copy(inSub.copy(plan = reuse.copy(convertBroadcast(bc)))) + case _ => + dpe + } + } + } + } + + override val childExprs: Seq[BaseExprMeta[_]] = { + // We want to leave the runtime filters as CPU expressions + p.output.map(GpuOverrides.wrapExpr(_, conf, Some(this))) + } + + override val childScans: scala.Seq[ScanMeta[_]] = + Seq(GpuOverrides.wrapScan(p.scan, conf, Some(this))) + + override def tagPlanForGpu(): Unit = { + if (!p.runtimeFilters.isEmpty && !childScans.head.supportsRuntimeFilters) { + willNotWorkOnGpu("runtime filtering (DPP) is not supported for this scan") + } + } + + override def convertToCpu(): SparkPlan = { + val cpu = wrapped.copy(runtimeFilters = runtimeFilters) + cpu.copyTagsFrom(wrapped) + cpu + } +} diff --git a/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/BatchScanExecMeta.scala b/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/BatchScanExecMeta.scala new file mode 100644 index 00000000000..e6c26eb65b8 --- /dev/null +++ b/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/BatchScanExecMeta.scala @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "400"} +spark-rapids-shim-json-lines ***/ +package com.nvidia.spark.rapids.shims + +import com.nvidia.spark.rapids._ + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec + +class BatchScanExecMeta(p: BatchScanExec, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _, _]], + rule: DataFromReplacementRule) + extends BatchScanExecMetaBase(p, conf, parent, rule) { + override def convertToGpu(): GpuExec = { + val spj = p.spjParams + GpuBatchScanExec(p.output, childScans.head.convertToGpu(), runtimeFilters, + p.ordering, p.table, spj) + } +} diff --git a/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala b/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala new file mode 100644 index 00000000000..3c2b649339b --- /dev/null +++ b/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala @@ -0,0 +1,269 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "400"} +spark-rapids-shim-json-lines ***/ +package com.nvidia.spark.rapids.shims + +import com.google.common.base.Objects +import com.nvidia.spark.rapids.GpuScan + +import org.apache.spark.SparkException +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, DynamicPruningExpression, Expression, Literal, RowOrdering, SortOrder} +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, KeyGroupedShuffleSpec, Partitioning, SinglePartition} +import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper} +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.connector.read._ +import org.apache.spark.sql.execution.datasources.rapids.DataSourceStrategyUtils +import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDD, StoragePartitionJoinParams} +import org.apache.spark.sql.internal.SQLConf + +case class GpuBatchScanExec( + output: Seq[AttributeReference], + @transient scan: GpuScan, + runtimeFilters: Seq[Expression] = Seq.empty, + ordering: Option[Seq[SortOrder]] = None, + @transient table: Table, + spjParams: StoragePartitionJoinParams = StoragePartitionJoinParams() + ) extends GpuBatchScanExecBase(scan, runtimeFilters) { + + @transient lazy val batch: Batch = if (scan == null) null else scan.toBatch + // TODO: unify the equal/hashCode implementation for all data source v2 query plans. + override def equals(other: Any): Boolean = other match { + case other: GpuBatchScanExec => + this.batch != null && this.batch == other.batch && + this.runtimeFilters == other.runtimeFilters && + this.spjParams == other.spjParams + case _ => + false + } + + override def hashCode(): Int = Objects.hashCode(batch, runtimeFilters) + + @transient override lazy val inputPartitions: Seq[InputPartition] = + batch.planInputPartitions() + + @transient override protected lazy val filteredPartitions: Seq[Seq[InputPartition]] = { + val dataSourceFilters = runtimeFilters.flatMap { + case DynamicPruningExpression(e) => DataSourceStrategyUtils.translateRuntimeFilter(e) + case _ => None + } + + if (dataSourceFilters.nonEmpty) { + val originalPartitioning = outputPartitioning + + // the cast is safe as runtime filters are only assigned if the scan can be filtered + val filterableScan = scan.asInstanceOf[SupportsRuntimeV2Filtering] + filterableScan.filter(dataSourceFilters.toArray) + + // call toBatch again to get filtered partitions + val newPartitions = scan.toBatch.planInputPartitions() + + originalPartitioning match { + case p: KeyGroupedPartitioning => + if (newPartitions.exists(!_.isInstanceOf[HasPartitionKey])) { + throw new SparkException("Data source must have preserved the original partitioning " + + "during runtime filtering: not all partitions implement HasPartitionKey after " + + "filtering") + } + + val newPartitionValues = newPartitions.map(partition => + InternalRowComparableWrapper(partition.asInstanceOf[HasPartitionKey], p.expressions)) + .toSet + val oldPartitionValues = p.partitionValues + .map(partition => InternalRowComparableWrapper(partition, p.expressions)).toSet + // We require the new number of partition values to be equal or less than the old number + // of partition values here. In the case of less than, empty partitions will be added for + // those missing values that are not present in the new input partitions. + if (oldPartitionValues.size < newPartitionValues.size) { + throw new SparkException("During runtime filtering, data source must either report " + + "the same number of partition values, or a subset of partition values from the " + + s"original. Before: ${oldPartitionValues.size} partition values. " + + s"After: ${newPartitionValues.size} partition values") + } + + if (!newPartitionValues.forall(oldPartitionValues.contains)) { + throw new SparkException("During runtime filtering, data source must not report new " + + "partition values that are not present in the original partitioning.") + } + groupPartitions(newPartitions) + .map(_.groupedParts.map(_.parts)).getOrElse(Seq.empty) + + case _ => + // no validation is needed as the data source did not report any specific partitioning + newPartitions.map(Seq(_)) + } + + } else { + partitions + } + } + + override def outputPartitioning: Partitioning = { + super.outputPartitioning match { + case k: KeyGroupedPartitioning if spjParams.commonPartitionValues.isDefined => + // We allow duplicated partition values if + // `spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled` is true + val newPartValues = spjParams.commonPartitionValues.get.flatMap { + case (partValue, numSplits) => Seq.fill(numSplits)(partValue) + } + val expressions = spjParams.joinKeyPositions match { + case Some(projectionPositions) => projectionPositions.map(i => k.expressions(i)) + case _ => k.expressions + } + k.copy(expressions = expressions, numPartitions = newPartValues.length, + partitionValues = newPartValues) + case p => p + } + } + + override lazy val readerFactory: PartitionReaderFactory = batch.createReaderFactory() + + override lazy val inputRDD: RDD[InternalRow] = { + val rdd = if (filteredPartitions.isEmpty && outputPartitioning == SinglePartition) { + // return an empty RDD with 1 partition if dynamic filtering removed the only split + sparkContext.parallelize(Array.empty[InternalRow], 1) + } else { + val finalPartitions = outputPartitioning match { + case p: KeyGroupedPartitioning => + assert(spjParams.keyGroupedPartitioning.isDefined) + val expressions = spjParams.keyGroupedPartitioning.get + + // Re-group the input partitions if we are projecting on a subset of join keys + val (groupedPartitions, partExpressions) = spjParams.joinKeyPositions match { + case Some(projectPositions) => + val projectedExpressions = projectPositions.map(i => expressions(i)) + val parts = filteredPartitions.flatten.groupBy(part => { + val row = part.asInstanceOf[HasPartitionKey].partitionKey() + val projectedRow = KeyGroupedPartitioning.project( + expressions, projectPositions, row) + InternalRowComparableWrapper(projectedRow, projectedExpressions) + }).map { case (wrapper, splits) => (wrapper.row, splits) }.toSeq + (parts, projectedExpressions) + case _ => + val groupedParts = filteredPartitions.map(splits => { + assert(splits.nonEmpty && splits.head.isInstanceOf[HasPartitionKey]) + (splits.head.asInstanceOf[HasPartitionKey].partitionKey(), splits) + }) + (groupedParts, expressions) + } + + // Also re-group the partitions if we are reducing compatible partition expressions + val finalGroupedPartitions = spjParams.reducers match { + case Some(reducers) => + val result = groupedPartitions.groupBy { case (row, _) => + KeyGroupedShuffleSpec.reducePartitionValue(row, partExpressions, reducers) + }.map { case (wrapper, splits) => (wrapper.row, splits.flatMap(_._2)) }.toSeq + val rowOrdering = RowOrdering.createNaturalAscendingOrdering( + partExpressions.map(_.dataType)) + result.sorted(rowOrdering.on((t: (InternalRow, _)) => t._1)) + case _ => groupedPartitions + } + + // When partially clustered, the input partitions are not grouped by partition + // values. Here we'll need to check `commonPartitionValues` and decide how to group + // and replicate splits within a partition. + if (spjParams.commonPartitionValues.isDefined && spjParams.applyPartialClustering) { + // A mapping from the common partition values to how many splits the partition + // should contain. + val commonPartValuesMap = spjParams.commonPartitionValues + .get + .map(t => (InternalRowComparableWrapper(t._1, partExpressions), t._2)) + .toMap + val nestGroupedPartitions = finalGroupedPartitions.map { case (partValue, splits) => + // `commonPartValuesMap` should contain the part value since it's the super set. + val numSplits = commonPartValuesMap + .get(InternalRowComparableWrapper(partValue, partExpressions)) + assert(numSplits.isDefined, s"Partition value $partValue does not exist in " + + "common partition values from Spark plan") + + val newSplits = if (spjParams.replicatePartitions) { + // We need to also replicate partitions according to the other side of join + Seq.fill(numSplits.get)(splits) + } else { + // Not grouping by partition values: this could be the side with partially + // clustered distribution. Because of dynamic filtering, we'll need to check if + // the final number of splits of a partition is smaller than the original + // number, and fill with empty splits if so. This is necessary so that both + // sides of a join will have the same number of partitions & splits. + splits.map(Seq(_)).padTo(numSplits.get, Seq.empty) + } + (InternalRowComparableWrapper(partValue, partExpressions), newSplits) + } + + // Now fill missing partition keys with empty partitions + val partitionMapping = nestGroupedPartitions.toMap + spjParams.commonPartitionValues.get.flatMap { + case (partValue, numSplits) => + // Use empty partition for those partition values that are not present. + partitionMapping.getOrElse( + InternalRowComparableWrapper(partValue, partExpressions), + Seq.fill(numSplits)(Seq.empty)) + } + } else { + // either `commonPartitionValues` is not defined, or it is defined but + // `applyPartialClustering` is false. + val partitionMapping = finalGroupedPartitions.map { case (partValue, splits) => + InternalRowComparableWrapper(partValue, partExpressions) -> splits + }.toMap + + // In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there + // could exist duplicated partition values, as partition grouping is not done + // at the beginning and postponed to this method. It is important to use unique + // partition values here so that grouped partitions won't get duplicated. + p.uniquePartitionValues.map { partValue => + // Use empty partition for those partition values that are not present + partitionMapping.getOrElse( + InternalRowComparableWrapper(partValue, partExpressions), Seq.empty) + } + } + + case _ => filteredPartitions + } + + new DataSourceRDD( + sparkContext, finalPartitions, readerFactory, supportsColumnar, customMetrics) + } + postDriverMetrics() + rdd + } + + override def keyGroupedPartitioning: Option[Seq[Expression]] = + spjParams.keyGroupedPartitioning + + override def doCanonicalize(): GpuBatchScanExec = { + this.copy( + output = output.map(QueryPlan.normalizeExpressions(_, output)), + runtimeFilters = QueryPlan.normalizePredicates( + runtimeFilters.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)), + output)) + } + + override def simpleString(maxFields: Int): String = { + val truncatedOutputString = truncatedString(output, "[", ", ", "]", maxFields) + val runtimeFiltersString = s"RuntimeFilters: ${runtimeFilters.mkString("[", ",", "]")}" + val result = s"$nodeName$truncatedOutputString ${scan.description()} $runtimeFiltersString" + redact(result) + } + + override def nodeName: String = { + s"GpuBatchScan ${table.name()}".trim + } +}