From beaa93a1ebde1d2bbcaaebd708d4464cbeded0e9 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Mon, 2 Oct 2023 15:12:42 -0500 Subject: [PATCH 1/8] Add support for AtomicCreateTableAsSelect with Delta Lake Signed-off-by: Jason Lowe --- .../tahoe/rapids/GpuDeltaCatalogBase.scala | 442 ++++++++++++++++ .../tahoe/rapids/GpuDeltaLog.scala | 9 + ...pl.scala => DatabricksDeltaProvider.scala} | 58 +- .../rapids/delta/DeleteCommandMeta.scala | 4 +- .../rapids/delta/MergeIntoCommandMeta.scala | 6 +- .../spark/rapids/delta/RapidsDeltaUtils.scala | 25 +- .../rapids/delta/UpdateCommandMeta.scala | 4 +- .../spark/rapids/delta/DeltaIOProvider.scala | 33 +- .../spark/rapids/delta/RapidsDeltaUtils.scala | 29 +- .../sql/delta/rapids/DeltaRuntimeShim.scala | 8 + .../delta/rapids/GpuDeltaCatalogBase.scala | 442 ++++++++++++++++ .../spark/sql/delta/rapids/GpuDeltaLog.scala | 9 + .../delta/delta20x/DeleteCommandMeta.scala | 2 +- .../delta/delta20x/Delta20xProvider.scala | 21 +- .../delta/delta20x/MergeIntoCommandMeta.scala | 3 +- .../delta/delta20x/UpdateCommandMeta.scala | 2 +- .../rapids/delta20x/Delta20xRuntimeShim.scala | 8 + .../delta20x/GpuCreateDeltaTableCommand.scala | 465 +++++++++++++++++ .../rapids/delta20x/GpuDeltaCatalog.scala | 111 ++++ .../delta/delta21x/DeleteCommandMeta.scala | 2 +- .../delta/delta21x/Delta21xProvider.scala | 21 +- .../delta/delta21x/MergeIntoCommandMeta.scala | 3 +- .../delta/delta21x/UpdateCommandMeta.scala | 2 +- .../rapids/delta21x/Delta21xRuntimeShim.scala | 7 + .../delta21x/GpuCreateDeltaTableCommand.scala | 466 +++++++++++++++++ .../rapids/delta21x/GpuDeltaCatalog.scala | 120 +++++ .../delta/delta22x/DeleteCommandMeta.scala | 2 +- .../delta/delta22x/Delta22xProvider.scala | 21 +- .../delta/delta22x/MergeIntoCommandMeta.scala | 3 +- .../delta/delta22x/UpdateCommandMeta.scala | 2 +- .../rapids/delta22x/Delta22xRuntimeShim.scala | 8 + .../delta22x/GpuCreateDeltaTableCommand.scala | 465 +++++++++++++++++ .../rapids/delta22x/GpuDeltaCatalog.scala | 220 ++++++++ .../delta/delta24x/DeleteCommandMeta.scala | 2 +- .../delta/delta24x/Delta24xProvider.scala | 21 +- .../delta/delta24x/GpuDeltaCatalog.scala | 193 +++++++ .../delta/delta24x/MergeIntoCommandMeta.scala | 3 +- .../delta/delta24x/UpdateCommandMeta.scala | 2 +- .../rapids/delta24x/Delta24xRuntimeShim.scala | 9 +- .../delta24x/GpuCreateDeltaTableCommand.scala | 494 ++++++++++++++++++ .../rapids/GpuCreateDeltaTableCommand.scala | 455 ++++++++++++++++ .../tahoe/rapids/GpuDeltaCatalog.scala | 109 ++++ .../rapids/GpuCreateDeltaTableCommand.scala | 464 ++++++++++++++++ .../tahoe/rapids/GpuDeltaCatalog.scala | 218 ++++++++ .../rapids/GpuCreateDeltaTableCommand.scala | 464 ++++++++++++++++ .../tahoe/rapids/GpuDeltaCatalog.scala | 218 ++++++++ .../src/main/python/delta_lake_write_test.py | 20 + .../AtomicCreateTableAsSelectExecMeta.scala | 36 ++ .../spark/rapids/delta/DeltaProvider.scala | 28 +- .../spark/sql/rapids/ExternalSource.scala | 23 + .../rapids/shims/Spark320PlusShims.scala | 8 + .../GpuAtomicCreateTableAsSelectExec.scala | 85 +++ .../GpuAtomicCreateTableAsSelectExec.scala | 86 +++ .../GpuAtomicCreateTableAsSelectExec.scala | 86 +++ .../GpuAtomicCreateTableAsSelectExec.scala | 84 +++ .../rapids/execution/ShimTrampolineUtil.scala | 5 + .../GpuAtomicCreateTableAsSelectExec.scala | 75 +++ 57 files changed, 6155 insertions(+), 56 deletions(-) create mode 100644 delta-lake/common/src/main/databricks/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDeltaCatalogBase.scala rename delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/{DeltaProviderImpl.scala => DatabricksDeltaProvider.scala} (70%) create mode 100644 delta-lake/common/src/main/delta-io/scala/org/apache/spark/sql/delta/rapids/GpuDeltaCatalogBase.scala create mode 100644 delta-lake/delta-20x/src/main/scala/org/apache/spark/sql/delta/rapids/delta20x/GpuCreateDeltaTableCommand.scala create mode 100644 delta-lake/delta-20x/src/main/scala/org/apache/spark/sql/delta/rapids/delta20x/GpuDeltaCatalog.scala create mode 100644 delta-lake/delta-21x/src/main/scala/org/apache/spark/sql/delta/rapids/delta21x/GpuCreateDeltaTableCommand.scala create mode 100644 delta-lake/delta-21x/src/main/scala/org/apache/spark/sql/delta/rapids/delta21x/GpuDeltaCatalog.scala create mode 100644 delta-lake/delta-22x/src/main/scala/org/apache/spark/sql/delta/rapids/delta22x/GpuCreateDeltaTableCommand.scala create mode 100644 delta-lake/delta-22x/src/main/scala/org/apache/spark/sql/delta/rapids/delta22x/GpuDeltaCatalog.scala create mode 100644 delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/GpuDeltaCatalog.scala create mode 100644 delta-lake/delta-24x/src/main/scala/org/apache/spark/sql/delta/rapids/delta24x/GpuCreateDeltaTableCommand.scala create mode 100644 delta-lake/delta-spark321db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuCreateDeltaTableCommand.scala create mode 100644 delta-lake/delta-spark321db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDeltaCatalog.scala create mode 100644 delta-lake/delta-spark330db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuCreateDeltaTableCommand.scala create mode 100644 delta-lake/delta-spark330db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDeltaCatalog.scala create mode 100644 delta-lake/delta-spark332db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuCreateDeltaTableCommand.scala create mode 100644 delta-lake/delta-spark332db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDeltaCatalog.scala create mode 100644 sql-plugin/src/main/scala/com/nvidia/spark/rapids/AtomicCreateTableAsSelectExecMeta.scala create mode 100644 sql-plugin/src/main/spark320/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala create mode 100644 sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala create mode 100644 sql-plugin/src/main/spark330/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala create mode 100644 sql-plugin/src/main/spark340/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala create mode 100644 sql-plugin/src/main/spark350/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala diff --git a/delta-lake/common/src/main/databricks/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDeltaCatalogBase.scala b/delta-lake/common/src/main/databricks/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDeltaCatalogBase.scala new file mode 100644 index 00000000000..2c3253b7c90 --- /dev/null +++ b/delta-lake/common/src/main/databricks/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDeltaCatalogBase.scala @@ -0,0 +1,442 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.databricks.sql.transaction.tahoe.rapids + +import java.util +import java.util.Locale + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import com.databricks.sql.transaction.tahoe.{DeltaErrors, DeltaLog, DeltaOptions} +import com.databricks.sql.transaction.tahoe.catalog.BucketTransform +import com.databricks.sql.transaction.tahoe.commands.{TableCreationModes, WriteIntoDelta} +import com.databricks.sql.transaction.tahoe.sources.{DeltaSourceUtils, DeltaSQLConf} +import com.nvidia.spark.rapids.RapidsConf +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTableType, CatalogUtils, SessionCatalog} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.connector.catalog.{Identifier, StagedTable, StagingTableCatalog, SupportsWrite, Table, TableCapability, TableCatalog, TableChange} +import org.apache.spark.sql.connector.catalog.TableCapability._ +import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, V1Write, WriteBuilder} +import org.apache.spark.sql.execution.command.LeafRunnableCommand +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.InsertableRelation +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +trait GpuDeltaCatalogBase extends StagingTableCatalog { + val spark: SparkSession = SparkSession.active + + val cpuCatalog: StagingTableCatalog + + val rapidsConf: RapidsConf + + protected def buildGpuCreateDeltaTableCommand( + rapidsConf: RapidsConf, + table: CatalogTable, + existingTableOpt: Option[CatalogTable], + mode: SaveMode, + query: Option[LogicalPlan], + operation: TableCreationModes.CreationMode, + tableByPath: Boolean): LeafRunnableCommand + + protected def getExistingTableIfExists(table: TableIdentifier): Option[CatalogTable] + + protected def verifyTableAndSolidify( + tableDesc: CatalogTable, + query: Option[LogicalPlan]): CatalogTable + + protected def createGpuStagedDeltaTableV2( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String], + operation: TableCreationModes.CreationMode): StagedTable = { + new GpuStagedDeltaTableV2(ident, schema, partitions, properties, operation) + } + + /** + * Creates a Delta table using GPU for writing the data + * + * @param ident The identifier of the table + * @param schema The schema of the table + * @param partitions The partition transforms for the table + * @param allTableProperties The table properties that configure the behavior of the table or + * provide information about the table + * @param writeOptions Options specific to the write during table creation or replacement + * @param sourceQuery A query if this CREATE request came from a CTAS or RTAS + * @param operation The specific table creation mode, whether this is a + * Create/Replace/Create or Replace + */ + protected def createDeltaTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + allTableProperties: util.Map[String, String], + writeOptions: Map[String, String], + sourceQuery: Option[DataFrame], + operation: TableCreationModes.CreationMode): Table = { + // These two keys are tableProperties in data source v2 but not in v1, so we have to filter + // them out. Otherwise property consistency checks will fail. + val tableProperties = allTableProperties.asScala.filterKeys { + case TableCatalog.PROP_LOCATION => false + case TableCatalog.PROP_PROVIDER => false + case TableCatalog.PROP_COMMENT => false + case TableCatalog.PROP_OWNER => false + case TableCatalog.PROP_EXTERNAL => false + case "path" => false + case "option.path" => false + case _ => true + }.toMap + val (partitionColumns, maybeBucketSpec) = convertTransforms(partitions) + val newSchema = schema + val newPartitionColumns = partitionColumns + val newBucketSpec = maybeBucketSpec + val conf = spark.sessionState.conf + + val isByPath = isPathIdentifier(ident) + if (isByPath && !conf.getConf(DeltaSQLConf.DELTA_LEGACY_ALLOW_AMBIGUOUS_PATHS) + && allTableProperties.containsKey("location") + // The location property can be qualified and different from the path in the identifier, so + // we check `endsWith` here. + && Option(allTableProperties.get("location")).exists(!_.endsWith(ident.name())) + ) { + throw DeltaErrors.ambiguousPathsInCreateTableException( + ident.name(), allTableProperties.get("location")) + } + val location = if (isByPath) { + Option(ident.name()) + } else { + Option(allTableProperties.get("location")) + } + val id = { + TableIdentifier(ident.name(), ident.namespace().lastOption) + } + val locUriOpt = location.map(CatalogUtils.stringToURI) + val existingTableOpt = getExistingTableIfExists(id) + val loc = locUriOpt + .orElse(existingTableOpt.flatMap(_.storage.locationUri)) + .getOrElse(spark.sessionState.catalog.defaultTablePath(id)) + val storage = DataSource.buildStorageFormatFromOptions(writeOptions) + .copy(locationUri = Option(loc)) + val tableType = + if (location.isDefined) CatalogTableType.EXTERNAL else CatalogTableType.MANAGED + val commentOpt = Option(allTableProperties.get("comment")) + + + val tableDesc = new CatalogTable( + identifier = id, + tableType = tableType, + storage = storage, + schema = newSchema, + provider = Some(DeltaSourceUtils.ALT_NAME), + partitionColumnNames = newPartitionColumns, + bucketSpec = newBucketSpec, + properties = tableProperties, + comment = commentOpt + ) + + val withDb = verifyTableAndSolidify(tableDesc, None) + + val writer = sourceQuery.map { df => + WriteIntoDelta( + DeltaLog.forTable(spark, new Path(loc)), + operation.mode, + new DeltaOptions(withDb.storage.properties, spark.sessionState.conf), + withDb.partitionColumnNames, + withDb.properties ++ commentOpt.map("comment" -> _), + df, + schemaInCatalog = if (newSchema != schema) Some(newSchema) else None) + } + + val gpuCreateTableCommand = buildGpuCreateDeltaTableCommand( + rapidsConf, + withDb, + existingTableOpt, + operation.mode, + writer, + operation, + tableByPath = isByPath) + gpuCreateTableCommand.run(spark) + + cpuCatalog.loadTable(ident) + } + + override def loadTable(ident: Identifier): Table = cpuCatalog.loadTable(ident) + + private def getProvider(properties: util.Map[String, String]): String = { + Option(properties.get("provider")) + .getOrElse(spark.sessionState.conf.getConf(SQLConf.DEFAULT_DATA_SOURCE_NAME)) + } + + override def createTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = { + if (DeltaSourceUtils.isDeltaDataSourceName(getProvider(properties))) { + createDeltaTable( + ident, + schema, + partitions, + properties, + Map.empty, + sourceQuery = None, + TableCreationModes.Create + ) + } else { + throw new IllegalStateException(s"Cannot create non-Delta tables") + } + } + + override def stageReplace( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): StagedTable = { + if (DeltaSourceUtils.isDeltaDataSourceName(getProvider(properties))) { + createGpuStagedDeltaTableV2( + ident, + schema, + partitions, + properties, + TableCreationModes.Replace + ) + } else { + throw new IllegalStateException(s"Cannot create non-Delta tables") + } + } + + override def stageCreateOrReplace( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): StagedTable = { + if (DeltaSourceUtils.isDeltaDataSourceName(getProvider(properties))) { + createGpuStagedDeltaTableV2( + ident, + schema, + partitions, + properties, + TableCreationModes.CreateOrReplace + ) + } else { + throw new IllegalStateException(s"Cannot create non-Delta tables") + } + } + + override def stageCreate( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): StagedTable = { + if (DeltaSourceUtils.isDeltaDataSourceName(getProvider(properties))) { + createGpuStagedDeltaTableV2( + ident, + schema, + partitions, + properties, + TableCreationModes.Create + ) + } else { + throw new IllegalStateException(s"Cannot create non-Delta tables") + } + } + + // Copy of V2SessionCatalog.convertTransforms, which is private. + private def convertTransforms(partitions: Seq[Transform]): (Seq[String], Option[BucketSpec]) = { + val identityCols = new mutable.ArrayBuffer[String] + var bucketSpec = Option.empty[BucketSpec] + + partitions.map { + case IdentityTransform(FieldReference(Seq(col))) => + identityCols += col + + case BucketTransform(numBuckets, bucketCols, sortCols) => + bucketSpec = Some(BucketSpec( + numBuckets, bucketCols.map(_.fieldNames.head), sortCols.map(_.fieldNames.head))) + + case _ => + throw DeltaErrors.operationNotSupportedException(s"Partitioning by expressions") + } + + (identityCols.toSeq, bucketSpec) + } + + /** + * A staged Delta table, which creates a HiveMetaStore entry and appends data if this was a + * CTAS/RTAS command. We have a ugly way of using this API right now, but it's the best way to + * maintain old behavior compatibility between Databricks Runtime and OSS Delta Lake. + */ + protected class GpuStagedDeltaTableV2( + ident: Identifier, + override val schema: StructType, + val partitions: Array[Transform], + override val properties: util.Map[String, String], + operation: TableCreationModes.CreationMode + ) extends StagedTable with SupportsWrite { + + private var asSelectQuery: Option[DataFrame] = None + private var writeOptions: Map[String, String] = Map.empty + + override def partitioning(): Array[Transform] = partitions + + override def commitStagedChanges(): Unit = { + val conf = spark.sessionState.conf + val props = new util.HashMap[String, String]() + // Options passed in through the SQL API will show up both with an "option." prefix and + // without in Spark 3.1, so we need to remove those from the properties + val optionsThroughProperties = properties.asScala.collect { + case (k, _) if k.startsWith("option.") => k.stripPrefix("option.") + }.toSet + val sqlWriteOptions = new util.HashMap[String, String]() + properties.asScala.foreach { case (k, v) => + if (!k.startsWith("option.") && !optionsThroughProperties.contains(k)) { + // Do not add to properties + props.put(k, v) + } else if (optionsThroughProperties.contains(k)) { + sqlWriteOptions.put(k, v) + } + } + if (writeOptions.isEmpty && !sqlWriteOptions.isEmpty) { + writeOptions = sqlWriteOptions.asScala.toMap + } + if (conf.getConf(DeltaSQLConf.DELTA_LEGACY_STORE_WRITER_OPTIONS_AS_PROPS)) { + // Legacy behavior + writeOptions.foreach { case (k, v) => props.put(k, v) } + } else { + writeOptions.foreach { case (k, v) => + // Continue putting in Delta prefixed options to avoid breaking workloads + if (k.toLowerCase(Locale.ROOT).startsWith("delta.")) { + props.put(k, v) + } + } + } + createDeltaTable( + ident, + schema, + partitions, + props, + writeOptions, + asSelectQuery, + operation + ) + } + + override def name(): String = ident.name() + + override def abortStagedChanges(): Unit = {} + + override def capabilities(): util.Set[TableCapability] = Set(V1_BATCH_WRITE).asJava + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { + writeOptions = info.options.asCaseSensitiveMap().asScala.toMap + new DeltaV1WriteBuilder + } + + /* + * WriteBuilder for creating a Delta table. + */ + private class DeltaV1WriteBuilder extends WriteBuilder { + override def build(): V1Write = new V1Write { + override def toInsertableRelation(): InsertableRelation = { + new InsertableRelation { + override def insert(data: DataFrame, overwrite: Boolean): Unit = { + asSelectQuery = Option(data) + } + } + } + } + } + } + + override def alterTable(ident: Identifier, changes: TableChange*): Table = { + cpuCatalog.alterTable(ident, changes: _*) + } + + override def dropTable(ident: Identifier): Boolean = cpuCatalog.dropTable(ident) + + override def renameTable(oldIdent: Identifier, newIdent: Identifier): Unit = { + cpuCatalog.renameTable(oldIdent, newIdent) + } + + override def listTables(namespace: Array[String]): Array[Identifier] = { + throw new IllegalStateException("should not be called") + } + + override def tableExists(ident: Identifier): Boolean = cpuCatalog.tableExists(ident) + + override def initialize(s: String, caseInsensitiveStringMap: CaseInsensitiveStringMap): Unit = { + throw new IllegalStateException("should not be called") + } + + override def name(): String = cpuCatalog.name() + + private def supportSQLOnFile: Boolean = spark.sessionState.conf.runSQLonFile + + private def hasDeltaNamespace(ident: Identifier): Boolean = { + ident.namespace().length == 1 && DeltaSourceUtils.isDeltaDataSourceName(ident.namespace().head) + } + + private def isPathIdentifier(ident: Identifier): Boolean = { + // Should be a simple check of a special PathIdentifier class in the future + try { + supportSQLOnFile && hasDeltaNamespace(ident) && new Path(ident.name()).isAbsolute + } catch { + case _: IllegalArgumentException => false + } + } +} + + +/** + * A trait for handling table access through delta.`/some/path`. This is a stop-gap solution + * until PathIdentifiers are implemented in Apache Spark. + */ +trait SupportsPathIdentifier extends TableCatalog { self: GpuDeltaCatalogBase => + + private def supportSQLOnFile: Boolean = spark.sessionState.conf.runSQLonFile + + protected lazy val catalog: SessionCatalog = spark.sessionState.catalog + + private def hasDeltaNamespace(ident: Identifier): Boolean = { + ident.namespace().length == 1 && DeltaSourceUtils.isDeltaDataSourceName(ident.namespace().head) + } + + protected def isPathIdentifier(ident: Identifier): Boolean = { + // Should be a simple check of a special PathIdentifier class in the future + try { + supportSQLOnFile && hasDeltaNamespace(ident) && new Path(ident.name()).isAbsolute + } catch { + case _: IllegalArgumentException => false + } + } + + protected def isPathIdentifier(table: CatalogTable): Boolean = { + isPathIdentifier(table.identifier) + } + + protected def isPathIdentifier(tableIdentifier: TableIdentifier) : Boolean = { + isPathIdentifier(Identifier.of(tableIdentifier.database.toArray, tableIdentifier.table)) + } +} diff --git a/delta-lake/common/src/main/databricks/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDeltaLog.scala b/delta-lake/common/src/main/databricks/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDeltaLog.scala index f71dcf98a4f..2927b8607ad 100644 --- a/delta-lake/common/src/main/databricks/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDeltaLog.scala +++ b/delta-lake/common/src/main/databricks/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDeltaLog.scala @@ -18,6 +18,7 @@ package com.databricks.sql.transaction.tahoe.rapids import com.databricks.sql.transaction.tahoe.{DeltaLog, OptimisticTransaction} import com.nvidia.spark.rapids.RapidsConf +import org.apache.hadoop.fs.Path import org.apache.spark.sql.SparkSession import org.apache.spark.util.Clock @@ -70,4 +71,12 @@ object GpuDeltaLog { val deltaLog = DeltaLog.forTable(spark, dataPath, options) new GpuDeltaLog(deltaLog, rapidsConf) } + + def forTable( + spark: SparkSession, + tableLocation: Path, + rapidsConf: RapidsConf): GpuDeltaLog = { + val deltaLog = DeltaLog.forTable(spark, tableLocation) + new GpuDeltaLog(deltaLog, rapidsConf) + } } diff --git a/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/DeltaProviderImpl.scala b/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/DatabricksDeltaProvider.scala similarity index 70% rename from delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/DeltaProviderImpl.scala rename to delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/DatabricksDeltaProvider.scala index 9f7e0896ec6..dc29208cef6 100644 --- a/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/DeltaProviderImpl.scala +++ b/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/DatabricksDeltaProvider.scala @@ -16,22 +16,31 @@ package com.nvidia.spark.rapids.delta +import scala.collection.JavaConverters.mapAsScalaMapConverter + +import com.databricks.sql.managedcatalog.UnityCatalogV2Proxy import com.databricks.sql.transaction.tahoe.{DeltaLog, DeltaParquetFileFormat} +import com.databricks.sql.transaction.tahoe.catalog.DeltaCatalog import com.databricks.sql.transaction.tahoe.commands.{DeleteCommand, DeleteCommandEdge, MergeIntoCommand, MergeIntoCommandEdge, UpdateCommand, UpdateCommandEdge} -import com.databricks.sql.transaction.tahoe.sources.DeltaDataSource +import com.databricks.sql.transaction.tahoe.rapids.GpuDeltaCatalog +import com.databricks.sql.transaction.tahoe.sources.{DeltaDataSource, DeltaSourceUtils} import com.nvidia.spark.rapids._ import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.catalog.StagingTableCatalog import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.execution.datasources.{FileFormat, SaveIntoDataSourceCommand} +import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec +import org.apache.spark.sql.execution.datasources.v2.rapids.GpuAtomicCreateTableAsSelectExec +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.rapids.ExternalSource import org.apache.spark.sql.sources.CreatableRelationProvider /** - * Implements the DeltaProvider interface for Databricks Delta Lake. + * Common implementation of the DeltaProvider interface for all Databricks versions. */ -object DeltaProviderImpl extends DeltaProviderImplBase { +object DatabricksDeltaProvider extends DeltaProviderImplBase { override def getCreatableRelationRules: Map[Class[_ <: CreatableRelationProvider], CreatableRelationProviderRule[_ <: CreatableRelationProvider]] = { Seq( @@ -92,6 +101,43 @@ object DeltaProviderImpl extends DeltaProviderImplBase { val cpuFormat = format.asInstanceOf[DeltaParquetFileFormat] GpuDeltaParquetFileFormat.convertToGpu(cpuFormat) } + + override def isSupportedCatalog(catalogClass: Class[_ <: StagingTableCatalog]): Boolean = { + catalogClass == classOf[DeltaCatalog] || catalogClass == classOf[UnityCatalogV2Proxy] + } + + override def tagForGpu( + cpuExec: AtomicCreateTableAsSelectExec, + meta: AtomicCreateTableAsSelectExecMeta): Unit = { + require(isSupportedCatalog(cpuExec.catalog.getClass)) + if (!meta.conf.isDeltaWriteEnabled) { + meta.willNotWorkOnGpu("Delta Lake output acceleration has been disabled. To enable set " + + s"${RapidsConf.ENABLE_DELTA_WRITE} to true") + } + val properties = cpuExec.properties + val provider = properties.getOrElse("provider", + cpuExec.conf.getConf(SQLConf.DEFAULT_DATA_SOURCE_NAME)) + if (!DeltaSourceUtils.isDeltaDataSourceName(provider)) { + meta.willNotWorkOnGpu(s"table provider '$provider' is not a Delta Lake provider") + } + RapidsDeltaUtils.tagForDeltaWrite(meta, cpuExec.query.schema, None, + cpuExec.writeOptions.asCaseSensitiveMap().asScala.toMap, cpuExec.session) + } + + override def convertToGpu( + cpuExec: AtomicCreateTableAsSelectExec, + meta: AtomicCreateTableAsSelectExecMeta): GpuExec = { + GpuAtomicCreateTableAsSelectExec( + cpuExec.output, + new GpuDeltaCatalog(cpuExec.catalog, meta.conf), + cpuExec.ident, + cpuExec.partitioning, + cpuExec.plan, + meta.childPlans.head.convertIfNeeded(), + cpuExec.tableSpec, + cpuExec.writeOptions, + cpuExec.ifNotExists) + } } class DeltaCreatableRelationProviderMeta( @@ -115,8 +161,8 @@ class DeltaCreatableRelationProviderMeta( val path = saveCmd.options.get("path") if (path.isDefined) { val deltaLog = DeltaLog.forTable(SparkSession.active, path.get, saveCmd.options) - RapidsDeltaUtils.tagForDeltaWrite(this, saveCmd.query.schema, deltaLog, saveCmd.options, - SparkSession.active) + RapidsDeltaUtils.tagForDeltaWrite(this, saveCmd.query.schema, Some(deltaLog), + saveCmd.options, SparkSession.active) } else { willNotWorkOnGpu("no path specified for Delta Lake table") } @@ -131,5 +177,5 @@ class DeltaCreatableRelationProviderMeta( */ class DeltaProbeImpl extends DeltaProbe { // Delta Lake is built-in for Databricks instances, so no probing is necessary. - override def getDeltaProvider: DeltaProvider = DeltaProviderImpl + override def getDeltaProvider: DeltaProvider = DatabricksDeltaProvider } diff --git a/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/DeleteCommandMeta.scala b/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/DeleteCommandMeta.scala index 8bc1b616bd1..02dba974937 100644 --- a/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/DeleteCommandMeta.scala +++ b/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/DeleteCommandMeta.scala @@ -37,7 +37,7 @@ class DeleteCommandMeta( s"${RapidsConf.ENABLE_DELTA_WRITE} to true") } DeleteCommandMetaShim.tagForGpu(this) - RapidsDeltaUtils.tagForDeltaWrite(this, deleteCmd.target.schema, deleteCmd.deltaLog, + RapidsDeltaUtils.tagForDeltaWrite(this, deleteCmd.target.schema, Some(deleteCmd.deltaLog), Map.empty, SparkSession.active) } @@ -62,7 +62,7 @@ class DeleteCommandEdgeMeta( s"${RapidsConf.ENABLE_DELTA_WRITE} to true") } DeleteCommandMetaShim.tagForGpu(this) - RapidsDeltaUtils.tagForDeltaWrite(this, deleteCmd.target.schema, deleteCmd.deltaLog, + RapidsDeltaUtils.tagForDeltaWrite(this, deleteCmd.target.schema, Some(deleteCmd.deltaLog), Map.empty, SparkSession.active) } diff --git a/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/MergeIntoCommandMeta.scala b/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/MergeIntoCommandMeta.scala index 6d8c2cf7f44..540c01e8cc8 100644 --- a/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/MergeIntoCommandMeta.scala +++ b/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/MergeIntoCommandMeta.scala @@ -38,7 +38,8 @@ class MergeIntoCommandMeta( MergeIntoCommandMetaShim.tagForGpu(this, mergeCmd) val targetSchema = mergeCmd.migratedSchema.getOrElse(mergeCmd.target.schema) val deltaLog = mergeCmd.targetFileIndex.deltaLog - RapidsDeltaUtils.tagForDeltaWrite(this, targetSchema, deltaLog, Map.empty, SparkSession.active) + RapidsDeltaUtils.tagForDeltaWrite(this, targetSchema, Some(deltaLog), Map.empty, + SparkSession.active) } override def convertToGpu(): RunnableCommand = @@ -60,7 +61,8 @@ class MergeIntoCommandEdgeMeta( MergeIntoCommandMetaShim.tagForGpu(this, mergeCmd) val targetSchema = mergeCmd.migratedSchema.getOrElse(mergeCmd.target.schema) val deltaLog = mergeCmd.targetFileIndex.deltaLog - RapidsDeltaUtils.tagForDeltaWrite(this, targetSchema, deltaLog, Map.empty, SparkSession.active) + RapidsDeltaUtils.tagForDeltaWrite(this, targetSchema, Some(deltaLog), Map.empty, + SparkSession.active) } override def convertToGpu(): RunnableCommand = diff --git a/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/RapidsDeltaUtils.scala b/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/RapidsDeltaUtils.scala index 17511390fbe..519f3ebef0f 100644 --- a/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/RapidsDeltaUtils.scala +++ b/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/RapidsDeltaUtils.scala @@ -29,12 +29,13 @@ object RapidsDeltaUtils { def tagForDeltaWrite( meta: RapidsMeta[_, _, _], schema: StructType, - deltaLog: DeltaLog, + deltaLog: Option[DeltaLog], options: Map[String, String], spark: SparkSession): Unit = { FileFormatChecks.tag(meta, schema, DeltaFormatType, WriteFileOp) - val format = DeltaLogShim.fileFormat(deltaLog) - if (format.getClass == classOf[DeltaParquetFileFormat]) { + val format = deltaLog.map(log => DeltaLogShim.fileFormat(log).getClass) + .getOrElse(classOf[DeltaParquetFileFormat]) + if (format == classOf[DeltaParquetFileFormat]) { GpuParquetFileFormat.tagGpuSupport(meta, spark, options, schema) } else { meta.willNotWorkOnGpu(s"file format $format is not supported") @@ -45,7 +46,7 @@ object RapidsDeltaUtils { private def checkIncompatibleConfs( meta: RapidsMeta[_, _, _], schema: StructType, - deltaLog: DeltaLog, + deltaLog: Option[DeltaLog], sqlConf: SQLConf, options: Map[String, String]): Unit = { def getSQLConf(key: String): Option[String] = { @@ -65,19 +66,21 @@ object RapidsDeltaUtils { orderableTypeSig.isSupportedByPlugin(t) } if (unorderableTypes.nonEmpty) { - val metadata = DeltaLogShim.getMetadata(deltaLog) - val hasPartitioning = metadata.partitionColumns.nonEmpty || + val metadata = deltaLog.map(log => DeltaLogShim.getMetadata(log)) + val hasPartitioning = metadata.exists(_.partitionColumns.nonEmpty) || options.get(DataSourceUtils.PARTITIONING_COLUMNS_KEY).exists(_.nonEmpty) if (!hasPartitioning) { val optimizeWriteEnabled = { val deltaOptions = new DeltaOptions(options, sqlConf) deltaOptions.optimizeWrite.orElse { getSQLConf("spark.databricks.delta.optimizeWrite.enabled").map(_.toBoolean).orElse { - DeltaConfigs.AUTO_OPTIMIZE.fromMetaData(metadata).orElse { - metadata.configuration.get("delta.autoOptimize.optimizeWrite").orElse { - getSQLConf( - "spark.databricks.delta.properties.defaults.autoOptimize.optimizeWrite") - }.map(_.toBoolean) + metadata.flatMap { m => + DeltaConfigs.AUTO_OPTIMIZE.fromMetaData(m).orElse { + m.configuration.get("delta.autoOptimize.optimizeWrite").orElse { + getSQLConf( + "spark.databricks.delta.properties.defaults.autoOptimize.optimizeWrite") + }.map(_.toBoolean) + } } } }.getOrElse(false) diff --git a/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/UpdateCommandMeta.scala b/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/UpdateCommandMeta.scala index a14e44a9ea4..dcb686c147f 100644 --- a/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/UpdateCommandMeta.scala +++ b/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/UpdateCommandMeta.scala @@ -35,7 +35,7 @@ class UpdateCommandMeta( s"${RapidsConf.ENABLE_DELTA_WRITE} to true") } RapidsDeltaUtils.tagForDeltaWrite(this, updateCmd.target.schema, - updateCmd.tahoeFileIndex.deltaLog, Map.empty, updateCmd.tahoeFileIndex.spark) + Some(updateCmd.tahoeFileIndex.deltaLog), Map.empty, updateCmd.tahoeFileIndex.spark) } override def convertToGpu(): RunnableCommand = { @@ -62,7 +62,7 @@ class UpdateCommandEdgeMeta( s"${RapidsConf.ENABLE_DELTA_WRITE} to true") } RapidsDeltaUtils.tagForDeltaWrite(this, updateCmd.target.schema, - updateCmd.tahoeFileIndex.deltaLog, Map.empty, updateCmd.tahoeFileIndex.spark) + Some(updateCmd.tahoeFileIndex.deltaLog), Map.empty, updateCmd.tahoeFileIndex.spark) } override def convertToGpu(): RunnableCommand = { diff --git a/delta-lake/common/src/main/delta-io/scala/com/nvidia/spark/rapids/delta/DeltaIOProvider.scala b/delta-lake/common/src/main/delta-io/scala/com/nvidia/spark/rapids/delta/DeltaIOProvider.scala index 64ad89490dd..049f98180cf 100644 --- a/delta-lake/common/src/main/delta-io/scala/com/nvidia/spark/rapids/delta/DeltaIOProvider.scala +++ b/delta-lake/common/src/main/delta-io/scala/com/nvidia/spark/rapids/delta/DeltaIOProvider.scala @@ -16,15 +16,20 @@ package com.nvidia.spark.rapids.delta +import scala.collection.JavaConverters.mapAsScalaMapConverter import scala.util.Try import com.nvidia.spark.rapids._ import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.catalog.StagingTableCatalog import org.apache.spark.sql.delta.{DeltaLog, DeltaParquetFileFormat} +import org.apache.spark.sql.delta.catalog.DeltaCatalog import org.apache.spark.sql.delta.rapids.DeltaRuntimeShim -import org.apache.spark.sql.delta.sources.DeltaDataSource +import org.apache.spark.sql.delta.sources.{DeltaDataSource, DeltaSourceUtils} import org.apache.spark.sql.execution.datasources.{FileFormat, SaveIntoDataSourceCommand} +import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.rapids.ExternalSource import org.apache.spark.sql.rapids.execution.UnshimmedTrampolineUtil import org.apache.spark.sql.sources.CreatableRelationProvider @@ -48,6 +53,28 @@ abstract class DeltaIOProvider extends DeltaProviderImplBase { override def isSupportedFormat(format: Class[_ <: FileFormat]): Boolean = { format == classOf[DeltaParquetFileFormat] } + + override def isSupportedCatalog(catalogClass: Class[_ <: StagingTableCatalog]): Boolean = { + catalogClass == classOf[DeltaCatalog] + } + + override def tagForGpu( + cpuExec: AtomicCreateTableAsSelectExec, + meta: AtomicCreateTableAsSelectExecMeta): Unit = { + require(isSupportedCatalog(cpuExec.catalog.getClass)) + if (!meta.conf.isDeltaWriteEnabled) { + meta.willNotWorkOnGpu("Delta Lake output acceleration has been disabled. To enable set " + + s"${RapidsConf.ENABLE_DELTA_WRITE} to true") + } + val properties = cpuExec.properties + val provider = properties.getOrElse("provider", + cpuExec.conf.getConf(SQLConf.DEFAULT_DATA_SOURCE_NAME)) + if (!DeltaSourceUtils.isDeltaDataSourceName(provider)) { + meta.willNotWorkOnGpu(s"table provider '$provider' is not a Delta Lake provider") + } + RapidsDeltaUtils.tagForDeltaWrite(meta, cpuExec.query.schema, None, + cpuExec.writeOptions.asCaseSensitiveMap().asScala.toMap, cpuExec.session) + } } class DeltaCreatableRelationProviderMeta( @@ -71,8 +98,8 @@ class DeltaCreatableRelationProviderMeta( val path = saveCmd.options.get("path") if (path.isDefined) { val deltaLog = DeltaLog.forTable(SparkSession.active, path.get, saveCmd.options) - RapidsDeltaUtils.tagForDeltaWrite(this, saveCmd.query.schema, deltaLog, saveCmd.options, - SparkSession.active) + RapidsDeltaUtils.tagForDeltaWrite(this, saveCmd.query.schema, Some(deltaLog), + saveCmd.options, SparkSession.active) } else { willNotWorkOnGpu("no path specified for Delta Lake table") } diff --git a/delta-lake/common/src/main/delta-io/scala/com/nvidia/spark/rapids/delta/RapidsDeltaUtils.scala b/delta-lake/common/src/main/delta-io/scala/com/nvidia/spark/rapids/delta/RapidsDeltaUtils.scala index e0c5725f91d..e78abd3d9c3 100644 --- a/delta-lake/common/src/main/delta-io/scala/com/nvidia/spark/rapids/delta/RapidsDeltaUtils.scala +++ b/delta-lake/common/src/main/delta-io/scala/com/nvidia/spark/rapids/delta/RapidsDeltaUtils.scala @@ -28,12 +28,13 @@ object RapidsDeltaUtils { def tagForDeltaWrite( meta: RapidsMeta[_, _, _], schema: StructType, - deltaLog: DeltaLog, + deltaLog: Option[DeltaLog], options: Map[String, String], spark: SparkSession): Unit = { FileFormatChecks.tag(meta, schema, DeltaFormatType, WriteFileOp) - val format = DeltaRuntimeShim.fileFormatFromLog(deltaLog) - if (format.getClass == classOf[DeltaParquetFileFormat]) { + val format = deltaLog.map(log => DeltaRuntimeShim.fileFormatFromLog(log).getClass) + .getOrElse(classOf[DeltaParquetFileFormat]) + if (format == classOf[DeltaParquetFileFormat]) { GpuParquetFileFormat.tagGpuSupport(meta, spark, options, schema) } else { meta.willNotWorkOnGpu(s"file format $format is not supported") @@ -43,7 +44,7 @@ object RapidsDeltaUtils { private def checkIncompatibleConfs( meta: RapidsMeta[_, _, _], - deltaLog: DeltaLog, + deltaLog: Option[DeltaLog], sqlConf: SQLConf, options: Map[String, String]): Unit = { def getSQLConf(key: String): Option[String] = { @@ -58,11 +59,13 @@ object RapidsDeltaUtils { val deltaOptions = new DeltaOptions(options, sqlConf) deltaOptions.optimizeWrite.orElse { getSQLConf("spark.databricks.delta.optimizeWrite.enabled").map(_.toBoolean).orElse { - val metadata = DeltaRuntimeShim.unsafeVolatileSnapshotFromLog(deltaLog).metadata - DeltaConfigs.AUTO_OPTIMIZE.fromMetaData(metadata).orElse { - metadata.configuration.get("delta.autoOptimize.optimizeWrite").orElse { - getSQLConf("spark.databricks.delta.properties.defaults.autoOptimize.optimizeWrite") - }.map(_.toBoolean) + deltaLog.flatMap { log => + val metadata = DeltaRuntimeShim.unsafeVolatileSnapshotFromLog(log).metadata + DeltaConfigs.AUTO_OPTIMIZE.fromMetaData(metadata).orElse { + metadata.configuration.get("delta.autoOptimize.optimizeWrite").orElse { + getSQLConf("spark.databricks.delta.properties.defaults.autoOptimize.optimizeWrite") + }.map(_.toBoolean) + } } } }.getOrElse(false) @@ -73,9 +76,11 @@ object RapidsDeltaUtils { val autoCompactEnabled = getSQLConf("spark.databricks.delta.autoCompact.enabled").orElse { - val metadata = DeltaRuntimeShim.unsafeVolatileSnapshotFromLog(deltaLog).metadata - metadata.configuration.get("delta.autoOptimize.autoCompact").orElse { - getSQLConf("spark.databricks.delta.properties.defaults.autoOptimize.autoCompact") + deltaLog.flatMap { log => + val metadata = DeltaRuntimeShim.unsafeVolatileSnapshotFromLog(log).metadata + metadata.configuration.get("delta.autoOptimize.autoCompact").orElse { + getSQLConf("spark.databricks.delta.properties.defaults.autoOptimize.autoCompact") + } } }.exists(_.toBoolean) if (autoCompactEnabled) { diff --git a/delta-lake/common/src/main/delta-io/scala/org/apache/spark/sql/delta/rapids/DeltaRuntimeShim.scala b/delta-lake/common/src/main/delta-io/scala/org/apache/spark/sql/delta/rapids/DeltaRuntimeShim.scala index f6f73da308e..65c9ac54094 100644 --- a/delta-lake/common/src/main/delta-io/scala/org/apache/spark/sql/delta/rapids/DeltaRuntimeShim.scala +++ b/delta-lake/common/src/main/delta-io/scala/org/apache/spark/sql/delta/rapids/DeltaRuntimeShim.scala @@ -22,7 +22,9 @@ import com.nvidia.spark.rapids.{RapidsConf, ShimReflectionUtils, VersionUtils} import com.nvidia.spark.rapids.delta.DeltaProvider import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.catalog.StagingTableCatalog import org.apache.spark.sql.delta.{DeltaLog, DeltaUDF, Snapshot} +import org.apache.spark.sql.delta.catalog.DeltaCatalog import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.util.Clock @@ -35,6 +37,8 @@ trait DeltaRuntimeShim { def fileFormatFromLog(deltaLog: DeltaLog): FileFormat def getTightBoundColumnOnFileInitDisabled(spark: SparkSession): Boolean + + def getGpuDeltaCatalog(cpuCatalog: DeltaCatalog, rapidsConf: RapidsConf): StagingTableCatalog } object DeltaRuntimeShim { @@ -81,4 +85,8 @@ object DeltaRuntimeShim { def getTightBoundColumnOnFileInitDisabled(spark: SparkSession): Boolean = shimInstance.getTightBoundColumnOnFileInitDisabled(spark) + + def getGpuDeltaCatalog(cpuCatalog: DeltaCatalog, rapidsConf: RapidsConf): StagingTableCatalog = { + shimInstance.getGpuDeltaCatalog(cpuCatalog, rapidsConf) + } } diff --git a/delta-lake/common/src/main/delta-io/scala/org/apache/spark/sql/delta/rapids/GpuDeltaCatalogBase.scala b/delta-lake/common/src/main/delta-io/scala/org/apache/spark/sql/delta/rapids/GpuDeltaCatalogBase.scala new file mode 100644 index 00000000000..6009c409a3a --- /dev/null +++ b/delta-lake/common/src/main/delta-io/scala/org/apache/spark/sql/delta/rapids/GpuDeltaCatalogBase.scala @@ -0,0 +1,442 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta.rapids + +import java.util +import java.util.Locale + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import com.nvidia.spark.rapids.RapidsConf +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTableType, CatalogUtils, SessionCatalog} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.connector.catalog.{Identifier, StagedTable, StagingTableCatalog, SupportsWrite, Table, TableCapability, TableCatalog, TableChange} +import org.apache.spark.sql.connector.catalog.TableCapability._ +import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, V1Write, WriteBuilder} +import org.apache.spark.sql.delta.{DeltaErrors, DeltaLog, DeltaOptions} +import org.apache.spark.sql.delta.catalog.{BucketTransform, DeltaCatalog} +import org.apache.spark.sql.delta.commands.{TableCreationModes, WriteIntoDelta} +import org.apache.spark.sql.delta.sources.{DeltaSourceUtils, DeltaSQLConf} +import org.apache.spark.sql.execution.command.LeafRunnableCommand +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.InsertableRelation +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +trait GpuDeltaCatalogBase extends StagingTableCatalog { + val spark: SparkSession + + val cpuCatalog: DeltaCatalog + + val rapidsConf: RapidsConf + + protected def buildGpuCreateDeltaTableCommand( + rapidsConf: RapidsConf, + table: CatalogTable, + existingTableOpt: Option[CatalogTable], + mode: SaveMode, + query: Option[LogicalPlan], + operation: TableCreationModes.CreationMode, + tableByPath: Boolean): LeafRunnableCommand + + protected def getExistingTableIfExists(table: TableIdentifier): Option[CatalogTable] + + protected def verifyTableAndSolidify( + tableDesc: CatalogTable, + query: Option[LogicalPlan]): CatalogTable + + protected def createGpuStagedDeltaTableV2( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String], + operation: TableCreationModes.CreationMode): StagedTable = { + new GpuStagedDeltaTableV2(ident, schema, partitions, properties, operation) + } + + /** + * Creates a Delta table using GPU for writing the data + * + * @param ident The identifier of the table + * @param schema The schema of the table + * @param partitions The partition transforms for the table + * @param allTableProperties The table properties that configure the behavior of the table or + * provide information about the table + * @param writeOptions Options specific to the write during table creation or replacement + * @param sourceQuery A query if this CREATE request came from a CTAS or RTAS + * @param operation The specific table creation mode, whether this is a + * Create/Replace/Create or Replace + */ + protected def createDeltaTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + allTableProperties: util.Map[String, String], + writeOptions: Map[String, String], + sourceQuery: Option[DataFrame], + operation: TableCreationModes.CreationMode): Table = { + // These two keys are tableProperties in data source v2 but not in v1, so we have to filter + // them out. Otherwise property consistency checks will fail. + val tableProperties = allTableProperties.asScala.filterKeys { + case TableCatalog.PROP_LOCATION => false + case TableCatalog.PROP_PROVIDER => false + case TableCatalog.PROP_COMMENT => false + case TableCatalog.PROP_OWNER => false + case TableCatalog.PROP_EXTERNAL => false + case "path" => false + case "option.path" => false + case _ => true + }.toMap + val (partitionColumns, maybeBucketSpec) = convertTransforms(partitions) + val newSchema = schema + val newPartitionColumns = partitionColumns + val newBucketSpec = maybeBucketSpec + val conf = spark.sessionState.conf + + val isByPath = isPathIdentifier(ident) + if (isByPath && !conf.getConf(DeltaSQLConf.DELTA_LEGACY_ALLOW_AMBIGUOUS_PATHS) + && allTableProperties.containsKey("location") + // The location property can be qualified and different from the path in the identifier, so + // we check `endsWith` here. + && Option(allTableProperties.get("location")).exists(!_.endsWith(ident.name())) + ) { + throw DeltaErrors.ambiguousPathsInCreateTableException( + ident.name(), allTableProperties.get("location")) + } + val location = if (isByPath) { + Option(ident.name()) + } else { + Option(allTableProperties.get("location")) + } + val id = { + TableIdentifier(ident.name(), ident.namespace().lastOption) + } + val locUriOpt = location.map(CatalogUtils.stringToURI) + val existingTableOpt = getExistingTableIfExists(id) + val loc = locUriOpt + .orElse(existingTableOpt.flatMap(_.storage.locationUri)) + .getOrElse(spark.sessionState.catalog.defaultTablePath(id)) + val storage = DataSource.buildStorageFormatFromOptions(writeOptions) + .copy(locationUri = Option(loc)) + val tableType = + if (location.isDefined) CatalogTableType.EXTERNAL else CatalogTableType.MANAGED + val commentOpt = Option(allTableProperties.get("comment")) + + + val tableDesc = new CatalogTable( + identifier = id, + tableType = tableType, + storage = storage, + schema = newSchema, + provider = Some(DeltaSourceUtils.ALT_NAME), + partitionColumnNames = newPartitionColumns, + bucketSpec = newBucketSpec, + properties = tableProperties, + comment = commentOpt + ) + + val withDb = verifyTableAndSolidify(tableDesc, None) + + val writer = sourceQuery.map { df => + WriteIntoDelta( + DeltaLog.forTable(spark, new Path(loc)), + operation.mode, + new DeltaOptions(withDb.storage.properties, spark.sessionState.conf), + withDb.partitionColumnNames, + withDb.properties ++ commentOpt.map("comment" -> _), + df, + schemaInCatalog = if (newSchema != schema) Some(newSchema) else None) + } + + val gpuCreateTableCommand = buildGpuCreateDeltaTableCommand( + rapidsConf, + withDb, + existingTableOpt, + operation.mode, + writer, + operation, + tableByPath = isByPath) + gpuCreateTableCommand.run(spark) + + cpuCatalog.loadTable(ident) + } + + override def loadTable(ident: Identifier): Table = cpuCatalog.loadTable(ident) + + private def getProvider(properties: util.Map[String, String]): String = { + Option(properties.get("provider")) + .getOrElse(spark.sessionState.conf.getConf(SQLConf.DEFAULT_DATA_SOURCE_NAME)) + } + + override def createTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = { + if (DeltaSourceUtils.isDeltaDataSourceName(getProvider(properties))) { + createDeltaTable( + ident, + schema, + partitions, + properties, + Map.empty, + sourceQuery = None, + TableCreationModes.Create + ) + } else { + throw new IllegalStateException(s"Cannot create non-Delta tables") + } + } + + override def stageReplace( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): StagedTable = { + if (DeltaSourceUtils.isDeltaDataSourceName(getProvider(properties))) { + createGpuStagedDeltaTableV2( + ident, + schema, + partitions, + properties, + TableCreationModes.Replace + ) + } else { + throw new IllegalStateException(s"Cannot create non-Delta tables") + } + } + + override def stageCreateOrReplace( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): StagedTable = { + if (DeltaSourceUtils.isDeltaDataSourceName(getProvider(properties))) { + createGpuStagedDeltaTableV2( + ident, + schema, + partitions, + properties, + TableCreationModes.CreateOrReplace + ) + } else { + throw new IllegalStateException(s"Cannot create non-Delta tables") + } + } + + override def stageCreate( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): StagedTable = { + if (DeltaSourceUtils.isDeltaDataSourceName(getProvider(properties))) { + createGpuStagedDeltaTableV2( + ident, + schema, + partitions, + properties, + TableCreationModes.Create + ) + } else { + throw new IllegalStateException(s"Cannot create non-Delta tables") + } + } + +// Copy of V2SessionCatalog.convertTransforms, which is private. + private def convertTransforms(partitions: Seq[Transform]): (Seq[String], Option[BucketSpec]) = { + val identityCols = new mutable.ArrayBuffer[String] + var bucketSpec = Option.empty[BucketSpec] + + partitions.map { + case IdentityTransform(FieldReference(Seq(col))) => + identityCols += col + + case BucketTransform(numBuckets, bucketCols, sortCols) => + bucketSpec = Some(BucketSpec( + numBuckets, bucketCols.map(_.fieldNames.head), sortCols.map(_.fieldNames.head))) + + case _ => + throw DeltaErrors.operationNotSupportedException(s"Partitioning by expressions") + } + + (identityCols.toSeq, bucketSpec) + } + + /** + * A staged Delta table, which creates a HiveMetaStore entry and appends data if this was a + * CTAS/RTAS command. We have a ugly way of using this API right now, but it's the best way to + * maintain old behavior compatibility between Databricks Runtime and OSS Delta Lake. + */ + protected class GpuStagedDeltaTableV2( + ident: Identifier, + override val schema: StructType, + val partitions: Array[Transform], + override val properties: util.Map[String, String], + operation: TableCreationModes.CreationMode + ) extends StagedTable with SupportsWrite { + + private var asSelectQuery: Option[DataFrame] = None + private var writeOptions: Map[String, String] = Map.empty + + override def partitioning(): Array[Transform] = partitions + + override def commitStagedChanges(): Unit = { + val conf = spark.sessionState.conf + val props = new util.HashMap[String, String]() + // Options passed in through the SQL API will show up both with an "option." prefix and + // without in Spark 3.1, so we need to remove those from the properties + val optionsThroughProperties = properties.asScala.collect { + case (k, _) if k.startsWith("option.") => k.stripPrefix("option.") + }.toSet + val sqlWriteOptions = new util.HashMap[String, String]() + properties.asScala.foreach { case (k, v) => + if (!k.startsWith("option.") && !optionsThroughProperties.contains(k)) { + // Do not add to properties + props.put(k, v) + } else if (optionsThroughProperties.contains(k)) { + sqlWriteOptions.put(k, v) + } + } + if (writeOptions.isEmpty && !sqlWriteOptions.isEmpty) { + writeOptions = sqlWriteOptions.asScala.toMap + } + if (conf.getConf(DeltaSQLConf.DELTA_LEGACY_STORE_WRITER_OPTIONS_AS_PROPS)) { + // Legacy behavior + writeOptions.foreach { case (k, v) => props.put(k, v) } + } else { + writeOptions.foreach { case (k, v) => + // Continue putting in Delta prefixed options to avoid breaking workloads + if (k.toLowerCase(Locale.ROOT).startsWith("delta.")) { + props.put(k, v) + } + } + } + createDeltaTable( + ident, + schema, + partitions, + props, + writeOptions, + asSelectQuery, + operation + ) + } + + override def name(): String = ident.name() + + override def abortStagedChanges(): Unit = {} + + override def capabilities(): util.Set[TableCapability] = Set(V1_BATCH_WRITE).asJava + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { + writeOptions = info.options.asCaseSensitiveMap().asScala.toMap + new DeltaV1WriteBuilder + } + + /* + * WriteBuilder for creating a Delta table. + */ + private class DeltaV1WriteBuilder extends WriteBuilder { + override def build(): V1Write = new V1Write { + override def toInsertableRelation(): InsertableRelation = { + new InsertableRelation { + override def insert(data: DataFrame, overwrite: Boolean): Unit = { + asSelectQuery = Option(data) + } + } + } + } + } + } + + override def alterTable(ident: Identifier, changes: TableChange*): Table = { + cpuCatalog.alterTable(ident, changes: _*) + } + + override def dropTable(ident: Identifier): Boolean = cpuCatalog.dropTable(ident) + + override def renameTable(oldIdent: Identifier, newIdent: Identifier): Unit = { + cpuCatalog.renameTable(oldIdent, newIdent) + } + + override def listTables(namespace: Array[String]): Array[Identifier] = { + throw new IllegalStateException("should not be called") + } + + override def tableExists(ident: Identifier): Boolean = cpuCatalog.tableExists(ident) + + override def initialize(s: String, caseInsensitiveStringMap: CaseInsensitiveStringMap): Unit = { + throw new IllegalStateException("should not be called") + } + + override def name(): String = cpuCatalog.name() + + private def supportSQLOnFile: Boolean = spark.sessionState.conf.runSQLonFile + + private def hasDeltaNamespace(ident: Identifier): Boolean = { + ident.namespace().length == 1 && DeltaSourceUtils.isDeltaDataSourceName(ident.namespace().head) + } + + private def isPathIdentifier(ident: Identifier): Boolean = { + // Should be a simple check of a special PathIdentifier class in the future + try { + supportSQLOnFile && hasDeltaNamespace(ident) && new Path(ident.name()).isAbsolute + } catch { + case _: IllegalArgumentException => false + } + } +} + + +/** + * A trait for handling table access through delta.`/some/path`. This is a stop-gap solution + * until PathIdentifiers are implemented in Apache Spark. + */ +trait SupportsPathIdentifier extends TableCatalog { self: GpuDeltaCatalogBase => + + private def supportSQLOnFile: Boolean = spark.sessionState.conf.runSQLonFile + + protected lazy val catalog: SessionCatalog = spark.sessionState.catalog + + private def hasDeltaNamespace(ident: Identifier): Boolean = { + ident.namespace().length == 1 && DeltaSourceUtils.isDeltaDataSourceName(ident.namespace().head) + } + + protected def isPathIdentifier(ident: Identifier): Boolean = { + // Should be a simple check of a special PathIdentifier class in the future + try { + supportSQLOnFile && hasDeltaNamespace(ident) && new Path(ident.name()).isAbsolute + } catch { + case _: IllegalArgumentException => false + } + } + + protected def isPathIdentifier(table: CatalogTable): Boolean = { + isPathIdentifier(table.identifier) + } + + protected def isPathIdentifier(tableIdentifier: TableIdentifier) : Boolean = { + isPathIdentifier(Identifier.of(tableIdentifier.database.toArray, tableIdentifier.table)) + } +} diff --git a/delta-lake/common/src/main/delta-io/scala/org/apache/spark/sql/delta/rapids/GpuDeltaLog.scala b/delta-lake/common/src/main/delta-io/scala/org/apache/spark/sql/delta/rapids/GpuDeltaLog.scala index 01016f6fd48..442dee6037c 100644 --- a/delta-lake/common/src/main/delta-io/scala/org/apache/spark/sql/delta/rapids/GpuDeltaLog.scala +++ b/delta-lake/common/src/main/delta-io/scala/org/apache/spark/sql/delta/rapids/GpuDeltaLog.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.delta.rapids import com.nvidia.spark.rapids.RapidsConf +import org.apache.hadoop.fs.Path import org.apache.spark.sql.SparkSession import org.apache.spark.sql.delta.{DeltaLog, OptimisticTransaction} @@ -70,4 +71,12 @@ object GpuDeltaLog { val deltaLog = DeltaLog.forTable(spark, dataPath, options) new GpuDeltaLog(deltaLog, rapidsConf) } + + def forTable( + spark: SparkSession, + tableLocation: Path, + rapidsConf: RapidsConf): GpuDeltaLog = { + val deltaLog = DeltaLog.forTable(spark, tableLocation) + new GpuDeltaLog(deltaLog, rapidsConf) + } } diff --git a/delta-lake/delta-20x/src/main/scala/com/nvidia/spark/rapids/delta/delta20x/DeleteCommandMeta.scala b/delta-lake/delta-20x/src/main/scala/com/nvidia/spark/rapids/delta/delta20x/DeleteCommandMeta.scala index 76f4906c1ed..67c5454073a 100644 --- a/delta-lake/delta-20x/src/main/scala/com/nvidia/spark/rapids/delta/delta20x/DeleteCommandMeta.scala +++ b/delta-lake/delta-20x/src/main/scala/com/nvidia/spark/rapids/delta/delta20x/DeleteCommandMeta.scala @@ -37,7 +37,7 @@ class DeleteCommandMeta( willNotWorkOnGpu("Delta Lake output acceleration has been disabled. To enable set " + s"${RapidsConf.ENABLE_DELTA_WRITE} to true") } - RapidsDeltaUtils.tagForDeltaWrite(this, deleteCmd.target.schema, deleteCmd.deltaLog, + RapidsDeltaUtils.tagForDeltaWrite(this, deleteCmd.target.schema, Some(deleteCmd.deltaLog), Map.empty, SparkSession.active) } diff --git a/delta-lake/delta-20x/src/main/scala/com/nvidia/spark/rapids/delta/delta20x/Delta20xProvider.scala b/delta-lake/delta-20x/src/main/scala/com/nvidia/spark/rapids/delta/delta20x/Delta20xProvider.scala index 1e8cef6cca7..7542b7b612b 100644 --- a/delta-lake/delta-20x/src/main/scala/com/nvidia/spark/rapids/delta/delta20x/Delta20xProvider.scala +++ b/delta-lake/delta-20x/src/main/scala/com/nvidia/spark/rapids/delta/delta20x/Delta20xProvider.scala @@ -16,14 +16,18 @@ package com.nvidia.spark.rapids.delta.delta20x -import com.nvidia.spark.rapids.{GpuOverrides, GpuReadParquetFileFormat, RunnableCommandRule, SparkPlanMeta} +import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, GpuExec, GpuOverrides, GpuReadParquetFileFormat, RunnableCommandRule, SparkPlanMeta} import com.nvidia.spark.rapids.delta.DeltaIOProvider import org.apache.spark.sql.delta.DeltaParquetFileFormat +import org.apache.spark.sql.delta.catalog.DeltaCatalog import org.apache.spark.sql.delta.commands.{DeleteCommand, MergeIntoCommand, UpdateCommand} +import org.apache.spark.sql.delta.rapids.DeltaRuntimeShim import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec +import org.apache.spark.sql.execution.datasources.v2.rapids.GpuAtomicCreateTableAsSelectExec object Delta20xProvider extends DeltaIOProvider { @@ -58,4 +62,19 @@ object Delta20xProvider extends DeltaIOProvider { val cpuFormat = format.asInstanceOf[DeltaParquetFileFormat] GpuDelta20xParquetFileFormat(cpuFormat.columnMappingMode, cpuFormat.referenceSchema) } + + override def convertToGpu( + cpuExec: AtomicCreateTableAsSelectExec, + meta: AtomicCreateTableAsSelectExecMeta): GpuExec = { + val cpuCatalog = cpuExec.catalog.asInstanceOf[DeltaCatalog] + GpuAtomicCreateTableAsSelectExec( + DeltaRuntimeShim.getGpuDeltaCatalog(cpuCatalog, meta.conf), + cpuExec.ident, + cpuExec.partitioning, + cpuExec.plan, + meta.childPlans.head.convertIfNeeded(), + cpuExec.properties, + cpuExec.writeOptions, + cpuExec.ifNotExists) + } } diff --git a/delta-lake/delta-20x/src/main/scala/com/nvidia/spark/rapids/delta/delta20x/MergeIntoCommandMeta.scala b/delta-lake/delta-20x/src/main/scala/com/nvidia/spark/rapids/delta/delta20x/MergeIntoCommandMeta.scala index c51929810b2..7fb65fce089 100644 --- a/delta-lake/delta-20x/src/main/scala/com/nvidia/spark/rapids/delta/delta20x/MergeIntoCommandMeta.scala +++ b/delta-lake/delta-20x/src/main/scala/com/nvidia/spark/rapids/delta/delta20x/MergeIntoCommandMeta.scala @@ -39,7 +39,8 @@ class MergeIntoCommandMeta( } val targetSchema = mergeCmd.migratedSchema.getOrElse(mergeCmd.target.schema) val deltaLog = mergeCmd.targetFileIndex.deltaLog - RapidsDeltaUtils.tagForDeltaWrite(this, targetSchema, deltaLog, Map.empty, SparkSession.active) + RapidsDeltaUtils.tagForDeltaWrite(this, targetSchema, Some(deltaLog), Map.empty, + SparkSession.active) } override def convertToGpu(): RunnableCommand = { diff --git a/delta-lake/delta-20x/src/main/scala/com/nvidia/spark/rapids/delta/delta20x/UpdateCommandMeta.scala b/delta-lake/delta-20x/src/main/scala/com/nvidia/spark/rapids/delta/delta20x/UpdateCommandMeta.scala index 2c6cfffd17a..4d479f6b86d 100644 --- a/delta-lake/delta-20x/src/main/scala/com/nvidia/spark/rapids/delta/delta20x/UpdateCommandMeta.scala +++ b/delta-lake/delta-20x/src/main/scala/com/nvidia/spark/rapids/delta/delta20x/UpdateCommandMeta.scala @@ -37,7 +37,7 @@ class UpdateCommandMeta( s"${RapidsConf.ENABLE_DELTA_WRITE} to true") } RapidsDeltaUtils.tagForDeltaWrite(this, updateCmd.target.schema, - updateCmd.tahoeFileIndex.deltaLog, Map.empty, updateCmd.tahoeFileIndex.spark) + Some(updateCmd.tahoeFileIndex.deltaLog), Map.empty, updateCmd.tahoeFileIndex.spark) } override def convertToGpu(): RunnableCommand = { diff --git a/delta-lake/delta-20x/src/main/scala/org/apache/spark/sql/delta/rapids/delta20x/Delta20xRuntimeShim.scala b/delta-lake/delta-20x/src/main/scala/org/apache/spark/sql/delta/rapids/delta20x/Delta20xRuntimeShim.scala index d620c5be5bd..9dcdc1aad22 100644 --- a/delta-lake/delta-20x/src/main/scala/org/apache/spark/sql/delta/rapids/delta20x/Delta20xRuntimeShim.scala +++ b/delta-lake/delta-20x/src/main/scala/org/apache/spark/sql/delta/rapids/delta20x/Delta20xRuntimeShim.scala @@ -21,7 +21,9 @@ import com.nvidia.spark.rapids.delta.DeltaProvider import com.nvidia.spark.rapids.delta.delta20x.Delta20xProvider import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.catalog.StagingTableCatalog import org.apache.spark.sql.delta.{DeltaLog, DeltaUDF, Snapshot} +import org.apache.spark.sql.delta.catalog.DeltaCatalog import org.apache.spark.sql.delta.rapids.{DeltaRuntimeShim, GpuOptimisticTransactionBase} import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.expressions.UserDefinedFunction @@ -54,4 +56,10 @@ class Delta20xRuntimeShim extends DeltaRuntimeShim { deltaLog.fileFormat() override def getTightBoundColumnOnFileInitDisabled(spark: SparkSession): Boolean = false + + override def getGpuDeltaCatalog( + cpuCatalog: DeltaCatalog, + rapidsConf: RapidsConf): StagingTableCatalog = { + new GpuDeltaCatalog(cpuCatalog, rapidsConf) + } } diff --git a/delta-lake/delta-20x/src/main/scala/org/apache/spark/sql/delta/rapids/delta20x/GpuCreateDeltaTableCommand.scala b/delta-lake/delta-20x/src/main/scala/org/apache/spark/sql/delta/rapids/delta20x/GpuCreateDeltaTableCommand.scala new file mode 100644 index 00000000000..83afddef1d5 --- /dev/null +++ b/delta-lake/delta-20x/src/main/scala/org/apache/spark/sql/delta/rapids/delta20x/GpuCreateDeltaTableCommand.scala @@ -0,0 +1,465 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * This file was derived from DeltaDataSource.scala in the + * Delta Lake project at https://github.com/delta-io/delta. + * + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta.rapids.delta20x + +import com.nvidia.spark.rapids.RapidsConf +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.delta._ +import org.apache.spark.sql.delta.actions.Metadata +import org.apache.spark.sql.delta.commands.{TableCreationModes, WriteIntoDelta} +import org.apache.spark.sql.delta.metering.DeltaLogging +import org.apache.spark.sql.delta.rapids.GpuDeltaLog +import org.apache.spark.sql.delta.schema.SchemaUtils +import org.apache.spark.sql.delta.sources.DeltaSQLConf +import org.apache.spark.sql.execution.command.{LeafRunnableCommand, RunnableCommand} +import org.apache.spark.sql.types.StructType + +/** + * Single entry point for all write or declaration operations for Delta tables accessed through + * the table name. + * + * @param table The table identifier for the Delta table + * @param existingTableOpt The existing table for the same identifier if exists + * @param mode The save mode when writing data. Relevant when the query is empty or set to Ignore + * with `CREATE TABLE IF NOT EXISTS`. + * @param query The query to commit into the Delta table if it exist. This can come from + * - CTAS + * - saveAsTable + */ +case class GpuCreateDeltaTableCommand( + table: CatalogTable, + existingTableOpt: Option[CatalogTable], + mode: SaveMode, + query: Option[LogicalPlan], + operation: TableCreationModes.CreationMode = TableCreationModes.Create, + tableByPath: Boolean = false, + override val output: Seq[Attribute] = Nil)(@transient rapidsConf: RapidsConf) + extends LeafRunnableCommand + with DeltaLogging { + + override def run(sparkSession: SparkSession): Seq[Row] = { + val table = this.table + + assert(table.tableType != CatalogTableType.VIEW) + assert(table.identifier.database.isDefined, "Database should've been fixed at analysis") + // There is a subtle race condition here, where the table can be created by someone else + // while this command is running. Nothing we can do about that though :( + val tableExists = existingTableOpt.isDefined + if (mode == SaveMode.Ignore && tableExists) { + // Early exit on ignore + return Nil + } else if (mode == SaveMode.ErrorIfExists && tableExists) { + throw DeltaErrors.tableAlreadyExists(table) + } + + val tableWithLocation = if (tableExists) { + val existingTable = existingTableOpt.get + table.storage.locationUri match { + case Some(location) if location.getPath != existingTable.location.getPath => + val tableName = table.identifier.quotedString + throw new AnalysisException( + s"The location of the existing table $tableName is " + + s"`${existingTable.location}`. It doesn't match the specified location " + + s"`${table.location}`.") + case _ => + } + table.copy( + storage = existingTable.storage, + tableType = existingTable.tableType) + } else if (table.storage.locationUri.isEmpty) { + // We are defining a new managed table + assert(table.tableType == CatalogTableType.MANAGED) + val loc = sparkSession.sessionState.catalog.defaultTablePath(table.identifier) + table.copy(storage = table.storage.copy(locationUri = Some(loc))) + } else { + // 1. We are defining a new external table + // 2. It's a managed table which already has the location populated. This can happen in DSV2 + // CTAS flow. + table + } + + + val isManagedTable = tableWithLocation.tableType == CatalogTableType.MANAGED + val tableLocation = new Path(tableWithLocation.location) + val gpuDeltaLog = GpuDeltaLog.forTable(sparkSession, tableLocation, rapidsConf) + val hadoopConf = gpuDeltaLog.deltaLog.newDeltaHadoopConf() + val fs = tableLocation.getFileSystem(hadoopConf) + val options = new DeltaOptions(table.storage.properties, sparkSession.sessionState.conf) + var result: Seq[Row] = Nil + + recordDeltaOperation(gpuDeltaLog.deltaLog, "delta.ddl.createTable") { + val txn = gpuDeltaLog.startTransaction() + if (query.isDefined) { + // If the mode is Ignore or ErrorIfExists, the table must not exist, or we would return + // earlier. And the data should not exist either, to match the behavior of + // Ignore/ErrorIfExists mode. This means the table path should not exist or is empty. + if (mode == SaveMode.Ignore || mode == SaveMode.ErrorIfExists) { + assert(!tableExists) + // We may have failed a previous write. The retry should still succeed even if we have + // garbage data + if (txn.readVersion > -1 || !fs.exists(gpuDeltaLog.deltaLog.logPath)) { + assertPathEmpty(hadoopConf, tableWithLocation) + } + } + // We are either appending/overwriting with saveAsTable or creating a new table with CTAS or + // we are creating a table as part of a RunnableCommand + query.get match { + case writer: WriteIntoDelta => + // In the V2 Writer, methods like "replace" and "createOrReplace" implicitly mean that + // the metadata should be changed. This wasn't the behavior for DataFrameWriterV1. + if (!isV1Writer) { + replaceMetadataIfNecessary( + txn, tableWithLocation, options, writer.data.schema.asNullable) + } + val actions = writer.write(txn, sparkSession) + val op = getOperation(txn.metadata, isManagedTable, Some(options)) + txn.commit(actions, op) + case cmd: RunnableCommand => + result = cmd.run(sparkSession) + case other => + // When using V1 APIs, the `other` plan is not yet optimized, therefore, it is safe + // to once again go through analysis + val data = Dataset.ofRows(sparkSession, other) + + // In the V2 Writer, methods like "replace" and "createOrReplace" implicitly mean that + // the metadata should be changed. This wasn't the behavior for DataFrameWriterV1. + if (!isV1Writer) { + replaceMetadataIfNecessary( + txn, tableWithLocation, options, other.schema.asNullable) + } + + val actions = WriteIntoDelta( + deltaLog = gpuDeltaLog.deltaLog, + mode = mode, + options, + partitionColumns = table.partitionColumnNames, + configuration = tableWithLocation.properties + ("comment" -> table.comment.orNull), + data = data).write(txn, sparkSession) + + val op = getOperation(txn.metadata, isManagedTable, Some(options)) + txn.commit(actions, op) + } + } else { + def createTransactionLogOrVerify(): Unit = { + if (isManagedTable) { + // When creating a managed table, the table path should not exist or is empty, or + // users would be surprised to see the data, or see the data directory being dropped + // after the table is dropped. + assertPathEmpty(hadoopConf, tableWithLocation) + } + + // This is either a new table, or, we never defined the schema of the table. While it is + // unexpected that `txn.metadata.schema` to be empty when txn.readVersion >= 0, we still + // guard against it, in case of checkpoint corruption bugs. + val noExistingMetadata = txn.readVersion == -1 || txn.metadata.schema.isEmpty + if (noExistingMetadata) { + assertTableSchemaDefined(fs, tableLocation, tableWithLocation, txn, sparkSession) + assertPathEmpty(hadoopConf, tableWithLocation) + // This is a user provided schema. + // Doesn't come from a query, Follow nullability invariants. + val newMetadata = getProvidedMetadata(tableWithLocation, table.schema.json) + txn.updateMetadataForNewTable(newMetadata) + + val op = getOperation(newMetadata, isManagedTable, None) + txn.commit(Nil, op) + } else { + verifyTableMetadata(txn, tableWithLocation) + } + } + // We are defining a table using the Create or Replace Table statements. + operation match { + case TableCreationModes.Create => + require(!tableExists, "Can't recreate a table when it exists") + createTransactionLogOrVerify() + + case TableCreationModes.CreateOrReplace if !tableExists => + // If the table doesn't exist, CREATE OR REPLACE must provide a schema + if (tableWithLocation.schema.isEmpty) { + throw DeltaErrors.schemaNotProvidedException + } + createTransactionLogOrVerify() + case _ => + // When the operation is a REPLACE or CREATE OR REPLACE, then the schema shouldn't be + // empty, since we'll use the entry to replace the schema + if (tableWithLocation.schema.isEmpty) { + throw DeltaErrors.schemaNotProvidedException + } + // We need to replace + replaceMetadataIfNecessary(txn, tableWithLocation, options, tableWithLocation.schema) + // Truncate the table + val operationTimestamp = System.currentTimeMillis() + val removes = txn.filterFiles().map(_.removeWithTimestamp(operationTimestamp)) + val op = getOperation(txn.metadata, isManagedTable, None) + txn.commit(removes, op) + } + } + + // We would have failed earlier on if we couldn't ignore the existence of the table + // In addition, we just might using saveAsTable to append to the table, so ignore the creation + // if it already exists. + // Note that someone may have dropped and recreated the table in a separate location in the + // meantime... Unfortunately we can't do anything there at the moment, because Hive sucks. + logInfo(s"Table is path-based table: $tableByPath. Update catalog with mode: $operation") + updateCatalog(sparkSession, tableWithLocation, gpuDeltaLog.deltaLog.snapshot, txn) + + result + } + } + + + private def getProvidedMetadata(table: CatalogTable, schemaString: String): Metadata = { + Metadata( + description = table.comment.orNull, + schemaString = schemaString, + partitionColumns = table.partitionColumnNames, + configuration = table.properties, + createdTime = Some(System.currentTimeMillis())) + } + + private def assertPathEmpty( + hadoopConf: Configuration, + tableWithLocation: CatalogTable): Unit = { + val path = new Path(tableWithLocation.location) + val fs = path.getFileSystem(hadoopConf) + // Verify that the table location associated with CREATE TABLE doesn't have any data. Note that + // we intentionally diverge from this behavior w.r.t regular datasource tables (that silently + // overwrite any previous data) + if (fs.exists(path) && fs.listStatus(path).nonEmpty) { + throw DeltaErrors.createTableWithNonEmptyLocation( + tableWithLocation.identifier.toString, + tableWithLocation.location.toString) + } + } + + private def assertTableSchemaDefined( + fs: FileSystem, + path: Path, + table: CatalogTable, + txn: OptimisticTransaction, + sparkSession: SparkSession): Unit = { + // If we allow creating an empty schema table and indeed the table is new, we just need to + // make sure: + // 1. txn.readVersion == -1 to read a new table + // 2. for external tables: path must either doesn't exist or is completely empty + val allowCreatingTableWithEmptySchema = sparkSession.sessionState + .conf.getConf(DeltaSQLConf.DELTA_ALLOW_CREATE_EMPTY_SCHEMA_TABLE) && txn.readVersion == -1 + + // Users did not specify the schema. We expect the schema exists in Delta. + if (table.schema.isEmpty) { + if (table.tableType == CatalogTableType.EXTERNAL) { + if (fs.exists(path) && fs.listStatus(path).nonEmpty) { + throw DeltaErrors.createExternalTableWithoutLogException( + path, table.identifier.quotedString, sparkSession) + } else { + if (allowCreatingTableWithEmptySchema) return + throw DeltaErrors.createExternalTableWithoutSchemaException( + path, table.identifier.quotedString, sparkSession) + } + } else { + if (allowCreatingTableWithEmptySchema) return + throw DeltaErrors.createManagedTableWithoutSchemaException( + table.identifier.quotedString, sparkSession) + } + } + } + + /** + * Verify against our transaction metadata that the user specified the right metadata for the + * table. + */ + private def verifyTableMetadata( + txn: OptimisticTransaction, + tableDesc: CatalogTable): Unit = { + val existingMetadata = txn.metadata + val path = new Path(tableDesc.location) + + // The delta log already exists. If they give any configuration, we'll make sure it all matches. + // Otherwise we'll just go with the metadata already present in the log. + // The schema compatibility checks will be made in `WriteIntoDelta` for CreateTable + // with a query + if (txn.readVersion > -1) { + if (tableDesc.schema.nonEmpty) { + // We check exact alignment on create table if everything is provided + // However, if in column mapping mode, we can safely ignore the related metadata fields in + // existing metadata because new table desc will not have related metadata assigned yet + val differences = SchemaUtils.reportDifferences( + DeltaColumnMapping.dropColumnMappingMetadata(existingMetadata.schema), + tableDesc.schema) + if (differences.nonEmpty) { + throw DeltaErrors.createTableWithDifferentSchemaException( + path, tableDesc.schema, existingMetadata.schema, differences) + } + } + + // If schema is specified, we must make sure the partitioning matches, even the partitioning + // is not specified. + if (tableDesc.schema.nonEmpty && + tableDesc.partitionColumnNames != existingMetadata.partitionColumns) { + throw DeltaErrors.createTableWithDifferentPartitioningException( + path, tableDesc.partitionColumnNames, existingMetadata.partitionColumns) + } + + if (tableDesc.properties.nonEmpty && tableDesc.properties != existingMetadata.configuration) { + throw DeltaErrors.createTableWithDifferentPropertiesException( + path, tableDesc.properties, existingMetadata.configuration) + } + } + } + + /** + * Based on the table creation operation, and parameters, we can resolve to different operations. + * A lot of this is needed for legacy reasons in Databricks Runtime. + * @param metadata The table metadata, which we are creating or replacing + * @param isManagedTable Whether we are creating or replacing a managed table + * @param options Write options, if this was a CTAS/RTAS + */ + private def getOperation( + metadata: Metadata, + isManagedTable: Boolean, + options: Option[DeltaOptions]): DeltaOperations.Operation = operation match { + // This is legacy saveAsTable behavior in Databricks Runtime + case TableCreationModes.Create if existingTableOpt.isDefined && query.isDefined => + DeltaOperations.Write(mode, Option(table.partitionColumnNames), options.get.replaceWhere, + options.flatMap(_.userMetadata)) + + // DataSourceV2 table creation + // CREATE TABLE (non-DataFrameWriter API) doesn't have options syntax + // (userMetadata uses SQLConf in this case) + case TableCreationModes.Create => + DeltaOperations.CreateTable(metadata, isManagedTable, query.isDefined) + + // DataSourceV2 table replace + // REPLACE TABLE (non-DataFrameWriter API) doesn't have options syntax + // (userMetadata uses SQLConf in this case) + case TableCreationModes.Replace => + DeltaOperations.ReplaceTable(metadata, isManagedTable, orCreate = false, query.isDefined) + + // Legacy saveAsTable with Overwrite mode + case TableCreationModes.CreateOrReplace if options.exists(_.replaceWhere.isDefined) => + DeltaOperations.Write(mode, Option(table.partitionColumnNames), options.get.replaceWhere, + options.flatMap(_.userMetadata)) + + // New DataSourceV2 saveAsTable with overwrite mode behavior + case TableCreationModes.CreateOrReplace => + DeltaOperations.ReplaceTable(metadata, isManagedTable, orCreate = true, query.isDefined, + options.flatMap(_.userMetadata)) + } + + /** + * Similar to getOperation, here we disambiguate the catalog alterations we need to do based + * on the table operation, and whether we have reached here through legacy code or DataSourceV2 + * code paths. + */ + private def updateCatalog( + spark: SparkSession, + table: CatalogTable, + snapshot: Snapshot, + txn: OptimisticTransaction): Unit = { + val cleaned = cleanupTableDefinition(table, snapshot) + operation match { + case _ if tableByPath => // do nothing with the metastore if this is by path + case TableCreationModes.Create => + spark.sessionState.catalog.createTable( + cleaned, + ignoreIfExists = existingTableOpt.isDefined, + validateLocation = false) + case TableCreationModes.Replace | TableCreationModes.CreateOrReplace + if existingTableOpt.isDefined => + spark.sessionState.catalog.alterTable(table) + case TableCreationModes.Replace => + val ident = Identifier.of(table.identifier.database.toArray, table.identifier.table) + throw new CannotReplaceMissingTableException(ident) + case TableCreationModes.CreateOrReplace => + spark.sessionState.catalog.createTable( + cleaned, + ignoreIfExists = false, + validateLocation = false) + } + } + + /** Clean up the information we pass on to store in the catalog. */ + private def cleanupTableDefinition(table: CatalogTable, snapshot: Snapshot): CatalogTable = { + // These actually have no effect on the usability of Delta, but feature flagging legacy + // behavior for now + val storageProps = if (conf.getConf(DeltaSQLConf.DELTA_LEGACY_STORE_WRITER_OPTIONS_AS_PROPS)) { + // Legacy behavior + table.storage + } else { + table.storage.copy(properties = Map.empty) + } + + table.copy( + schema = new StructType(), + properties = Map.empty, + partitionColumnNames = Nil, + // Remove write specific options when updating the catalog + storage = storageProps, + tracksPartitionsInCatalog = true) + } + + /** + * With DataFrameWriterV2, methods like `replace()` or `createOrReplace()` mean that the + * metadata of the table should be replaced. If overwriteSchema=false is provided with these + * methods, then we will verify that the metadata match exactly. + */ + private def replaceMetadataIfNecessary( + txn: OptimisticTransaction, + tableDesc: CatalogTable, + options: DeltaOptions, + schema: StructType): Unit = { + val isReplace = (operation == TableCreationModes.CreateOrReplace || + operation == TableCreationModes.Replace) + // If a user explicitly specifies not to overwrite the schema, during a replace, we should + // tell them that it's not supported + val dontOverwriteSchema = options.options.contains(DeltaOptions.OVERWRITE_SCHEMA_OPTION) && + !options.canOverwriteSchema + if (isReplace && dontOverwriteSchema) { + throw DeltaErrors.illegalUsageException(DeltaOptions.OVERWRITE_SCHEMA_OPTION, "replacing") + } + if (txn.readVersion > -1L && isReplace && !dontOverwriteSchema) { + // When a table already exists, and we're using the DataFrameWriterV2 API to replace + // or createOrReplace a table, we blindly overwrite the metadata. + txn.updateMetadataForNewTable(getProvidedMetadata(table, schema.json)) + } + } + + /** + * Horrible hack to differentiate between DataFrameWriterV1 and V2 so that we can decide + * what to do with table metadata. In DataFrameWriterV1, mode("overwrite").saveAsTable, + * behaves as a CreateOrReplace table, but we have asked for "overwriteSchema" as an + * explicit option to overwrite partitioning or schema information. With DataFrameWriterV2, + * the behavior asked for by the user is clearer: .createOrReplace(), which means that we + * should overwrite schema and/or partitioning. Therefore we have this hack. + */ + private def isV1Writer: Boolean = { + Thread.currentThread().getStackTrace.exists(_.toString.contains( + classOf[DataFrameWriter[_]].getCanonicalName + ".")) + } +} diff --git a/delta-lake/delta-20x/src/main/scala/org/apache/spark/sql/delta/rapids/delta20x/GpuDeltaCatalog.scala b/delta-lake/delta-20x/src/main/scala/org/apache/spark/sql/delta/rapids/delta20x/GpuDeltaCatalog.scala new file mode 100644 index 00000000000..3c6a1e86457 --- /dev/null +++ b/delta-lake/delta-20x/src/main/scala/org/apache/spark/sql/delta/rapids/delta20x/GpuDeltaCatalog.scala @@ -0,0 +1,111 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * This file was derived from DeltaDataSource.scala in the + * Delta Lake project at https://github.com/delta-io/delta. + * + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta.rapids.delta20x + +import com.nvidia.spark.rapids.RapidsConf + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.delta.{DeltaConfigs, DeltaErrors} +import org.apache.spark.sql.delta.catalog.DeltaCatalog +import org.apache.spark.sql.delta.commands.TableCreationModes +import org.apache.spark.sql.delta.rapids.{GpuDeltaCatalogBase, SupportsPathIdentifier} +import org.apache.spark.sql.delta.sources.DeltaSourceUtils +import org.apache.spark.sql.execution.command.LeafRunnableCommand +import org.apache.spark.sql.execution.datasources.PartitioningUtils + +class GpuDeltaCatalog( + override val cpuCatalog: DeltaCatalog, + override val rapidsConf: RapidsConf) + extends GpuDeltaCatalogBase with SupportsPathIdentifier with Logging { + + override val spark: SparkSession = cpuCatalog.spark + + override protected def buildGpuCreateDeltaTableCommand( + rapidsConf: RapidsConf, + table: CatalogTable, + existingTableOpt: Option[CatalogTable], + mode: SaveMode, + query: Option[LogicalPlan], + operation: TableCreationModes.CreationMode, + tableByPath: Boolean): LeafRunnableCommand = { + GpuCreateDeltaTableCommand( + table, + existingTableOpt, + mode, + query, + operation, + tableByPath = tableByPath + )(rapidsConf) + } + + override protected def getExistingTableIfExists(table: TableIdentifier): Option[CatalogTable] = { + // If this is a path identifier, we cannot return an existing CatalogTable. The Create command + // will check the file system itself + if (isPathIdentifier(table)) return None + val tableExists = catalog.tableExists(table) + if (tableExists) { + val oldTable = catalog.getTableMetadata(table) + if (oldTable.tableType == CatalogTableType.VIEW) { + throw new AnalysisException( + s"$table is a view. You may not write data into a view.") + } + if (!DeltaSourceUtils.isDeltaTable(oldTable.provider)) { + throw DeltaErrors.notADeltaTable(table.table) + } + Some(oldTable) + } else { + None + } + } + + override protected def verifyTableAndSolidify( + tableDesc: CatalogTable, + query: Option[LogicalPlan]): CatalogTable = { + + if (tableDesc.bucketSpec.isDefined) { + throw DeltaErrors.operationNotSupportedException("Bucketing", tableDesc.identifier) + } + + val schema = query.map { plan => + assert(tableDesc.schema.isEmpty, "Can't specify table schema in CTAS.") + plan.schema.asNullable + }.getOrElse(tableDesc.schema) + + PartitioningUtils.validatePartitionColumn( + schema, + tableDesc.partitionColumnNames, + caseSensitive = false) // Delta is case insensitive + + val validatedConfigurations = DeltaConfigs.validateConfigurations(tableDesc.properties) + + val db = tableDesc.identifier.database.getOrElse(catalog.getCurrentDatabase) + val tableIdentWithDB = tableDesc.identifier.copy(database = Some(db)) + tableDesc.copy( + identifier = tableIdentWithDB, + schema = schema, + properties = validatedConfigurations) + } +} diff --git a/delta-lake/delta-21x/src/main/scala/com/nvidia/spark/rapids/delta/delta21x/DeleteCommandMeta.scala b/delta-lake/delta-21x/src/main/scala/com/nvidia/spark/rapids/delta/delta21x/DeleteCommandMeta.scala index 37cdc447077..dd8c44b77bf 100644 --- a/delta-lake/delta-21x/src/main/scala/com/nvidia/spark/rapids/delta/delta21x/DeleteCommandMeta.scala +++ b/delta-lake/delta-21x/src/main/scala/com/nvidia/spark/rapids/delta/delta21x/DeleteCommandMeta.scala @@ -37,7 +37,7 @@ class DeleteCommandMeta( willNotWorkOnGpu("Delta Lake output acceleration has been disabled. To enable set " + s"${RapidsConf.ENABLE_DELTA_WRITE} to true") } - RapidsDeltaUtils.tagForDeltaWrite(this, deleteCmd.target.schema, deleteCmd.deltaLog, + RapidsDeltaUtils.tagForDeltaWrite(this, deleteCmd.target.schema, Some(deleteCmd.deltaLog), Map.empty, SparkSession.active) } diff --git a/delta-lake/delta-21x/src/main/scala/com/nvidia/spark/rapids/delta/delta21x/Delta21xProvider.scala b/delta-lake/delta-21x/src/main/scala/com/nvidia/spark/rapids/delta/delta21x/Delta21xProvider.scala index 934fe7eb95b..4fcb2a993bc 100644 --- a/delta-lake/delta-21x/src/main/scala/com/nvidia/spark/rapids/delta/delta21x/Delta21xProvider.scala +++ b/delta-lake/delta-21x/src/main/scala/com/nvidia/spark/rapids/delta/delta21x/Delta21xProvider.scala @@ -16,14 +16,18 @@ package com.nvidia.spark.rapids.delta.delta21x -import com.nvidia.spark.rapids.{GpuOverrides, GpuReadParquetFileFormat, RunnableCommandRule, SparkPlanMeta} +import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, GpuExec, GpuOverrides, GpuReadParquetFileFormat, RunnableCommandRule, SparkPlanMeta} import com.nvidia.spark.rapids.delta.DeltaIOProvider import org.apache.spark.sql.delta.DeltaParquetFileFormat +import org.apache.spark.sql.delta.catalog.DeltaCatalog import org.apache.spark.sql.delta.commands.{DeleteCommand, MergeIntoCommand, UpdateCommand} +import org.apache.spark.sql.delta.rapids.DeltaRuntimeShim import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec +import org.apache.spark.sql.execution.datasources.v2.rapids.GpuAtomicCreateTableAsSelectExec object Delta21xProvider extends DeltaIOProvider { @@ -58,4 +62,19 @@ object Delta21xProvider extends DeltaIOProvider { val cpuFormat = format.asInstanceOf[DeltaParquetFileFormat] GpuDelta21xParquetFileFormat(cpuFormat.columnMappingMode, cpuFormat.referenceSchema) } + + override def convertToGpu( + cpuExec: AtomicCreateTableAsSelectExec, + meta: AtomicCreateTableAsSelectExecMeta): GpuExec = { + val cpuCatalog = cpuExec.catalog.asInstanceOf[DeltaCatalog] + GpuAtomicCreateTableAsSelectExec( + DeltaRuntimeShim.getGpuDeltaCatalog(cpuCatalog, meta.conf), + cpuExec.ident, + cpuExec.partitioning, + cpuExec.plan, + meta.childPlans.head.convertIfNeeded(), + cpuExec.tableSpec, + cpuExec.writeOptions, + cpuExec.ifNotExists) + } } diff --git a/delta-lake/delta-21x/src/main/scala/com/nvidia/spark/rapids/delta/delta21x/MergeIntoCommandMeta.scala b/delta-lake/delta-21x/src/main/scala/com/nvidia/spark/rapids/delta/delta21x/MergeIntoCommandMeta.scala index a7ad1bcd4fc..afe21a862f6 100644 --- a/delta-lake/delta-21x/src/main/scala/com/nvidia/spark/rapids/delta/delta21x/MergeIntoCommandMeta.scala +++ b/delta-lake/delta-21x/src/main/scala/com/nvidia/spark/rapids/delta/delta21x/MergeIntoCommandMeta.scala @@ -39,7 +39,8 @@ class MergeIntoCommandMeta( } val targetSchema = mergeCmd.migratedSchema.getOrElse(mergeCmd.target.schema) val deltaLog = mergeCmd.targetFileIndex.deltaLog - RapidsDeltaUtils.tagForDeltaWrite(this, targetSchema, deltaLog, Map.empty, SparkSession.active) + RapidsDeltaUtils.tagForDeltaWrite(this, targetSchema, Some(deltaLog), Map.empty, + SparkSession.active) } override def convertToGpu(): RunnableCommand = { diff --git a/delta-lake/delta-21x/src/main/scala/com/nvidia/spark/rapids/delta/delta21x/UpdateCommandMeta.scala b/delta-lake/delta-21x/src/main/scala/com/nvidia/spark/rapids/delta/delta21x/UpdateCommandMeta.scala index 659f857b158..c18ad778169 100644 --- a/delta-lake/delta-21x/src/main/scala/com/nvidia/spark/rapids/delta/delta21x/UpdateCommandMeta.scala +++ b/delta-lake/delta-21x/src/main/scala/com/nvidia/spark/rapids/delta/delta21x/UpdateCommandMeta.scala @@ -37,7 +37,7 @@ class UpdateCommandMeta( s"${RapidsConf.ENABLE_DELTA_WRITE} to true") } RapidsDeltaUtils.tagForDeltaWrite(this, updateCmd.target.schema, - updateCmd.tahoeFileIndex.deltaLog, Map.empty, updateCmd.tahoeFileIndex.spark) + Some(updateCmd.tahoeFileIndex.deltaLog), Map.empty, updateCmd.tahoeFileIndex.spark) } override def convertToGpu(): RunnableCommand = { diff --git a/delta-lake/delta-21x/src/main/scala/org/apache/spark/sql/delta/rapids/delta21x/Delta21xRuntimeShim.scala b/delta-lake/delta-21x/src/main/scala/org/apache/spark/sql/delta/rapids/delta21x/Delta21xRuntimeShim.scala index d4112034d99..8aaf6e7ca65 100644 --- a/delta-lake/delta-21x/src/main/scala/org/apache/spark/sql/delta/rapids/delta21x/Delta21xRuntimeShim.scala +++ b/delta-lake/delta-21x/src/main/scala/org/apache/spark/sql/delta/rapids/delta21x/Delta21xRuntimeShim.scala @@ -21,7 +21,9 @@ import com.nvidia.spark.rapids.delta.DeltaProvider import com.nvidia.spark.rapids.delta.delta21x.Delta21xProvider import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.catalog.StagingTableCatalog import org.apache.spark.sql.delta.{DeltaLog, DeltaUDF, Snapshot} +import org.apache.spark.sql.delta.catalog.DeltaCatalog import org.apache.spark.sql.delta.rapids.{DeltaRuntimeShim, GpuOptimisticTransactionBase} import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.expressions.UserDefinedFunction @@ -54,4 +56,9 @@ class Delta21xRuntimeShim extends DeltaRuntimeShim { override def getTightBoundColumnOnFileInitDisabled(spark: SparkSession): Boolean = false + override def getGpuDeltaCatalog( + cpuCatalog: DeltaCatalog, + rapidsConf: RapidsConf): StagingTableCatalog = { + new GpuDeltaCatalog(cpuCatalog, rapidsConf) + } } diff --git a/delta-lake/delta-21x/src/main/scala/org/apache/spark/sql/delta/rapids/delta21x/GpuCreateDeltaTableCommand.scala b/delta-lake/delta-21x/src/main/scala/org/apache/spark/sql/delta/rapids/delta21x/GpuCreateDeltaTableCommand.scala new file mode 100644 index 00000000000..2e32056d352 --- /dev/null +++ b/delta-lake/delta-21x/src/main/scala/org/apache/spark/sql/delta/rapids/delta21x/GpuCreateDeltaTableCommand.scala @@ -0,0 +1,466 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * This file was derived from CreateDeltaTableCommand.scala in the + * Delta Lake project at https://github.com/delta-io/delta. + * + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta.rapids.delta21x + +import com.nvidia.spark.rapids.RapidsConf +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.delta._ +import org.apache.spark.sql.delta.actions.Metadata +import org.apache.spark.sql.delta.commands.{TableCreationModes, WriteIntoDelta} +import org.apache.spark.sql.delta.metering.DeltaLogging +import org.apache.spark.sql.delta.rapids.GpuDeltaLog +import org.apache.spark.sql.delta.schema.SchemaUtils +import org.apache.spark.sql.delta.sources.DeltaSQLConf +import org.apache.spark.sql.execution.command.{LeafRunnableCommand, RunnableCommand} +import org.apache.spark.sql.types.StructType + + +/** + * Single entry point for all write or declaration operations for Delta tables accessed through + * the table name. + * + * @param table The table identifier for the Delta table + * @param existingTableOpt The existing table for the same identifier if exists + * @param mode The save mode when writing data. Relevant when the query is empty or set to Ignore + * with `CREATE TABLE IF NOT EXISTS`. + * @param query The query to commit into the Delta table if it exist. This can come from + * - CTAS + * - saveAsTable + */ +case class GpuCreateDeltaTableCommand( + table: CatalogTable, + existingTableOpt: Option[CatalogTable], + mode: SaveMode, + query: Option[LogicalPlan], + operation: TableCreationModes.CreationMode = TableCreationModes.Create, + tableByPath: Boolean = false, + override val output: Seq[Attribute] = Nil)(@transient rapidsConf: RapidsConf) + extends LeafRunnableCommand + with DeltaLogging { + + override def otherCopyArgs: Seq[AnyRef] = Seq(rapidsConf) + + override def run(sparkSession: SparkSession): Seq[Row] = { + val table = this.table + + assert(table.tableType != CatalogTableType.VIEW) + assert(table.identifier.database.isDefined, "Database should've been fixed at analysis") + // There is a subtle race condition here, where the table can be created by someone else + // while this command is running. Nothing we can do about that though :( + val tableExists = existingTableOpt.isDefined + if (mode == SaveMode.Ignore && tableExists) { + // Early exit on ignore + return Nil + } else if (mode == SaveMode.ErrorIfExists && tableExists) { + throw DeltaErrors.tableAlreadyExists(table) + } + + val tableWithLocation = if (tableExists) { + val existingTable = existingTableOpt.get + table.storage.locationUri match { + case Some(location) if location.getPath != existingTable.location.getPath => + val tableName = table.identifier.quotedString + throw new AnalysisException( + s"The location of the existing table $tableName is " + + s"`${existingTable.location}`. It doesn't match the specified location " + + s"`${table.location}`.") + case _ => + } + table.copy( + storage = existingTable.storage, + tableType = existingTable.tableType) + } else if (table.storage.locationUri.isEmpty) { + // We are defining a new managed table + assert(table.tableType == CatalogTableType.MANAGED) + val loc = sparkSession.sessionState.catalog.defaultTablePath(table.identifier) + table.copy(storage = table.storage.copy(locationUri = Some(loc))) + } else { + // 1. We are defining a new external table + // 2. It's a managed table which already has the location populated. This can happen in DSV2 + // CTAS flow. + table + } + + val isManagedTable = tableWithLocation.tableType == CatalogTableType.MANAGED + val tableLocation = new Path(tableWithLocation.location) + val gpuDeltaLog = GpuDeltaLog.forTable(sparkSession, tableLocation, rapidsConf) + val hadoopConf = gpuDeltaLog.deltaLog.newDeltaHadoopConf() + val fs = tableLocation.getFileSystem(hadoopConf) + val options = new DeltaOptions(table.storage.properties, sparkSession.sessionState.conf) + var result: Seq[Row] = Nil + + recordDeltaOperation(gpuDeltaLog.deltaLog, "delta.ddl.createTable") { + val txn = gpuDeltaLog.startTransaction() + if (query.isDefined) { + // If the mode is Ignore or ErrorIfExists, the table must not exist, or we would return + // earlier. And the data should not exist either, to match the behavior of + // Ignore/ErrorIfExists mode. This means the table path should not exist or is empty. + if (mode == SaveMode.Ignore || mode == SaveMode.ErrorIfExists) { + assert(!tableExists) + // We may have failed a previous write. The retry should still succeed even if we have + // garbage data + if (txn.readVersion > -1 || !fs.exists(gpuDeltaLog.deltaLog.logPath)) { + assertPathEmpty(hadoopConf, tableWithLocation) + } + } + // We are either appending/overwriting with saveAsTable or creating a new table with CTAS or + // we are creating a table as part of a RunnableCommand + query.get match { + case writer: WriteIntoDelta => + // In the V2 Writer, methods like "replace" and "createOrReplace" implicitly mean that + // the metadata should be changed. This wasn't the behavior for DataFrameWriterV1. + if (!isV1Writer) { + replaceMetadataIfNecessary( + txn, tableWithLocation, options, writer.data.schema.asNullable) + } + val actions = writer.write(txn, sparkSession) + val op = getOperation(txn.metadata, isManagedTable, Some(options)) + txn.commit(actions, op) + case cmd: RunnableCommand => + result = cmd.run(sparkSession) + case other => + // When using V1 APIs, the `other` plan is not yet optimized, therefore, it is safe + // to once again go through analysis + val data = Dataset.ofRows(sparkSession, other) + + // In the V2 Writer, methods like "replace" and "createOrReplace" implicitly mean that + // the metadata should be changed. This wasn't the behavior for DataFrameWriterV1. + if (!isV1Writer) { + replaceMetadataIfNecessary( + txn, tableWithLocation, options, other.schema.asNullable) + } + + val actions = WriteIntoDelta( + deltaLog = gpuDeltaLog.deltaLog, + mode = mode, + options, + partitionColumns = table.partitionColumnNames, + configuration = tableWithLocation.properties + ("comment" -> table.comment.orNull), + data = data).write(txn, sparkSession) + + val op = getOperation(txn.metadata, isManagedTable, Some(options)) + txn.commit(actions, op) + } + } else { + def createTransactionLogOrVerify(): Unit = { + if (isManagedTable) { + // When creating a managed table, the table path should not exist or is empty, or + // users would be surprised to see the data, or see the data directory being dropped + // after the table is dropped. + assertPathEmpty(hadoopConf, tableWithLocation) + } + + // This is either a new table, or, we never defined the schema of the table. While it is + // unexpected that `txn.metadata.schema` to be empty when txn.readVersion >= 0, we still + // guard against it, in case of checkpoint corruption bugs. + val noExistingMetadata = txn.readVersion == -1 || txn.metadata.schema.isEmpty + if (noExistingMetadata) { + assertTableSchemaDefined(fs, tableLocation, tableWithLocation, txn, sparkSession) + assertPathEmpty(hadoopConf, tableWithLocation) + // This is a user provided schema. + // Doesn't come from a query, Follow nullability invariants. + val newMetadata = getProvidedMetadata(tableWithLocation, table.schema.json) + txn.updateMetadataForNewTable(newMetadata) + + val op = getOperation(newMetadata, isManagedTable, None) + txn.commit(Nil, op) + } else { + verifyTableMetadata(txn, tableWithLocation) + } + } + // We are defining a table using the Create or Replace Table statements. + operation match { + case TableCreationModes.Create => + require(!tableExists, "Can't recreate a table when it exists") + createTransactionLogOrVerify() + + case TableCreationModes.CreateOrReplace if !tableExists => + // If the table doesn't exist, CREATE OR REPLACE must provide a schema + if (tableWithLocation.schema.isEmpty) { + throw DeltaErrors.schemaNotProvidedException + } + createTransactionLogOrVerify() + case _ => + // When the operation is a REPLACE or CREATE OR REPLACE, then the schema shouldn't be + // empty, since we'll use the entry to replace the schema + if (tableWithLocation.schema.isEmpty) { + throw DeltaErrors.schemaNotProvidedException + } + // We need to replace + replaceMetadataIfNecessary(txn, tableWithLocation, options, tableWithLocation.schema) + // Truncate the table + val operationTimestamp = System.currentTimeMillis() + val removes = txn.filterFiles().map(_.removeWithTimestamp(operationTimestamp)) + val op = getOperation(txn.metadata, isManagedTable, None) + txn.commit(removes, op) + } + } + + // We would have failed earlier on if we couldn't ignore the existence of the table + // In addition, we just might using saveAsTable to append to the table, so ignore the creation + // if it already exists. + // Note that someone may have dropped and recreated the table in a separate location in the + // meantime... Unfortunately we can't do anything there at the moment, because Hive sucks. + logInfo(s"Table is path-based table: $tableByPath. Update catalog with mode: $operation") + updateCatalog(sparkSession, tableWithLocation, gpuDeltaLog.deltaLog.snapshot, txn) + + result + } + } + + private def getProvidedMetadata(table: CatalogTable, schemaString: String): Metadata = { + Metadata( + description = table.comment.orNull, + schemaString = schemaString, + partitionColumns = table.partitionColumnNames, + configuration = table.properties, + createdTime = Some(System.currentTimeMillis())) + } + + private def assertPathEmpty( + hadoopConf: Configuration, + tableWithLocation: CatalogTable): Unit = { + val path = new Path(tableWithLocation.location) + val fs = path.getFileSystem(hadoopConf) + // Verify that the table location associated with CREATE TABLE doesn't have any data. Note that + // we intentionally diverge from this behavior w.r.t regular datasource tables (that silently + // overwrite any previous data) + if (fs.exists(path) && fs.listStatus(path).nonEmpty) { + throw DeltaErrors.createTableWithNonEmptyLocation( + tableWithLocation.identifier.toString, + tableWithLocation.location.toString) + } + } + + private def assertTableSchemaDefined( + fs: FileSystem, + path: Path, + table: CatalogTable, + txn: OptimisticTransaction, + sparkSession: SparkSession): Unit = { + // If we allow creating an empty schema table and indeed the table is new, we just need to + // make sure: + // 1. txn.readVersion == -1 to read a new table + // 2. for external tables: path must either doesn't exist or is completely empty + val allowCreatingTableWithEmptySchema = sparkSession.sessionState + .conf.getConf(DeltaSQLConf.DELTA_ALLOW_CREATE_EMPTY_SCHEMA_TABLE) && txn.readVersion == -1 + + // Users did not specify the schema. We expect the schema exists in Delta. + if (table.schema.isEmpty) { + if (table.tableType == CatalogTableType.EXTERNAL) { + if (fs.exists(path) && fs.listStatus(path).nonEmpty) { + throw DeltaErrors.createExternalTableWithoutLogException( + path, table.identifier.quotedString, sparkSession) + } else { + if (allowCreatingTableWithEmptySchema) return + throw DeltaErrors.createExternalTableWithoutSchemaException( + path, table.identifier.quotedString, sparkSession) + } + } else { + if (allowCreatingTableWithEmptySchema) return + throw DeltaErrors.createManagedTableWithoutSchemaException( + table.identifier.quotedString, sparkSession) + } + } + } + + /** + * Verify against our transaction metadata that the user specified the right metadata for the + * table. + */ + private def verifyTableMetadata( + txn: OptimisticTransaction, + tableDesc: CatalogTable): Unit = { + val existingMetadata = txn.metadata + val path = new Path(tableDesc.location) + + // The delta log already exists. If they give any configuration, we'll make sure it all matches. + // Otherwise we'll just go with the metadata already present in the log. + // The schema compatibility checks will be made in `WriteIntoDelta` for CreateTable + // with a query + if (txn.readVersion > -1) { + if (tableDesc.schema.nonEmpty) { + // We check exact alignment on create table if everything is provided + // However, if in column mapping mode, we can safely ignore the related metadata fields in + // existing metadata because new table desc will not have related metadata assigned yet + val differences = SchemaUtils.reportDifferences( + DeltaColumnMapping.dropColumnMappingMetadata(existingMetadata.schema), + tableDesc.schema) + if (differences.nonEmpty) { + throw DeltaErrors.createTableWithDifferentSchemaException( + path, tableDesc.schema, existingMetadata.schema, differences) + } + } + + // If schema is specified, we must make sure the partitioning matches, even the partitioning + // is not specified. + if (tableDesc.schema.nonEmpty && + tableDesc.partitionColumnNames != existingMetadata.partitionColumns) { + throw DeltaErrors.createTableWithDifferentPartitioningException( + path, tableDesc.partitionColumnNames, existingMetadata.partitionColumns) + } + + if (tableDesc.properties.nonEmpty && tableDesc.properties != existingMetadata.configuration) { + throw DeltaErrors.createTableWithDifferentPropertiesException( + path, tableDesc.properties, existingMetadata.configuration) + } + } + } + + /** + * Based on the table creation operation, and parameters, we can resolve to different operations. + * A lot of this is needed for legacy reasons in Databricks Runtime. + * @param metadata The table metadata, which we are creating or replacing + * @param isManagedTable Whether we are creating or replacing a managed table + * @param options Write options, if this was a CTAS/RTAS + */ + private def getOperation( + metadata: Metadata, + isManagedTable: Boolean, + options: Option[DeltaOptions]): DeltaOperations.Operation = operation match { + // This is legacy saveAsTable behavior in Databricks Runtime + case TableCreationModes.Create if existingTableOpt.isDefined && query.isDefined => + DeltaOperations.Write(mode, Option(table.partitionColumnNames), options.get.replaceWhere, + options.flatMap(_.userMetadata)) + + // DataSourceV2 table creation + // CREATE TABLE (non-DataFrameWriter API) doesn't have options syntax + // (userMetadata uses SQLConf in this case) + case TableCreationModes.Create => + DeltaOperations.CreateTable(metadata, isManagedTable, query.isDefined) + + // DataSourceV2 table replace + // REPLACE TABLE (non-DataFrameWriter API) doesn't have options syntax + // (userMetadata uses SQLConf in this case) + case TableCreationModes.Replace => + DeltaOperations.ReplaceTable(metadata, isManagedTable, orCreate = false, query.isDefined) + + // Legacy saveAsTable with Overwrite mode + case TableCreationModes.CreateOrReplace if options.exists(_.replaceWhere.isDefined) => + DeltaOperations.Write(mode, Option(table.partitionColumnNames), options.get.replaceWhere, + options.flatMap(_.userMetadata)) + + // New DataSourceV2 saveAsTable with overwrite mode behavior + case TableCreationModes.CreateOrReplace => + DeltaOperations.ReplaceTable(metadata, isManagedTable, orCreate = true, query.isDefined, + options.flatMap(_.userMetadata)) + } + + /** + * Similar to getOperation, here we disambiguate the catalog alterations we need to do based + * on the table operation, and whether we have reached here through legacy code or DataSourceV2 + * code paths. + */ + private def updateCatalog( + spark: SparkSession, + table: CatalogTable, + snapshot: Snapshot, + txn: OptimisticTransaction): Unit = { + val cleaned = cleanupTableDefinition(table, snapshot) + operation match { + case _ if tableByPath => // do nothing with the metastore if this is by path + case TableCreationModes.Create => + spark.sessionState.catalog.createTable( + cleaned, + ignoreIfExists = existingTableOpt.isDefined, + validateLocation = false) + case TableCreationModes.Replace | TableCreationModes.CreateOrReplace + if existingTableOpt.isDefined => + spark.sessionState.catalog.alterTable(table) + case TableCreationModes.Replace => + val ident = Identifier.of(table.identifier.database.toArray, table.identifier.table) + throw new CannotReplaceMissingTableException(ident) + case TableCreationModes.CreateOrReplace => + spark.sessionState.catalog.createTable( + cleaned, + ignoreIfExists = false, + validateLocation = false) + } + } + + /** Clean up the information we pass on to store in the catalog. */ + private def cleanupTableDefinition(table: CatalogTable, snapshot: Snapshot): CatalogTable = { + // These actually have no effect on the usability of Delta, but feature flagging legacy + // behavior for now + val storageProps = if (conf.getConf(DeltaSQLConf.DELTA_LEGACY_STORE_WRITER_OPTIONS_AS_PROPS)) { + // Legacy behavior + table.storage + } else { + table.storage.copy(properties = Map.empty) + } + + table.copy( + schema = new StructType(), + properties = Map.empty, + partitionColumnNames = Nil, + // Remove write specific options when updating the catalog + storage = storageProps, + tracksPartitionsInCatalog = true) + } + + /** + * With DataFrameWriterV2, methods like `replace()` or `createOrReplace()` mean that the + * metadata of the table should be replaced. If overwriteSchema=false is provided with these + * methods, then we will verify that the metadata match exactly. + */ + private def replaceMetadataIfNecessary( + txn: OptimisticTransaction, + tableDesc: CatalogTable, + options: DeltaOptions, + schema: StructType): Unit = { + val isReplace = (operation == TableCreationModes.CreateOrReplace || + operation == TableCreationModes.Replace) + // If a user explicitly specifies not to overwrite the schema, during a replace, we should + // tell them that it's not supported + val dontOverwriteSchema = options.options.contains(DeltaOptions.OVERWRITE_SCHEMA_OPTION) && + !options.canOverwriteSchema + if (isReplace && dontOverwriteSchema) { + throw DeltaErrors.illegalUsageException(DeltaOptions.OVERWRITE_SCHEMA_OPTION, "replacing") + } + if (txn.readVersion > -1L && isReplace && !dontOverwriteSchema) { + // When a table already exists, and we're using the DataFrameWriterV2 API to replace + // or createOrReplace a table, we blindly overwrite the metadata. + txn.updateMetadataForNewTable(getProvidedMetadata(table, schema.json)) + } + } + + /** + * Horrible hack to differentiate between DataFrameWriterV1 and V2 so that we can decide + * what to do with table metadata. In DataFrameWriterV1, mode("overwrite").saveAsTable, + * behaves as a CreateOrReplace table, but we have asked for "overwriteSchema" as an + * explicit option to overwrite partitioning or schema information. With DataFrameWriterV2, + * the behavior asked for by the user is clearer: .createOrReplace(), which means that we + * should overwrite schema and/or partitioning. Therefore we have this hack. + */ + private def isV1Writer: Boolean = { + Thread.currentThread().getStackTrace.exists(_.toString.contains( + classOf[DataFrameWriter[_]].getCanonicalName + ".")) + } +} diff --git a/delta-lake/delta-21x/src/main/scala/org/apache/spark/sql/delta/rapids/delta21x/GpuDeltaCatalog.scala b/delta-lake/delta-21x/src/main/scala/org/apache/spark/sql/delta/rapids/delta21x/GpuDeltaCatalog.scala new file mode 100644 index 00000000000..20904641071 --- /dev/null +++ b/delta-lake/delta-21x/src/main/scala/org/apache/spark/sql/delta/rapids/delta21x/GpuDeltaCatalog.scala @@ -0,0 +1,120 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * This file was derived from DeltaDataSource.scala in the + * Delta Lake project at https://github.com/delta-io/delta. + * + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta.rapids.delta21x + +import com.nvidia.spark.rapids.RapidsConf + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.connector.catalog.{Identifier, Table} +import org.apache.spark.sql.delta.{DeltaConfigs, DeltaErrors} +import org.apache.spark.sql.delta.catalog.DeltaCatalog +import org.apache.spark.sql.delta.commands.TableCreationModes +import org.apache.spark.sql.delta.rapids.{GpuDeltaCatalogBase, SupportsPathIdentifier} +import org.apache.spark.sql.delta.sources.DeltaSourceUtils +import org.apache.spark.sql.execution.command.LeafRunnableCommand +import org.apache.spark.sql.execution.datasources.PartitioningUtils + +class GpuDeltaCatalog( + override val cpuCatalog: DeltaCatalog, + override val rapidsConf: RapidsConf) + extends GpuDeltaCatalogBase with SupportsPathIdentifier with Logging { + + override val spark: SparkSession = cpuCatalog.spark + + override protected def buildGpuCreateDeltaTableCommand( + rapidsConf: RapidsConf, + table: CatalogTable, + existingTableOpt: Option[CatalogTable], + mode: SaveMode, + query: Option[LogicalPlan], + operation: TableCreationModes.CreationMode, + tableByPath: Boolean): LeafRunnableCommand = { + GpuCreateDeltaTableCommand( + table, + existingTableOpt, + mode, + query, + operation, + tableByPath = tableByPath + )(rapidsConf) + } + + override protected def getExistingTableIfExists(table: TableIdentifier): Option[CatalogTable] = { + // If this is a path identifier, we cannot return an existing CatalogTable. The Create command + // will check the file system itself + if (isPathIdentifier(table)) return None + val tableExists = catalog.tableExists(table) + if (tableExists) { + val oldTable = catalog.getTableMetadata(table) + if (oldTable.tableType == CatalogTableType.VIEW) { + throw new AnalysisException( + s"$table is a view. You may not write data into a view.") + } + if (!DeltaSourceUtils.isDeltaTable(oldTable.provider)) { + throw DeltaErrors.notADeltaTable(table.table) + } + Some(oldTable) + } else { + None + } + } + + override protected def verifyTableAndSolidify( + tableDesc: CatalogTable, + query: Option[LogicalPlan]): CatalogTable = { + + if (tableDesc.bucketSpec.isDefined) { + throw DeltaErrors.operationNotSupportedException("Bucketing", tableDesc.identifier) + } + + val schema = query.map { plan => + assert(tableDesc.schema.isEmpty, "Can't specify table schema in CTAS.") + plan.schema.asNullable + }.getOrElse(tableDesc.schema) + + PartitioningUtils.validatePartitionColumn( + schema, + tableDesc.partitionColumnNames, + caseSensitive = false) // Delta is case insensitive + + val validatedConfigurations = DeltaConfigs.validateConfigurations(tableDesc.properties) + + val db = tableDesc.identifier.database.getOrElse(catalog.getCurrentDatabase) + val tableIdentWithDB = tableDesc.identifier.copy(database = Some(db)) + tableDesc.copy( + identifier = tableIdentWithDB, + schema = schema, + properties = validatedConfigurations) + } + + override def loadTable(ident: Identifier, timestamp: Long): Table = { + cpuCatalog.loadTable(ident, timestamp) + } + + override def loadTable(ident: Identifier, version: String): Table = { + cpuCatalog.loadTable(ident, version) + } +} diff --git a/delta-lake/delta-22x/src/main/scala/com/nvidia/spark/rapids/delta/delta22x/DeleteCommandMeta.scala b/delta-lake/delta-22x/src/main/scala/com/nvidia/spark/rapids/delta/delta22x/DeleteCommandMeta.scala index 7fe489ce334..9145be60ccb 100644 --- a/delta-lake/delta-22x/src/main/scala/com/nvidia/spark/rapids/delta/delta22x/DeleteCommandMeta.scala +++ b/delta-lake/delta-22x/src/main/scala/com/nvidia/spark/rapids/delta/delta22x/DeleteCommandMeta.scala @@ -37,7 +37,7 @@ class DeleteCommandMeta( willNotWorkOnGpu("Delta Lake output acceleration has been disabled. To enable set " + s"${RapidsConf.ENABLE_DELTA_WRITE} to true") } - RapidsDeltaUtils.tagForDeltaWrite(this, deleteCmd.target.schema, deleteCmd.deltaLog, + RapidsDeltaUtils.tagForDeltaWrite(this, deleteCmd.target.schema, Some(deleteCmd.deltaLog), Map.empty, SparkSession.active) } diff --git a/delta-lake/delta-22x/src/main/scala/com/nvidia/spark/rapids/delta/delta22x/Delta22xProvider.scala b/delta-lake/delta-22x/src/main/scala/com/nvidia/spark/rapids/delta/delta22x/Delta22xProvider.scala index d5b5a79c7bc..73ead6f45e4 100644 --- a/delta-lake/delta-22x/src/main/scala/com/nvidia/spark/rapids/delta/delta22x/Delta22xProvider.scala +++ b/delta-lake/delta-22x/src/main/scala/com/nvidia/spark/rapids/delta/delta22x/Delta22xProvider.scala @@ -16,14 +16,18 @@ package com.nvidia.spark.rapids.delta.delta22x -import com.nvidia.spark.rapids.{GpuOverrides, GpuReadParquetFileFormat, RunnableCommandRule, SparkPlanMeta} +import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, GpuExec, GpuOverrides, GpuReadParquetFileFormat, RunnableCommandRule, SparkPlanMeta} import com.nvidia.spark.rapids.delta.DeltaIOProvider import org.apache.spark.sql.delta.DeltaParquetFileFormat +import org.apache.spark.sql.delta.catalog.DeltaCatalog import org.apache.spark.sql.delta.commands.{DeleteCommand, MergeIntoCommand, UpdateCommand} +import org.apache.spark.sql.delta.rapids.DeltaRuntimeShim import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec +import org.apache.spark.sql.execution.datasources.v2.rapids.GpuAtomicCreateTableAsSelectExec object Delta22xProvider extends DeltaIOProvider { @@ -58,4 +62,19 @@ object Delta22xProvider extends DeltaIOProvider { val cpuFormat = format.asInstanceOf[DeltaParquetFileFormat] GpuDelta22xParquetFileFormat(cpuFormat.columnMappingMode, cpuFormat.referenceSchema) } + + override def convertToGpu( + cpuExec: AtomicCreateTableAsSelectExec, + meta: AtomicCreateTableAsSelectExecMeta): GpuExec = { + val cpuCatalog = cpuExec.catalog.asInstanceOf[DeltaCatalog] + GpuAtomicCreateTableAsSelectExec( + DeltaRuntimeShim.getGpuDeltaCatalog(cpuCatalog, meta.conf), + cpuExec.ident, + cpuExec.partitioning, + cpuExec.plan, + meta.childPlans.head.convertIfNeeded(), + cpuExec.tableSpec, + cpuExec.writeOptions, + cpuExec.ifNotExists) + } } diff --git a/delta-lake/delta-22x/src/main/scala/com/nvidia/spark/rapids/delta/delta22x/MergeIntoCommandMeta.scala b/delta-lake/delta-22x/src/main/scala/com/nvidia/spark/rapids/delta/delta22x/MergeIntoCommandMeta.scala index 259de1c448d..df48b1f0f66 100644 --- a/delta-lake/delta-22x/src/main/scala/com/nvidia/spark/rapids/delta/delta22x/MergeIntoCommandMeta.scala +++ b/delta-lake/delta-22x/src/main/scala/com/nvidia/spark/rapids/delta/delta22x/MergeIntoCommandMeta.scala @@ -39,7 +39,8 @@ class MergeIntoCommandMeta( } val targetSchema = mergeCmd.migratedSchema.getOrElse(mergeCmd.target.schema) val deltaLog = mergeCmd.targetFileIndex.deltaLog - RapidsDeltaUtils.tagForDeltaWrite(this, targetSchema, deltaLog, Map.empty, SparkSession.active) + RapidsDeltaUtils.tagForDeltaWrite(this, targetSchema, Some(deltaLog), Map.empty, + SparkSession.active) } override def convertToGpu(): RunnableCommand = { diff --git a/delta-lake/delta-22x/src/main/scala/com/nvidia/spark/rapids/delta/delta22x/UpdateCommandMeta.scala b/delta-lake/delta-22x/src/main/scala/com/nvidia/spark/rapids/delta/delta22x/UpdateCommandMeta.scala index 4cd1c364e0a..e5857ea74b8 100644 --- a/delta-lake/delta-22x/src/main/scala/com/nvidia/spark/rapids/delta/delta22x/UpdateCommandMeta.scala +++ b/delta-lake/delta-22x/src/main/scala/com/nvidia/spark/rapids/delta/delta22x/UpdateCommandMeta.scala @@ -37,7 +37,7 @@ class UpdateCommandMeta( s"${RapidsConf.ENABLE_DELTA_WRITE} to true") } RapidsDeltaUtils.tagForDeltaWrite(this, updateCmd.target.schema, - updateCmd.tahoeFileIndex.deltaLog, Map.empty, updateCmd.tahoeFileIndex.spark) + Some(updateCmd.tahoeFileIndex.deltaLog), Map.empty, updateCmd.tahoeFileIndex.spark) } override def convertToGpu(): RunnableCommand = { diff --git a/delta-lake/delta-22x/src/main/scala/org/apache/spark/sql/delta/rapids/delta22x/Delta22xRuntimeShim.scala b/delta-lake/delta-22x/src/main/scala/org/apache/spark/sql/delta/rapids/delta22x/Delta22xRuntimeShim.scala index ae885495f53..b261f3127c0 100644 --- a/delta-lake/delta-22x/src/main/scala/org/apache/spark/sql/delta/rapids/delta22x/Delta22xRuntimeShim.scala +++ b/delta-lake/delta-22x/src/main/scala/org/apache/spark/sql/delta/rapids/delta22x/Delta22xRuntimeShim.scala @@ -21,7 +21,9 @@ import com.nvidia.spark.rapids.delta.DeltaProvider import com.nvidia.spark.rapids.delta.delta22x.Delta22xProvider import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.catalog.StagingTableCatalog import org.apache.spark.sql.delta.{DeltaLog, DeltaUDF, Snapshot} +import org.apache.spark.sql.delta.catalog.DeltaCatalog import org.apache.spark.sql.delta.rapids.{DeltaRuntimeShim, GpuOptimisticTransactionBase} import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.expressions.UserDefinedFunction @@ -49,4 +51,10 @@ class Delta22xRuntimeShim extends DeltaRuntimeShim { deltaLog.fileFormat() override def getTightBoundColumnOnFileInitDisabled(spark: SparkSession): Boolean = false + + override def getGpuDeltaCatalog( + cpuCatalog: DeltaCatalog, + rapidsConf: RapidsConf): StagingTableCatalog = { + new GpuDeltaCatalog(cpuCatalog, rapidsConf) + } } diff --git a/delta-lake/delta-22x/src/main/scala/org/apache/spark/sql/delta/rapids/delta22x/GpuCreateDeltaTableCommand.scala b/delta-lake/delta-22x/src/main/scala/org/apache/spark/sql/delta/rapids/delta22x/GpuCreateDeltaTableCommand.scala new file mode 100644 index 00000000000..a412c8fb5c0 --- /dev/null +++ b/delta-lake/delta-22x/src/main/scala/org/apache/spark/sql/delta/rapids/delta22x/GpuCreateDeltaTableCommand.scala @@ -0,0 +1,465 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * This file was derived from CreateDeltaTableCommand.scala in the + * Delta Lake project at https://github.com/delta-io/delta. + * + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta.rapids.delta22x + +import com.nvidia.spark.rapids.RapidsConf +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.delta._ +import org.apache.spark.sql.delta.actions.Metadata +import org.apache.spark.sql.delta.commands.{TableCreationModes, WriteIntoDelta} +import org.apache.spark.sql.delta.metering.DeltaLogging +import org.apache.spark.sql.delta.rapids.GpuDeltaLog +import org.apache.spark.sql.delta.schema.SchemaUtils +import org.apache.spark.sql.delta.sources.DeltaSQLConf +import org.apache.spark.sql.execution.command.{LeafRunnableCommand, RunnableCommand} +import org.apache.spark.sql.types.StructType + +/** + * Single entry point for all write or declaration operations for Delta tables accessed through + * the table name. + * + * @param table The table identifier for the Delta table + * @param existingTableOpt The existing table for the same identifier if exists + * @param mode The save mode when writing data. Relevant when the query is empty or set to Ignore + * with `CREATE TABLE IF NOT EXISTS`. + * @param query The query to commit into the Delta table if it exist. This can come from + * - CTAS + * - saveAsTable + */ +case class GpuCreateDeltaTableCommand( + table: CatalogTable, + existingTableOpt: Option[CatalogTable], + mode: SaveMode, + query: Option[LogicalPlan], + operation: TableCreationModes.CreationMode = TableCreationModes.Create, + tableByPath: Boolean = false, + override val output: Seq[Attribute] = Nil)(@transient rapidsConf: RapidsConf) + extends LeafRunnableCommand + with DeltaLogging { + + override def otherCopyArgs: Seq[AnyRef] = Seq(rapidsConf) + + override def run(sparkSession: SparkSession): Seq[Row] = { + val table = this.table + + assert(table.tableType != CatalogTableType.VIEW) + assert(table.identifier.database.isDefined, "Database should've been fixed at analysis") + // There is a subtle race condition here, where the table can be created by someone else + // while this command is running. Nothing we can do about that though :( + val tableExists = existingTableOpt.isDefined + if (mode == SaveMode.Ignore && tableExists) { + // Early exit on ignore + return Nil + } else if (mode == SaveMode.ErrorIfExists && tableExists) { + throw DeltaErrors.tableAlreadyExists(table) + } + + val tableWithLocation = if (tableExists) { + val existingTable = existingTableOpt.get + table.storage.locationUri match { + case Some(location) if location.getPath != existingTable.location.getPath => + throw DeltaErrors.tableLocationMismatch(table, existingTable) + case _ => + } + table.copy( + storage = existingTable.storage, + tableType = existingTable.tableType) + } else if (table.storage.locationUri.isEmpty) { + // We are defining a new managed table + assert(table.tableType == CatalogTableType.MANAGED) + val loc = sparkSession.sessionState.catalog.defaultTablePath(table.identifier) + table.copy(storage = table.storage.copy(locationUri = Some(loc))) + } else { + // 1. We are defining a new external table + // 2. It's a managed table which already has the location populated. This can happen in DSV2 + // CTAS flow. + table + } + + val isManagedTable = tableWithLocation.tableType == CatalogTableType.MANAGED + val tableLocation = new Path(tableWithLocation.location) + val gpuDeltaLog = GpuDeltaLog.forTable(sparkSession, tableLocation, rapidsConf) + val hadoopConf = gpuDeltaLog.deltaLog.newDeltaHadoopConf() + val fs = tableLocation.getFileSystem(hadoopConf) + val options = new DeltaOptions(table.storage.properties, sparkSession.sessionState.conf) + var result: Seq[Row] = Nil + + recordDeltaOperation(gpuDeltaLog.deltaLog, "delta.ddl.createTable") { + val txn = gpuDeltaLog.startTransaction() + val opStartTs = System.currentTimeMillis() + if (query.isDefined) { + // If the mode is Ignore or ErrorIfExists, the table must not exist, or we would return + // earlier. And the data should not exist either, to match the behavior of + // Ignore/ErrorIfExists mode. This means the table path should not exist or is empty. + if (mode == SaveMode.Ignore || mode == SaveMode.ErrorIfExists) { + assert(!tableExists) + // We may have failed a previous write. The retry should still succeed even if we have + // garbage data + if (txn.readVersion > -1 || !fs.exists(gpuDeltaLog.deltaLog.logPath)) { + assertPathEmpty(hadoopConf, tableWithLocation) + } + } + // We are either appending/overwriting with saveAsTable or creating a new table with CTAS or + // we are creating a table as part of a RunnableCommand + query.get match { + case writer: WriteIntoDelta => + // In the V2 Writer, methods like "replace" and "createOrReplace" implicitly mean that + // the metadata should be changed. This wasn't the behavior for DataFrameWriterV1. + if (!isV1Writer) { + replaceMetadataIfNecessary( + txn, tableWithLocation, options, writer.data.schema.asNullable) + } + val actions = writer.write(txn, sparkSession) + val op = getOperation(txn.metadata, isManagedTable, Some(options)) + txn.commit(actions, op) + case cmd: RunnableCommand => + result = cmd.run(sparkSession) + case other => + // When using V1 APIs, the `other` plan is not yet optimized, therefore, it is safe + // to once again go through analysis + val data = Dataset.ofRows(sparkSession, other) + + // In the V2 Writer, methods like "replace" and "createOrReplace" implicitly mean that + // the metadata should be changed. This wasn't the behavior for DataFrameWriterV1. + if (!isV1Writer) { + replaceMetadataIfNecessary( + txn, tableWithLocation, options, other.schema.asNullable) + } + + val actions = WriteIntoDelta( + deltaLog = gpuDeltaLog.deltaLog, + mode = mode, + options, + partitionColumns = table.partitionColumnNames, + configuration = tableWithLocation.properties + ("comment" -> table.comment.orNull), + data = data).write(txn, sparkSession) + + val op = getOperation(txn.metadata, isManagedTable, Some(options)) + txn.commit(actions, op) + } + } else { + def createTransactionLogOrVerify(): Unit = { + if (isManagedTable) { + // When creating a managed table, the table path should not exist or is empty, or + // users would be surprised to see the data, or see the data directory being dropped + // after the table is dropped. + assertPathEmpty(hadoopConf, tableWithLocation) + } + + // This is either a new table, or, we never defined the schema of the table. While it is + // unexpected that `txn.metadata.schema` to be empty when txn.readVersion >= 0, we still + // guard against it, in case of checkpoint corruption bugs. + val noExistingMetadata = txn.readVersion == -1 || txn.metadata.schema.isEmpty + if (noExistingMetadata) { + assertTableSchemaDefined(fs, tableLocation, tableWithLocation, txn, sparkSession) + assertPathEmpty(hadoopConf, tableWithLocation) + // This is a user provided schema. + // Doesn't come from a query, Follow nullability invariants. + val newMetadata = getProvidedMetadata(tableWithLocation, table.schema.json) + txn.updateMetadataForNewTable(newMetadata) + + val op = getOperation(newMetadata, isManagedTable, None) + txn.commit(Nil, op) + } else { + verifyTableMetadata(txn, tableWithLocation) + } + } + // We are defining a table using the Create or Replace Table statements. + operation match { + case TableCreationModes.Create => + require(!tableExists, "Can't recreate a table when it exists") + createTransactionLogOrVerify() + + case TableCreationModes.CreateOrReplace if !tableExists => + // If the table doesn't exist, CREATE OR REPLACE must provide a schema + if (tableWithLocation.schema.isEmpty) { + throw DeltaErrors.schemaNotProvidedException + } + createTransactionLogOrVerify() + case _ => + // When the operation is a REPLACE or CREATE OR REPLACE, then the schema shouldn't be + // empty, since we'll use the entry to replace the schema + if (tableWithLocation.schema.isEmpty) { + throw DeltaErrors.schemaNotProvidedException + } + // We need to replace + replaceMetadataIfNecessary(txn, tableWithLocation, options, tableWithLocation.schema) + // Truncate the table + val operationTimestamp = System.currentTimeMillis() + val removes = txn.filterFiles().map(_.removeWithTimestamp(operationTimestamp)) + val op = getOperation(txn.metadata, isManagedTable, None) + txn.commit(removes, op) + } + } + + // We would have failed earlier on if we couldn't ignore the existence of the table + // In addition, we just might using saveAsTable to append to the table, so ignore the creation + // if it already exists. + // Note that someone may have dropped and recreated the table in a separate location in the + // meantime... Unfortunately we can't do anything there at the moment, because Hive sucks. + logInfo(s"Table is path-based table: $tableByPath. Update catalog with mode: $operation") + updateCatalog( + sparkSession, + tableWithLocation, + gpuDeltaLog.deltaLog.update(checkIfUpdatedSinceTs = Some(opStartTs)), + txn) + + result + } + } + + private def getProvidedMetadata(table: CatalogTable, schemaString: String): Metadata = { + Metadata( + description = table.comment.orNull, + schemaString = schemaString, + partitionColumns = table.partitionColumnNames, + configuration = table.properties, + createdTime = Some(System.currentTimeMillis())) + } + + private def assertPathEmpty( + hadoopConf: Configuration, + tableWithLocation: CatalogTable): Unit = { + val path = new Path(tableWithLocation.location) + val fs = path.getFileSystem(hadoopConf) + // Verify that the table location associated with CREATE TABLE doesn't have any data. Note that + // we intentionally diverge from this behavior w.r.t regular datasource tables (that silently + // overwrite any previous data) + if (fs.exists(path) && fs.listStatus(path).nonEmpty) { + throw DeltaErrors.createTableWithNonEmptyLocation( + tableWithLocation.identifier.toString, + tableWithLocation.location.toString) + } + } + + private def assertTableSchemaDefined( + fs: FileSystem, + path: Path, + table: CatalogTable, + txn: OptimisticTransaction, + sparkSession: SparkSession): Unit = { + // If we allow creating an empty schema table and indeed the table is new, we just need to + // make sure: + // 1. txn.readVersion == -1 to read a new table + // 2. for external tables: path must either doesn't exist or is completely empty + val allowCreatingTableWithEmptySchema = sparkSession.sessionState + .conf.getConf(DeltaSQLConf.DELTA_ALLOW_CREATE_EMPTY_SCHEMA_TABLE) && txn.readVersion == -1 + + // Users did not specify the schema. We expect the schema exists in Delta. + if (table.schema.isEmpty) { + if (table.tableType == CatalogTableType.EXTERNAL) { + if (fs.exists(path) && fs.listStatus(path).nonEmpty) { + throw DeltaErrors.createExternalTableWithoutLogException( + path, table.identifier.quotedString, sparkSession) + } else { + if (allowCreatingTableWithEmptySchema) return + throw DeltaErrors.createExternalTableWithoutSchemaException( + path, table.identifier.quotedString, sparkSession) + } + } else { + if (allowCreatingTableWithEmptySchema) return + throw DeltaErrors.createManagedTableWithoutSchemaException( + table.identifier.quotedString, sparkSession) + } + } + } + + /** + * Verify against our transaction metadata that the user specified the right metadata for the + * table. + */ + private def verifyTableMetadata( + txn: OptimisticTransaction, + tableDesc: CatalogTable): Unit = { + val existingMetadata = txn.metadata + val path = new Path(tableDesc.location) + + // The delta log already exists. If they give any configuration, we'll make sure it all matches. + // Otherwise we'll just go with the metadata already present in the log. + // The schema compatibility checks will be made in `WriteIntoDelta` for CreateTable + // with a query + if (txn.readVersion > -1) { + if (tableDesc.schema.nonEmpty) { + // We check exact alignment on create table if everything is provided + // However, if in column mapping mode, we can safely ignore the related metadata fields in + // existing metadata because new table desc will not have related metadata assigned yet + val differences = SchemaUtils.reportDifferences( + DeltaColumnMapping.dropColumnMappingMetadata(existingMetadata.schema), + tableDesc.schema) + if (differences.nonEmpty) { + throw DeltaErrors.createTableWithDifferentSchemaException( + path, tableDesc.schema, existingMetadata.schema, differences) + } + } + + // If schema is specified, we must make sure the partitioning matches, even the partitioning + // is not specified. + if (tableDesc.schema.nonEmpty && + tableDesc.partitionColumnNames != existingMetadata.partitionColumns) { + throw DeltaErrors.createTableWithDifferentPartitioningException( + path, tableDesc.partitionColumnNames, existingMetadata.partitionColumns) + } + + if (tableDesc.properties.nonEmpty && tableDesc.properties != existingMetadata.configuration) { + throw DeltaErrors.createTableWithDifferentPropertiesException( + path, tableDesc.properties, existingMetadata.configuration) + } + } + } + + /** + * Based on the table creation operation, and parameters, we can resolve to different operations. + * A lot of this is needed for legacy reasons in Databricks Runtime. + * @param metadata The table metadata, which we are creating or replacing + * @param isManagedTable Whether we are creating or replacing a managed table + * @param options Write options, if this was a CTAS/RTAS + */ + private def getOperation( + metadata: Metadata, + isManagedTable: Boolean, + options: Option[DeltaOptions]): DeltaOperations.Operation = operation match { + // This is legacy saveAsTable behavior in Databricks Runtime + case TableCreationModes.Create if existingTableOpt.isDefined && query.isDefined => + DeltaOperations.Write(mode, Option(table.partitionColumnNames), options.get.replaceWhere, + options.flatMap(_.userMetadata)) + + // DataSourceV2 table creation + // CREATE TABLE (non-DataFrameWriter API) doesn't have options syntax + // (userMetadata uses SQLConf in this case) + case TableCreationModes.Create => + DeltaOperations.CreateTable(metadata, isManagedTable, query.isDefined) + + // DataSourceV2 table replace + // REPLACE TABLE (non-DataFrameWriter API) doesn't have options syntax + // (userMetadata uses SQLConf in this case) + case TableCreationModes.Replace => + DeltaOperations.ReplaceTable(metadata, isManagedTable, orCreate = false, query.isDefined) + + // Legacy saveAsTable with Overwrite mode + case TableCreationModes.CreateOrReplace if options.exists(_.replaceWhere.isDefined) => + DeltaOperations.Write(mode, Option(table.partitionColumnNames), options.get.replaceWhere, + options.flatMap(_.userMetadata)) + + // New DataSourceV2 saveAsTable with overwrite mode behavior + case TableCreationModes.CreateOrReplace => + DeltaOperations.ReplaceTable(metadata, isManagedTable, orCreate = true, query.isDefined, + options.flatMap(_.userMetadata)) + } + + /** + * Similar to getOperation, here we disambiguate the catalog alterations we need to do based + * on the table operation, and whether we have reached here through legacy code or DataSourceV2 + * code paths. + */ + private def updateCatalog( + spark: SparkSession, + table: CatalogTable, + snapshot: Snapshot, + txn: OptimisticTransaction): Unit = { + val cleaned = cleanupTableDefinition(table, snapshot) + operation match { + case _ if tableByPath => // do nothing with the metastore if this is by path + case TableCreationModes.Create => + spark.sessionState.catalog.createTable( + cleaned, + ignoreIfExists = existingTableOpt.isDefined, + validateLocation = false) + case TableCreationModes.Replace | TableCreationModes.CreateOrReplace + if existingTableOpt.isDefined => + spark.sessionState.catalog.alterTable(table) + case TableCreationModes.Replace => + val ident = Identifier.of(table.identifier.database.toArray, table.identifier.table) + throw DeltaErrors.cannotReplaceMissingTableException(ident) + case TableCreationModes.CreateOrReplace => + spark.sessionState.catalog.createTable( + cleaned, + ignoreIfExists = false, + validateLocation = false) + } + } + + /** Clean up the information we pass on to store in the catalog. */ + private def cleanupTableDefinition(table: CatalogTable, snapshot: Snapshot): CatalogTable = { + // These actually have no effect on the usability of Delta, but feature flagging legacy + // behavior for now + val storageProps = if (conf.getConf(DeltaSQLConf.DELTA_LEGACY_STORE_WRITER_OPTIONS_AS_PROPS)) { + // Legacy behavior + table.storage + } else { + table.storage.copy(properties = Map.empty) + } + + table.copy( + schema = new StructType(), + properties = Map.empty, + partitionColumnNames = Nil, + // Remove write specific options when updating the catalog + storage = storageProps, + tracksPartitionsInCatalog = true) + } + + /** + * With DataFrameWriterV2, methods like `replace()` or `createOrReplace()` mean that the + * metadata of the table should be replaced. If overwriteSchema=false is provided with these + * methods, then we will verify that the metadata match exactly. + */ + private def replaceMetadataIfNecessary( + txn: OptimisticTransaction, + tableDesc: CatalogTable, + options: DeltaOptions, + schema: StructType): Unit = { + val isReplace = (operation == TableCreationModes.CreateOrReplace || + operation == TableCreationModes.Replace) + // If a user explicitly specifies not to overwrite the schema, during a replace, we should + // tell them that it's not supported + val dontOverwriteSchema = options.options.contains(DeltaOptions.OVERWRITE_SCHEMA_OPTION) && + !options.canOverwriteSchema + if (isReplace && dontOverwriteSchema) { + throw DeltaErrors.illegalUsageException(DeltaOptions.OVERWRITE_SCHEMA_OPTION, "replacing") + } + if (txn.readVersion > -1L && isReplace && !dontOverwriteSchema) { + // When a table already exists, and we're using the DataFrameWriterV2 API to replace + // or createOrReplace a table, we blindly overwrite the metadata. + txn.updateMetadataForNewTable(getProvidedMetadata(table, schema.json)) + } + } + + /** + * Horrible hack to differentiate between DataFrameWriterV1 and V2 so that we can decide + * what to do with table metadata. In DataFrameWriterV1, mode("overwrite").saveAsTable, + * behaves as a CreateOrReplace table, but we have asked for "overwriteSchema" as an + * explicit option to overwrite partitioning or schema information. With DataFrameWriterV2, + * the behavior asked for by the user is clearer: .createOrReplace(), which means that we + * should overwrite schema and/or partitioning. Therefore we have this hack. + */ + private def isV1Writer: Boolean = { + Thread.currentThread().getStackTrace.exists(_.toString.contains( + classOf[DataFrameWriter[_]].getCanonicalName + ".")) + } +} diff --git a/delta-lake/delta-22x/src/main/scala/org/apache/spark/sql/delta/rapids/delta22x/GpuDeltaCatalog.scala b/delta-lake/delta-22x/src/main/scala/org/apache/spark/sql/delta/rapids/delta22x/GpuDeltaCatalog.scala new file mode 100644 index 00000000000..4408c917a03 --- /dev/null +++ b/delta-lake/delta-22x/src/main/scala/org/apache/spark/sql/delta/rapids/delta22x/GpuDeltaCatalog.scala @@ -0,0 +1,220 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * This file was derived from DeltaDataSource.scala in the + * Delta Lake project at https://github.com/delta-io/delta. + * + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta.rapids.delta22x + +import com.nvidia.spark.rapids.RapidsConf +import java.util + +import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.connector.catalog.{Identifier, StagedTable, Table} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.delta.{DeltaConfigs, DeltaErrors} +import org.apache.spark.sql.delta.catalog.DeltaCatalog +import org.apache.spark.sql.delta.commands.TableCreationModes +import org.apache.spark.sql.delta.metering.DeltaLogging +import org.apache.spark.sql.delta.rapids.{GpuDeltaCatalogBase, SupportsPathIdentifier} +import org.apache.spark.sql.delta.sources.DeltaSourceUtils +import org.apache.spark.sql.execution.command.LeafRunnableCommand +import org.apache.spark.sql.execution.datasources.PartitioningUtils +import org.apache.spark.sql.types.StructType + +class GpuDeltaCatalog( + override val cpuCatalog: DeltaCatalog, + override val rapidsConf: RapidsConf) + extends GpuDeltaCatalogBase with SupportsPathIdentifier with DeltaLogging { + + override val spark: SparkSession = cpuCatalog.spark + + override protected def buildGpuCreateDeltaTableCommand( + rapidsConf: RapidsConf, + table: CatalogTable, + existingTableOpt: Option[CatalogTable], + mode: SaveMode, + query: Option[LogicalPlan], + operation: TableCreationModes.CreationMode, + tableByPath: Boolean): LeafRunnableCommand = { + GpuCreateDeltaTableCommand( + table, + existingTableOpt, + mode, + query, + operation, + tableByPath = tableByPath + )(rapidsConf) + } + + override protected def getExistingTableIfExists(table: TableIdentifier): Option[CatalogTable] = { + // If this is a path identifier, we cannot return an existing CatalogTable. The Create command + // will check the file system itself + if (isPathIdentifier(table)) return None + val tableExists = catalog.tableExists(table) + if (tableExists) { + val oldTable = catalog.getTableMetadata(table) + if (oldTable.tableType == CatalogTableType.VIEW) { + throw new AnalysisException( + s"$table is a view. You may not write data into a view.") + } + if (!DeltaSourceUtils.isDeltaTable(oldTable.provider)) { + throw DeltaErrors.notADeltaTable(table.table) + } + Some(oldTable) + } else { + None + } + } + + override protected def verifyTableAndSolidify( + tableDesc: CatalogTable, + query: Option[LogicalPlan]): CatalogTable = { + + if (tableDesc.bucketSpec.isDefined) { + throw DeltaErrors.operationNotSupportedException("Bucketing", tableDesc.identifier) + } + + val schema = query.map { plan => + assert(tableDesc.schema.isEmpty, "Can't specify table schema in CTAS.") + plan.schema.asNullable + }.getOrElse(tableDesc.schema) + + PartitioningUtils.validatePartitionColumn( + schema, + tableDesc.partitionColumnNames, + caseSensitive = false) // Delta is case insensitive + + val validatedConfigurations = DeltaConfigs.validateConfigurations(tableDesc.properties) + + val db = tableDesc.identifier.database.getOrElse(catalog.getCurrentDatabase) + val tableIdentWithDB = tableDesc.identifier.copy(database = Some(db)) + tableDesc.copy( + identifier = tableIdentWithDB, + schema = schema, + properties = validatedConfigurations) + } + + override protected def createGpuStagedDeltaTableV2( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String], + operation: TableCreationModes.CreationMode): StagedTable = { + new GpuStagedDeltaTableV2WithLogging(ident, schema, partitions, properties, operation) + } + + override def loadTable(ident: Identifier, timestamp: Long): Table = { + cpuCatalog.loadTable(ident, timestamp) + } + + override def loadTable(ident: Identifier, version: String): Table = { + cpuCatalog.loadTable(ident, version) + } + + /** + * Creates a Delta table using GPU for writing the data + * + * @param ident The identifier of the table + * @param schema The schema of the table + * @param partitions The partition transforms for the table + * @param allTableProperties The table properties that configure the behavior of the table or + * provide information about the table + * @param writeOptions Options specific to the write during table creation or replacement + * @param sourceQuery A query if this CREATE request came from a CTAS or RTAS + * @param operation The specific table creation mode, whether this is a + * Create/Replace/Create or Replace + */ + override def createDeltaTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + allTableProperties: util.Map[String, String], + writeOptions: Map[String, String], + sourceQuery: Option[DataFrame], + operation: TableCreationModes.CreationMode + ): Table = recordFrameProfile( + "DeltaCatalog", "createDeltaTable") { + super.createDeltaTable( + ident, + schema, + partitions, + allTableProperties, + writeOptions, + sourceQuery, + operation) + } + + override def createTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = + recordFrameProfile("DeltaCatalog", "createTable") { + super.createTable(ident, schema, partitions, properties) + } + + override def stageReplace( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): StagedTable = + recordFrameProfile("DeltaCatalog", "stageReplace") { + super.stageReplace(ident, schema, partitions, properties) + } + + override def stageCreateOrReplace( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): StagedTable = + recordFrameProfile("DeltaCatalog", "stageCreateOrReplace") { + super.stageCreateOrReplace(ident, schema, partitions, properties) + } + + override def stageCreate( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): StagedTable = + recordFrameProfile("DeltaCatalog", "stageCreate") { + super.stageCreate(ident, schema, partitions, properties) + } + + /** + * A staged Delta table, which creates a HiveMetaStore entry and appends data if this was a + * CTAS/RTAS command. We have a ugly way of using this API right now, but it's the best way to + * maintain old behavior compatibility between Databricks Runtime and OSS Delta Lake. + */ + protected class GpuStagedDeltaTableV2WithLogging( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String], + operation: TableCreationModes.CreationMode) + extends GpuStagedDeltaTableV2(ident, schema, partitions, properties, operation) { + + override def commitStagedChanges(): Unit = recordFrameProfile( + "DeltaCatalog", "commitStagedChanges") { + super.commitStagedChanges() + } + } +} diff --git a/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/DeleteCommandMeta.scala b/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/DeleteCommandMeta.scala index b6437cc2ac0..4c2827132bf 100644 --- a/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/DeleteCommandMeta.scala +++ b/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/DeleteCommandMeta.scala @@ -45,7 +45,7 @@ class DeleteCommandMeta( // https://github.com/NVIDIA/spark-rapids/issues/8554 willNotWorkOnGpu("Deletion vectors are not supported on GPU") } - RapidsDeltaUtils.tagForDeltaWrite(this, deleteCmd.target.schema, deleteCmd.deltaLog, + RapidsDeltaUtils.tagForDeltaWrite(this, deleteCmd.target.schema, Some(deleteCmd.deltaLog), Map.empty, SparkSession.active) } diff --git a/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/Delta24xProvider.scala b/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/Delta24xProvider.scala index 1f21c2a3e02..f50f1d96ced 100644 --- a/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/Delta24xProvider.scala +++ b/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/Delta24xProvider.scala @@ -16,15 +16,19 @@ package com.nvidia.spark.rapids.delta.delta24x -import com.nvidia.spark.rapids.{GpuOverrides, GpuReadParquetFileFormat, RunnableCommandRule, SparkPlanMeta} +import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, GpuExec, GpuOverrides, GpuReadParquetFileFormat, RunnableCommandRule, SparkPlanMeta} import com.nvidia.spark.rapids.delta.DeltaIOProvider import org.apache.spark.sql.delta.DeltaParquetFileFormat import org.apache.spark.sql.delta.DeltaParquetFileFormat.{IS_ROW_DELETED_COLUMN_NAME, ROW_INDEX_COLUMN_NAME} +import org.apache.spark.sql.delta.catalog.DeltaCatalog import org.apache.spark.sql.delta.commands.{DeleteCommand, MergeIntoCommand, UpdateCommand} +import org.apache.spark.sql.delta.rapids.DeltaRuntimeShim import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec +import org.apache.spark.sql.execution.datasources.v2.rapids.GpuAtomicCreateTableAsSelectExec object Delta24xProvider extends DeltaIOProvider { @@ -72,4 +76,19 @@ object Delta24xProvider extends DeltaIOProvider { val cpuFormat = format.asInstanceOf[DeltaParquetFileFormat] GpuDelta24xParquetFileFormat(cpuFormat.metadata, cpuFormat.isSplittable) } + + override def convertToGpu( + cpuExec: AtomicCreateTableAsSelectExec, + meta: AtomicCreateTableAsSelectExecMeta): GpuExec = { + val cpuCatalog = cpuExec.catalog.asInstanceOf[DeltaCatalog] + GpuAtomicCreateTableAsSelectExec( + DeltaRuntimeShim.getGpuDeltaCatalog(cpuCatalog, meta.conf), + cpuExec.ident, + cpuExec.partitioning, + cpuExec.plan, + meta.childPlans.head.convertIfNeeded(), + cpuExec.tableSpec, + cpuExec.writeOptions, + cpuExec.ifNotExists) + } } diff --git a/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/GpuDeltaCatalog.scala b/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/GpuDeltaCatalog.scala new file mode 100644 index 00000000000..7c5791ce265 --- /dev/null +++ b/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/GpuDeltaCatalog.scala @@ -0,0 +1,193 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * This file was derived from DeltaDataSource.scala in the + * Delta Lake project at https://github.com/delta-io/delta. + * + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.delta.delta24x + +import com.nvidia.spark.rapids.RapidsConf +import java.util + +import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.CatalogTable +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.connector.catalog.{Identifier, StagedTable, Table} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.delta.catalog.DeltaCatalog +import org.apache.spark.sql.delta.commands.TableCreationModes +import org.apache.spark.sql.delta.metering.DeltaLogging +import org.apache.spark.sql.delta.rapids.GpuDeltaCatalogBase +import org.apache.spark.sql.delta.rapids.delta24x.GpuCreateDeltaTableCommand +import org.apache.spark.sql.execution.command.LeafRunnableCommand +import org.apache.spark.sql.rapids.execution.ShimTrampolineUtil +import org.apache.spark.sql.types.StructType + +class GpuDeltaCatalog( + override val cpuCatalog: DeltaCatalog, + override val rapidsConf: RapidsConf) + extends GpuDeltaCatalogBase with DeltaLogging { + + override val spark: SparkSession = cpuCatalog.spark + + override protected def buildGpuCreateDeltaTableCommand( + rapidsConf: RapidsConf, + table: CatalogTable, + existingTableOpt: Option[CatalogTable], + mode: SaveMode, + query: Option[LogicalPlan], + operation: TableCreationModes.CreationMode, + tableByPath: Boolean): LeafRunnableCommand = { + GpuCreateDeltaTableCommand( + table, + existingTableOpt, + mode, + query, + operation, + tableByPath = tableByPath + )(rapidsConf) + } + + override protected def getExistingTableIfExists(table: TableIdentifier): Option[CatalogTable] = { + cpuCatalog.getExistingTableIfExists(table) + } + + override protected def verifyTableAndSolidify( + tableDesc: CatalogTable, + query: Option[LogicalPlan]): CatalogTable = { + cpuCatalog.verifyTableAndSolidify(tableDesc, query) + } + + override protected def createGpuStagedDeltaTableV2( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String], + operation: TableCreationModes.CreationMode): StagedTable = { + new GpuStagedDeltaTableV2WithLogging(ident, schema, partitions, properties, operation) + } + + override def loadTable(ident: Identifier, timestamp: Long): Table = { + cpuCatalog.loadTable(ident, timestamp) + } + + override def loadTable(ident: Identifier, version: String): Table = { + cpuCatalog.loadTable(ident, version) + } + + /** + * Creates a Delta table using GPU for writing the data + * + * @param ident The identifier of the table + * @param schema The schema of the table + * @param partitions The partition transforms for the table + * @param allTableProperties The table properties that configure the behavior of the table or + * provide information about the table + * @param writeOptions Options specific to the write during table creation or replacement + * @param sourceQuery A query if this CREATE request came from a CTAS or RTAS + * @param operation The specific table creation mode, whether this is a + * Create/Replace/Create or Replace + */ + override def createDeltaTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + allTableProperties: util.Map[String, String], + writeOptions: Map[String, String], + sourceQuery: Option[DataFrame], + operation: TableCreationModes.CreationMode + ): Table = recordFrameProfile( + "DeltaCatalog", "createDeltaTable") { + super.createDeltaTable( + ident, + schema, + partitions, + allTableProperties, + writeOptions, + sourceQuery, + operation) + } + + override def createTable( + ident: Identifier, + columns: Array[org.apache.spark.sql.connector.catalog.Column], + partitions: Array[Transform], + properties: util.Map[String, String]): Table = { + createTable( + ident, + ShimTrampolineUtil.v2ColumnsToStructType(columns), + partitions, + properties) + } + + override def createTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = + recordFrameProfile("DeltaCatalog", "createTable") { + super.createTable(ident, schema, partitions, properties) + } + + override def stageReplace( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): StagedTable = + recordFrameProfile("DeltaCatalog", "stageReplace") { + super.stageReplace(ident, schema, partitions, properties) + } + + override def stageCreateOrReplace( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): StagedTable = + recordFrameProfile("DeltaCatalog", "stageCreateOrReplace") { + super.stageCreateOrReplace(ident, schema, partitions, properties) + } + + override def stageCreate( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): StagedTable = + recordFrameProfile("DeltaCatalog", "stageCreate") { + super.stageCreate(ident, schema, partitions, properties) + } + + /** + * A staged Delta table, which creates a HiveMetaStore entry and appends data if this was a + * CTAS/RTAS command. We have a ugly way of using this API right now, but it's the best way to + * maintain old behavior compatibility between Databricks Runtime and OSS Delta Lake. + */ + protected class GpuStagedDeltaTableV2WithLogging( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String], + operation: TableCreationModes.CreationMode) + extends GpuStagedDeltaTableV2(ident, schema, partitions, properties, operation) { + + override def commitStagedChanges(): Unit = recordFrameProfile( + "DeltaCatalog", "commitStagedChanges") { + super.commitStagedChanges() + } + } +} diff --git a/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/MergeIntoCommandMeta.scala b/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/MergeIntoCommandMeta.scala index d9576f7793f..4b4dfb624b5 100644 --- a/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/MergeIntoCommandMeta.scala +++ b/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/MergeIntoCommandMeta.scala @@ -43,7 +43,8 @@ class MergeIntoCommandMeta( } val targetSchema = mergeCmd.migratedSchema.getOrElse(mergeCmd.target.schema) val deltaLog = mergeCmd.targetFileIndex.deltaLog - RapidsDeltaUtils.tagForDeltaWrite(this, targetSchema, deltaLog, Map.empty, SparkSession.active) + RapidsDeltaUtils.tagForDeltaWrite(this, targetSchema, Some(deltaLog), + Map.empty, SparkSession.active) } override def convertToGpu(): RunnableCommand = { diff --git a/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/UpdateCommandMeta.scala b/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/UpdateCommandMeta.scala index ce340ce0e54..82faf8ad16d 100644 --- a/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/UpdateCommandMeta.scala +++ b/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/UpdateCommandMeta.scala @@ -37,7 +37,7 @@ class UpdateCommandMeta( s"${RapidsConf.ENABLE_DELTA_WRITE} to true") } RapidsDeltaUtils.tagForDeltaWrite(this, updateCmd.target.schema, - updateCmd.tahoeFileIndex.deltaLog, Map.empty, updateCmd.tahoeFileIndex.spark) + Some(updateCmd.tahoeFileIndex.deltaLog), Map.empty, updateCmd.tahoeFileIndex.spark) } override def convertToGpu(): RunnableCommand = { diff --git a/delta-lake/delta-24x/src/main/scala/org/apache/spark/sql/delta/rapids/delta24x/Delta24xRuntimeShim.scala b/delta-lake/delta-24x/src/main/scala/org/apache/spark/sql/delta/rapids/delta24x/Delta24xRuntimeShim.scala index d79eae37e20..ab9cd4753c0 100644 --- a/delta-lake/delta-24x/src/main/scala/org/apache/spark/sql/delta/rapids/delta24x/Delta24xRuntimeShim.scala +++ b/delta-lake/delta-24x/src/main/scala/org/apache/spark/sql/delta/rapids/delta24x/Delta24xRuntimeShim.scala @@ -18,10 +18,12 @@ package org.apache.spark.sql.delta.rapids.delta24x import com.nvidia.spark.rapids.RapidsConf import com.nvidia.spark.rapids.delta.DeltaProvider -import com.nvidia.spark.rapids.delta.delta24x.Delta24xProvider +import com.nvidia.spark.rapids.delta.delta24x.{Delta24xProvider, GpuDeltaCatalog} import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.catalog.StagingTableCatalog import org.apache.spark.sql.delta.{DeltaLog, DeltaUDF, Snapshot} +import org.apache.spark.sql.delta.catalog.DeltaCatalog import org.apache.spark.sql.delta.rapids.{DeltaRuntimeShim, GpuOptimisticTransactionBase} import org.apache.spark.sql.delta.sources.DeltaSQLConf import org.apache.spark.sql.execution.datasources.FileFormat @@ -54,4 +56,9 @@ class Delta24xRuntimeShim extends DeltaRuntimeShim { spark.sessionState.conf.getConf(DeltaSQLConf.TIGHT_BOUND_COLUMN_ON_FILE_INIT_DISABLED) } + override def getGpuDeltaCatalog( + cpuCatalog: DeltaCatalog, + rapidsConf: RapidsConf): StagingTableCatalog = { + new GpuDeltaCatalog(cpuCatalog, rapidsConf) + } } diff --git a/delta-lake/delta-24x/src/main/scala/org/apache/spark/sql/delta/rapids/delta24x/GpuCreateDeltaTableCommand.scala b/delta-lake/delta-24x/src/main/scala/org/apache/spark/sql/delta/rapids/delta24x/GpuCreateDeltaTableCommand.scala new file mode 100644 index 00000000000..73a359cde6f --- /dev/null +++ b/delta-lake/delta-24x/src/main/scala/org/apache/spark/sql/delta/rapids/delta24x/GpuCreateDeltaTableCommand.scala @@ -0,0 +1,494 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * This file was derived from CreateDeltaTableCommand.scala in the + * Delta Lake project at https://github.com/delta-io/delta. + * + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta.rapids.delta24x + +import com.nvidia.spark.rapids.RapidsConf +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.delta._ +import org.apache.spark.sql.delta.DeltaColumnMapping.{dropColumnMappingMetadata, filterColumnMappingProperties} +import org.apache.spark.sql.delta.actions.{Action, Metadata, Protocol} +import org.apache.spark.sql.delta.commands.{DeltaCommand, TableCreationModes, WriteIntoDelta} +import org.apache.spark.sql.delta.metering.DeltaLogging +import org.apache.spark.sql.delta.rapids.GpuDeltaLog +import org.apache.spark.sql.delta.schema.SchemaUtils +import org.apache.spark.sql.delta.sources.DeltaSQLConf +import org.apache.spark.sql.execution.command.{LeafRunnableCommand, RunnableCommand} +import org.apache.spark.sql.types.StructType + +/** + * Single entry point for all write or declaration operations for Delta tables accessed through + * the table name. + * + * @param table The table identifier for the Delta table + * @param existingTableOpt The existing table for the same identifier if exists + * @param mode The save mode when writing data. Relevant when the query is empty or set to Ignore + * with `CREATE TABLE IF NOT EXISTS`. + * @param query The query to commit into the Delta table if it exist. This can come from + * - CTAS + * - saveAsTable + * @param protocol This is used to create a table with specific protocol version + */ +case class GpuCreateDeltaTableCommand( + table: CatalogTable, + existingTableOpt: Option[CatalogTable], + mode: SaveMode, + query: Option[LogicalPlan], + operation: TableCreationModes.CreationMode = TableCreationModes.Create, + tableByPath: Boolean = false, + override val output: Seq[Attribute] = Nil, + protocol: Option[Protocol] = None)(@transient rapidsConf: RapidsConf) + extends LeafRunnableCommand + with DeltaCommand + with DeltaLogging { + + override def otherCopyArgs: Seq[AnyRef] = Seq(rapidsConf) + + override def run(sparkSession: SparkSession): Seq[Row] = { + val table = this.table + + assert(table.tableType != CatalogTableType.VIEW) + assert(table.identifier.database.isDefined, "Database should've been fixed at analysis") + // There is a subtle race condition here, where the table can be created by someone else + // while this command is running. Nothing we can do about that though :( + val tableExists = existingTableOpt.isDefined + if (mode == SaveMode.Ignore && tableExists) { + // Early exit on ignore + return Nil + } else if (mode == SaveMode.ErrorIfExists && tableExists) { + throw DeltaErrors.tableAlreadyExists(table) + } + + val tableWithLocation = if (tableExists) { + val existingTable = existingTableOpt.get + table.storage.locationUri match { + case Some(location) if location.getPath != existingTable.location.getPath => + throw DeltaErrors.tableLocationMismatch(table, existingTable) + case _ => + } + table.copy( + storage = existingTable.storage, + tableType = existingTable.tableType) + } else if (table.storage.locationUri.isEmpty) { + // We are defining a new managed table + assert(table.tableType == CatalogTableType.MANAGED) + val loc = sparkSession.sessionState.catalog.defaultTablePath(table.identifier) + table.copy(storage = table.storage.copy(locationUri = Some(loc))) + } else { + // 1. We are defining a new external table + // 2. It's a managed table which already has the location populated. This can happen in DSV2 + // CTAS flow. + table + } + + val isManagedTable = tableWithLocation.tableType == CatalogTableType.MANAGED + val tableLocation = new Path(tableWithLocation.location) + val gpuDeltaLog = GpuDeltaLog.forTable(sparkSession, tableLocation, rapidsConf) + val hadoopConf = gpuDeltaLog.deltaLog.newDeltaHadoopConf() + val fs = tableLocation.getFileSystem(hadoopConf) + val options = new DeltaOptions(table.storage.properties, sparkSession.sessionState.conf) + var result: Seq[Row] = Nil + + recordDeltaOperation(gpuDeltaLog.deltaLog, "delta.ddl.createTable") { + val txn = gpuDeltaLog.startTransaction() + val opStartTs = System.currentTimeMillis() + if (query.isDefined) { + // If the mode is Ignore or ErrorIfExists, the table must not exist, or we would return + // earlier. And the data should not exist either, to match the behavior of + // Ignore/ErrorIfExists mode. This means the table path should not exist or is empty. + if (mode == SaveMode.Ignore || mode == SaveMode.ErrorIfExists) { + assert(!tableExists) + // We may have failed a previous write. The retry should still succeed even if we have + // garbage data + if (txn.readVersion > -1 || !fs.exists(gpuDeltaLog.deltaLog.logPath)) { + assertPathEmpty(hadoopConf, tableWithLocation) + } + } + + // Execute write command for `deltaWriter` by + // - replacing the metadata new target table for DataFrameWriterV2 writer if it is a + // REPLACE or CREATE_OR_REPLACE command, + // - running the write procedure of DataFrameWriter command and returning the + // new created actions, + // - returning the Delta Operation type of this DataFrameWriter + def doDeltaWrite( + deltaWriter: WriteIntoDelta, + schema: StructType): (Seq[Action], DeltaOperations.Operation) = { + // In the V2 Writer, methods like "replace" and "createOrReplace" implicitly mean that + // the metadata should be changed. This wasn't the behavior for DataFrameWriterV1. + if (!isV1Writer) { + replaceMetadataIfNecessary( + txn, tableWithLocation, options, schema) + } + val actions = deltaWriter.write(txn, sparkSession) + val op = getOperation(txn.metadata, isManagedTable, Some(options)) + (actions, op) + } + + // We are either appending/overwriting with saveAsTable or creating a new table with CTAS or + // we are creating a table as part of a RunnableCommand + query.get match { + case deltaWriter: WriteIntoDelta => + if (!hasBeenExecuted(txn, sparkSession, Some(options))) { + val (actions, op) = doDeltaWrite(deltaWriter, deltaWriter.data.schema.asNullable) + txn.commit(actions, op) + } + case cmd: RunnableCommand => + result = cmd.run(sparkSession) + case other => + // When using V1 APIs, the `other` plan is not yet optimized, therefore, it is safe + // to once again go through analysis + val data = Dataset.ofRows(sparkSession, other) + val deltaWriter = WriteIntoDelta( + deltaLog = gpuDeltaLog.deltaLog, + mode = mode, + options, + partitionColumns = table.partitionColumnNames, + configuration = tableWithLocation.properties + ("comment" -> table.comment.orNull), + data = data) + if (!hasBeenExecuted(txn, sparkSession, Some(options))) { + val (actions, op) = doDeltaWrite(deltaWriter, other.schema.asNullable) + txn.commit(actions, op) + } + } + } else { + def createTransactionLogOrVerify(): Unit = { + if (isManagedTable) { + // When creating a managed table, the table path should not exist or is empty, or + // users would be surprised to see the data, or see the data directory being dropped + // after the table is dropped. + assertPathEmpty(hadoopConf, tableWithLocation) + } + + // This is either a new table, or, we never defined the schema of the table. While it is + // unexpected that `txn.metadata.schema` to be empty when txn.readVersion >= 0, we still + // guard against it, in case of checkpoint corruption bugs. + val noExistingMetadata = txn.readVersion == -1 || txn.metadata.schema.isEmpty + if (noExistingMetadata) { + assertTableSchemaDefined(fs, tableLocation, tableWithLocation, txn, sparkSession) + assertPathEmpty(hadoopConf, tableWithLocation) + // This is a user provided schema. + // Doesn't come from a query, Follow nullability invariants. + val newMetadata = getProvidedMetadata(tableWithLocation, table.schema.json) + txn.updateMetadataForNewTable(newMetadata) + protocol.foreach { protocol => + txn.updateProtocol(protocol) + } + val op = getOperation(newMetadata, isManagedTable, None) + txn.commit(Nil, op) + } else { + verifyTableMetadata(txn, tableWithLocation) + } + } + // We are defining a table using the Create or Replace Table statements. + operation match { + case TableCreationModes.Create => + require(!tableExists, "Can't recreate a table when it exists") + createTransactionLogOrVerify() + + case TableCreationModes.CreateOrReplace if !tableExists => + // If the table doesn't exist, CREATE OR REPLACE must provide a schema + if (tableWithLocation.schema.isEmpty) { + throw DeltaErrors.schemaNotProvidedException + } + createTransactionLogOrVerify() + case _ => + // When the operation is a REPLACE or CREATE OR REPLACE, then the schema shouldn't be + // empty, since we'll use the entry to replace the schema + if (tableWithLocation.schema.isEmpty) { + throw DeltaErrors.schemaNotProvidedException + } + // We need to replace + replaceMetadataIfNecessary(txn, tableWithLocation, options, tableWithLocation.schema) + // Truncate the table + val operationTimestamp = System.currentTimeMillis() + val removes = txn.filterFiles().map(_.removeWithTimestamp(operationTimestamp)) + val op = getOperation(txn.metadata, isManagedTable, None) + txn.commit(removes, op) + } + } + + // We would have failed earlier on if we couldn't ignore the existence of the table + // In addition, we just might using saveAsTable to append to the table, so ignore the creation + // if it already exists. + // Note that someone may have dropped and recreated the table in a separate location in the + // meantime... Unfortunately we can't do anything there at the moment, because Hive sucks. + logInfo(s"Table is path-based table: $tableByPath. Update catalog with mode: $operation") + updateCatalog( + sparkSession, + tableWithLocation, + gpuDeltaLog.deltaLog.update(checkIfUpdatedSinceTs = Some(opStartTs)), + txn) + + + result + } + } + + private def getProvidedMetadata(table: CatalogTable, schemaString: String): Metadata = { + Metadata( + description = table.comment.orNull, + schemaString = schemaString, + partitionColumns = table.partitionColumnNames, + configuration = table.properties, + createdTime = Some(System.currentTimeMillis())) + } + + private def assertPathEmpty( + hadoopConf: Configuration, + tableWithLocation: CatalogTable): Unit = { + val path = new Path(tableWithLocation.location) + val fs = path.getFileSystem(hadoopConf) + // Verify that the table location associated with CREATE TABLE doesn't have any data. Note that + // we intentionally diverge from this behavior w.r.t regular datasource tables (that silently + // overwrite any previous data) + if (fs.exists(path) && fs.listStatus(path).nonEmpty) { + throw DeltaErrors.createTableWithNonEmptyLocation( + tableWithLocation.identifier.toString, + tableWithLocation.location.toString) + } + } + + private def assertTableSchemaDefined( + fs: FileSystem, + path: Path, + table: CatalogTable, + txn: OptimisticTransaction, + sparkSession: SparkSession): Unit = { + // If we allow creating an empty schema table and indeed the table is new, we just need to + // make sure: + // 1. txn.readVersion == -1 to read a new table + // 2. for external tables: path must either doesn't exist or is completely empty + val allowCreatingTableWithEmptySchema = sparkSession.sessionState + .conf.getConf(DeltaSQLConf.DELTA_ALLOW_CREATE_EMPTY_SCHEMA_TABLE) && txn.readVersion == -1 + + // Users did not specify the schema. We expect the schema exists in Delta. + if (table.schema.isEmpty) { + if (table.tableType == CatalogTableType.EXTERNAL) { + if (fs.exists(path) && fs.listStatus(path).nonEmpty) { + throw DeltaErrors.createExternalTableWithoutLogException( + path, table.identifier.quotedString, sparkSession) + } else { + if (allowCreatingTableWithEmptySchema) return + throw DeltaErrors.createExternalTableWithoutSchemaException( + path, table.identifier.quotedString, sparkSession) + } + } else { + if (allowCreatingTableWithEmptySchema) return + throw DeltaErrors.createManagedTableWithoutSchemaException( + table.identifier.quotedString, sparkSession) + } + } + } + + /** + * Verify against our transaction metadata that the user specified the right metadata for the + * table. + */ + private def verifyTableMetadata( + txn: OptimisticTransaction, + tableDesc: CatalogTable): Unit = { + val existingMetadata = txn.metadata + val path = new Path(tableDesc.location) + + // The delta log already exists. If they give any configuration, we'll make sure it all matches. + // Otherwise we'll just go with the metadata already present in the log. + // The schema compatibility checks will be made in `WriteIntoDelta` for CreateTable + // with a query + if (txn.readVersion > -1) { + if (tableDesc.schema.nonEmpty) { + // We check exact alignment on create table if everything is provided + // However, if in column mapping mode, we can safely ignore the related metadata fields in + // existing metadata because new table desc will not have related metadata assigned yet + val differences = SchemaUtils.reportDifferences( + dropColumnMappingMetadata(existingMetadata.schema), + tableDesc.schema) + if (differences.nonEmpty) { + throw DeltaErrors.createTableWithDifferentSchemaException( + path, tableDesc.schema, existingMetadata.schema, differences) + } + + // If schema is specified, we must make sure the partitioning matches, even the partitioning + // is not specified. + if (tableDesc.partitionColumnNames != existingMetadata.partitionColumns) { + throw DeltaErrors.createTableWithDifferentPartitioningException( + path, tableDesc.partitionColumnNames, existingMetadata.partitionColumns) + } + } + + if (tableDesc.properties.nonEmpty) { + // When comparing properties of the existing table and the new table, remove some + // internal column mapping properties for the sake of comparison. + val filteredTableProperties = filterColumnMappingProperties(tableDesc.properties) + val filteredExistingProperties = filterColumnMappingProperties( + existingMetadata.configuration) + if (filteredTableProperties != filteredExistingProperties) { + throw DeltaErrors.createTableWithDifferentPropertiesException( + path, filteredTableProperties, filteredExistingProperties) + } + // If column mapping properties are present in both configs, verify they're the same value. + if (!DeltaColumnMapping.verifyInternalProperties( + tableDesc.properties, existingMetadata.configuration)) { + throw DeltaErrors.createTableWithDifferentPropertiesException( + path, tableDesc.properties, existingMetadata.configuration) + } + } + } + } + + /** + * Based on the table creation operation, and parameters, we can resolve to different operations. + * A lot of this is needed for legacy reasons in Databricks Runtime. + * @param metadata The table metadata, which we are creating or replacing + * @param isManagedTable Whether we are creating or replacing a managed table + * @param options Write options, if this was a CTAS/RTAS + */ + private def getOperation( + metadata: Metadata, + isManagedTable: Boolean, + options: Option[DeltaOptions]): DeltaOperations.Operation = operation match { + // This is legacy saveAsTable behavior in Databricks Runtime + case TableCreationModes.Create if existingTableOpt.isDefined && query.isDefined => + DeltaOperations.Write(mode, Option(table.partitionColumnNames), options.get.replaceWhere, + options.flatMap(_.userMetadata)) + + // DataSourceV2 table creation + // CREATE TABLE (non-DataFrameWriter API) doesn't have options syntax + // (userMetadata uses SQLConf in this case) + case TableCreationModes.Create => + DeltaOperations.CreateTable(metadata, isManagedTable, query.isDefined) + + // DataSourceV2 table replace + // REPLACE TABLE (non-DataFrameWriter API) doesn't have options syntax + // (userMetadata uses SQLConf in this case) + case TableCreationModes.Replace => + DeltaOperations.ReplaceTable(metadata, isManagedTable, orCreate = false, query.isDefined) + + // Legacy saveAsTable with Overwrite mode + case TableCreationModes.CreateOrReplace if options.exists(_.replaceWhere.isDefined) => + DeltaOperations.Write(mode, Option(table.partitionColumnNames), options.get.replaceWhere, + options.flatMap(_.userMetadata)) + + // New DataSourceV2 saveAsTable with overwrite mode behavior + case TableCreationModes.CreateOrReplace => + DeltaOperations.ReplaceTable(metadata, isManagedTable, orCreate = true, query.isDefined, + options.flatMap(_.userMetadata)) + } + + /** + * Similar to getOperation, here we disambiguate the catalog alterations we need to do based + * on the table operation, and whether we have reached here through legacy code or DataSourceV2 + * code paths. + */ + private def updateCatalog( + spark: SparkSession, + table: CatalogTable, + snapshot: Snapshot, + txn: OptimisticTransaction): Unit = { + val cleaned = cleanupTableDefinition(spark, table, snapshot) + operation match { + case _ if tableByPath => // do nothing with the metastore if this is by path + case TableCreationModes.Create => + spark.sessionState.catalog.createTable( + cleaned, + ignoreIfExists = existingTableOpt.isDefined, + validateLocation = false) + case TableCreationModes.Replace | TableCreationModes.CreateOrReplace + if existingTableOpt.isDefined => + spark.sessionState.catalog.alterTable(table) + case TableCreationModes.Replace => + val ident = Identifier.of(table.identifier.database.toArray, table.identifier.table) + throw DeltaErrors.cannotReplaceMissingTableException(ident) + case TableCreationModes.CreateOrReplace => + spark.sessionState.catalog.createTable( + cleaned, + ignoreIfExists = false, + validateLocation = false) + } + } + + /** Clean up the information we pass on to store in the catalog. */ + private def cleanupTableDefinition(spark: SparkSession, table: CatalogTable, snapshot: Snapshot) + : CatalogTable = { + // These actually have no effect on the usability of Delta, but feature flagging legacy + // behavior for now + val storageProps = if (conf.getConf(DeltaSQLConf.DELTA_LEGACY_STORE_WRITER_OPTIONS_AS_PROPS)) { + // Legacy behavior + table.storage + } else { + table.storage.copy(properties = Map.empty) + } + + table.copy( + schema = new StructType(), + properties = Map.empty, + partitionColumnNames = Nil, + // Remove write specific options when updating the catalog + storage = storageProps, + tracksPartitionsInCatalog = true) + } + + /** + * With DataFrameWriterV2, methods like `replace()` or `createOrReplace()` mean that the + * metadata of the table should be replaced. If overwriteSchema=false is provided with these + * methods, then we will verify that the metadata match exactly. + */ + private def replaceMetadataIfNecessary( + txn: OptimisticTransaction, + tableDesc: CatalogTable, + options: DeltaOptions, + schema: StructType): Unit = { + val isReplace = (operation == TableCreationModes.CreateOrReplace || + operation == TableCreationModes.Replace) + // If a user explicitly specifies not to overwrite the schema, during a replace, we should + // tell them that it's not supported + val dontOverwriteSchema = options.options.contains(DeltaOptions.OVERWRITE_SCHEMA_OPTION) && + !options.canOverwriteSchema + if (isReplace && dontOverwriteSchema) { + throw DeltaErrors.illegalUsageException(DeltaOptions.OVERWRITE_SCHEMA_OPTION, "replacing") + } + if (txn.readVersion > -1L && isReplace && !dontOverwriteSchema) { + // When a table already exists, and we're using the DataFrameWriterV2 API to replace + // or createOrReplace a table, we blindly overwrite the metadata. + txn.updateMetadataForNewTable(getProvidedMetadata(table, schema.json)) + } + } + + /** + * Horrible hack to differentiate between DataFrameWriterV1 and V2 so that we can decide + * what to do with table metadata. In DataFrameWriterV1, mode("overwrite").saveAsTable, + * behaves as a CreateOrReplace table, but we have asked for "overwriteSchema" as an + * explicit option to overwrite partitioning or schema information. With DataFrameWriterV2, + * the behavior asked for by the user is clearer: .createOrReplace(), which means that we + * should overwrite schema and/or partitioning. Therefore we have this hack. + */ + private def isV1Writer: Boolean = { + Thread.currentThread().getStackTrace.exists(_.toString.contains( + classOf[DataFrameWriter[_]].getCanonicalName + ".")) + } +} diff --git a/delta-lake/delta-spark321db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuCreateDeltaTableCommand.scala b/delta-lake/delta-spark321db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuCreateDeltaTableCommand.scala new file mode 100644 index 00000000000..bb8738c950f --- /dev/null +++ b/delta-lake/delta-spark321db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuCreateDeltaTableCommand.scala @@ -0,0 +1,455 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * This file was derived from DeltaDataSource.scala in the + * Delta Lake project at https://github.com/delta-io/delta. + * + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.databricks.sql.transaction.tahoe.rapids + +import com.databricks.sql.transaction.tahoe._ +import com.databricks.sql.transaction.tahoe.actions.Metadata +import com.databricks.sql.transaction.tahoe.commands.{TableCreationModes, WriteIntoDelta} +import com.databricks.sql.transaction.tahoe.metering.DeltaLogging +import com.databricks.sql.transaction.tahoe.schema.SchemaUtils +import com.databricks.sql.transaction.tahoe.sources.DeltaSQLConf +import com.nvidia.spark.rapids.RapidsConf +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.execution.command.{LeafRunnableCommand, RunnableCommand} +import org.apache.spark.sql.types.StructType + +/** + * Single entry point for all write or declaration operations for Delta tables accessed through + * the table name. + * + * @param table The table identifier for the Delta table + * @param existingTableOpt The existing table for the same identifier if exists + * @param mode The save mode when writing data. Relevant when the query is empty or set to Ignore + * with `CREATE TABLE IF NOT EXISTS`. + * @param query The query to commit into the Delta table if it exist. This can come from + * - CTAS + * - saveAsTable + */ +case class GpuCreateDeltaTableCommand( + table: CatalogTable, + existingTableOpt: Option[CatalogTable], + mode: SaveMode, + query: Option[LogicalPlan], + operation: TableCreationModes.CreationMode = TableCreationModes.Create, + tableByPath: Boolean = false, + override val output: Seq[Attribute] = Nil)(@transient rapidsConf: RapidsConf) + extends LeafRunnableCommand + with DeltaLogging { + + override def run(sparkSession: SparkSession): Seq[Row] = { + val table = this.table + + assert(table.tableType != CatalogTableType.VIEW) + assert(table.identifier.database.isDefined, "Database should've been fixed at analysis") + // There is a subtle race condition here, where the table can be created by someone else + // while this command is running. Nothing we can do about that though :( + val tableExists = existingTableOpt.isDefined + if (mode == SaveMode.Ignore && tableExists) { + // Early exit on ignore + return Nil + } else if (mode == SaveMode.ErrorIfExists && tableExists) { + throw new AnalysisException(s"DTable ${table.identifier.quotedString} already exists.") + } + + val tableWithLocation = if (tableExists) { + val existingTable = existingTableOpt.get + table.storage.locationUri match { + case Some(location) if location.getPath != existingTable.location.getPath => + val tableName = table.identifier.quotedString + throw new AnalysisException( + s"The location of the existing table $tableName is " + + s"`${existingTable.location}`. It doesn't match the specified location " + + s"`${table.location}`.") + case _ => + } + table.copy( + storage = existingTable.storage, + tableType = existingTable.tableType) + } else if (table.storage.locationUri.isEmpty) { + // We are defining a new managed table + assert(table.tableType == CatalogTableType.MANAGED) + val loc = sparkSession.sessionState.catalog.defaultTablePath(table.identifier) + table.copy(storage = table.storage.copy(locationUri = Some(loc))) + } else { + // 1. We are defining a new external table + // 2. It's a managed table which already has the location populated. This can happen in DSV2 + // CTAS flow. + table + } + + + val isManagedTable = tableWithLocation.tableType == CatalogTableType.MANAGED + val tableLocation = new Path(tableWithLocation.location) + val gpuDeltaLog = GpuDeltaLog.forTable(sparkSession, tableLocation, rapidsConf) + val hadoopConf = gpuDeltaLog.deltaLog.newDeltaHadoopConf() + val fs = tableLocation.getFileSystem(hadoopConf) + val options = new DeltaOptions(table.storage.properties, sparkSession.sessionState.conf) + var result: Seq[Row] = Nil + + recordDeltaOperation(gpuDeltaLog.deltaLog, "delta.ddl.createTable") { + val txn = gpuDeltaLog.startTransaction() + if (query.isDefined) { + // If the mode is Ignore or ErrorIfExists, the table must not exist, or we would return + // earlier. And the data should not exist either, to match the behavior of + // Ignore/ErrorIfExists mode. This means the table path should not exist or is empty. + if (mode == SaveMode.Ignore || mode == SaveMode.ErrorIfExists) { + assert(!tableExists) + // We may have failed a previous write. The retry should still succeed even if we have + // garbage data + if (txn.readVersion > -1 || !fs.exists(gpuDeltaLog.deltaLog.logPath)) { + assertPathEmpty(hadoopConf, tableWithLocation) + } + } + // We are either appending/overwriting with saveAsTable or creating a new table with CTAS or + // we are creating a table as part of a RunnableCommand + query.get match { + case writer: WriteIntoDelta => + // In the V2 Writer, methods like "replace" and "createOrReplace" implicitly mean that + // the metadata should be changed. This wasn't the behavior for DataFrameWriterV1. + if (!isV1Writer) { + replaceMetadataIfNecessary( + txn, tableWithLocation, options, writer.data.schema.asNullable) + } + val actions = writer.write(txn, sparkSession) + val op = getOperation(txn.metadata, isManagedTable, Some(options)) + txn.commit(actions, op) + case cmd: RunnableCommand => + result = cmd.run(sparkSession) + case other => + // When using V1 APIs, the `other` plan is not yet optimized, therefore, it is safe + // to once again go through analysis + val data = Dataset.ofRows(sparkSession, other) + + // In the V2 Writer, methods like "replace" and "createOrReplace" implicitly mean that + // the metadata should be changed. This wasn't the behavior for DataFrameWriterV1. + if (!isV1Writer) { + replaceMetadataIfNecessary( + txn, tableWithLocation, options, other.schema.asNullable) + } + + val actions = WriteIntoDelta( + deltaLog = gpuDeltaLog.deltaLog, + mode = mode, + options, + partitionColumns = table.partitionColumnNames, + configuration = tableWithLocation.properties + ("comment" -> table.comment.orNull), + data = data).write(txn, sparkSession) + + val op = getOperation(txn.metadata, isManagedTable, Some(options)) + txn.commit(actions, op) + } + } else { + def createTransactionLogOrVerify(): Unit = { + if (isManagedTable) { + // When creating a managed table, the table path should not exist or is empty, or + // users would be surprised to see the data, or see the data directory being dropped + // after the table is dropped. + assertPathEmpty(hadoopConf, tableWithLocation) + } + + // This is either a new table, or, we never defined the schema of the table. While it is + // unexpected that `txn.metadata.schema` to be empty when txn.readVersion >= 0, we still + // guard against it, in case of checkpoint corruption bugs. + val noExistingMetadata = txn.readVersion == -1 || txn.metadata.schema.isEmpty + if (noExistingMetadata) { + assertTableSchemaDefined(fs, tableLocation, tableWithLocation, txn, sparkSession) + assertPathEmpty(hadoopConf, tableWithLocation) + // This is a user provided schema. + // Doesn't come from a query, Follow nullability invariants. + val newMetadata = getProvidedMetadata(tableWithLocation, table.schema.json) + txn.updateMetadataForNewTable(newMetadata) + + val op = getOperation(newMetadata, isManagedTable, None) + txn.commit(Nil, op) + } else { + verifyTableMetadata(txn, tableWithLocation) + } + } + // We are defining a table using the Create or Replace Table statements. + operation match { + case TableCreationModes.Create => + require(!tableExists, "Can't recreate a table when it exists") + createTransactionLogOrVerify() + + case TableCreationModes.CreateOrReplace if !tableExists => + // If the table doesn't exist, CREATE OR REPLACE must provide a schema + if (tableWithLocation.schema.isEmpty) { + throw DeltaErrors.schemaNotProvidedException + } + createTransactionLogOrVerify() + case _ => + // When the operation is a REPLACE or CREATE OR REPLACE, then the schema shouldn't be + // empty, since we'll use the entry to replace the schema + if (tableWithLocation.schema.isEmpty) { + throw DeltaErrors.schemaNotProvidedException + } + // We need to replace + replaceMetadataIfNecessary(txn, tableWithLocation, options, tableWithLocation.schema) + // Truncate the table + val operationTimestamp = System.currentTimeMillis() + val removes = txn.filterFiles().map(_.removeWithTimestamp(operationTimestamp)) + val op = getOperation(txn.metadata, isManagedTable, None) + txn.commit(removes, op) + } + } + + // We would have failed earlier on if we couldn't ignore the existence of the table + // In addition, we just might using saveAsTable to append to the table, so ignore the creation + // if it already exists. + // Note that someone may have dropped and recreated the table in a separate location in the + // meantime... Unfortunately we can't do anything there at the moment, because Hive sucks. + logInfo(s"Table is path-based table: $tableByPath. Update catalog with mode: $operation") + updateCatalog(sparkSession, tableWithLocation, gpuDeltaLog.deltaLog.snapshot, txn) + + result + } + } + + + private def getProvidedMetadata(table: CatalogTable, schemaString: String): Metadata = { + Metadata( + description = table.comment.orNull, + schemaString = schemaString, + partitionColumns = table.partitionColumnNames, + configuration = table.properties, + createdTime = Some(System.currentTimeMillis())) + } + + private def assertPathEmpty( + hadoopConf: Configuration, + tableWithLocation: CatalogTable): Unit = { + val path = new Path(tableWithLocation.location) + val fs = path.getFileSystem(hadoopConf) + // Verify that the table location associated with CREATE TABLE doesn't have any data. Note that + // we intentionally diverge from this behavior w.r.t regular datasource tables (that silently + // overwrite any previous data) + if (fs.exists(path) && fs.listStatus(path).nonEmpty) { + throw new AnalysisException(s"Cannot create table ('${tableWithLocation.identifier}')." + + s" The associated location ('${tableWithLocation.location}') is not empty but " + + s"it's not a Delta table") + } + } + + private def assertTableSchemaDefined( + fs: FileSystem, + path: Path, + table: CatalogTable, + txn: OptimisticTransaction, + sparkSession: SparkSession): Unit = { + // Users did not specify the schema. We expect the schema exists in Delta. + if (table.schema.isEmpty) { + if (table.tableType == CatalogTableType.EXTERNAL) { + if (fs.exists(path) && fs.listStatus(path).nonEmpty) { + throw DeltaErrors.createExternalTableWithoutLogException( + path, table.identifier.quotedString, sparkSession) + } else { + throw DeltaErrors.createExternalTableWithoutSchemaException( + path, table.identifier.quotedString, sparkSession) + } + } else { + throw DeltaErrors.createManagedTableWithoutSchemaException( + table.identifier.quotedString, sparkSession) + } + } + } + + /** + * Verify against our transaction metadata that the user specified the right metadata for the + * table. + */ + private def verifyTableMetadata( + txn: OptimisticTransaction, + tableDesc: CatalogTable): Unit = { + val existingMetadata = txn.metadata + val path = new Path(tableDesc.location) + + // The delta log already exists. If they give any configuration, we'll make sure it all matches. + // Otherwise we'll just go with the metadata already present in the log. + // The schema compatibility checks will be made in `WriteIntoDelta` for CreateTable + // with a query + if (txn.readVersion > -1) { + if (tableDesc.schema.nonEmpty) { + // We check exact alignment on create table if everything is provided + // However, if in column mapping mode, we can safely ignore the related metadata fields in + // existing metadata because new table desc will not have related metadata assigned yet + val differences = SchemaUtils.reportDifferences( + DeltaColumnMapping.dropColumnMappingMetadata(existingMetadata.schema), + tableDesc.schema) + if (differences.nonEmpty) { + throw DeltaErrors.createTableWithDifferentSchemaException( + path, tableDesc.schema, existingMetadata.schema, differences) + } + } + + // If schema is specified, we must make sure the partitioning matches, even the partitioning + // is not specified. + if (tableDesc.schema.nonEmpty && + tableDesc.partitionColumnNames != existingMetadata.partitionColumns) { + throw DeltaErrors.createTableWithDifferentPartitioningException( + path, tableDesc.partitionColumnNames, existingMetadata.partitionColumns) + } + + if (tableDesc.properties.nonEmpty && tableDesc.properties != existingMetadata.configuration) { + throw DeltaErrors.createTableWithDifferentPropertiesException( + path, tableDesc.properties, existingMetadata.configuration) + } + } + } + + /** + * Based on the table creation operation, and parameters, we can resolve to different operations. + * A lot of this is needed for legacy reasons in Databricks Runtime. + * @param metadata The table metadata, which we are creating or replacing + * @param isManagedTable Whether we are creating or replacing a managed table + * @param options Write options, if this was a CTAS/RTAS + */ + private def getOperation( + metadata: Metadata, + isManagedTable: Boolean, + options: Option[DeltaOptions]): DeltaOperations.Operation = operation match { + // This is legacy saveAsTable behavior in Databricks Runtime + case TableCreationModes.Create if existingTableOpt.isDefined && query.isDefined => + DeltaOperations.Write(mode, Option(table.partitionColumnNames), options.get.replaceWhere, + options.flatMap(_.userMetadata)) + + // DataSourceV2 table creation + // CREATE TABLE (non-DataFrameWriter API) doesn't have options syntax + // (userMetadata uses SQLConf in this case) + case TableCreationModes.Create => + DeltaOperations.CreateTable(metadata, isManagedTable, query.isDefined) + + // DataSourceV2 table replace + // REPLACE TABLE (non-DataFrameWriter API) doesn't have options syntax + // (userMetadata uses SQLConf in this case) + case TableCreationModes.Replace => + DeltaOperations.ReplaceTable(metadata, isManagedTable, orCreate = false, query.isDefined) + + // Legacy saveAsTable with Overwrite mode + case TableCreationModes.CreateOrReplace if options.exists(_.replaceWhere.isDefined) => + DeltaOperations.Write(mode, Option(table.partitionColumnNames), options.get.replaceWhere, + options.flatMap(_.userMetadata)) + + // New DataSourceV2 saveAsTable with overwrite mode behavior + case TableCreationModes.CreateOrReplace => + DeltaOperations.ReplaceTable(metadata, isManagedTable, orCreate = true, query.isDefined, + options.flatMap(_.userMetadata)) + } + + /** + * Similar to getOperation, here we disambiguate the catalog alterations we need to do based + * on the table operation, and whether we have reached here through legacy code or DataSourceV2 + * code paths. + */ + private def updateCatalog( + spark: SparkSession, + table: CatalogTable, + snapshot: Snapshot, + txn: OptimisticTransaction): Unit = { + val cleaned = cleanupTableDefinition(table, snapshot) + operation match { + case _ if tableByPath => // do nothing with the metastore if this is by path + case TableCreationModes.Create => + spark.sessionState.catalog.createTable( + cleaned, + ignoreIfExists = existingTableOpt.isDefined, + validateLocation = false) + case TableCreationModes.Replace | TableCreationModes.CreateOrReplace + if existingTableOpt.isDefined => + spark.sessionState.catalog.alterTable(table) + case TableCreationModes.Replace => + val ident = Identifier.of(table.identifier.database.toArray, table.identifier.table) + throw new CannotReplaceMissingTableException(ident) + case TableCreationModes.CreateOrReplace => + spark.sessionState.catalog.createTable( + cleaned, + ignoreIfExists = false, + validateLocation = false) + } + } + + /** Clean up the information we pass on to store in the catalog. */ + private def cleanupTableDefinition(table: CatalogTable, snapshot: Snapshot): CatalogTable = { + // These actually have no effect on the usability of Delta, but feature flagging legacy + // behavior for now + val storageProps = if (conf.getConf(DeltaSQLConf.DELTA_LEGACY_STORE_WRITER_OPTIONS_AS_PROPS)) { + // Legacy behavior + table.storage + } else { + table.storage.copy(properties = Map.empty) + } + + table.copy( + schema = new StructType(), + properties = Map.empty, + partitionColumnNames = Nil, + // Remove write specific options when updating the catalog + storage = storageProps, + tracksPartitionsInCatalog = true) + } + + /** + * With DataFrameWriterV2, methods like `replace()` or `createOrReplace()` mean that the + * metadata of the table should be replaced. If overwriteSchema=false is provided with these + * methods, then we will verify that the metadata match exactly. + */ + private def replaceMetadataIfNecessary( + txn: OptimisticTransaction, + tableDesc: CatalogTable, + options: DeltaOptions, + schema: StructType): Unit = { + val isReplace = (operation == TableCreationModes.CreateOrReplace || + operation == TableCreationModes.Replace) + // If a user explicitly specifies not to overwrite the schema, during a replace, we should + // tell them that it's not supported + val dontOverwriteSchema = options.options.contains(DeltaOptions.OVERWRITE_SCHEMA_OPTION) && + !options.canOverwriteSchema + if (isReplace && dontOverwriteSchema) { + throw DeltaErrors.illegalUsageException(DeltaOptions.OVERWRITE_SCHEMA_OPTION, "replacing") + } + if (txn.readVersion > -1L && isReplace && !dontOverwriteSchema) { + // When a table already exists, and we're using the DataFrameWriterV2 API to replace + // or createOrReplace a table, we blindly overwrite the metadata. + txn.updateMetadataForNewTable(getProvidedMetadata(table, schema.json)) + } + } + + /** + * Horrible hack to differentiate between DataFrameWriterV1 and V2 so that we can decide + * what to do with table metadata. In DataFrameWriterV1, mode("overwrite").saveAsTable, + * behaves as a CreateOrReplace table, but we have asked for "overwriteSchema" as an + * explicit option to overwrite partitioning or schema information. With DataFrameWriterV2, + * the behavior asked for by the user is clearer: .createOrReplace(), which means that we + * should overwrite schema and/or partitioning. Therefore we have this hack. + */ + private def isV1Writer: Boolean = { + Thread.currentThread().getStackTrace.exists(_.toString.contains( + classOf[DataFrameWriter[_]].getCanonicalName + ".")) + } +} diff --git a/delta-lake/delta-spark321db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDeltaCatalog.scala b/delta-lake/delta-spark321db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDeltaCatalog.scala new file mode 100644 index 00000000000..44b34a7ad18 --- /dev/null +++ b/delta-lake/delta-spark321db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDeltaCatalog.scala @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * This file was derived from DeltaDataSource.scala in the + * Delta Lake project at https://github.com/delta-io/delta. + * + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.databricks.sql.transaction.tahoe.rapids + +import com.databricks.sql.transaction.tahoe.{DeltaConfigs, DeltaErrors} +import com.databricks.sql.transaction.tahoe.commands.TableCreationModes +import com.databricks.sql.transaction.tahoe.sources.DeltaSourceUtils +import com.nvidia.spark.rapids.RapidsConf + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{AnalysisException, SaveMode} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.connector.catalog.StagingTableCatalog +import org.apache.spark.sql.execution.command.LeafRunnableCommand +import org.apache.spark.sql.execution.datasources.PartitioningUtils + +class GpuDeltaCatalog( + override val cpuCatalog: StagingTableCatalog, + override val rapidsConf: RapidsConf) + extends GpuDeltaCatalogBase with SupportsPathIdentifier with Logging { + + override protected def buildGpuCreateDeltaTableCommand( + rapidsConf: RapidsConf, + table: CatalogTable, + existingTableOpt: Option[CatalogTable], + mode: SaveMode, + query: Option[LogicalPlan], + operation: TableCreationModes.CreationMode, + tableByPath: Boolean): LeafRunnableCommand = { + GpuCreateDeltaTableCommand( + table, + existingTableOpt, + mode, + query, + operation, + tableByPath = tableByPath + )(rapidsConf) + } + + override protected def getExistingTableIfExists(table: TableIdentifier): Option[CatalogTable] = { + // If this is a path identifier, we cannot return an existing CatalogTable. The Create command + // will check the file system itself + if (isPathIdentifier(table)) return None + val tableExists = catalog.tableExists(table) + if (tableExists) { + val oldTable = catalog.getTableMetadata(table) + if (oldTable.tableType == CatalogTableType.VIEW) { + throw new AnalysisException( + s"$table is a view. You may not write data into a view.") + } + if (!DeltaSourceUtils.isDeltaTable(oldTable.provider)) { + throw new AnalysisException(s"$table is not a Delta table. Please drop this " + + "table first if you would like to recreate it with Delta Lake.") + } + Some(oldTable) + } else { + None + } + } + + override protected def verifyTableAndSolidify( + tableDesc: CatalogTable, + query: Option[LogicalPlan]): CatalogTable = { + + if (tableDesc.bucketSpec.isDefined) { + throw DeltaErrors.operationNotSupportedException("Bucketing", tableDesc.identifier) + } + + val schema = query.map { plan => + assert(tableDesc.schema.isEmpty, "Can't specify table schema in CTAS.") + plan.schema.asNullable + }.getOrElse(tableDesc.schema) + + PartitioningUtils.validatePartitionColumn( + schema, + tableDesc.partitionColumnNames, + caseSensitive = false) // Delta is case insensitive + + val validatedConfigurations = DeltaConfigs.validateConfigurations(tableDesc.properties) + + val db = tableDesc.identifier.database.getOrElse(catalog.getCurrentDatabase) + val tableIdentWithDB = tableDesc.identifier.copy(database = Some(db)) + tableDesc.copy( + identifier = tableIdentWithDB, + schema = schema, + properties = validatedConfigurations) + } +} diff --git a/delta-lake/delta-spark330db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuCreateDeltaTableCommand.scala b/delta-lake/delta-spark330db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuCreateDeltaTableCommand.scala new file mode 100644 index 00000000000..320485eb1ee --- /dev/null +++ b/delta-lake/delta-spark330db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuCreateDeltaTableCommand.scala @@ -0,0 +1,464 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * This file was derived from CreateDeltaTableCommand.scala in the + * Delta Lake project at https://github.com/delta-io/delta. + * + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.databricks.sql.transaction.tahoe.rapids + +import com.databricks.sql.transaction.tahoe._ +import com.databricks.sql.transaction.tahoe.actions.Metadata +import com.databricks.sql.transaction.tahoe.commands.{TableCreationModes, WriteIntoDelta} +import com.databricks.sql.transaction.tahoe.metering.DeltaLogging +import com.databricks.sql.transaction.tahoe.schema.SchemaUtils +import com.databricks.sql.transaction.tahoe.sources.DeltaSQLConf +import com.nvidia.spark.rapids.RapidsConf +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.execution.command.{LeafRunnableCommand, RunnableCommand} +import org.apache.spark.sql.types.StructType + +/** + * Single entry point for all write or declaration operations for Delta tables accessed through + * the table name. + * + * @param table The table identifier for the Delta table + * @param existingTableOpt The existing table for the same identifier if exists + * @param mode The save mode when writing data. Relevant when the query is empty or set to Ignore + * with `CREATE TABLE IF NOT EXISTS`. + * @param query The query to commit into the Delta table if it exist. This can come from + * - CTAS + * - saveAsTable + */ +case class GpuCreateDeltaTableCommand( + table: CatalogTable, + existingTableOpt: Option[CatalogTable], + mode: SaveMode, + query: Option[LogicalPlan], + operation: TableCreationModes.CreationMode = TableCreationModes.Create, + tableByPath: Boolean = false, + override val output: Seq[Attribute] = Nil)(@transient rapidsConf: RapidsConf) + extends LeafRunnableCommand + with DeltaLogging { + + override def otherCopyArgs: Seq[AnyRef] = Seq(rapidsConf) + + override def run(sparkSession: SparkSession): Seq[Row] = { + val table = this.table + + assert(table.tableType != CatalogTableType.VIEW) + assert(table.identifier.database.isDefined, "Database should've been fixed at analysis") + // There is a subtle race condition here, where the table can be created by someone else + // while this command is running. Nothing we can do about that though :( + val tableExists = existingTableOpt.isDefined + if (mode == SaveMode.Ignore && tableExists) { + // Early exit on ignore + return Nil + } else if (mode == SaveMode.ErrorIfExists && tableExists) { + throw DeltaErrors.tableAlreadyExists(table) + } + + val tableWithLocation = if (tableExists) { + val existingTable = existingTableOpt.get + table.storage.locationUri match { + case Some(location) if location.getPath != existingTable.location.getPath => + throw DeltaErrors.tableLocationMismatch(table, existingTable) + case _ => + } + table.copy( + storage = existingTable.storage, + tableType = existingTable.tableType) + } else if (table.storage.locationUri.isEmpty) { + // We are defining a new managed table + assert(table.tableType == CatalogTableType.MANAGED) + val loc = sparkSession.sessionState.catalog.defaultTablePath(table.identifier) + table.copy(storage = table.storage.copy(locationUri = Some(loc))) + } else { + // 1. We are defining a new external table + // 2. It's a managed table which already has the location populated. This can happen in DSV2 + // CTAS flow. + table + } + + val isManagedTable = tableWithLocation.tableType == CatalogTableType.MANAGED + val tableLocation = new Path(tableWithLocation.location) + val gpuDeltaLog = GpuDeltaLog.forTable(sparkSession, tableLocation, rapidsConf) + val hadoopConf = gpuDeltaLog.deltaLog.newDeltaHadoopConf() + val fs = tableLocation.getFileSystem(hadoopConf) + val options = new DeltaOptions(table.storage.properties, sparkSession.sessionState.conf) + var result: Seq[Row] = Nil + + recordDeltaOperation(gpuDeltaLog.deltaLog, "delta.ddl.createTable") { + val txn = gpuDeltaLog.startTransaction() + val opStartTs = System.currentTimeMillis() + if (query.isDefined) { + // If the mode is Ignore or ErrorIfExists, the table must not exist, or we would return + // earlier. And the data should not exist either, to match the behavior of + // Ignore/ErrorIfExists mode. This means the table path should not exist or is empty. + if (mode == SaveMode.Ignore || mode == SaveMode.ErrorIfExists) { + assert(!tableExists) + // We may have failed a previous write. The retry should still succeed even if we have + // garbage data + if (txn.readVersion > -1 || !fs.exists(gpuDeltaLog.deltaLog.logPath)) { + assertPathEmpty(hadoopConf, tableWithLocation) + } + } + // We are either appending/overwriting with saveAsTable or creating a new table with CTAS or + // we are creating a table as part of a RunnableCommand + query.get match { + case writer: WriteIntoDelta => + // In the V2 Writer, methods like "replace" and "createOrReplace" implicitly mean that + // the metadata should be changed. This wasn't the behavior for DataFrameWriterV1. + if (!isV1Writer) { + replaceMetadataIfNecessary( + txn, tableWithLocation, options, writer.data.schema.asNullable) + } + val actions = writer.write(txn, sparkSession) + val op = getOperation(txn.metadata, isManagedTable, Some(options)) + txn.commit(actions, op) + case cmd: RunnableCommand => + result = cmd.run(sparkSession) + case other => + // When using V1 APIs, the `other` plan is not yet optimized, therefore, it is safe + // to once again go through analysis + val data = Dataset.ofRows(sparkSession, other) + + // In the V2 Writer, methods like "replace" and "createOrReplace" implicitly mean that + // the metadata should be changed. This wasn't the behavior for DataFrameWriterV1. + if (!isV1Writer) { + replaceMetadataIfNecessary( + txn, tableWithLocation, options, other.schema.asNullable) + } + + val actions = WriteIntoDelta( + deltaLog = gpuDeltaLog.deltaLog, + mode = mode, + options, + partitionColumns = table.partitionColumnNames, + configuration = tableWithLocation.properties + ("comment" -> table.comment.orNull), + data = data).write(txn, sparkSession) + + val op = getOperation(txn.metadata, isManagedTable, Some(options)) + txn.commit(actions, op) + } + } else { + def createTransactionLogOrVerify(): Unit = { + if (isManagedTable) { + // When creating a managed table, the table path should not exist or is empty, or + // users would be surprised to see the data, or see the data directory being dropped + // after the table is dropped. + assertPathEmpty(hadoopConf, tableWithLocation) + } + + // This is either a new table, or, we never defined the schema of the table. While it is + // unexpected that `txn.metadata.schema` to be empty when txn.readVersion >= 0, we still + // guard against it, in case of checkpoint corruption bugs. + val noExistingMetadata = txn.readVersion == -1 || txn.metadata.schema.isEmpty + if (noExistingMetadata) { + assertTableSchemaDefined(fs, tableLocation, tableWithLocation, txn, sparkSession) + assertPathEmpty(hadoopConf, tableWithLocation) + // This is a user provided schema. + // Doesn't come from a query, Follow nullability invariants. + val newMetadata = getProvidedMetadata(tableWithLocation, table.schema.json) + txn.updateMetadataForNewTable(newMetadata) + + val op = getOperation(newMetadata, isManagedTable, None) + txn.commit(Nil, op) + } else { + verifyTableMetadata(txn, tableWithLocation) + } + } + // We are defining a table using the Create or Replace Table statements. + operation match { + case TableCreationModes.Create => + require(!tableExists, "Can't recreate a table when it exists") + createTransactionLogOrVerify() + + case TableCreationModes.CreateOrReplace if !tableExists => + // If the table doesn't exist, CREATE OR REPLACE must provide a schema + if (tableWithLocation.schema.isEmpty) { + throw DeltaErrors.schemaNotProvidedException + } + createTransactionLogOrVerify() + case _ => + // When the operation is a REPLACE or CREATE OR REPLACE, then the schema shouldn't be + // empty, since we'll use the entry to replace the schema + if (tableWithLocation.schema.isEmpty) { + throw DeltaErrors.schemaNotProvidedException + } + // We need to replace + replaceMetadataIfNecessary(txn, tableWithLocation, options, tableWithLocation.schema) + // Truncate the table + val operationTimestamp = System.currentTimeMillis() + val removes = txn.filterFiles().map(_.removeWithTimestamp(operationTimestamp)) + val op = getOperation(txn.metadata, isManagedTable, None) + txn.commit(removes, op) + } + } + + // We would have failed earlier on if we couldn't ignore the existence of the table + // In addition, we just might using saveAsTable to append to the table, so ignore the creation + // if it already exists. + // Note that someone may have dropped and recreated the table in a separate location in the + // meantime... Unfortunately we can't do anything there at the moment, because Hive sucks. + logInfo(s"Table is path-based table: $tableByPath. Update catalog with mode: $operation") + updateCatalog( + sparkSession, + tableWithLocation, + gpuDeltaLog.deltaLog.update(checkIfUpdatedSinceTs = Some(opStartTs)), + txn) + + result + } + } + + private def getProvidedMetadata(table: CatalogTable, schemaString: String): Metadata = { + Metadata( + description = table.comment.orNull, + schemaString = schemaString, + partitionColumns = table.partitionColumnNames, + configuration = table.properties, + createdTime = Some(System.currentTimeMillis())) + } + + private def assertPathEmpty( + hadoopConf: Configuration, + tableWithLocation: CatalogTable): Unit = { + val path = new Path(tableWithLocation.location) + val fs = path.getFileSystem(hadoopConf) + // Verify that the table location associated with CREATE TABLE doesn't have any data. Note that + // we intentionally diverge from this behavior w.r.t regular datasource tables (that silently + // overwrite any previous data) + if (fs.exists(path) && fs.listStatus(path).nonEmpty) { + throw DeltaErrors.createTableWithNonEmptyLocation( + tableWithLocation.identifier.toString, + tableWithLocation.location.toString) + } + } + + private def assertTableSchemaDefined( + fs: FileSystem, + path: Path, + table: CatalogTable, + txn: OptimisticTransaction, + sparkSession: SparkSession): Unit = { + // If we allow creating an empty schema table and indeed the table is new, we just need to + // make sure: + // 1. txn.readVersion == -1 to read a new table + // 2. for external tables: path must either doesn't exist or is completely empty + val allowCreatingTableWithEmptySchema = sparkSession.sessionState + .conf.getConf(DeltaSQLConf.DELTA_ALLOW_CREATE_EMPTY_SCHEMA_TABLE) && txn.readVersion == -1 + + // Users did not specify the schema. We expect the schema exists in Delta. + if (table.schema.isEmpty) { + if (table.tableType == CatalogTableType.EXTERNAL) { + if (fs.exists(path) && fs.listStatus(path).nonEmpty) { + throw DeltaErrors.createExternalTableWithoutLogException( + path, table.identifier.quotedString, sparkSession) + } else { + if (allowCreatingTableWithEmptySchema) return + throw DeltaErrors.createExternalTableWithoutSchemaException( + path, table.identifier.quotedString, sparkSession) + } + } else { + if (allowCreatingTableWithEmptySchema) return + throw DeltaErrors.createManagedTableWithoutSchemaException( + table.identifier.quotedString, sparkSession) + } + } + } + + /** + * Verify against our transaction metadata that the user specified the right metadata for the + * table. + */ + private def verifyTableMetadata( + txn: OptimisticTransaction, + tableDesc: CatalogTable): Unit = { + val existingMetadata = txn.metadata + val path = new Path(tableDesc.location) + + // The delta log already exists. If they give any configuration, we'll make sure it all matches. + // Otherwise we'll just go with the metadata already present in the log. + // The schema compatibility checks will be made in `WriteIntoDelta` for CreateTable + // with a query + if (txn.readVersion > -1) { + if (tableDesc.schema.nonEmpty) { + // We check exact alignment on create table if everything is provided + // However, if in column mapping mode, we can safely ignore the related metadata fields in + // existing metadata because new table desc will not have related metadata assigned yet + val differences = SchemaUtils.reportDifferences( + DeltaColumnMapping.dropColumnMappingMetadata(existingMetadata.schema), + tableDesc.schema) + if (differences.nonEmpty) { + throw DeltaErrors.createTableWithDifferentSchemaException( + path, tableDesc.schema, existingMetadata.schema, differences) + } + } + + // If schema is specified, we must make sure the partitioning matches, even the partitioning + // is not specified. + if (tableDesc.schema.nonEmpty && + tableDesc.partitionColumnNames != existingMetadata.partitionColumns) { + throw DeltaErrors.createTableWithDifferentPartitioningException( + path, tableDesc.partitionColumnNames, existingMetadata.partitionColumns) + } + + if (tableDesc.properties.nonEmpty && tableDesc.properties != existingMetadata.configuration) { + throw DeltaErrors.createTableWithDifferentPropertiesException( + path, tableDesc.properties, existingMetadata.configuration) + } + } + } + + /** + * Based on the table creation operation, and parameters, we can resolve to different operations. + * A lot of this is needed for legacy reasons in Databricks Runtime. + * @param metadata The table metadata, which we are creating or replacing + * @param isManagedTable Whether we are creating or replacing a managed table + * @param options Write options, if this was a CTAS/RTAS + */ + private def getOperation( + metadata: Metadata, + isManagedTable: Boolean, + options: Option[DeltaOptions]): DeltaOperations.Operation = operation match { + // This is legacy saveAsTable behavior in Databricks Runtime + case TableCreationModes.Create if existingTableOpt.isDefined && query.isDefined => + DeltaOperations.Write(mode, Option(table.partitionColumnNames), options.get.replaceWhere, + options.flatMap(_.userMetadata)) + + // DataSourceV2 table creation + // CREATE TABLE (non-DataFrameWriter API) doesn't have options syntax + // (userMetadata uses SQLConf in this case) + case TableCreationModes.Create => + DeltaOperations.CreateTable(metadata, isManagedTable, query.isDefined) + + // DataSourceV2 table replace + // REPLACE TABLE (non-DataFrameWriter API) doesn't have options syntax + // (userMetadata uses SQLConf in this case) + case TableCreationModes.Replace => + DeltaOperations.ReplaceTable(metadata, isManagedTable, orCreate = false, query.isDefined) + + // Legacy saveAsTable with Overwrite mode + case TableCreationModes.CreateOrReplace if options.exists(_.replaceWhere.isDefined) => + DeltaOperations.Write(mode, Option(table.partitionColumnNames), options.get.replaceWhere, + options.flatMap(_.userMetadata)) + + // New DataSourceV2 saveAsTable with overwrite mode behavior + case TableCreationModes.CreateOrReplace => + DeltaOperations.ReplaceTable(metadata, isManagedTable, orCreate = true, query.isDefined, + options.flatMap(_.userMetadata)) + } + + /** + * Similar to getOperation, here we disambiguate the catalog alterations we need to do based + * on the table operation, and whether we have reached here through legacy code or DataSourceV2 + * code paths. + */ + private def updateCatalog( + spark: SparkSession, + table: CatalogTable, + snapshot: Snapshot, + txn: OptimisticTransaction): Unit = { + val cleaned = cleanupTableDefinition(table, snapshot) + operation match { + case _ if tableByPath => // do nothing with the metastore if this is by path + case TableCreationModes.Create => + spark.sessionState.catalog.createTable( + cleaned, + ignoreIfExists = existingTableOpt.isDefined, + validateLocation = false) + case TableCreationModes.Replace | TableCreationModes.CreateOrReplace + if existingTableOpt.isDefined => + spark.sessionState.catalog.alterTable(table) + case TableCreationModes.Replace => + val ident = Identifier.of(table.identifier.database.toArray, table.identifier.table) + throw DeltaErrors.cannotReplaceMissingTableException(ident) + case TableCreationModes.CreateOrReplace => + spark.sessionState.catalog.createTable( + cleaned, + ignoreIfExists = false, + validateLocation = false) + } + } + + /** Clean up the information we pass on to store in the catalog. */ + private def cleanupTableDefinition(table: CatalogTable, snapshot: Snapshot): CatalogTable = { + // These actually have no effect on the usability of Delta, but feature flagging legacy + // behavior for now + val storageProps = if (conf.getConf(DeltaSQLConf.DELTA_LEGACY_STORE_WRITER_OPTIONS_AS_PROPS)) { + // Legacy behavior + table.storage + } else { + table.storage.copy(properties = Map.empty) + } + + table.copy( + schema = new StructType(), + properties = Map.empty, + partitionColumnNames = Nil, + // Remove write specific options when updating the catalog + storage = storageProps, + tracksPartitionsInCatalog = true) + } + + /** + * With DataFrameWriterV2, methods like `replace()` or `createOrReplace()` mean that the + * metadata of the table should be replaced. If overwriteSchema=false is provided with these + * methods, then we will verify that the metadata match exactly. + */ + private def replaceMetadataIfNecessary( + txn: OptimisticTransaction, + tableDesc: CatalogTable, + options: DeltaOptions, + schema: StructType): Unit = { + val isReplace = (operation == TableCreationModes.CreateOrReplace || + operation == TableCreationModes.Replace) + // If a user explicitly specifies not to overwrite the schema, during a replace, we should + // tell them that it's not supported + val dontOverwriteSchema = options.options.contains(DeltaOptions.OVERWRITE_SCHEMA_OPTION) && + !options.canOverwriteSchema + if (isReplace && dontOverwriteSchema) { + throw DeltaErrors.illegalUsageException(DeltaOptions.OVERWRITE_SCHEMA_OPTION, "replacing") + } + if (txn.readVersion > -1L && isReplace && !dontOverwriteSchema) { + // When a table already exists, and we're using the DataFrameWriterV2 API to replace + // or createOrReplace a table, we blindly overwrite the metadata. + txn.updateMetadataForNewTable(getProvidedMetadata(table, schema.json)) + } + } + + /** + * Horrible hack to differentiate between DataFrameWriterV1 and V2 so that we can decide + * what to do with table metadata. In DataFrameWriterV1, mode("overwrite").saveAsTable, + * behaves as a CreateOrReplace table, but we have asked for "overwriteSchema" as an + * explicit option to overwrite partitioning or schema information. With DataFrameWriterV2, + * the behavior asked for by the user is clearer: .createOrReplace(), which means that we + * should overwrite schema and/or partitioning. Therefore we have this hack. + */ + private def isV1Writer: Boolean = { + Thread.currentThread().getStackTrace.exists(_.toString.contains( + classOf[DataFrameWriter[_]].getCanonicalName + ".")) + } +} diff --git a/delta-lake/delta-spark330db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDeltaCatalog.scala b/delta-lake/delta-spark330db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDeltaCatalog.scala new file mode 100644 index 00000000000..8c62f4f3fd7 --- /dev/null +++ b/delta-lake/delta-spark330db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDeltaCatalog.scala @@ -0,0 +1,218 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * This file was derived from DeltaDataSource.scala in the + * Delta Lake project at https://github.com/delta-io/delta. + * + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.databricks.sql.transaction.tahoe.rapids + +import java.util + +import com.databricks.sql.transaction.tahoe.{DeltaConfigs, DeltaErrors} +import com.databricks.sql.transaction.tahoe.commands.TableCreationModes +import com.databricks.sql.transaction.tahoe.metering.DeltaLogging +import com.databricks.sql.transaction.tahoe.sources.DeltaSourceUtils +import com.nvidia.spark.rapids.RapidsConf + +import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.connector.catalog.{Identifier, StagedTable, StagingTableCatalog, Table} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.execution.command.LeafRunnableCommand +import org.apache.spark.sql.execution.datasources.PartitioningUtils +import org.apache.spark.sql.types.StructType + +class GpuDeltaCatalog( + override val cpuCatalog: StagingTableCatalog, + override val rapidsConf: RapidsConf) + extends GpuDeltaCatalogBase with SupportsPathIdentifier with DeltaLogging { + + override protected def buildGpuCreateDeltaTableCommand( + rapidsConf: RapidsConf, + table: CatalogTable, + existingTableOpt: Option[CatalogTable], + mode: SaveMode, + query: Option[LogicalPlan], + operation: TableCreationModes.CreationMode, + tableByPath: Boolean): LeafRunnableCommand = { + GpuCreateDeltaTableCommand( + table, + existingTableOpt, + mode, + query, + operation, + tableByPath = tableByPath + )(rapidsConf) + } + + override protected def getExistingTableIfExists(table: TableIdentifier): Option[CatalogTable] = { + // If this is a path identifier, we cannot return an existing CatalogTable. The Create command + // will check the file system itself + if (isPathIdentifier(table)) return None + val tableExists = catalog.tableExists(table) + if (tableExists) { + val oldTable = catalog.getTableMetadata(table) + if (oldTable.tableType == CatalogTableType.VIEW) { + throw new AnalysisException( + s"$table is a view. You may not write data into a view.") + } + if (!DeltaSourceUtils.isDeltaTable(oldTable.provider)) { + throw new AnalysisException(s"$table is not a Delta table. Please drop this " + + "table first if you would like to recreate it with Delta Lake.") + } + Some(oldTable) + } else { + None + } + } + + override protected def verifyTableAndSolidify( + tableDesc: CatalogTable, + query: Option[LogicalPlan]): CatalogTable = { + + if (tableDesc.bucketSpec.isDefined) { + throw DeltaErrors.operationNotSupportedException("Bucketing", tableDesc.identifier) + } + + val schema = query.map { plan => + assert(tableDesc.schema.isEmpty, "Can't specify table schema in CTAS.") + plan.schema.asNullable + }.getOrElse(tableDesc.schema) + + PartitioningUtils.validatePartitionColumn( + schema, + tableDesc.partitionColumnNames, + caseSensitive = false) // Delta is case insensitive + + val validatedConfigurations = DeltaConfigs.validateConfigurations(tableDesc.properties) + + val db = tableDesc.identifier.database.getOrElse(catalog.getCurrentDatabase) + val tableIdentWithDB = tableDesc.identifier.copy(database = Some(db)) + tableDesc.copy( + identifier = tableIdentWithDB, + schema = schema, + properties = validatedConfigurations) + } + + override protected def createGpuStagedDeltaTableV2( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String], + operation: TableCreationModes.CreationMode): StagedTable = { + new GpuStagedDeltaTableV2WithLogging(ident, schema, partitions, properties, operation) + } + + override def loadTable(ident: Identifier, timestamp: Long): Table = { + cpuCatalog.loadTable(ident, timestamp) + } + + override def loadTable(ident: Identifier, version: String): Table = { + cpuCatalog.loadTable(ident, version) + } + + /** + * Creates a Delta table using GPU for writing the data + * + * @param ident The identifier of the table + * @param schema The schema of the table + * @param partitions The partition transforms for the table + * @param allTableProperties The table properties that configure the behavior of the table or + * provide information about the table + * @param writeOptions Options specific to the write during table creation or replacement + * @param sourceQuery A query if this CREATE request came from a CTAS or RTAS + * @param operation The specific table creation mode, whether this is a + * Create/Replace/Create or Replace + */ + override def createDeltaTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + allTableProperties: util.Map[String, String], + writeOptions: Map[String, String], + sourceQuery: Option[DataFrame], + operation: TableCreationModes.CreationMode + ): Table = recordFrameProfile( + "DeltaCatalog", "createDeltaTable") { + super.createDeltaTable( + ident, + schema, + partitions, + allTableProperties, + writeOptions, + sourceQuery, + operation) + } + + override def createTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = + recordFrameProfile("DeltaCatalog", "createTable") { + super.createTable(ident, schema, partitions, properties) + } + + override def stageReplace( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): StagedTable = + recordFrameProfile("DeltaCatalog", "stageReplace") { + super.stageReplace(ident, schema, partitions, properties) + } + + override def stageCreateOrReplace( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): StagedTable = + recordFrameProfile("DeltaCatalog", "stageCreateOrReplace") { + super.stageCreateOrReplace(ident, schema, partitions, properties) + } + + override def stageCreate( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): StagedTable = + recordFrameProfile("DeltaCatalog", "stageCreate") { + super.stageCreate(ident, schema, partitions, properties) + } + + /** + * A staged Delta table, which creates a HiveMetaStore entry and appends data if this was a + * CTAS/RTAS command. We have a ugly way of using this API right now, but it's the best way to + * maintain old behavior compatibility between Databricks Runtime and OSS Delta Lake. + */ + protected class GpuStagedDeltaTableV2WithLogging( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String], + operation: TableCreationModes.CreationMode) + extends GpuStagedDeltaTableV2(ident, schema, partitions, properties, operation) { + + override def commitStagedChanges(): Unit = recordFrameProfile( + "DeltaCatalog", "commitStagedChanges") { + super.commitStagedChanges() + } + } +} diff --git a/delta-lake/delta-spark332db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuCreateDeltaTableCommand.scala b/delta-lake/delta-spark332db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuCreateDeltaTableCommand.scala new file mode 100644 index 00000000000..320485eb1ee --- /dev/null +++ b/delta-lake/delta-spark332db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuCreateDeltaTableCommand.scala @@ -0,0 +1,464 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * This file was derived from CreateDeltaTableCommand.scala in the + * Delta Lake project at https://github.com/delta-io/delta. + * + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.databricks.sql.transaction.tahoe.rapids + +import com.databricks.sql.transaction.tahoe._ +import com.databricks.sql.transaction.tahoe.actions.Metadata +import com.databricks.sql.transaction.tahoe.commands.{TableCreationModes, WriteIntoDelta} +import com.databricks.sql.transaction.tahoe.metering.DeltaLogging +import com.databricks.sql.transaction.tahoe.schema.SchemaUtils +import com.databricks.sql.transaction.tahoe.sources.DeltaSQLConf +import com.nvidia.spark.rapids.RapidsConf +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.execution.command.{LeafRunnableCommand, RunnableCommand} +import org.apache.spark.sql.types.StructType + +/** + * Single entry point for all write or declaration operations for Delta tables accessed through + * the table name. + * + * @param table The table identifier for the Delta table + * @param existingTableOpt The existing table for the same identifier if exists + * @param mode The save mode when writing data. Relevant when the query is empty or set to Ignore + * with `CREATE TABLE IF NOT EXISTS`. + * @param query The query to commit into the Delta table if it exist. This can come from + * - CTAS + * - saveAsTable + */ +case class GpuCreateDeltaTableCommand( + table: CatalogTable, + existingTableOpt: Option[CatalogTable], + mode: SaveMode, + query: Option[LogicalPlan], + operation: TableCreationModes.CreationMode = TableCreationModes.Create, + tableByPath: Boolean = false, + override val output: Seq[Attribute] = Nil)(@transient rapidsConf: RapidsConf) + extends LeafRunnableCommand + with DeltaLogging { + + override def otherCopyArgs: Seq[AnyRef] = Seq(rapidsConf) + + override def run(sparkSession: SparkSession): Seq[Row] = { + val table = this.table + + assert(table.tableType != CatalogTableType.VIEW) + assert(table.identifier.database.isDefined, "Database should've been fixed at analysis") + // There is a subtle race condition here, where the table can be created by someone else + // while this command is running. Nothing we can do about that though :( + val tableExists = existingTableOpt.isDefined + if (mode == SaveMode.Ignore && tableExists) { + // Early exit on ignore + return Nil + } else if (mode == SaveMode.ErrorIfExists && tableExists) { + throw DeltaErrors.tableAlreadyExists(table) + } + + val tableWithLocation = if (tableExists) { + val existingTable = existingTableOpt.get + table.storage.locationUri match { + case Some(location) if location.getPath != existingTable.location.getPath => + throw DeltaErrors.tableLocationMismatch(table, existingTable) + case _ => + } + table.copy( + storage = existingTable.storage, + tableType = existingTable.tableType) + } else if (table.storage.locationUri.isEmpty) { + // We are defining a new managed table + assert(table.tableType == CatalogTableType.MANAGED) + val loc = sparkSession.sessionState.catalog.defaultTablePath(table.identifier) + table.copy(storage = table.storage.copy(locationUri = Some(loc))) + } else { + // 1. We are defining a new external table + // 2. It's a managed table which already has the location populated. This can happen in DSV2 + // CTAS flow. + table + } + + val isManagedTable = tableWithLocation.tableType == CatalogTableType.MANAGED + val tableLocation = new Path(tableWithLocation.location) + val gpuDeltaLog = GpuDeltaLog.forTable(sparkSession, tableLocation, rapidsConf) + val hadoopConf = gpuDeltaLog.deltaLog.newDeltaHadoopConf() + val fs = tableLocation.getFileSystem(hadoopConf) + val options = new DeltaOptions(table.storage.properties, sparkSession.sessionState.conf) + var result: Seq[Row] = Nil + + recordDeltaOperation(gpuDeltaLog.deltaLog, "delta.ddl.createTable") { + val txn = gpuDeltaLog.startTransaction() + val opStartTs = System.currentTimeMillis() + if (query.isDefined) { + // If the mode is Ignore or ErrorIfExists, the table must not exist, or we would return + // earlier. And the data should not exist either, to match the behavior of + // Ignore/ErrorIfExists mode. This means the table path should not exist or is empty. + if (mode == SaveMode.Ignore || mode == SaveMode.ErrorIfExists) { + assert(!tableExists) + // We may have failed a previous write. The retry should still succeed even if we have + // garbage data + if (txn.readVersion > -1 || !fs.exists(gpuDeltaLog.deltaLog.logPath)) { + assertPathEmpty(hadoopConf, tableWithLocation) + } + } + // We are either appending/overwriting with saveAsTable or creating a new table with CTAS or + // we are creating a table as part of a RunnableCommand + query.get match { + case writer: WriteIntoDelta => + // In the V2 Writer, methods like "replace" and "createOrReplace" implicitly mean that + // the metadata should be changed. This wasn't the behavior for DataFrameWriterV1. + if (!isV1Writer) { + replaceMetadataIfNecessary( + txn, tableWithLocation, options, writer.data.schema.asNullable) + } + val actions = writer.write(txn, sparkSession) + val op = getOperation(txn.metadata, isManagedTable, Some(options)) + txn.commit(actions, op) + case cmd: RunnableCommand => + result = cmd.run(sparkSession) + case other => + // When using V1 APIs, the `other` plan is not yet optimized, therefore, it is safe + // to once again go through analysis + val data = Dataset.ofRows(sparkSession, other) + + // In the V2 Writer, methods like "replace" and "createOrReplace" implicitly mean that + // the metadata should be changed. This wasn't the behavior for DataFrameWriterV1. + if (!isV1Writer) { + replaceMetadataIfNecessary( + txn, tableWithLocation, options, other.schema.asNullable) + } + + val actions = WriteIntoDelta( + deltaLog = gpuDeltaLog.deltaLog, + mode = mode, + options, + partitionColumns = table.partitionColumnNames, + configuration = tableWithLocation.properties + ("comment" -> table.comment.orNull), + data = data).write(txn, sparkSession) + + val op = getOperation(txn.metadata, isManagedTable, Some(options)) + txn.commit(actions, op) + } + } else { + def createTransactionLogOrVerify(): Unit = { + if (isManagedTable) { + // When creating a managed table, the table path should not exist or is empty, or + // users would be surprised to see the data, or see the data directory being dropped + // after the table is dropped. + assertPathEmpty(hadoopConf, tableWithLocation) + } + + // This is either a new table, or, we never defined the schema of the table. While it is + // unexpected that `txn.metadata.schema` to be empty when txn.readVersion >= 0, we still + // guard against it, in case of checkpoint corruption bugs. + val noExistingMetadata = txn.readVersion == -1 || txn.metadata.schema.isEmpty + if (noExistingMetadata) { + assertTableSchemaDefined(fs, tableLocation, tableWithLocation, txn, sparkSession) + assertPathEmpty(hadoopConf, tableWithLocation) + // This is a user provided schema. + // Doesn't come from a query, Follow nullability invariants. + val newMetadata = getProvidedMetadata(tableWithLocation, table.schema.json) + txn.updateMetadataForNewTable(newMetadata) + + val op = getOperation(newMetadata, isManagedTable, None) + txn.commit(Nil, op) + } else { + verifyTableMetadata(txn, tableWithLocation) + } + } + // We are defining a table using the Create or Replace Table statements. + operation match { + case TableCreationModes.Create => + require(!tableExists, "Can't recreate a table when it exists") + createTransactionLogOrVerify() + + case TableCreationModes.CreateOrReplace if !tableExists => + // If the table doesn't exist, CREATE OR REPLACE must provide a schema + if (tableWithLocation.schema.isEmpty) { + throw DeltaErrors.schemaNotProvidedException + } + createTransactionLogOrVerify() + case _ => + // When the operation is a REPLACE or CREATE OR REPLACE, then the schema shouldn't be + // empty, since we'll use the entry to replace the schema + if (tableWithLocation.schema.isEmpty) { + throw DeltaErrors.schemaNotProvidedException + } + // We need to replace + replaceMetadataIfNecessary(txn, tableWithLocation, options, tableWithLocation.schema) + // Truncate the table + val operationTimestamp = System.currentTimeMillis() + val removes = txn.filterFiles().map(_.removeWithTimestamp(operationTimestamp)) + val op = getOperation(txn.metadata, isManagedTable, None) + txn.commit(removes, op) + } + } + + // We would have failed earlier on if we couldn't ignore the existence of the table + // In addition, we just might using saveAsTable to append to the table, so ignore the creation + // if it already exists. + // Note that someone may have dropped and recreated the table in a separate location in the + // meantime... Unfortunately we can't do anything there at the moment, because Hive sucks. + logInfo(s"Table is path-based table: $tableByPath. Update catalog with mode: $operation") + updateCatalog( + sparkSession, + tableWithLocation, + gpuDeltaLog.deltaLog.update(checkIfUpdatedSinceTs = Some(opStartTs)), + txn) + + result + } + } + + private def getProvidedMetadata(table: CatalogTable, schemaString: String): Metadata = { + Metadata( + description = table.comment.orNull, + schemaString = schemaString, + partitionColumns = table.partitionColumnNames, + configuration = table.properties, + createdTime = Some(System.currentTimeMillis())) + } + + private def assertPathEmpty( + hadoopConf: Configuration, + tableWithLocation: CatalogTable): Unit = { + val path = new Path(tableWithLocation.location) + val fs = path.getFileSystem(hadoopConf) + // Verify that the table location associated with CREATE TABLE doesn't have any data. Note that + // we intentionally diverge from this behavior w.r.t regular datasource tables (that silently + // overwrite any previous data) + if (fs.exists(path) && fs.listStatus(path).nonEmpty) { + throw DeltaErrors.createTableWithNonEmptyLocation( + tableWithLocation.identifier.toString, + tableWithLocation.location.toString) + } + } + + private def assertTableSchemaDefined( + fs: FileSystem, + path: Path, + table: CatalogTable, + txn: OptimisticTransaction, + sparkSession: SparkSession): Unit = { + // If we allow creating an empty schema table and indeed the table is new, we just need to + // make sure: + // 1. txn.readVersion == -1 to read a new table + // 2. for external tables: path must either doesn't exist or is completely empty + val allowCreatingTableWithEmptySchema = sparkSession.sessionState + .conf.getConf(DeltaSQLConf.DELTA_ALLOW_CREATE_EMPTY_SCHEMA_TABLE) && txn.readVersion == -1 + + // Users did not specify the schema. We expect the schema exists in Delta. + if (table.schema.isEmpty) { + if (table.tableType == CatalogTableType.EXTERNAL) { + if (fs.exists(path) && fs.listStatus(path).nonEmpty) { + throw DeltaErrors.createExternalTableWithoutLogException( + path, table.identifier.quotedString, sparkSession) + } else { + if (allowCreatingTableWithEmptySchema) return + throw DeltaErrors.createExternalTableWithoutSchemaException( + path, table.identifier.quotedString, sparkSession) + } + } else { + if (allowCreatingTableWithEmptySchema) return + throw DeltaErrors.createManagedTableWithoutSchemaException( + table.identifier.quotedString, sparkSession) + } + } + } + + /** + * Verify against our transaction metadata that the user specified the right metadata for the + * table. + */ + private def verifyTableMetadata( + txn: OptimisticTransaction, + tableDesc: CatalogTable): Unit = { + val existingMetadata = txn.metadata + val path = new Path(tableDesc.location) + + // The delta log already exists. If they give any configuration, we'll make sure it all matches. + // Otherwise we'll just go with the metadata already present in the log. + // The schema compatibility checks will be made in `WriteIntoDelta` for CreateTable + // with a query + if (txn.readVersion > -1) { + if (tableDesc.schema.nonEmpty) { + // We check exact alignment on create table if everything is provided + // However, if in column mapping mode, we can safely ignore the related metadata fields in + // existing metadata because new table desc will not have related metadata assigned yet + val differences = SchemaUtils.reportDifferences( + DeltaColumnMapping.dropColumnMappingMetadata(existingMetadata.schema), + tableDesc.schema) + if (differences.nonEmpty) { + throw DeltaErrors.createTableWithDifferentSchemaException( + path, tableDesc.schema, existingMetadata.schema, differences) + } + } + + // If schema is specified, we must make sure the partitioning matches, even the partitioning + // is not specified. + if (tableDesc.schema.nonEmpty && + tableDesc.partitionColumnNames != existingMetadata.partitionColumns) { + throw DeltaErrors.createTableWithDifferentPartitioningException( + path, tableDesc.partitionColumnNames, existingMetadata.partitionColumns) + } + + if (tableDesc.properties.nonEmpty && tableDesc.properties != existingMetadata.configuration) { + throw DeltaErrors.createTableWithDifferentPropertiesException( + path, tableDesc.properties, existingMetadata.configuration) + } + } + } + + /** + * Based on the table creation operation, and parameters, we can resolve to different operations. + * A lot of this is needed for legacy reasons in Databricks Runtime. + * @param metadata The table metadata, which we are creating or replacing + * @param isManagedTable Whether we are creating or replacing a managed table + * @param options Write options, if this was a CTAS/RTAS + */ + private def getOperation( + metadata: Metadata, + isManagedTable: Boolean, + options: Option[DeltaOptions]): DeltaOperations.Operation = operation match { + // This is legacy saveAsTable behavior in Databricks Runtime + case TableCreationModes.Create if existingTableOpt.isDefined && query.isDefined => + DeltaOperations.Write(mode, Option(table.partitionColumnNames), options.get.replaceWhere, + options.flatMap(_.userMetadata)) + + // DataSourceV2 table creation + // CREATE TABLE (non-DataFrameWriter API) doesn't have options syntax + // (userMetadata uses SQLConf in this case) + case TableCreationModes.Create => + DeltaOperations.CreateTable(metadata, isManagedTable, query.isDefined) + + // DataSourceV2 table replace + // REPLACE TABLE (non-DataFrameWriter API) doesn't have options syntax + // (userMetadata uses SQLConf in this case) + case TableCreationModes.Replace => + DeltaOperations.ReplaceTable(metadata, isManagedTable, orCreate = false, query.isDefined) + + // Legacy saveAsTable with Overwrite mode + case TableCreationModes.CreateOrReplace if options.exists(_.replaceWhere.isDefined) => + DeltaOperations.Write(mode, Option(table.partitionColumnNames), options.get.replaceWhere, + options.flatMap(_.userMetadata)) + + // New DataSourceV2 saveAsTable with overwrite mode behavior + case TableCreationModes.CreateOrReplace => + DeltaOperations.ReplaceTable(metadata, isManagedTable, orCreate = true, query.isDefined, + options.flatMap(_.userMetadata)) + } + + /** + * Similar to getOperation, here we disambiguate the catalog alterations we need to do based + * on the table operation, and whether we have reached here through legacy code or DataSourceV2 + * code paths. + */ + private def updateCatalog( + spark: SparkSession, + table: CatalogTable, + snapshot: Snapshot, + txn: OptimisticTransaction): Unit = { + val cleaned = cleanupTableDefinition(table, snapshot) + operation match { + case _ if tableByPath => // do nothing with the metastore if this is by path + case TableCreationModes.Create => + spark.sessionState.catalog.createTable( + cleaned, + ignoreIfExists = existingTableOpt.isDefined, + validateLocation = false) + case TableCreationModes.Replace | TableCreationModes.CreateOrReplace + if existingTableOpt.isDefined => + spark.sessionState.catalog.alterTable(table) + case TableCreationModes.Replace => + val ident = Identifier.of(table.identifier.database.toArray, table.identifier.table) + throw DeltaErrors.cannotReplaceMissingTableException(ident) + case TableCreationModes.CreateOrReplace => + spark.sessionState.catalog.createTable( + cleaned, + ignoreIfExists = false, + validateLocation = false) + } + } + + /** Clean up the information we pass on to store in the catalog. */ + private def cleanupTableDefinition(table: CatalogTable, snapshot: Snapshot): CatalogTable = { + // These actually have no effect on the usability of Delta, but feature flagging legacy + // behavior for now + val storageProps = if (conf.getConf(DeltaSQLConf.DELTA_LEGACY_STORE_WRITER_OPTIONS_AS_PROPS)) { + // Legacy behavior + table.storage + } else { + table.storage.copy(properties = Map.empty) + } + + table.copy( + schema = new StructType(), + properties = Map.empty, + partitionColumnNames = Nil, + // Remove write specific options when updating the catalog + storage = storageProps, + tracksPartitionsInCatalog = true) + } + + /** + * With DataFrameWriterV2, methods like `replace()` or `createOrReplace()` mean that the + * metadata of the table should be replaced. If overwriteSchema=false is provided with these + * methods, then we will verify that the metadata match exactly. + */ + private def replaceMetadataIfNecessary( + txn: OptimisticTransaction, + tableDesc: CatalogTable, + options: DeltaOptions, + schema: StructType): Unit = { + val isReplace = (operation == TableCreationModes.CreateOrReplace || + operation == TableCreationModes.Replace) + // If a user explicitly specifies not to overwrite the schema, during a replace, we should + // tell them that it's not supported + val dontOverwriteSchema = options.options.contains(DeltaOptions.OVERWRITE_SCHEMA_OPTION) && + !options.canOverwriteSchema + if (isReplace && dontOverwriteSchema) { + throw DeltaErrors.illegalUsageException(DeltaOptions.OVERWRITE_SCHEMA_OPTION, "replacing") + } + if (txn.readVersion > -1L && isReplace && !dontOverwriteSchema) { + // When a table already exists, and we're using the DataFrameWriterV2 API to replace + // or createOrReplace a table, we blindly overwrite the metadata. + txn.updateMetadataForNewTable(getProvidedMetadata(table, schema.json)) + } + } + + /** + * Horrible hack to differentiate between DataFrameWriterV1 and V2 so that we can decide + * what to do with table metadata. In DataFrameWriterV1, mode("overwrite").saveAsTable, + * behaves as a CreateOrReplace table, but we have asked for "overwriteSchema" as an + * explicit option to overwrite partitioning or schema information. With DataFrameWriterV2, + * the behavior asked for by the user is clearer: .createOrReplace(), which means that we + * should overwrite schema and/or partitioning. Therefore we have this hack. + */ + private def isV1Writer: Boolean = { + Thread.currentThread().getStackTrace.exists(_.toString.contains( + classOf[DataFrameWriter[_]].getCanonicalName + ".")) + } +} diff --git a/delta-lake/delta-spark332db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDeltaCatalog.scala b/delta-lake/delta-spark332db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDeltaCatalog.scala new file mode 100644 index 00000000000..8c62f4f3fd7 --- /dev/null +++ b/delta-lake/delta-spark332db/src/main/scala/com/databricks/sql/transaction/tahoe/rapids/GpuDeltaCatalog.scala @@ -0,0 +1,218 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * This file was derived from DeltaDataSource.scala in the + * Delta Lake project at https://github.com/delta-io/delta. + * + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.databricks.sql.transaction.tahoe.rapids + +import java.util + +import com.databricks.sql.transaction.tahoe.{DeltaConfigs, DeltaErrors} +import com.databricks.sql.transaction.tahoe.commands.TableCreationModes +import com.databricks.sql.transaction.tahoe.metering.DeltaLogging +import com.databricks.sql.transaction.tahoe.sources.DeltaSourceUtils +import com.nvidia.spark.rapids.RapidsConf + +import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.connector.catalog.{Identifier, StagedTable, StagingTableCatalog, Table} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.execution.command.LeafRunnableCommand +import org.apache.spark.sql.execution.datasources.PartitioningUtils +import org.apache.spark.sql.types.StructType + +class GpuDeltaCatalog( + override val cpuCatalog: StagingTableCatalog, + override val rapidsConf: RapidsConf) + extends GpuDeltaCatalogBase with SupportsPathIdentifier with DeltaLogging { + + override protected def buildGpuCreateDeltaTableCommand( + rapidsConf: RapidsConf, + table: CatalogTable, + existingTableOpt: Option[CatalogTable], + mode: SaveMode, + query: Option[LogicalPlan], + operation: TableCreationModes.CreationMode, + tableByPath: Boolean): LeafRunnableCommand = { + GpuCreateDeltaTableCommand( + table, + existingTableOpt, + mode, + query, + operation, + tableByPath = tableByPath + )(rapidsConf) + } + + override protected def getExistingTableIfExists(table: TableIdentifier): Option[CatalogTable] = { + // If this is a path identifier, we cannot return an existing CatalogTable. The Create command + // will check the file system itself + if (isPathIdentifier(table)) return None + val tableExists = catalog.tableExists(table) + if (tableExists) { + val oldTable = catalog.getTableMetadata(table) + if (oldTable.tableType == CatalogTableType.VIEW) { + throw new AnalysisException( + s"$table is a view. You may not write data into a view.") + } + if (!DeltaSourceUtils.isDeltaTable(oldTable.provider)) { + throw new AnalysisException(s"$table is not a Delta table. Please drop this " + + "table first if you would like to recreate it with Delta Lake.") + } + Some(oldTable) + } else { + None + } + } + + override protected def verifyTableAndSolidify( + tableDesc: CatalogTable, + query: Option[LogicalPlan]): CatalogTable = { + + if (tableDesc.bucketSpec.isDefined) { + throw DeltaErrors.operationNotSupportedException("Bucketing", tableDesc.identifier) + } + + val schema = query.map { plan => + assert(tableDesc.schema.isEmpty, "Can't specify table schema in CTAS.") + plan.schema.asNullable + }.getOrElse(tableDesc.schema) + + PartitioningUtils.validatePartitionColumn( + schema, + tableDesc.partitionColumnNames, + caseSensitive = false) // Delta is case insensitive + + val validatedConfigurations = DeltaConfigs.validateConfigurations(tableDesc.properties) + + val db = tableDesc.identifier.database.getOrElse(catalog.getCurrentDatabase) + val tableIdentWithDB = tableDesc.identifier.copy(database = Some(db)) + tableDesc.copy( + identifier = tableIdentWithDB, + schema = schema, + properties = validatedConfigurations) + } + + override protected def createGpuStagedDeltaTableV2( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String], + operation: TableCreationModes.CreationMode): StagedTable = { + new GpuStagedDeltaTableV2WithLogging(ident, schema, partitions, properties, operation) + } + + override def loadTable(ident: Identifier, timestamp: Long): Table = { + cpuCatalog.loadTable(ident, timestamp) + } + + override def loadTable(ident: Identifier, version: String): Table = { + cpuCatalog.loadTable(ident, version) + } + + /** + * Creates a Delta table using GPU for writing the data + * + * @param ident The identifier of the table + * @param schema The schema of the table + * @param partitions The partition transforms for the table + * @param allTableProperties The table properties that configure the behavior of the table or + * provide information about the table + * @param writeOptions Options specific to the write during table creation or replacement + * @param sourceQuery A query if this CREATE request came from a CTAS or RTAS + * @param operation The specific table creation mode, whether this is a + * Create/Replace/Create or Replace + */ + override def createDeltaTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + allTableProperties: util.Map[String, String], + writeOptions: Map[String, String], + sourceQuery: Option[DataFrame], + operation: TableCreationModes.CreationMode + ): Table = recordFrameProfile( + "DeltaCatalog", "createDeltaTable") { + super.createDeltaTable( + ident, + schema, + partitions, + allTableProperties, + writeOptions, + sourceQuery, + operation) + } + + override def createTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = + recordFrameProfile("DeltaCatalog", "createTable") { + super.createTable(ident, schema, partitions, properties) + } + + override def stageReplace( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): StagedTable = + recordFrameProfile("DeltaCatalog", "stageReplace") { + super.stageReplace(ident, schema, partitions, properties) + } + + override def stageCreateOrReplace( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): StagedTable = + recordFrameProfile("DeltaCatalog", "stageCreateOrReplace") { + super.stageCreateOrReplace(ident, schema, partitions, properties) + } + + override def stageCreate( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): StagedTable = + recordFrameProfile("DeltaCatalog", "stageCreate") { + super.stageCreate(ident, schema, partitions, properties) + } + + /** + * A staged Delta table, which creates a HiveMetaStore entry and appends data if this was a + * CTAS/RTAS command. We have a ugly way of using this API right now, but it's the best way to + * maintain old behavior compatibility between Databricks Runtime and OSS Delta Lake. + */ + protected class GpuStagedDeltaTableV2WithLogging( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String], + operation: TableCreationModes.CreationMode) + extends GpuStagedDeltaTableV2(ident, schema, partitions, properties, operation) { + + override def commitStagedChanges(): Unit = recordFrameProfile( + "DeltaCatalog", "commitStagedChanges") { + super.commitStagedChanges() + } + } +} diff --git a/integration_tests/src/main/python/delta_lake_write_test.py b/integration_tests/src/main/python/delta_lake_write_test.py index 94e7c529bb1..0327947880f 100644 --- a/integration_tests/src/main/python/delta_lake_write_test.py +++ b/integration_tests/src/main/python/delta_lake_write_test.py @@ -256,6 +256,26 @@ def test_delta_overwrite_round_trip_unmanaged(spark_tmp_path): def test_delta_append_round_trip_unmanaged(spark_tmp_path): do_update_round_trip_managed(spark_tmp_path, "append") +@allow_non_gpu(*delta_meta_allow) +@delta_lake +@ignore_order(local=True) +@pytest.mark.skipif(is_before_spark_320(), reason="Delta Lake writes are not supported before Spark 3.2.x") +@pytest.mark.parametrize("gens", parquet_write_gens_list, ids=idfn) +def test_delta_atomic_create_table_as_select(gens, spark_tmp_table_factory, spark_tmp_path): + gen_list = [("c" + str(i), gen) for i, gen in enumerate(gens)] + data_path = spark_tmp_path + "/DELTA_DATA" + confs = copy_and_update(writer_confs, delta_writes_enabled_conf) + path_to_table= {} + def do_write(spark, path): + table = spark_tmp_table_factory.get() + path_to_table[path] = table + gen_df(spark, gen_list).coalesce(1).write.format("delta").saveAsTable(table) + assert_gpu_and_cpu_writes_are_equal_collect( + do_write, + lambda spark, path: spark.read.format("delta").table(path_to_table[path]), + data_path, + conf = confs) + @allow_non_gpu(*delta_meta_allow) @delta_lake @ignore_order diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AtomicCreateTableAsSelectExecMeta.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AtomicCreateTableAsSelectExecMeta.scala new file mode 100644 index 00000000000..54b3c95894d --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AtomicCreateTableAsSelectExecMeta.scala @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec +import org.apache.spark.sql.rapids.ExternalSource + +class AtomicCreateTableAsSelectExecMeta( + wrapped: AtomicCreateTableAsSelectExec, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _, _]], + rule: DataFromReplacementRule) + extends SparkPlanMeta[AtomicCreateTableAsSelectExec](wrapped, conf, parent, rule) { + + override def tagPlanForGpu(): Unit = { + ExternalSource.tagForGpu(wrapped, this) + } + + override def convertToGpu(): GpuExec = { + ExternalSource.convertToGpu(wrapped, this) + } +} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/delta/DeltaProvider.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/delta/DeltaProvider.scala index 7e521bd4829..6c5e53df74f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/delta/DeltaProvider.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/delta/DeltaProvider.scala @@ -16,12 +16,14 @@ package com.nvidia.spark.rapids.delta -import com.nvidia.spark.rapids.{CreatableRelationProviderRule, ExecRule, RunnableCommandRule, ShimLoader, SparkPlanMeta} +import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, CreatableRelationProviderRule, ExecRule, GpuExec, RunnableCommandRule, ShimLoader, SparkPlanMeta} import org.apache.spark.sql.Strategy +import org.apache.spark.sql.connector.catalog.StagingTableCatalog import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec import org.apache.spark.sql.sources.CreatableRelationProvider /** Probe interface to determine which Delta Lake provider to use. */ @@ -46,6 +48,16 @@ trait DeltaProvider { def tagSupportForGpuFileSourceScan(meta: SparkPlanMeta[FileSourceScanExec]): Unit def getReadFileFormat(format: FileFormat): FileFormat + + def isSupportedCatalog(catalogClass: Class[_ <: StagingTableCatalog]): Boolean + + def tagForGpu( + cpuExec: AtomicCreateTableAsSelectExec, + meta: AtomicCreateTableAsSelectExecMeta): Unit + + def convertToGpu( + cpuExec: AtomicCreateTableAsSelectExec, + meta: AtomicCreateTableAsSelectExecMeta): GpuExec } object DeltaProvider { @@ -74,4 +86,18 @@ object NoDeltaProvider extends DeltaProvider { override def getReadFileFormat(format: FileFormat): FileFormat = throw new IllegalStateException("unsupported format") + + override def isSupportedCatalog(catalogClass: Class[_ <: StagingTableCatalog]): Boolean = false + + override def tagForGpu( + cpuExec: AtomicCreateTableAsSelectExec, + meta: AtomicCreateTableAsSelectExecMeta): Unit = { + throw new IllegalStateException("catalog not supported, should not be called") + } + + override def convertToGpu( + cpuExec: AtomicCreateTableAsSelectExec, + meta: AtomicCreateTableAsSelectExecMeta): GpuExec = { + throw new IllegalStateException("catalog not supported, should not be called") + } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala index fe560551aa3..aa3588fcab8 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec import org.apache.spark.sql.sources.CreatableRelationProvider import org.apache.spark.util.Utils @@ -136,4 +137,26 @@ object ExternalSource extends Logging { require(doWrap != null) new CreatableRelationProviderRule[INPUT](doWrap, desc, tag) } + + def tagForGpu( + cpuExec: AtomicCreateTableAsSelectExec, + meta: AtomicCreateTableAsSelectExecMeta): Unit = { + val catalogClass = cpuExec.catalog.getClass + if (deltaProvider.isSupportedCatalog(catalogClass)) { + deltaProvider.tagForGpu(cpuExec, meta) + } else { + meta.willNotWorkOnGpu(s"catalog ${cpuExec.catalog.getClass} is not supported") + } + } + + def convertToGpu( + cpuExec: AtomicCreateTableAsSelectExec, + meta: AtomicCreateTableAsSelectExecMeta): GpuExec = { + val catalogClass = cpuExec.catalog.getClass + if (deltaProvider.isSupportedCatalog(catalogClass)) { + deltaProvider.convertToGpu(cpuExec, meta) + } else { + throw new IllegalStateException("No GPU conversion") + } + } } diff --git a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala index 05fc736c733..d977f4f5506 100644 --- a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala +++ b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala @@ -52,6 +52,7 @@ import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive._ import org.apache.spark.sql.execution.command._ +import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec import org.apache.spark.sql.execution.datasources.v2.csv.CSVScan import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan @@ -251,6 +252,13 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging { override def getExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] = { val maps: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] = Seq( + exec[AtomicCreateTableAsSelectExec]( + "Create table as select for datasource V2 tables that support staging table creation", + ExecChecks((TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.STRUCT + + TypeSig.MAP + TypeSig.ARRAY + TypeSig.BINARY + + GpuTypeShims.additionalCommonOperatorSupportedTypes).nested(), + TypeSig.all), + (e, conf, p, r) => new AtomicCreateTableAsSelectExecMeta(e, conf, p, r)), GpuOverrides.exec[WindowInPandasExec]( "The backend for Window Aggregation Pandas UDF, Accelerates the data transfer between" + " the Java process and the Python process. It also supports scheduling GPU resources" + diff --git a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala new file mode 100644 index 00000000000..250a4fa904c --- /dev/null +++ b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "320"} +{"spark": "321"} +{"spark": "321cdh"} +{"spark": "322"} +{"spark": "323"} +{"spark": "324"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.execution.datasources.v2.rapids + +import scala.collection.JavaConverters._ + +import com.nvidia.spark.rapids.GpuExec + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.connector.catalog.{Identifier, StagingTableCatalog} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.datasources.v2.TableWriteExecHelper +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * GPU version of AtomicCreateTableAsSelectExec. + * + * Physical plan node for v2 create table as select, when the catalog is determined to support + * staging table creation. + * + * A new table will be created using the schema of the query, and rows from the query are appended. + * The CTAS operation is atomic. The creation of the table is staged and the commit of the write + * should bundle the commitment of the metadata and the table contents in a single unit. If the + * write fails, the table is instructed to roll back all staged changes. + */ +case class GpuAtomicCreateTableAsSelectExec( + catalog: StagingTableCatalog, + ident: Identifier, + partitioning: Seq[Transform], + plan: LogicalPlan, + query: SparkPlan, + properties: Map[String, String], + writeOptions: CaseInsensitiveStringMap, + ifNotExists: Boolean) extends TableWriteExecHelper with GpuExec { + + override def supportsColumnar: Boolean = false + + override protected def run(): Seq[InternalRow] = { + if (catalog.tableExists(ident)) { + if (ifNotExists) { + return Nil + } + + throw QueryCompilationErrors.tableAlreadyExistsError(ident) + } + val schema = CharVarcharUtils.getRawSchema(query.schema).asNullable + val stagedTable = catalog.stageCreate( + ident, schema, partitioning.toArray, properties.asJava) + writeToTable(catalog, stagedTable, writeOptions, ident) + } + + override protected def withNewChildInternal( + newChild: SparkPlan): GpuAtomicCreateTableAsSelectExec = copy(query = newChild) + + override protected def internalDoExecuteColumnar(): RDD[ColumnarBatch] = + throw new IllegalStateException("Columnar execution not supported") +} diff --git a/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala b/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala new file mode 100644 index 00000000000..ddf9ce97b40 --- /dev/null +++ b/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "321db"} +{"spark": "330db"} +{"spark": "332db"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.execution.datasources.v2.rapids + +import scala.collection.JavaConverters._ + +import com.nvidia.spark.rapids.GpuExec + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, TableSpec} +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.datasources.v2.TableWriteExecHelper +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * GPU version of AtomicCreateTableAsSelectExec. + * + * Physical plan node for v2 create table as select, when the catalog is determined to support + * staging table creation. + * + * A new table will be created using the schema of the query, and rows from the query are appended. + * The CTAS operation is atomic. The creation of the table is staged and the commit of the write + * should bundle the commitment of the metadata and the table contents in a single unit. If the + * write fails, the table is instructed to roll back all staged changes. + */ +case class GpuAtomicCreateTableAsSelectExec( + override val output: Seq[Attribute], + catalog: StagingTableCatalog, + ident: Identifier, + partitioning: Seq[Transform], + plan: LogicalPlan, + query: SparkPlan, + tableSpec: TableSpec, + writeOptions: CaseInsensitiveStringMap, + ifNotExists: Boolean) extends TableWriteExecHelper with GpuExec { + + val properties = CatalogV2Util.convertTableProperties(tableSpec) + + override def supportsColumnar: Boolean = false + + override protected def run(): Seq[InternalRow] = { + if (catalog.tableExists(ident)) { + if (ifNotExists) { + return Nil + } + + throw QueryCompilationErrors.tableAlreadyExistsError(ident) + } + val schema = CharVarcharUtils.getRawSchema(query.schema, conf).asNullable + val stagedTable = catalog.stageCreate( + ident, schema, partitioning.toArray, properties.asJava) + writeToTable(catalog, stagedTable, writeOptions, ident) + } + + override protected def withNewChildInternal( + newChild: SparkPlan): GpuAtomicCreateTableAsSelectExec = copy(query = newChild) + + override protected def internalDoExecuteColumnar(): RDD[ColumnarBatch] = + throw new IllegalStateException("Columnar execution not supported") +} diff --git a/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala new file mode 100644 index 00000000000..f4c1fb198a9 --- /dev/null +++ b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "330"} +{"spark": "330cdh"} +{"spark": "331"} +{"spark": "332"} +{"spark": "333"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.execution.datasources.v2.rapids + +import scala.collection.JavaConverters._ + +import com.nvidia.spark.rapids.GpuExec + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, TableSpec} +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.datasources.v2.TableWriteExecHelper +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * GPU version of AtomicCreateTableAsSelectExec. + * + * Physical plan node for v2 create table as select, when the catalog is determined to support + * staging table creation. + * + * A new table will be created using the schema of the query, and rows from the query are appended. + * The CTAS operation is atomic. The creation of the table is staged and the commit of the write + * should bundle the commitment of the metadata and the table contents in a single unit. If the + * write fails, the table is instructed to roll back all staged changes. + */ +case class GpuAtomicCreateTableAsSelectExec( + catalog: StagingTableCatalog, + ident: Identifier, + partitioning: Seq[Transform], + plan: LogicalPlan, + query: SparkPlan, + tableSpec: TableSpec, + writeOptions: CaseInsensitiveStringMap, + ifNotExists: Boolean) extends TableWriteExecHelper with GpuExec { + + val properties = CatalogV2Util.convertTableProperties(tableSpec) + + override def supportsColumnar: Boolean = false + + override protected def run(): Seq[InternalRow] = { + if (catalog.tableExists(ident)) { + if (ifNotExists) { + return Nil + } + + throw QueryCompilationErrors.tableAlreadyExistsError(ident) + } + val schema = CharVarcharUtils.getRawSchema(query.schema, conf).asNullable + val stagedTable = catalog.stageCreate( + ident, schema, partitioning.toArray, properties.asJava) + writeToTable(catalog, stagedTable, writeOptions, ident) + } + + override protected def withNewChildInternal( + newChild: SparkPlan): GpuAtomicCreateTableAsSelectExec = copy(query = newChild) + + override protected def internalDoExecuteColumnar(): RDD[ColumnarBatch] = + throw new IllegalStateException("Columnar execution not supported") +} diff --git a/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala new file mode 100644 index 00000000000..28a6e5910d3 --- /dev/null +++ b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "340"} +{"spark": "341"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.execution.datasources.v2.rapids + +import scala.collection.JavaConverters._ + +import com.nvidia.spark.rapids.GpuExec + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, TableSpec} +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.datasources.v2.TableWriteExecHelper +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * GPU version of AtomicCreateTableAsSelectExec. + * + * Physical plan node for v2 create table as select, when the catalog is determined to support + * staging table creation. + * + * A new table will be created using the schema of the query, and rows from the query are appended. + * The CTAS operation is atomic. The creation of the table is staged and the commit of the write + * should bundle the commitment of the metadata and the table contents in a single unit. If the + * write fails, the table is instructed to roll back all staged changes. + */ +case class GpuAtomicCreateTableAsSelectExec( + catalog: StagingTableCatalog, + ident: Identifier, + partitioning: Seq[Transform], + plan: LogicalPlan, + query: SparkPlan, + tableSpec: TableSpec, + writeOptions: CaseInsensitiveStringMap, + ifNotExists: Boolean) extends TableWriteExecHelper with GpuExec { + + val properties = CatalogV2Util.convertTableProperties(tableSpec) + + override def supportsColumnar: Boolean = false + + override protected def run(): Seq[InternalRow] = { + if (catalog.tableExists(ident)) { + if (ifNotExists) { + return Nil + } + + throw QueryCompilationErrors.tableAlreadyExistsError(ident) + } + val columns = CatalogV2Util.structTypeToV2Columns( + CharVarcharUtils.getRawSchema(query.schema, conf).asNullable) + val stagedTable = catalog.stageCreate( + ident, columns, partitioning.toArray, properties.asJava) + writeToTable(catalog, stagedTable, writeOptions, ident) + } + + override protected def withNewChildInternal( + newChild: SparkPlan): GpuAtomicCreateTableAsSelectExec = copy(query = newChild) + + override protected def internalDoExecuteColumnar(): RDD[ColumnarBatch] = + throw new IllegalStateException("Columnar execution not supported") +} diff --git a/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/execution/ShimTrampolineUtil.scala b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/execution/ShimTrampolineUtil.scala index 3333b446b5d..5a4e0126367 100644 --- a/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/execution/ShimTrampolineUtil.scala +++ b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/execution/ShimTrampolineUtil.scala @@ -22,6 +22,7 @@ spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.rapids.execution import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, IdentityBroadcastMode} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column} import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode import org.apache.spark.sql.types.{DataType, StructType} @@ -35,4 +36,8 @@ object ShimTrampolineUtil { case IdentityBroadcastMode => true case _ => false } + + def v2ColumnsToStructType(columns: Array[Column]): StructType = { + CatalogV2Util.v2ColumnsToStructType(columns) + } } diff --git a/sql-plugin/src/main/spark350/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala b/sql-plugin/src/main/spark350/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala new file mode 100644 index 00000000000..2f6869be7e4 --- /dev/null +++ b/sql-plugin/src/main/spark350/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "350"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.execution.datasources.v2.rapids + +import scala.collection.JavaConverters._ + +import com.nvidia.spark.rapids.GpuExec + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, TableSpec} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.datasources.v2.V2CreateTableAsSelectBaseExec +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * GPU version of AtomicCreateTableAsSelectExec. + * + * Physical plan node for v2 create table as select, when the catalog is determined to support + * staging table creation. + * + * A new table will be created using the schema of the query, and rows from the query are appended. + * The CTAS operation is atomic. The creation of the table is staged and the commit of the write + * should bundle the commitment of the metadata and the table contents in a single unit. If the + * write fails, the table is instructed to roll back all staged changes. + */ +case class GpuAtomicCreateTableAsSelectExec( + catalog: StagingTableCatalog, + ident: Identifier, + partitioning: Seq[Transform], + query: LogicalPlan, + tableSpec: TableSpec, + writeOptions: Map[String, String], + ifNotExists: Boolean) extends V2CreateTableAsSelectBaseExec with GpuExec { + + val properties = CatalogV2Util.convertTableProperties(tableSpec) + + override def supportsColumnar: Boolean = false + + override protected def run(): Seq[InternalRow] = { + if (catalog.tableExists(ident)) { + if (ifNotExists) { + return Nil + } + + throw QueryCompilationErrors.tableAlreadyExistsError(ident) + } + val stagedTable = catalog.stageCreate( + ident, getV2Columns(query.schema, catalog.useNullableQuerySchema), + partitioning.toArray, properties.asJava) + writeToTable(catalog, stagedTable, writeOptions, ident, query) + } + + override protected def internalDoExecuteColumnar(): RDD[ColumnarBatch] = + throw new IllegalStateException("Columnar execution not supported") +} From 6307e257683cac1e40feede68967c5279881d801 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Thu, 12 Oct 2023 17:22:07 -0500 Subject: [PATCH 2/8] Add AtomicReplaceTableAsSelectExec support for Delta Lake Signed-off-by: Jason Lowe --- .../delta/DatabricksDeltaProvider.scala | 38 ++++++- .../spark/rapids/delta/DeltaIOProvider.scala | 30 ++++-- .../delta/delta20x/Delta20xProvider.scala | 22 +++- .../delta/delta21x/Delta21xProvider.scala | 22 +++- .../delta/delta22x/Delta22xProvider.scala | 22 +++- .../delta/delta24x/Delta24xProvider.scala | 22 +++- .../src/main/python/delta_lake_write_test.py | 30 ++++-- .../spark/rapids/delta/DeltaProvider.scala | 24 ++++- ...ecMeta.scala => v2WriteCommandMetas.scala} | 18 +++- .../spark/sql/rapids/ExternalSource.scala | 24 ++++- .../rapids/shims/Spark320PlusShims.scala | 9 +- .../GpuAtomicReplaceTableAsSelectExec.scala | 96 +++++++++++++++++ .../GpuAtomicReplaceTableAsSelectExec.scala | 100 ++++++++++++++++++ .../GpuAtomicReplaceTableAsSelectExec.scala | 100 ++++++++++++++++++ .../GpuAtomicReplaceTableAsSelectExec.scala | 98 +++++++++++++++++ .../GpuAtomicReplaceTableAsSelectExec.scala | 89 ++++++++++++++++ 16 files changed, 711 insertions(+), 33 deletions(-) rename sql-plugin/src/main/scala/com/nvidia/spark/rapids/{AtomicCreateTableAsSelectExecMeta.scala => v2WriteCommandMetas.scala} (66%) create mode 100644 sql-plugin/src/main/spark320/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala create mode 100644 sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala create mode 100644 sql-plugin/src/main/spark330/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala create mode 100644 sql-plugin/src/main/spark340/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala create mode 100644 sql-plugin/src/main/spark350/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala diff --git a/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/DatabricksDeltaProvider.scala b/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/DatabricksDeltaProvider.scala index dc29208cef6..4186ca937af 100644 --- a/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/DatabricksDeltaProvider.scala +++ b/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/DatabricksDeltaProvider.scala @@ -31,8 +31,8 @@ import org.apache.spark.sql.connector.catalog.StagingTableCatalog import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.execution.datasources.{FileFormat, SaveIntoDataSourceCommand} -import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec -import org.apache.spark.sql.execution.datasources.v2.rapids.GpuAtomicCreateTableAsSelectExec +import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec} +import org.apache.spark.sql.execution.datasources.v2.rapids.{GpuAtomicCreateTableAsSelectExec, GpuAtomicReplaceTableAsSelectExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.rapids.ExternalSource import org.apache.spark.sql.sources.CreatableRelationProvider @@ -138,6 +138,40 @@ object DatabricksDeltaProvider extends DeltaProviderImplBase { cpuExec.writeOptions, cpuExec.ifNotExists) } + + override def tagForGpu( + cpuExec: AtomicReplaceTableAsSelectExec, + meta: AtomicReplaceTableAsSelectExecMeta): Unit = { + require(isSupportedCatalog(cpuExec.catalog.getClass)) + if (!meta.conf.isDeltaWriteEnabled) { + meta.willNotWorkOnGpu("Delta Lake output acceleration has been disabled. To enable set " + + s"${RapidsConf.ENABLE_DELTA_WRITE} to true") + } + val properties = cpuExec.properties + val provider = properties.getOrElse("provider", + cpuExec.conf.getConf(SQLConf.DEFAULT_DATA_SOURCE_NAME)) + if (!DeltaSourceUtils.isDeltaDataSourceName(provider)) { + meta.willNotWorkOnGpu(s"table provider '$provider' is not a Delta Lake provider") + } + RapidsDeltaUtils.tagForDeltaWrite(meta, cpuExec.query.schema, None, + cpuExec.writeOptions.asCaseSensitiveMap().asScala.toMap, cpuExec.session) + } + + override def convertToGpu( + cpuExec: AtomicReplaceTableAsSelectExec, + meta: AtomicReplaceTableAsSelectExecMeta): GpuExec = { + GpuAtomicReplaceTableAsSelectExec( + cpuExec.output, + new GpuDeltaCatalog(cpuExec.catalog, meta.conf), + cpuExec.ident, + cpuExec.partitioning, + cpuExec.plan, + meta.childPlans.head.convertIfNeeded(), + cpuExec.tableSpec, + cpuExec.writeOptions, + cpuExec.orCreate, + cpuExec.invalidateCache) + } } class DeltaCreatableRelationProviderMeta( diff --git a/delta-lake/common/src/main/delta-io/scala/com/nvidia/spark/rapids/delta/DeltaIOProvider.scala b/delta-lake/common/src/main/delta-io/scala/com/nvidia/spark/rapids/delta/DeltaIOProvider.scala index 049f98180cf..e35f85f9859 100644 --- a/delta-lake/common/src/main/delta-io/scala/com/nvidia/spark/rapids/delta/DeltaIOProvider.scala +++ b/delta-lake/common/src/main/delta-io/scala/com/nvidia/spark/rapids/delta/DeltaIOProvider.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.delta.catalog.DeltaCatalog import org.apache.spark.sql.delta.rapids.DeltaRuntimeShim import org.apache.spark.sql.delta.sources.{DeltaDataSource, DeltaSourceUtils} import org.apache.spark.sql.execution.datasources.{FileFormat, SaveIntoDataSourceCommand} -import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec +import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.rapids.ExternalSource import org.apache.spark.sql.rapids.execution.UnshimmedTrampolineUtil @@ -66,15 +66,33 @@ abstract class DeltaIOProvider extends DeltaProviderImplBase { meta.willNotWorkOnGpu("Delta Lake output acceleration has been disabled. To enable set " + s"${RapidsConf.ENABLE_DELTA_WRITE} to true") } - val properties = cpuExec.properties - val provider = properties.getOrElse("provider", - cpuExec.conf.getConf(SQLConf.DEFAULT_DATA_SOURCE_NAME)) - if (!DeltaSourceUtils.isDeltaDataSourceName(provider)) { - meta.willNotWorkOnGpu(s"table provider '$provider' is not a Delta Lake provider") + checkDeltaProvider(meta, cpuExec.properties, cpuExec.conf) + RapidsDeltaUtils.tagForDeltaWrite(meta, cpuExec.query.schema, None, + cpuExec.writeOptions.asCaseSensitiveMap().asScala.toMap, cpuExec.session) + } + + override def tagForGpu( + cpuExec: AtomicReplaceTableAsSelectExec, + meta: AtomicReplaceTableAsSelectExecMeta): Unit = { + require(isSupportedCatalog(cpuExec.catalog.getClass)) + if (!meta.conf.isDeltaWriteEnabled) { + meta.willNotWorkOnGpu("Delta Lake output acceleration has been disabled. To enable set " + + s"${RapidsConf.ENABLE_DELTA_WRITE} to true") } + checkDeltaProvider(meta, cpuExec.properties, cpuExec.conf) RapidsDeltaUtils.tagForDeltaWrite(meta, cpuExec.query.schema, None, cpuExec.writeOptions.asCaseSensitiveMap().asScala.toMap, cpuExec.session) } + + private def checkDeltaProvider( + meta: RapidsMeta[_, _, _], + properties: Map[String, String], + conf: SQLConf): Unit = { + val provider = properties.getOrElse("provider", conf.getConf(SQLConf.DEFAULT_DATA_SOURCE_NAME)) + if (!DeltaSourceUtils.isDeltaDataSourceName(provider)) { + meta.willNotWorkOnGpu(s"table provider '$provider' is not a Delta Lake provider") + } + } } class DeltaCreatableRelationProviderMeta( diff --git a/delta-lake/delta-20x/src/main/scala/com/nvidia/spark/rapids/delta/delta20x/Delta20xProvider.scala b/delta-lake/delta-20x/src/main/scala/com/nvidia/spark/rapids/delta/delta20x/Delta20xProvider.scala index 7542b7b612b..961c6632a7e 100644 --- a/delta-lake/delta-20x/src/main/scala/com/nvidia/spark/rapids/delta/delta20x/Delta20xProvider.scala +++ b/delta-lake/delta-20x/src/main/scala/com/nvidia/spark/rapids/delta/delta20x/Delta20xProvider.scala @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids.delta.delta20x -import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, GpuExec, GpuOverrides, GpuReadParquetFileFormat, RunnableCommandRule, SparkPlanMeta} +import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, AtomicReplaceTableAsSelectExecMeta, GpuExec, GpuOverrides, GpuReadParquetFileFormat, RunnableCommandRule, SparkPlanMeta} import com.nvidia.spark.rapids.delta.DeltaIOProvider import org.apache.spark.sql.delta.DeltaParquetFileFormat @@ -26,8 +26,8 @@ import org.apache.spark.sql.delta.rapids.DeltaRuntimeShim import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.execution.datasources.FileFormat -import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec -import org.apache.spark.sql.execution.datasources.v2.rapids.GpuAtomicCreateTableAsSelectExec +import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec} +import org.apache.spark.sql.execution.datasources.v2.rapids.{GpuAtomicCreateTableAsSelectExec, GpuAtomicReplaceTableAsSelectExec} object Delta20xProvider extends DeltaIOProvider { @@ -77,4 +77,20 @@ object Delta20xProvider extends DeltaIOProvider { cpuExec.writeOptions, cpuExec.ifNotExists) } + + override def convertToGpu( + cpuExec: AtomicReplaceTableAsSelectExec, + meta: AtomicReplaceTableAsSelectExecMeta): GpuExec = { + val cpuCatalog = cpuExec.catalog.asInstanceOf[DeltaCatalog] + GpuAtomicReplaceTableAsSelectExec( + DeltaRuntimeShim.getGpuDeltaCatalog(cpuCatalog, meta.conf), + cpuExec.ident, + cpuExec.partitioning, + cpuExec.plan, + meta.childPlans.head.convertIfNeeded(), + cpuExec.properties, + cpuExec.writeOptions, + cpuExec.orCreate, + cpuExec.invalidateCache) + } } diff --git a/delta-lake/delta-21x/src/main/scala/com/nvidia/spark/rapids/delta/delta21x/Delta21xProvider.scala b/delta-lake/delta-21x/src/main/scala/com/nvidia/spark/rapids/delta/delta21x/Delta21xProvider.scala index 4fcb2a993bc..9e5387b8ef1 100644 --- a/delta-lake/delta-21x/src/main/scala/com/nvidia/spark/rapids/delta/delta21x/Delta21xProvider.scala +++ b/delta-lake/delta-21x/src/main/scala/com/nvidia/spark/rapids/delta/delta21x/Delta21xProvider.scala @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids.delta.delta21x -import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, GpuExec, GpuOverrides, GpuReadParquetFileFormat, RunnableCommandRule, SparkPlanMeta} +import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, AtomicReplaceTableAsSelectExecMeta, GpuExec, GpuOverrides, GpuReadParquetFileFormat, RunnableCommandRule, SparkPlanMeta} import com.nvidia.spark.rapids.delta.DeltaIOProvider import org.apache.spark.sql.delta.DeltaParquetFileFormat @@ -26,8 +26,8 @@ import org.apache.spark.sql.delta.rapids.DeltaRuntimeShim import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.execution.datasources.FileFormat -import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec -import org.apache.spark.sql.execution.datasources.v2.rapids.GpuAtomicCreateTableAsSelectExec +import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec} +import org.apache.spark.sql.execution.datasources.v2.rapids.{GpuAtomicCreateTableAsSelectExec, GpuAtomicReplaceTableAsSelectExec} object Delta21xProvider extends DeltaIOProvider { @@ -77,4 +77,20 @@ object Delta21xProvider extends DeltaIOProvider { cpuExec.writeOptions, cpuExec.ifNotExists) } + + override def convertToGpu( + cpuExec: AtomicReplaceTableAsSelectExec, + meta: AtomicReplaceTableAsSelectExecMeta): GpuExec = { + val cpuCatalog = cpuExec.catalog.asInstanceOf[DeltaCatalog] + GpuAtomicReplaceTableAsSelectExec( + DeltaRuntimeShim.getGpuDeltaCatalog(cpuCatalog, meta.conf), + cpuExec.ident, + cpuExec.partitioning, + cpuExec.plan, + meta.childPlans.head.convertIfNeeded(), + cpuExec.tableSpec, + cpuExec.writeOptions, + cpuExec.orCreate, + cpuExec.invalidateCache) + } } diff --git a/delta-lake/delta-22x/src/main/scala/com/nvidia/spark/rapids/delta/delta22x/Delta22xProvider.scala b/delta-lake/delta-22x/src/main/scala/com/nvidia/spark/rapids/delta/delta22x/Delta22xProvider.scala index 73ead6f45e4..47d158bb37b 100644 --- a/delta-lake/delta-22x/src/main/scala/com/nvidia/spark/rapids/delta/delta22x/Delta22xProvider.scala +++ b/delta-lake/delta-22x/src/main/scala/com/nvidia/spark/rapids/delta/delta22x/Delta22xProvider.scala @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids.delta.delta22x -import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, GpuExec, GpuOverrides, GpuReadParquetFileFormat, RunnableCommandRule, SparkPlanMeta} +import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, AtomicReplaceTableAsSelectExecMeta, GpuExec, GpuOverrides, GpuReadParquetFileFormat, RunnableCommandRule, SparkPlanMeta} import com.nvidia.spark.rapids.delta.DeltaIOProvider import org.apache.spark.sql.delta.DeltaParquetFileFormat @@ -26,8 +26,8 @@ import org.apache.spark.sql.delta.rapids.DeltaRuntimeShim import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.execution.datasources.FileFormat -import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec -import org.apache.spark.sql.execution.datasources.v2.rapids.GpuAtomicCreateTableAsSelectExec +import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec} +import org.apache.spark.sql.execution.datasources.v2.rapids.{GpuAtomicCreateTableAsSelectExec, GpuAtomicReplaceTableAsSelectExec} object Delta22xProvider extends DeltaIOProvider { @@ -77,4 +77,20 @@ object Delta22xProvider extends DeltaIOProvider { cpuExec.writeOptions, cpuExec.ifNotExists) } + + override def convertToGpu( + cpuExec: AtomicReplaceTableAsSelectExec, + meta: AtomicReplaceTableAsSelectExecMeta): GpuExec = { + val cpuCatalog = cpuExec.catalog.asInstanceOf[DeltaCatalog] + GpuAtomicReplaceTableAsSelectExec( + DeltaRuntimeShim.getGpuDeltaCatalog(cpuCatalog, meta.conf), + cpuExec.ident, + cpuExec.partitioning, + cpuExec.plan, + meta.childPlans.head.convertIfNeeded(), + cpuExec.tableSpec, + cpuExec.writeOptions, + cpuExec.orCreate, + cpuExec.invalidateCache) + } } diff --git a/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/Delta24xProvider.scala b/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/Delta24xProvider.scala index f50f1d96ced..d3f952b856c 100644 --- a/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/Delta24xProvider.scala +++ b/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/Delta24xProvider.scala @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids.delta.delta24x -import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, GpuExec, GpuOverrides, GpuReadParquetFileFormat, RunnableCommandRule, SparkPlanMeta} +import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, AtomicReplaceTableAsSelectExecMeta, GpuExec, GpuOverrides, GpuReadParquetFileFormat, RunnableCommandRule, SparkPlanMeta} import com.nvidia.spark.rapids.delta.DeltaIOProvider import org.apache.spark.sql.delta.DeltaParquetFileFormat @@ -27,8 +27,8 @@ import org.apache.spark.sql.delta.rapids.DeltaRuntimeShim import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.execution.datasources.FileFormat -import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec -import org.apache.spark.sql.execution.datasources.v2.rapids.GpuAtomicCreateTableAsSelectExec +import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec} +import org.apache.spark.sql.execution.datasources.v2.rapids.{GpuAtomicCreateTableAsSelectExec, GpuAtomicReplaceTableAsSelectExec} object Delta24xProvider extends DeltaIOProvider { @@ -91,4 +91,20 @@ object Delta24xProvider extends DeltaIOProvider { cpuExec.writeOptions, cpuExec.ifNotExists) } + + override def convertToGpu( + cpuExec: AtomicReplaceTableAsSelectExec, + meta: AtomicReplaceTableAsSelectExecMeta): GpuExec = { + val cpuCatalog = cpuExec.catalog.asInstanceOf[DeltaCatalog] + GpuAtomicReplaceTableAsSelectExec( + DeltaRuntimeShim.getGpuDeltaCatalog(cpuCatalog, meta.conf), + cpuExec.ident, + cpuExec.partitioning, + cpuExec.plan, + meta.childPlans.head.convertIfNeeded(), + cpuExec.tableSpec, + cpuExec.writeOptions, + cpuExec.orCreate, + cpuExec.invalidateCache) + } } diff --git a/integration_tests/src/main/python/delta_lake_write_test.py b/integration_tests/src/main/python/delta_lake_write_test.py index 0327947880f..3ed897f2142 100644 --- a/integration_tests/src/main/python/delta_lake_write_test.py +++ b/integration_tests/src/main/python/delta_lake_write_test.py @@ -256,12 +256,7 @@ def test_delta_overwrite_round_trip_unmanaged(spark_tmp_path): def test_delta_append_round_trip_unmanaged(spark_tmp_path): do_update_round_trip_managed(spark_tmp_path, "append") -@allow_non_gpu(*delta_meta_allow) -@delta_lake -@ignore_order(local=True) -@pytest.mark.skipif(is_before_spark_320(), reason="Delta Lake writes are not supported before Spark 3.2.x") -@pytest.mark.parametrize("gens", parquet_write_gens_list, ids=idfn) -def test_delta_atomic_create_table_as_select(gens, spark_tmp_table_factory, spark_tmp_path): +def _atomic_write_table_as_select(gens, spark_tmp_table_factory, spark_tmp_path, overwrite): gen_list = [("c" + str(i), gen) for i, gen in enumerate(gens)] data_path = spark_tmp_path + "/DELTA_DATA" confs = copy_and_update(writer_confs, delta_writes_enabled_conf) @@ -269,12 +264,31 @@ def test_delta_atomic_create_table_as_select(gens, spark_tmp_table_factory, spar def do_write(spark, path): table = spark_tmp_table_factory.get() path_to_table[path] = table - gen_df(spark, gen_list).coalesce(1).write.format("delta").saveAsTable(table) + writer = gen_df(spark, gen_list).coalesce(1).write.format("delta") + if overwrite: + writer = writer.mode("overwrite") + writer.saveAsTable(table) assert_gpu_and_cpu_writes_are_equal_collect( do_write, lambda spark, path: spark.read.format("delta").table(path_to_table[path]), data_path, - conf = confs) + conf=confs) + +@allow_non_gpu(*delta_meta_allow) +@delta_lake +@ignore_order(local=True) +@pytest.mark.skipif(is_before_spark_320(), reason="Delta Lake writes are not supported before Spark 3.2.x") +@pytest.mark.parametrize("gens", parquet_write_gens_list, ids=idfn) +def test_delta_atomic_create_table_as_select(gens, spark_tmp_table_factory, spark_tmp_path): + _atomic_write_table_as_select(gens, spark_tmp_table_factory, spark_tmp_path, overwrite=False) + +@allow_non_gpu(*delta_meta_allow) +@delta_lake +@ignore_order(local=True) +@pytest.mark.skipif(is_before_spark_320(), reason="Delta Lake writes are not supported before Spark 3.2.x") +@pytest.mark.parametrize("gens", parquet_write_gens_list, ids=idfn) +def test_delta_atomic_replace_table_as_select(gens, spark_tmp_table_factory, spark_tmp_path): + _atomic_write_table_as_select(gens, spark_tmp_table_factory, spark_tmp_path, overwrite=True) @allow_non_gpu(*delta_meta_allow) @delta_lake diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/delta/DeltaProvider.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/delta/DeltaProvider.scala index 6c5e53df74f..1fc6492826a 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/delta/DeltaProvider.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/delta/DeltaProvider.scala @@ -16,14 +16,14 @@ package com.nvidia.spark.rapids.delta -import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, CreatableRelationProviderRule, ExecRule, GpuExec, RunnableCommandRule, ShimLoader, SparkPlanMeta} +import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, AtomicReplaceTableAsSelectExecMeta, CreatableRelationProviderRule, ExecRule, GpuExec, RunnableCommandRule, ShimLoader, SparkPlanMeta} import org.apache.spark.sql.Strategy import org.apache.spark.sql.connector.catalog.StagingTableCatalog import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.execution.datasources.FileFormat -import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec +import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec} import org.apache.spark.sql.sources.CreatableRelationProvider /** Probe interface to determine which Delta Lake provider to use. */ @@ -58,6 +58,14 @@ trait DeltaProvider { def convertToGpu( cpuExec: AtomicCreateTableAsSelectExec, meta: AtomicCreateTableAsSelectExecMeta): GpuExec + + def tagForGpu( + cpuExec: AtomicReplaceTableAsSelectExec, + meta: AtomicReplaceTableAsSelectExecMeta): Unit + + def convertToGpu( + cpuExec: AtomicReplaceTableAsSelectExec, + meta: AtomicReplaceTableAsSelectExecMeta): GpuExec } object DeltaProvider { @@ -100,4 +108,16 @@ object NoDeltaProvider extends DeltaProvider { meta: AtomicCreateTableAsSelectExecMeta): GpuExec = { throw new IllegalStateException("catalog not supported, should not be called") } + + override def tagForGpu( + cpuExec: AtomicReplaceTableAsSelectExec, + meta: AtomicReplaceTableAsSelectExecMeta): Unit = { + throw new IllegalStateException("catalog not supported, should not be called") + } + + override def convertToGpu( + cpuExec: AtomicReplaceTableAsSelectExec, + meta: AtomicReplaceTableAsSelectExecMeta): GpuExec = { + throw new IllegalStateException("catalog not supported, should not be called") + } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AtomicCreateTableAsSelectExecMeta.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/v2WriteCommandMetas.scala similarity index 66% rename from sql-plugin/src/main/scala/com/nvidia/spark/rapids/AtomicCreateTableAsSelectExecMeta.scala rename to sql-plugin/src/main/scala/com/nvidia/spark/rapids/v2WriteCommandMetas.scala index 54b3c95894d..dbb3f5d1e4b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AtomicCreateTableAsSelectExecMeta.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/v2WriteCommandMetas.scala @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids -import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec +import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec} import org.apache.spark.sql.rapids.ExternalSource class AtomicCreateTableAsSelectExecMeta( @@ -34,3 +34,19 @@ class AtomicCreateTableAsSelectExecMeta( ExternalSource.convertToGpu(wrapped, this) } } + +class AtomicReplaceTableAsSelectExecMeta( + wrapped: AtomicReplaceTableAsSelectExec, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _, _]], + rule: DataFromReplacementRule) + extends SparkPlanMeta[AtomicReplaceTableAsSelectExec](wrapped, conf, parent, rule) { + + override def tagPlanForGpu(): Unit = { + ExternalSource.tagForGpu(wrapped, this) + } + + override def convertToGpu(): GpuExec = { + ExternalSource.convertToGpu(wrapped, this) + } +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala index aa3588fcab8..dfc4aad1192 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.execution.datasources.FileFormat -import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec +import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec} import org.apache.spark.sql.sources.CreatableRelationProvider import org.apache.spark.util.Utils @@ -159,4 +159,26 @@ object ExternalSource extends Logging { throw new IllegalStateException("No GPU conversion") } } + + def tagForGpu( + cpuExec: AtomicReplaceTableAsSelectExec, + meta: AtomicReplaceTableAsSelectExecMeta): Unit = { + val catalogClass = cpuExec.catalog.getClass + if (deltaProvider.isSupportedCatalog(catalogClass)) { + deltaProvider.tagForGpu(cpuExec, meta) + } else { + meta.willNotWorkOnGpu(s"catalog ${cpuExec.catalog.getClass} is not supported") + } + } + + def convertToGpu( + cpuExec: AtomicReplaceTableAsSelectExec, + meta: AtomicReplaceTableAsSelectExecMeta): GpuExec = { + val catalogClass = cpuExec.catalog.getClass + if (deltaProvider.isSupportedCatalog(catalogClass)) { + deltaProvider.convertToGpu(cpuExec, meta) + } else { + throw new IllegalStateException("No GPU conversion") + } + } } diff --git a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala index d977f4f5506..7d0807cdccc 100644 --- a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala +++ b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala @@ -52,7 +52,7 @@ import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive._ import org.apache.spark.sql.execution.command._ -import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec +import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec} import org.apache.spark.sql.execution.datasources.v2.csv.CSVScan import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan @@ -259,6 +259,13 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging { GpuTypeShims.additionalCommonOperatorSupportedTypes).nested(), TypeSig.all), (e, conf, p, r) => new AtomicCreateTableAsSelectExecMeta(e, conf, p, r)), + exec[AtomicReplaceTableAsSelectExec]( + "Replace table as select for datasource V2 tables that support staging table creation", + ExecChecks((TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.STRUCT + + TypeSig.MAP + TypeSig.ARRAY + TypeSig.BINARY + + GpuTypeShims.additionalCommonOperatorSupportedTypes).nested(), + TypeSig.all), + (e, conf, p, r) => new AtomicReplaceTableAsSelectExecMeta(e, conf, p, r)), GpuOverrides.exec[WindowInPandasExec]( "The backend for Window Aggregation Pandas UDF, Accelerates the data transfer between" + " the Java process and the Python process. It also supports scheduling GPU resources" + diff --git a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala new file mode 100644 index 00000000000..ee3402c4f12 --- /dev/null +++ b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "320"} +{"spark": "321"} +{"spark": "321cdh"} +{"spark": "322"} +{"spark": "323"} +{"spark": "324"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.execution.datasources.v2.rapids + +import scala.collection.JavaConverters._ + +import com.nvidia.spark.rapids.GpuExec + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.connector.catalog.{Identifier, StagingTableCatalog, Table, TableCatalog} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.datasources.v2.TableWriteExecHelper +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * GPU version of AtomicReplaceTableAsSelectExec. + * + * Physical plan node for v2 replace table as select when the catalog does not support staging + * table replacement. + * + * A new table will be created using the schema of the query, and rows from the query are appended. + * If the table exists, its contents and schema should be replaced with the schema and the contents + * of the query. This is a non-atomic implementation that drops the table and then runs non-atomic + * CTAS. For an atomic implementation for catalogs with the appropriate support, see + * ReplaceTableAsSelectStagingExec. + */ +case class GpuAtomicReplaceTableAsSelectExec( + catalog: StagingTableCatalog, + ident: Identifier, + partitioning: Seq[Transform], + plan: LogicalPlan, + query: SparkPlan, + properties: Map[String, String], + writeOptions: CaseInsensitiveStringMap, + orCreate: Boolean, + invalidateCache: (TableCatalog, Table, Identifier) => Unit) + extends TableWriteExecHelper with GpuExec { + + override def supportsColumnar: Boolean = false + + override protected def run(): Seq[InternalRow] = { + // Note that this operation is potentially unsafe, but these are the strict semantics of + // RTAS if the catalog does not support atomic operations. + // + // There are numerous cases we concede to where the table will be dropped and irrecoverable: + // + // 1. Creating the new table fails, + // 2. Writing to the new table fails, + // 3. The table returned by catalog.createTable doesn't support writing. + if (catalog.tableExists(ident)) { + val table = catalog.loadTable(ident) + invalidateCache(catalog, table, ident) + catalog.dropTable(ident) + } else if (!orCreate) { + throw QueryCompilationErrors.cannotReplaceMissingTableError(ident) + } + val schema = CharVarcharUtils.getRawSchema(query.schema).asNullable + val table = catalog.createTable( + ident, schema, partitioning.toArray, properties.asJava) + writeToTable(catalog, table, writeOptions, ident) + } + + override protected def withNewChildInternal( + newChild: SparkPlan): GpuAtomicReplaceTableAsSelectExec = copy(query = newChild) + + override protected def internalDoExecuteColumnar(): RDD[ColumnarBatch] = + throw new IllegalStateException("Columnar execution not supported") +} diff --git a/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala b/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala new file mode 100644 index 00000000000..1f00da4ad50 --- /dev/null +++ b/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "321db"} +{"spark": "330db"} +{"spark": "332db"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.execution.datasources.v2.rapids + +import scala.collection.JavaConverters._ + +import com.nvidia.spark.rapids.GpuExec + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, TableSpec} +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog, Table, TableCatalog} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.datasources.v2.TableWriteExecHelper +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * GPU version of AtomicReplaceTableAsSelectExec. + * + * Physical plan node for v2 replace table as select when the catalog supports staging + * table replacement. + * + * A new table will be created using the schema of the query, and rows from the query are appended. + * If the table exists, its contents and schema should be replaced with the schema and the contents + * of the query. This implementation is atomic. The table replacement is staged, and the commit + * operation at the end should perform the replacement of the table's metadata and contents. If the + * write fails, the table is instructed to roll back staged changes and any previously written table + * is left untouched. + */ +case class GpuAtomicReplaceTableAsSelectExec( + override val output: Seq[Attribute], + catalog: StagingTableCatalog, + ident: Identifier, + partitioning: Seq[Transform], + plan: LogicalPlan, + query: SparkPlan, + tableSpec: TableSpec, + writeOptions: CaseInsensitiveStringMap, + orCreate: Boolean, + invalidateCache: (TableCatalog, Table, Identifier) => Unit) + extends TableWriteExecHelper with GpuExec { + + val properties = CatalogV2Util.convertTableProperties(tableSpec) + + override def supportsColumnar: Boolean = false + + override protected def run(): Seq[InternalRow] = { + val schema = CharVarcharUtils.getRawSchema(query.schema, conf).asNullable + if (catalog.tableExists(ident)) { + val table = catalog.loadTable(ident) + invalidateCache(catalog, table, ident) + } + val staged = if (orCreate) { + catalog.stageCreateOrReplace( + ident, schema, partitioning.toArray, properties.asJava) + } else if (catalog.tableExists(ident)) { + try { + catalog.stageReplace( + ident, schema, partitioning.toArray, properties.asJava) + } catch { + case e: NoSuchTableException => + throw QueryCompilationErrors.cannotReplaceMissingTableError(ident, Some(e)) + } + } else { + throw QueryCompilationErrors.cannotReplaceMissingTableError(ident) + } + writeToTable(catalog, staged, writeOptions, ident) + } + + override protected def withNewChildInternal( + newChild: SparkPlan): GpuAtomicReplaceTableAsSelectExec = copy(query = newChild) + + override protected def internalDoExecuteColumnar(): RDD[ColumnarBatch] = + throw new IllegalStateException("Columnar execution not supported") +} diff --git a/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala new file mode 100644 index 00000000000..97e90b4a801 --- /dev/null +++ b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "330"} +{"spark": "330cdh"} +{"spark": "331"} +{"spark": "332"} +{"spark": "333"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.execution.datasources.v2.rapids + +import scala.collection.JavaConverters._ + +import com.nvidia.spark.rapids.GpuExec + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, TableSpec} +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog, Table, TableCatalog} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.datasources.v2.TableWriteExecHelper +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * GPU version of AtomicReplaceTableAsSelectExec. + * + * Physical plan node for v2 replace table as select when the catalog supports staging + * table replacement. + * + * A new table will be created using the schema of the query, and rows from the query are appended. + * If the table exists, its contents and schema should be replaced with the schema and the contents + * of the query. This implementation is atomic. The table replacement is staged, and the commit + * operation at the end should perform the replacement of the table's metadata and contents. If the + * write fails, the table is instructed to roll back staged changes and any previously written table + * is left untouched. + */ +case class GpuAtomicReplaceTableAsSelectExec( + catalog: StagingTableCatalog, + ident: Identifier, + partitioning: Seq[Transform], + plan: LogicalPlan, + query: SparkPlan, + tableSpec: TableSpec, + writeOptions: CaseInsensitiveStringMap, + orCreate: Boolean, + invalidateCache: (TableCatalog, Table, Identifier) => Unit) + extends TableWriteExecHelper with GpuExec { + + val properties = CatalogV2Util.convertTableProperties(tableSpec) + + override def supportsColumnar: Boolean = false + + override protected def run(): Seq[InternalRow] = { + val schema = CharVarcharUtils.getRawSchema(query.schema, conf).asNullable + if (catalog.tableExists(ident)) { + val table = catalog.loadTable(ident) + invalidateCache(catalog, table, ident) + } + val staged = if (orCreate) { + catalog.stageCreateOrReplace( + ident, schema, partitioning.toArray, properties.asJava) + } else if (catalog.tableExists(ident)) { + try { + catalog.stageReplace( + ident, schema, partitioning.toArray, properties.asJava) + } catch { + case e: NoSuchTableException => + throw QueryCompilationErrors.cannotReplaceMissingTableError(ident, Some(e)) + } + } else { + throw QueryCompilationErrors.cannotReplaceMissingTableError(ident) + } + writeToTable(catalog, staged, writeOptions, ident) + } + + override protected def withNewChildInternal( + newChild: SparkPlan): GpuAtomicReplaceTableAsSelectExec = copy(query = newChild) + + override protected def internalDoExecuteColumnar(): RDD[ColumnarBatch] = + throw new IllegalStateException("Columnar execution not supported") +} diff --git a/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala new file mode 100644 index 00000000000..e3e5cd5a40b --- /dev/null +++ b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala @@ -0,0 +1,98 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "340"} +{"spark": "341"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.execution.datasources.v2.rapids + +import scala.collection.JavaConverters._ + +import com.nvidia.spark.rapids.GpuExec + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, TableSpec} +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog, Table, TableCatalog} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.datasources.v2.TableWriteExecHelper +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * GPU version of AtomicReplaceTableAsSelectExec. + * + * Physical plan node for v2 replace table as select when the catalog supports staging + * table replacement. + * + * A new table will be created using the schema of the query, and rows from the query are appended. + * If the table exists, its contents and schema should be replaced with the schema and the contents + * of the query. This implementation is atomic. The table replacement is staged, and the commit + * operation at the end should perform the replacement of the table's metadata and contents. If the + * write fails, the table is instructed to roll back staged changes and any previously written table + * is left untouched. + */ +case class GpuAtomicReplaceTableAsSelectExec( + catalog: StagingTableCatalog, + ident: Identifier, + partitioning: Seq[Transform], + plan: LogicalPlan, + query: SparkPlan, + tableSpec: TableSpec, + writeOptions: CaseInsensitiveStringMap, + orCreate: Boolean, + invalidateCache: (TableCatalog, Table, Identifier) => Unit) + extends TableWriteExecHelper with GpuExec { + + val properties = CatalogV2Util.convertTableProperties(tableSpec) + + override def supportsColumnar: Boolean = false + + override protected def run(): Seq[InternalRow] = { + val columns = CatalogV2Util.structTypeToV2Columns( + CharVarcharUtils.getRawSchema(query.schema, conf).asNullable) + if (catalog.tableExists(ident)) { + val table = catalog.loadTable(ident) + invalidateCache(catalog, table, ident) + } + val staged = if (orCreate) { + catalog.stageCreateOrReplace( + ident, columns, partitioning.toArray, properties.asJava) + } else if (catalog.tableExists(ident)) { + try { + catalog.stageReplace( + ident, columns, partitioning.toArray, properties.asJava) + } catch { + case e: NoSuchTableException => + throw QueryCompilationErrors.cannotReplaceMissingTableError(ident, Some(e)) + } + } else { + throw QueryCompilationErrors.cannotReplaceMissingTableError(ident) + } + writeToTable(catalog, staged, writeOptions, ident) + } + + override protected def withNewChildInternal( + newChild: SparkPlan): GpuAtomicReplaceTableAsSelectExec = copy(query = newChild) + + override protected def internalDoExecuteColumnar(): RDD[ColumnarBatch] = + throw new IllegalStateException("Columnar execution not supported") +} diff --git a/sql-plugin/src/main/spark350/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala b/sql-plugin/src/main/spark350/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala new file mode 100644 index 00000000000..9ac78288bef --- /dev/null +++ b/sql-plugin/src/main/spark350/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "350"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.execution.datasources.v2.rapids + +import scala.collection.JavaConverters._ + +import com.nvidia.spark.rapids.GpuExec + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, TableSpec} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog, Table, TableCatalog} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.datasources.v2.V2CreateTableAsSelectBaseExec +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * GPU version of AtomicReplaceTableAsSelectExec. + * + * Physical plan node for v2 replace table as select when the catalog supports staging + * table replacement. + * + * A new table will be created using the schema of the query, and rows from the query are appended. + * If the table exists, its contents and schema should be replaced with the schema and the contents + * of the query. This implementation is atomic. The table replacement is staged, and the commit + * operation at the end should perform the replacement of the table's metadata and contents. If the + * write fails, the table is instructed to roll back staged changes and any previously written table + * is left untouched. + */ +case class GpuAtomicReplaceTableAsSelectExec( + catalog: StagingTableCatalog, + ident: Identifier, + partitioning: Seq[Transform], + query: LogicalPlan, + tableSpec: TableSpec, + writeOptions: Map[String, String], + orCreate: Boolean, + invalidateCache: (TableCatalog, Table, Identifier) => Unit) + extends V2CreateTableAsSelectBaseExec with GpuExec { + + val properties = CatalogV2Util.convertTableProperties(tableSpec) + + override def supportsColumnar: Boolean = false + + override protected def run(): Seq[InternalRow] = { + val columns = getV2Columns(query.schema, catalog.useNullableQuerySchema) + if (catalog.tableExists(ident)) { + val table = catalog.loadTable(ident) + invalidateCache(catalog, table, ident) + } + val staged = if (orCreate) { + catalog.stageCreateOrReplace( + ident, columns, partitioning.toArray, properties.asJava) + } else if (catalog.tableExists(ident)) { + try { + catalog.stageReplace( + ident, columns, partitioning.toArray, properties.asJava) + } catch { + case e: NoSuchTableException => + throw QueryCompilationErrors.cannotReplaceMissingTableError(ident, Some(e)) + } + } else { + throw QueryCompilationErrors.cannotReplaceMissingTableError(ident) + } + writeToTable(catalog, staged, writeOptions, ident, query) + } + + override protected def internalDoExecuteColumnar(): RDD[ColumnarBatch] = + throw new IllegalStateException("Columnar execution not supported") +} From 7425ae95a43182ee2c9849ce8b3c6c09949d1758 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Mon, 16 Oct 2023 11:21:39 -0500 Subject: [PATCH 3/8] Avoid showing GpuColumnarToRow transition in plan that does not actually execute --- .../com/nvidia/spark/rapids/GpuTransitionOverrides.scala | 2 ++ .../v2/rapids/GpuAtomicCreateTableAsSelectExec.scala | 4 ++-- .../v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala | 4 ++-- .../v2/rapids/GpuAtomicCreateTableAsSelectExec.scala | 4 ++-- .../v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala | 4 ++-- .../v2/rapids/GpuAtomicCreateTableAsSelectExec.scala | 4 ++-- .../v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala | 4 ++-- .../v2/rapids/GpuAtomicCreateTableAsSelectExec.scala | 4 ++-- .../v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala | 4 ++-- .../v2/rapids/GpuAtomicCreateTableAsSelectExec.scala | 4 +++- .../v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala | 3 ++- 11 files changed, 23 insertions(+), 18 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala index b80913aa64a..3ab11c6428b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala @@ -510,6 +510,8 @@ class GpuTransitionOverrides extends Rule[SparkPlan] { private def insertColumnarFromGpu(plan: SparkPlan): SparkPlan = { if (plan.supportsColumnar && plan.isInstanceOf[GpuExec]) { GpuBringBackToHost(insertColumnarToGpu(plan)) + } else if (plan.isInstanceOf[ColumnarToRowTransition] && plan.isInstanceOf[GpuExec]) { + plan.withNewChildren(plan.children.map(insertColumnarToGpu)) } else { plan.withNewChildren(plan.children.map(insertColumnarFromGpu)) } diff --git a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala index 250a4fa904c..9e868454902 100644 --- a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala +++ b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.catalog.{Identifier, StagingTableCatalog} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan} import org.apache.spark.sql.execution.datasources.v2.TableWriteExecHelper import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.sql.vectorized.ColumnarBatch @@ -59,7 +59,7 @@ case class GpuAtomicCreateTableAsSelectExec( query: SparkPlan, properties: Map[String, String], writeOptions: CaseInsensitiveStringMap, - ifNotExists: Boolean) extends TableWriteExecHelper with GpuExec { + ifNotExists: Boolean) extends TableWriteExecHelper with GpuExec with ColumnarToRowTransition { override def supportsColumnar: Boolean = false diff --git a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala index ee3402c4f12..d1f842a937a 100644 --- a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala +++ b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.catalog.{Identifier, StagingTableCatalog, Table, TableCatalog} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan} import org.apache.spark.sql.execution.datasources.v2.TableWriteExecHelper import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.sql.vectorized.ColumnarBatch @@ -62,7 +62,7 @@ case class GpuAtomicReplaceTableAsSelectExec( writeOptions: CaseInsensitiveStringMap, orCreate: Boolean, invalidateCache: (TableCatalog, Table, Identifier) => Unit) - extends TableWriteExecHelper with GpuExec { + extends TableWriteExecHelper with GpuExec with ColumnarToRowTransition { override def supportsColumnar: Boolean = false diff --git a/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala b/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala index ddf9ce97b40..e76d9b7f58a 100644 --- a/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala +++ b/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan} import org.apache.spark.sql.execution.datasources.v2.TableWriteExecHelper import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.sql.vectorized.ColumnarBatch @@ -58,7 +58,7 @@ case class GpuAtomicCreateTableAsSelectExec( query: SparkPlan, tableSpec: TableSpec, writeOptions: CaseInsensitiveStringMap, - ifNotExists: Boolean) extends TableWriteExecHelper with GpuExec { + ifNotExists: Boolean) extends TableWriteExecHelper with GpuExec with ColumnarToRowTransition { val properties = CatalogV2Util.convertTableProperties(tableSpec) diff --git a/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala b/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala index 1f00da4ad50..c5059046867 100644 --- a/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala +++ b/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog, Table, TableCatalog} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan} import org.apache.spark.sql.execution.datasources.v2.TableWriteExecHelper import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.sql.vectorized.ColumnarBatch @@ -63,7 +63,7 @@ case class GpuAtomicReplaceTableAsSelectExec( writeOptions: CaseInsensitiveStringMap, orCreate: Boolean, invalidateCache: (TableCatalog, Table, Identifier) => Unit) - extends TableWriteExecHelper with GpuExec { + extends TableWriteExecHelper with GpuExec with ColumnarToRowTransition { val properties = CatalogV2Util.convertTableProperties(tableSpec) diff --git a/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala index f4c1fb198a9..fd64ae0d568 100644 --- a/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala +++ b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan} import org.apache.spark.sql.execution.datasources.v2.TableWriteExecHelper import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.sql.vectorized.ColumnarBatch @@ -58,7 +58,7 @@ case class GpuAtomicCreateTableAsSelectExec( query: SparkPlan, tableSpec: TableSpec, writeOptions: CaseInsensitiveStringMap, - ifNotExists: Boolean) extends TableWriteExecHelper with GpuExec { + ifNotExists: Boolean) extends TableWriteExecHelper with GpuExec with ColumnarToRowTransition { val properties = CatalogV2Util.convertTableProperties(tableSpec) diff --git a/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala index 97e90b4a801..adedd9f9091 100644 --- a/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala +++ b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog, Table, TableCatalog} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan} import org.apache.spark.sql.execution.datasources.v2.TableWriteExecHelper import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.sql.vectorized.ColumnarBatch @@ -63,7 +63,7 @@ case class GpuAtomicReplaceTableAsSelectExec( writeOptions: CaseInsensitiveStringMap, orCreate: Boolean, invalidateCache: (TableCatalog, Table, Identifier) => Unit) - extends TableWriteExecHelper with GpuExec { + extends TableWriteExecHelper with GpuExec with ColumnarToRowTransition { val properties = CatalogV2Util.convertTableProperties(tableSpec) diff --git a/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala index 28a6e5910d3..1918e979c44 100644 --- a/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala +++ b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan} import org.apache.spark.sql.execution.datasources.v2.TableWriteExecHelper import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.sql.vectorized.ColumnarBatch @@ -55,7 +55,7 @@ case class GpuAtomicCreateTableAsSelectExec( query: SparkPlan, tableSpec: TableSpec, writeOptions: CaseInsensitiveStringMap, - ifNotExists: Boolean) extends TableWriteExecHelper with GpuExec { + ifNotExists: Boolean) extends TableWriteExecHelper with GpuExec with ColumnarToRowTransition { val properties = CatalogV2Util.convertTableProperties(tableSpec) diff --git a/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala index e3e5cd5a40b..8031e503fbe 100644 --- a/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala +++ b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog, Table, TableCatalog} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan} import org.apache.spark.sql.execution.datasources.v2.TableWriteExecHelper import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.sql.vectorized.ColumnarBatch @@ -60,7 +60,7 @@ case class GpuAtomicReplaceTableAsSelectExec( writeOptions: CaseInsensitiveStringMap, orCreate: Boolean, invalidateCache: (TableCatalog, Table, Identifier) => Unit) - extends TableWriteExecHelper with GpuExec { + extends TableWriteExecHelper with GpuExec with ColumnarToRowTransition { val properties = CatalogV2Util.convertTableProperties(tableSpec) diff --git a/sql-plugin/src/main/spark350/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala b/sql-plugin/src/main/spark350/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala index 2f6869be7e4..2c7039d328c 100644 --- a/sql-plugin/src/main/spark350/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala +++ b/sql-plugin/src/main/spark350/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, TableSpec} import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.ColumnarToRowTransition import org.apache.spark.sql.execution.datasources.v2.V2CreateTableAsSelectBaseExec import org.apache.spark.sql.vectorized.ColumnarBatch @@ -50,7 +51,8 @@ case class GpuAtomicCreateTableAsSelectExec( query: LogicalPlan, tableSpec: TableSpec, writeOptions: Map[String, String], - ifNotExists: Boolean) extends V2CreateTableAsSelectBaseExec with GpuExec { + ifNotExists: Boolean) + extends V2CreateTableAsSelectBaseExec with GpuExec with ColumnarToRowTransition { val properties = CatalogV2Util.convertTableProperties(tableSpec) diff --git a/sql-plugin/src/main/spark350/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala b/sql-plugin/src/main/spark350/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala index 9ac78288bef..cae386eeadf 100644 --- a/sql-plugin/src/main/spark350/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala +++ b/sql-plugin/src/main/spark350/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, TableSpec} import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog, Table, TableCatalog} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.ColumnarToRowTransition import org.apache.spark.sql.execution.datasources.v2.V2CreateTableAsSelectBaseExec import org.apache.spark.sql.vectorized.ColumnarBatch @@ -55,7 +56,7 @@ case class GpuAtomicReplaceTableAsSelectExec( writeOptions: Map[String, String], orCreate: Boolean, invalidateCache: (TableCatalog, Table, Identifier) => Unit) - extends V2CreateTableAsSelectBaseExec with GpuExec { + extends V2CreateTableAsSelectBaseExec with GpuExec with ColumnarToRowTransition { val properties = CatalogV2Util.convertTableProperties(tableSpec) From 72d44f5fa8e3206b6d6b92d00dbe7d3a8a91bcaa Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Mon, 16 Oct 2023 11:26:00 -0500 Subject: [PATCH 4/8] Fix 3.5 build --- .../v2/rapids/GpuAtomicCreateTableAsSelectExec.scala | 3 +-- .../v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/sql-plugin/src/main/spark350/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala b/sql-plugin/src/main/spark350/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala index 2c7039d328c..18ebf15042f 100644 --- a/sql-plugin/src/main/spark350/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala +++ b/sql-plugin/src/main/spark350/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicCreateTableAsSelectExec.scala @@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, TableSpec} import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.execution.ColumnarToRowTransition import org.apache.spark.sql.execution.datasources.v2.V2CreateTableAsSelectBaseExec import org.apache.spark.sql.vectorized.ColumnarBatch @@ -52,7 +51,7 @@ case class GpuAtomicCreateTableAsSelectExec( tableSpec: TableSpec, writeOptions: Map[String, String], ifNotExists: Boolean) - extends V2CreateTableAsSelectBaseExec with GpuExec with ColumnarToRowTransition { + extends V2CreateTableAsSelectBaseExec with GpuExec { val properties = CatalogV2Util.convertTableProperties(tableSpec) diff --git a/sql-plugin/src/main/spark350/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala b/sql-plugin/src/main/spark350/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala index cae386eeadf..9ac78288bef 100644 --- a/sql-plugin/src/main/spark350/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala +++ b/sql-plugin/src/main/spark350/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala @@ -30,7 +30,6 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, TableSpec} import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog, Table, TableCatalog} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.execution.ColumnarToRowTransition import org.apache.spark.sql.execution.datasources.v2.V2CreateTableAsSelectBaseExec import org.apache.spark.sql.vectorized.ColumnarBatch @@ -56,7 +55,7 @@ case class GpuAtomicReplaceTableAsSelectExec( writeOptions: Map[String, String], orCreate: Boolean, invalidateCache: (TableCatalog, Table, Identifier) => Unit) - extends V2CreateTableAsSelectBaseExec with GpuExec with ColumnarToRowTransition { + extends V2CreateTableAsSelectBaseExec with GpuExec { val properties = CatalogV2Util.convertTableProperties(tableSpec) From 97436a344037b6405c677e9a5e5fafc0a6787356 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Wed, 18 Oct 2023 10:28:06 -0500 Subject: [PATCH 5/8] Fix 332cdh build --- .../v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala index adedd9f9091..66da3dd9e5f 100644 --- a/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala +++ b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala @@ -19,6 +19,7 @@ {"spark": "330cdh"} {"spark": "331"} {"spark": "332"} +{"spark": "332cdh"} {"spark": "333"} spark-rapids-shim-json-lines ***/ package org.apache.spark.sql.execution.datasources.v2.rapids From 3a0c72c7a3722c999f2995be82db0b9a0902a8b1 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Wed, 18 Oct 2023 14:01:37 -0500 Subject: [PATCH 6/8] Remove accidental test duplication from upmerge --- .../src/main/python/delta_lake_write_test.py | 20 ------------------- 1 file changed, 20 deletions(-) diff --git a/integration_tests/src/main/python/delta_lake_write_test.py b/integration_tests/src/main/python/delta_lake_write_test.py index 11eb9a3f23d..3ed897f2142 100644 --- a/integration_tests/src/main/python/delta_lake_write_test.py +++ b/integration_tests/src/main/python/delta_lake_write_test.py @@ -290,26 +290,6 @@ def test_delta_atomic_create_table_as_select(gens, spark_tmp_table_factory, spar def test_delta_atomic_replace_table_as_select(gens, spark_tmp_table_factory, spark_tmp_path): _atomic_write_table_as_select(gens, spark_tmp_table_factory, spark_tmp_path, overwrite=True) -@allow_non_gpu(*delta_meta_allow) -@delta_lake -@ignore_order(local=True) -@pytest.mark.skipif(is_before_spark_320(), reason="Delta Lake writes are not supported before Spark 3.2.x") -@pytest.mark.parametrize("gens", parquet_write_gens_list, ids=idfn) -def test_delta_atomic_create_table_as_select(gens, spark_tmp_table_factory, spark_tmp_path): - gen_list = [("c" + str(i), gen) for i, gen in enumerate(gens)] - data_path = spark_tmp_path + "/DELTA_DATA" - confs = copy_and_update(writer_confs, delta_writes_enabled_conf) - path_to_table= {} - def do_write(spark, path): - table = spark_tmp_table_factory.get() - path_to_table[path] = table - gen_df(spark, gen_list).coalesce(1).write.format("delta").saveAsTable(table) - assert_gpu_and_cpu_writes_are_equal_collect( - do_write, - lambda spark, path: spark.read.format("delta").table(path_to_table[path]), - data_path, - conf = confs) - @allow_non_gpu(*delta_meta_allow) @delta_lake @ignore_order From e2a1816ceff18dc8d124a614de52fe2a1132bb58 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Thu, 19 Oct 2023 11:22:49 -0500 Subject: [PATCH 7/8] Fix comment --- .../v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala index d1f842a937a..fefcb66c10e 100644 --- a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala +++ b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala @@ -43,14 +43,15 @@ import org.apache.spark.sql.vectorized.ColumnarBatch /** * GPU version of AtomicReplaceTableAsSelectExec. * - * Physical plan node for v2 replace table as select when the catalog does not support staging + * Physical plan node for v2 replace table as select when the catalog supports staging * table replacement. * * A new table will be created using the schema of the query, and rows from the query are appended. * If the table exists, its contents and schema should be replaced with the schema and the contents - * of the query. This is a non-atomic implementation that drops the table and then runs non-atomic - * CTAS. For an atomic implementation for catalogs with the appropriate support, see - * ReplaceTableAsSelectStagingExec. + * of the query. This implementation is atomic. The table replacement is staged, and the commit + * operation at the end should perform the replacement of the table's metadata and contents. If the + * write fails, the table is instructed to roll back staged changes and any previously written table + * is left untouched. */ case class GpuAtomicReplaceTableAsSelectExec( catalog: StagingTableCatalog, From 0f5450b6ea85493bdb91a595fbbd3da6b1d4ad12 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Thu, 19 Oct 2023 12:28:14 -0500 Subject: [PATCH 8/8] Fix GpuAtomicReplaceTableAsSelectExec for 3.2.x --- .../GpuAtomicReplaceTableAsSelectExec.scala | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala index fefcb66c10e..ba992a908d7 100644 --- a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala +++ b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala @@ -30,6 +30,7 @@ import com.nvidia.spark.rapids.GpuExec import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.catalog.{Identifier, StagingTableCatalog, Table, TableCatalog} @@ -68,25 +69,26 @@ case class GpuAtomicReplaceTableAsSelectExec( override def supportsColumnar: Boolean = false override protected def run(): Seq[InternalRow] = { - // Note that this operation is potentially unsafe, but these are the strict semantics of - // RTAS if the catalog does not support atomic operations. - // - // There are numerous cases we concede to where the table will be dropped and irrecoverable: - // - // 1. Creating the new table fails, - // 2. Writing to the new table fails, - // 3. The table returned by catalog.createTable doesn't support writing. + val schema = CharVarcharUtils.getRawSchema(query.schema).asNullable if (catalog.tableExists(ident)) { val table = catalog.loadTable(ident) invalidateCache(catalog, table, ident) - catalog.dropTable(ident) - } else if (!orCreate) { + } + val staged = if (orCreate) { + catalog.stageCreateOrReplace( + ident, schema, partitioning.toArray, properties.asJava) + } else if (catalog.tableExists(ident)) { + try { + catalog.stageReplace( + ident, schema, partitioning.toArray, properties.asJava) + } catch { + case e: NoSuchTableException => + throw QueryCompilationErrors.cannotReplaceMissingTableError(ident, Some(e)) + } + } else { throw QueryCompilationErrors.cannotReplaceMissingTableError(ident) } - val schema = CharVarcharUtils.getRawSchema(query.schema).asNullable - val table = catalog.createTable( - ident, schema, partitioning.toArray, properties.asJava) - writeToTable(catalog, table, writeOptions, ident) + writeToTable(catalog, staged, writeOptions, ident) } override protected def withNewChildInternal(