From 7bac3a6439c10efb1961d3c4ba028128d9dca249 Mon Sep 17 00:00:00 2001 From: Renjie Liu Date: Wed, 19 Jun 2024 09:44:48 +0800 Subject: [PATCH] [FEA] Introduce low shuffle merge. (#10979) * feat: Introduce low shuffle merge. Signed-off-by: liurenjie1024 * fix * Test databricks parallel * Test more databricks parallel * Fix comments * Config && scala 2.13 * Revert * Fix comments * scala 2.13 * Revert unnecessary changes * Revert "Revert unnecessary changes" This reverts commit 9fa4cf268cc3fce4d2732e04cb33eb53e4859c99. * restore change --------- Signed-off-by: liurenjie1024 --- aggregator/pom.xml | 4 + .../GpuDeltaParquetFileFormatUtils.scala | 160 +++ .../nvidia/spark/rapids/delta/deltaUDFs.scala | 83 +- .../delta/delta24x/Delta24xProvider.scala | 5 +- .../GpuDelta24xParquetFileFormat.scala | 61 +- .../delta/delta24x/MergeIntoCommandMeta.scala | 58 +- .../delta24x/GpuLowShuffleMergeCommand.scala | 1084 +++++++++++++++++ .../rapids/GpuLowShuffleMergeCommand.scala | 1083 ++++++++++++++++ .../delta/GpuDeltaParquetFileFormat.scala | 63 +- .../shims/MergeIntoCommandMetaShim.scala | 101 +- .../advanced_configs.md | 6 + .../delta_lake_low_shuffle_merge_test.py | 165 +++ .../main/python/delta_lake_merge_common.py | 155 +++ .../src/main/python/delta_lake_merge_test.py | 127 +- pom.xml | 10 + scala2.13/aggregator/pom.xml | 4 + scala2.13/pom.xml | 10 + scala2.13/sql-plugin/pom.xml | 4 + sql-plugin/pom.xml | 4 + .../com/nvidia/spark/rapids/RapidsConf.scala | 28 + 20 files changed, 3061 insertions(+), 154 deletions(-) create mode 100644 delta-lake/common/src/main/scala/com/nvidia/spark/rapids/delta/GpuDeltaParquetFileFormatUtils.scala create mode 100644 delta-lake/delta-24x/src/main/scala/org/apache/spark/sql/delta/rapids/delta24x/GpuLowShuffleMergeCommand.scala create mode 100644 delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuLowShuffleMergeCommand.scala create mode 100644 integration_tests/src/main/python/delta_lake_low_shuffle_merge_test.py create mode 100644 integration_tests/src/main/python/delta_lake_merge_common.py diff --git a/aggregator/pom.xml b/aggregator/pom.xml index 22bfe11105e..8cf881419c9 100644 --- a/aggregator/pom.xml +++ b/aggregator/pom.xml @@ -94,6 +94,10 @@ com.google.flatbuffers ${rapids.shade.package}.com.google.flatbuffers + + org.roaringbitmap + ${rapids.shade.package}.org.roaringbitmap + diff --git a/delta-lake/common/src/main/scala/com/nvidia/spark/rapids/delta/GpuDeltaParquetFileFormatUtils.scala b/delta-lake/common/src/main/scala/com/nvidia/spark/rapids/delta/GpuDeltaParquetFileFormatUtils.scala new file mode 100644 index 00000000000..101a82da830 --- /dev/null +++ b/delta-lake/common/src/main/scala/com/nvidia/spark/rapids/delta/GpuDeltaParquetFileFormatUtils.scala @@ -0,0 +1,160 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.delta + +import ai.rapids.cudf.{ColumnVector => CudfColumnVector, Scalar, Table} +import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} +import com.nvidia.spark.rapids.GpuColumnVector +import org.roaringbitmap.longlong.{PeekableLongIterator, Roaring64Bitmap} + +import org.apache.spark.sql.types.{BooleanType, LongType, StringType, StructField, StructType} +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} + + +object GpuDeltaParquetFileFormatUtils { + /** + * Row number of the row in the file. When used with [[FILE_PATH_COL]] together, it can be used + * as unique id of a row in file. Currently to correctly calculate this, the caller needs to + * set both [[isSplitable]] to false, and [[RapidsConf.PARQUET_READER_TYPE]] to "PERFILE". + */ + val METADATA_ROW_IDX_COL: String = "__metadata_row_index" + val METADATA_ROW_IDX_FIELD: StructField = StructField(METADATA_ROW_IDX_COL, LongType, + nullable = false) + + val METADATA_ROW_DEL_COL: String = "__metadata_row_del" + val METADATA_ROW_DEL_FIELD: StructField = StructField(METADATA_ROW_DEL_COL, BooleanType, + nullable = false) + + + /** + * File path of the file that the row came from. + */ + val FILE_PATH_COL: String = "_metadata_file_path" + val FILE_PATH_FIELD: StructField = StructField(FILE_PATH_COL, StringType, nullable = false) + + /** + * Add a metadata column to the iterator. Currently only support [[METADATA_ROW_IDX_COL]]. + */ + def addMetadataColumnToIterator( + schema: StructType, + delVector: Option[Roaring64Bitmap], + input: Iterator[ColumnarBatch], + maxBatchSize: Int): Iterator[ColumnarBatch] = { + val metadataRowIndexCol = schema.fieldNames.indexOf(METADATA_ROW_IDX_COL) + val delRowIdx = schema.fieldNames.indexOf(METADATA_ROW_DEL_COL) + if (metadataRowIndexCol == -1 && delRowIdx == -1) { + return input + } + var rowIndex = 0L + input.map { batch => + withResource(batch) { _ => + val rowIdxCol = if (metadataRowIndexCol == -1) { + None + } else { + Some(metadataRowIndexCol) + } + + val delRowIdx2 = if (delRowIdx == -1) { + None + } else { + Some(delRowIdx) + } + val newBatch = addMetadataColumns(rowIdxCol, delRowIdx2, delVector,maxBatchSize, + rowIndex, batch) + rowIndex += batch.numRows() + newBatch + } + } + } + + private def addMetadataColumns( + rowIdxPos: Option[Int], + delRowIdx: Option[Int], + delVec: Option[Roaring64Bitmap], + maxBatchSize: Int, + rowIdxStart: Long, + batch: ColumnarBatch): ColumnarBatch = { + val rowIdxCol = rowIdxPos.map { _ => + withResource(Scalar.fromLong(rowIdxStart)) { start => + GpuColumnVector.from(CudfColumnVector.sequence(start, batch.numRows()), + METADATA_ROW_IDX_FIELD.dataType) + } + } + + closeOnExcept(rowIdxCol) { rowIdxCol => + + val delVecCol = delVec.map { delVec => + withResource(Scalar.fromBool(false)) { s => + withResource(CudfColumnVector.fromScalar(s, batch.numRows())) { c => + var table = new Table(c) + val posIter = new RoaringBitmapIterator( + delVec.getLongIteratorFrom(rowIdxStart), + rowIdxStart, + rowIdxStart + batch.numRows(), + ).grouped(Math.min(maxBatchSize, batch.numRows())) + + for (posChunk <- posIter) { + withResource(CudfColumnVector.fromLongs(posChunk: _*)) { poses => + withResource(Scalar.fromBool(true)) { s => + table = withResource(table) { _ => + Table.scatter(Array(s), poses, table) + } + } + } + } + + withResource(table) { _ => + GpuColumnVector.from(table.getColumn(0).incRefCount(), + METADATA_ROW_DEL_FIELD.dataType) + } + } + } + } + + closeOnExcept(delVecCol) { delVecCol => + // Replace row_idx column + val columns = new Array[ColumnVector](batch.numCols()) + for (i <- 0 until batch.numCols()) { + if (rowIdxPos.contains(i)) { + columns(i) = rowIdxCol.get + } else if (delRowIdx.contains(i)) { + columns(i) = delVecCol.get + } else { + columns(i) = batch.column(i) match { + case gpuCol: GpuColumnVector => gpuCol.incRefCount() + case col => col + } + } + } + + new ColumnarBatch(columns, batch.numRows()) + } + } + } +} + +class RoaringBitmapIterator(val inner: PeekableLongIterator, val start: Long, val end: Long) + extends Iterator[Long] { + + override def hasNext: Boolean = { + inner.hasNext && inner.peekNext() < end + } + + override def next(): Long = { + inner.next() - start + } +} diff --git a/delta-lake/common/src/main/scala/com/nvidia/spark/rapids/delta/deltaUDFs.scala b/delta-lake/common/src/main/scala/com/nvidia/spark/rapids/delta/deltaUDFs.scala index 6b2c63407d7..9893545a4ad 100644 --- a/delta-lake/common/src/main/scala/com/nvidia/spark/rapids/delta/deltaUDFs.scala +++ b/delta-lake/common/src/main/scala/com/nvidia/spark/rapids/delta/deltaUDFs.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,12 +16,19 @@ package com.nvidia.spark.rapids.delta +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} + import ai.rapids.cudf.{ColumnVector, Scalar, Table} import ai.rapids.cudf.Table.DuplicateKeepOption import com.nvidia.spark.RapidsUDF import com.nvidia.spark.rapids.Arm.withResource +import org.roaringbitmap.longlong.Roaring64Bitmap +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.sql.types.{BinaryType, DataType, SQLUserDefinedType, UserDefinedType} import org.apache.spark.util.AccumulatorV2 class GpuDeltaRecordTouchedFileNameUDF(accum: AccumulatorV2[String, java.util.Set[String]]) @@ -73,3 +80,77 @@ class GpuDeltaMetricUpdateUDF(metric: SQLMetric) } } } + +class GpuDeltaNoopUDF extends Function1[Boolean, Boolean] with RapidsUDF with Serializable { + override def apply(v1: Boolean): Boolean = v1 + + override def evaluateColumnar(numRows: Int, args: ColumnVector*): ColumnVector = { + require(args.length == 1) + args(0).incRefCount() + } +} + +@SQLUserDefinedType(udt = classOf[RoaringBitmapUDT]) +case class RoaringBitmapWrapper(inner: Roaring64Bitmap) { + def serializeToBytes(): Array[Byte] = { + withResource(new ByteArrayOutputStream()) { bout => + withResource(new DataOutputStream(bout)) { dao => + inner.serialize(dao) + } + bout.toByteArray + } + } +} + +object RoaringBitmapWrapper { + def deserializeFromBytes(bytes: Array[Byte]): RoaringBitmapWrapper = { + withResource(new ByteArrayInputStream(bytes)) { bin => + withResource(new DataInputStream(bin)) { din => + val ret = RoaringBitmapWrapper(new Roaring64Bitmap) + ret.inner.deserialize(din) + ret + } + } + } +} + +class RoaringBitmapUDT extends UserDefinedType[RoaringBitmapWrapper] { + + override def sqlType: DataType = BinaryType + + override def serialize(obj: RoaringBitmapWrapper): Any = { + obj.serializeToBytes() + } + + override def deserialize(datum: Any): RoaringBitmapWrapper = { + datum match { + case b: Array[Byte] => RoaringBitmapWrapper.deserializeFromBytes(b) + case t => throw new IllegalArgumentException(s"t: ${t.getClass}") + } + } + + override def userClass: Class[RoaringBitmapWrapper] = classOf[RoaringBitmapWrapper] + + override def typeName: String = "RoaringBitmap" +} + +object RoaringBitmapUDAF extends Aggregator[Long, RoaringBitmapWrapper, RoaringBitmapWrapper] { + override def zero: RoaringBitmapWrapper = RoaringBitmapWrapper(new Roaring64Bitmap()) + + override def reduce(b: RoaringBitmapWrapper, a: Long): RoaringBitmapWrapper = { + b.inner.addLong(a) + b + } + + override def merge(b1: RoaringBitmapWrapper, b2: RoaringBitmapWrapper): RoaringBitmapWrapper = { + val ret = b1.inner.clone() + ret.or(b2.inner) + RoaringBitmapWrapper(ret) + } + + override def finish(reduction: RoaringBitmapWrapper): RoaringBitmapWrapper = reduction + + override def bufferEncoder: Encoder[RoaringBitmapWrapper] = ExpressionEncoder() + + override def outputEncoder: Encoder[RoaringBitmapWrapper] = ExpressionEncoder() +} diff --git a/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/Delta24xProvider.scala b/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/Delta24xProvider.scala index d3f952b856c..f90f31300e5 100644 --- a/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/Delta24xProvider.scala +++ b/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/Delta24xProvider.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -74,7 +74,8 @@ object Delta24xProvider extends DeltaIOProvider { override def getReadFileFormat(format: FileFormat): FileFormat = { val cpuFormat = format.asInstanceOf[DeltaParquetFileFormat] - GpuDelta24xParquetFileFormat(cpuFormat.metadata, cpuFormat.isSplittable) + GpuDelta24xParquetFileFormat(cpuFormat.metadata, cpuFormat.isSplittable, + cpuFormat.disablePushDowns, cpuFormat.broadcastDvMap) } override def convertToGpu( diff --git a/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/GpuDelta24xParquetFileFormat.scala b/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/GpuDelta24xParquetFileFormat.scala index 709df7e9416..ef579d78e6f 100644 --- a/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/GpuDelta24xParquetFileFormat.scala +++ b/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/GpuDelta24xParquetFileFormat.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,18 +16,32 @@ package com.nvidia.spark.rapids.delta.delta24x -import com.nvidia.spark.rapids.delta.GpuDeltaParquetFileFormat +import java.net.URI + +import com.nvidia.spark.rapids.{GpuMetric, RapidsConf} +import com.nvidia.spark.rapids.delta.{GpuDeltaParquetFileFormat, RoaringBitmapWrapper} +import com.nvidia.spark.rapids.delta.GpuDeltaParquetFileFormatUtils.addMetadataColumnToIterator +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.delta.{DeltaColumnMappingMode, IdMapping} +import org.apache.spark.sql.delta.DeltaParquetFileFormat.DeletionVectorDescriptorWithFilterType import org.apache.spark.sql.delta.actions.Metadata +import org.apache.spark.sql.execution.datasources.PartitionedFile import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch case class GpuDelta24xParquetFileFormat( metadata: Metadata, - isSplittable: Boolean) extends GpuDeltaParquetFileFormat { + isSplittable: Boolean, + disablePushDown: Boolean, + broadcastDvMap: Option[Broadcast[Map[URI, DeletionVectorDescriptorWithFilterType]]]) + extends GpuDeltaParquetFileFormat { override val columnMappingMode: DeltaColumnMappingMode = metadata.columnMappingMode override val referenceSchema: StructType = metadata.schema @@ -46,6 +60,47 @@ case class GpuDelta24xParquetFileFormat( options: Map[String, String], path: Path): Boolean = isSplittable + override def buildReaderWithPartitionValuesAndMetrics( + sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration, + metrics: Map[String, GpuMetric], + alluxioPathReplacementMap: Option[Map[String, String]]) + : PartitionedFile => Iterator[InternalRow] = { + + + val dataReader = super.buildReaderWithPartitionValuesAndMetrics( + sparkSession, + dataSchema, + partitionSchema, + requiredSchema, + if (disablePushDown) Seq.empty else filters, + options, + hadoopConf, + metrics, + alluxioPathReplacementMap) + + val delVecs = broadcastDvMap + val maxDelVecScatterBatchSize = RapidsConf + .DELTA_LOW_SHUFFLE_MERGE_SCATTER_DEL_VECTOR_BATCH_SIZE + .get(sparkSession.sessionState.conf) + + (file: PartitionedFile) => { + val input = dataReader(file) + val dv = delVecs.flatMap(_.value.get(new URI(file.filePath.toString()))) + .map(dv => RoaringBitmapWrapper.deserializeFromBytes(dv.descriptor.inlineData).inner) + addMetadataColumnToIterator(prepareSchema(requiredSchema), + dv, + input.asInstanceOf[Iterator[ColumnarBatch]], + maxDelVecScatterBatchSize) + .asInstanceOf[Iterator[InternalRow]] + } + } + /** * We sometimes need to replace FileFormat within LogicalPlans, so we have to override * `equals` to ensure file format changes are captured diff --git a/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/MergeIntoCommandMeta.scala b/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/MergeIntoCommandMeta.scala index 4b4dfb624b5..8ce813ef011 100644 --- a/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/MergeIntoCommandMeta.scala +++ b/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/MergeIntoCommandMeta.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,13 +16,14 @@ package com.nvidia.spark.rapids.delta.delta24x -import com.nvidia.spark.rapids.{DataFromReplacementRule, RapidsConf, RapidsMeta, RunnableCommandMeta} +import com.nvidia.spark.rapids.{DataFromReplacementRule, RapidsConf, RapidsMeta, RapidsReaderType, RunnableCommandMeta} import com.nvidia.spark.rapids.delta.RapidsDeltaUtils +import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.delta.commands.MergeIntoCommand import org.apache.spark.sql.delta.rapids.GpuDeltaLog -import org.apache.spark.sql.delta.rapids.delta24x.GpuMergeIntoCommand +import org.apache.spark.sql.delta.rapids.delta24x.{GpuLowShuffleMergeCommand, GpuMergeIntoCommand} import org.apache.spark.sql.execution.command.RunnableCommand class MergeIntoCommandMeta( @@ -30,12 +31,12 @@ class MergeIntoCommandMeta( conf: RapidsConf, parent: Option[RapidsMeta[_, _, _]], rule: DataFromReplacementRule) - extends RunnableCommandMeta[MergeIntoCommand](mergeCmd, conf, parent, rule) { + extends RunnableCommandMeta[MergeIntoCommand](mergeCmd, conf, parent, rule) with Logging { override def tagSelfForGpu(): Unit = { if (!conf.isDeltaWriteEnabled) { willNotWorkOnGpu("Delta Lake output acceleration has been disabled. To enable set " + - s"${RapidsConf.ENABLE_DELTA_WRITE} to true") + s"${RapidsConf.ENABLE_DELTA_WRITE} to true") } if (mergeCmd.notMatchedBySourceClauses.nonEmpty) { // https://github.com/NVIDIA/spark-rapids/issues/8415 @@ -48,14 +49,43 @@ class MergeIntoCommandMeta( } override def convertToGpu(): RunnableCommand = { - GpuMergeIntoCommand( - mergeCmd.source, - mergeCmd.target, - new GpuDeltaLog(mergeCmd.targetFileIndex.deltaLog, conf), - mergeCmd.condition, - mergeCmd.matchedClauses, - mergeCmd.notMatchedClauses, - mergeCmd.notMatchedBySourceClauses, - mergeCmd.migratedSchema)(conf) + // TODO: Currently we only support low shuffler merge only when parquet per file read is enabled + // due to the limitation of implementing row index metadata column. + if (conf.isDeltaLowShuffleMergeEnabled) { + if (conf.isParquetPerFileReadEnabled) { + GpuLowShuffleMergeCommand( + mergeCmd.source, + mergeCmd.target, + new GpuDeltaLog(mergeCmd.targetFileIndex.deltaLog, conf), + mergeCmd.condition, + mergeCmd.matchedClauses, + mergeCmd.notMatchedClauses, + mergeCmd.notMatchedBySourceClauses, + mergeCmd.migratedSchema)(conf) + } else { + logWarning(s"""Low shuffle merge disabled since ${RapidsConf.PARQUET_READER_TYPE} is + not set to ${RapidsReaderType.PERFILE}. Falling back to classic merge.""") + GpuMergeIntoCommand( + mergeCmd.source, + mergeCmd.target, + new GpuDeltaLog(mergeCmd.targetFileIndex.deltaLog, conf), + mergeCmd.condition, + mergeCmd.matchedClauses, + mergeCmd.notMatchedClauses, + mergeCmd.notMatchedBySourceClauses, + mergeCmd.migratedSchema)(conf) + } + } else { + GpuMergeIntoCommand( + mergeCmd.source, + mergeCmd.target, + new GpuDeltaLog(mergeCmd.targetFileIndex.deltaLog, conf), + mergeCmd.condition, + mergeCmd.matchedClauses, + mergeCmd.notMatchedClauses, + mergeCmd.notMatchedBySourceClauses, + mergeCmd.migratedSchema)(conf) + } } + } diff --git a/delta-lake/delta-24x/src/main/scala/org/apache/spark/sql/delta/rapids/delta24x/GpuLowShuffleMergeCommand.scala b/delta-lake/delta-24x/src/main/scala/org/apache/spark/sql/delta/rapids/delta24x/GpuLowShuffleMergeCommand.scala new file mode 100644 index 00000000000..9c27d28ebd3 --- /dev/null +++ b/delta-lake/delta-24x/src/main/scala/org/apache/spark/sql/delta/rapids/delta24x/GpuLowShuffleMergeCommand.scala @@ -0,0 +1,1084 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * This file was derived from MergeIntoCommand.scala + * in the Delta Lake project at https://github.com/delta-io/delta. + * + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta.rapids.delta24x + +import java.net.URI +import java.util.concurrent.TimeUnit + +import scala.collection.mutable + +import com.nvidia.spark.rapids.{GpuOverrides, RapidsConf, SparkPlanMeta} +import com.nvidia.spark.rapids.RapidsConf.DELTA_LOW_SHUFFLE_MERGE_DEL_VECTOR_BROADCAST_THRESHOLD +import com.nvidia.spark.rapids.delta._ +import com.nvidia.spark.rapids.delta.GpuDeltaParquetFileFormatUtils._ +import com.nvidia.spark.rapids.shims.FileSourceScanExecMeta +import org.roaringbitmap.longlong.Roaring64Bitmap + +import org.apache.spark.SparkContext +import org.apache.spark.internal.Logging +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, CaseWhen, Expression, Literal, NamedExpression, PredicateHelper} +import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral +import org.apache.spark.sql.catalyst.plans.logical.{DeltaMergeAction, DeltaMergeIntoClause, DeltaMergeIntoMatchedClause, DeltaMergeIntoMatchedDeleteClause, DeltaMergeIntoMatchedUpdateClause, DeltaMergeIntoNotMatchedBySourceClause, DeltaMergeIntoNotMatchedBySourceDeleteClause, DeltaMergeIntoNotMatchedBySourceUpdateClause, DeltaMergeIntoNotMatchedClause, DeltaMergeIntoNotMatchedInsertClause, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.delta.{DeltaErrors, DeltaLog, DeltaOperations, DeltaParquetFileFormat, DeltaTableUtils, DeltaUDF, NoMapping, OptimisticTransaction, RowIndexFilterType} +import org.apache.spark.sql.delta.DeltaOperations.MergePredicate +import org.apache.spark.sql.delta.DeltaParquetFileFormat.DeletionVectorDescriptorWithFilterType +import org.apache.spark.sql.delta.actions.{AddCDCFile, AddFile, DeletionVectorDescriptor, FileAction} +import org.apache.spark.sql.delta.commands.DeltaCommand +import org.apache.spark.sql.delta.rapids.{GpuDeltaLog, GpuOptimisticTransactionBase} +import org.apache.spark.sql.delta.rapids.delta24x.MergeExecutor.{toDeletionVector, totalBytesAndDistinctPartitionValues, INCR_METRICS_COL, INCR_METRICS_FIELD, ROW_DROPPED_COL, ROW_DROPPED_FIELD, SOURCE_ROW_PRESENT_COL, SOURCE_ROW_PRESENT_FIELD, TARGET_ROW_PRESENT_COL, TARGET_ROW_PRESENT_FIELD} +import org.apache.spark.sql.delta.schema.ImplicitMetadataOperation +import org.apache.spark.sql.delta.sources.DeltaSQLConf +import org.apache.spark.sql.delta.util.{AnalysisHelper, DeltaFileOperations} +import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.command.LeafRunnableCommand +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{BooleanType, LongType, StringType, StructField, StructType} + +/** + * GPU version of Delta Lake's low shuffle merge implementation. + * + * Performs a merge of a source query/table into a Delta table. + * + * Issues an error message when the ON search_condition of the MERGE statement can match + * a single row from the target table with multiple rows of the source table-reference. + * Different from the original implementation, it optimized writing touched unmodified target files. + * + * Algorithm: + * + * Phase 1: Find the input files in target that are touched by the rows that satisfy + * the condition and verify that no two source rows match with the same target row. + * This is implemented as an inner-join using the given condition. See [[findTouchedFiles]] + * for more details. + * + * Phase 2: Read the touched files again and write new files with updated and/or inserted rows + * without copying unmodified rows. + * + * Phase 3: Read the touched files again and write new files with unmodified rows in target table, + * trying to keep its original order and avoid shuffle as much as possible. + * + * Phase 4: Use the Delta protocol to atomically remove the touched files and add the new files. + * + * @param source Source data to merge from + * @param target Target table to merge into + * @param gpuDeltaLog Delta log to use + * @param condition Condition for a source row to match with a target row + * @param matchedClauses All info related to matched clauses. + * @param notMatchedClauses All info related to not matched clause. + * @param migratedSchema The final schema of the target - may be changed by schema evolution. + */ +case class GpuLowShuffleMergeCommand( + @transient source: LogicalPlan, + @transient target: LogicalPlan, + @transient gpuDeltaLog: GpuDeltaLog, + condition: Expression, + matchedClauses: Seq[DeltaMergeIntoMatchedClause], + notMatchedClauses: Seq[DeltaMergeIntoNotMatchedClause], + notMatchedBySourceClauses: Seq[DeltaMergeIntoNotMatchedBySourceClause], + migratedSchema: Option[StructType])( + @transient val rapidsConf: RapidsConf) + extends LeafRunnableCommand + with DeltaCommand with PredicateHelper with AnalysisHelper with ImplicitMetadataOperation { + + import SQLMetrics._ + + override val otherCopyArgs: Seq[AnyRef] = Seq(rapidsConf) + + override val canMergeSchema: Boolean = conf.getConf(DeltaSQLConf.DELTA_SCHEMA_AUTO_MIGRATE) + override val canOverwriteSchema: Boolean = false + + override val output: Seq[Attribute] = Seq( + AttributeReference("num_affected_rows", LongType)(), + AttributeReference("num_updated_rows", LongType)(), + AttributeReference("num_deleted_rows", LongType)(), + AttributeReference("num_inserted_rows", LongType)()) + + @transient private lazy val sc: SparkContext = SparkContext.getOrCreate() + @transient private[delta] lazy val targetDeltaLog: DeltaLog = gpuDeltaLog.deltaLog + + override lazy val metrics = Map[String, SQLMetric]( + "numSourceRows" -> createMetric(sc, "number of source rows"), + "numSourceRowsInSecondScan" -> + createMetric(sc, "number of source rows (during repeated scan)"), + "numTargetRowsCopied" -> createMetric(sc, "number of target rows rewritten unmodified"), + "numTargetRowsInserted" -> createMetric(sc, "number of inserted rows"), + "numTargetRowsUpdated" -> createMetric(sc, "number of updated rows"), + "numTargetRowsDeleted" -> createMetric(sc, "number of deleted rows"), + "numTargetRowsMatchedUpdated" -> createMetric(sc, "number of target rows updated when matched"), + "numTargetRowsMatchedDeleted" -> createMetric(sc, "number of target rows deleted when matched"), + "numTargetRowsNotMatchedBySourceUpdated" -> createMetric(sc, + "number of target rows updated when not matched by source"), + "numTargetRowsNotMatchedBySourceDeleted" -> createMetric(sc, + "number of target rows deleted when not matched by source"), + "numTargetFilesBeforeSkipping" -> createMetric(sc, "number of target files before skipping"), + "numTargetFilesAfterSkipping" -> createMetric(sc, "number of target files after skipping"), + "numTargetFilesRemoved" -> createMetric(sc, "number of files removed to target"), + "numTargetFilesAdded" -> createMetric(sc, "number of files added to target"), + "numTargetChangeFilesAdded" -> + createMetric(sc, "number of change data capture files generated"), + "numTargetChangeFileBytes" -> + createMetric(sc, "total size of change data capture files generated"), + "numTargetBytesBeforeSkipping" -> createMetric(sc, "number of target bytes before skipping"), + "numTargetBytesAfterSkipping" -> createMetric(sc, "number of target bytes after skipping"), + "numTargetBytesRemoved" -> createMetric(sc, "number of target bytes removed"), + "numTargetBytesAdded" -> createMetric(sc, "number of target bytes added"), + "numTargetPartitionsAfterSkipping" -> + createMetric(sc, "number of target partitions after skipping"), + "numTargetPartitionsRemovedFrom" -> + createMetric(sc, "number of target partitions from which files were removed"), + "numTargetPartitionsAddedTo" -> + createMetric(sc, "number of target partitions to which files were added"), + "executionTimeMs" -> + createMetric(sc, "time taken to execute the entire operation"), + "scanTimeMs" -> + createMetric(sc, "time taken to scan the files for matches"), + "rewriteTimeMs" -> + createMetric(sc, "time taken to rewrite the matched files")) + + /** Whether this merge statement has only a single insert (NOT MATCHED) clause. */ + protected def isSingleInsertOnly: Boolean = matchedClauses.isEmpty && + notMatchedClauses.length == 1 + + override def run(spark: SparkSession): Seq[Row] = { + recordDeltaOperation(targetDeltaLog, "delta.dml.lowshufflemerge") { + val startTime = System.nanoTime() + val result = gpuDeltaLog.withNewTransaction { deltaTxn => + if (target.schema.size != deltaTxn.metadata.schema.size) { + throw DeltaErrors.schemaChangedSinceAnalysis( + atAnalysis = target.schema, latestSchema = deltaTxn.metadata.schema) + } + + if (canMergeSchema) { + updateMetadata( + spark, deltaTxn, migratedSchema.getOrElse(target.schema), + deltaTxn.metadata.partitionColumns, deltaTxn.metadata.configuration, + isOverwriteMode = false, rearrangeOnly = false) + } + + + val (executor, fallback) = { + val context = MergeExecutorContext(this, spark, deltaTxn, rapidsConf) + if (isSingleInsertOnly && spark.conf.get(DeltaSQLConf.MERGE_INSERT_ONLY_ENABLED)) { + (new InsertOnlyMergeExecutor(context), false) + } else { + val executor = new LowShuffleMergeExecutor(context) + (executor, executor.shouldFallback()) + } + } + + if (fallback) { + None + } else { + Some(runLowShuffleMerge(spark, startTime, deltaTxn, executor)) + } + } + + result match { + case Some(row) => row + case None => + // We should rollback to normal gpu + new GpuMergeIntoCommand(source, target, gpuDeltaLog, condition, matchedClauses, + notMatchedClauses, notMatchedBySourceClauses, migratedSchema)(rapidsConf) + .run(spark) + } + } + } + + + private def runLowShuffleMerge( + spark: SparkSession, + startTime: Long, + deltaTxn: GpuOptimisticTransactionBase, + mergeExecutor: MergeExecutor): Seq[Row] = { + val deltaActions = mergeExecutor.execute() + // Metrics should be recorded before commit (where they are written to delta logs). + metrics("executionTimeMs").set((System.nanoTime() - startTime) / 1000 / 1000) + deltaTxn.registerSQLMetrics(spark, metrics) + + // This is a best-effort sanity check. + if (metrics("numSourceRowsInSecondScan").value >= 0 && + metrics("numSourceRows").value != metrics("numSourceRowsInSecondScan").value) { + log.warn(s"Merge source has ${metrics("numSourceRows").value} rows in initial scan but " + + s"${metrics("numSourceRowsInSecondScan").value} rows in second scan") + if (conf.getConf(DeltaSQLConf.MERGE_FAIL_IF_SOURCE_CHANGED)) { + throw DeltaErrors.sourceNotDeterministicInMergeException(spark) + } + } + + deltaTxn.commit( + deltaActions, + DeltaOperations.Merge( + Option(condition), + matchedClauses.map(DeltaOperations.MergePredicate(_)), + notMatchedClauses.map(DeltaOperations.MergePredicate(_)), + // We do not support notMatchedBySourcePredicates yet and fall back to CPU + // See https://github.com/NVIDIA/spark-rapids/issues/8415 + notMatchedBySourcePredicates = Seq.empty[MergePredicate] + )) + + // Record metrics + val stats = GpuMergeStats.fromMergeSQLMetrics( + metrics, + condition, + matchedClauses, + notMatchedClauses, + notMatchedBySourceClauses, + deltaTxn.metadata.partitionColumns.nonEmpty) + recordDeltaEvent(targetDeltaLog, "delta.dml.merge.stats", data = stats) + + + spark.sharedState.cacheManager.recacheByPlan(spark, target) + + // This is needed to make the SQL metrics visible in the Spark UI. Also this needs + // to be outside the recordMergeOperation because this method will update some metric. + val executionId = spark.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(spark.sparkContext, executionId, metrics.values.toSeq) + Seq(Row(metrics("numTargetRowsUpdated").value + metrics("numTargetRowsDeleted").value + + metrics("numTargetRowsInserted").value, metrics("numTargetRowsUpdated").value, + metrics("numTargetRowsDeleted").value, metrics("numTargetRowsInserted").value)) + } + + /** + * Execute the given `thunk` and return its result while recording the time taken to do it. + * + * @param sqlMetricName name of SQL metric to update with the time taken by the thunk + * @param thunk the code to execute + */ + private[delta] def recordMergeOperation[A](sqlMetricName: String)(thunk: => A): A = { + val startTimeNs = System.nanoTime() + val r = thunk + val timeTakenMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs) + if (sqlMetricName != null && timeTakenMs > 0) { + metrics(sqlMetricName) += timeTakenMs + } + r + } + + /** Expressions to increment SQL metrics */ + private[delta] def makeMetricUpdateUDF(name: String, deterministic: Boolean = false) + : Expression = { + // only capture the needed metric in a local variable + val metric = metrics(name) + var u = DeltaUDF.boolean(new GpuDeltaMetricUpdateUDF(metric)) + if (!deterministic) { + u = u.asNondeterministic() + } + u.apply().expr + } +} + +/** + * Context merge execution. + */ +case class MergeExecutorContext(cmd: GpuLowShuffleMergeCommand, + spark: SparkSession, + deltaTxn: OptimisticTransaction, + rapidsConf: RapidsConf) + +trait MergeExecutor extends AnalysisHelper with PredicateHelper with Logging { + + val context: MergeExecutorContext + + + /** + * Map to get target output attributes by name. + * The case sensitivity of the map is set accordingly to Spark configuration. + */ + @transient private lazy val targetOutputAttributesMap: Map[String, Attribute] = { + val attrMap: Map[String, Attribute] = context.cmd.target + .outputSet.view + .map(attr => attr.name -> attr).toMap + if (context.cmd.conf.caseSensitiveAnalysis) { + attrMap + } else { + CaseInsensitiveMap(attrMap) + } + } + + def execute(): Seq[FileAction] + + protected def targetOutputCols: Seq[NamedExpression] = { + context.deltaTxn.metadata.schema.map { col => + targetOutputAttributesMap + .get(col.name) + .map { a => + AttributeReference(col.name, col.dataType, col.nullable)(a.exprId) + } + .getOrElse(Alias(Literal(null), col.name)()) + } + } + + /** + * Build a DataFrame using the given `files` that has the same output columns (exprIds) + * as the `target` logical plan, so that existing update/insert expressions can be applied + * on this new plan. + */ + protected def buildTargetDFWithFiles(files: Seq[AddFile]): DataFrame = { + val targetOutputColsMap = { + val colsMap: Map[String, NamedExpression] = targetOutputCols.view + .map(col => col.name -> col).toMap + if (context.cmd.conf.caseSensitiveAnalysis) { + colsMap + } else { + CaseInsensitiveMap(colsMap) + } + } + + val plan = { + // We have to do surgery to use the attributes from `targetOutputCols` to scan the table. + // In cases of schema evolution, they may not be the same type as the original attributes. + val original = + context.deltaTxn.deltaLog.createDataFrame(context.deltaTxn.snapshot, files) + .queryExecution + .analyzed + val transformed = original.transform { + case LogicalRelation(base, _, catalogTbl, isStreaming) => + LogicalRelation( + base, + // We can ignore the new columns which aren't yet AttributeReferences. + targetOutputCols.collect { case a: AttributeReference => a }, + catalogTbl, + isStreaming) + } + + // In case of schema evolution & column mapping, we would also need to rebuild the file + // format because under column mapping, the reference schema within DeltaParquetFileFormat + // that is used to populate metadata needs to be updated + if (context.deltaTxn.metadata.columnMappingMode != NoMapping) { + val updatedFileFormat = context.deltaTxn.deltaLog.fileFormat( + context.deltaTxn.deltaLog.unsafeVolatileSnapshot.protocol, context.deltaTxn.metadata) + DeltaTableUtils.replaceFileFormat(transformed, updatedFileFormat) + } else { + transformed + } + } + + // For each plan output column, find the corresponding target output column (by name) and + // create an alias + val aliases = plan.output.map { + case newAttrib: AttributeReference => + val existingTargetAttrib = targetOutputColsMap.getOrElse(newAttrib.name, + throw new AnalysisException( + s"Could not find ${newAttrib.name} among the existing target output " + + targetOutputCols.mkString(","))).asInstanceOf[AttributeReference] + + if (existingTargetAttrib.exprId == newAttrib.exprId) { + // It's not valid to alias an expression to its own exprId (this is considered a + // non-unique exprId by the analyzer), so we just use the attribute directly. + newAttrib + } else { + Alias(newAttrib, existingTargetAttrib.name)(exprId = existingTargetAttrib.exprId) + } + } + + Dataset.ofRows(context.spark, Project(aliases, plan)) + } + + + /** + * Repartitions the output DataFrame by the partition columns if table is partitioned + * and `merge.repartitionBeforeWrite.enabled` is set to true. + */ + protected def repartitionIfNeeded(df: DataFrame): DataFrame = { + val partitionColumns = context.deltaTxn.metadata.partitionColumns + // TODO: We should remove this method and use optimized write instead, see + // https://github.com/NVIDIA/spark-rapids/issues/10417 + if (partitionColumns.nonEmpty && context.spark.conf.get(DeltaSQLConf + .MERGE_REPARTITION_BEFORE_WRITE)) { + df.repartition(partitionColumns.map(col): _*) + } else { + df + } + } + + protected def sourceDF: DataFrame = { + // UDF to increment metrics + val incrSourceRowCountExpr = context.cmd.makeMetricUpdateUDF("numSourceRows") + Dataset.ofRows(context.spark, context.cmd.source) + .filter(new Column(incrSourceRowCountExpr)) + } + + /** Whether this merge statement has no insert (NOT MATCHED) clause. */ + protected def hasNoInserts: Boolean = context.cmd.notMatchedClauses.isEmpty + + +} + +/** + * This is an optimization of the case when there is no update clause for the merge. + * We perform an left anti join on the source data to find the rows to be inserted. + * + * This will currently only optimize for the case when there is a _single_ notMatchedClause. + */ +class InsertOnlyMergeExecutor(override val context: MergeExecutorContext) extends MergeExecutor { + override def execute(): Seq[FileAction] = { + context.cmd.recordMergeOperation(sqlMetricName = "rewriteTimeMs") { + + // UDFs to update metrics + val incrSourceRowCountExpr = context.cmd.makeMetricUpdateUDF("numSourceRows") + val incrInsertedCountExpr = context.cmd.makeMetricUpdateUDF("numTargetRowsInserted") + + val outputColNames = targetOutputCols.map(_.name) + // we use head here since we know there is only a single notMatchedClause + val outputExprs = context.cmd.notMatchedClauses.head.resolvedActions.map(_.expr) + val outputCols = outputExprs.zip(outputColNames).map { case (expr, name) => + new Column(Alias(expr, name)()) + } + + // source DataFrame + val sourceDF = Dataset.ofRows(context.spark, context.cmd.source) + .filter(new Column(incrSourceRowCountExpr)) + .filter(new Column(context.cmd.notMatchedClauses.head.condition + .getOrElse(Literal.TrueLiteral))) + + // Skip data based on the merge condition + val conjunctivePredicates = splitConjunctivePredicates(context.cmd.condition) + val targetOnlyPredicates = + conjunctivePredicates.filter(_.references.subsetOf(context.cmd.target.outputSet)) + val dataSkippedFiles = context.deltaTxn.filterFiles(targetOnlyPredicates) + + // target DataFrame + val targetDF = buildTargetDFWithFiles(dataSkippedFiles) + + val insertDf = sourceDF.join(targetDF, new Column(context.cmd.condition), "leftanti") + .select(outputCols: _*) + .filter(new Column(incrInsertedCountExpr)) + + val newFiles = context.deltaTxn + .writeFiles(repartitionIfNeeded(insertDf, + )) + + // Update metrics + context.cmd.metrics("numTargetFilesBeforeSkipping") += context.deltaTxn.snapshot.numOfFiles + context.cmd.metrics("numTargetBytesBeforeSkipping") += context.deltaTxn.snapshot.sizeInBytes + val (afterSkippingBytes, afterSkippingPartitions) = + totalBytesAndDistinctPartitionValues(dataSkippedFiles) + context.cmd.metrics("numTargetFilesAfterSkipping") += dataSkippedFiles.size + context.cmd.metrics("numTargetBytesAfterSkipping") += afterSkippingBytes + context.cmd.metrics("numTargetPartitionsAfterSkipping") += afterSkippingPartitions + context.cmd.metrics("numTargetFilesRemoved") += 0 + context.cmd.metrics("numTargetBytesRemoved") += 0 + context.cmd.metrics("numTargetPartitionsRemovedFrom") += 0 + val (addedBytes, addedPartitions) = totalBytesAndDistinctPartitionValues(newFiles) + context.cmd.metrics("numTargetFilesAdded") += newFiles.count(_.isInstanceOf[AddFile]) + context.cmd.metrics("numTargetBytesAdded") += addedBytes + context.cmd.metrics("numTargetPartitionsAddedTo") += addedPartitions + newFiles + } + } +} + + +/** + * This is an optimized algorithm for merge statement, where we avoid shuffling the unmodified + * target data. + * + * The algorithm is as follows: + * 1. Find touched target files in the target table by joining the source and target data, with + * collecting joined row identifiers as (`__metadata_file_path`, `__metadata_row_idx`) pairs. + * 2. Read the touched files again and write new files with updated and/or inserted rows + * without coping unmodified data from target table, but filtering target table with collected + * rows mentioned above. + * 3. Read the touched files again, filtering unmodified rows with collected row identifiers + * collected in first step, and saving them without shuffle. + */ +class LowShuffleMergeExecutor(override val context: MergeExecutorContext) extends MergeExecutor { + + // We over-count numTargetRowsDeleted when there are multiple matches; + // this is the amount of the overcount, so we can subtract it to get a correct final metric. + private var multipleMatchDeleteOnlyOvercount: Option[Long] = None + + // UDFs to update metrics + private val incrSourceRowCountExpr: Expression = context.cmd. + makeMetricUpdateUDF("numSourceRowsInSecondScan") + private val incrUpdatedCountExpr: Expression = context.cmd + .makeMetricUpdateUDF("numTargetRowsUpdated") + private val incrUpdatedMatchedCountExpr: Expression = context.cmd + .makeMetricUpdateUDF("numTargetRowsMatchedUpdated") + private val incrUpdatedNotMatchedBySourceCountExpr: Expression = context.cmd + .makeMetricUpdateUDF("numTargetRowsNotMatchedBySourceUpdated") + private val incrInsertedCountExpr: Expression = context.cmd + .makeMetricUpdateUDF("numTargetRowsInserted") + private val incrDeletedCountExpr: Expression = context.cmd + .makeMetricUpdateUDF("numTargetRowsDeleted") + private val incrDeletedMatchedCountExpr: Expression = context.cmd + .makeMetricUpdateUDF("numTargetRowsMatchedDeleted") + private val incrDeletedNotMatchedBySourceCountExpr: Expression = context.cmd + .makeMetricUpdateUDF("numTargetRowsNotMatchedBySourceDeleted") + + private def updateOutput(resolvedActions: Seq[DeltaMergeAction], incrExpr: Expression) + : Seq[Expression] = { + resolvedActions.map(_.expr) :+ + Literal.FalseLiteral :+ + UnresolvedAttribute(TARGET_ROW_PRESENT_COL) :+ + UnresolvedAttribute(SOURCE_ROW_PRESENT_COL) :+ + incrExpr + } + + private def deleteOutput(incrExpr: Expression): Seq[Expression] = { + targetOutputCols :+ + TrueLiteral :+ + UnresolvedAttribute(TARGET_ROW_PRESENT_COL) :+ + UnresolvedAttribute(SOURCE_ROW_PRESENT_COL) :+ + incrExpr + } + + private def insertOutput(resolvedActions: Seq[DeltaMergeAction], incrExpr: Expression) + : Seq[Expression] = { + resolvedActions.map(_.expr) :+ + Literal.FalseLiteral :+ + UnresolvedAttribute(TARGET_ROW_PRESENT_COL) :+ + UnresolvedAttribute(SOURCE_ROW_PRESENT_COL) :+ + incrExpr + } + + private def clauseOutput(clause: DeltaMergeIntoClause): Seq[Expression] = clause match { + case u: DeltaMergeIntoMatchedUpdateClause => + updateOutput(u.resolvedActions, And(incrUpdatedCountExpr, incrUpdatedMatchedCountExpr)) + case _: DeltaMergeIntoMatchedDeleteClause => + deleteOutput(And(incrDeletedCountExpr, incrDeletedMatchedCountExpr)) + case i: DeltaMergeIntoNotMatchedInsertClause => + insertOutput(i.resolvedActions, incrInsertedCountExpr) + case u: DeltaMergeIntoNotMatchedBySourceUpdateClause => + updateOutput(u.resolvedActions, + And(incrUpdatedCountExpr, incrUpdatedNotMatchedBySourceCountExpr)) + case _: DeltaMergeIntoNotMatchedBySourceDeleteClause => + deleteOutput(And(incrDeletedCountExpr, incrDeletedNotMatchedBySourceCountExpr)) + } + + private def clauseCondition(clause: DeltaMergeIntoClause): Expression = { + // if condition is None, then expression always evaluates to true + clause.condition.getOrElse(TrueLiteral) + } + + /** + * Though low shuffle merge algorithm performs better than traditional merge algorithm in some + * cases, there are some case we should fallback to traditional merge executor: + * + * 1. Low shuffle merge algorithm requires generating metadata columns such as + * [[METADATA_ROW_IDX_COL]], [[METADATA_ROW_DEL_COL]], which only implemented on + * [[org.apache.spark.sql.rapids.GpuFileSourceScanExec]]. That means we need to fallback to + * this normal executor when [[org.apache.spark.sql.rapids.GpuFileSourceScanExec]] is disabled + * for some reason. + * 2. Low shuffle merge algorithm currently needs to broadcast deletion vector, which may + * introduce extra overhead. It maybe better to fallback to this algorithm when the changeset + * it too large. + */ + private[delta] def shouldFallback(): Boolean = { + // Trying to detect if we can execute finding touched files. + val touchFilePlanOverrideSucceed = verifyGpuPlan(planForFindingTouchedFiles()) { planMeta => + def check(meta: SparkPlanMeta[SparkPlan]): Boolean = { + meta match { + case scan if scan.isInstanceOf[FileSourceScanExecMeta] => scan + .asInstanceOf[FileSourceScanExecMeta] + .wrapped + .schema + .fieldNames + .contains(METADATA_ROW_IDX_COL) && scan.canThisBeReplaced + case m => m.childPlans.exists(check) + } + } + + check(planMeta) + } + if (!touchFilePlanOverrideSucceed) { + logWarning("Unable to override file scan for low shuffle merge for finding touched files " + + "plan, fallback to tradition merge.") + return true + } + + // Trying to detect if we can execute the merge plan. + val mergePlanOverrideSucceed = verifyGpuPlan(planForMergeExecution(touchedFiles)) { planMeta => + var overrideCount = 0 + def count(meta: SparkPlanMeta[SparkPlan]): Unit = { + meta match { + case scan if scan.isInstanceOf[FileSourceScanExecMeta] => + if (scan.asInstanceOf[FileSourceScanExecMeta] + .wrapped.schema.fieldNames.contains(METADATA_ROW_DEL_COL) && scan.canThisBeReplaced) { + overrideCount += 1 + } + case m => m.childPlans.foreach(count) + } + } + + count(planMeta) + overrideCount == 2 + } + + if (!mergePlanOverrideSucceed) { + logWarning("Unable to override file scan for low shuffle merge for merge plan, fallback to " + + "tradition merge.") + return true + } + + val deletionVectorSize = touchedFiles.values.map(_._1.serializedSizeInBytes()).sum + val maxDelVectorSize = context.rapidsConf + .get(DELTA_LOW_SHUFFLE_MERGE_DEL_VECTOR_BROADCAST_THRESHOLD) + if (deletionVectorSize > maxDelVectorSize) { + logWarning( + s"""Low shuffle merge can't be executed because broadcast deletion vector count + |$deletionVectorSize is large than max value $maxDelVectorSize """.stripMargin) + return true + } + + false + } + + private def verifyGpuPlan(input: DataFrame)(checkPlanMeta: SparkPlanMeta[SparkPlan] => Boolean) + : Boolean = { + val overridePlan = GpuOverrides.wrapAndTagPlan(input.queryExecution.sparkPlan, + context.rapidsConf) + checkPlanMeta(overridePlan) + } + + override def execute(): Seq[FileAction] = { + val newFiles = context.cmd.withStatusCode("DELTA", + s"Rewriting ${touchedFiles.size} files and saving modified data") { + val df = planForMergeExecution(touchedFiles) + context.deltaTxn.writeFiles(df) + } + + // Update metrics + val (addedBytes, addedPartitions) = totalBytesAndDistinctPartitionValues(newFiles) + context.cmd.metrics("numTargetFilesAdded") += newFiles.count(_.isInstanceOf[AddFile]) + context.cmd.metrics("numTargetChangeFilesAdded") += newFiles.count(_.isInstanceOf[AddCDCFile]) + context.cmd.metrics("numTargetChangeFileBytes") += newFiles.collect { + case f: AddCDCFile => f.size + } + .sum + context.cmd.metrics("numTargetBytesAdded") += addedBytes + context.cmd.metrics("numTargetPartitionsAddedTo") += addedPartitions + + if (multipleMatchDeleteOnlyOvercount.isDefined) { + // Compensate for counting duplicates during the query. + val actualRowsDeleted = + context.cmd.metrics("numTargetRowsDeleted").value - multipleMatchDeleteOnlyOvercount.get + assert(actualRowsDeleted >= 0) + context.cmd.metrics("numTargetRowsDeleted").set(actualRowsDeleted) + } + + touchedFiles.values.map(_._2).map(_.remove).toSeq ++ newFiles + } + + private lazy val dataSkippedFiles: Seq[AddFile] = { + // Skip data based on the merge condition + val targetOnlyPredicates = splitConjunctivePredicates(context.cmd.condition) + .filter(_.references.subsetOf(context.cmd.target.outputSet)) + context.deltaTxn.filterFiles(targetOnlyPredicates) + } + + private lazy val dataSkippedTargetDF: DataFrame = { + addRowIndexMetaColumn(buildTargetDFWithFiles(dataSkippedFiles)) + } + + private lazy val touchedFiles: Map[String, (Roaring64Bitmap, AddFile)] = this.findTouchedFiles() + + private def planForFindingTouchedFiles(): DataFrame = { + + // Apply inner join to between source and target using the merge condition to find matches + // In addition, we attach two columns + // - METADATA_ROW_IDX column to identify target row in file + // - FILE_PATH_COL the target file name the row is from to later identify the files touched + // by matched rows + val targetDF = dataSkippedTargetDF.withColumn(FILE_PATH_COL, input_file_name()) + + sourceDF.join(targetDF, new Column(context.cmd.condition), "inner") + } + + private def planForMergeExecution(touchedFiles: Map[String, (Roaring64Bitmap, AddFile)]) + : DataFrame = { + getModifiedDF(touchedFiles).unionAll(getUnmodifiedDF(touchedFiles)) + } + + /** + * Find the target table files that contain the rows that satisfy the merge condition. This is + * implemented as an inner-join between the source query/table and the target table using + * the merge condition. + */ + private def findTouchedFiles(): Map[String, (Roaring64Bitmap, AddFile)] = + context.cmd.recordMergeOperation(sqlMetricName = "scanTimeMs") { + context.spark.udf.register("row_index_set", udaf(RoaringBitmapUDAF)) + // Process the matches from the inner join to record touched files and find multiple matches + val collectTouchedFiles = planForFindingTouchedFiles() + .select(col(FILE_PATH_COL), col(METADATA_ROW_IDX_COL)) + .groupBy(FILE_PATH_COL) + .agg( + expr(s"row_index_set($METADATA_ROW_IDX_COL) as row_idxes"), + count("*").as("count")) + .collect().map(row => { + val filename = row.getAs[String](FILE_PATH_COL) + val rowIdxSet = row.getAs[RoaringBitmapWrapper]("row_idxes").inner + val count = row.getAs[Long]("count") + (filename, (rowIdxSet, count)) + }) + .toMap + + val duplicateCount = { + val distinctMatchedRowCounts = collectTouchedFiles.values + .map(_._1.getLongCardinality).sum + val allMatchedRowCounts = collectTouchedFiles.values.map(_._2).sum + allMatchedRowCounts - distinctMatchedRowCounts + } + + val hasMultipleMatches = duplicateCount > 0 + + // Throw error if multiple matches are ambiguous or cannot be computed correctly. + val canBeComputedUnambiguously = { + // Multiple matches are not ambiguous when there is only one unconditional delete as + // all the matched row pairs in the 2nd join in `writeAllChanges` will get deleted. + val isUnconditionalDelete = context.cmd.matchedClauses.headOption match { + case Some(DeltaMergeIntoMatchedDeleteClause(None)) => true + case _ => false + } + context.cmd.matchedClauses.size == 1 && isUnconditionalDelete + } + + if (hasMultipleMatches && !canBeComputedUnambiguously) { + throw DeltaErrors.multipleSourceRowMatchingTargetRowInMergeException(context.spark) + } + + if (hasMultipleMatches) { + // This is only allowed for delete-only queries. + // This query will count the duplicates for numTargetRowsDeleted in Job 2, + // because we count matches after the join and not just the target rows. + // We have to compensate for this by subtracting the duplicates later, + // so we need to record them here. + multipleMatchDeleteOnlyOvercount = Some(duplicateCount) + } + + // Get the AddFiles using the touched file names. + val touchedFileNames = collectTouchedFiles.keys.toSeq + + val nameToAddFileMap = context.cmd.generateCandidateFileMap( + context.cmd.targetDeltaLog.dataPath, + dataSkippedFiles) + + val touchedAddFiles = touchedFileNames.map(f => + context.cmd.getTouchedFile(context.cmd.targetDeltaLog.dataPath, f, nameToAddFileMap)) + .map(f => (DeltaFileOperations + .absolutePath(context.cmd.targetDeltaLog.dataPath.toString, f.path) + .toString, f)).toMap + + // When the target table is empty, and the optimizer optimized away the join entirely + // numSourceRows will be incorrectly 0. + // We need to scan the source table once to get the correct + // metric here. + if (context.cmd.metrics("numSourceRows").value == 0 && + (dataSkippedFiles.isEmpty || dataSkippedTargetDF.take(1).isEmpty)) { + val numSourceRows = sourceDF.count() + context.cmd.metrics("numSourceRows").set(numSourceRows) + } + + // Update metrics + context.cmd.metrics("numTargetFilesBeforeSkipping") += context.deltaTxn.snapshot.numOfFiles + context.cmd.metrics("numTargetBytesBeforeSkipping") += context.deltaTxn.snapshot.sizeInBytes + val (afterSkippingBytes, afterSkippingPartitions) = + totalBytesAndDistinctPartitionValues(dataSkippedFiles) + context.cmd.metrics("numTargetFilesAfterSkipping") += dataSkippedFiles.size + context.cmd.metrics("numTargetBytesAfterSkipping") += afterSkippingBytes + context.cmd.metrics("numTargetPartitionsAfterSkipping") += afterSkippingPartitions + val (removedBytes, removedPartitions) = + totalBytesAndDistinctPartitionValues(touchedAddFiles.values.toSeq) + context.cmd.metrics("numTargetFilesRemoved") += touchedAddFiles.size + context.cmd.metrics("numTargetBytesRemoved") += removedBytes + context.cmd.metrics("numTargetPartitionsRemovedFrom") += removedPartitions + + collectTouchedFiles.map(kv => (kv._1, (kv._2._1, touchedAddFiles(kv._1)))) + } + + + /** + * Modify original data frame to insert + * [[GpuDeltaParquetFileFormatUtils.METADATA_ROW_IDX_COL]]. + */ + private def addRowIndexMetaColumn(baseDF: DataFrame): DataFrame = { + val rowIdxAttr = AttributeReference( + METADATA_ROW_IDX_COL, + METADATA_ROW_IDX_FIELD.dataType, + METADATA_ROW_IDX_FIELD.nullable)() + + val newPlan = baseDF.queryExecution.analyzed.transformUp { + case r@LogicalRelation(fs: HadoopFsRelation, _, _, _) => + val newSchema = StructType(fs.dataSchema.fields).add(METADATA_ROW_IDX_FIELD) + + // This is required to ensure that row index is correctly calculated. + val newFileFormat = fs.fileFormat.asInstanceOf[DeltaParquetFileFormat] + .copy(isSplittable = false, disablePushDowns = true) + + val newFs = fs.copy(dataSchema = newSchema, fileFormat = newFileFormat)(context.spark) + + val newOutput = r.output :+ rowIdxAttr + r.copy(relation = newFs, output = newOutput) + case p@Project(projectList, _) => + val newProjectList = projectList :+ rowIdxAttr + p.copy(projectList = newProjectList) + } + + Dataset.ofRows(context.spark, newPlan) + } + + /** + * The result is scanning target table with touched files, and added an extra + * [[METADATA_ROW_DEL_COL]] to indicate whether filtered by joining with source table in first + * step. + */ + private def getTouchedTargetDF(touchedFiles: Map[String, (Roaring64Bitmap, AddFile)]) + : DataFrame = { + // Generate a new target dataframe that has same output attributes exprIds as the target plan. + // This allows us to apply the existing resolved update/insert expressions. + val baseTargetDF = buildTargetDFWithFiles(touchedFiles.values.map(_._2).toSeq) + + val newPlan = { + val rowDelAttr = AttributeReference( + METADATA_ROW_DEL_COL, + METADATA_ROW_DEL_FIELD.dataType, + METADATA_ROW_DEL_FIELD.nullable)() + + baseTargetDF.queryExecution.analyzed.transformUp { + case r@LogicalRelation(fs: HadoopFsRelation, _, _, _) => + val newSchema = StructType(fs.dataSchema.fields).add(METADATA_ROW_DEL_FIELD) + + // This is required to ensure that row index is correctly calculated. + val newFileFormat = { + val oldFormat = fs.fileFormat.asInstanceOf[DeltaParquetFileFormat] + val dvs = touchedFiles.map(kv => (new URI(kv._1), + DeletionVectorDescriptorWithFilterType(toDeletionVector(kv._2._1), + RowIndexFilterType.UNKNOWN))) + val broadcastDVs = context.spark.sparkContext.broadcast(dvs) + + oldFormat.copy(isSplittable = false, + broadcastDvMap = Some(broadcastDVs), + disablePushDowns = true) + } + + val newFs = fs.copy(dataSchema = newSchema, fileFormat = newFileFormat)(context.spark) + + val newOutput = r.output :+ rowDelAttr + r.copy(relation = newFs, output = newOutput) + case p@Project(projectList, _) => + val newProjectList = projectList :+ rowDelAttr + p.copy(projectList = newProjectList) + } + } + + val df = Dataset.ofRows(context.spark, newPlan) + .withColumn(TARGET_ROW_PRESENT_COL, lit(true)) + + df + } + + /** + * Generate a plan by calculating modified rows. It's computed by joining source and target + * tables, where target table has been filtered by (`__metadata_file_name`, + * `__metadata_row_idx`) pairs collected in first step. + * + * Schema of `modifiedDF`: + * + * targetSchema + ROW_DROPPED_COL + TARGET_ROW_PRESENT_COL + + * SOURCE_ROW_PRESENT_COL + INCR_METRICS_COL + * INCR_METRICS_COL + * + * It consists of several parts: + * + * 1. Unmatched source rows which are inserted + * 2. Unmatched source rows which are deleted + * 3. Target rows which are updated + * 4. Target rows which are deleted + */ + private def getModifiedDF(touchedFiles: Map[String, (Roaring64Bitmap, AddFile)]): DataFrame = { + val sourceDF = this.sourceDF + .withColumn(SOURCE_ROW_PRESENT_COL, new Column(incrSourceRowCountExpr)) + + val targetDF = getTouchedTargetDF(touchedFiles) + + val joinedDF = { + val joinType = if (hasNoInserts && + context.spark.conf.get(DeltaSQLConf.MERGE_MATCHED_ONLY_ENABLED)) { + "inner" + } else { + "leftOuter" + } + val matchedTargetDF = targetDF.filter(METADATA_ROW_DEL_COL) + .drop(METADATA_ROW_DEL_COL) + + sourceDF.join(matchedTargetDF, new Column(context.cmd.condition), joinType) + } + + val modifiedRowsSchema = context.deltaTxn.metadata.schema + .add(ROW_DROPPED_FIELD) + .add(TARGET_ROW_PRESENT_FIELD.copy(nullable = true)) + .add(SOURCE_ROW_PRESENT_FIELD.copy(nullable = true)) + .add(INCR_METRICS_FIELD) + + // Here we generate a case when statement to handle all cases: + // CASE + // WHEN + // CASE WHEN + // + // WHEN + // + // ELSE + // + // WHEN + // CASE WHEN + // + // WHEN + // + // ELSE + // + // END + + val notMatchedConditions = context.cmd.notMatchedClauses.map(clauseCondition) + val notMatchedExpr = { + val deletedNotMatchedRow = { + targetOutputCols :+ + Literal.TrueLiteral :+ + Literal.FalseLiteral :+ + Literal(null) :+ + Literal.TrueLiteral + } + if (context.cmd.notMatchedClauses.isEmpty) { + // If there no `WHEN NOT MATCHED` clause, we should just delete not matched row + deletedNotMatchedRow + } else { + val notMatchedOutputs = context.cmd.notMatchedClauses.map(clauseOutput) + modifiedRowsSchema.zipWithIndex.map { + case (_, idx) => + CaseWhen(notMatchedConditions.zip(notMatchedOutputs.map(_(idx))), + deletedNotMatchedRow(idx)) + } + } + } + + val matchedConditions = context.cmd.matchedClauses.map(clauseCondition) + val matchedOutputs = context.cmd.matchedClauses.map(clauseOutput) + val matchedExprs = { + val notMatchedRow = { + targetOutputCols :+ + Literal.FalseLiteral :+ + Literal.TrueLiteral :+ + Literal(null) :+ + Literal.TrueLiteral + } + if (context.cmd.matchedClauses.isEmpty) { + // If there is not matched clause, this is insert only, we should delete this row. + notMatchedRow + } else { + modifiedRowsSchema.zipWithIndex.map { + case (_, idx) => + CaseWhen(matchedConditions.zip(matchedOutputs.map(_(idx))), + notMatchedRow(idx)) + } + } + } + + val sourceRowHasNoMatch = col(TARGET_ROW_PRESENT_COL).isNull.expr + + val modifiedCols = modifiedRowsSchema.zipWithIndex.map { case (col, idx) => + val caseWhen = CaseWhen( + Seq(sourceRowHasNoMatch -> notMatchedExpr(idx)), + matchedExprs(idx)) + Column(Alias(caseWhen, col.name)()) + } + + val modifiedDF = { + + // Make this a udf to avoid catalyst to be too aggressive to even remove the join! + val noopRowDroppedCol = udf(new GpuDeltaNoopUDF()).apply(!col(ROW_DROPPED_COL)) + + val modifiedDF = joinedDF.select(modifiedCols: _*) + // This will not filter anything since they always return true, but we need to avoid + // catalyst from optimizing these udf + .filter(noopRowDroppedCol && col(INCR_METRICS_COL)) + .drop(ROW_DROPPED_COL, INCR_METRICS_COL, TARGET_ROW_PRESENT_COL, SOURCE_ROW_PRESENT_COL) + + repartitionIfNeeded(modifiedDF) + } + + modifiedDF + } + + private def getUnmodifiedDF(touchedFiles: Map[String, (Roaring64Bitmap, AddFile)]): DataFrame = { + getTouchedTargetDF(touchedFiles) + .filter(!col(METADATA_ROW_DEL_COL)) + .drop(TARGET_ROW_PRESENT_COL, METADATA_ROW_DEL_COL) + } +} + + +object MergeExecutor { + + /** + * Spark UI will track all normal accumulators along with Spark tasks to show them on Web UI. + * However, the accumulator used by `MergeIntoCommand` can store a very large value since it + * tracks all files that need to be rewritten. We should ask Spark UI to not remember it, + * otherwise, the UI data may consume lots of memory. Hence, we use the prefix `internal.metrics.` + * to make this accumulator become an internal accumulator, so that it will not be tracked by + * Spark UI. + */ + val TOUCHED_FILES_ACCUM_NAME = "internal.metrics.MergeIntoDelta.touchedFiles" + + val ROW_ID_COL = "_row_id_" + val FILE_PATH_COL: String = GpuDeltaParquetFileFormatUtils.FILE_PATH_COL + val SOURCE_ROW_PRESENT_COL: String = "_source_row_present_" + val SOURCE_ROW_PRESENT_FIELD: StructField = StructField(SOURCE_ROW_PRESENT_COL, BooleanType, + nullable = false) + val TARGET_ROW_PRESENT_COL: String = "_target_row_present_" + val TARGET_ROW_PRESENT_FIELD: StructField = StructField(TARGET_ROW_PRESENT_COL, BooleanType, + nullable = false) + val ROW_DROPPED_COL: String = GpuDeltaMergeConstants.ROW_DROPPED_COL + val ROW_DROPPED_FIELD: StructField = StructField(ROW_DROPPED_COL, BooleanType, nullable = false) + val INCR_METRICS_COL: String = "_incr_metrics_" + val INCR_METRICS_FIELD: StructField = StructField(INCR_METRICS_COL, BooleanType, nullable = false) + val INCR_ROW_COUNT_COL: String = "_incr_row_count_" + + // Some Delta versions use Literal(null) which translates to a literal of NullType instead + // of the Literal(null, StringType) which is needed, so using a fixed version here + // rather than the version from Delta Lake. + val CDC_TYPE_NOT_CDC_LITERAL: Literal = Literal(null, StringType) + + private[delta] def toDeletionVector(bitmap: Roaring64Bitmap): DeletionVectorDescriptor = { + DeletionVectorDescriptor.inlineInLog(RoaringBitmapWrapper(bitmap).serializeToBytes(), + bitmap.getLongCardinality) + } + + /** Count the number of distinct partition values among the AddFiles in the given set. */ + private[delta] def totalBytesAndDistinctPartitionValues(files: Seq[FileAction]): (Long, Int) = { + val distinctValues = new mutable.HashSet[Map[String, String]]() + var bytes = 0L + val iter = files.collect { case a: AddFile => a }.iterator + while (iter.hasNext) { + val file = iter.next() + distinctValues += file.partitionValues + bytes += file.size + } + // If the only distinct value map is an empty map, then it must be an unpartitioned table. + // Return 0 in that case. + val numDistinctValues = + if (distinctValues.size == 1 && distinctValues.head.isEmpty) 0 else distinctValues.size + (bytes, numDistinctValues) + } +} \ No newline at end of file diff --git a/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuLowShuffleMergeCommand.scala b/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuLowShuffleMergeCommand.scala new file mode 100644 index 00000000000..fddebda33bd --- /dev/null +++ b/delta-lake/delta-spark341db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuLowShuffleMergeCommand.scala @@ -0,0 +1,1083 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * This file was derived from MergeIntoCommand.scala + * in the Delta Lake project at https://github.com/delta-io/delta. + * + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.databricks.sql.transaction.tahoe.rapids + +import java.net.URI +import java.util.concurrent.TimeUnit + +import scala.collection.mutable + +import com.databricks.sql.io.RowIndexFilterType +import com.databricks.sql.transaction.tahoe._ +import com.databricks.sql.transaction.tahoe.DeltaOperations.MergePredicate +import com.databricks.sql.transaction.tahoe.DeltaParquetFileFormat.DeletionVectorDescriptorWithFilterType +import com.databricks.sql.transaction.tahoe.actions.{AddCDCFile, AddFile, DeletionVectorDescriptor, FileAction} +import com.databricks.sql.transaction.tahoe.commands.DeltaCommand +import com.databricks.sql.transaction.tahoe.rapids.MergeExecutor.{toDeletionVector, totalBytesAndDistinctPartitionValues, FILE_PATH_COL, INCR_METRICS_COL, INCR_METRICS_FIELD, ROW_DROPPED_COL, ROW_DROPPED_FIELD, SOURCE_ROW_PRESENT_COL, SOURCE_ROW_PRESENT_FIELD, TARGET_ROW_PRESENT_COL, TARGET_ROW_PRESENT_FIELD} +import com.databricks.sql.transaction.tahoe.schema.ImplicitMetadataOperation +import com.databricks.sql.transaction.tahoe.sources.DeltaSQLConf +import com.databricks.sql.transaction.tahoe.util.{AnalysisHelper, DeltaFileOperations} +import com.nvidia.spark.rapids.{GpuOverrides, RapidsConf, SparkPlanMeta} +import com.nvidia.spark.rapids.RapidsConf.DELTA_LOW_SHUFFLE_MERGE_DEL_VECTOR_BROADCAST_THRESHOLD +import com.nvidia.spark.rapids.delta._ +import com.nvidia.spark.rapids.delta.GpuDeltaParquetFileFormatUtils.{METADATA_ROW_DEL_COL, METADATA_ROW_DEL_FIELD, METADATA_ROW_IDX_COL, METADATA_ROW_IDX_FIELD} +import com.nvidia.spark.rapids.shims.FileSourceScanExecMeta +import org.roaringbitmap.longlong.Roaring64Bitmap + +import org.apache.spark.SparkContext +import org.apache.spark.internal.Logging +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, CaseWhen, Expression, Literal, NamedExpression, PredicateHelper} +import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral +import org.apache.spark.sql.catalyst.plans.logical.{DeltaMergeAction, DeltaMergeIntoClause, DeltaMergeIntoMatchedClause, DeltaMergeIntoMatchedDeleteClause, DeltaMergeIntoMatchedUpdateClause, DeltaMergeIntoNotMatchedBySourceClause, DeltaMergeIntoNotMatchedBySourceDeleteClause, DeltaMergeIntoNotMatchedBySourceUpdateClause, DeltaMergeIntoNotMatchedClause, DeltaMergeIntoNotMatchedInsertClause, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.command.LeafRunnableCommand +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{BooleanType, LongType, StringType, StructField, StructType} + +/** + * GPU version of Delta Lake's low shuffle merge implementation. + * + * Performs a merge of a source query/table into a Delta table. + * + * Issues an error message when the ON search_condition of the MERGE statement can match + * a single row from the target table with multiple rows of the source table-reference. + * Different from the original implementation, it optimized writing touched unmodified target files. + * + * Algorithm: + * + * Phase 1: Find the input files in target that are touched by the rows that satisfy + * the condition and verify that no two source rows match with the same target row. + * This is implemented as an inner-join using the given condition. See [[findTouchedFiles]] + * for more details. + * + * Phase 2: Read the touched files again and write new files with updated and/or inserted rows + * without copying unmodified rows. + * + * Phase 3: Read the touched files again and write new files with unmodified rows in target table, + * trying to keep its original order and avoid shuffle as much as possible. + * + * Phase 4: Use the Delta protocol to atomically remove the touched files and add the new files. + * + * @param source Source data to merge from + * @param target Target table to merge into + * @param gpuDeltaLog Delta log to use + * @param condition Condition for a source row to match with a target row + * @param matchedClauses All info related to matched clauses. + * @param notMatchedClauses All info related to not matched clause. + * @param migratedSchema The final schema of the target - may be changed by schema evolution. + */ +case class GpuLowShuffleMergeCommand( + @transient source: LogicalPlan, + @transient target: LogicalPlan, + @transient gpuDeltaLog: GpuDeltaLog, + condition: Expression, + matchedClauses: Seq[DeltaMergeIntoMatchedClause], + notMatchedClauses: Seq[DeltaMergeIntoNotMatchedClause], + notMatchedBySourceClauses: Seq[DeltaMergeIntoNotMatchedBySourceClause], + migratedSchema: Option[StructType])( + @transient val rapidsConf: RapidsConf) + extends LeafRunnableCommand + with DeltaCommand with PredicateHelper with AnalysisHelper with ImplicitMetadataOperation { + + import SQLMetrics._ + + override val otherCopyArgs: Seq[AnyRef] = Seq(rapidsConf) + + override val canMergeSchema: Boolean = conf.getConf(DeltaSQLConf.DELTA_SCHEMA_AUTO_MIGRATE) + override val canOverwriteSchema: Boolean = false + + override val output: Seq[Attribute] = Seq( + AttributeReference("num_affected_rows", LongType)(), + AttributeReference("num_updated_rows", LongType)(), + AttributeReference("num_deleted_rows", LongType)(), + AttributeReference("num_inserted_rows", LongType)()) + + @transient private lazy val sc: SparkContext = SparkContext.getOrCreate() + @transient lazy val targetDeltaLog: DeltaLog = gpuDeltaLog.deltaLog + + override lazy val metrics = Map[String, SQLMetric]( + "numSourceRows" -> createMetric(sc, "number of source rows"), + "numSourceRowsInSecondScan" -> + createMetric(sc, "number of source rows (during repeated scan)"), + "numTargetRowsCopied" -> createMetric(sc, "number of target rows rewritten unmodified"), + "numTargetRowsInserted" -> createMetric(sc, "number of inserted rows"), + "numTargetRowsUpdated" -> createMetric(sc, "number of updated rows"), + "numTargetRowsDeleted" -> createMetric(sc, "number of deleted rows"), + "numTargetRowsMatchedUpdated" -> createMetric(sc, "number of target rows updated when matched"), + "numTargetRowsMatchedDeleted" -> createMetric(sc, "number of target rows deleted when matched"), + "numTargetRowsNotMatchedBySourceUpdated" -> createMetric(sc, + "number of target rows updated when not matched by source"), + "numTargetRowsNotMatchedBySourceDeleted" -> createMetric(sc, + "number of target rows deleted when not matched by source"), + "numTargetFilesBeforeSkipping" -> createMetric(sc, "number of target files before skipping"), + "numTargetFilesAfterSkipping" -> createMetric(sc, "number of target files after skipping"), + "numTargetFilesRemoved" -> createMetric(sc, "number of files removed to target"), + "numTargetFilesAdded" -> createMetric(sc, "number of files added to target"), + "numTargetChangeFilesAdded" -> + createMetric(sc, "number of change data capture files generated"), + "numTargetChangeFileBytes" -> + createMetric(sc, "total size of change data capture files generated"), + "numTargetBytesBeforeSkipping" -> createMetric(sc, "number of target bytes before skipping"), + "numTargetBytesAfterSkipping" -> createMetric(sc, "number of target bytes after skipping"), + "numTargetBytesRemoved" -> createMetric(sc, "number of target bytes removed"), + "numTargetBytesAdded" -> createMetric(sc, "number of target bytes added"), + "numTargetPartitionsAfterSkipping" -> + createMetric(sc, "number of target partitions after skipping"), + "numTargetPartitionsRemovedFrom" -> + createMetric(sc, "number of target partitions from which files were removed"), + "numTargetPartitionsAddedTo" -> + createMetric(sc, "number of target partitions to which files were added"), + "executionTimeMs" -> + createMetric(sc, "time taken to execute the entire operation"), + "scanTimeMs" -> + createMetric(sc, "time taken to scan the files for matches"), + "rewriteTimeMs" -> + createMetric(sc, "time taken to rewrite the matched files")) + + /** Whether this merge statement has only a single insert (NOT MATCHED) clause. */ + protected def isSingleInsertOnly: Boolean = matchedClauses.isEmpty && + notMatchedClauses.length == 1 + + override def run(spark: SparkSession): Seq[Row] = { + recordDeltaOperation(targetDeltaLog, "delta.dml.lowshufflemerge") { + val startTime = System.nanoTime() + val result = gpuDeltaLog.withNewTransaction { deltaTxn => + if (target.schema.size != deltaTxn.metadata.schema.size) { + throw DeltaErrors.schemaChangedSinceAnalysis( + atAnalysis = target.schema, latestSchema = deltaTxn.metadata.schema) + } + + if (canMergeSchema) { + updateMetadata( + spark, deltaTxn, migratedSchema.getOrElse(target.schema), + deltaTxn.metadata.partitionColumns, deltaTxn.metadata.configuration, + isOverwriteMode = false, rearrangeOnly = false) + } + + + val (executor, fallback) = { + val context = MergeExecutorContext(this, spark, deltaTxn, rapidsConf) + if (isSingleInsertOnly && spark.conf.get(DeltaSQLConf.MERGE_INSERT_ONLY_ENABLED)) { + (new InsertOnlyMergeExecutor(context), false) + } else { + val executor = new LowShuffleMergeExecutor(context) + (executor, executor.shouldFallback()) + } + } + + if (fallback) { + None + } else { + Some(runLowShuffleMerge(spark, startTime, deltaTxn, executor)) + } + } + + result match { + case Some(row) => row + case None => + // We should rollback to normal gpu + new GpuMergeIntoCommand(source, target, gpuDeltaLog, condition, matchedClauses, + notMatchedClauses, notMatchedBySourceClauses, migratedSchema)(rapidsConf) + .run(spark) + } + } + } + + + private def runLowShuffleMerge( + spark: SparkSession, + startTime: Long, + deltaTxn: GpuOptimisticTransactionBase, + mergeExecutor: MergeExecutor): Seq[Row] = { + val deltaActions = mergeExecutor.execute() + // Metrics should be recorded before commit (where they are written to delta logs). + metrics("executionTimeMs").set((System.nanoTime() - startTime) / 1000 / 1000) + deltaTxn.registerSQLMetrics(spark, metrics) + + // This is a best-effort sanity check. + if (metrics("numSourceRowsInSecondScan").value >= 0 && + metrics("numSourceRows").value != metrics("numSourceRowsInSecondScan").value) { + log.warn(s"Merge source has ${metrics("numSourceRows").value} rows in initial scan but " + + s"${metrics("numSourceRowsInSecondScan").value} rows in second scan") + if (conf.getConf(DeltaSQLConf.MERGE_FAIL_IF_SOURCE_CHANGED)) { + throw DeltaErrors.sourceNotDeterministicInMergeException(spark) + } + } + + deltaTxn.commit( + deltaActions, + DeltaOperations.Merge( + Option(condition), + matchedClauses.map(DeltaOperations.MergePredicate(_)), + notMatchedClauses.map(DeltaOperations.MergePredicate(_)), + // We do not support notMatchedBySourcePredicates yet and fall back to CPU + // See https://github.com/NVIDIA/spark-rapids/issues/8415 + notMatchedBySourcePredicates = Seq.empty[MergePredicate] + )) + + // Record metrics + val stats = GpuMergeStats.fromMergeSQLMetrics( + metrics, + condition, + matchedClauses, + notMatchedClauses, + deltaTxn.metadata.partitionColumns.nonEmpty) + recordDeltaEvent(targetDeltaLog, "delta.dml.merge.stats", data = stats) + + + spark.sharedState.cacheManager.recacheByPlan(spark, target) + + // This is needed to make the SQL metrics visible in the Spark UI. Also this needs + // to be outside the recordMergeOperation because this method will update some metric. + val executionId = spark.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(spark.sparkContext, executionId, metrics.values.toSeq) + Seq(Row(metrics("numTargetRowsUpdated").value + metrics("numTargetRowsDeleted").value + + metrics("numTargetRowsInserted").value, metrics("numTargetRowsUpdated").value, + metrics("numTargetRowsDeleted").value, metrics("numTargetRowsInserted").value)) + } + + /** + * Execute the given `thunk` and return its result while recording the time taken to do it. + * + * @param sqlMetricName name of SQL metric to update with the time taken by the thunk + * @param thunk the code to execute + */ + def recordMergeOperation[A](sqlMetricName: String)(thunk: => A): A = { + val startTimeNs = System.nanoTime() + val r = thunk + val timeTakenMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs) + if (sqlMetricName != null && timeTakenMs > 0) { + metrics(sqlMetricName) += timeTakenMs + } + r + } + + /** Expressions to increment SQL metrics */ + def makeMetricUpdateUDF(name: String, deterministic: Boolean = false) + : Expression = { + // only capture the needed metric in a local variable + val metric = metrics(name) + var u = DeltaUDF.boolean(new GpuDeltaMetricUpdateUDF(metric)) + if (!deterministic) { + u = u.asNondeterministic() + } + u.apply().expr + } +} + +/** + * Context merge execution. + */ +case class MergeExecutorContext(cmd: GpuLowShuffleMergeCommand, + spark: SparkSession, + deltaTxn: OptimisticTransaction, + rapidsConf: RapidsConf) + +trait MergeExecutor extends AnalysisHelper with PredicateHelper with Logging { + + val context: MergeExecutorContext + + + /** + * Map to get target output attributes by name. + * The case sensitivity of the map is set accordingly to Spark configuration. + */ + @transient private lazy val targetOutputAttributesMap: Map[String, Attribute] = { + val attrMap: Map[String, Attribute] = context.cmd.target + .outputSet.view + .map(attr => attr.name -> attr).toMap + if (context.cmd.conf.caseSensitiveAnalysis) { + attrMap + } else { + CaseInsensitiveMap(attrMap) + } + } + + def execute(): Seq[FileAction] + + protected def targetOutputCols: Seq[NamedExpression] = { + context.deltaTxn.metadata.schema.map { col => + targetOutputAttributesMap + .get(col.name) + .map { a => + AttributeReference(col.name, col.dataType, col.nullable)(a.exprId) + } + .getOrElse(Alias(Literal(null), col.name)()) + } + } + + /** + * Build a DataFrame using the given `files` that has the same output columns (exprIds) + * as the `target` logical plan, so that existing update/insert expressions can be applied + * on this new plan. + */ + protected def buildTargetDFWithFiles(files: Seq[AddFile]): DataFrame = { + val targetOutputColsMap = { + val colsMap: Map[String, NamedExpression] = targetOutputCols.view + .map(col => col.name -> col).toMap + if (context.cmd.conf.caseSensitiveAnalysis) { + colsMap + } else { + CaseInsensitiveMap(colsMap) + } + } + + val plan = { + // We have to do surgery to use the attributes from `targetOutputCols` to scan the table. + // In cases of schema evolution, they may not be the same type as the original attributes. + val original = + context.deltaTxn.deltaLog.createDataFrame(context.deltaTxn.snapshot, files) + .queryExecution + .analyzed + val transformed = original.transform { + case LogicalRelation(base, _, catalogTbl, isStreaming) => + LogicalRelation( + base, + // We can ignore the new columns which aren't yet AttributeReferences. + targetOutputCols.collect { case a: AttributeReference => a }, + catalogTbl, + isStreaming) + } + + // In case of schema evolution & column mapping, we would also need to rebuild the file + // format because under column mapping, the reference schema within DeltaParquetFileFormat + // that is used to populate metadata needs to be updated + if (context.deltaTxn.metadata.columnMappingMode != NoMapping) { + val updatedFileFormat = context.deltaTxn.deltaLog.fileFormat( + context.deltaTxn.deltaLog.unsafeVolatileSnapshot.protocol, context.deltaTxn.metadata) + DeltaTableUtils.replaceFileFormat(transformed, updatedFileFormat) + } else { + transformed + } + } + + // For each plan output column, find the corresponding target output column (by name) and + // create an alias + val aliases = plan.output.map { + case newAttrib: AttributeReference => + val existingTargetAttrib = targetOutputColsMap.getOrElse(newAttrib.name, + throw new AnalysisException( + s"Could not find ${newAttrib.name} among the existing target output " + + targetOutputCols.mkString(","))).asInstanceOf[AttributeReference] + + if (existingTargetAttrib.exprId == newAttrib.exprId) { + // It's not valid to alias an expression to its own exprId (this is considered a + // non-unique exprId by the analyzer), so we just use the attribute directly. + newAttrib + } else { + Alias(newAttrib, existingTargetAttrib.name)(exprId = existingTargetAttrib.exprId) + } + } + + Dataset.ofRows(context.spark, Project(aliases, plan)) + } + + + /** + * Repartitions the output DataFrame by the partition columns if table is partitioned + * and `merge.repartitionBeforeWrite.enabled` is set to true. + */ + protected def repartitionIfNeeded(df: DataFrame): DataFrame = { + val partitionColumns = context.deltaTxn.metadata.partitionColumns + // TODO: We should remove this method and use optimized write instead, see + // https://github.com/NVIDIA/spark-rapids/issues/10417 + if (partitionColumns.nonEmpty && context.spark.conf.get(DeltaSQLConf + .MERGE_REPARTITION_BEFORE_WRITE)) { + df.repartition(partitionColumns.map(col): _*) + } else { + df + } + } + + protected def sourceDF: DataFrame = { + // UDF to increment metrics + val incrSourceRowCountExpr = context.cmd.makeMetricUpdateUDF("numSourceRows") + Dataset.ofRows(context.spark, context.cmd.source) + .filter(new Column(incrSourceRowCountExpr)) + } + + /** Whether this merge statement has no insert (NOT MATCHED) clause. */ + protected def hasNoInserts: Boolean = context.cmd.notMatchedClauses.isEmpty + + +} + +/** + * This is an optimization of the case when there is no update clause for the merge. + * We perform an left anti join on the source data to find the rows to be inserted. + * + * This will currently only optimize for the case when there is a _single_ notMatchedClause. + */ +class InsertOnlyMergeExecutor(override val context: MergeExecutorContext) extends MergeExecutor { + override def execute(): Seq[FileAction] = { + context.cmd.recordMergeOperation(sqlMetricName = "rewriteTimeMs") { + + // UDFs to update metrics + val incrSourceRowCountExpr = context.cmd.makeMetricUpdateUDF("numSourceRows") + val incrInsertedCountExpr = context.cmd.makeMetricUpdateUDF("numTargetRowsInserted") + + val outputColNames = targetOutputCols.map(_.name) + // we use head here since we know there is only a single notMatchedClause + val outputExprs = context.cmd.notMatchedClauses.head.resolvedActions.map(_.expr) + val outputCols = outputExprs.zip(outputColNames).map { case (expr, name) => + new Column(Alias(expr, name)()) + } + + // source DataFrame + val sourceDF = Dataset.ofRows(context.spark, context.cmd.source) + .filter(new Column(incrSourceRowCountExpr)) + .filter(new Column(context.cmd.notMatchedClauses.head.condition + .getOrElse(Literal.TrueLiteral))) + + // Skip data based on the merge condition + val conjunctivePredicates = splitConjunctivePredicates(context.cmd.condition) + val targetOnlyPredicates = + conjunctivePredicates.filter(_.references.subsetOf(context.cmd.target.outputSet)) + val dataSkippedFiles = context.deltaTxn.filterFiles(targetOnlyPredicates) + + // target DataFrame + val targetDF = buildTargetDFWithFiles(dataSkippedFiles) + + val insertDf = sourceDF.join(targetDF, new Column(context.cmd.condition), "leftanti") + .select(outputCols: _*) + .filter(new Column(incrInsertedCountExpr)) + + val newFiles = context.deltaTxn + .writeFiles(repartitionIfNeeded(insertDf, + )) + + // Update metrics + context.cmd.metrics("numTargetFilesBeforeSkipping") += context.deltaTxn.snapshot.numOfFiles + context.cmd.metrics("numTargetBytesBeforeSkipping") += context.deltaTxn.snapshot.sizeInBytes + val (afterSkippingBytes, afterSkippingPartitions) = + totalBytesAndDistinctPartitionValues(dataSkippedFiles) + context.cmd.metrics("numTargetFilesAfterSkipping") += dataSkippedFiles.size + context.cmd.metrics("numTargetBytesAfterSkipping") += afterSkippingBytes + context.cmd.metrics("numTargetPartitionsAfterSkipping") += afterSkippingPartitions + context.cmd.metrics("numTargetFilesRemoved") += 0 + context.cmd.metrics("numTargetBytesRemoved") += 0 + context.cmd.metrics("numTargetPartitionsRemovedFrom") += 0 + val (addedBytes, addedPartitions) = totalBytesAndDistinctPartitionValues(newFiles) + context.cmd.metrics("numTargetFilesAdded") += newFiles.count(_.isInstanceOf[AddFile]) + context.cmd.metrics("numTargetBytesAdded") += addedBytes + context.cmd.metrics("numTargetPartitionsAddedTo") += addedPartitions + newFiles + } + } +} + + +/** + * This is an optimized algorithm for merge statement, where we avoid shuffling the unmodified + * target data. + * + * The algorithm is as follows: + * 1. Find touched target files in the target table by joining the source and target data, with + * collecting joined row identifiers as (`__metadata_file_path`, `__metadata_row_idx`) pairs. + * 2. Read the touched files again and write new files with updated and/or inserted rows + * without coping unmodified data from target table, but filtering target table with collected + * rows mentioned above. + * 3. Read the touched files again, filtering unmodified rows with collected row identifiers + * collected in first step, and saving them without shuffle. + */ +class LowShuffleMergeExecutor(override val context: MergeExecutorContext) extends MergeExecutor { + + // We over-count numTargetRowsDeleted when there are multiple matches; + // this is the amount of the overcount, so we can subtract it to get a correct final metric. + private var multipleMatchDeleteOnlyOvercount: Option[Long] = None + + // UDFs to update metrics + private val incrSourceRowCountExpr: Expression = context.cmd. + makeMetricUpdateUDF("numSourceRowsInSecondScan") + private val incrUpdatedCountExpr: Expression = context.cmd + .makeMetricUpdateUDF("numTargetRowsUpdated") + private val incrUpdatedMatchedCountExpr: Expression = context.cmd + .makeMetricUpdateUDF("numTargetRowsMatchedUpdated") + private val incrUpdatedNotMatchedBySourceCountExpr: Expression = context.cmd + .makeMetricUpdateUDF("numTargetRowsNotMatchedBySourceUpdated") + private val incrInsertedCountExpr: Expression = context.cmd + .makeMetricUpdateUDF("numTargetRowsInserted") + private val incrDeletedCountExpr: Expression = context.cmd + .makeMetricUpdateUDF("numTargetRowsDeleted") + private val incrDeletedMatchedCountExpr: Expression = context.cmd + .makeMetricUpdateUDF("numTargetRowsMatchedDeleted") + private val incrDeletedNotMatchedBySourceCountExpr: Expression = context.cmd + .makeMetricUpdateUDF("numTargetRowsNotMatchedBySourceDeleted") + + private def updateOutput(resolvedActions: Seq[DeltaMergeAction], incrExpr: Expression) + : Seq[Expression] = { + resolvedActions.map(_.expr) :+ + Literal.FalseLiteral :+ + UnresolvedAttribute(TARGET_ROW_PRESENT_COL) :+ + UnresolvedAttribute(SOURCE_ROW_PRESENT_COL) :+ + incrExpr + } + + private def deleteOutput(incrExpr: Expression): Seq[Expression] = { + targetOutputCols :+ + TrueLiteral :+ + UnresolvedAttribute(TARGET_ROW_PRESENT_COL) :+ + UnresolvedAttribute(SOURCE_ROW_PRESENT_COL) :+ + incrExpr + } + + private def insertOutput(resolvedActions: Seq[DeltaMergeAction], incrExpr: Expression) + : Seq[Expression] = { + resolvedActions.map(_.expr) :+ + Literal.FalseLiteral :+ + UnresolvedAttribute(TARGET_ROW_PRESENT_COL) :+ + UnresolvedAttribute(SOURCE_ROW_PRESENT_COL) :+ + incrExpr + } + + private def clauseOutput(clause: DeltaMergeIntoClause): Seq[Expression] = clause match { + case u: DeltaMergeIntoMatchedUpdateClause => + updateOutput(u.resolvedActions, And(incrUpdatedCountExpr, incrUpdatedMatchedCountExpr)) + case _: DeltaMergeIntoMatchedDeleteClause => + deleteOutput(And(incrDeletedCountExpr, incrDeletedMatchedCountExpr)) + case i: DeltaMergeIntoNotMatchedInsertClause => + insertOutput(i.resolvedActions, incrInsertedCountExpr) + case u: DeltaMergeIntoNotMatchedBySourceUpdateClause => + updateOutput(u.resolvedActions, + And(incrUpdatedCountExpr, incrUpdatedNotMatchedBySourceCountExpr)) + case _: DeltaMergeIntoNotMatchedBySourceDeleteClause => + deleteOutput(And(incrDeletedCountExpr, incrDeletedNotMatchedBySourceCountExpr)) + } + + private def clauseCondition(clause: DeltaMergeIntoClause): Expression = { + // if condition is None, then expression always evaluates to true + clause.condition.getOrElse(TrueLiteral) + } + + /** + * Though low shuffle merge algorithm performs better than traditional merge algorithm in some + * cases, there are some case we should fallback to traditional merge executor: + * + * 1. Low shuffle merge algorithm requires generating metadata columns such as + * [[METADATA_ROW_IDX_COL]], [[METADATA_ROW_DEL_COL]], which only implemented on + * [[org.apache.spark.sql.rapids.GpuFileSourceScanExec]]. That means we need to fallback to + * this normal executor when [[org.apache.spark.sql.rapids.GpuFileSourceScanExec]] is disabled + * for some reason. + * 2. Low shuffle merge algorithm currently needs to broadcast deletion vector, which may + * introduce extra overhead. It maybe better to fallback to this algorithm when the changeset + * it too large. + */ + def shouldFallback(): Boolean = { + // Trying to detect if we can execute finding touched files. + val touchFilePlanOverrideSucceed = verifyGpuPlan(planForFindingTouchedFiles()) { planMeta => + def check(meta: SparkPlanMeta[SparkPlan]): Boolean = { + meta match { + case scan if scan.isInstanceOf[FileSourceScanExecMeta] => scan + .asInstanceOf[FileSourceScanExecMeta] + .wrapped + .schema + .fieldNames + .contains(METADATA_ROW_IDX_COL) && scan.canThisBeReplaced + case m => m.childPlans.exists(check) + } + } + + check(planMeta) + } + if (!touchFilePlanOverrideSucceed) { + logWarning("Unable to override file scan for low shuffle merge for finding touched files " + + "plan, fallback to tradition merge.") + return true + } + + // Trying to detect if we can execute the merge plan. + val mergePlanOverrideSucceed = verifyGpuPlan(planForMergeExecution(touchedFiles)) { planMeta => + var overrideCount = 0 + def count(meta: SparkPlanMeta[SparkPlan]): Unit = { + meta match { + case scan if scan.isInstanceOf[FileSourceScanExecMeta] => + if (scan.asInstanceOf[FileSourceScanExecMeta] + .wrapped.schema.fieldNames.contains(METADATA_ROW_DEL_COL) && scan.canThisBeReplaced) { + overrideCount += 1 + } + case m => m.childPlans.foreach(count) + } + } + + count(planMeta) + overrideCount == 2 + } + + if (!mergePlanOverrideSucceed) { + logWarning("Unable to override file scan for low shuffle merge for merge plan, fallback to " + + "tradition merge.") + return true + } + + val deletionVectorSize = touchedFiles.values.map(_._1.serializedSizeInBytes()).sum + val maxDelVectorSize = context.rapidsConf + .get(DELTA_LOW_SHUFFLE_MERGE_DEL_VECTOR_BROADCAST_THRESHOLD) + if (deletionVectorSize > maxDelVectorSize) { + logWarning( + s"""Low shuffle merge can't be executed because broadcast deletion vector count + |$deletionVectorSize is large than max value $maxDelVectorSize """.stripMargin) + return true + } + + false + } + + private def verifyGpuPlan(input: DataFrame)(checkPlanMeta: SparkPlanMeta[SparkPlan] => Boolean) + : Boolean = { + val overridePlan = GpuOverrides.wrapAndTagPlan(input.queryExecution.sparkPlan, + context.rapidsConf) + checkPlanMeta(overridePlan) + } + + override def execute(): Seq[FileAction] = { + val newFiles = context.cmd.withStatusCode("DELTA", + s"Rewriting ${touchedFiles.size} files and saving modified data") { + val df = planForMergeExecution(touchedFiles) + context.deltaTxn.writeFiles(df) + } + + // Update metrics + val (addedBytes, addedPartitions) = totalBytesAndDistinctPartitionValues(newFiles) + context.cmd.metrics("numTargetFilesAdded") += newFiles.count(_.isInstanceOf[AddFile]) + context.cmd.metrics("numTargetChangeFilesAdded") += newFiles.count(_.isInstanceOf[AddCDCFile]) + context.cmd.metrics("numTargetChangeFileBytes") += newFiles.collect { + case f: AddCDCFile => f.size + } + .sum + context.cmd.metrics("numTargetBytesAdded") += addedBytes + context.cmd.metrics("numTargetPartitionsAddedTo") += addedPartitions + + if (multipleMatchDeleteOnlyOvercount.isDefined) { + // Compensate for counting duplicates during the query. + val actualRowsDeleted = + context.cmd.metrics("numTargetRowsDeleted").value - multipleMatchDeleteOnlyOvercount.get + assert(actualRowsDeleted >= 0) + context.cmd.metrics("numTargetRowsDeleted").set(actualRowsDeleted) + } + + touchedFiles.values.map(_._2).map(_.remove).toSeq ++ newFiles + } + + private lazy val dataSkippedFiles: Seq[AddFile] = { + // Skip data based on the merge condition + val targetOnlyPredicates = splitConjunctivePredicates(context.cmd.condition) + .filter(_.references.subsetOf(context.cmd.target.outputSet)) + context.deltaTxn.filterFiles(targetOnlyPredicates) + } + + private lazy val dataSkippedTargetDF: DataFrame = { + addRowIndexMetaColumn(buildTargetDFWithFiles(dataSkippedFiles)) + } + + private lazy val touchedFiles: Map[String, (Roaring64Bitmap, AddFile)] = this.findTouchedFiles() + + private def planForFindingTouchedFiles(): DataFrame = { + + // Apply inner join to between source and target using the merge condition to find matches + // In addition, we attach two columns + // - METADATA_ROW_IDX column to identify target row in file + // - FILE_PATH_COL the target file name the row is from to later identify the files touched + // by matched rows + val targetDF = dataSkippedTargetDF.withColumn(FILE_PATH_COL, input_file_name()) + + sourceDF.join(targetDF, new Column(context.cmd.condition), "inner") + } + + private def planForMergeExecution(touchedFiles: Map[String, (Roaring64Bitmap, AddFile)]) + : DataFrame = { + getModifiedDF(touchedFiles).unionAll(getUnmodifiedDF(touchedFiles)) + } + + /** + * Find the target table files that contain the rows that satisfy the merge condition. This is + * implemented as an inner-join between the source query/table and the target table using + * the merge condition. + */ + private def findTouchedFiles(): Map[String, (Roaring64Bitmap, AddFile)] = + context.cmd.recordMergeOperation(sqlMetricName = "scanTimeMs") { + context.spark.udf.register("row_index_set", udaf(RoaringBitmapUDAF)) + // Process the matches from the inner join to record touched files and find multiple matches + val collectTouchedFiles = planForFindingTouchedFiles() + .select(col(FILE_PATH_COL), col(METADATA_ROW_IDX_COL)) + .groupBy(FILE_PATH_COL) + .agg( + expr(s"row_index_set($METADATA_ROW_IDX_COL) as row_idxes"), + count("*").as("count")) + .collect().map(row => { + val filename = row.getAs[String](FILE_PATH_COL) + val rowIdxSet = row.getAs[RoaringBitmapWrapper]("row_idxes").inner + val count = row.getAs[Long]("count") + (filename, (rowIdxSet, count)) + }) + .toMap + + val duplicateCount = { + val distinctMatchedRowCounts = collectTouchedFiles.values + .map(_._1.getLongCardinality).sum + val allMatchedRowCounts = collectTouchedFiles.values.map(_._2).sum + allMatchedRowCounts - distinctMatchedRowCounts + } + + val hasMultipleMatches = duplicateCount > 0 + + // Throw error if multiple matches are ambiguous or cannot be computed correctly. + val canBeComputedUnambiguously = { + // Multiple matches are not ambiguous when there is only one unconditional delete as + // all the matched row pairs in the 2nd join in `writeAllChanges` will get deleted. + val isUnconditionalDelete = context.cmd.matchedClauses.headOption match { + case Some(DeltaMergeIntoMatchedDeleteClause(None)) => true + case _ => false + } + context.cmd.matchedClauses.size == 1 && isUnconditionalDelete + } + + if (hasMultipleMatches && !canBeComputedUnambiguously) { + throw DeltaErrors.multipleSourceRowMatchingTargetRowInMergeException(context.spark) + } + + if (hasMultipleMatches) { + // This is only allowed for delete-only queries. + // This query will count the duplicates for numTargetRowsDeleted in Job 2, + // because we count matches after the join and not just the target rows. + // We have to compensate for this by subtracting the duplicates later, + // so we need to record them here. + multipleMatchDeleteOnlyOvercount = Some(duplicateCount) + } + + // Get the AddFiles using the touched file names. + val touchedFileNames = collectTouchedFiles.keys.toSeq + + val nameToAddFileMap = context.cmd.generateCandidateFileMap( + context.cmd.targetDeltaLog.dataPath, + dataSkippedFiles) + + val touchedAddFiles = touchedFileNames.map(f => + context.cmd.getTouchedFile(context.cmd.targetDeltaLog.dataPath, f, nameToAddFileMap)) + .map(f => (DeltaFileOperations + .absolutePath(context.cmd.targetDeltaLog.dataPath.toString, f.path) + .toString, f)).toMap + + // When the target table is empty, and the optimizer optimized away the join entirely + // numSourceRows will be incorrectly 0. + // We need to scan the source table once to get the correct + // metric here. + if (context.cmd.metrics("numSourceRows").value == 0 && + (dataSkippedFiles.isEmpty || dataSkippedTargetDF.take(1).isEmpty)) { + val numSourceRows = sourceDF.count() + context.cmd.metrics("numSourceRows").set(numSourceRows) + } + + // Update metrics + context.cmd.metrics("numTargetFilesBeforeSkipping") += context.deltaTxn.snapshot.numOfFiles + context.cmd.metrics("numTargetBytesBeforeSkipping") += context.deltaTxn.snapshot.sizeInBytes + val (afterSkippingBytes, afterSkippingPartitions) = + totalBytesAndDistinctPartitionValues(dataSkippedFiles) + context.cmd.metrics("numTargetFilesAfterSkipping") += dataSkippedFiles.size + context.cmd.metrics("numTargetBytesAfterSkipping") += afterSkippingBytes + context.cmd.metrics("numTargetPartitionsAfterSkipping") += afterSkippingPartitions + val (removedBytes, removedPartitions) = + totalBytesAndDistinctPartitionValues(touchedAddFiles.values.toSeq) + context.cmd.metrics("numTargetFilesRemoved") += touchedAddFiles.size + context.cmd.metrics("numTargetBytesRemoved") += removedBytes + context.cmd.metrics("numTargetPartitionsRemovedFrom") += removedPartitions + + collectTouchedFiles.map(kv => (kv._1, (kv._2._1, touchedAddFiles(kv._1)))) + } + + + /** + * Modify original data frame to insert + * [[GpuDeltaParquetFileFormatUtils.METADATA_ROW_IDX_COL]]. + */ + private def addRowIndexMetaColumn(baseDF: DataFrame): DataFrame = { + val rowIdxAttr = AttributeReference( + METADATA_ROW_IDX_COL, + METADATA_ROW_IDX_FIELD.dataType, + METADATA_ROW_IDX_FIELD.nullable)() + + val newPlan = baseDF.queryExecution.analyzed.transformUp { + case r@LogicalRelation(fs: HadoopFsRelation, _, _, _) => + val newSchema = StructType(fs.dataSchema.fields).add(METADATA_ROW_IDX_FIELD) + + // This is required to ensure that row index is correctly calculated. + val newFileFormat = fs.fileFormat.asInstanceOf[DeltaParquetFileFormat] + .copy(isSplittable = false, disablePushDowns = true) + + val newFs = fs.copy(dataSchema = newSchema, fileFormat = newFileFormat)(context.spark) + + val newOutput = r.output :+ rowIdxAttr + r.copy(relation = newFs, output = newOutput) + case p@Project(projectList, _) => + val newProjectList = projectList :+ rowIdxAttr + p.copy(projectList = newProjectList) + } + + Dataset.ofRows(context.spark, newPlan) + } + + /** + * The result is scanning target table with touched files, and added an extra + * [[METADATA_ROW_DEL_COL]] to indicate whether filtered by joining with source table in first + * step. + */ + private def getTouchedTargetDF(touchedFiles: Map[String, (Roaring64Bitmap, AddFile)]) + : DataFrame = { + // Generate a new target dataframe that has same output attributes exprIds as the target plan. + // This allows us to apply the existing resolved update/insert expressions. + val baseTargetDF = buildTargetDFWithFiles(touchedFiles.values.map(_._2).toSeq) + + val newPlan = { + val rowDelAttr = AttributeReference( + METADATA_ROW_DEL_COL, + METADATA_ROW_DEL_FIELD.dataType, + METADATA_ROW_DEL_FIELD.nullable)() + + baseTargetDF.queryExecution.analyzed.transformUp { + case r@LogicalRelation(fs: HadoopFsRelation, _, _, _) => + val newSchema = StructType(fs.dataSchema.fields).add(METADATA_ROW_DEL_FIELD) + + // This is required to ensure that row index is correctly calculated. + val newFileFormat = { + val oldFormat = fs.fileFormat.asInstanceOf[DeltaParquetFileFormat] + val dvs = touchedFiles.map(kv => (new URI(kv._1), + DeletionVectorDescriptorWithFilterType(toDeletionVector(kv._2._1), + RowIndexFilterType.UNKNOWN))) + val broadcastDVs = context.spark.sparkContext.broadcast(dvs) + + oldFormat.copy(isSplittable = false, + broadcastDvMap = Some(broadcastDVs), + disablePushDowns = true) + } + + val newFs = fs.copy(dataSchema = newSchema, fileFormat = newFileFormat)(context.spark) + + val newOutput = r.output :+ rowDelAttr + r.copy(relation = newFs, output = newOutput) + case p@Project(projectList, _) => + val newProjectList = projectList :+ rowDelAttr + p.copy(projectList = newProjectList) + } + } + + val df = Dataset.ofRows(context.spark, newPlan) + .withColumn(TARGET_ROW_PRESENT_COL, lit(true)) + + df + } + + /** + * Generate a plan by calculating modified rows. It's computed by joining source and target + * tables, where target table has been filtered by (`__metadata_file_name`, + * `__metadata_row_idx`) pairs collected in first step. + * + * Schema of `modifiedDF`: + * + * targetSchema + ROW_DROPPED_COL + TARGET_ROW_PRESENT_COL + + * SOURCE_ROW_PRESENT_COL + INCR_METRICS_COL + * INCR_METRICS_COL + * + * It consists of several parts: + * + * 1. Unmatched source rows which are inserted + * 2. Unmatched source rows which are deleted + * 3. Target rows which are updated + * 4. Target rows which are deleted + */ + private def getModifiedDF(touchedFiles: Map[String, (Roaring64Bitmap, AddFile)]): DataFrame = { + val sourceDF = this.sourceDF + .withColumn(SOURCE_ROW_PRESENT_COL, new Column(incrSourceRowCountExpr)) + + val targetDF = getTouchedTargetDF(touchedFiles) + + val joinedDF = { + val joinType = if (hasNoInserts && + context.spark.conf.get(DeltaSQLConf.MERGE_MATCHED_ONLY_ENABLED)) { + "inner" + } else { + "leftOuter" + } + val matchedTargetDF = targetDF.filter(METADATA_ROW_DEL_COL) + .drop(METADATA_ROW_DEL_COL) + + sourceDF.join(matchedTargetDF, new Column(context.cmd.condition), joinType) + } + + val modifiedRowsSchema = context.deltaTxn.metadata.schema + .add(ROW_DROPPED_FIELD) + .add(TARGET_ROW_PRESENT_FIELD.copy(nullable = true)) + .add(SOURCE_ROW_PRESENT_FIELD.copy(nullable = true)) + .add(INCR_METRICS_FIELD) + + // Here we generate a case when statement to handle all cases: + // CASE + // WHEN + // CASE WHEN + // + // WHEN + // + // ELSE + // + // WHEN + // CASE WHEN + // + // WHEN + // + // ELSE + // + // END + + val notMatchedConditions = context.cmd.notMatchedClauses.map(clauseCondition) + val notMatchedExpr = { + val deletedNotMatchedRow = { + targetOutputCols :+ + Literal.TrueLiteral :+ + Literal.FalseLiteral :+ + Literal(null) :+ + Literal.TrueLiteral + } + if (context.cmd.notMatchedClauses.isEmpty) { + // If there no `WHEN NOT MATCHED` clause, we should just delete not matched row + deletedNotMatchedRow + } else { + val notMatchedOutputs = context.cmd.notMatchedClauses.map(clauseOutput) + modifiedRowsSchema.zipWithIndex.map { + case (_, idx) => + CaseWhen(notMatchedConditions.zip(notMatchedOutputs.map(_(idx))), + deletedNotMatchedRow(idx)) + } + } + } + + val matchedConditions = context.cmd.matchedClauses.map(clauseCondition) + val matchedOutputs = context.cmd.matchedClauses.map(clauseOutput) + val matchedExprs = { + val notMatchedRow = { + targetOutputCols :+ + Literal.FalseLiteral :+ + Literal.TrueLiteral :+ + Literal(null) :+ + Literal.TrueLiteral + } + if (context.cmd.matchedClauses.isEmpty) { + // If there is not matched clause, this is insert only, we should delete this row. + notMatchedRow + } else { + modifiedRowsSchema.zipWithIndex.map { + case (_, idx) => + CaseWhen(matchedConditions.zip(matchedOutputs.map(_(idx))), + notMatchedRow(idx)) + } + } + } + + val sourceRowHasNoMatch = col(TARGET_ROW_PRESENT_COL).isNull.expr + + val modifiedCols = modifiedRowsSchema.zipWithIndex.map { case (col, idx) => + val caseWhen = CaseWhen( + Seq(sourceRowHasNoMatch -> notMatchedExpr(idx)), + matchedExprs(idx)) + new Column(Alias(caseWhen, col.name)()) + } + + val modifiedDF = { + + // Make this a udf to avoid catalyst to be too aggressive to even remove the join! + val noopRowDroppedCol = udf(new GpuDeltaNoopUDF()).apply(!col(ROW_DROPPED_COL)) + + val modifiedDF = joinedDF.select(modifiedCols: _*) + // This will not filter anything since they always return true, but we need to avoid + // catalyst from optimizing these udf + .filter(noopRowDroppedCol && col(INCR_METRICS_COL)) + .drop(ROW_DROPPED_COL, INCR_METRICS_COL, TARGET_ROW_PRESENT_COL, SOURCE_ROW_PRESENT_COL) + + repartitionIfNeeded(modifiedDF) + } + + modifiedDF + } + + private def getUnmodifiedDF(touchedFiles: Map[String, (Roaring64Bitmap, AddFile)]): DataFrame = { + getTouchedTargetDF(touchedFiles) + .filter(!col(METADATA_ROW_DEL_COL)) + .drop(TARGET_ROW_PRESENT_COL, METADATA_ROW_DEL_COL) + } +} + + +object MergeExecutor { + + /** + * Spark UI will track all normal accumulators along with Spark tasks to show them on Web UI. + * However, the accumulator used by `MergeIntoCommand` can store a very large value since it + * tracks all files that need to be rewritten. We should ask Spark UI to not remember it, + * otherwise, the UI data may consume lots of memory. Hence, we use the prefix `internal.metrics.` + * to make this accumulator become an internal accumulator, so that it will not be tracked by + * Spark UI. + */ + val TOUCHED_FILES_ACCUM_NAME = "internal.metrics.MergeIntoDelta.touchedFiles" + + val ROW_ID_COL = "_row_id_" + val FILE_PATH_COL: String = GpuDeltaParquetFileFormatUtils.FILE_PATH_COL + val SOURCE_ROW_PRESENT_COL: String = "_source_row_present_" + val SOURCE_ROW_PRESENT_FIELD: StructField = StructField(SOURCE_ROW_PRESENT_COL, BooleanType, + nullable = false) + val TARGET_ROW_PRESENT_COL: String = "_target_row_present_" + val TARGET_ROW_PRESENT_FIELD: StructField = StructField(TARGET_ROW_PRESENT_COL, BooleanType, + nullable = false) + val ROW_DROPPED_COL: String = GpuDeltaMergeConstants.ROW_DROPPED_COL + val ROW_DROPPED_FIELD: StructField = StructField(ROW_DROPPED_COL, BooleanType, nullable = false) + val INCR_METRICS_COL: String = "_incr_metrics_" + val INCR_METRICS_FIELD: StructField = StructField(INCR_METRICS_COL, BooleanType, nullable = false) + val INCR_ROW_COUNT_COL: String = "_incr_row_count_" + + // Some Delta versions use Literal(null) which translates to a literal of NullType instead + // of the Literal(null, StringType) which is needed, so using a fixed version here + // rather than the version from Delta Lake. + val CDC_TYPE_NOT_CDC_LITERAL: Literal = Literal(null, StringType) + + def toDeletionVector(bitmap: Roaring64Bitmap): DeletionVectorDescriptor = { + DeletionVectorDescriptor.inlineInLog(RoaringBitmapWrapper(bitmap).serializeToBytes(), + bitmap.getLongCardinality) + } + + /** Count the number of distinct partition values among the AddFiles in the given set. */ + def totalBytesAndDistinctPartitionValues(files: Seq[FileAction]): (Long, Int) = { + val distinctValues = new mutable.HashSet[Map[String, String]]() + var bytes = 0L + val iter = files.collect { case a: AddFile => a }.iterator + while (iter.hasNext) { + val file = iter.next() + distinctValues += file.partitionValues + bytes += file.size + } + // If the only distinct value map is an empty map, then it must be an unpartitioned table. + // Return 0 in that case. + val numDistinctValues = + if (distinctValues.size == 1 && distinctValues.head.isEmpty) 0 else distinctValues.size + (bytes, numDistinctValues) + } +} \ No newline at end of file diff --git a/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/GpuDeltaParquetFileFormat.scala b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/GpuDeltaParquetFileFormat.scala index 969d005b573..604ed826397 100644 --- a/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/GpuDeltaParquetFileFormat.scala +++ b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/GpuDeltaParquetFileFormat.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,20 +16,32 @@ package com.nvidia.spark.rapids.delta +import java.net.URI + import com.databricks.sql.transaction.tahoe.{DeltaColumnMappingMode, DeltaParquetFileFormat, IdMapping} -import com.databricks.sql.transaction.tahoe.DeltaParquetFileFormat.IS_ROW_DELETED_COLUMN_NAME -import com.nvidia.spark.rapids.SparkPlanMeta +import com.databricks.sql.transaction.tahoe.DeltaParquetFileFormat.{DeletionVectorDescriptorWithFilterType, IS_ROW_DELETED_COLUMN_NAME} +import com.nvidia.spark.rapids.{GpuMetric, RapidsConf, SparkPlanMeta} +import com.nvidia.spark.rapids.delta.GpuDeltaParquetFileFormatUtils.addMetadataColumnToIterator +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.FileSourceScanExec +import org.apache.spark.sql.execution.datasources.PartitionedFile import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch case class GpuDeltaParquetFileFormat( override val columnMappingMode: DeltaColumnMappingMode, override val referenceSchema: StructType, - isSplittable: Boolean) extends GpuDeltaParquetFileFormatBase { + isSplittable: Boolean, + disablePushDown: Boolean, + broadcastDvMap: Option[Broadcast[Map[URI, DeletionVectorDescriptorWithFilterType]]] +) extends GpuDeltaParquetFileFormatBase { if (columnMappingMode == IdMapping) { val requiredReadConf = SQLConf.PARQUET_FIELD_ID_READ_ENABLED @@ -44,6 +56,46 @@ case class GpuDeltaParquetFileFormat( sparkSession: SparkSession, options: Map[String, String], path: Path): Boolean = isSplittable + + override def buildReaderWithPartitionValuesAndMetrics( + sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration, + metrics: Map[String, GpuMetric], + alluxioPathReplacementMap: Option[Map[String, String]]) + : PartitionedFile => Iterator[InternalRow] = { + + val dataReader = super.buildReaderWithPartitionValuesAndMetrics( + sparkSession, + dataSchema, + partitionSchema, + requiredSchema, + filters, + options, + hadoopConf, + metrics, + alluxioPathReplacementMap) + + val delVecs = broadcastDvMap + val maxDelVecScatterBatchSize = RapidsConf + .DELTA_LOW_SHUFFLE_MERGE_SCATTER_DEL_VECTOR_BATCH_SIZE + .get(sparkSession.sessionState.conf) + + (file: PartitionedFile) => { + val input = dataReader(file) + val dv = delVecs.flatMap(_.value.get(new URI(file.filePath.toString()))) + .map(dv => RoaringBitmapWrapper.deserializeFromBytes(dv.descriptor.inlineData).inner) + addMetadataColumnToIterator(prepareSchema(requiredSchema), + dv, + input.asInstanceOf[Iterator[ColumnarBatch]], + maxDelVecScatterBatchSize + ).asInstanceOf[Iterator[InternalRow]] + } + } } object GpuDeltaParquetFileFormat { @@ -60,6 +112,7 @@ object GpuDeltaParquetFileFormat { } def convertToGpu(fmt: DeltaParquetFileFormat): GpuDeltaParquetFileFormat = { - GpuDeltaParquetFileFormat(fmt.columnMappingMode, fmt.referenceSchema, fmt.isSplittable) + GpuDeltaParquetFileFormat(fmt.columnMappingMode, fmt.referenceSchema, fmt.isSplittable, + fmt.disablePushDowns, fmt.broadcastDvMap) } } diff --git a/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MergeIntoCommandMetaShim.scala b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MergeIntoCommandMetaShim.scala index 8e13a9e4b5a..5a2b4e7b52e 100644 --- a/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MergeIntoCommandMetaShim.scala +++ b/delta-lake/delta-spark341db/src/main/scala/com/nvidia/spark/rapids/delta/shims/MergeIntoCommandMetaShim.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,13 +17,14 @@ package com.nvidia.spark.rapids.delta.shims import com.databricks.sql.transaction.tahoe.commands.{MergeIntoCommand, MergeIntoCommandEdge} -import com.databricks.sql.transaction.tahoe.rapids.{GpuDeltaLog, GpuMergeIntoCommand} -import com.nvidia.spark.rapids.RapidsConf +import com.databricks.sql.transaction.tahoe.rapids.{GpuDeltaLog, GpuLowShuffleMergeCommand, GpuMergeIntoCommand} +import com.nvidia.spark.rapids.{RapidsConf, RapidsReaderType} import com.nvidia.spark.rapids.delta.{MergeIntoCommandEdgeMeta, MergeIntoCommandMeta} +import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.command.RunnableCommand -object MergeIntoCommandMetaShim { +object MergeIntoCommandMetaShim extends Logging { def tagForGpu(meta: MergeIntoCommandMeta, mergeCmd: MergeIntoCommand): Unit = { // see https://github.com/NVIDIA/spark-rapids/issues/8415 for more information if (mergeCmd.notMatchedBySourceClauses.nonEmpty) { @@ -39,26 +40,82 @@ object MergeIntoCommandMetaShim { } def convertToGpu(mergeCmd: MergeIntoCommand, conf: RapidsConf): RunnableCommand = { - GpuMergeIntoCommand( - mergeCmd.source, - mergeCmd.target, - new GpuDeltaLog(mergeCmd.targetFileIndex.deltaLog, conf), - mergeCmd.condition, - mergeCmd.matchedClauses, - mergeCmd.notMatchedClauses, - mergeCmd.notMatchedBySourceClauses, - mergeCmd.migratedSchema)(conf) + // TODO: Currently we only support low shuffler merge only when parquet per file read is enabled + // due to the limitation of implementing row index metadata column. + if (conf.isDeltaLowShuffleMergeEnabled) { + if (conf.isParquetPerFileReadEnabled) { + GpuLowShuffleMergeCommand( + mergeCmd.source, + mergeCmd.target, + new GpuDeltaLog(mergeCmd.targetFileIndex.deltaLog, conf), + mergeCmd.condition, + mergeCmd.matchedClauses, + mergeCmd.notMatchedClauses, + mergeCmd.notMatchedBySourceClauses, + mergeCmd.migratedSchema)(conf) + } else { + logWarning(s"""Low shuffle merge disabled since ${RapidsConf.PARQUET_READER_TYPE} is + not set to ${RapidsReaderType.PERFILE}. Falling back to classic merge.""") + GpuMergeIntoCommand( + mergeCmd.source, + mergeCmd.target, + new GpuDeltaLog(mergeCmd.targetFileIndex.deltaLog, conf), + mergeCmd.condition, + mergeCmd.matchedClauses, + mergeCmd.notMatchedClauses, + mergeCmd.notMatchedBySourceClauses, + mergeCmd.migratedSchema)(conf) + } + } else { + GpuMergeIntoCommand( + mergeCmd.source, + mergeCmd.target, + new GpuDeltaLog(mergeCmd.targetFileIndex.deltaLog, conf), + mergeCmd.condition, + mergeCmd.matchedClauses, + mergeCmd.notMatchedClauses, + mergeCmd.notMatchedBySourceClauses, + mergeCmd.migratedSchema)(conf) + } } def convertToGpu(mergeCmd: MergeIntoCommandEdge, conf: RapidsConf): RunnableCommand = { - GpuMergeIntoCommand( - mergeCmd.source, - mergeCmd.target, - new GpuDeltaLog(mergeCmd.targetFileIndex.deltaLog, conf), - mergeCmd.condition, - mergeCmd.matchedClauses, - mergeCmd.notMatchedClauses, - mergeCmd.notMatchedBySourceClauses, - mergeCmd.migratedSchema)(conf) + // TODO: Currently we only support low shuffler merge only when parquet per file read is enabled + // due to the limitation of implementing row index metadata column. + if (conf.isDeltaLowShuffleMergeEnabled) { + if (conf.isParquetPerFileReadEnabled) { + GpuLowShuffleMergeCommand( + mergeCmd.source, + mergeCmd.target, + new GpuDeltaLog(mergeCmd.targetFileIndex.deltaLog, conf), + mergeCmd.condition, + mergeCmd.matchedClauses, + mergeCmd.notMatchedClauses, + mergeCmd.notMatchedBySourceClauses, + mergeCmd.migratedSchema)(conf) + } else { + logWarning(s"""Low shuffle merge is still disable since ${RapidsConf.PARQUET_READER_TYPE} is + not set to ${RapidsReaderType.PERFILE}. Falling back to classic merge.""") + GpuMergeIntoCommand( + mergeCmd.source, + mergeCmd.target, + new GpuDeltaLog(mergeCmd.targetFileIndex.deltaLog, conf), + mergeCmd.condition, + mergeCmd.matchedClauses, + mergeCmd.notMatchedClauses, + mergeCmd.notMatchedBySourceClauses, + mergeCmd.migratedSchema)(conf) + } + } else { + GpuMergeIntoCommand( + mergeCmd.source, + mergeCmd.target, + new GpuDeltaLog(mergeCmd.targetFileIndex.deltaLog, conf), + mergeCmd.condition, + mergeCmd.matchedClauses, + mergeCmd.notMatchedClauses, + mergeCmd.notMatchedBySourceClauses, + mergeCmd.migratedSchema)(conf) + } } } diff --git a/docs/additional-functionality/advanced_configs.md b/docs/additional-functionality/advanced_configs.md index 3231b7b3069..941ab4046e6 100644 --- a/docs/additional-functionality/advanced_configs.md +++ b/docs/additional-functionality/advanced_configs.md @@ -73,6 +73,12 @@ Name | Description | Default Value | Applicable at spark.rapids.sql.csv.read.double.enabled|CSV reading is not 100% compatible when reading doubles.|true|Runtime spark.rapids.sql.csv.read.float.enabled|CSV reading is not 100% compatible when reading floats.|true|Runtime spark.rapids.sql.decimalOverflowGuarantees|FOR TESTING ONLY. DO NOT USE IN PRODUCTION. Please see the decimal section of the compatibility documents for more information on this config.|true|Runtime +spark.rapids.sql.delta.lowShuffleMerge.deletionVector.broadcast.threshold|Currently we need to broadcast deletion vector to all executors to perform low shuffle merge. When we detect the deletion vector broadcast size is larger than this value, we will fallback to normal shuffle merge.|20971520|Runtime +spark.rapids.sql.delta.lowShuffleMerge.enabled|Option to turn on the low shuffle merge for Delta Lake. Currently there are some limitations for this feature: +1. We only support Databricks Runtime 13.3 and Deltalake 2.4. +2. The file scan mode must be set to PERFILE +3. The deletion vector size must be smaller than spark.rapids.sql.delta.lowShuffleMerge.deletionVector.broadcast.threshold +|false|Runtime spark.rapids.sql.detectDeltaCheckpointQueries|Queries against Delta Lake _delta_log checkpoint Parquet files are not efficient on the GPU. When this option is enabled, the plugin will attempt to detect these queries and fall back to the CPU.|true|Runtime spark.rapids.sql.detectDeltaLogQueries|Queries against Delta Lake _delta_log JSON files are not efficient on the GPU. When this option is enabled, the plugin will attempt to detect these queries and fall back to the CPU.|true|Runtime spark.rapids.sql.fast.sample|Option to turn on fast sample. If enable it is inconsistent with CPU sample because of GPU sample algorithm is inconsistent with CPU.|false|Runtime diff --git a/integration_tests/src/main/python/delta_lake_low_shuffle_merge_test.py b/integration_tests/src/main/python/delta_lake_low_shuffle_merge_test.py new file mode 100644 index 00000000000..6935ee13751 --- /dev/null +++ b/integration_tests/src/main/python/delta_lake_low_shuffle_merge_test.py @@ -0,0 +1,165 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pyspark.sql.functions as f +import pytest + +from delta_lake_merge_common import * +from marks import * +from pyspark.sql.types import * +from spark_session import is_databricks133_or_later, spark_version + +delta_merge_enabled_conf = copy_and_update(delta_writes_enabled_conf, + {"spark.rapids.sql.command.MergeIntoCommand": "true", + "spark.rapids.sql.command.MergeIntoCommandEdge": "true", + "spark.rapids.sql.delta.lowShuffleMerge.enabled": "true", + "spark.rapids.sql.format.parquet.reader.type": "PERFILE"}) + +@allow_non_gpu("ColumnarToRowExec", *delta_meta_allow) +@delta_lake +@ignore_order +@pytest.mark.skipif(not ((is_databricks_runtime() and is_databricks133_or_later()) or + (not is_databricks_runtime() and spark_version().startswith("3.4"))), + reason="Delta Lake Low Shuffle Merge only supports Databricks 13.3 or OSS " + "delta 2.4") +@pytest.mark.parametrize("use_cdf", [True, False], ids=idfn) +@pytest.mark.parametrize("num_slices", num_slices_to_test, ids=idfn) +def test_delta_low_shuffle_merge_when_gpu_file_scan_override_failed(spark_tmp_path, + spark_tmp_table_factory, + use_cdf, num_slices): + # Need to eliminate duplicate keys in the source table otherwise update semantics are ambiguous + src_table_func = lambda spark: two_col_df(spark, int_gen, string_gen, num_slices=num_slices).groupBy("a").agg(f.max("b").alias("b")) + dest_table_func = lambda spark: two_col_df(spark, int_gen, string_gen, seed=1, num_slices=num_slices) + merge_sql = "MERGE INTO {dest_table} USING {src_table} ON {dest_table}.a == {src_table}.a" \ + " WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT *" + + conf = copy_and_update(delta_merge_enabled_conf, + { + "spark.rapids.sql.exec.FileSourceScanExec": "false", + # Disable auto broadcast join due to this issue: + # https://github.com/NVIDIA/spark-rapids/issues/10973 + "spark.sql.autoBroadcastJoinThreshold": "-1" + }) + assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf, + src_table_func, dest_table_func, merge_sql, False, conf=conf) + + + +@allow_non_gpu(*delta_meta_allow) +@delta_lake +@ignore_order +@pytest.mark.skipif(not ((is_databricks_runtime() and is_databricks133_or_later()) or + (not is_databricks_runtime() and spark_version().startswith("3.4"))), + reason="Delta Lake Low Shuffle Merge only supports Databricks 13.3 or OSS " + "delta 2.4") +@pytest.mark.parametrize("table_ranges", [(range(20), range(10)), # partial insert of source + (range(5), range(5)), # no-op insert + (range(10), range(20, 30)) # full insert of source + ], ids=idfn) +@pytest.mark.parametrize("use_cdf", [True, False], ids=idfn) +@pytest.mark.parametrize("partition_columns", [None, ["a"], ["b"], ["a", "b"]], ids=idfn) +@pytest.mark.parametrize("num_slices", num_slices_to_test, ids=idfn) +def test_delta_merge_not_match_insert_only(spark_tmp_path, spark_tmp_table_factory, table_ranges, + use_cdf, partition_columns, num_slices): + do_test_delta_merge_not_match_insert_only(spark_tmp_path, spark_tmp_table_factory, + table_ranges, use_cdf, partition_columns, + num_slices, False, delta_merge_enabled_conf) + +@allow_non_gpu(*delta_meta_allow) +@delta_lake +@ignore_order +@pytest.mark.skipif(not ((is_databricks_runtime() and is_databricks133_or_later()) or + (not is_databricks_runtime() and spark_version().startswith("3.4"))), + reason="Delta Lake Low Shuffle Merge only supports Databricks 13.3 or OSS " + "delta 2.4") +@pytest.mark.parametrize("table_ranges", [(range(10), range(20)), # partial delete of target + (range(5), range(5)), # full delete of target + (range(10), range(20, 30)) # no-op delete + ], ids=idfn) +@pytest.mark.parametrize("use_cdf", [True, False], ids=idfn) +@pytest.mark.parametrize("partition_columns", [None, ["a"], ["b"], ["a", "b"]], ids=idfn) +@pytest.mark.parametrize("num_slices", num_slices_to_test, ids=idfn) +def test_delta_merge_match_delete_only(spark_tmp_path, spark_tmp_table_factory, table_ranges, + use_cdf, partition_columns, num_slices): + do_test_delta_merge_match_delete_only(spark_tmp_path, spark_tmp_table_factory, table_ranges, + use_cdf, partition_columns, num_slices, False, + delta_merge_enabled_conf) + +@allow_non_gpu(*delta_meta_allow) +@delta_lake +@ignore_order +@pytest.mark.skipif(not ((is_databricks_runtime() and is_databricks133_or_later()) or + (not is_databricks_runtime() and spark_version().startswith("3.4"))), + reason="Delta Lake Low Shuffle Merge only supports Databricks 13.3 or OSS " + "delta 2.4") +@pytest.mark.parametrize("use_cdf", [True, False], ids=idfn) +@pytest.mark.parametrize("num_slices", num_slices_to_test, ids=idfn) +def test_delta_merge_standard_upsert(spark_tmp_path, spark_tmp_table_factory, use_cdf, num_slices): + do_test_delta_merge_standard_upsert(spark_tmp_path, spark_tmp_table_factory, use_cdf, + num_slices, False, delta_merge_enabled_conf) + +@allow_non_gpu(*delta_meta_allow) +@delta_lake +@ignore_order +@pytest.mark.skipif(not ((is_databricks_runtime() and is_databricks133_or_later()) or + (not is_databricks_runtime() and spark_version().startswith("3.4"))), + reason="Delta Lake Low Shuffle Merge only supports Databricks 13.3 or OSS " + "delta 2.4") +@pytest.mark.parametrize("use_cdf", [True, False], ids=idfn) +@pytest.mark.parametrize("merge_sql", [ + "MERGE INTO {dest_table} d USING {src_table} s ON d.a == s.a" \ + " WHEN MATCHED AND s.b > 'q' THEN UPDATE SET d.a = s.a / 2, d.b = s.b" \ + " WHEN NOT MATCHED THEN INSERT *", + "MERGE INTO {dest_table} d USING {src_table} s ON d.a == s.a" \ + " WHEN NOT MATCHED AND s.b > 'q' THEN INSERT *", + "MERGE INTO {dest_table} d USING {src_table} s ON d.a == s.a" \ + " WHEN MATCHED AND s.b > 'a' AND s.b < 'g' THEN UPDATE SET d.a = s.a / 2, d.b = s.b" \ + " WHEN MATCHED AND s.b > 'g' AND s.b < 'z' THEN UPDATE SET d.a = s.a / 4, d.b = concat('extra', s.b)" \ + " WHEN NOT MATCHED AND s.b > 'b' AND s.b < 'f' THEN INSERT *" \ + " WHEN NOT MATCHED AND s.b > 'f' AND s.b < 'z' THEN INSERT (b) VALUES ('not here')" ], ids=idfn) +@pytest.mark.parametrize("num_slices", num_slices_to_test, ids=idfn) +def test_delta_merge_upsert_with_condition(spark_tmp_path, spark_tmp_table_factory, use_cdf, merge_sql, num_slices): + do_test_delta_merge_upsert_with_condition(spark_tmp_path, spark_tmp_table_factory, use_cdf, + merge_sql, num_slices, False, + delta_merge_enabled_conf) + +@allow_non_gpu(*delta_meta_allow) +@delta_lake +@ignore_order +@pytest.mark.skipif(not ((is_databricks_runtime() and is_databricks133_or_later()) or + (not is_databricks_runtime() and spark_version().startswith("3.4"))), + reason="Delta Lake Low Shuffle Merge only supports Databricks 13.3 or OSS " + "delta 2.4") +@pytest.mark.parametrize("use_cdf", [True, False], ids=idfn) +@pytest.mark.parametrize("num_slices", num_slices_to_test, ids=idfn) +def test_delta_merge_upsert_with_unmatchable_match_condition(spark_tmp_path, spark_tmp_table_factory, use_cdf, num_slices): + do_test_delta_merge_upsert_with_unmatchable_match_condition(spark_tmp_path, + spark_tmp_table_factory, + use_cdf, + num_slices, + False, + delta_merge_enabled_conf) + +@allow_non_gpu(*delta_meta_allow) +@delta_lake +@ignore_order +@pytest.mark.skipif(not ((is_databricks_runtime() and is_databricks133_or_later()) or + (not is_databricks_runtime() and spark_version().startswith("3.4"))), + reason="Delta Lake Low Shuffle Merge only supports Databricks 13.3 or OSS " + "delta 2.4") +@pytest.mark.parametrize("use_cdf", [True, False], ids=idfn) +def test_delta_merge_update_with_aggregation(spark_tmp_path, spark_tmp_table_factory, use_cdf): + do_test_delta_merge_update_with_aggregation(spark_tmp_path, spark_tmp_table_factory, use_cdf, + delta_merge_enabled_conf) + diff --git a/integration_tests/src/main/python/delta_lake_merge_common.py b/integration_tests/src/main/python/delta_lake_merge_common.py new file mode 100644 index 00000000000..e6e9676625d --- /dev/null +++ b/integration_tests/src/main/python/delta_lake_merge_common.py @@ -0,0 +1,155 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pyspark.sql.functions as f +import string +from pyspark.sql.types import * + +from asserts import * +from data_gen import * +from delta_lake_utils import * +from spark_session import is_databricks_runtime + +# Databricks changes the number of files being written, so we cannot compare logs +num_slices_to_test = [10] if is_databricks_runtime() else [1, 10] + + +def make_df(spark, gen, num_slices): + return three_col_df(spark, gen, SetValuesGen(StringType(), string.ascii_lowercase), + SetValuesGen(StringType(), string.ascii_uppercase), num_slices=num_slices) + + +def delta_sql_merge_test(spark_tmp_path, spark_tmp_table_factory, use_cdf, + src_table_func, dest_table_func, merge_sql, check_func, + partition_columns=None): + data_path = spark_tmp_path + "/DELTA_DATA" + src_table = spark_tmp_table_factory.get() + + def setup_tables(spark): + setup_delta_dest_tables(spark, data_path, dest_table_func, use_cdf, partition_columns) + src_table_func(spark).createOrReplaceTempView(src_table) + + def do_merge(spark, path): + dest_table = spark_tmp_table_factory.get() + read_delta_path(spark, path).createOrReplaceTempView(dest_table) + return spark.sql(merge_sql.format(src_table=src_table, dest_table=dest_table)).collect() + with_cpu_session(setup_tables) + check_func(data_path, do_merge) + + +def assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf, + src_table_func, dest_table_func, merge_sql, + compare_logs, partition_columns=None, conf=None): + assert conf is not None, "conf must be set" + + def read_data(spark, path): + read_func = read_delta_path_with_cdf if use_cdf else read_delta_path + df = read_func(spark, path) + return df.sort(df.columns) + + def checker(data_path, do_merge): + cpu_path = data_path + "/CPU" + gpu_path = data_path + "/GPU" + # compare resulting dataframe from the merge operation (some older Spark versions return empty here) + cpu_result = with_cpu_session(lambda spark: do_merge(spark, cpu_path), conf=conf) + gpu_result = with_gpu_session(lambda spark: do_merge(spark, gpu_path), conf=conf) + assert_equal(cpu_result, gpu_result) + # compare merged table data results, read both via CPU to make sure GPU write can be read by CPU + cpu_result = with_cpu_session(lambda spark: read_data(spark, cpu_path).collect(), conf=conf) + gpu_result = with_cpu_session(lambda spark: read_data(spark, gpu_path).collect(), conf=conf) + assert_equal(cpu_result, gpu_result) + # Using partition columns involves sorting, and there's no guarantees on the task + # partitioning due to random sampling. + if compare_logs and not partition_columns: + with_cpu_session(lambda spark: assert_gpu_and_cpu_delta_logs_equivalent(spark, data_path)) + delta_sql_merge_test(spark_tmp_path, spark_tmp_table_factory, use_cdf, + src_table_func, dest_table_func, merge_sql, checker, partition_columns) + + +def do_test_delta_merge_not_match_insert_only(spark_tmp_path, spark_tmp_table_factory, table_ranges, + use_cdf, partition_columns, num_slices, compare_logs, + conf): + src_range, dest_range = table_ranges + src_table_func = lambda spark: make_df(spark, SetValuesGen(IntegerType(), src_range), num_slices) + dest_table_func = lambda spark: make_df(spark, SetValuesGen(IntegerType(), dest_range), num_slices) + merge_sql = "MERGE INTO {dest_table} USING {src_table} ON {dest_table}.a == {src_table}.a" \ + " WHEN NOT MATCHED THEN INSERT *" + assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf, + src_table_func, dest_table_func, merge_sql, compare_logs, + partition_columns, conf=conf) + + +def do_test_delta_merge_match_delete_only(spark_tmp_path, spark_tmp_table_factory, table_ranges, + use_cdf, partition_columns, num_slices, compare_logs, + conf): + src_range, dest_range = table_ranges + src_table_func = lambda spark: make_df(spark, SetValuesGen(IntegerType(), src_range), num_slices) + dest_table_func = lambda spark: make_df(spark, SetValuesGen(IntegerType(), dest_range), num_slices) + merge_sql = "MERGE INTO {dest_table} USING {src_table} ON {dest_table}.a == {src_table}.a" \ + " WHEN MATCHED THEN DELETE" + assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf, + src_table_func, dest_table_func, merge_sql, compare_logs, + partition_columns, conf=conf) + + +def do_test_delta_merge_standard_upsert(spark_tmp_path, spark_tmp_table_factory, use_cdf, + num_slices, compare_logs, conf): + # Need to eliminate duplicate keys in the source table otherwise update semantics are ambiguous + src_table_func = lambda spark: two_col_df(spark, int_gen, string_gen, num_slices=num_slices).groupBy("a").agg(f.max("b").alias("b")) + dest_table_func = lambda spark: two_col_df(spark, int_gen, string_gen, seed=1, num_slices=num_slices) + merge_sql = "MERGE INTO {dest_table} USING {src_table} ON {dest_table}.a == {src_table}.a" \ + " WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT *" + assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf, + src_table_func, dest_table_func, merge_sql, compare_logs, + conf=conf) + + +def do_test_delta_merge_upsert_with_condition(spark_tmp_path, spark_tmp_table_factory, use_cdf, + merge_sql, num_slices, compare_logs, conf): + # Need to eliminate duplicate keys in the source table otherwise update semantics are ambiguous + src_table_func = lambda spark: two_col_df(spark, int_gen, string_gen, num_slices=num_slices).groupBy("a").agg(f.max("b").alias("b")) + dest_table_func = lambda spark: two_col_df(spark, int_gen, string_gen, seed=1, num_slices=num_slices) + assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf, + src_table_func, dest_table_func, merge_sql, compare_logs, + conf=conf) + + +def do_test_delta_merge_upsert_with_unmatchable_match_condition(spark_tmp_path, + spark_tmp_table_factory, use_cdf, + num_slices, compare_logs, conf): + # Need to eliminate duplicate keys in the source table otherwise update semantics are ambiguous + src_table_func = lambda spark: two_col_df(spark, int_gen, string_gen, num_slices=num_slices).groupBy("a").agg(f.max("b").alias("b")) + dest_table_func = lambda spark: two_col_df(spark, SetValuesGen(IntegerType(), range(100)), string_gen, seed=1, num_slices=num_slices) + merge_sql = "MERGE INTO {dest_table} USING {src_table} ON {dest_table}.a == {src_table}.a" \ + " WHEN MATCHED AND {dest_table}.a > 100 THEN UPDATE SET *" + assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf, + src_table_func, dest_table_func, merge_sql, compare_logs, + conf=conf) + + +def do_test_delta_merge_update_with_aggregation(spark_tmp_path, spark_tmp_table_factory, use_cdf, + conf): + # Need to eliminate duplicate keys in the source table otherwise update semantics are ambiguous + src_table_func = lambda spark: spark.range(10).withColumn("x", f.col("id") + 1) \ + .select(f.col("id"), (f.col("x") + 1).alias("x")) \ + .drop_duplicates(["id"]) \ + .limit(10) + dest_table_func = lambda spark: spark.range(5).withColumn("x", f.col("id") + 1) + merge_sql = "MERGE INTO {dest_table} USING {src_table} ON {dest_table}.id == {src_table}.id" \ + " WHEN MATCHED THEN UPDATE SET {dest_table}.x = {src_table}.x + 2" \ + " WHEN NOT MATCHED AND {src_table}.x < 7 THEN INSERT *" + + assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf, + src_table_func, dest_table_func, merge_sql, + compare_logs=False, conf=conf) diff --git a/integration_tests/src/main/python/delta_lake_merge_test.py b/integration_tests/src/main/python/delta_lake_merge_test.py index 0880db16434..5c3bb915ddb 100644 --- a/integration_tests/src/main/python/delta_lake_merge_test.py +++ b/integration_tests/src/main/python/delta_lake_merge_test.py @@ -14,66 +14,17 @@ import pyspark.sql.functions as f import pytest -import string -from asserts import * -from data_gen import * -from delta_lake_utils import * +from delta_lake_merge_common import * from marks import * from pyspark.sql.types import * from spark_session import is_before_spark_320, is_databricks_runtime, spark_version -# Databricks changes the number of files being written, so we cannot compare logs -num_slices_to_test = [10] if is_databricks_runtime() else [1, 10] delta_merge_enabled_conf = copy_and_update(delta_writes_enabled_conf, {"spark.rapids.sql.command.MergeIntoCommand": "true", "spark.rapids.sql.command.MergeIntoCommandEdge": "true"}) -def make_df(spark, gen, num_slices): - return three_col_df(spark, gen, SetValuesGen(StringType(), string.ascii_lowercase), - SetValuesGen(StringType(), string.ascii_uppercase), num_slices=num_slices) - -def delta_sql_merge_test(spark_tmp_path, spark_tmp_table_factory, use_cdf, - src_table_func, dest_table_func, merge_sql, check_func, - partition_columns=None): - data_path = spark_tmp_path + "/DELTA_DATA" - src_table = spark_tmp_table_factory.get() - def setup_tables(spark): - setup_delta_dest_tables(spark, data_path, dest_table_func, use_cdf, partition_columns) - src_table_func(spark).createOrReplaceTempView(src_table) - def do_merge(spark, path): - dest_table = spark_tmp_table_factory.get() - read_delta_path(spark, path).createOrReplaceTempView(dest_table) - return spark.sql(merge_sql.format(src_table=src_table, dest_table=dest_table)).collect() - with_cpu_session(setup_tables) - check_func(data_path, do_merge) - -def assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf, - src_table_func, dest_table_func, merge_sql, - compare_logs, partition_columns=None, - conf=delta_merge_enabled_conf): - def read_data(spark, path): - read_func = read_delta_path_with_cdf if use_cdf else read_delta_path - df = read_func(spark, path) - return df.sort(df.columns) - def checker(data_path, do_merge): - cpu_path = data_path + "/CPU" - gpu_path = data_path + "/GPU" - # compare resulting dataframe from the merge operation (some older Spark versions return empty here) - cpu_result = with_cpu_session(lambda spark: do_merge(spark, cpu_path), conf=conf) - gpu_result = with_gpu_session(lambda spark: do_merge(spark, gpu_path), conf=conf) - assert_equal(cpu_result, gpu_result) - # compare merged table data results, read both via CPU to make sure GPU write can be read by CPU - cpu_result = with_cpu_session(lambda spark: read_data(spark, cpu_path).collect(), conf=conf) - gpu_result = with_cpu_session(lambda spark: read_data(spark, gpu_path).collect(), conf=conf) - assert_equal(cpu_result, gpu_result) - # Using partition columns involves sorting, and there's no guarantees on the task - # partitioning due to random sampling. - if compare_logs and not partition_columns: - with_cpu_session(lambda spark: assert_gpu_and_cpu_delta_logs_equivalent(spark, data_path)) - delta_sql_merge_test(spark_tmp_path, spark_tmp_table_factory, use_cdf, - src_table_func, dest_table_func, merge_sql, checker, partition_columns) @allow_non_gpu(delta_write_fallback_allow, *delta_meta_allow) @delta_lake @@ -162,16 +113,9 @@ def test_delta_merge_partial_fallback_via_conf(spark_tmp_path, spark_tmp_table_f @pytest.mark.parametrize("num_slices", num_slices_to_test, ids=idfn) def test_delta_merge_not_match_insert_only(spark_tmp_path, spark_tmp_table_factory, table_ranges, use_cdf, partition_columns, num_slices): - src_range, dest_range = table_ranges - src_table_func = lambda spark: make_df(spark, SetValuesGen(IntegerType(), src_range), num_slices) - dest_table_func = lambda spark: make_df(spark, SetValuesGen(IntegerType(), dest_range), num_slices) - merge_sql = "MERGE INTO {dest_table} USING {src_table} ON {dest_table}.a == {src_table}.a" \ - " WHEN NOT MATCHED THEN INSERT *" - # Non-deterministic input for each task means we can only reliably compare record counts when using only one task - compare_logs = num_slices == 1 - assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf, - src_table_func, dest_table_func, merge_sql, compare_logs, - partition_columns) + do_test_delta_merge_not_match_insert_only(spark_tmp_path, spark_tmp_table_factory, + table_ranges, use_cdf, partition_columns, + num_slices, num_slices == 1, delta_merge_enabled_conf) @allow_non_gpu(*delta_meta_allow) @delta_lake @@ -186,16 +130,9 @@ def test_delta_merge_not_match_insert_only(spark_tmp_path, spark_tmp_table_facto @pytest.mark.parametrize("num_slices", num_slices_to_test, ids=idfn) def test_delta_merge_match_delete_only(spark_tmp_path, spark_tmp_table_factory, table_ranges, use_cdf, partition_columns, num_slices): - src_range, dest_range = table_ranges - src_table_func = lambda spark: make_df(spark, SetValuesGen(IntegerType(), src_range), num_slices) - dest_table_func = lambda spark: make_df(spark, SetValuesGen(IntegerType(), dest_range), num_slices) - merge_sql = "MERGE INTO {dest_table} USING {src_table} ON {dest_table}.a == {src_table}.a" \ - " WHEN MATCHED THEN DELETE" - # Non-deterministic input for each task means we can only reliably compare record counts when using only one task - compare_logs = num_slices == 1 - assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf, - src_table_func, dest_table_func, merge_sql, compare_logs, - partition_columns) + do_test_delta_merge_match_delete_only(spark_tmp_path, spark_tmp_table_factory, table_ranges, + use_cdf, partition_columns, num_slices, num_slices == 1, + delta_merge_enabled_conf) @allow_non_gpu(*delta_meta_allow) @delta_lake @@ -204,15 +141,9 @@ def test_delta_merge_match_delete_only(spark_tmp_path, spark_tmp_table_factory, @pytest.mark.parametrize("use_cdf", [True, False], ids=idfn) @pytest.mark.parametrize("num_slices", num_slices_to_test, ids=idfn) def test_delta_merge_standard_upsert(spark_tmp_path, spark_tmp_table_factory, use_cdf, num_slices): - # Need to eliminate duplicate keys in the source table otherwise update semantics are ambiguous - src_table_func = lambda spark: two_col_df(spark, int_gen, string_gen, num_slices=num_slices).groupBy("a").agg(f.max("b").alias("b")) - dest_table_func = lambda spark: two_col_df(spark, int_gen, string_gen, seed=1, num_slices=num_slices) - merge_sql = "MERGE INTO {dest_table} USING {src_table} ON {dest_table}.a == {src_table}.a" \ - " WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT *" - # Non-deterministic input for each task means we can only reliably compare record counts when using only one task - compare_logs = num_slices == 1 - assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf, - src_table_func, dest_table_func, merge_sql, compare_logs) + do_test_delta_merge_standard_upsert(spark_tmp_path, spark_tmp_table_factory, use_cdf, + num_slices, num_slices == 1, delta_merge_enabled_conf) + @allow_non_gpu(*delta_meta_allow) @delta_lake @@ -232,13 +163,10 @@ def test_delta_merge_standard_upsert(spark_tmp_path, spark_tmp_table_factory, us " WHEN NOT MATCHED AND s.b > 'f' AND s.b < 'z' THEN INSERT (b) VALUES ('not here')" ], ids=idfn) @pytest.mark.parametrize("num_slices", num_slices_to_test, ids=idfn) def test_delta_merge_upsert_with_condition(spark_tmp_path, spark_tmp_table_factory, use_cdf, merge_sql, num_slices): - # Need to eliminate duplicate keys in the source table otherwise update semantics are ambiguous - src_table_func = lambda spark: two_col_df(spark, int_gen, string_gen, num_slices=num_slices).groupBy("a").agg(f.max("b").alias("b")) - dest_table_func = lambda spark: two_col_df(spark, int_gen, string_gen, seed=1, num_slices=num_slices) - # Non-deterministic input for each task means we can only reliably compare record counts when using only one task - compare_logs = num_slices == 1 - assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf, - src_table_func, dest_table_func, merge_sql, compare_logs) + do_test_delta_merge_upsert_with_condition(spark_tmp_path, spark_tmp_table_factory, use_cdf, + merge_sql, num_slices, num_slices == 1, + delta_merge_enabled_conf) + @allow_non_gpu(*delta_meta_allow) @delta_lake @@ -247,15 +175,10 @@ def test_delta_merge_upsert_with_condition(spark_tmp_path, spark_tmp_table_facto @pytest.mark.parametrize("use_cdf", [True, False], ids=idfn) @pytest.mark.parametrize("num_slices", num_slices_to_test, ids=idfn) def test_delta_merge_upsert_with_unmatchable_match_condition(spark_tmp_path, spark_tmp_table_factory, use_cdf, num_slices): - # Need to eliminate duplicate keys in the source table otherwise update semantics are ambiguous - src_table_func = lambda spark: two_col_df(spark, int_gen, string_gen, num_slices=num_slices).groupBy("a").agg(f.max("b").alias("b")) - dest_table_func = lambda spark: two_col_df(spark, SetValuesGen(IntegerType(), range(100)), string_gen, seed=1, num_slices=num_slices) - merge_sql = "MERGE INTO {dest_table} USING {src_table} ON {dest_table}.a == {src_table}.a" \ - " WHEN MATCHED AND {dest_table}.a > 100 THEN UPDATE SET *" - # Non-deterministic input for each task means we can only reliably compare record counts when using only one task - compare_logs = num_slices == 1 - assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf, - src_table_func, dest_table_func, merge_sql, compare_logs) + do_test_delta_merge_upsert_with_unmatchable_match_condition(spark_tmp_path, + spark_tmp_table_factory, use_cdf, + num_slices, num_slices == 1, + delta_merge_enabled_conf) @allow_non_gpu(*delta_meta_allow) @delta_lake @@ -263,18 +186,8 @@ def test_delta_merge_upsert_with_unmatchable_match_condition(spark_tmp_path, spa @pytest.mark.skipif(is_before_spark_320(), reason="Delta Lake writes are not supported before Spark 3.2.x") @pytest.mark.parametrize("use_cdf", [True, False], ids=idfn) def test_delta_merge_update_with_aggregation(spark_tmp_path, spark_tmp_table_factory, use_cdf): - # Need to eliminate duplicate keys in the source table otherwise update semantics are ambiguous - src_table_func = lambda spark: spark.range(10).withColumn("x", f.col("id") + 1)\ - .select(f.col("id"), (f.col("x") + 1).alias("x"))\ - .drop_duplicates(["id"])\ - .limit(10) - dest_table_func = lambda spark: spark.range(5).withColumn("x", f.col("id") + 1) - merge_sql = "MERGE INTO {dest_table} USING {src_table} ON {dest_table}.id == {src_table}.id" \ - " WHEN MATCHED THEN UPDATE SET {dest_table}.x = {src_table}.x + 2" \ - " WHEN NOT MATCHED AND {src_table}.x < 7 THEN INSERT *" - - assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf, - src_table_func, dest_table_func, merge_sql, compare_logs=False) + do_test_delta_merge_update_with_aggregation(spark_tmp_path, spark_tmp_table_factory, use_cdf, + delta_merge_enabled_conf) @allow_non_gpu(*delta_meta_allow) @delta_lake diff --git a/pom.xml b/pom.xml index 06947857521..3ff87c3cb97 100644 --- a/pom.xml +++ b/pom.xml @@ -733,6 +733,7 @@ --> -Xlint:all,-serial,-path,-try,-processing|-Werror 1.16.0 + 1.0.6 ${ucx.baseVersion} true @@ -1016,6 +1017,15 @@ ${alluxio.client.version} provided + + + org.roaringbitmap + RoaringBitmap + ${roaringbitmap.version} + compile + org.scalatest scalatest_${scala.binary.version} diff --git a/scala2.13/aggregator/pom.xml b/scala2.13/aggregator/pom.xml index 1d70c76f037..198b62d5fa6 100644 --- a/scala2.13/aggregator/pom.xml +++ b/scala2.13/aggregator/pom.xml @@ -94,6 +94,10 @@ com.google.flatbuffers ${rapids.shade.package}.com.google.flatbuffers + + org.roaringbitmap + ${rapids.shade.package}.org.roaringbitmap + diff --git a/scala2.13/pom.xml b/scala2.13/pom.xml index cbc4aecbd26..e32a64f0529 100644 --- a/scala2.13/pom.xml +++ b/scala2.13/pom.xml @@ -733,6 +733,7 @@ --> -Xlint:all,-serial,-path,-try,-processing|-Werror 1.16.0 + 1.0.6 ${ucx.baseVersion} true @@ -1016,6 +1017,15 @@ ${alluxio.client.version} provided + + + org.roaringbitmap + RoaringBitmap + ${roaringbitmap.version} + compile + org.scalatest scalatest_${scala.binary.version} diff --git a/scala2.13/sql-plugin/pom.xml b/scala2.13/sql-plugin/pom.xml index df3532a3592..eb6f240a3f6 100644 --- a/scala2.13/sql-plugin/pom.xml +++ b/scala2.13/sql-plugin/pom.xml @@ -97,6 +97,10 @@ org.alluxio alluxio-shaded-client + + org.roaringbitmap + RoaringBitmap + org.mockito mockito-core diff --git a/sql-plugin/pom.xml b/sql-plugin/pom.xml index 961e6f08372..08657a9d40b 100644 --- a/sql-plugin/pom.xml +++ b/sql-plugin/pom.xml @@ -97,6 +97,10 @@ org.alluxio alluxio-shaded-client + + org.roaringbitmap + RoaringBitmap + org.mockito mockito-core diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index 8ea1641fb4a..5203e926efa 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -2274,6 +2274,32 @@ val SHUFFLE_COMPRESSION_LZ4_CHUNK_SIZE = conf("spark.rapids.shuffle.compression. .integerConf .createWithDefault(1024) + val DELTA_LOW_SHUFFLE_MERGE_SCATTER_DEL_VECTOR_BATCH_SIZE = + conf("spark.rapids.sql.delta.lowShuffleMerge.deletion.scatter.max.size") + .doc("Option to set max batch size when scattering deletion vector") + .internal() + .integerConf + .createWithDefault(32 * 1024) + + val DELTA_LOW_SHUFFLE_MERGE_DEL_VECTOR_BROADCAST_THRESHOLD = + conf("spark.rapids.sql.delta.lowShuffleMerge.deletionVector.broadcast.threshold") + .doc("Currently we need to broadcast deletion vector to all executors to perform low " + + "shuffle merge. When we detect the deletion vector broadcast size is larger than this " + + "value, we will fallback to normal shuffle merge.") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(20 * 1024 * 1024) + + val ENABLE_DELTA_LOW_SHUFFLE_MERGE = + conf("spark.rapids.sql.delta.lowShuffleMerge.enabled") + .doc("Option to turn on the low shuffle merge for Delta Lake. Currently there are some " + + "limitations for this feature: \n" + + "1. We only support Databricks Runtime 13.3 and Deltalake 2.4. \n" + + s"2. The file scan mode must be set to ${RapidsReaderType.PERFILE} \n" + + "3. The deletion vector size must be smaller than " + + s"${DELTA_LOW_SHUFFLE_MERGE_DEL_VECTOR_BROADCAST_THRESHOLD.key} \n") + .booleanConf + .createWithDefault(false) + private def printSectionHeader(category: String): Unit = println(s"\n### $category") @@ -3083,6 +3109,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging { lazy val testGetJsonObjectSaveRows: Int = get(TEST_GET_JSON_OBJECT_SAVE_ROWS) + lazy val isDeltaLowShuffleMergeEnabled: Boolean = get(ENABLE_DELTA_LOW_SHUFFLE_MERGE) + private val optimizerDefaults = Map( // this is not accurate because CPU projections do have a cost due to appending values // to each row that is produced, but this needs to be a really small number because