From 3b877976fdd97146676343f0e833444fda5a41f3 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Tue, 31 Oct 2023 10:07:40 -0500 Subject: [PATCH] Delta Lake 2.3.0 support (#9562) * Delta Lake 2.3.0 support Signed-off-by: Jason Lowe * Make merge UDFs consistent with Delta Lake 2.4 --------- Signed-off-by: Jason Lowe --- aggregator/pom.xml | 36 + delta-lake/README.md | 1 + .../sql/delta/rapids/DeltaRuntimeShim.scala | 6 +- delta-lake/delta-23x/pom.xml | 98 ++ .../delta/delta23x/DeleteCommandMeta.scala | 58 + .../delta/delta23x/Delta23xProvider.scala | 110 ++ .../GpuDelta23xParquetFileFormat.scala | 64 + .../delta/delta23x/GpuDeltaCatalog.scala | 180 +++ .../delta/delta23x/MergeIntoCommandMeta.scala | 61 + .../delta/delta23x/UpdateCommandMeta.scala | 52 + .../rapids/delta/shims/MetadataShims.scala | 20 + .../rapids/delta23x/Delta23xRuntimeShim.scala | 60 + .../delta23x/GpuCreateDeltaTableCommand.scala | 480 ++++++ .../rapids/delta23x/GpuDeleteCommand.scala | 381 +++++ .../rapids/delta23x/GpuMergeIntoCommand.scala | 1331 +++++++++++++++++ .../delta23x/GpuOptimisticTransaction.scala | 274 ++++ .../rapids/delta23x/GpuUpdateCommand.scala | 275 ++++ pom.xml | 6 + scala2.13/aggregator/pom.xml | 36 + scala2.13/delta-lake/delta-23x/pom.xml | 98 ++ scala2.13/pom.xml | 6 + 21 files changed, 3632 insertions(+), 1 deletion(-) create mode 100644 delta-lake/delta-23x/pom.xml create mode 100644 delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/DeleteCommandMeta.scala create mode 100644 delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/Delta23xProvider.scala create mode 100644 delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/GpuDelta23xParquetFileFormat.scala create mode 100644 delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/GpuDeltaCatalog.scala create mode 100644 delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/MergeIntoCommandMeta.scala create mode 100644 delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/UpdateCommandMeta.scala create mode 100644 delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/shims/MetadataShims.scala create mode 100644 delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/Delta23xRuntimeShim.scala create mode 100644 delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/GpuCreateDeltaTableCommand.scala create mode 100644 delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/GpuDeleteCommand.scala create mode 100644 delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/GpuMergeIntoCommand.scala create mode 100644 delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/GpuOptimisticTransaction.scala create mode 100644 delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/GpuUpdateCommand.scala create mode 100644 scala2.13/delta-lake/delta-23x/pom.xml diff --git a/aggregator/pom.xml b/aggregator/pom.xml index 450be067764..6c72912a164 100644 --- a/aggregator/pom.xml +++ b/aggregator/pom.xml @@ -385,6 +385,12 @@ ${project.version} ${spark.version.classifier} + + com.nvidia + rapids-4-spark-delta-23x_${scala.binary.version} + ${project.version} + ${spark.version.classifier} + @@ -408,6 +414,12 @@ ${project.version} ${spark.version.classifier} + + com.nvidia + rapids-4-spark-delta-23x_${scala.binary.version} + ${project.version} + ${spark.version.classifier} + @@ -431,6 +443,12 @@ ${project.version} ${spark.version.classifier} + + com.nvidia + rapids-4-spark-delta-23x_${scala.binary.version} + ${project.version} + ${spark.version.classifier} + @@ -471,6 +489,12 @@ ${project.version} ${spark.version.classifier} + + com.nvidia + rapids-4-spark-delta-23x_${scala.binary.version} + ${project.version} + ${spark.version.classifier} + @@ -494,6 +518,12 @@ ${project.version} ${spark.version.classifier} + + com.nvidia + rapids-4-spark-delta-23x_${scala.binary.version} + ${project.version} + ${spark.version.classifier} + @@ -534,6 +564,12 @@ ${project.version} ${spark.version.classifier} + + com.nvidia + rapids-4-spark-delta-23x_${scala.binary.version} + ${project.version} + ${spark.version.classifier} + diff --git a/delta-lake/README.md b/delta-lake/README.md index 945e3508f07..ff9d2553e31 100644 --- a/delta-lake/README.md +++ b/delta-lake/README.md @@ -14,6 +14,7 @@ and directory contains the corresponding support code. | 2.0.x | Spark 3.2.x | `delta-20x` | | 2.1.x | Spark 3.3.x | `delta-21x` | | 2.2.x | Spark 3.3.x | `delta-22x` | +| 2.3.x | Spark 3.3.x | `delta-23x` | | 2.4.x | Spark 3.4.x | `delta-24x` | | Databricks 10.4 | Databricks 10.4 | `delta-spark321db` | | Databricks 11.3 | Databricks 11.3 | `delta-spark330db` | diff --git a/delta-lake/common/src/main/delta-io/scala/org/apache/spark/sql/delta/rapids/DeltaRuntimeShim.scala b/delta-lake/common/src/main/delta-io/scala/org/apache/spark/sql/delta/rapids/DeltaRuntimeShim.scala index 65c9ac54094..91c55d090e5 100644 --- a/delta-lake/common/src/main/delta-io/scala/org/apache/spark/sql/delta/rapids/DeltaRuntimeShim.scala +++ b/delta-lake/common/src/main/delta-io/scala/org/apache/spark/sql/delta/rapids/DeltaRuntimeShim.scala @@ -53,7 +53,11 @@ object DeltaRuntimeShim { Try { DeltaUDF.getClass.getMethod("stringStringUdf", classOf[String => String]) }.map(_ => "org.apache.spark.sql.delta.rapids.delta21x.Delta21xRuntimeShim") - .getOrElse("org.apache.spark.sql.delta.rapids.delta22x.Delta22xRuntimeShim") + .orElse { + Try { + classOf[DeltaLog].getMethod("assertRemovable") + }.map(_ => "org.apache.spark.sql.delta.rapids.delta22x.Delta22xRuntimeShim") + }.getOrElse("org.apache.spark.sql.delta.rapids.delta23x.Delta23xRuntimeShim") } else if (VersionUtils.cmpSparkVersion(3, 5, 0) < 0) { "org.apache.spark.sql.delta.rapids.delta24x.Delta24xRuntimeShim" } else { diff --git a/delta-lake/delta-23x/pom.xml b/delta-lake/delta-23x/pom.xml new file mode 100644 index 00000000000..9b8cb489cb6 --- /dev/null +++ b/delta-lake/delta-23x/pom.xml @@ -0,0 +1,98 @@ + + + + 4.0.0 + + + com.nvidia + rapids-4-spark-parent_2.12 + 23.12.0-SNAPSHOT + ../../pom.xml + + + rapids-4-spark-delta-23x_2.12 + RAPIDS Accelerator for Apache Spark Delta Lake 2.3.x Support + Delta Lake 2.3.x support for the RAPIDS Accelerator for Apache Spark + 23.12.0-SNAPSHOT + + + ../delta-lake/delta-23x + false + **/* + package + + + + + com.nvidia + rapids-4-spark-sql_${scala.binary.version} + ${project.version} + ${spark.version.classifier} + provided + + + io.delta + delta-core_${scala.binary.version} + 2.3.0 + provided + + + org.apache.spark + spark-sql_${scala.binary.version} + + + + + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-common-sources + generate-sources + + add-source + + + + + ${project.basedir}/../common/src/main/scala + ${project.basedir}/../common/src/main/delta-io/scala + + + + + + + + + net.alchim31.maven + scala-maven-plugin + + + org.apache.rat + apache-rat-plugin + + + + diff --git a/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/DeleteCommandMeta.scala b/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/DeleteCommandMeta.scala new file mode 100644 index 00000000000..49274eba3aa --- /dev/null +++ b/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/DeleteCommandMeta.scala @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.delta.delta23x + +import com.nvidia.spark.rapids.{DataFromReplacementRule, RapidsConf, RapidsMeta, RunnableCommandMeta} +import com.nvidia.spark.rapids.delta.RapidsDeltaUtils + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.delta.commands.{DeleteCommand, DeletionVectorUtils} +import org.apache.spark.sql.delta.rapids.GpuDeltaLog +import org.apache.spark.sql.delta.rapids.delta23x.GpuDeleteCommand +import org.apache.spark.sql.delta.sources.DeltaSQLConf +import org.apache.spark.sql.execution.command.RunnableCommand + +class DeleteCommandMeta( + deleteCmd: DeleteCommand, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _, _]], + rule: DataFromReplacementRule) + extends RunnableCommandMeta[DeleteCommand](deleteCmd, conf, parent, rule) { + + override def tagSelfForGpu(): Unit = { + if (!conf.isDeltaWriteEnabled) { + willNotWorkOnGpu("Delta Lake output acceleration has been disabled. To enable set " + + s"${RapidsConf.ENABLE_DELTA_WRITE} to true") + } + val dvFeatureEnabled = DeletionVectorUtils.deletionVectorsWritable( + deleteCmd.deltaLog.unsafeVolatileSnapshot) + if (dvFeatureEnabled && deleteCmd.conf.getConf( + DeltaSQLConf.DELETE_USE_PERSISTENT_DELETION_VECTORS)) { + // https://github.com/NVIDIA/spark-rapids/issues/8554 + willNotWorkOnGpu("Deletion vectors are not supported on GPU") + } + RapidsDeltaUtils.tagForDeltaWrite(this, deleteCmd.target.schema, Some(deleteCmd.deltaLog), + Map.empty, SparkSession.active) + } + + override def convertToGpu(): RunnableCommand = { + GpuDeleteCommand( + new GpuDeltaLog(deleteCmd.deltaLog, conf), + deleteCmd.target, + deleteCmd.condition) + } +} diff --git a/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/Delta23xProvider.scala b/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/Delta23xProvider.scala new file mode 100644 index 00000000000..c32c0a0e02d --- /dev/null +++ b/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/Delta23xProvider.scala @@ -0,0 +1,110 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.delta.delta23x + +import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, AtomicReplaceTableAsSelectExecMeta, GpuExec, GpuOverrides, GpuReadParquetFileFormat, RunnableCommandRule, SparkPlanMeta} +import com.nvidia.spark.rapids.delta.DeltaIOProvider + +import org.apache.spark.sql.delta.DeltaParquetFileFormat +import org.apache.spark.sql.delta.DeltaParquetFileFormat.{IS_ROW_DELETED_COLUMN_NAME, ROW_INDEX_COLUMN_NAME} +import org.apache.spark.sql.delta.catalog.DeltaCatalog +import org.apache.spark.sql.delta.commands.{DeleteCommand, MergeIntoCommand, UpdateCommand} +import org.apache.spark.sql.delta.rapids.DeltaRuntimeShim +import org.apache.spark.sql.execution.FileSourceScanExec +import org.apache.spark.sql.execution.command.RunnableCommand +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec} +import org.apache.spark.sql.execution.datasources.v2.rapids.{GpuAtomicCreateTableAsSelectExec, GpuAtomicReplaceTableAsSelectExec} + +object Delta23xProvider extends DeltaIOProvider { + + override def getRunnableCommandRules: Map[Class[_ <: RunnableCommand], + RunnableCommandRule[_ <: RunnableCommand]] = { + Seq( + GpuOverrides.runnableCmd[DeleteCommand]( + "Delete rows from a Delta Lake table", + (a, conf, p, r) => new DeleteCommandMeta(a, conf, p, r)) + .disabledByDefault("Delta Lake delete support is experimental"), + GpuOverrides.runnableCmd[MergeIntoCommand]( + "Merge of a source query/table into a Delta table", + (a, conf, p, r) => new MergeIntoCommandMeta(a, conf, p, r)) + .disabledByDefault("Delta Lake merge support is experimental"), + GpuOverrides.runnableCmd[UpdateCommand]( + "Update rows in a Delta Lake table", + (a, conf, p, r) => new UpdateCommandMeta(a, conf, p, r)) + .disabledByDefault("Delta Lake update support is experimental") + ).map(r => (r.getClassFor.asSubclass(classOf[RunnableCommand]), r)).toMap + } + + override def tagSupportForGpuFileSourceScan(meta: SparkPlanMeta[FileSourceScanExec]): Unit = { + val format = meta.wrapped.relation.fileFormat + if (format.getClass == classOf[DeltaParquetFileFormat]) { + val deltaFormat = format.asInstanceOf[DeltaParquetFileFormat] + val requiredSchema = meta.wrapped.requiredSchema + if (requiredSchema.exists(_.name == IS_ROW_DELETED_COLUMN_NAME)) { + meta.willNotWorkOnGpu( + s"reading metadata column $IS_ROW_DELETED_COLUMN_NAME is not supported") + } + if (requiredSchema.exists(_.name == ROW_INDEX_COLUMN_NAME)) { + meta.willNotWorkOnGpu( + s"reading metadata column $ROW_INDEX_COLUMN_NAME is not supported") + } + if (deltaFormat.hasDeletionVectorMap()) { + meta.willNotWorkOnGpu("deletion vectors are not supported") + } + GpuReadParquetFileFormat.tagSupport(meta) + } else { + meta.willNotWorkOnGpu(s"format ${format.getClass} is not supported") + } + } + + override def getReadFileFormat(format: FileFormat): FileFormat = { + val cpuFormat = format.asInstanceOf[DeltaParquetFileFormat] + GpuDelta23xParquetFileFormat(cpuFormat.metadata, cpuFormat.isSplittable) + } + + override def convertToGpu( + cpuExec: AtomicCreateTableAsSelectExec, + meta: AtomicCreateTableAsSelectExecMeta): GpuExec = { + val cpuCatalog = cpuExec.catalog.asInstanceOf[DeltaCatalog] + GpuAtomicCreateTableAsSelectExec( + DeltaRuntimeShim.getGpuDeltaCatalog(cpuCatalog, meta.conf), + cpuExec.ident, + cpuExec.partitioning, + cpuExec.plan, + meta.childPlans.head.convertIfNeeded(), + cpuExec.tableSpec, + cpuExec.writeOptions, + cpuExec.ifNotExists) + } + + override def convertToGpu( + cpuExec: AtomicReplaceTableAsSelectExec, + meta: AtomicReplaceTableAsSelectExecMeta): GpuExec = { + val cpuCatalog = cpuExec.catalog.asInstanceOf[DeltaCatalog] + GpuAtomicReplaceTableAsSelectExec( + DeltaRuntimeShim.getGpuDeltaCatalog(cpuCatalog, meta.conf), + cpuExec.ident, + cpuExec.partitioning, + cpuExec.plan, + meta.childPlans.head.convertIfNeeded(), + cpuExec.tableSpec, + cpuExec.writeOptions, + cpuExec.orCreate, + cpuExec.invalidateCache) + } +} diff --git a/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/GpuDelta23xParquetFileFormat.scala b/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/GpuDelta23xParquetFileFormat.scala new file mode 100644 index 00000000000..0466a01506a --- /dev/null +++ b/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/GpuDelta23xParquetFileFormat.scala @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.delta.delta23x + +import com.nvidia.spark.rapids.delta.GpuDeltaParquetFileFormat +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.delta.{DeltaColumnMappingMode, IdMapping} +import org.apache.spark.sql.delta.actions.Metadata +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType + +case class GpuDelta23xParquetFileFormat( + metadata: Metadata, + isSplittable: Boolean) extends GpuDeltaParquetFileFormat { + + override val columnMappingMode: DeltaColumnMappingMode = metadata.columnMappingMode + override val referenceSchema: StructType = metadata.schema + + if (columnMappingMode == IdMapping) { + val requiredReadConf = SQLConf.PARQUET_FIELD_ID_READ_ENABLED + require(SparkSession.getActiveSession.exists(_.sessionState.conf.getConf(requiredReadConf)), + s"${requiredReadConf.key} must be enabled to support Delta id column mapping mode") + val requiredWriteConf = SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED + require(SparkSession.getActiveSession.exists(_.sessionState.conf.getConf(requiredWriteConf)), + s"${requiredWriteConf.key} must be enabled to support Delta id column mapping mode") + } + + override def isSplitable( + sparkSession: SparkSession, + options: Map[String, String], + path: Path): Boolean = isSplittable + + /** + * We sometimes need to replace FileFormat within LogicalPlans, so we have to override + * `equals` to ensure file format changes are captured + */ + override def equals(other: Any): Boolean = { + other match { + case ff: GpuDelta23xParquetFileFormat => + ff.columnMappingMode == columnMappingMode && + ff.referenceSchema == referenceSchema && + ff.isSplittable == isSplittable + case _ => false + } + } + + override def hashCode(): Int = getClass.getCanonicalName.hashCode() +} diff --git a/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/GpuDeltaCatalog.scala b/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/GpuDeltaCatalog.scala new file mode 100644 index 00000000000..d10b114299d --- /dev/null +++ b/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/GpuDeltaCatalog.scala @@ -0,0 +1,180 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * This file was derived from DeltaDataSource.scala in the + * Delta Lake project at https://github.com/delta-io/delta. + * + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.delta.delta23x + +import com.nvidia.spark.rapids.RapidsConf +import java.util + +import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.CatalogTable +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.connector.catalog.{Identifier, StagedTable, Table} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.delta.catalog.DeltaCatalog +import org.apache.spark.sql.delta.commands.TableCreationModes +import org.apache.spark.sql.delta.metering.DeltaLogging +import org.apache.spark.sql.delta.rapids.GpuDeltaCatalogBase +import org.apache.spark.sql.delta.rapids.delta23x.GpuCreateDeltaTableCommand +import org.apache.spark.sql.execution.command.LeafRunnableCommand +import org.apache.spark.sql.types.StructType + +class GpuDeltaCatalog( + override val cpuCatalog: DeltaCatalog, + override val rapidsConf: RapidsConf) + extends GpuDeltaCatalogBase with DeltaLogging { + + override val spark: SparkSession = cpuCatalog.spark + + override protected def buildGpuCreateDeltaTableCommand( + rapidsConf: RapidsConf, + table: CatalogTable, + existingTableOpt: Option[CatalogTable], + mode: SaveMode, + query: Option[LogicalPlan], + operation: TableCreationModes.CreationMode, + tableByPath: Boolean): LeafRunnableCommand = { + GpuCreateDeltaTableCommand( + table, + existingTableOpt, + mode, + query, + operation, + tableByPath = tableByPath + )(rapidsConf) + } + + override protected def getExistingTableIfExists(table: TableIdentifier): Option[CatalogTable] = { + cpuCatalog.getExistingTableIfExists(table) + } + + override protected def verifyTableAndSolidify( + tableDesc: CatalogTable, + query: Option[LogicalPlan]): CatalogTable = { + cpuCatalog.verifyTableAndSolidify(tableDesc, query) + } + + override protected def createGpuStagedDeltaTableV2( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String], + operation: TableCreationModes.CreationMode): StagedTable = { + new GpuStagedDeltaTableV2WithLogging(ident, schema, partitions, properties, operation) + } + + override def loadTable(ident: Identifier, timestamp: Long): Table = { + cpuCatalog.loadTable(ident, timestamp) + } + + override def loadTable(ident: Identifier, version: String): Table = { + cpuCatalog.loadTable(ident, version) + } + + /** + * Creates a Delta table using GPU for writing the data + * + * @param ident The identifier of the table + * @param schema The schema of the table + * @param partitions The partition transforms for the table + * @param allTableProperties The table properties that configure the behavior of the table or + * provide information about the table + * @param writeOptions Options specific to the write during table creation or replacement + * @param sourceQuery A query if this CREATE request came from a CTAS or RTAS + * @param operation The specific table creation mode, whether this is a + * Create/Replace/Create or Replace + */ + override def createDeltaTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + allTableProperties: util.Map[String, String], + writeOptions: Map[String, String], + sourceQuery: Option[DataFrame], + operation: TableCreationModes.CreationMode + ): Table = recordFrameProfile( + "DeltaCatalog", "createDeltaTable") { + super.createDeltaTable( + ident, + schema, + partitions, + allTableProperties, + writeOptions, + sourceQuery, + operation) + } + + override def createTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = + recordFrameProfile("DeltaCatalog", "createTable") { + super.createTable(ident, schema, partitions, properties) + } + + override def stageReplace( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): StagedTable = + recordFrameProfile("DeltaCatalog", "stageReplace") { + super.stageReplace(ident, schema, partitions, properties) + } + + override def stageCreateOrReplace( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): StagedTable = + recordFrameProfile("DeltaCatalog", "stageCreateOrReplace") { + super.stageCreateOrReplace(ident, schema, partitions, properties) + } + + override def stageCreate( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): StagedTable = + recordFrameProfile("DeltaCatalog", "stageCreate") { + super.stageCreate(ident, schema, partitions, properties) + } + + /** + * A staged Delta table, which creates a HiveMetaStore entry and appends data if this was a + * CTAS/RTAS command. We have a ugly way of using this API right now, but it's the best way to + * maintain old behavior compatibility between Databricks Runtime and OSS Delta Lake. + */ + protected class GpuStagedDeltaTableV2WithLogging( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String], + operation: TableCreationModes.CreationMode) + extends GpuStagedDeltaTableV2(ident, schema, partitions, properties, operation) { + + override def commitStagedChanges(): Unit = recordFrameProfile( + "DeltaCatalog", "commitStagedChanges") { + super.commitStagedChanges() + } + } +} diff --git a/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/MergeIntoCommandMeta.scala b/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/MergeIntoCommandMeta.scala new file mode 100644 index 00000000000..8e1cd7dd490 --- /dev/null +++ b/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/MergeIntoCommandMeta.scala @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.delta.delta23x + +import com.nvidia.spark.rapids.{DataFromReplacementRule, RapidsConf, RapidsMeta, RunnableCommandMeta} +import com.nvidia.spark.rapids.delta.RapidsDeltaUtils + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.delta.commands.MergeIntoCommand +import org.apache.spark.sql.delta.rapids.GpuDeltaLog +import org.apache.spark.sql.delta.rapids.delta23x.GpuMergeIntoCommand +import org.apache.spark.sql.execution.command.RunnableCommand + +class MergeIntoCommandMeta( + mergeCmd: MergeIntoCommand, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _, _]], + rule: DataFromReplacementRule) + extends RunnableCommandMeta[MergeIntoCommand](mergeCmd, conf, parent, rule) { + + override def tagSelfForGpu(): Unit = { + if (!conf.isDeltaWriteEnabled) { + willNotWorkOnGpu("Delta Lake output acceleration has been disabled. To enable set " + + s"${RapidsConf.ENABLE_DELTA_WRITE} to true") + } + if (mergeCmd.notMatchedBySourceClauses.nonEmpty) { + // https://github.com/NVIDIA/spark-rapids/issues/8415 + willNotWorkOnGpu("notMatchedBySourceClauses not supported on GPU") + } + val targetSchema = mergeCmd.migratedSchema.getOrElse(mergeCmd.target.schema) + val deltaLog = mergeCmd.targetFileIndex.deltaLog + RapidsDeltaUtils.tagForDeltaWrite(this, targetSchema, Some(deltaLog), Map.empty, + SparkSession.active) + } + + override def convertToGpu(): RunnableCommand = { + GpuMergeIntoCommand( + mergeCmd.source, + mergeCmd.target, + new GpuDeltaLog(mergeCmd.targetFileIndex.deltaLog, conf), + mergeCmd.condition, + mergeCmd.matchedClauses, + mergeCmd.notMatchedClauses, + mergeCmd.notMatchedBySourceClauses, + mergeCmd.migratedSchema)(conf) + } +} diff --git a/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/UpdateCommandMeta.scala b/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/UpdateCommandMeta.scala new file mode 100644 index 00000000000..1f1bcd93137 --- /dev/null +++ b/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/delta23x/UpdateCommandMeta.scala @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.delta.delta23x + +import com.nvidia.spark.rapids.{DataFromReplacementRule, RapidsConf, RapidsMeta, RunnableCommandMeta} +import com.nvidia.spark.rapids.delta.RapidsDeltaUtils + +import org.apache.spark.sql.delta.commands.UpdateCommand +import org.apache.spark.sql.delta.rapids.GpuDeltaLog +import org.apache.spark.sql.delta.rapids.delta23x.GpuUpdateCommand +import org.apache.spark.sql.execution.command.RunnableCommand + +class UpdateCommandMeta( + updateCmd: UpdateCommand, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _, _]], + rule: DataFromReplacementRule) + extends RunnableCommandMeta[UpdateCommand](updateCmd, conf, parent, rule) { + + override def tagSelfForGpu(): Unit = { + if (!conf.isDeltaWriteEnabled) { + willNotWorkOnGpu("Delta Lake output acceleration has been disabled. To enable set " + + s"${RapidsConf.ENABLE_DELTA_WRITE} to true") + } + RapidsDeltaUtils.tagForDeltaWrite(this, updateCmd.target.schema, + Some(updateCmd.tahoeFileIndex.deltaLog), Map.empty, updateCmd.tahoeFileIndex.spark) + } + + override def convertToGpu(): RunnableCommand = { + GpuUpdateCommand( + new GpuDeltaLog(updateCmd.tahoeFileIndex.deltaLog, conf), + updateCmd.tahoeFileIndex, + updateCmd.target, + updateCmd.updateExpressions, + updateCmd.condition + ) + } +} diff --git a/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/shims/MetadataShims.scala b/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/shims/MetadataShims.scala new file mode 100644 index 00000000000..f4f9836e9f9 --- /dev/null +++ b/delta-lake/delta-23x/src/main/scala/com/nvidia/spark/rapids/delta/shims/MetadataShims.scala @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids.delta.shims + +import org.apache.spark.sql.delta.stats.UsesMetadataFields + +trait ShimUsesMetadataFields extends UsesMetadataFields \ No newline at end of file diff --git a/delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/Delta23xRuntimeShim.scala b/delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/Delta23xRuntimeShim.scala new file mode 100644 index 00000000000..c12982d54fe --- /dev/null +++ b/delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/Delta23xRuntimeShim.scala @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta.rapids.delta23x + +import com.nvidia.spark.rapids.RapidsConf +import com.nvidia.spark.rapids.delta.DeltaProvider +import com.nvidia.spark.rapids.delta.delta23x.{Delta23xProvider, GpuDeltaCatalog} + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.catalog.StagingTableCatalog +import org.apache.spark.sql.delta.{DeltaLog, DeltaUDF, Snapshot} +import org.apache.spark.sql.delta.catalog.DeltaCatalog +import org.apache.spark.sql.delta.rapids.{DeltaRuntimeShim, GpuOptimisticTransactionBase} +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.expressions.UserDefinedFunction +import org.apache.spark.util.Clock + +class Delta23xRuntimeShim extends DeltaRuntimeShim { + override def getDeltaProvider: DeltaProvider = Delta23xProvider + + override def startTransaction( + log: DeltaLog, + conf: RapidsConf, + clock: Clock): GpuOptimisticTransactionBase = { + new GpuOptimisticTransaction(log, conf)(clock) + } + + override def stringFromStringUdf(f: String => String): UserDefinedFunction = { + DeltaUDF.stringFromString(f) + } + + override def unsafeVolatileSnapshotFromLog(deltaLog: DeltaLog): Snapshot = { + deltaLog.unsafeVolatileSnapshot + } + + override def fileFormatFromLog(deltaLog: DeltaLog): FileFormat = + deltaLog.fileFormat(deltaLog.unsafeVolatileMetadata) + + override def getTightBoundColumnOnFileInitDisabled(spark: SparkSession): Boolean = false + + override def getGpuDeltaCatalog( + cpuCatalog: DeltaCatalog, + rapidsConf: RapidsConf): StagingTableCatalog = { + new GpuDeltaCatalog(cpuCatalog, rapidsConf) + } +} diff --git a/delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/GpuCreateDeltaTableCommand.scala b/delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/GpuCreateDeltaTableCommand.scala new file mode 100644 index 00000000000..83b6408d647 --- /dev/null +++ b/delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/GpuCreateDeltaTableCommand.scala @@ -0,0 +1,480 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * This file was derived from CreateDeltaTableCommand.scala in the + * Delta Lake project at https://github.com/delta-io/delta. + * + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta.rapids.delta23x + +import com.nvidia.spark.rapids.RapidsConf +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.delta._ +import org.apache.spark.sql.delta.actions.{Action, Metadata, Protocol} +import org.apache.spark.sql.delta.commands.{DeltaCommand, TableCreationModes, WriteIntoDelta} +import org.apache.spark.sql.delta.metering.DeltaLogging +import org.apache.spark.sql.delta.rapids.GpuDeltaLog +import org.apache.spark.sql.delta.schema.SchemaUtils +import org.apache.spark.sql.delta.sources.DeltaSQLConf +import org.apache.spark.sql.execution.command.{LeafRunnableCommand, RunnableCommand} +import org.apache.spark.sql.types.StructType + +/** + * Single entry point for all write or declaration operations for Delta tables accessed through + * the table name. + * + * @param table The table identifier for the Delta table + * @param existingTableOpt The existing table for the same identifier if exists + * @param mode The save mode when writing data. Relevant when the query is empty or set to Ignore + * with `CREATE TABLE IF NOT EXISTS`. + * @param query The query to commit into the Delta table if it exist. This can come from + * - CTAS + * - saveAsTable + * @param protocol This is used to create a table with specific protocol version + */ +case class GpuCreateDeltaTableCommand( + table: CatalogTable, + existingTableOpt: Option[CatalogTable], + mode: SaveMode, + query: Option[LogicalPlan], + operation: TableCreationModes.CreationMode = TableCreationModes.Create, + tableByPath: Boolean = false, + override val output: Seq[Attribute] = Nil, + protocol: Option[Protocol] = None)(@transient rapidsConf: RapidsConf) + extends LeafRunnableCommand + with DeltaCommand + with DeltaLogging { + + override def otherCopyArgs: Seq[AnyRef] = Seq(rapidsConf) + + override def run(sparkSession: SparkSession): Seq[Row] = { + val table = this.table + + assert(table.tableType != CatalogTableType.VIEW) + assert(table.identifier.database.isDefined, "Database should've been fixed at analysis") + // There is a subtle race condition here, where the table can be created by someone else + // while this command is running. Nothing we can do about that though :( + val tableExists = existingTableOpt.isDefined + if (mode == SaveMode.Ignore && tableExists) { + // Early exit on ignore + return Nil + } else if (mode == SaveMode.ErrorIfExists && tableExists) { + throw DeltaErrors.tableAlreadyExists(table) + } + + val tableWithLocation = if (tableExists) { + val existingTable = existingTableOpt.get + table.storage.locationUri match { + case Some(location) if location.getPath != existingTable.location.getPath => + throw DeltaErrors.tableLocationMismatch(table, existingTable) + case _ => + } + table.copy( + storage = existingTable.storage, + tableType = existingTable.tableType) + } else if (table.storage.locationUri.isEmpty) { + // We are defining a new managed table + assert(table.tableType == CatalogTableType.MANAGED) + val loc = sparkSession.sessionState.catalog.defaultTablePath(table.identifier) + table.copy(storage = table.storage.copy(locationUri = Some(loc))) + } else { + // 1. We are defining a new external table + // 2. It's a managed table which already has the location populated. This can happen in DSV2 + // CTAS flow. + table + } + + val isManagedTable = tableWithLocation.tableType == CatalogTableType.MANAGED + val tableLocation = new Path(tableWithLocation.location) + val gpuDeltaLog = GpuDeltaLog.forTable(sparkSession, tableLocation, rapidsConf) + val hadoopConf = gpuDeltaLog.deltaLog.newDeltaHadoopConf() + val fs = tableLocation.getFileSystem(hadoopConf) + val options = new DeltaOptions(table.storage.properties, sparkSession.sessionState.conf) + var result: Seq[Row] = Nil + + recordDeltaOperation(gpuDeltaLog.deltaLog, "delta.ddl.createTable") { + val txn = gpuDeltaLog.startTransaction() + val opStartTs = System.currentTimeMillis() + if (query.isDefined) { + // If the mode is Ignore or ErrorIfExists, the table must not exist, or we would return + // earlier. And the data should not exist either, to match the behavior of + // Ignore/ErrorIfExists mode. This means the table path should not exist or is empty. + if (mode == SaveMode.Ignore || mode == SaveMode.ErrorIfExists) { + assert(!tableExists) + // We may have failed a previous write. The retry should still succeed even if we have + // garbage data + if (txn.readVersion > -1 || !fs.exists(gpuDeltaLog.deltaLog.logPath)) { + assertPathEmpty(hadoopConf, tableWithLocation) + } + } + + // Execute write command for `deltaWriter` by + // - replacing the metadata new target table for DataFrameWriterV2 writer if it is a + // REPLACE or CREATE_OR_REPLACE command, + // - running the write procedure of DataFrameWriter command and returning the + // new created actions, + // - returning the Delta Operation type of this DataFrameWriter + def doDeltaWrite( + deltaWriter: WriteIntoDelta, + schema: StructType): (Seq[Action], DeltaOperations.Operation) = { + // In the V2 Writer, methods like "replace" and "createOrReplace" implicitly mean that + // the metadata should be changed. This wasn't the behavior for DataFrameWriterV1. + if (!isV1Writer) { + replaceMetadataIfNecessary( + txn, tableWithLocation, options, schema) + } + val actions = deltaWriter.write(txn, sparkSession) + val op = getOperation(txn.metadata, isManagedTable, Some(options)) + (actions, op) + } + + // We are either appending/overwriting with saveAsTable or creating a new table with CTAS or + // we are creating a table as part of a RunnableCommand + query.get match { + case deltaWriter: WriteIntoDelta => + if (!hasBeenExecuted(txn, sparkSession, Some(options))) { + val (actions, op) = doDeltaWrite(deltaWriter, deltaWriter.data.schema.asNullable) + txn.commit(actions, op) + } + case cmd: RunnableCommand => + result = cmd.run(sparkSession) + case other => + // When using V1 APIs, the `other` plan is not yet optimized, therefore, it is safe + // to once again go through analysis + val data = Dataset.ofRows(sparkSession, other) + val deltaWriter = WriteIntoDelta( + deltaLog = gpuDeltaLog.deltaLog, + mode = mode, + options, + partitionColumns = table.partitionColumnNames, + configuration = tableWithLocation.properties + ("comment" -> table.comment.orNull), + data = data) + if (!hasBeenExecuted(txn, sparkSession, Some(options))) { + val (actions, op) = doDeltaWrite(deltaWriter, other.schema.asNullable) + txn.commit(actions, op) + } + } + } else { + def createTransactionLogOrVerify(): Unit = { + if (isManagedTable) { + // When creating a managed table, the table path should not exist or is empty, or + // users would be surprised to see the data, or see the data directory being dropped + // after the table is dropped. + assertPathEmpty(hadoopConf, tableWithLocation) + } + + // This is either a new table, or, we never defined the schema of the table. While it is + // unexpected that `txn.metadata.schema` to be empty when txn.readVersion >= 0, we still + // guard against it, in case of checkpoint corruption bugs. + val noExistingMetadata = txn.readVersion == -1 || txn.metadata.schema.isEmpty + if (noExistingMetadata) { + assertTableSchemaDefined(fs, tableLocation, tableWithLocation, txn, sparkSession) + assertPathEmpty(hadoopConf, tableWithLocation) + // This is a user provided schema. + // Doesn't come from a query, Follow nullability invariants. + val newMetadata = getProvidedMetadata(tableWithLocation, table.schema.json) + txn.updateMetadataForNewTable(newMetadata) + protocol.foreach { protocol => + txn.updateProtocol(protocol) + } + val op = getOperation(newMetadata, isManagedTable, None) + txn.commit(Nil, op) + } else { + verifyTableMetadata(txn, tableWithLocation) + } + } + // We are defining a table using the Create or Replace Table statements. + operation match { + case TableCreationModes.Create => + require(!tableExists, "Can't recreate a table when it exists") + createTransactionLogOrVerify() + + case TableCreationModes.CreateOrReplace if !tableExists => + // If the table doesn't exist, CREATE OR REPLACE must provide a schema + if (tableWithLocation.schema.isEmpty) { + throw DeltaErrors.schemaNotProvidedException + } + createTransactionLogOrVerify() + case _ => + // When the operation is a REPLACE or CREATE OR REPLACE, then the schema shouldn't be + // empty, since we'll use the entry to replace the schema + if (tableWithLocation.schema.isEmpty) { + throw DeltaErrors.schemaNotProvidedException + } + // We need to replace + replaceMetadataIfNecessary(txn, tableWithLocation, options, tableWithLocation.schema) + // Truncate the table + val operationTimestamp = System.currentTimeMillis() + val removes = txn.filterFiles().map(_.removeWithTimestamp(operationTimestamp)) + val op = getOperation(txn.metadata, isManagedTable, None) + txn.commit(removes, op) + } + } + + // We would have failed earlier on if we couldn't ignore the existence of the table + // In addition, we just might using saveAsTable to append to the table, so ignore the creation + // if it already exists. + // Note that someone may have dropped and recreated the table in a separate location in the + // meantime... Unfortunately we can't do anything there at the moment, because Hive sucks. + logInfo(s"Table is path-based table: $tableByPath. Update catalog with mode: $operation") + updateCatalog( + sparkSession, + tableWithLocation, + gpuDeltaLog.deltaLog.update(checkIfUpdatedSinceTs = Some(opStartTs)), + txn) + + result + } + } + + private def getProvidedMetadata(table: CatalogTable, schemaString: String): Metadata = { + Metadata( + description = table.comment.orNull, + schemaString = schemaString, + partitionColumns = table.partitionColumnNames, + configuration = table.properties, + createdTime = Some(System.currentTimeMillis())) + } + + private def assertPathEmpty( + hadoopConf: Configuration, + tableWithLocation: CatalogTable): Unit = { + val path = new Path(tableWithLocation.location) + val fs = path.getFileSystem(hadoopConf) + // Verify that the table location associated with CREATE TABLE doesn't have any data. Note that + // we intentionally diverge from this behavior w.r.t regular datasource tables (that silently + // overwrite any previous data) + if (fs.exists(path) && fs.listStatus(path).nonEmpty) { + throw DeltaErrors.createTableWithNonEmptyLocation( + tableWithLocation.identifier.toString, + tableWithLocation.location.toString) + } + } + + private def assertTableSchemaDefined( + fs: FileSystem, + path: Path, + table: CatalogTable, + txn: OptimisticTransaction, + sparkSession: SparkSession): Unit = { + // If we allow creating an empty schema table and indeed the table is new, we just need to + // make sure: + // 1. txn.readVersion == -1 to read a new table + // 2. for external tables: path must either doesn't exist or is completely empty + val allowCreatingTableWithEmptySchema = sparkSession.sessionState + .conf.getConf(DeltaSQLConf.DELTA_ALLOW_CREATE_EMPTY_SCHEMA_TABLE) && txn.readVersion == -1 + + // Users did not specify the schema. We expect the schema exists in Delta. + if (table.schema.isEmpty) { + if (table.tableType == CatalogTableType.EXTERNAL) { + if (fs.exists(path) && fs.listStatus(path).nonEmpty) { + throw DeltaErrors.createExternalTableWithoutLogException( + path, table.identifier.quotedString, sparkSession) + } else { + if (allowCreatingTableWithEmptySchema) return + throw DeltaErrors.createExternalTableWithoutSchemaException( + path, table.identifier.quotedString, sparkSession) + } + } else { + if (allowCreatingTableWithEmptySchema) return + throw DeltaErrors.createManagedTableWithoutSchemaException( + table.identifier.quotedString, sparkSession) + } + } + } + + /** + * Verify against our transaction metadata that the user specified the right metadata for the + * table. + */ + private def verifyTableMetadata( + txn: OptimisticTransaction, + tableDesc: CatalogTable): Unit = { + val existingMetadata = txn.metadata + val path = new Path(tableDesc.location) + + // The delta log already exists. If they give any configuration, we'll make sure it all matches. + // Otherwise we'll just go with the metadata already present in the log. + // The schema compatibility checks will be made in `WriteIntoDelta` for CreateTable + // with a query + if (txn.readVersion > -1) { + if (tableDesc.schema.nonEmpty) { + // We check exact alignment on create table if everything is provided + // However, if in column mapping mode, we can safely ignore the related metadata fields in + // existing metadata because new table desc will not have related metadata assigned yet + val differences = SchemaUtils.reportDifferences( + DeltaColumnMapping.dropColumnMappingMetadata(existingMetadata.schema), + tableDesc.schema) + if (differences.nonEmpty) { + throw DeltaErrors.createTableWithDifferentSchemaException( + path, tableDesc.schema, existingMetadata.schema, differences) + } + } + + // If schema is specified, we must make sure the partitioning matches, even the partitioning + // is not specified. + if (tableDesc.schema.nonEmpty && + tableDesc.partitionColumnNames != existingMetadata.partitionColumns) { + throw DeltaErrors.createTableWithDifferentPartitioningException( + path, tableDesc.partitionColumnNames, existingMetadata.partitionColumns) + } + + if (tableDesc.properties.nonEmpty && tableDesc.properties != existingMetadata.configuration) { + throw DeltaErrors.createTableWithDifferentPropertiesException( + path, tableDesc.properties, existingMetadata.configuration) + } + } + } + + /** + * Based on the table creation operation, and parameters, we can resolve to different operations. + * A lot of this is needed for legacy reasons in Databricks Runtime. + * @param metadata The table metadata, which we are creating or replacing + * @param isManagedTable Whether we are creating or replacing a managed table + * @param options Write options, if this was a CTAS/RTAS + */ + private def getOperation( + metadata: Metadata, + isManagedTable: Boolean, + options: Option[DeltaOptions]): DeltaOperations.Operation = operation match { + // This is legacy saveAsTable behavior in Databricks Runtime + case TableCreationModes.Create if existingTableOpt.isDefined && query.isDefined => + DeltaOperations.Write(mode, Option(table.partitionColumnNames), options.get.replaceWhere, + options.flatMap(_.userMetadata)) + + // DataSourceV2 table creation + // CREATE TABLE (non-DataFrameWriter API) doesn't have options syntax + // (userMetadata uses SQLConf in this case) + case TableCreationModes.Create => + DeltaOperations.CreateTable(metadata, isManagedTable, query.isDefined) + + // DataSourceV2 table replace + // REPLACE TABLE (non-DataFrameWriter API) doesn't have options syntax + // (userMetadata uses SQLConf in this case) + case TableCreationModes.Replace => + DeltaOperations.ReplaceTable(metadata, isManagedTable, orCreate = false, query.isDefined) + + // Legacy saveAsTable with Overwrite mode + case TableCreationModes.CreateOrReplace if options.exists(_.replaceWhere.isDefined) => + DeltaOperations.Write(mode, Option(table.partitionColumnNames), options.get.replaceWhere, + options.flatMap(_.userMetadata)) + + // New DataSourceV2 saveAsTable with overwrite mode behavior + case TableCreationModes.CreateOrReplace => + DeltaOperations.ReplaceTable(metadata, isManagedTable, orCreate = true, query.isDefined, + options.flatMap(_.userMetadata)) + } + + /** + * Similar to getOperation, here we disambiguate the catalog alterations we need to do based + * on the table operation, and whether we have reached here through legacy code or DataSourceV2 + * code paths. + */ + private def updateCatalog( + spark: SparkSession, + table: CatalogTable, + snapshot: Snapshot, + txn: OptimisticTransaction): Unit = { + val cleaned = cleanupTableDefinition(spark, table, snapshot) + operation match { + case _ if tableByPath => // do nothing with the metastore if this is by path + case TableCreationModes.Create => + spark.sessionState.catalog.createTable( + cleaned, + ignoreIfExists = existingTableOpt.isDefined, + validateLocation = false) + case TableCreationModes.Replace | TableCreationModes.CreateOrReplace + if existingTableOpt.isDefined => + spark.sessionState.catalog.alterTable(table) + case TableCreationModes.Replace => + val ident = Identifier.of(table.identifier.database.toArray, table.identifier.table) + throw DeltaErrors.cannotReplaceMissingTableException(ident) + case TableCreationModes.CreateOrReplace => + spark.sessionState.catalog.createTable( + cleaned, + ignoreIfExists = false, + validateLocation = false) + } + } + + /** Clean up the information we pass on to store in the catalog. */ + private def cleanupTableDefinition(spark: SparkSession, table: CatalogTable, snapshot: Snapshot) + : CatalogTable = { + // These actually have no effect on the usability of Delta, but feature flagging legacy + // behavior for now + val storageProps = if (conf.getConf(DeltaSQLConf.DELTA_LEGACY_STORE_WRITER_OPTIONS_AS_PROPS)) { + // Legacy behavior + table.storage + } else { + table.storage.copy(properties = Map.empty) + } + + table.copy( + schema = new StructType(), + properties = Map.empty, + partitionColumnNames = Nil, + // Remove write specific options when updating the catalog + storage = storageProps, + tracksPartitionsInCatalog = true) + } + + /** + * With DataFrameWriterV2, methods like `replace()` or `createOrReplace()` mean that the + * metadata of the table should be replaced. If overwriteSchema=false is provided with these + * methods, then we will verify that the metadata match exactly. + */ + private def replaceMetadataIfNecessary( + txn: OptimisticTransaction, + tableDesc: CatalogTable, + options: DeltaOptions, + schema: StructType): Unit = { + val isReplace = (operation == TableCreationModes.CreateOrReplace || + operation == TableCreationModes.Replace) + // If a user explicitly specifies not to overwrite the schema, during a replace, we should + // tell them that it's not supported + val dontOverwriteSchema = options.options.contains(DeltaOptions.OVERWRITE_SCHEMA_OPTION) && + !options.canOverwriteSchema + if (isReplace && dontOverwriteSchema) { + throw DeltaErrors.illegalUsageException(DeltaOptions.OVERWRITE_SCHEMA_OPTION, "replacing") + } + if (txn.readVersion > -1L && isReplace && !dontOverwriteSchema) { + // When a table already exists, and we're using the DataFrameWriterV2 API to replace + // or createOrReplace a table, we blindly overwrite the metadata. + txn.updateMetadataForNewTable(getProvidedMetadata(table, schema.json)) + } + } + + /** + * Horrible hack to differentiate between DataFrameWriterV1 and V2 so that we can decide + * what to do with table metadata. In DataFrameWriterV1, mode("overwrite").saveAsTable, + * behaves as a CreateOrReplace table, but we have asked for "overwriteSchema" as an + * explicit option to overwrite partitioning or schema information. With DataFrameWriterV2, + * the behavior asked for by the user is clearer: .createOrReplace(), which means that we + * should overwrite schema and/or partitioning. Therefore we have this hack. + */ + private def isV1Writer: Boolean = { + Thread.currentThread().getStackTrace.exists(_.toString.contains( + classOf[DataFrameWriter[_]].getCanonicalName + ".")) + } +} diff --git a/delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/GpuDeleteCommand.scala b/delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/GpuDeleteCommand.scala new file mode 100644 index 00000000000..1ddd00bc046 --- /dev/null +++ b/delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/GpuDeleteCommand.scala @@ -0,0 +1,381 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * This file was derived from DeleteCommand.scala + * in the Delta Lake project at https://github.com/delta-io/delta. + * + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta.rapids.delta23x + +import com.nvidia.spark.rapids.delta.GpuDeltaMetricUpdateUDF + +import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, SparkSession} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, EqualNullSafe, Expression, If, Literal, Not} +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.delta.{DeltaConfigs, DeltaLog, DeltaOperations, DeltaTableUtils, DeltaUDF, OptimisticTransaction} +import org.apache.spark.sql.delta.actions.{Action, AddCDCFile, FileAction} +import org.apache.spark.sql.delta.commands.{DeleteCommandMetrics, DeleteMetric, DeletionVectorUtils, DeltaCommand} +import org.apache.spark.sql.delta.commands.DeleteCommand.{rewritingFilesMsg, FINDING_TOUCHED_FILES_MSG} +import org.apache.spark.sql.delta.commands.MergeIntoCommand.totalBytesAndDistinctPartitionValues +import org.apache.spark.sql.delta.files.TahoeBatchFileIndex +import org.apache.spark.sql.delta.rapids.GpuDeltaLog +import org.apache.spark.sql.delta.sources.DeltaSQLConf +import org.apache.spark.sql.delta.util.Utils +import org.apache.spark.sql.execution.command.LeafRunnableCommand +import org.apache.spark.sql.functions.input_file_name +import org.apache.spark.sql.types.LongType + +/** + * GPU version of Delta Lake DeleteCommand. + * + * Performs a Delete based on the search condition + * + * Algorithm: + * 1) Scan all the files and determine which files have + * the rows that need to be deleted. + * 2) Traverse the affected files and rebuild the touched files. + * 3) Use the Delta protocol to atomically write the remaining rows to new files and remove + * the affected files that are identified in step 1. + */ +case class GpuDeleteCommand( + gpuDeltaLog: GpuDeltaLog, + target: LogicalPlan, + condition: Option[Expression]) + extends LeafRunnableCommand with DeltaCommand with DeleteCommandMetrics { + + override def innerChildren: Seq[QueryPlan[_]] = Seq(target) + + override val output: Seq[Attribute] = Seq(AttributeReference("num_affected_rows", LongType)()) + + override lazy val metrics = createMetrics + + final override def run(sparkSession: SparkSession): Seq[Row] = { + val deltaLog = gpuDeltaLog.deltaLog + recordDeltaOperation(gpuDeltaLog.deltaLog, "delta.dml.delete") { + gpuDeltaLog.withNewTransaction { txn => + DeltaLog.assertRemovable(txn.snapshot) + if (hasBeenExecuted(txn, sparkSession)) { + sendDriverMetrics(sparkSession, metrics) + return Seq.empty + } + + val deleteActions = performDelete(sparkSession, deltaLog, txn) + txn.commitIfNeeded(deleteActions, DeltaOperations.Delete(condition.map(_.sql).toSeq)) + } + // Re-cache all cached plans(including this relation itself, if it's cached) that refer to + // this data source relation. + sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, target) + } + + // Adjust for deletes at partition boundaries. Deletes at partition boundaries is a metadata + // operation, therefore we don't actually have any information around how many rows were deleted + // While this info may exist in the file statistics, it's not guaranteed that we have these + // statistics. To avoid any performance regressions, we currently just return a -1 in such cases + if (metrics("numRemovedFiles").value > 0 && metrics("numDeletedRows").value == 0) { + Seq(Row(-1L)) + } else { + Seq(Row(metrics("numDeletedRows").value)) + } + } + + def performDelete( + sparkSession: SparkSession, + deltaLog: DeltaLog, + txn: OptimisticTransaction): Seq[Action] = { + import org.apache.spark.sql.delta.implicits._ + + var numRemovedFiles: Long = 0 + var numAddedFiles: Long = 0 + var numAddedChangeFiles: Long = 0 + var scanTimeMs: Long = 0 + var rewriteTimeMs: Long = 0 + var numAddedBytes: Long = 0 + var changeFileBytes: Long = 0 + var numRemovedBytes: Long = 0 + var numFilesBeforeSkipping: Long = 0 + var numBytesBeforeSkipping: Long = 0 + var numFilesAfterSkipping: Long = 0 + var numBytesAfterSkipping: Long = 0 + var numPartitionsAfterSkipping: Option[Long] = None + var numPartitionsRemovedFrom: Option[Long] = None + var numPartitionsAddedTo: Option[Long] = None + var numDeletedRows: Option[Long] = None + var numCopiedRows: Option[Long] = None + + val startTime = System.nanoTime() + val numFilesTotal = txn.snapshot.numOfFiles + + val deleteActions: Seq[Action] = condition match { + case None => + // Case 1: Delete the whole table if the condition is true + val reportRowLevelMetrics = conf.getConf(DeltaSQLConf.DELTA_DML_METRICS_FROM_METADATA) + val allFiles = txn.filterFiles(Nil, keepNumRecords = reportRowLevelMetrics) + + numRemovedFiles = allFiles.size + scanTimeMs = (System.nanoTime() - startTime) / 1000 / 1000 + val (numBytes, numPartitions) = totalBytesAndDistinctPartitionValues(allFiles) + numRemovedBytes = numBytes + numFilesBeforeSkipping = numRemovedFiles + numBytesBeforeSkipping = numBytes + numFilesAfterSkipping = numRemovedFiles + numBytesAfterSkipping = numBytes + numDeletedRows = getDeletedRowsFromAddFilesAndUpdateMetrics(allFiles) + + if (txn.metadata.partitionColumns.nonEmpty) { + numPartitionsAfterSkipping = Some(numPartitions) + numPartitionsRemovedFrom = Some(numPartitions) + numPartitionsAddedTo = Some(0) + } + val operationTimestamp = System.currentTimeMillis() + allFiles.map(_.removeWithTimestamp(operationTimestamp)) + case Some(cond) => + val (metadataPredicates, otherPredicates) = + DeltaTableUtils.splitMetadataAndDataPredicates( + cond, txn.metadata.partitionColumns, sparkSession) + + numFilesBeforeSkipping = txn.snapshot.numOfFiles + numBytesBeforeSkipping = txn.snapshot.sizeInBytes + + if (otherPredicates.isEmpty) { + // Case 2: The condition can be evaluated using metadata only. + // Delete a set of files without the need of scanning any data files. + val operationTimestamp = System.currentTimeMillis() + val reportRowLevelMetrics = conf.getConf(DeltaSQLConf.DELTA_DML_METRICS_FROM_METADATA) + val candidateFiles = + txn.filterFiles(metadataPredicates, keepNumRecords = reportRowLevelMetrics) + + scanTimeMs = (System.nanoTime() - startTime) / 1000 / 1000 + numRemovedFiles = candidateFiles.size + numRemovedBytes = candidateFiles.map(_.size).sum + numFilesAfterSkipping = candidateFiles.size + val (numCandidateBytes, numCandidatePartitions) = + totalBytesAndDistinctPartitionValues(candidateFiles) + numBytesAfterSkipping = numCandidateBytes + numDeletedRows = getDeletedRowsFromAddFilesAndUpdateMetrics(candidateFiles) + + if (txn.metadata.partitionColumns.nonEmpty) { + numPartitionsAfterSkipping = Some(numCandidatePartitions) + numPartitionsRemovedFrom = Some(numCandidatePartitions) + numPartitionsAddedTo = Some(0) + } + candidateFiles.map(_.removeWithTimestamp(operationTimestamp)) + } else { + // Case 3: Delete the rows based on the condition. + + // Should we write the DVs to represent the deleted rows? + val shouldWriteDVs = shouldWritePersistentDeletionVectors(sparkSession, txn) + + val candidateFiles = txn.filterFiles( + metadataPredicates ++ otherPredicates, + keepNumRecords = shouldWriteDVs) + // `candidateFiles` contains the files filtered using statistics and delete condition + // They may or may not contains any rows that need to be deleted. + + numFilesAfterSkipping = candidateFiles.size + val (numCandidateBytes, numCandidatePartitions) = + totalBytesAndDistinctPartitionValues(candidateFiles) + numBytesAfterSkipping = numCandidateBytes + if (txn.metadata.partitionColumns.nonEmpty) { + numPartitionsAfterSkipping = Some(numCandidatePartitions) + } + + val nameToAddFileMap = generateCandidateFileMap(deltaLog.dataPath, candidateFiles) + + val fileIndex = new TahoeBatchFileIndex( + sparkSession, "delete", candidateFiles, deltaLog, deltaLog.dataPath, txn.snapshot) + if (shouldWriteDVs) { + // this should be unreachable because we fall back to CPU + // if deletion vectors are enabled. The tracking issue for adding deletion vector + // support is https://github.com/NVIDIA/spark-rapids/issues/8554 + throw new IllegalStateException("Deletion vectors are not supported on GPU") + } else { + // Keep everything from the resolved target except a new TahoeFileIndex + // that only involves the affected files instead of all files. + val newTarget = DeltaTableUtils.replaceFileIndex(target, fileIndex) + val data = Dataset.ofRows(sparkSession, newTarget) + val deletedRowCount = metrics("numDeletedRows") + val deletedRowUdf = DeltaUDF.boolean { + new GpuDeltaMetricUpdateUDF(deletedRowCount) + }.asNondeterministic() + val filesToRewrite = + withStatusCode("DELTA", FINDING_TOUCHED_FILES_MSG) { + if (candidateFiles.isEmpty) { + Array.empty[String] + } else { + data.filter(new Column(cond)) + .select(input_file_name()) + .filter(deletedRowUdf()) + .distinct() + .as[String] + .collect() + } + } + + numRemovedFiles = filesToRewrite.length + scanTimeMs = (System.nanoTime() - startTime) / 1000 / 1000 + if (filesToRewrite.isEmpty) { + // Case 3.1: no row matches and no delete will be triggered + if (txn.metadata.partitionColumns.nonEmpty) { + numPartitionsRemovedFrom = Some(0) + numPartitionsAddedTo = Some(0) + } + Nil + } else { + // Case 3.2: some files need an update to remove the deleted files + // Do the second pass and just read the affected files + val baseRelation = buildBaseRelation( + sparkSession, txn, "delete", deltaLog.dataPath, filesToRewrite, nameToAddFileMap) + // Keep everything from the resolved target except a new TahoeFileIndex + // that only involves the affected files instead of all files. + val newTarget = DeltaTableUtils.replaceFileIndex(target, baseRelation.location) + val targetDF = Dataset.ofRows(sparkSession, newTarget) + val filterCond = Not(EqualNullSafe(cond, Literal.TrueLiteral)) + val rewrittenActions = rewriteFiles(txn, targetDF, filterCond, filesToRewrite.length) + val (changeFiles, rewrittenFiles) = rewrittenActions + .partition(_.isInstanceOf[AddCDCFile]) + numAddedFiles = rewrittenFiles.size + val removedFiles = filesToRewrite.map(f => + getTouchedFile(deltaLog.dataPath, f, nameToAddFileMap)) + val (removedBytes, removedPartitions) = + totalBytesAndDistinctPartitionValues(removedFiles) + numRemovedBytes = removedBytes + val (rewrittenBytes, rewrittenPartitions) = + totalBytesAndDistinctPartitionValues(rewrittenFiles) + numAddedBytes = rewrittenBytes + if (txn.metadata.partitionColumns.nonEmpty) { + numPartitionsRemovedFrom = Some(removedPartitions) + numPartitionsAddedTo = Some(rewrittenPartitions) + } + numAddedChangeFiles = changeFiles.size + changeFileBytes = changeFiles.collect { case f: AddCDCFile => f.size }.sum + rewriteTimeMs = (System.nanoTime() - startTime) / 1000 / 1000 - scanTimeMs + numDeletedRows = Some(metrics("numDeletedRows").value) + numCopiedRows = + Some(metrics("numTouchedRows").value - metrics("numDeletedRows").value) + + val operationTimestamp = System.currentTimeMillis() + removeFilesFromPaths( + deltaLog, nameToAddFileMap, filesToRewrite, operationTimestamp) ++ rewrittenActions + } + } + } + } + metrics("numRemovedFiles").set(numRemovedFiles) + metrics("numAddedFiles").set(numAddedFiles) + val executionTimeMs = (System.nanoTime() - startTime) / 1000 / 1000 + metrics("executionTimeMs").set(executionTimeMs) + metrics("scanTimeMs").set(scanTimeMs) + metrics("rewriteTimeMs").set(rewriteTimeMs) + metrics("numAddedChangeFiles").set(numAddedChangeFiles) + metrics("changeFileBytes").set(changeFileBytes) + metrics("numAddedBytes").set(numAddedBytes) + metrics("numRemovedBytes").set(numRemovedBytes) + metrics("numFilesBeforeSkipping").set(numFilesBeforeSkipping) + metrics("numBytesBeforeSkipping").set(numBytesBeforeSkipping) + metrics("numFilesAfterSkipping").set(numFilesAfterSkipping) + metrics("numBytesAfterSkipping").set(numBytesAfterSkipping) + numPartitionsAfterSkipping.foreach(metrics("numPartitionsAfterSkipping").set) + numPartitionsAddedTo.foreach(metrics("numPartitionsAddedTo").set) + numPartitionsRemovedFrom.foreach(metrics("numPartitionsRemovedFrom").set) + numCopiedRows.foreach(metrics("numCopiedRows").set) + txn.registerSQLMetrics(sparkSession, metrics) + sendDriverMetrics(sparkSession, metrics) + + recordDeltaEvent( + deltaLog, + "delta.dml.delete.stats", + data = DeleteMetric( + condition = condition.map(_.sql).getOrElse("true"), + numFilesTotal, + numFilesAfterSkipping, + numAddedFiles, + numRemovedFiles, + numAddedFiles, + numAddedChangeFiles = numAddedChangeFiles, + numFilesBeforeSkipping, + numBytesBeforeSkipping, + numFilesAfterSkipping, + numBytesAfterSkipping, + numPartitionsAfterSkipping, + numPartitionsAddedTo, + numPartitionsRemovedFrom, + numCopiedRows, + numDeletedRows, + numAddedBytes, + numRemovedBytes, + changeFileBytes = changeFileBytes, + scanTimeMs, + rewriteTimeMs) + ) + + if (deleteActions.nonEmpty) { + createSetTransaction(sparkSession, deltaLog).toSeq ++ deleteActions + } else { + Seq.empty + } + } + + /** + * Returns the list of `AddFile`s and `AddCDCFile`s that have been re-written. + */ + private def rewriteFiles( + txn: OptimisticTransaction, + baseData: DataFrame, + filterCondition: Expression, + numFilesToRewrite: Long): Seq[FileAction] = { + val shouldWriteCdc = DeltaConfigs.CHANGE_DATA_FEED.fromMetaData(txn.metadata) + + // number of total rows that we have seen / are either copying or deleting (sum of both). + val numTouchedRows = metrics("numTouchedRows") + val numTouchedRowsUdf = DeltaUDF.boolean { + new GpuDeltaMetricUpdateUDF(numTouchedRows) + }.asNondeterministic() + + withStatusCode( + "DELTA", rewritingFilesMsg(numFilesToRewrite)) { + val dfToWrite = if (shouldWriteCdc) { + import org.apache.spark.sql.delta.commands.cdc.CDCReader._ + // The logic here ends up being surprisingly elegant, with all source rows ending up in + // the output. Recall that we flipped the user-provided delete condition earlier, before the + // call to `rewriteFiles`. All rows which match this latest `filterCondition` are retained + // as table data, while all rows which don't match are removed from the rewritten table data + // but do get included in the output as CDC events. + baseData + .filter(numTouchedRowsUdf()) + .withColumn( + CDC_TYPE_COLUMN_NAME, + new Column(If(filterCondition, CDC_TYPE_NOT_CDC, CDC_TYPE_DELETE)) + ) + } else { + baseData + .filter(numTouchedRowsUdf()) + .filter(new Column(filterCondition)) + } + + txn.writeFiles(dfToWrite) + } + } + + def shouldWritePersistentDeletionVectors( + spark: SparkSession, txn: OptimisticTransaction): Boolean = { + // DELETE with DVs only enabled for tests. + Utils.isTesting && + spark.conf.get(DeltaSQLConf.DELETE_USE_PERSISTENT_DELETION_VECTORS) && + DeletionVectorUtils.deletionVectorsWritable(txn.snapshot) + } +} diff --git a/delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/GpuMergeIntoCommand.scala b/delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/GpuMergeIntoCommand.scala new file mode 100644 index 00000000000..c4eef959208 --- /dev/null +++ b/delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/GpuMergeIntoCommand.scala @@ -0,0 +1,1331 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * This file was derived from MergeIntoCommand.scala + * in the Delta Lake project at https://github.com/delta-io/delta. + * + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta.rapids.delta23x + +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import com.fasterxml.jackson.databind.annotation.JsonDeserialize +import com.nvidia.spark.rapids.{BaseExprMeta, GpuOverrides, RapidsConf} +import com.nvidia.spark.rapids.delta._ + +import org.apache.spark.SparkContext +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, BasePredicate, Expression, Literal, NamedExpression, PredicateHelper, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.delta._ +import org.apache.spark.sql.delta.actions.{AddCDCFile, AddFile, FileAction} +import org.apache.spark.sql.delta.commands.DeltaCommand +import org.apache.spark.sql.delta.commands.merge.MergeIntoMaterializeSource +import org.apache.spark.sql.delta.rapids.GpuDeltaLog +import org.apache.spark.sql.delta.schema.{ImplicitMetadataOperation, SchemaUtils} +import org.apache.spark.sql.delta.sources.DeltaSQLConf +import org.apache.spark.sql.delta.util.{AnalysisHelper, SetAccumulator} +import org.apache.spark.sql.execution.command.LeafRunnableCommand +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DataTypes, LongType, StringType, StructType} + +case class GpuMergeDataSizes( + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + rows: Option[Long] = None, + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + files: Option[Long] = None, + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + bytes: Option[Long] = None, + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + partitions: Option[Long] = None) + +/** + * Represents the state of a single merge clause: + * - merge clause's (optional) predicate + * - action type (insert, update, delete) + * - action's expressions + */ +case class GpuMergeClauseStats( + condition: Option[String], + actionType: String, + actionExpr: Seq[String]) + +object GpuMergeClauseStats { + def apply(mergeClause: DeltaMergeIntoClause): GpuMergeClauseStats = { + GpuMergeClauseStats( + condition = mergeClause.condition.map(_.sql), + mergeClause.clauseType.toLowerCase(), + actionExpr = mergeClause.actions.map(_.sql)) + } +} + +/** State for a GPU merge operation */ +case class GpuMergeStats( + // Merge condition expression + conditionExpr: String, + + // Expressions used in old MERGE stats, now always Null + updateConditionExpr: String, + updateExprs: Seq[String], + insertConditionExpr: String, + insertExprs: Seq[String], + deleteConditionExpr: String, + + // Newer expressions used in MERGE with any number of MATCHED/NOT MATCHED/NOT MATCHED BY SOURCE + matchedStats: Seq[GpuMergeClauseStats], + notMatchedStats: Seq[GpuMergeClauseStats], + notMatchedBySourceStats: Seq[GpuMergeClauseStats], + + // Timings + executionTimeMs: Long, + scanTimeMs: Long, + rewriteTimeMs: Long, + + // Data sizes of source and target at different stages of processing + source: GpuMergeDataSizes, + targetBeforeSkipping: GpuMergeDataSizes, + targetAfterSkipping: GpuMergeDataSizes, + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + sourceRowsInSecondScan: Option[Long], + + // Data change sizes + targetFilesRemoved: Long, + targetFilesAdded: Long, + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + targetChangeFilesAdded: Option[Long], + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + targetChangeFileBytes: Option[Long], + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + targetBytesRemoved: Option[Long], + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + targetBytesAdded: Option[Long], + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + targetPartitionsRemovedFrom: Option[Long], + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + targetPartitionsAddedTo: Option[Long], + targetRowsCopied: Long, + targetRowsUpdated: Long, + targetRowsMatchedUpdated: Long, + targetRowsNotMatchedBySourceUpdated: Long, + targetRowsInserted: Long, + targetRowsDeleted: Long, + targetRowsMatchedDeleted: Long, + targetRowsNotMatchedBySourceDeleted: Long, + + // MergeMaterializeSource stats + materializeSourceReason: Option[String] = None, + @JsonDeserialize(contentAs = classOf[java.lang.Long]) + materializeSourceAttempts: Option[Long] = None +) + +object GpuMergeStats { + + def fromMergeSQLMetrics( + metrics: Map[String, SQLMetric], + condition: Expression, + matchedClauses: Seq[DeltaMergeIntoMatchedClause], + notMatchedClauses: Seq[DeltaMergeIntoNotMatchedClause], + notMatchedBySourceClauses: Seq[DeltaMergeIntoNotMatchedBySourceClause], + isPartitioned: Boolean): GpuMergeStats = { + + def metricValueIfPartitioned(metricName: String): Option[Long] = { + if (isPartitioned) Some(metrics(metricName).value) else None + } + + GpuMergeStats( + // Merge condition expression + conditionExpr = condition.sql, + + // Newer expressions used in MERGE with any number of MATCHED/NOT MATCHED/ + // NOT MATCHED BY SOURCE + matchedStats = matchedClauses.map(GpuMergeClauseStats(_)), + notMatchedStats = notMatchedClauses.map(GpuMergeClauseStats(_)), + notMatchedBySourceStats = notMatchedBySourceClauses.map(GpuMergeClauseStats(_)), + + // Timings + executionTimeMs = metrics("executionTimeMs").value, + scanTimeMs = metrics("scanTimeMs").value, + rewriteTimeMs = metrics("rewriteTimeMs").value, + + // Data sizes of source and target at different stages of processing + source = GpuMergeDataSizes(rows = Some(metrics("numSourceRows").value)), + targetBeforeSkipping = + GpuMergeDataSizes( + files = Some(metrics("numTargetFilesBeforeSkipping").value), + bytes = Some(metrics("numTargetBytesBeforeSkipping").value)), + targetAfterSkipping = + GpuMergeDataSizes( + files = Some(metrics("numTargetFilesAfterSkipping").value), + bytes = Some(metrics("numTargetBytesAfterSkipping").value), + partitions = metricValueIfPartitioned("numTargetPartitionsAfterSkipping")), + sourceRowsInSecondScan = + metrics.get("numSourceRowsInSecondScan").map(_.value).filter(_ >= 0), + + // Data change sizes + targetFilesAdded = metrics("numTargetFilesAdded").value, + targetChangeFilesAdded = metrics.get("numTargetChangeFilesAdded").map(_.value), + targetChangeFileBytes = metrics.get("numTargetChangeFileBytes").map(_.value), + targetFilesRemoved = metrics("numTargetFilesRemoved").value, + targetBytesAdded = Some(metrics("numTargetBytesAdded").value), + targetBytesRemoved = Some(metrics("numTargetBytesRemoved").value), + targetPartitionsRemovedFrom = metricValueIfPartitioned("numTargetPartitionsRemovedFrom"), + targetPartitionsAddedTo = metricValueIfPartitioned("numTargetPartitionsAddedTo"), + targetRowsCopied = metrics("numTargetRowsCopied").value, + targetRowsUpdated = metrics("numTargetRowsUpdated").value, + targetRowsMatchedUpdated = metrics("numTargetRowsMatchedUpdated").value, + targetRowsNotMatchedBySourceUpdated = metrics("numTargetRowsNotMatchedBySourceUpdated").value, + targetRowsInserted = metrics("numTargetRowsInserted").value, + targetRowsDeleted = metrics("numTargetRowsDeleted").value, + targetRowsMatchedDeleted = metrics("numTargetRowsMatchedDeleted").value, + targetRowsNotMatchedBySourceDeleted = metrics("numTargetRowsNotMatchedBySourceDeleted").value, + + // Deprecated fields + updateConditionExpr = null, + updateExprs = null, + insertConditionExpr = null, + insertExprs = null, + deleteConditionExpr = null) + } +} + +/** + * GPU version of Delta Lake's MergeIntoCommand. + * + * Performs a merge of a source query/table into a Delta table. + * + * Issues an error message when the ON search_condition of the MERGE statement can match + * a single row from the target table with multiple rows of the source table-reference. + * + * Algorithm: + * + * Phase 1: Find the input files in target that are touched by the rows that satisfy + * the condition and verify that no two source rows match with the same target row. + * This is implemented as an inner-join using the given condition. See [[findTouchedFiles]] + * for more details. + * + * Phase 2: Read the touched files again and write new files with updated and/or inserted rows. + * + * Phase 3: Use the Delta protocol to atomically remove the touched files and add the new files. + * + * @param source Source data to merge from + * @param target Target table to merge into + * @param gpuDeltaLog Delta log to use + * @param condition Condition for a source row to match with a target row + * @param matchedClauses All info related to matched clauses. + * @param notMatchedClauses All info related to not matched clauses. + * @param notMatchedBySourceClauses All info related to not matched by source clauses. + * @param migratedSchema The final schema of the target - may be changed by schema + * evolution. + */ +case class GpuMergeIntoCommand( + @transient source: LogicalPlan, + @transient target: LogicalPlan, + @transient gpuDeltaLog: GpuDeltaLog, + condition: Expression, + matchedClauses: Seq[DeltaMergeIntoMatchedClause], + notMatchedClauses: Seq[DeltaMergeIntoNotMatchedClause], + notMatchedBySourceClauses: Seq[DeltaMergeIntoNotMatchedBySourceClause], + migratedSchema: Option[StructType])( + @transient val rapidsConf: RapidsConf) + extends LeafRunnableCommand + with DeltaCommand + with PredicateHelper + with AnalysisHelper + with ImplicitMetadataOperation + with MergeIntoMaterializeSource { + + import GpuMergeIntoCommand._ + + import SQLMetrics._ + import org.apache.spark.sql.delta.commands.cdc.CDCReader._ + + override val otherCopyArgs: Seq[AnyRef] = Seq(rapidsConf) + + override val canMergeSchema: Boolean = conf.getConf(DeltaSQLConf.DELTA_SCHEMA_AUTO_MIGRATE) + override val canOverwriteSchema: Boolean = false + + override val output: Seq[Attribute] = Seq( + AttributeReference("num_affected_rows", LongType)(), + AttributeReference("num_updated_rows", LongType)(), + AttributeReference("num_deleted_rows", LongType)(), + AttributeReference("num_inserted_rows", LongType)()) + + @transient private lazy val sc: SparkContext = SparkContext.getOrCreate() + @transient private lazy val targetDeltaLog: DeltaLog = gpuDeltaLog.deltaLog + /** + * Map to get target output attributes by name. + * The case sensitivity of the map is set accordingly to Spark configuration. + */ + @transient private lazy val targetOutputAttributesMap: Map[String, Attribute] = { + val attrMap: Map[String, Attribute] = target + .outputSet.view + .map(attr => attr.name -> attr).toMap + if (conf.caseSensitiveAnalysis) { + attrMap + } else { + CaseInsensitiveMap(attrMap) + } + } + + /** Whether this merge statement has only a single insert (NOT MATCHED) clause. */ + private def isSingleInsertOnly: Boolean = + matchedClauses.isEmpty && notMatchedBySourceClauses.isEmpty && notMatchedClauses.length == 1 + /** Whether this merge statement has no insert (NOT MATCHED) clause. */ + private def hasNoInserts: Boolean = notMatchedClauses.isEmpty + + // We over-count numTargetRowsDeleted when there are multiple matches; + // this is the amount of the overcount, so we can subtract it to get a correct final metric. + private var multipleMatchDeleteOnlyOvercount: Option[Long] = None + + override lazy val metrics = Map[String, SQLMetric]( + "numSourceRows" -> createMetric(sc, "number of source rows"), + "numSourceRowsInSecondScan" -> + createMetric(sc, "number of source rows (during repeated scan)"), + "numTargetRowsCopied" -> createMetric(sc, "number of target rows rewritten unmodified"), + "numTargetRowsInserted" -> createMetric(sc, "number of inserted rows"), + "numTargetRowsUpdated" -> createMetric(sc, "number of updated rows"), + "numTargetRowsMatchedUpdated" -> + createMetric(sc, "number of rows updated by a matched clause"), + "numTargetRowsNotMatchedBySourceUpdated" -> + createMetric(sc, "number of rows updated by a not matched by source clause"), + "numTargetRowsDeleted" -> createMetric(sc, "number of deleted rows"), + "numTargetRowsMatchedDeleted" -> + createMetric(sc, "number of rows deleted by a matched clause"), + "numTargetRowsNotMatchedBySourceDeleted" -> + createMetric(sc, "number of rows deleted by a not matched by source clause"), + "numTargetFilesBeforeSkipping" -> createMetric(sc, "number of target files before skipping"), + "numTargetFilesAfterSkipping" -> createMetric(sc, "number of target files after skipping"), + "numTargetFilesRemoved" -> createMetric(sc, "number of files removed to target"), + "numTargetFilesAdded" -> createMetric(sc, "number of files added to target"), + "numTargetChangeFilesAdded" -> + createMetric(sc, "number of change data capture files generated"), + "numTargetChangeFileBytes" -> + createMetric(sc, "total size of change data capture files generated"), + "numTargetBytesBeforeSkipping" -> createMetric(sc, "number of target bytes before skipping"), + "numTargetBytesAfterSkipping" -> createMetric(sc, "number of target bytes after skipping"), + "numTargetBytesRemoved" -> createMetric(sc, "number of target bytes removed"), + "numTargetBytesAdded" -> createMetric(sc, "number of target bytes added"), + "numTargetPartitionsAfterSkipping" -> + createMetric(sc, "number of target partitions after skipping"), + "numTargetPartitionsRemovedFrom" -> + createMetric(sc, "number of target partitions from which files were removed"), + "numTargetPartitionsAddedTo" -> + createMetric(sc, "number of target partitions to which files were added"), + "executionTimeMs" -> + createTimingMetric(sc, "time taken to execute the entire operation"), + "scanTimeMs" -> + createTimingMetric(sc, "time taken to scan the files for matches"), + "rewriteTimeMs" -> + createTimingMetric(sc, "time taken to rewrite the matched files")) + + override def run(spark: SparkSession): Seq[Row] = { + metrics("executionTimeMs").set(0) + metrics("scanTimeMs").set(0) + metrics("rewriteTimeMs").set(0) + + if (migratedSchema.isDefined) { + // Block writes of void columns in the Delta log. Currently void columns are not properly + // supported and are dropped on read, but this is not enough for merge command that is also + // reading the schema from the Delta log. Until proper support we prefer to fail merge + // queries that add void columns. + val newNullColumn = SchemaUtils.findNullTypeColumn(migratedSchema.get) + if (newNullColumn.isDefined) { + throw new AnalysisException( + s"""Cannot add column '${newNullColumn.get}' with type 'void'. Please explicitly specify a + |non-void type.""".stripMargin.replaceAll("\n", " ") + ) + } + } + val (materializeSource, _) = shouldMaterializeSource(spark, source, isSingleInsertOnly) + if (!materializeSource) { + runMerge(spark) + } else { + // If it is determined that source should be materialized, wrap the execution with retries, + // in case the data of the materialized source is lost. + runWithMaterializedSourceLostRetries( + spark, targetDeltaLog, metrics, runMerge) + } + } + + protected def runMerge(spark: SparkSession): Seq[Row] = { + recordDeltaOperation(targetDeltaLog, "delta.dml.merge") { + val startTime = System.nanoTime() + gpuDeltaLog.withNewTransaction { deltaTxn => + if (hasBeenExecuted(deltaTxn, spark)) { + sendDriverMetrics(spark, metrics) + return Seq.empty + } + if (target.schema.size != deltaTxn.metadata.schema.size) { + throw DeltaErrors.schemaChangedSinceAnalysis( + atAnalysis = target.schema, latestSchema = deltaTxn.metadata.schema) + } + + if (canMergeSchema) { + updateMetadata( + spark, deltaTxn, migratedSchema.getOrElse(target.schema), + deltaTxn.metadata.partitionColumns, deltaTxn.metadata.configuration, + isOverwriteMode = false, rearrangeOnly = false) + } + + // If materialized, prepare the DF reading the materialize source + // Otherwise, prepare a regular DF from source plan. + val materializeSourceReason = prepareSourceDFAndReturnMaterializeReason( + spark, + source, + condition, + matchedClauses, + notMatchedClauses, + isSingleInsertOnly) + + val deltaActions = { + if (isSingleInsertOnly && spark.conf.get(DeltaSQLConf.MERGE_INSERT_ONLY_ENABLED)) { + writeInsertsOnlyWhenNoMatchedClauses(spark, deltaTxn) + } else { + val filesToRewrite = findTouchedFiles(spark, deltaTxn) + val newWrittenFiles = withStatusCode("DELTA", "Writing merged data") { + writeAllChanges(spark, deltaTxn, filesToRewrite) + } + filesToRewrite.map(_.remove) ++ newWrittenFiles + } + } + + val finalActions = createSetTransaction(spark, targetDeltaLog).toSeq ++ deltaActions + // Metrics should be recorded before commit (where they are written to delta logs). + metrics("executionTimeMs").set((System.nanoTime() - startTime) / 1000 / 1000) + deltaTxn.registerSQLMetrics(spark, metrics) + + // This is a best-effort sanity check. + if (metrics("numSourceRowsInSecondScan").value >= 0 && + metrics("numSourceRows").value != metrics("numSourceRowsInSecondScan").value) { + log.warn(s"Merge source has ${metrics("numSourceRows")} rows in initial scan but " + + s"${metrics("numSourceRowsInSecondScan")} rows in second scan") + if (conf.getConf(DeltaSQLConf.MERGE_FAIL_IF_SOURCE_CHANGED)) { + throw DeltaErrors.sourceNotDeterministicInMergeException(spark) + } + } + + deltaTxn.commitIfNeeded( + finalActions, + DeltaOperations.Merge( + Option(condition.sql), + matchedClauses.map(DeltaOperations.MergePredicate(_)), + notMatchedClauses.map(DeltaOperations.MergePredicate(_)), + notMatchedBySourceClauses.map(DeltaOperations.MergePredicate(_)))) + + // Record metrics + var stats = GpuMergeStats.fromMergeSQLMetrics( + metrics, + condition, + matchedClauses, + notMatchedClauses, + notMatchedBySourceClauses, + deltaTxn.metadata.partitionColumns.nonEmpty) + stats = stats.copy( + materializeSourceReason = Some(materializeSourceReason.toString), + materializeSourceAttempts = Some(attempt)) + + recordDeltaEvent(targetDeltaLog, "delta.dml.merge.stats", data = stats) + + } + spark.sharedState.cacheManager.recacheByPlan(spark, target) + } + sendDriverMetrics(spark, metrics) + Seq(Row(metrics("numTargetRowsUpdated").value + metrics("numTargetRowsDeleted").value + + metrics("numTargetRowsInserted").value, metrics("numTargetRowsUpdated").value, + metrics("numTargetRowsDeleted").value, metrics("numTargetRowsInserted").value)) + } + + /** + * Find the target table files that contain the rows that satisfy the merge condition. This is + * implemented as an inner-join between the source query/table and the target table using + * the merge condition. + */ + private def findTouchedFiles( + spark: SparkSession, + deltaTxn: OptimisticTransaction + ): Seq[AddFile] = recordMergeOperation(sqlMetricName = "scanTimeMs") { + + // Accumulator to collect all the distinct touched files + val touchedFilesAccum = new SetAccumulator[String]() + spark.sparkContext.register(touchedFilesAccum, TOUCHED_FILES_ACCUM_NAME) + + // UDFs to records touched files names and add them to the accumulator + val recordTouchedFileName = DeltaUDF.intFromString( + new GpuDeltaRecordTouchedFileNameUDF(touchedFilesAccum)).asNondeterministic() + + // Prune non-matching files if we don't need to collect them for NOT MATCHED BY SOURCE clauses. + val dataSkippedFiles = + if (notMatchedBySourceClauses.isEmpty) { + val targetOnlyPredicates = + splitConjunctivePredicates(condition).filter(_.references.subsetOf(target.outputSet)) + deltaTxn.filterFiles(targetOnlyPredicates) + } else { + deltaTxn.filterFiles() + } + + // UDF to increment metrics + val incrSourceRowCountExpr = makeMetricUpdateUDF("numSourceRows") + val sourceDF = getSourceDF() + .filter(new Column(incrSourceRowCountExpr)) + + // Join the source and target table using the merge condition to find touched files. An inner + // join collects all candidate files for MATCHED clauses, a right outer join also includes + // candidates for NOT MATCHED BY SOURCE clauses. + // In addition, we attach two columns + // - a monotonically increasing row id for target rows to later identify whether the same + // target row is modified by multiple user or not + // - the target file name the row is from to later identify the files touched by matched rows + val joinType = if (notMatchedBySourceClauses.isEmpty) "inner" else "right_outer" + val targetDF = buildTargetPlanWithFiles(spark, deltaTxn, dataSkippedFiles) + .withColumn(ROW_ID_COL, monotonically_increasing_id()) + .withColumn(FILE_NAME_COL, input_file_name()) + val joinToFindTouchedFiles = sourceDF.join(targetDF, new Column(condition), joinType) + + // Process the matches from the inner join to record touched files and find multiple matches + val collectTouchedFiles = joinToFindTouchedFiles + .select(col(ROW_ID_COL), recordTouchedFileName(col(FILE_NAME_COL)).as("one")) + + // Calculate frequency of matches per source row + val matchedRowCounts = collectTouchedFiles.groupBy(ROW_ID_COL).agg(sum("one").as("count")) + + // Get multiple matches and simultaneously collect (using touchedFilesAccum) the file names + // multipleMatchCount = # of target rows with more than 1 matching source row (duplicate match) + // multipleMatchSum = total # of duplicate matched rows + import org.apache.spark.sql.delta.implicits._ + val (multipleMatchCount, multipleMatchSum) = matchedRowCounts + .filter("count > 1") + .select(coalesce(count(new Column("*")), lit(0)), coalesce(sum("count"), lit(0))) + .as[(Long, Long)] + .collect() + .head + + val hasMultipleMatches = multipleMatchCount > 0 + + // Throw error if multiple matches are ambiguous or cannot be computed correctly. + val canBeComputedUnambiguously = { + // Multiple matches are not ambiguous when there is only one unconditional delete as + // all the matched row pairs in the 2nd join in `writeAllChanges` will get deleted. + val isUnconditionalDelete = matchedClauses.headOption match { + case Some(DeltaMergeIntoMatchedDeleteClause(None)) => true + case _ => false + } + matchedClauses.size == 1 && isUnconditionalDelete + } + + if (hasMultipleMatches && !canBeComputedUnambiguously) { + throw DeltaErrors.multipleSourceRowMatchingTargetRowInMergeException(spark) + } + + if (hasMultipleMatches) { + // This is only allowed for delete-only queries. + // This query will count the duplicates for numTargetRowsDeleted in Job 2, + // because we count matches after the join and not just the target rows. + // We have to compensate for this by subtracting the duplicates later, + // so we need to record them here. + val duplicateCount = multipleMatchSum - multipleMatchCount + multipleMatchDeleteOnlyOvercount = Some(duplicateCount) + } + + // Get the AddFiles using the touched file names. + val touchedFileNames = touchedFilesAccum.value.iterator().asScala.toSeq + logTrace(s"findTouchedFiles: matched files:\n\t${touchedFileNames.mkString("\n\t")}") + + val nameToAddFileMap = generateCandidateFileMap(targetDeltaLog.dataPath, dataSkippedFiles) + val touchedAddFiles = touchedFileNames.map(f => + getTouchedFile(targetDeltaLog.dataPath, f, nameToAddFileMap)) + + // When the target table is empty, and the optimizer optimized away the join entirely + // numSourceRows will be incorrectly 0. We need to scan the source table once to get the correct + // metric here. + if (metrics("numSourceRows").value == 0 && + (dataSkippedFiles.isEmpty || targetDF.take(1).isEmpty)) { + val numSourceRows = sourceDF.count() + metrics("numSourceRows").set(numSourceRows) + } + + // Update metrics + metrics("numTargetFilesBeforeSkipping") += deltaTxn.snapshot.numOfFiles + metrics("numTargetBytesBeforeSkipping") += deltaTxn.snapshot.sizeInBytes + val (afterSkippingBytes, afterSkippingPartitions) = + totalBytesAndDistinctPartitionValues(dataSkippedFiles) + metrics("numTargetFilesAfterSkipping") += dataSkippedFiles.size + metrics("numTargetBytesAfterSkipping") += afterSkippingBytes + metrics("numTargetPartitionsAfterSkipping") += afterSkippingPartitions + val (removedBytes, removedPartitions) = totalBytesAndDistinctPartitionValues(touchedAddFiles) + metrics("numTargetFilesRemoved") += touchedAddFiles.size + metrics("numTargetBytesRemoved") += removedBytes + metrics("numTargetPartitionsRemovedFrom") += removedPartitions + touchedAddFiles + } + + /** + * This is an optimization of the case when there is no update clause for the merge. + * We perform an left anti join on the source data to find the rows to be inserted. + * + * This will currently only optimize for the case when there is a _single_ notMatchedClause. + */ + private def writeInsertsOnlyWhenNoMatchedClauses( + spark: SparkSession, + deltaTxn: OptimisticTransaction + ): Seq[FileAction] = recordMergeOperation(sqlMetricName = "rewriteTimeMs") { + + // UDFs to update metrics + val incrSourceRowCountExpr = makeMetricUpdateUDF("numSourceRows") + val incrInsertedCountExpr = makeMetricUpdateUDF("numTargetRowsInserted") + + val outputColNames = getTargetOutputCols(deltaTxn).map(_.name) + // we use head here since we know there is only a single notMatchedClause + val outputExprs = notMatchedClauses.head.resolvedActions.map(_.expr) + val outputCols = outputExprs.zip(outputColNames).map { case (expr, name) => + new Column(Alias(expr, name)()) + } + + // source DataFrame + val sourceDF = getSourceDF() + .filter(new Column(incrSourceRowCountExpr)) + .filter(new Column(notMatchedClauses.head.condition.getOrElse(Literal.TrueLiteral))) + + // Skip data based on the merge condition + val conjunctivePredicates = splitConjunctivePredicates(condition) + val targetOnlyPredicates = + conjunctivePredicates.filter(_.references.subsetOf(target.outputSet)) + val dataSkippedFiles = deltaTxn.filterFiles(targetOnlyPredicates) + + // target DataFrame + val targetDF = buildTargetPlanWithFiles(spark, deltaTxn, dataSkippedFiles) + + val insertDf = sourceDF.join(targetDF, new Column(condition), "leftanti") + .select(outputCols: _*) + .filter(new Column(incrInsertedCountExpr)) + + val newFiles = deltaTxn + .writeFiles(repartitionIfNeeded(spark, insertDf, deltaTxn.metadata.partitionColumns)) + .filter { + // In some cases (e.g. insert-only when all rows are matched, insert-only with an empty + // source, insert-only with an unsatisfied condition) we can write out an empty insertDf. + // This is hard to catch before the write without collecting the DF ahead of time. Instead, + // we can just accept only the AddFiles that actually add rows or + // when we don't know the number of records + case a: AddFile => a.numLogicalRecords.forall(_ > 0) + case _ => true + } + + // Update metrics + metrics("numTargetFilesBeforeSkipping") += deltaTxn.snapshot.numOfFiles + metrics("numTargetBytesBeforeSkipping") += deltaTxn.snapshot.sizeInBytes + val (afterSkippingBytes, afterSkippingPartitions) = + totalBytesAndDistinctPartitionValues(dataSkippedFiles) + metrics("numTargetFilesAfterSkipping") += dataSkippedFiles.size + metrics("numTargetBytesAfterSkipping") += afterSkippingBytes + metrics("numTargetPartitionsAfterSkipping") += afterSkippingPartitions + metrics("numTargetFilesRemoved") += 0 + metrics("numTargetBytesRemoved") += 0 + metrics("numTargetPartitionsRemovedFrom") += 0 + val (addedBytes, addedPartitions) = totalBytesAndDistinctPartitionValues(newFiles) + metrics("numTargetFilesAdded") += newFiles.count(_.isInstanceOf[AddFile]) + metrics("numTargetBytesAdded") += addedBytes + metrics("numTargetPartitionsAddedTo") += addedPartitions + newFiles + } + + /** + * Write new files by reading the touched files and updating/inserting data using the source + * query/table. This is implemented using a full|right-outer-join using the merge condition. + * + * Note that unlike the insert-only code paths with just one control column INCR_ROW_COUNT_COL, + * this method has two additional control columns ROW_DROPPED_COL for dropping deleted rows and + * CDC_TYPE_COL_NAME used for handling CDC when enabled. + */ + private def writeAllChanges( + spark: SparkSession, + deltaTxn: OptimisticTransaction, + filesToRewrite: Seq[AddFile] + ): Seq[FileAction] = recordMergeOperation(sqlMetricName = "rewriteTimeMs") { + import org.apache.spark.sql.catalyst.expressions.Literal.{TrueLiteral, FalseLiteral} + + val cdcEnabled = DeltaConfigs.CHANGE_DATA_FEED.fromMetaData(deltaTxn.metadata) + + var targetOutputCols = getTargetOutputCols(deltaTxn) + var outputRowSchema = deltaTxn.metadata.schema + + // When we have duplicate matches (only allowed when the whenMatchedCondition is a delete with + // no match condition) we will incorrectly generate duplicate CDC rows. + // Duplicate matches can be due to: + // - Duplicate rows in the source w.r.t. the merge condition + // - A target-only or source-only merge condition, which essentially turns our join into a cross + // join with the target/source satisfiying the merge condition. + // These duplicate matches are dropped from the main data output since this is a delete + // operation, but the duplicate CDC rows are not removed by default. + // See https://github.com/delta-io/delta/issues/1274 + + // We address this specific scenario by adding row ids to the target before performing our join. + // There should only be one CDC delete row per target row so we can use these row ids to dedupe + // the duplicate CDC delete rows. + + // We also need to address the scenario when there are duplicate matches with delete and we + // insert duplicate rows. Here we need to additionally add row ids to the source before the + // join to avoid dropping these valid duplicate inserted rows and their corresponding cdc rows. + + // When there is an insert clause, we set SOURCE_ROW_ID_COL=null for all delete rows because we + // need to drop the duplicate matches. + val isDeleteWithDuplicateMatchesAndCdc = multipleMatchDeleteOnlyOvercount.nonEmpty && cdcEnabled + + // Generate a new target dataframe that has same output attributes exprIds as the target plan. + // This allows us to apply the existing resolved update/insert expressions. + val baseTargetDF = buildTargetPlanWithFiles(spark, deltaTxn, filesToRewrite) + val joinType = if (hasNoInserts && + spark.conf.get(DeltaSQLConf.MERGE_MATCHED_ONLY_ENABLED)) { + "rightOuter" + } else { + "fullOuter" + } + + logDebug(s"""writeAllChanges using $joinType join: + | source.output: ${source.outputSet} + | target.output: ${target.outputSet} + | condition: $condition + | newTarget.output: ${baseTargetDF.queryExecution.logical.outputSet} + """.stripMargin) + + // UDFs to update metrics + // Make UDFs that appear in the custom join processor node deterministic, as they always + // return true and update a metric. Catalyst precludes non-deterministic UDFs that are not + // allowed outside a very specific set of Catalyst nodes (Project, Filter, Window, Aggregate). + val incrSourceRowCountExpr = makeMetricUpdateUDF("numSourceRowsInSecondScan", + deterministic = true) + val incrUpdatedCountExpr = makeMetricUpdateUDF("numTargetRowsUpdated", deterministic = true) + val incrUpdatedMatchedCountExpr = makeMetricUpdateUDF("numTargetRowsMatchedUpdated", + deterministic = true) + val incrUpdatedNotMatchedBySourceCountExpr = + makeMetricUpdateUDF("numTargetRowsNotMatchedBySourceUpdated", deterministic = true) + val incrInsertedCountExpr = makeMetricUpdateUDF("numTargetRowsInserted", deterministic = true) + val incrNoopCountExpr = makeMetricUpdateUDF("numTargetRowsCopied", deterministic = true) + val incrDeletedCountExpr = makeMetricUpdateUDF("numTargetRowsDeleted", deterministic = true) + val incrDeletedMatchedCountExpr = makeMetricUpdateUDF("numTargetRowsMatchedDeleted", + deterministic = true) + val incrDeletedNotMatchedBySourceCountExpr = + makeMetricUpdateUDF("numTargetRowsNotMatchedBySourceDeleted", deterministic = true) + + // Apply an outer join to find both, matches and non-matches. We are adding two boolean fields + // with value `true`, one to each side of the join. Whether this field is null or not after + // the outer join, will allow us to identify whether the resultant joined row was a + // matched inner result or an unmatched result with null on one side. + // We add row IDs to the targetDF if we have a delete-when-matched clause with duplicate + // matches and CDC is enabled, and additionally add row IDs to the source if we also have an + // insert clause. See above at isDeleteWithDuplicateMatchesAndCdc definition for more details. + var sourceDF = getSourceDF() + .withColumn(SOURCE_ROW_PRESENT_COL, new Column(incrSourceRowCountExpr)) + var targetDF = baseTargetDF + .withColumn(TARGET_ROW_PRESENT_COL, lit(true)) + if (isDeleteWithDuplicateMatchesAndCdc) { + targetDF = targetDF.withColumn(TARGET_ROW_ID_COL, monotonically_increasing_id()) + if (notMatchedClauses.nonEmpty) { // insert clause + sourceDF = sourceDF.withColumn(SOURCE_ROW_ID_COL, monotonically_increasing_id()) + } + } + val joinedDF = sourceDF.join(targetDF, new Column(condition), joinType) + val joinedPlan = joinedDF.queryExecution.analyzed + + def resolveOnJoinedPlan(exprs: Seq[Expression]): Seq[Expression] = { + tryResolveReferencesForExpressions(spark, exprs, joinedPlan) + } + + // ==== Generate the expressions to process full-outer join output and generate target rows ==== + // If there are N columns in the target table, there will be N + 3 columns after processing + // - N columns for target table + // - ROW_DROPPED_COL to define whether the generated row should dropped or written + // - INCR_ROW_COUNT_COL containing a UDF to update the output row row counter + // - CDC_TYPE_COLUMN_NAME containing the type of change being performed in a particular row + + // To generate these N + 3 columns, we will generate N + 3 expressions and apply them to the + // rows in the joinedDF. The CDC column will be either used for CDC generation or dropped before + // performing the final write, and the other two will always be dropped after executing the + // metrics UDF and filtering on ROW_DROPPED_COL. + + // We produce rows for both the main table data (with CDC_TYPE_COLUMN_NAME = CDC_TYPE_NOT_CDC), + // and rows for the CDC data which will be output to CDCReader.CDC_LOCATION. + // See [[CDCReader]] for general details on how partitioning on the CDC type column works. + + // In the following functions `updateOutput`, `deleteOutput` and `insertOutput`, we + // produce a Seq[Expression] for each intended output row. + // Depending on the clause and whether CDC is enabled, we output between 0 and 3 rows, as a + // Seq[Seq[Expression]] + + // There is one corner case outlined above at isDeleteWithDuplicateMatchesAndCdc definition. + // When we have a delete-ONLY merge with duplicate matches we have N + 4 columns: + // N target cols, TARGET_ROW_ID_COL, ROW_DROPPED_COL, INCR_ROW_COUNT_COL, CDC_TYPE_COLUMN_NAME + // When we have a delete-when-matched merge with duplicate matches + an insert clause, we have + // N + 5 columns: + // N target cols, TARGET_ROW_ID_COL, SOURCE_ROW_ID_COL, ROW_DROPPED_COL, INCR_ROW_COUNT_COL, + // CDC_TYPE_COLUMN_NAME + // These ROW_ID_COL will always be dropped before the final write. + + if (isDeleteWithDuplicateMatchesAndCdc) { + targetOutputCols = targetOutputCols :+ UnresolvedAttribute(TARGET_ROW_ID_COL) + outputRowSchema = outputRowSchema.add(TARGET_ROW_ID_COL, DataTypes.LongType) + if (notMatchedClauses.nonEmpty) { // there is an insert clause, make SRC_ROW_ID_COL=null + targetOutputCols = targetOutputCols :+ Alias(Literal(null), SOURCE_ROW_ID_COL)() + outputRowSchema = outputRowSchema.add(SOURCE_ROW_ID_COL, DataTypes.LongType) + } + } + + if (cdcEnabled) { + outputRowSchema = outputRowSchema + .add(ROW_DROPPED_COL, DataTypes.BooleanType) + .add(INCR_ROW_COUNT_COL, DataTypes.BooleanType) + .add(CDC_TYPE_COLUMN_NAME, DataTypes.StringType) + } + + def updateOutput(resolvedActions: Seq[DeltaMergeAction], incrMetricExpr: Expression) + : Seq[Seq[Expression]] = { + val updateExprs = { + // Generate update expressions and set ROW_DELETED_COL = false and + // CDC_TYPE_COLUMN_NAME = CDC_TYPE_NOT_CDC + val mainDataOutput = resolvedActions.map(_.expr) :+ FalseLiteral :+ + incrMetricExpr :+ CDC_TYPE_NOT_CDC_LITERAL + if (cdcEnabled) { + // For update preimage, we have do a no-op copy with ROW_DELETED_COL = false and + // CDC_TYPE_COLUMN_NAME = CDC_TYPE_UPDATE_PREIMAGE and INCR_ROW_COUNT_COL as a no-op + // (because the metric will be incremented in `mainDataOutput`) + val preImageOutput = targetOutputCols :+ FalseLiteral :+ TrueLiteral :+ + Literal(CDC_TYPE_UPDATE_PREIMAGE) + // For update postimage, we have the same expressions as for mainDataOutput but with + // INCR_ROW_COUNT_COL as a no-op (because the metric will be incremented in + // `mainDataOutput`), and CDC_TYPE_COLUMN_NAME = CDC_TYPE_UPDATE_POSTIMAGE + val postImageOutput = mainDataOutput.dropRight(2) :+ TrueLiteral :+ + Literal(CDC_TYPE_UPDATE_POSTIMAGE) + Seq(mainDataOutput, preImageOutput, postImageOutput) + } else { + Seq(mainDataOutput) + } + } + updateExprs.map(resolveOnJoinedPlan) + } + + def deleteOutput(incrMetricExpr: Expression): Seq[Seq[Expression]] = { + val deleteExprs = { + // Generate expressions to set the ROW_DELETED_COL = true and CDC_TYPE_COLUMN_NAME = + // CDC_TYPE_NOT_CDC + val mainDataOutput = targetOutputCols :+ TrueLiteral :+ incrMetricExpr :+ + CDC_TYPE_NOT_CDC_LITERAL + if (cdcEnabled) { + // For delete we do a no-op copy with ROW_DELETED_COL = false, INCR_ROW_COUNT_COL as a + // no-op (because the metric will be incremented in `mainDataOutput`) and + // CDC_TYPE_COLUMN_NAME = CDC_TYPE_DELETE + val deleteCdcOutput = targetOutputCols :+ FalseLiteral :+ TrueLiteral :+ CDC_TYPE_DELETE + Seq(mainDataOutput, deleteCdcOutput) + } else { + Seq(mainDataOutput) + } + } + deleteExprs.map(resolveOnJoinedPlan) + } + + def insertOutput(resolvedActions: Seq[DeltaMergeAction], incrMetricExpr: Expression) + : Seq[Seq[Expression]] = { + // Generate insert expressions and set ROW_DELETED_COL = false and + // CDC_TYPE_COLUMN_NAME = CDC_TYPE_NOT_CDC + val insertExprs = resolvedActions.map(_.expr) + val mainDataOutput = resolveOnJoinedPlan( + if (isDeleteWithDuplicateMatchesAndCdc) { + // Must be delete-when-matched merge with duplicate matches + insert clause + // Therefore we must keep the target row id and source row id. Since this is a not-matched + // clause we know the target row-id will be null. See above at + // isDeleteWithDuplicateMatchesAndCdc definition for more details. + insertExprs :+ + Alias(Literal(null), TARGET_ROW_ID_COL)() :+ UnresolvedAttribute(SOURCE_ROW_ID_COL) :+ + FalseLiteral :+ incrMetricExpr :+ CDC_TYPE_NOT_CDC_LITERAL + } else { + insertExprs :+ FalseLiteral :+ incrMetricExpr :+ CDC_TYPE_NOT_CDC_LITERAL + } + ) + if (cdcEnabled) { + // For insert we have the same expressions as for mainDataOutput, but with + // INCR_ROW_COUNT_COL as a no-op (because the metric will be incremented in + // `mainDataOutput`), and CDC_TYPE_COLUMN_NAME = CDC_TYPE_INSERT + val insertCdcOutput = mainDataOutput.dropRight(2) :+ TrueLiteral :+ Literal(CDC_TYPE_INSERT) + Seq(mainDataOutput, insertCdcOutput) + } else { + Seq(mainDataOutput) + } + } + + def clauseOutput(clause: DeltaMergeIntoClause): Seq[Seq[Expression]] = clause match { + case u: DeltaMergeIntoMatchedUpdateClause => + updateOutput(u.resolvedActions, And(incrUpdatedCountExpr, incrUpdatedMatchedCountExpr)) + case _: DeltaMergeIntoMatchedDeleteClause => + deleteOutput(And(incrDeletedCountExpr, incrDeletedMatchedCountExpr)) + case i: DeltaMergeIntoNotMatchedInsertClause => + insertOutput(i.resolvedActions, incrInsertedCountExpr) + case u: DeltaMergeIntoNotMatchedBySourceUpdateClause => + updateOutput( + u.resolvedActions, + And(incrUpdatedCountExpr, incrUpdatedNotMatchedBySourceCountExpr)) + case _: DeltaMergeIntoNotMatchedBySourceDeleteClause => + deleteOutput(And(incrDeletedCountExpr, incrDeletedNotMatchedBySourceCountExpr)) + } + + def clauseCondition(clause: DeltaMergeIntoClause): Expression = { + // if condition is None, then expression always evaluates to true + val condExpr = clause.condition.getOrElse(TrueLiteral) + resolveOnJoinedPlan(Seq(condExpr)).head + } + + val targetRowHasNoMatch = resolveOnJoinedPlan(Seq(col(SOURCE_ROW_PRESENT_COL).isNull.expr)).head + val sourceRowHasNoMatch = resolveOnJoinedPlan(Seq(col(TARGET_ROW_PRESENT_COL).isNull.expr)).head + val matchedConditions = matchedClauses.map(clauseCondition) + val matchedOutputs = matchedClauses.map(clauseOutput) + val notMatchedConditions = notMatchedClauses.map(clauseCondition) + val notMatchedOutputs = notMatchedClauses.map(clauseOutput) + val notMatchedBySourceConditions = notMatchedBySourceClauses.map(clauseCondition) + val notMatchedBySourceOutputs = notMatchedBySourceClauses.map(clauseOutput) + val noopCopyOutput = + resolveOnJoinedPlan(targetOutputCols :+ FalseLiteral :+ incrNoopCountExpr :+ + CDC_TYPE_NOT_CDC_LITERAL) + val deleteRowOutput = + resolveOnJoinedPlan(targetOutputCols :+ TrueLiteral :+ TrueLiteral :+ + CDC_TYPE_NOT_CDC_LITERAL) + var outputDF = addMergeJoinProcessor(spark, joinedPlan, outputRowSchema, + targetRowHasNoMatch = targetRowHasNoMatch, + sourceRowHasNoMatch = sourceRowHasNoMatch, + matchedConditions = matchedConditions, + matchedOutputs = matchedOutputs, + notMatchedConditions = notMatchedConditions, + notMatchedOutputs = notMatchedOutputs, + notMatchedBySourceConditions = notMatchedBySourceConditions, + notMatchedBySourceOutputs = notMatchedBySourceOutputs, + noopCopyOutput = noopCopyOutput, + deleteRowOutput = deleteRowOutput) + + if (isDeleteWithDuplicateMatchesAndCdc) { + // When we have a delete when matched clause with duplicate matches we have to remove + // duplicate CDC rows. This scenario is further explained at + // isDeleteWithDuplicateMatchesAndCdc definition. + + // To remove duplicate CDC rows generated by the duplicate matches we dedupe by + // TARGET_ROW_ID_COL since there should only be one CDC delete row per target row. + // When there is an insert clause in addition to the delete clause we additionally dedupe by + // SOURCE_ROW_ID_COL and CDC_TYPE_COLUMN_NAME to avoid dropping valid duplicate inserted rows + // and their corresponding CDC rows. + val columnsToDedupeBy = if (notMatchedClauses.nonEmpty) { // insert clause + Seq(TARGET_ROW_ID_COL, SOURCE_ROW_ID_COL, CDC_TYPE_COLUMN_NAME) + } else { + Seq(TARGET_ROW_ID_COL) + } + outputDF = outputDF + .dropDuplicates(columnsToDedupeBy) + .drop(ROW_DROPPED_COL, INCR_ROW_COUNT_COL, TARGET_ROW_ID_COL, SOURCE_ROW_ID_COL) + } else { + outputDF = outputDF.drop(ROW_DROPPED_COL, INCR_ROW_COUNT_COL) + } + + logDebug("writeAllChanges: join output plan:\n" + outputDF.queryExecution) + + // Write to Delta + val newFiles = deltaTxn + .writeFiles(repartitionIfNeeded(spark, outputDF, deltaTxn.metadata.partitionColumns)) + + // Update metrics + val (addedBytes, addedPartitions) = totalBytesAndDistinctPartitionValues(newFiles) + metrics("numTargetFilesAdded") += newFiles.count(_.isInstanceOf[AddFile]) + metrics("numTargetChangeFilesAdded") += newFiles.count(_.isInstanceOf[AddCDCFile]) + metrics("numTargetChangeFileBytes") += newFiles.collect{ case f: AddCDCFile => f.size }.sum + metrics("numTargetBytesAdded") += addedBytes + metrics("numTargetPartitionsAddedTo") += addedPartitions + if (multipleMatchDeleteOnlyOvercount.isDefined) { + // Compensate for counting duplicates during the query. + val actualRowsDeleted = + metrics("numTargetRowsDeleted").value - multipleMatchDeleteOnlyOvercount.get + assert(actualRowsDeleted >= 0) + metrics("numTargetRowsDeleted").set(actualRowsDeleted) + val actualRowsMatchedDeleted = + metrics("numTargetRowsMatchedDeleted").value - multipleMatchDeleteOnlyOvercount.get + assert(actualRowsMatchedDeleted >= 0) + metrics("numTargetRowsMatchedDeleted").set(actualRowsMatchedDeleted) + } + + newFiles + } + + private def addMergeJoinProcessor( + spark: SparkSession, + joinedPlan: LogicalPlan, + outputRowSchema: StructType, + targetRowHasNoMatch: Expression, + sourceRowHasNoMatch: Expression, + matchedConditions: Seq[Expression], + matchedOutputs: Seq[Seq[Seq[Expression]]], + notMatchedConditions: Seq[Expression], + notMatchedOutputs: Seq[Seq[Seq[Expression]]], + notMatchedBySourceConditions: Seq[Expression], + notMatchedBySourceOutputs: Seq[Seq[Seq[Expression]]], + noopCopyOutput: Seq[Expression], + deleteRowOutput: Seq[Expression]): Dataset[Row] = { + def wrap(e: Expression): BaseExprMeta[Expression] = { + GpuOverrides.wrapExpr(e, rapidsConf, None) + } + + val targetRowHasNoMatchMeta = wrap(targetRowHasNoMatch) + val sourceRowHasNoMatchMeta = wrap(sourceRowHasNoMatch) + val matchedConditionsMetas = matchedConditions.map(wrap) + val matchedOutputsMetas = matchedOutputs.map(_.map(_.map(wrap))) + val notMatchedConditionsMetas = notMatchedConditions.map(wrap) + val notMatchedOutputsMetas = notMatchedOutputs.map(_.map(_.map(wrap))) + val notMatchedBySourceConditionsMetas = notMatchedConditions.map(wrap) + val notMatchedBySourceOutputsMetas = notMatchedOutputs.map(_.map(_.map(wrap))) + val noopCopyOutputMetas = noopCopyOutput.map(wrap) + val deleteRowOutputMetas = deleteRowOutput.map(wrap) + val allMetas = Seq(targetRowHasNoMatchMeta, sourceRowHasNoMatchMeta) ++ + matchedConditionsMetas ++ matchedOutputsMetas.flatten.flatten ++ + notMatchedConditionsMetas ++ notMatchedOutputsMetas.flatten.flatten ++ + notMatchedBySourceConditionsMetas ++ notMatchedBySourceOutputsMetas.flatten.flatten ++ + noopCopyOutputMetas ++ deleteRowOutputMetas + allMetas.foreach(_.tagForGpu()) + val canReplace = allMetas.forall(_.canExprTreeBeReplaced) && rapidsConf.isOperatorEnabled( + "spark.rapids.sql.exec.RapidsProcessDeltaMergeJoinExec", false, false) + if (rapidsConf.shouldExplainAll || (rapidsConf.shouldExplain && !canReplace)) { + val exprExplains = allMetas.map(_.explain(rapidsConf.shouldExplainAll)) + val execWorkInfo = if (canReplace) { + "will run on GPU" + } else { + "cannot run on GPU because not all merge processing expressions can be replaced" + } + logWarning(s" $execWorkInfo:\n" + + s" ${exprExplains.mkString(" ")}") + } + + if (canReplace) { + val processedJoinPlan = RapidsProcessDeltaMergeJoin( + joinedPlan, + outputRowSchema.toAttributes, + targetRowHasNoMatch = targetRowHasNoMatch, + sourceRowHasNoMatch = sourceRowHasNoMatch, + matchedConditions = matchedConditions, + matchedOutputs = matchedOutputs, + notMatchedConditions = notMatchedConditions, + notMatchedOutputs = notMatchedOutputs, + notMatchedBySourceConditions = notMatchedBySourceConditions, + notMatchedBySourceOutputs = notMatchedBySourceOutputs, + noopCopyOutput = noopCopyOutput, + deleteRowOutput = deleteRowOutput) + Dataset.ofRows(spark, processedJoinPlan) + } else { + val joinedRowEncoder = RowEncoder(joinedPlan.schema) + val outputRowEncoder = RowEncoder(outputRowSchema).resolveAndBind() + + val processor = new JoinedRowProcessor( + targetRowHasNoMatch = targetRowHasNoMatch, + sourceRowHasNoMatch = sourceRowHasNoMatch, + matchedConditions = matchedConditions, + matchedOutputs = matchedOutputs, + notMatchedConditions = notMatchedConditions, + notMatchedOutputs = notMatchedOutputs, + notMatchedBySourceConditions = notMatchedBySourceConditions, + notMatchedBySourceOutputs = notMatchedBySourceOutputs, + noopCopyOutput = noopCopyOutput, + deleteRowOutput = deleteRowOutput, + joinedAttributes = joinedPlan.output, + joinedRowEncoder = joinedRowEncoder, + outputRowEncoder = outputRowEncoder) + + Dataset.ofRows(spark, joinedPlan).mapPartitions(processor.processPartition)(outputRowEncoder) + } + } + + /** + * Build a new logical plan using the given `files` that has the same output columns (exprIds) + * as the `target` logical plan, so that existing update/insert expressions can be applied + * on this new plan. + */ + private def buildTargetPlanWithFiles( + spark: SparkSession, + deltaTxn: OptimisticTransaction, + files: Seq[AddFile]): DataFrame = { + val targetOutputCols = getTargetOutputCols(deltaTxn) + val targetOutputColsMap = { + val colsMap: Map[String, NamedExpression] = targetOutputCols.view + .map(col => col.name -> col).toMap + if (conf.caseSensitiveAnalysis) { + colsMap + } else { + CaseInsensitiveMap(colsMap) + } + } + + val plan = { + // We have to do surgery to use the attributes from `targetOutputCols` to scan the table. + // In cases of schema evolution, they may not be the same type as the original attributes. + val original = + deltaTxn.deltaLog.createDataFrame(deltaTxn.snapshot, files).queryExecution.analyzed + val transformed = original.transform { + case LogicalRelation(base, _, catalogTbl, isStreaming) => + LogicalRelation( + base, + // We can ignore the new columns which aren't yet AttributeReferences. + targetOutputCols.collect { case a: AttributeReference => a }, + catalogTbl, + isStreaming) + } + + // In case of schema evolution & column mapping, we would also need to rebuild the file format + // because under column mapping, the reference schema within DeltaParquetFileFormat + // that is used to populate metadata needs to be updated + if (deltaTxn.metadata.columnMappingMode != NoMapping) { + val updatedFileFormat = deltaTxn.deltaLog.fileFormat(deltaTxn.metadata) + DeltaTableUtils.replaceFileFormat(transformed, updatedFileFormat) + } else { + transformed + } + } + + // For each plan output column, find the corresponding target output column (by name) and + // create an alias + val aliases = plan.output.map { + case newAttrib: AttributeReference => + val existingTargetAttrib = targetOutputColsMap.get(newAttrib.name) + .getOrElse { + throw DeltaErrors.failedFindAttributeInOutputColumns( + newAttrib.name, targetOutputCols.mkString(",")) + }.asInstanceOf[AttributeReference] + + if (existingTargetAttrib.exprId == newAttrib.exprId) { + // It's not valid to alias an expression to its own exprId (this is considered a + // non-unique exprId by the analyzer), so we just use the attribute directly. + newAttrib + } else { + Alias(newAttrib, existingTargetAttrib.name)(exprId = existingTargetAttrib.exprId) + } + } + + Dataset.ofRows(spark, Project(aliases, plan)) + } + + /** Expressions to increment SQL metrics */ + private def makeMetricUpdateUDF(name: String, deterministic: Boolean = false): Expression = { + // only capture the needed metric in a local variable + val metric = metrics(name) + var u = DeltaUDF.boolean(new GpuDeltaMetricUpdateUDF(metric)) + if (!deterministic) { + u = u.asNondeterministic() + } + u.apply().expr + } + + private def getTargetOutputCols(txn: OptimisticTransaction): Seq[NamedExpression] = { + txn.metadata.schema.map { col => + targetOutputAttributesMap + .get(col.name) + .map { a => + AttributeReference(col.name, col.dataType, col.nullable)(a.exprId) + } + .getOrElse(Alias(Literal(null), col.name)() + ) + } + } + + /** + * Repartitions the output DataFrame by the partition columns if table is partitioned + * and `merge.repartitionBeforeWrite.enabled` is set to true. + */ + protected def repartitionIfNeeded( + spark: SparkSession, + df: DataFrame, + partitionColumns: Seq[String]): DataFrame = { + if (partitionColumns.nonEmpty && spark.conf.get(DeltaSQLConf.MERGE_REPARTITION_BEFORE_WRITE)) { + df.repartition(partitionColumns.map(col): _*) + } else { + df + } + } + + /** + * Execute the given `thunk` and return its result while recording the time taken to do it. + * + * @param sqlMetricName name of SQL metric to update with the time taken by the thunk + * @param thunk the code to execute + */ + private def recordMergeOperation[A](sqlMetricName: String)(thunk: => A): A = { + val startTimeNs = System.nanoTime() + val r = thunk + val timeTakenMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs) + if (sqlMetricName != null && timeTakenMs > 0) { + metrics(sqlMetricName) += timeTakenMs + } + r + } +} + +object GpuMergeIntoCommand { + /** + * Spark UI will track all normal accumulators along with Spark tasks to show them on Web UI. + * However, the accumulator used by `MergeIntoCommand` can store a very large value since it + * tracks all files that need to be rewritten. We should ask Spark UI to not remember it, + * otherwise, the UI data may consume lots of memory. Hence, we use the prefix `internal.metrics.` + * to make this accumulator become an internal accumulator, so that it will not be tracked by + * Spark UI. + */ + val TOUCHED_FILES_ACCUM_NAME = "internal.metrics.MergeIntoDelta.touchedFiles" + + val ROW_ID_COL = "_row_id_" + val TARGET_ROW_ID_COL = "_target_row_id_" + val SOURCE_ROW_ID_COL = "_source_row_id_" + val FILE_NAME_COL = "_file_name_" + val SOURCE_ROW_PRESENT_COL = "_source_row_present_" + val TARGET_ROW_PRESENT_COL = "_target_row_present_" + val ROW_DROPPED_COL = GpuDeltaMergeConstants.ROW_DROPPED_COL + val INCR_ROW_COUNT_COL = "_incr_row_count_" + + // Some Delta versions use Literal(null) which translates to a literal of NullType instead + // of the Literal(null, StringType) which is needed, so using a fixed version here + // rather than the version from Delta Lake. + val CDC_TYPE_NOT_CDC_LITERAL = Literal(null, StringType) + + /** + * @param targetRowHasNoMatch whether a joined row is a target row with no match in the source + * table + * @param sourceRowHasNoMatch whether a joined row is a source row with no match in the target + * table + * @param matchedConditions condition for each match clause + * @param matchedOutputs corresponding output for each match clause. for each clause, we + * have 1-3 output rows, each of which is a sequence of expressions + * to apply to the joined row + * @param notMatchedConditions condition for each not-matched clause + * @param notMatchedOutputs corresponding output for each not-matched clause. for each clause, + * we have 1-2 output rows, each of which is a sequence of + * expressions to apply to the joined row + * @param notMatchedBySourceConditions condition for each not-matched-by-source clause + * @param notMatchedBySourceOutputs corresponding output for each not-matched-by-source + * clause. for each clause, we have 1-3 output rows, each of + * which is a sequence of expressions to apply to the joined + * row + * @param noopCopyOutput no-op expression to copy a target row to the output + * @param deleteRowOutput expression to drop a row from the final output. this is used for + * source rows that don't match any not-matched clauses + * @param joinedAttributes schema of our outer-joined dataframe + * @param joinedRowEncoder joinedDF row encoder + * @param outputRowEncoder final output row encoder + */ + class JoinedRowProcessor( + targetRowHasNoMatch: Expression, + sourceRowHasNoMatch: Expression, + matchedConditions: Seq[Expression], + matchedOutputs: Seq[Seq[Seq[Expression]]], + notMatchedConditions: Seq[Expression], + notMatchedOutputs: Seq[Seq[Seq[Expression]]], + notMatchedBySourceConditions: Seq[Expression], + notMatchedBySourceOutputs: Seq[Seq[Seq[Expression]]], + noopCopyOutput: Seq[Expression], + deleteRowOutput: Seq[Expression], + joinedAttributes: Seq[Attribute], + joinedRowEncoder: ExpressionEncoder[Row], + outputRowEncoder: ExpressionEncoder[Row]) extends Serializable { + + private def generateProjection(exprs: Seq[Expression]): UnsafeProjection = { + UnsafeProjection.create(exprs, joinedAttributes) + } + + private def generatePredicate(expr: Expression): BasePredicate = { + GeneratePredicate.generate(expr, joinedAttributes) + } + + def processPartition(rowIterator: Iterator[Row]): Iterator[Row] = { + + val targetRowHasNoMatchPred = generatePredicate(targetRowHasNoMatch) + val sourceRowHasNoMatchPred = generatePredicate(sourceRowHasNoMatch) + val matchedPreds = matchedConditions.map(generatePredicate) + val matchedProjs = matchedOutputs.map(_.map(generateProjection)) + val notMatchedPreds = notMatchedConditions.map(generatePredicate) + val notMatchedProjs = notMatchedOutputs.map(_.map(generateProjection)) + val notMatchedBySourcePreds = notMatchedBySourceConditions.map(generatePredicate) + val notMatchedBySourceProjs = notMatchedBySourceOutputs.map(_.map(generateProjection)) + val noopCopyProj = generateProjection(noopCopyOutput) + val deleteRowProj = generateProjection(deleteRowOutput) + val outputProj = UnsafeProjection.create(outputRowEncoder.schema) + + // this is accessing ROW_DROPPED_COL. If ROW_DROPPED_COL is not in outputRowEncoder.schema + // then CDC must be disabled and it's the column after our output cols + def shouldDeleteRow(row: InternalRow): Boolean = { + row.getBoolean( + outputRowEncoder.schema.getFieldIndex(ROW_DROPPED_COL) + .getOrElse(outputRowEncoder.schema.fields.size) + ) + } + + def processRow(inputRow: InternalRow): Iterator[InternalRow] = { + // Identify which set of clauses to execute: matched, not-matched or not-matched-by-source + val (predicates, projections, noopAction) = if (targetRowHasNoMatchPred.eval(inputRow)) { + // Target row did not match any source row, so update the target row. + (notMatchedBySourcePreds, notMatchedBySourceProjs, noopCopyProj) + } else if (sourceRowHasNoMatchPred.eval(inputRow)) { + // Source row did not match with any target row, so insert the new source row + (notMatchedPreds, notMatchedProjs, deleteRowProj) + } else { + // Source row matched with target row, so update the target row + (matchedPreds, matchedProjs, noopCopyProj) + } + + // find (predicate, projection) pair whose predicate satisfies inputRow + val pair = (predicates zip projections).find { + case (predicate, _) => predicate.eval(inputRow) + } + + pair match { + case Some((_, projections)) => + projections.map(_.apply(inputRow)).iterator + case None => Iterator(noopAction.apply(inputRow)) + } + } + + val toRow = joinedRowEncoder.createSerializer() + val fromRow = outputRowEncoder.createDeserializer() + rowIterator + .map(toRow) + .flatMap(processRow) + .filter(!shouldDeleteRow(_)) + .map { notDeletedInternalRow => + fromRow(outputProj(notDeletedInternalRow)) + } + } + } + + /** Count the number of distinct partition values among the AddFiles in the given set. */ + def totalBytesAndDistinctPartitionValues(files: Seq[FileAction]): (Long, Int) = { + val distinctValues = new mutable.HashSet[Map[String, String]]() + var bytes = 0L + val iter = files.collect { case a: AddFile => a }.iterator + while (iter.hasNext) { + val file = iter.next() + distinctValues += file.partitionValues + bytes += file.size + } + // If the only distinct value map is an empty map, then it must be an unpartitioned table. + // Return 0 in that case. + val numDistinctValues = + if (distinctValues.size == 1 && distinctValues.head.isEmpty) 0 else distinctValues.size + (bytes, numDistinctValues) + } +} diff --git a/delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/GpuOptimisticTransaction.scala b/delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/GpuOptimisticTransaction.scala new file mode 100644 index 00000000000..38ee8a786c0 --- /dev/null +++ b/delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/GpuOptimisticTransaction.scala @@ -0,0 +1,274 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * This file was derived from OptimisticTransaction.scala and TransactionalWrite.scala + * in the Delta Lake project at https://github.com/delta-io/delta. + * + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta.rapids.delta23x + +import java.net.URI + +import scala.collection.mutable.ListBuffer + +import com.nvidia.spark.rapids._ +import com.nvidia.spark.rapids.delta._ +import org.apache.commons.lang3.exception.ExceptionUtils +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkException +import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.delta._ +import org.apache.spark.sql.delta.actions.{AddFile, FileAction} +import org.apache.spark.sql.delta.constraints.{Constraint, Constraints} +import org.apache.spark.sql.delta.rapids.GpuOptimisticTransactionBase +import org.apache.spark.sql.delta.schema.InvariantViolationException +import org.apache.spark.sql.delta.sources.DeltaSQLConf +import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, FileFormatWriter} +import org.apache.spark.sql.functions.to_json +import org.apache.spark.sql.rapids.{BasicColumnarWriteJobStatsTracker, ColumnarWriteJobStatsTracker, GpuFileFormatWriter, GpuWriteJobStatsTracker} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.{Clock, SerializableConfiguration} + +/** + * Used to perform a set of reads in a transaction and then commit a set of updates to the + * state of the log. All reads from the DeltaLog, MUST go through this instance rather + * than directly to the DeltaLog otherwise they will not be check for logical conflicts + * with concurrent updates. + * + * This class is not thread-safe. + * + * @param deltaLog The Delta Log for the table this transaction is modifying. + * @param snapshot The snapshot that this transaction is reading at. + * @param rapidsConf RAPIDS Accelerator config settings. + */ +class GpuOptimisticTransaction + (deltaLog: DeltaLog, snapshot: Snapshot, rapidsConf: RapidsConf) + (implicit clock: Clock) + extends GpuOptimisticTransactionBase(deltaLog, snapshot, rapidsConf)(clock) { + + /** Creates a new OptimisticTransaction. + * + * @param deltaLog The Delta Log for the table this transaction is modifying. + * @param rapidsConf RAPIDS Accelerator config settings + */ + def this(deltaLog: DeltaLog, rapidsConf: RapidsConf)(implicit clock: Clock) = { + this(deltaLog, deltaLog.update(), rapidsConf) + } + + private def getGpuStatsColExpr( + statsDataSchema: Seq[Attribute], + statsCollection: GpuStatisticsCollection): Expression = { + Dataset.ofRows(spark, LocalRelation(statsDataSchema)) + .select(to_json(statsCollection.statsCollector)) + .queryExecution.analyzed.expressions.head + } + + /** Return the pair of optional stats tracker and stats collection class */ + private def getOptionalGpuStatsTrackerAndStatsCollection( + output: Seq[Attribute], + partitionSchema: StructType, data: DataFrame): ( + Option[GpuDeltaJobStatisticsTracker], + Option[GpuStatisticsCollection]) = { + if (spark.sessionState.conf.getConf(DeltaSQLConf.DELTA_COLLECT_STATS)) { + + val (statsDataSchema, statsCollectionSchema) = getStatsSchema(output, partitionSchema) + + val indexedCols = DeltaConfigs.DATA_SKIPPING_NUM_INDEXED_COLS.fromMetaData(metadata) + val prefixLength = + spark.sessionState.conf.getConf(DeltaSQLConf.DATA_SKIPPING_STRING_PREFIX_LENGTH) + val tableSchema = { + // If collecting stats using the table schema, then pass in statsCollectionSchema. + // Otherwise pass in statsDataSchema to collect stats using the DataFrame schema. + if (spark.sessionState.conf.getConf(DeltaSQLConf + .DELTA_COLLECT_STATS_USING_TABLE_SCHEMA)) { + statsCollectionSchema.toStructType + } else { + statsDataSchema.toStructType + } + } + + val _spark = spark + + val statsCollection = new GpuStatisticsCollection { + override val spark = _spark + override val deletionVectorsSupported = false + override val tableDataSchema = tableSchema + override val dataSchema = statsDataSchema.toStructType + override val numIndexedCols = indexedCols + override val stringPrefixLength: Int = prefixLength + } + + val statsColExpr = getGpuStatsColExpr(statsDataSchema, statsCollection) + + val statsSchema = statsCollection.statCollectionSchema + val explodedDataSchema = statsCollection.explodedDataSchema + val batchStatsToRow = (batch: ColumnarBatch, row: InternalRow) => { + GpuStatisticsCollection.batchStatsToRow(statsSchema, explodedDataSchema, batch, row) + } + (Some(new GpuDeltaJobStatisticsTracker(statsDataSchema, statsColExpr, batchStatsToRow)), + Some(statsCollection)) + } else { + (None, None) + } + } + + override def writeFiles( + inputData: Dataset[_], + writeOptions: Option[DeltaOptions], + additionalConstraints: Seq[Constraint]): Seq[FileAction] = { + hasWritten = true + + val spark = inputData.sparkSession + val (data, partitionSchema) = performCDCPartition(inputData) + val outputPath = deltaLog.dataPath + + val (normalizedQueryExecution, output, generatedColumnConstraints, _) = + normalizeData(deltaLog, data) + + // Build a new plan with a stub GpuDeltaWrite node to work around undesired transitions between + // columns and rows when AQE is involved. Without this node in the plan, AdaptiveSparkPlanExec + // could be the root node of the plan. In that case we do not have enough context to know + // whether the AdaptiveSparkPlanExec should be columnar or not, since the GPU overrides do not + // see how the parent is using the AdaptiveSparkPlanExec outputs. By using this stub node that + // appears to be a data writing node to AQE (it derives from V2CommandExec), the + // AdaptiveSparkPlanExec will be planned as a child of this new node. That provides enough + // context to plan the AQE sub-plan properly with respect to columnar and row transitions. + // We could force the AQE node to be columnar here by explicitly replacing the node, but that + // breaks the connection between the queryExecution and the node that will actually execute. + val gpuWritePlan = Dataset.ofRows(spark, RapidsDeltaWrite(normalizedQueryExecution.logical)) + val queryExecution = gpuWritePlan.queryExecution + + val partitioningColumns = getPartitioningColumns(partitionSchema, output) + + val committer = getCommitter(outputPath) + + // If Statistics Collection is enabled, then create a stats tracker that will be injected during + // the FileFormatWriter.write call below and will collect per-file stats using + // StatisticsCollection + val (optionalStatsTracker, _) = getOptionalGpuStatsTrackerAndStatsCollection(output, + partitionSchema, data) + + val constraints = + Constraints.getAll(metadata, spark) ++ generatedColumnConstraints ++ additionalConstraints + + SQLExecution.withNewExecutionId(queryExecution, Option("deltaTransactionalWrite")) { + val outputSpec = FileFormatWriter.OutputSpec( + outputPath.toString, + Map.empty, + output) + + // Remove any unnecessary row conversions added as part of Spark planning + val queryPhysicalPlan = queryExecution.executedPlan match { + case GpuColumnarToRowExec(child, _) => child + case p => p + } + val gpuRapidsWrite = queryPhysicalPlan match { + case g: GpuRapidsDeltaWriteExec => Some(g) + case _ => None + } + + val empty2NullPlan = convertEmptyToNullIfNeeded(queryPhysicalPlan, + partitioningColumns, constraints) + val planWithInvariants = addInvariantChecks(empty2NullPlan, constraints) + val physicalPlan = convertToGpu(planWithInvariants) + + val statsTrackers: ListBuffer[ColumnarWriteJobStatsTracker] = ListBuffer() + + val hadoopConf = spark.sessionState.newHadoopConfWithOptions( + metadata.configuration ++ deltaLog.options) + + if (spark.conf.get(DeltaSQLConf.DELTA_HISTORY_METRICS_ENABLED)) { + val serializableHadoopConf = new SerializableConfiguration(hadoopConf) + val basicWriteJobStatsTracker = new BasicColumnarWriteJobStatsTracker( + serializableHadoopConf, + BasicWriteJobStatsTracker.metrics) + registerSQLMetrics(spark, basicWriteJobStatsTracker.driverSideMetrics) + statsTrackers.append(basicWriteJobStatsTracker) + gpuRapidsWrite.foreach { grw => + val tracker = new GpuWriteJobStatsTracker(serializableHadoopConf, + grw.basicMetrics, grw.taskMetrics) + statsTrackers.append(tracker) + } + } + + // Retain only a minimal selection of Spark writer options to avoid any potential + // compatibility issues + val options = writeOptions match { + case None => Map.empty[String, String] + case Some(writeOptions) => + writeOptions.options.filterKeys { key => + key.equalsIgnoreCase(DeltaOptions.MAX_RECORDS_PER_FILE) || + key.equalsIgnoreCase(DeltaOptions.COMPRESSION) + }.toMap + } + + val deltaFileFormat = deltaLog.fileFormat(metadata) + val gpuFileFormat = if (deltaFileFormat.getClass == classOf[DeltaParquetFileFormat]) { + new GpuParquetFileFormat + } else { + throw new IllegalStateException(s"file format $deltaFileFormat is not supported") + } + + try { + GpuFileFormatWriter.write( + sparkSession = spark, + plan = physicalPlan, + fileFormat = gpuFileFormat, + committer = committer, + outputSpec = outputSpec, + hadoopConf = hadoopConf, + partitionColumns = partitioningColumns, + bucketSpec = None, + statsTrackers = optionalStatsTracker.toSeq ++ statsTrackers, + options = options, + rapidsConf.stableSort, + rapidsConf.concurrentWriterPartitionFlushSize) + } catch { + case s: SparkException => + // Pull an InvariantViolationException up to the top level if it was the root cause. + val violationException = ExceptionUtils.getRootCause(s) + if (violationException.isInstanceOf[InvariantViolationException]) { + throw violationException + } else { + throw s + } + } + } + + val resultFiles = committer.addedStatuses.map { a => + a.copy(stats = optionalStatsTracker.map( + _.recordedStats(new Path(new URI(a.path)).getName)).getOrElse(a.stats)) + }.filter { + // In some cases, we can write out an empty `inputData`. Some examples of this (though, they + // may be fixed in the future) are the MERGE command when you delete with empty source, or + // empty target, or on disjoint tables. This is hard to catch before the write without + // collecting the DF ahead of time. Instead, we can return only the AddFiles that + // a) actually add rows, or + // b) don't have any stats so we don't know the number of rows at all + case a: AddFile => a.numLogicalRecords.forall(_ > 0) + case _ => true + } + + resultFiles.toSeq ++ committer.changeFiles + } +} diff --git a/delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/GpuUpdateCommand.scala b/delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/GpuUpdateCommand.scala new file mode 100644 index 00000000000..dd425f044d6 --- /dev/null +++ b/delta-lake/delta-23x/src/main/scala/org/apache/spark/sql/delta/rapids/delta23x/GpuUpdateCommand.scala @@ -0,0 +1,275 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * This file was derived from UpdateCommand.scala + * in the Delta Lake project at https://github.com/delta-io/delta. + * + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta.rapids.delta23x + +import com.nvidia.spark.rapids.delta.GpuDeltaMetricUpdateUDF +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkContext +import org.apache.spark.sql.{Column, Dataset, Row, SparkSession} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, Literal} +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.delta.{DeltaLog, DeltaOperations, DeltaTableUtils, DeltaUDF, OptimisticTransaction} +import org.apache.spark.sql.delta.actions.{AddCDCFile, AddFile, FileAction} +import org.apache.spark.sql.delta.commands.{DeltaCommand, UpdateCommand, UpdateMetric} +import org.apache.spark.sql.delta.files.{TahoeBatchFileIndex, TahoeFileIndex} +import org.apache.spark.sql.delta.rapids.GpuDeltaLog +import org.apache.spark.sql.execution.command.LeafRunnableCommand +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.metric.SQLMetrics.{createMetric, createTimingMetric} +import org.apache.spark.sql.functions.input_file_name +import org.apache.spark.sql.types.LongType + +case class GpuUpdateCommand( + gpuDeltaLog: GpuDeltaLog, + tahoeFileIndex: TahoeFileIndex, + target: LogicalPlan, + updateExpressions: Seq[Expression], + condition: Option[Expression]) + extends LeafRunnableCommand with DeltaCommand { + + override val output: Seq[Attribute] = { + Seq(AttributeReference("num_affected_rows", LongType)()) + } + + override def innerChildren: Seq[QueryPlan[_]] = Seq(target) + + @transient private lazy val sc: SparkContext = SparkContext.getOrCreate() + + override lazy val metrics = Map[String, SQLMetric]( + "numAddedFiles" -> createMetric(sc, "number of files added."), + "numAddedBytes" -> createMetric(sc, "number of bytes added."), + "numRemovedFiles" -> createMetric(sc, "number of files removed."), + "numRemovedBytes" -> createMetric(sc, "number of bytes removed."), + "numUpdatedRows" -> createMetric(sc, "number of rows updated."), + "numCopiedRows" -> createMetric(sc, "number of rows copied."), + "executionTimeMs" -> + createTimingMetric(sc, "time taken to execute the entire operation"), + "scanTimeMs" -> + createTimingMetric(sc, "time taken to scan the files for matches"), + "rewriteTimeMs" -> + createTimingMetric(sc, "time taken to rewrite the matched files"), + "numAddedChangeFiles" -> createMetric(sc, "number of change data capture files generated"), + "changeFileBytes" -> createMetric(sc, "total size of change data capture files generated"), + "numTouchedRows" -> createMetric(sc, "number of rows touched (copied + updated)") + ) + + final override def run(sparkSession: SparkSession): Seq[Row] = { + recordDeltaOperation(tahoeFileIndex.deltaLog, "delta.dml.update") { + val deltaLog = tahoeFileIndex.deltaLog + gpuDeltaLog.withNewTransaction { txn => + DeltaLog.assertRemovable(txn.snapshot) + if (hasBeenExecuted(txn, sparkSession)) { + sendDriverMetrics(sparkSession, metrics) + return Seq.empty + } + performUpdate(sparkSession, deltaLog, txn) + } + // Re-cache all cached plans(including this relation itself, if it's cached) that refer to + // this data source relation. + sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, target) + } + Seq(Row(metrics("numUpdatedRows").value)) + } + + private def performUpdate( + sparkSession: SparkSession, deltaLog: DeltaLog, txn: OptimisticTransaction): Unit = { + import org.apache.spark.sql.delta.implicits._ + + var numTouchedFiles: Long = 0 + var numRewrittenFiles: Long = 0 + var numAddedBytes: Long = 0 + var numRemovedBytes: Long = 0 + var numAddedChangeFiles: Long = 0 + var changeFileBytes: Long = 0 + var scanTimeMs: Long = 0 + var rewriteTimeMs: Long = 0 + + val startTime = System.nanoTime() + val numFilesTotal = txn.snapshot.numOfFiles + + val updateCondition = condition.getOrElse(Literal.TrueLiteral) + val (metadataPredicates, dataPredicates) = + DeltaTableUtils.splitMetadataAndDataPredicates( + updateCondition, txn.metadata.partitionColumns, sparkSession) + val candidateFiles = txn.filterFiles(metadataPredicates ++ dataPredicates) + val nameToAddFile = generateCandidateFileMap(deltaLog.dataPath, candidateFiles) + + scanTimeMs = (System.nanoTime() - startTime) / 1000 / 1000 + + val filesToRewrite: Seq[AddFile] = if (candidateFiles.isEmpty) { + // Case 1: Do nothing if no row qualifies the partition predicates + // that are part of Update condition + Nil + } else if (dataPredicates.isEmpty) { + // Case 2: Update all the rows from the files that are in the specified partitions + // when the data filter is empty + candidateFiles + } else { + // Case 3: Find all the affected files using the user-specified condition + val fileIndex = new TahoeBatchFileIndex( + sparkSession, "update", candidateFiles, deltaLog, tahoeFileIndex.path, txn.snapshot) + // Keep everything from the resolved target except a new TahoeFileIndex + // that only involves the affected files instead of all files. + val newTarget = DeltaTableUtils.replaceFileIndex(target, fileIndex) + val data = Dataset.ofRows(sparkSession, newTarget) + val updatedRowCount = metrics("numUpdatedRows") + val updatedRowUdf = DeltaUDF.boolean { + new GpuDeltaMetricUpdateUDF(updatedRowCount) + }.asNondeterministic() + val pathsToRewrite = + withStatusCode("DELTA", UpdateCommand.FINDING_TOUCHED_FILES_MSG) { + data.filter(new Column(updateCondition)) + .select(input_file_name()) + .filter(updatedRowUdf()) + .distinct() + .as[String] + .collect() + } + + scanTimeMs = (System.nanoTime() - startTime) / 1000 / 1000 + + pathsToRewrite.map(getTouchedFile(deltaLog.dataPath, _, nameToAddFile)).toSeq + } + + numTouchedFiles = filesToRewrite.length + + val newActions = if (filesToRewrite.isEmpty) { + // Do nothing if no row qualifies the UPDATE condition + Nil + } else { + // Generate the new files containing the updated values + withStatusCode("DELTA", UpdateCommand.rewritingFilesMsg(filesToRewrite.size)) { + rewriteFiles(sparkSession, txn, tahoeFileIndex.path, + filesToRewrite.map(_.path), nameToAddFile, updateCondition) + } + } + + rewriteTimeMs = (System.nanoTime() - startTime) / 1000 / 1000 - scanTimeMs + + val (changeActions, addActions) = newActions.partition(_.isInstanceOf[AddCDCFile]) + numRewrittenFiles = addActions.size + numAddedBytes = addActions.map(_.getFileSize).sum + numAddedChangeFiles = changeActions.size + changeFileBytes = changeActions.collect { case f: AddCDCFile => f.size }.sum + + val totalActions = if (filesToRewrite.isEmpty) { + // Do nothing if no row qualifies the UPDATE condition + Nil + } else { + // Delete the old files and return those delete actions along with the new AddFile actions for + // files containing the updated values + val operationTimestamp = System.currentTimeMillis() + val deleteActions = filesToRewrite.map(_.removeWithTimestamp(operationTimestamp)) + + numRemovedBytes = filesToRewrite.map(_.getFileSize).sum + deleteActions ++ newActions + } + + metrics("numAddedFiles").set(numRewrittenFiles) + metrics("numAddedBytes").set(numAddedBytes) + metrics("numAddedChangeFiles").set(numAddedChangeFiles) + metrics("changeFileBytes").set(changeFileBytes) + metrics("numRemovedFiles").set(numTouchedFiles) + metrics("numRemovedBytes").set(numRemovedBytes) + metrics("executionTimeMs").set((System.nanoTime() - startTime) / 1000 / 1000) + metrics("scanTimeMs").set(scanTimeMs) + metrics("rewriteTimeMs").set(rewriteTimeMs) + // In the case where the numUpdatedRows is not captured, we can siphon out the metrics from + // the BasicWriteStatsTracker. This is for case 2 where the update condition contains only + // metadata predicates and so the entire partition is re-written. + val outputRows = txn.getMetric("numOutputRows").map(_.value).getOrElse(-1L) + if (metrics("numUpdatedRows").value == 0 && outputRows != 0 && + metrics("numCopiedRows").value == 0) { + // We know that numTouchedRows = numCopiedRows + numUpdatedRows. + // Since an entire partition was re-written, no rows were copied. + // So numTouchedRows == numUpdateRows + metrics("numUpdatedRows").set(metrics("numTouchedRows").value) + } else { + // This is for case 3 where the update condition contains both metadata and data predicates + // so relevant files will have some rows updated and some rows copied. We don't need to + // consider case 1 here, where no files match the update condition, as we know that + // `totalActions` is empty. + metrics("numCopiedRows").set( + metrics("numTouchedRows").value - metrics("numUpdatedRows").value) + } + txn.registerSQLMetrics(sparkSession, metrics) + val finalActions = createSetTransaction(sparkSession, deltaLog).toSeq ++ totalActions + txn.commitIfNeeded(finalActions, DeltaOperations.Update(condition.map(_.toString))) + sendDriverMetrics(sparkSession, metrics) + + recordDeltaEvent( + deltaLog, + "delta.dml.update.stats", + data = UpdateMetric( + condition = condition.map(_.sql).getOrElse("true"), + numFilesTotal, + numTouchedFiles, + numRewrittenFiles, + numAddedChangeFiles, + changeFileBytes, + scanTimeMs, + rewriteTimeMs) + ) + } + + /** + * Scan all the affected files and write out the updated files. + * + * When CDF is enabled, includes the generation of CDC preimage and postimage columns for + * changed rows. + * + * @return the list of [[AddFile]]s and [[AddCDCFile]]s that have been written. + */ + private def rewriteFiles( + spark: SparkSession, + txn: OptimisticTransaction, + rootPath: Path, + inputLeafFiles: Seq[String], + nameToAddFileMap: Map[String, AddFile], + condition: Expression): Seq[FileAction] = { + // Containing the map from the relative file path to AddFile + val baseRelation = buildBaseRelation( + spark, txn, "update", rootPath, inputLeafFiles, nameToAddFileMap) + val newTarget = DeltaTableUtils.replaceFileIndex(target, baseRelation.location) + val targetDf = Dataset.ofRows(spark, newTarget) + + // Number of total rows that we have seen, i.e. are either copying or updating (sum of both). + // This will be used later, along with numUpdatedRows, to determine numCopiedRows. + val numTouchedRows = metrics("numTouchedRows") + val numTouchedRowsUdf = DeltaUDF.boolean { + new GpuDeltaMetricUpdateUDF(numTouchedRows) + }.asNondeterministic() + + val updatedDataFrame = UpdateCommand.withUpdatedColumns( + target, + updateExpressions, + condition, + targetDf + .filter(numTouchedRowsUdf()) + .withColumn(UpdateCommand.CONDITION_COLUMN_NAME, new Column(condition)), + UpdateCommand.shouldOutputCdc(txn)) + + txn.writeFiles(updatedDataFrame) + } +} diff --git a/pom.xml b/pom.xml index e41d79dc3ca..bc10e6d6180 100644 --- a/pom.xml +++ b/pom.xml @@ -298,6 +298,7 @@ delta-lake/delta-21x delta-lake/delta-22x + delta-lake/delta-23x @@ -318,6 +319,7 @@ delta-lake/delta-21x delta-lake/delta-22x + delta-lake/delta-23x @@ -338,6 +340,7 @@ delta-lake/delta-21x delta-lake/delta-22x + delta-lake/delta-23x @@ -358,6 +361,7 @@ delta-lake/delta-21x delta-lake/delta-22x + delta-lake/delta-23x @@ -420,6 +424,7 @@ shim-deps/cloudera delta-lake/delta-21x delta-lake/delta-22x + delta-lake/delta-23x @@ -442,6 +447,7 @@ shim-deps/cloudera delta-lake/delta-21x delta-lake/delta-22x + delta-lake/delta-23x diff --git a/scala2.13/aggregator/pom.xml b/scala2.13/aggregator/pom.xml index c672d0c611f..a33ba487fa9 100644 --- a/scala2.13/aggregator/pom.xml +++ b/scala2.13/aggregator/pom.xml @@ -385,6 +385,12 @@ ${project.version} ${spark.version.classifier} + + com.nvidia + rapids-4-spark-delta-23x_${scala.binary.version} + ${project.version} + ${spark.version.classifier} + @@ -408,6 +414,12 @@ ${project.version} ${spark.version.classifier} + + com.nvidia + rapids-4-spark-delta-23x_${scala.binary.version} + ${project.version} + ${spark.version.classifier} + @@ -431,6 +443,12 @@ ${project.version} ${spark.version.classifier} + + com.nvidia + rapids-4-spark-delta-23x_${scala.binary.version} + ${project.version} + ${spark.version.classifier} + @@ -471,6 +489,12 @@ ${project.version} ${spark.version.classifier} + + com.nvidia + rapids-4-spark-delta-23x_${scala.binary.version} + ${project.version} + ${spark.version.classifier} + @@ -494,6 +518,12 @@ ${project.version} ${spark.version.classifier} + + com.nvidia + rapids-4-spark-delta-23x_${scala.binary.version} + ${project.version} + ${spark.version.classifier} + @@ -534,6 +564,12 @@ ${project.version} ${spark.version.classifier} + + com.nvidia + rapids-4-spark-delta-23x_${scala.binary.version} + ${project.version} + ${spark.version.classifier} + diff --git a/scala2.13/delta-lake/delta-23x/pom.xml b/scala2.13/delta-lake/delta-23x/pom.xml new file mode 100644 index 00000000000..6193d34ab44 --- /dev/null +++ b/scala2.13/delta-lake/delta-23x/pom.xml @@ -0,0 +1,98 @@ + + + + 4.0.0 + + + com.nvidia + rapids-4-spark-parent_2.13 + 23.12.0-SNAPSHOT + ../../pom.xml + + + rapids-4-spark-delta-23x_2.13 + RAPIDS Accelerator for Apache Spark Delta Lake 2.3.x Support + Delta Lake 2.3.x support for the RAPIDS Accelerator for Apache Spark + 23.12.0-SNAPSHOT + + + ../delta-lake/delta-23x + false + **/* + package + + + + + com.nvidia + rapids-4-spark-sql_${scala.binary.version} + ${project.version} + ${spark.version.classifier} + provided + + + io.delta + delta-core_${scala.binary.version} + 2.3.0 + provided + + + org.apache.spark + spark-sql_${scala.binary.version} + + + + + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-common-sources + generate-sources + + add-source + + + + + + ${project.basedir}/../../${rapids.module}/../common/src/main/scala + ${project.basedir}/../../${rapids.module}/../common/src/main/delta-io/scala + + + + + + + + net.alchim31.maven + scala-maven-plugin + + + org.apache.rat + apache-rat-plugin + + + + diff --git a/scala2.13/pom.xml b/scala2.13/pom.xml index f3cb9749433..e36413dee2f 100644 --- a/scala2.13/pom.xml +++ b/scala2.13/pom.xml @@ -298,6 +298,7 @@ delta-lake/delta-21x delta-lake/delta-22x + delta-lake/delta-23x @@ -318,6 +319,7 @@ delta-lake/delta-21x delta-lake/delta-22x + delta-lake/delta-23x @@ -338,6 +340,7 @@ delta-lake/delta-21x delta-lake/delta-22x + delta-lake/delta-23x @@ -358,6 +361,7 @@ delta-lake/delta-21x delta-lake/delta-22x + delta-lake/delta-23x @@ -420,6 +424,7 @@ shim-deps/cloudera delta-lake/delta-21x delta-lake/delta-22x + delta-lake/delta-23x @@ -442,6 +447,7 @@ shim-deps/cloudera delta-lake/delta-21x delta-lake/delta-22x + delta-lake/delta-23x