diff --git a/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/DatabricksDeltaProvider.scala b/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/DatabricksDeltaProvider.scala index dc29208cef6..4186ca937af 100644 --- a/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/DatabricksDeltaProvider.scala +++ b/delta-lake/common/src/main/databricks/scala/com/nvidia/spark/rapids/delta/DatabricksDeltaProvider.scala @@ -31,8 +31,8 @@ import org.apache.spark.sql.connector.catalog.StagingTableCatalog import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.execution.datasources.{FileFormat, SaveIntoDataSourceCommand} -import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec -import org.apache.spark.sql.execution.datasources.v2.rapids.GpuAtomicCreateTableAsSelectExec +import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec} +import org.apache.spark.sql.execution.datasources.v2.rapids.{GpuAtomicCreateTableAsSelectExec, GpuAtomicReplaceTableAsSelectExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.rapids.ExternalSource import org.apache.spark.sql.sources.CreatableRelationProvider @@ -138,6 +138,40 @@ object DatabricksDeltaProvider extends DeltaProviderImplBase { cpuExec.writeOptions, cpuExec.ifNotExists) } + + override def tagForGpu( + cpuExec: AtomicReplaceTableAsSelectExec, + meta: AtomicReplaceTableAsSelectExecMeta): Unit = { + require(isSupportedCatalog(cpuExec.catalog.getClass)) + if (!meta.conf.isDeltaWriteEnabled) { + meta.willNotWorkOnGpu("Delta Lake output acceleration has been disabled. To enable set " + + s"${RapidsConf.ENABLE_DELTA_WRITE} to true") + } + val properties = cpuExec.properties + val provider = properties.getOrElse("provider", + cpuExec.conf.getConf(SQLConf.DEFAULT_DATA_SOURCE_NAME)) + if (!DeltaSourceUtils.isDeltaDataSourceName(provider)) { + meta.willNotWorkOnGpu(s"table provider '$provider' is not a Delta Lake provider") + } + RapidsDeltaUtils.tagForDeltaWrite(meta, cpuExec.query.schema, None, + cpuExec.writeOptions.asCaseSensitiveMap().asScala.toMap, cpuExec.session) + } + + override def convertToGpu( + cpuExec: AtomicReplaceTableAsSelectExec, + meta: AtomicReplaceTableAsSelectExecMeta): GpuExec = { + GpuAtomicReplaceTableAsSelectExec( + cpuExec.output, + new GpuDeltaCatalog(cpuExec.catalog, meta.conf), + cpuExec.ident, + cpuExec.partitioning, + cpuExec.plan, + meta.childPlans.head.convertIfNeeded(), + cpuExec.tableSpec, + cpuExec.writeOptions, + cpuExec.orCreate, + cpuExec.invalidateCache) + } } class DeltaCreatableRelationProviderMeta( diff --git a/delta-lake/common/src/main/delta-io/scala/com/nvidia/spark/rapids/delta/DeltaIOProvider.scala b/delta-lake/common/src/main/delta-io/scala/com/nvidia/spark/rapids/delta/DeltaIOProvider.scala index 049f98180cf..e35f85f9859 100644 --- a/delta-lake/common/src/main/delta-io/scala/com/nvidia/spark/rapids/delta/DeltaIOProvider.scala +++ b/delta-lake/common/src/main/delta-io/scala/com/nvidia/spark/rapids/delta/DeltaIOProvider.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.delta.catalog.DeltaCatalog import org.apache.spark.sql.delta.rapids.DeltaRuntimeShim import org.apache.spark.sql.delta.sources.{DeltaDataSource, DeltaSourceUtils} import org.apache.spark.sql.execution.datasources.{FileFormat, SaveIntoDataSourceCommand} -import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec +import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.rapids.ExternalSource import org.apache.spark.sql.rapids.execution.UnshimmedTrampolineUtil @@ -66,15 +66,33 @@ abstract class DeltaIOProvider extends DeltaProviderImplBase { meta.willNotWorkOnGpu("Delta Lake output acceleration has been disabled. To enable set " + s"${RapidsConf.ENABLE_DELTA_WRITE} to true") } - val properties = cpuExec.properties - val provider = properties.getOrElse("provider", - cpuExec.conf.getConf(SQLConf.DEFAULT_DATA_SOURCE_NAME)) - if (!DeltaSourceUtils.isDeltaDataSourceName(provider)) { - meta.willNotWorkOnGpu(s"table provider '$provider' is not a Delta Lake provider") + checkDeltaProvider(meta, cpuExec.properties, cpuExec.conf) + RapidsDeltaUtils.tagForDeltaWrite(meta, cpuExec.query.schema, None, + cpuExec.writeOptions.asCaseSensitiveMap().asScala.toMap, cpuExec.session) + } + + override def tagForGpu( + cpuExec: AtomicReplaceTableAsSelectExec, + meta: AtomicReplaceTableAsSelectExecMeta): Unit = { + require(isSupportedCatalog(cpuExec.catalog.getClass)) + if (!meta.conf.isDeltaWriteEnabled) { + meta.willNotWorkOnGpu("Delta Lake output acceleration has been disabled. To enable set " + + s"${RapidsConf.ENABLE_DELTA_WRITE} to true") } + checkDeltaProvider(meta, cpuExec.properties, cpuExec.conf) RapidsDeltaUtils.tagForDeltaWrite(meta, cpuExec.query.schema, None, cpuExec.writeOptions.asCaseSensitiveMap().asScala.toMap, cpuExec.session) } + + private def checkDeltaProvider( + meta: RapidsMeta[_, _, _], + properties: Map[String, String], + conf: SQLConf): Unit = { + val provider = properties.getOrElse("provider", conf.getConf(SQLConf.DEFAULT_DATA_SOURCE_NAME)) + if (!DeltaSourceUtils.isDeltaDataSourceName(provider)) { + meta.willNotWorkOnGpu(s"table provider '$provider' is not a Delta Lake provider") + } + } } class DeltaCreatableRelationProviderMeta( diff --git a/delta-lake/delta-20x/src/main/scala/com/nvidia/spark/rapids/delta/delta20x/Delta20xProvider.scala b/delta-lake/delta-20x/src/main/scala/com/nvidia/spark/rapids/delta/delta20x/Delta20xProvider.scala index 7542b7b612b..961c6632a7e 100644 --- a/delta-lake/delta-20x/src/main/scala/com/nvidia/spark/rapids/delta/delta20x/Delta20xProvider.scala +++ b/delta-lake/delta-20x/src/main/scala/com/nvidia/spark/rapids/delta/delta20x/Delta20xProvider.scala @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids.delta.delta20x -import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, GpuExec, GpuOverrides, GpuReadParquetFileFormat, RunnableCommandRule, SparkPlanMeta} +import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, AtomicReplaceTableAsSelectExecMeta, GpuExec, GpuOverrides, GpuReadParquetFileFormat, RunnableCommandRule, SparkPlanMeta} import com.nvidia.spark.rapids.delta.DeltaIOProvider import org.apache.spark.sql.delta.DeltaParquetFileFormat @@ -26,8 +26,8 @@ import org.apache.spark.sql.delta.rapids.DeltaRuntimeShim import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.execution.datasources.FileFormat -import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec -import org.apache.spark.sql.execution.datasources.v2.rapids.GpuAtomicCreateTableAsSelectExec +import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec} +import org.apache.spark.sql.execution.datasources.v2.rapids.{GpuAtomicCreateTableAsSelectExec, GpuAtomicReplaceTableAsSelectExec} object Delta20xProvider extends DeltaIOProvider { @@ -77,4 +77,20 @@ object Delta20xProvider extends DeltaIOProvider { cpuExec.writeOptions, cpuExec.ifNotExists) } + + override def convertToGpu( + cpuExec: AtomicReplaceTableAsSelectExec, + meta: AtomicReplaceTableAsSelectExecMeta): GpuExec = { + val cpuCatalog = cpuExec.catalog.asInstanceOf[DeltaCatalog] + GpuAtomicReplaceTableAsSelectExec( + DeltaRuntimeShim.getGpuDeltaCatalog(cpuCatalog, meta.conf), + cpuExec.ident, + cpuExec.partitioning, + cpuExec.plan, + meta.childPlans.head.convertIfNeeded(), + cpuExec.properties, + cpuExec.writeOptions, + cpuExec.orCreate, + cpuExec.invalidateCache) + } } diff --git a/delta-lake/delta-21x/src/main/scala/com/nvidia/spark/rapids/delta/delta21x/Delta21xProvider.scala b/delta-lake/delta-21x/src/main/scala/com/nvidia/spark/rapids/delta/delta21x/Delta21xProvider.scala index 4fcb2a993bc..9e5387b8ef1 100644 --- a/delta-lake/delta-21x/src/main/scala/com/nvidia/spark/rapids/delta/delta21x/Delta21xProvider.scala +++ b/delta-lake/delta-21x/src/main/scala/com/nvidia/spark/rapids/delta/delta21x/Delta21xProvider.scala @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids.delta.delta21x -import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, GpuExec, GpuOverrides, GpuReadParquetFileFormat, RunnableCommandRule, SparkPlanMeta} +import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, AtomicReplaceTableAsSelectExecMeta, GpuExec, GpuOverrides, GpuReadParquetFileFormat, RunnableCommandRule, SparkPlanMeta} import com.nvidia.spark.rapids.delta.DeltaIOProvider import org.apache.spark.sql.delta.DeltaParquetFileFormat @@ -26,8 +26,8 @@ import org.apache.spark.sql.delta.rapids.DeltaRuntimeShim import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.execution.datasources.FileFormat -import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec -import org.apache.spark.sql.execution.datasources.v2.rapids.GpuAtomicCreateTableAsSelectExec +import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec} +import org.apache.spark.sql.execution.datasources.v2.rapids.{GpuAtomicCreateTableAsSelectExec, GpuAtomicReplaceTableAsSelectExec} object Delta21xProvider extends DeltaIOProvider { @@ -77,4 +77,20 @@ object Delta21xProvider extends DeltaIOProvider { cpuExec.writeOptions, cpuExec.ifNotExists) } + + override def convertToGpu( + cpuExec: AtomicReplaceTableAsSelectExec, + meta: AtomicReplaceTableAsSelectExecMeta): GpuExec = { + val cpuCatalog = cpuExec.catalog.asInstanceOf[DeltaCatalog] + GpuAtomicReplaceTableAsSelectExec( + DeltaRuntimeShim.getGpuDeltaCatalog(cpuCatalog, meta.conf), + cpuExec.ident, + cpuExec.partitioning, + cpuExec.plan, + meta.childPlans.head.convertIfNeeded(), + cpuExec.tableSpec, + cpuExec.writeOptions, + cpuExec.orCreate, + cpuExec.invalidateCache) + } } diff --git a/delta-lake/delta-22x/src/main/scala/com/nvidia/spark/rapids/delta/delta22x/Delta22xProvider.scala b/delta-lake/delta-22x/src/main/scala/com/nvidia/spark/rapids/delta/delta22x/Delta22xProvider.scala index 73ead6f45e4..47d158bb37b 100644 --- a/delta-lake/delta-22x/src/main/scala/com/nvidia/spark/rapids/delta/delta22x/Delta22xProvider.scala +++ b/delta-lake/delta-22x/src/main/scala/com/nvidia/spark/rapids/delta/delta22x/Delta22xProvider.scala @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids.delta.delta22x -import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, GpuExec, GpuOverrides, GpuReadParquetFileFormat, RunnableCommandRule, SparkPlanMeta} +import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, AtomicReplaceTableAsSelectExecMeta, GpuExec, GpuOverrides, GpuReadParquetFileFormat, RunnableCommandRule, SparkPlanMeta} import com.nvidia.spark.rapids.delta.DeltaIOProvider import org.apache.spark.sql.delta.DeltaParquetFileFormat @@ -26,8 +26,8 @@ import org.apache.spark.sql.delta.rapids.DeltaRuntimeShim import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.execution.datasources.FileFormat -import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec -import org.apache.spark.sql.execution.datasources.v2.rapids.GpuAtomicCreateTableAsSelectExec +import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec} +import org.apache.spark.sql.execution.datasources.v2.rapids.{GpuAtomicCreateTableAsSelectExec, GpuAtomicReplaceTableAsSelectExec} object Delta22xProvider extends DeltaIOProvider { @@ -77,4 +77,20 @@ object Delta22xProvider extends DeltaIOProvider { cpuExec.writeOptions, cpuExec.ifNotExists) } + + override def convertToGpu( + cpuExec: AtomicReplaceTableAsSelectExec, + meta: AtomicReplaceTableAsSelectExecMeta): GpuExec = { + val cpuCatalog = cpuExec.catalog.asInstanceOf[DeltaCatalog] + GpuAtomicReplaceTableAsSelectExec( + DeltaRuntimeShim.getGpuDeltaCatalog(cpuCatalog, meta.conf), + cpuExec.ident, + cpuExec.partitioning, + cpuExec.plan, + meta.childPlans.head.convertIfNeeded(), + cpuExec.tableSpec, + cpuExec.writeOptions, + cpuExec.orCreate, + cpuExec.invalidateCache) + } } diff --git a/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/Delta24xProvider.scala b/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/Delta24xProvider.scala index f50f1d96ced..d3f952b856c 100644 --- a/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/Delta24xProvider.scala +++ b/delta-lake/delta-24x/src/main/scala/com/nvidia/spark/rapids/delta/delta24x/Delta24xProvider.scala @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids.delta.delta24x -import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, GpuExec, GpuOverrides, GpuReadParquetFileFormat, RunnableCommandRule, SparkPlanMeta} +import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, AtomicReplaceTableAsSelectExecMeta, GpuExec, GpuOverrides, GpuReadParquetFileFormat, RunnableCommandRule, SparkPlanMeta} import com.nvidia.spark.rapids.delta.DeltaIOProvider import org.apache.spark.sql.delta.DeltaParquetFileFormat @@ -27,8 +27,8 @@ import org.apache.spark.sql.delta.rapids.DeltaRuntimeShim import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.execution.datasources.FileFormat -import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec -import org.apache.spark.sql.execution.datasources.v2.rapids.GpuAtomicCreateTableAsSelectExec +import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec} +import org.apache.spark.sql.execution.datasources.v2.rapids.{GpuAtomicCreateTableAsSelectExec, GpuAtomicReplaceTableAsSelectExec} object Delta24xProvider extends DeltaIOProvider { @@ -91,4 +91,20 @@ object Delta24xProvider extends DeltaIOProvider { cpuExec.writeOptions, cpuExec.ifNotExists) } + + override def convertToGpu( + cpuExec: AtomicReplaceTableAsSelectExec, + meta: AtomicReplaceTableAsSelectExecMeta): GpuExec = { + val cpuCatalog = cpuExec.catalog.asInstanceOf[DeltaCatalog] + GpuAtomicReplaceTableAsSelectExec( + DeltaRuntimeShim.getGpuDeltaCatalog(cpuCatalog, meta.conf), + cpuExec.ident, + cpuExec.partitioning, + cpuExec.plan, + meta.childPlans.head.convertIfNeeded(), + cpuExec.tableSpec, + cpuExec.writeOptions, + cpuExec.orCreate, + cpuExec.invalidateCache) + } } diff --git a/integration_tests/src/main/python/delta_lake_write_test.py b/integration_tests/src/main/python/delta_lake_write_test.py index 0327947880f..3ed897f2142 100644 --- a/integration_tests/src/main/python/delta_lake_write_test.py +++ b/integration_tests/src/main/python/delta_lake_write_test.py @@ -256,12 +256,7 @@ def test_delta_overwrite_round_trip_unmanaged(spark_tmp_path): def test_delta_append_round_trip_unmanaged(spark_tmp_path): do_update_round_trip_managed(spark_tmp_path, "append") -@allow_non_gpu(*delta_meta_allow) -@delta_lake -@ignore_order(local=True) -@pytest.mark.skipif(is_before_spark_320(), reason="Delta Lake writes are not supported before Spark 3.2.x") -@pytest.mark.parametrize("gens", parquet_write_gens_list, ids=idfn) -def test_delta_atomic_create_table_as_select(gens, spark_tmp_table_factory, spark_tmp_path): +def _atomic_write_table_as_select(gens, spark_tmp_table_factory, spark_tmp_path, overwrite): gen_list = [("c" + str(i), gen) for i, gen in enumerate(gens)] data_path = spark_tmp_path + "/DELTA_DATA" confs = copy_and_update(writer_confs, delta_writes_enabled_conf) @@ -269,12 +264,31 @@ def test_delta_atomic_create_table_as_select(gens, spark_tmp_table_factory, spar def do_write(spark, path): table = spark_tmp_table_factory.get() path_to_table[path] = table - gen_df(spark, gen_list).coalesce(1).write.format("delta").saveAsTable(table) + writer = gen_df(spark, gen_list).coalesce(1).write.format("delta") + if overwrite: + writer = writer.mode("overwrite") + writer.saveAsTable(table) assert_gpu_and_cpu_writes_are_equal_collect( do_write, lambda spark, path: spark.read.format("delta").table(path_to_table[path]), data_path, - conf = confs) + conf=confs) + +@allow_non_gpu(*delta_meta_allow) +@delta_lake +@ignore_order(local=True) +@pytest.mark.skipif(is_before_spark_320(), reason="Delta Lake writes are not supported before Spark 3.2.x") +@pytest.mark.parametrize("gens", parquet_write_gens_list, ids=idfn) +def test_delta_atomic_create_table_as_select(gens, spark_tmp_table_factory, spark_tmp_path): + _atomic_write_table_as_select(gens, spark_tmp_table_factory, spark_tmp_path, overwrite=False) + +@allow_non_gpu(*delta_meta_allow) +@delta_lake +@ignore_order(local=True) +@pytest.mark.skipif(is_before_spark_320(), reason="Delta Lake writes are not supported before Spark 3.2.x") +@pytest.mark.parametrize("gens", parquet_write_gens_list, ids=idfn) +def test_delta_atomic_replace_table_as_select(gens, spark_tmp_table_factory, spark_tmp_path): + _atomic_write_table_as_select(gens, spark_tmp_table_factory, spark_tmp_path, overwrite=True) @allow_non_gpu(*delta_meta_allow) @delta_lake diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/delta/DeltaProvider.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/delta/DeltaProvider.scala index 6c5e53df74f..1fc6492826a 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/delta/DeltaProvider.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/delta/DeltaProvider.scala @@ -16,14 +16,14 @@ package com.nvidia.spark.rapids.delta -import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, CreatableRelationProviderRule, ExecRule, GpuExec, RunnableCommandRule, ShimLoader, SparkPlanMeta} +import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, AtomicReplaceTableAsSelectExecMeta, CreatableRelationProviderRule, ExecRule, GpuExec, RunnableCommandRule, ShimLoader, SparkPlanMeta} import org.apache.spark.sql.Strategy import org.apache.spark.sql.connector.catalog.StagingTableCatalog import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.execution.datasources.FileFormat -import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec +import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec} import org.apache.spark.sql.sources.CreatableRelationProvider /** Probe interface to determine which Delta Lake provider to use. */ @@ -58,6 +58,14 @@ trait DeltaProvider { def convertToGpu( cpuExec: AtomicCreateTableAsSelectExec, meta: AtomicCreateTableAsSelectExecMeta): GpuExec + + def tagForGpu( + cpuExec: AtomicReplaceTableAsSelectExec, + meta: AtomicReplaceTableAsSelectExecMeta): Unit + + def convertToGpu( + cpuExec: AtomicReplaceTableAsSelectExec, + meta: AtomicReplaceTableAsSelectExecMeta): GpuExec } object DeltaProvider { @@ -100,4 +108,16 @@ object NoDeltaProvider extends DeltaProvider { meta: AtomicCreateTableAsSelectExecMeta): GpuExec = { throw new IllegalStateException("catalog not supported, should not be called") } + + override def tagForGpu( + cpuExec: AtomicReplaceTableAsSelectExec, + meta: AtomicReplaceTableAsSelectExecMeta): Unit = { + throw new IllegalStateException("catalog not supported, should not be called") + } + + override def convertToGpu( + cpuExec: AtomicReplaceTableAsSelectExec, + meta: AtomicReplaceTableAsSelectExecMeta): GpuExec = { + throw new IllegalStateException("catalog not supported, should not be called") + } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AtomicCreateTableAsSelectExecMeta.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/v2WriteCommandMetas.scala similarity index 66% rename from sql-plugin/src/main/scala/com/nvidia/spark/rapids/AtomicCreateTableAsSelectExecMeta.scala rename to sql-plugin/src/main/scala/com/nvidia/spark/rapids/v2WriteCommandMetas.scala index 54b3c95894d..dbb3f5d1e4b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AtomicCreateTableAsSelectExecMeta.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/v2WriteCommandMetas.scala @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids -import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec +import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec} import org.apache.spark.sql.rapids.ExternalSource class AtomicCreateTableAsSelectExecMeta( @@ -34,3 +34,19 @@ class AtomicCreateTableAsSelectExecMeta( ExternalSource.convertToGpu(wrapped, this) } } + +class AtomicReplaceTableAsSelectExecMeta( + wrapped: AtomicReplaceTableAsSelectExec, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _, _]], + rule: DataFromReplacementRule) + extends SparkPlanMeta[AtomicReplaceTableAsSelectExec](wrapped, conf, parent, rule) { + + override def tagPlanForGpu(): Unit = { + ExternalSource.tagForGpu(wrapped, this) + } + + override def convertToGpu(): GpuExec = { + ExternalSource.convertToGpu(wrapped, this) + } +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala index aa3588fcab8..dfc4aad1192 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ExternalSource.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.execution.datasources.FileFormat -import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec +import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec} import org.apache.spark.sql.sources.CreatableRelationProvider import org.apache.spark.util.Utils @@ -159,4 +159,26 @@ object ExternalSource extends Logging { throw new IllegalStateException("No GPU conversion") } } + + def tagForGpu( + cpuExec: AtomicReplaceTableAsSelectExec, + meta: AtomicReplaceTableAsSelectExecMeta): Unit = { + val catalogClass = cpuExec.catalog.getClass + if (deltaProvider.isSupportedCatalog(catalogClass)) { + deltaProvider.tagForGpu(cpuExec, meta) + } else { + meta.willNotWorkOnGpu(s"catalog ${cpuExec.catalog.getClass} is not supported") + } + } + + def convertToGpu( + cpuExec: AtomicReplaceTableAsSelectExec, + meta: AtomicReplaceTableAsSelectExecMeta): GpuExec = { + val catalogClass = cpuExec.catalog.getClass + if (deltaProvider.isSupportedCatalog(catalogClass)) { + deltaProvider.convertToGpu(cpuExec, meta) + } else { + throw new IllegalStateException("No GPU conversion") + } + } } diff --git a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala index 195ad544db9..f94351ac6c6 100644 --- a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala +++ b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala @@ -53,7 +53,7 @@ import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive._ import org.apache.spark.sql.execution.command._ -import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec +import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec} import org.apache.spark.sql.execution.datasources.v2.csv.CSVScan import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan @@ -261,6 +261,13 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging { GpuTypeShims.additionalCommonOperatorSupportedTypes).nested(), TypeSig.all), (e, conf, p, r) => new AtomicCreateTableAsSelectExecMeta(e, conf, p, r)), + exec[AtomicReplaceTableAsSelectExec]( + "Replace table as select for datasource V2 tables that support staging table creation", + ExecChecks((TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.STRUCT + + TypeSig.MAP + TypeSig.ARRAY + TypeSig.BINARY + + GpuTypeShims.additionalCommonOperatorSupportedTypes).nested(), + TypeSig.all), + (e, conf, p, r) => new AtomicReplaceTableAsSelectExecMeta(e, conf, p, r)), GpuOverrides.exec[WindowInPandasExec]( "The backend for Window Aggregation Pandas UDF, Accelerates the data transfer between" + " the Java process and the Python process. It also supports scheduling GPU resources" + diff --git a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala new file mode 100644 index 00000000000..ba992a908d7 --- /dev/null +++ b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "320"} +{"spark": "321"} +{"spark": "321cdh"} +{"spark": "322"} +{"spark": "323"} +{"spark": "324"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.execution.datasources.v2.rapids + +import scala.collection.JavaConverters._ + +import com.nvidia.spark.rapids.GpuExec + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.connector.catalog.{Identifier, StagingTableCatalog, Table, TableCatalog} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan} +import org.apache.spark.sql.execution.datasources.v2.TableWriteExecHelper +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * GPU version of AtomicReplaceTableAsSelectExec. + * + * Physical plan node for v2 replace table as select when the catalog supports staging + * table replacement. + * + * A new table will be created using the schema of the query, and rows from the query are appended. + * If the table exists, its contents and schema should be replaced with the schema and the contents + * of the query. This implementation is atomic. The table replacement is staged, and the commit + * operation at the end should perform the replacement of the table's metadata and contents. If the + * write fails, the table is instructed to roll back staged changes and any previously written table + * is left untouched. + */ +case class GpuAtomicReplaceTableAsSelectExec( + catalog: StagingTableCatalog, + ident: Identifier, + partitioning: Seq[Transform], + plan: LogicalPlan, + query: SparkPlan, + properties: Map[String, String], + writeOptions: CaseInsensitiveStringMap, + orCreate: Boolean, + invalidateCache: (TableCatalog, Table, Identifier) => Unit) + extends TableWriteExecHelper with GpuExec with ColumnarToRowTransition { + + override def supportsColumnar: Boolean = false + + override protected def run(): Seq[InternalRow] = { + val schema = CharVarcharUtils.getRawSchema(query.schema).asNullable + if (catalog.tableExists(ident)) { + val table = catalog.loadTable(ident) + invalidateCache(catalog, table, ident) + } + val staged = if (orCreate) { + catalog.stageCreateOrReplace( + ident, schema, partitioning.toArray, properties.asJava) + } else if (catalog.tableExists(ident)) { + try { + catalog.stageReplace( + ident, schema, partitioning.toArray, properties.asJava) + } catch { + case e: NoSuchTableException => + throw QueryCompilationErrors.cannotReplaceMissingTableError(ident, Some(e)) + } + } else { + throw QueryCompilationErrors.cannotReplaceMissingTableError(ident) + } + writeToTable(catalog, staged, writeOptions, ident) + } + + override protected def withNewChildInternal( + newChild: SparkPlan): GpuAtomicReplaceTableAsSelectExec = copy(query = newChild) + + override protected def internalDoExecuteColumnar(): RDD[ColumnarBatch] = + throw new IllegalStateException("Columnar execution not supported") +} diff --git a/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala b/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala new file mode 100644 index 00000000000..c5059046867 --- /dev/null +++ b/sql-plugin/src/main/spark321db/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "321db"} +{"spark": "330db"} +{"spark": "332db"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.execution.datasources.v2.rapids + +import scala.collection.JavaConverters._ + +import com.nvidia.spark.rapids.GpuExec + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, TableSpec} +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog, Table, TableCatalog} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan} +import org.apache.spark.sql.execution.datasources.v2.TableWriteExecHelper +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * GPU version of AtomicReplaceTableAsSelectExec. + * + * Physical plan node for v2 replace table as select when the catalog supports staging + * table replacement. + * + * A new table will be created using the schema of the query, and rows from the query are appended. + * If the table exists, its contents and schema should be replaced with the schema and the contents + * of the query. This implementation is atomic. The table replacement is staged, and the commit + * operation at the end should perform the replacement of the table's metadata and contents. If the + * write fails, the table is instructed to roll back staged changes and any previously written table + * is left untouched. + */ +case class GpuAtomicReplaceTableAsSelectExec( + override val output: Seq[Attribute], + catalog: StagingTableCatalog, + ident: Identifier, + partitioning: Seq[Transform], + plan: LogicalPlan, + query: SparkPlan, + tableSpec: TableSpec, + writeOptions: CaseInsensitiveStringMap, + orCreate: Boolean, + invalidateCache: (TableCatalog, Table, Identifier) => Unit) + extends TableWriteExecHelper with GpuExec with ColumnarToRowTransition { + + val properties = CatalogV2Util.convertTableProperties(tableSpec) + + override def supportsColumnar: Boolean = false + + override protected def run(): Seq[InternalRow] = { + val schema = CharVarcharUtils.getRawSchema(query.schema, conf).asNullable + if (catalog.tableExists(ident)) { + val table = catalog.loadTable(ident) + invalidateCache(catalog, table, ident) + } + val staged = if (orCreate) { + catalog.stageCreateOrReplace( + ident, schema, partitioning.toArray, properties.asJava) + } else if (catalog.tableExists(ident)) { + try { + catalog.stageReplace( + ident, schema, partitioning.toArray, properties.asJava) + } catch { + case e: NoSuchTableException => + throw QueryCompilationErrors.cannotReplaceMissingTableError(ident, Some(e)) + } + } else { + throw QueryCompilationErrors.cannotReplaceMissingTableError(ident) + } + writeToTable(catalog, staged, writeOptions, ident) + } + + override protected def withNewChildInternal( + newChild: SparkPlan): GpuAtomicReplaceTableAsSelectExec = copy(query = newChild) + + override protected def internalDoExecuteColumnar(): RDD[ColumnarBatch] = + throw new IllegalStateException("Columnar execution not supported") +} diff --git a/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala new file mode 100644 index 00000000000..66da3dd9e5f --- /dev/null +++ b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "330"} +{"spark": "330cdh"} +{"spark": "331"} +{"spark": "332"} +{"spark": "332cdh"} +{"spark": "333"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.execution.datasources.v2.rapids + +import scala.collection.JavaConverters._ + +import com.nvidia.spark.rapids.GpuExec + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, TableSpec} +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog, Table, TableCatalog} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan} +import org.apache.spark.sql.execution.datasources.v2.TableWriteExecHelper +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * GPU version of AtomicReplaceTableAsSelectExec. + * + * Physical plan node for v2 replace table as select when the catalog supports staging + * table replacement. + * + * A new table will be created using the schema of the query, and rows from the query are appended. + * If the table exists, its contents and schema should be replaced with the schema and the contents + * of the query. This implementation is atomic. The table replacement is staged, and the commit + * operation at the end should perform the replacement of the table's metadata and contents. If the + * write fails, the table is instructed to roll back staged changes and any previously written table + * is left untouched. + */ +case class GpuAtomicReplaceTableAsSelectExec( + catalog: StagingTableCatalog, + ident: Identifier, + partitioning: Seq[Transform], + plan: LogicalPlan, + query: SparkPlan, + tableSpec: TableSpec, + writeOptions: CaseInsensitiveStringMap, + orCreate: Boolean, + invalidateCache: (TableCatalog, Table, Identifier) => Unit) + extends TableWriteExecHelper with GpuExec with ColumnarToRowTransition { + + val properties = CatalogV2Util.convertTableProperties(tableSpec) + + override def supportsColumnar: Boolean = false + + override protected def run(): Seq[InternalRow] = { + val schema = CharVarcharUtils.getRawSchema(query.schema, conf).asNullable + if (catalog.tableExists(ident)) { + val table = catalog.loadTable(ident) + invalidateCache(catalog, table, ident) + } + val staged = if (orCreate) { + catalog.stageCreateOrReplace( + ident, schema, partitioning.toArray, properties.asJava) + } else if (catalog.tableExists(ident)) { + try { + catalog.stageReplace( + ident, schema, partitioning.toArray, properties.asJava) + } catch { + case e: NoSuchTableException => + throw QueryCompilationErrors.cannotReplaceMissingTableError(ident, Some(e)) + } + } else { + throw QueryCompilationErrors.cannotReplaceMissingTableError(ident) + } + writeToTable(catalog, staged, writeOptions, ident) + } + + override protected def withNewChildInternal( + newChild: SparkPlan): GpuAtomicReplaceTableAsSelectExec = copy(query = newChild) + + override protected def internalDoExecuteColumnar(): RDD[ColumnarBatch] = + throw new IllegalStateException("Columnar execution not supported") +} diff --git a/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala new file mode 100644 index 00000000000..8031e503fbe --- /dev/null +++ b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala @@ -0,0 +1,98 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "340"} +{"spark": "341"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.execution.datasources.v2.rapids + +import scala.collection.JavaConverters._ + +import com.nvidia.spark.rapids.GpuExec + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, TableSpec} +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog, Table, TableCatalog} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan} +import org.apache.spark.sql.execution.datasources.v2.TableWriteExecHelper +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * GPU version of AtomicReplaceTableAsSelectExec. + * + * Physical plan node for v2 replace table as select when the catalog supports staging + * table replacement. + * + * A new table will be created using the schema of the query, and rows from the query are appended. + * If the table exists, its contents and schema should be replaced with the schema and the contents + * of the query. This implementation is atomic. The table replacement is staged, and the commit + * operation at the end should perform the replacement of the table's metadata and contents. If the + * write fails, the table is instructed to roll back staged changes and any previously written table + * is left untouched. + */ +case class GpuAtomicReplaceTableAsSelectExec( + catalog: StagingTableCatalog, + ident: Identifier, + partitioning: Seq[Transform], + plan: LogicalPlan, + query: SparkPlan, + tableSpec: TableSpec, + writeOptions: CaseInsensitiveStringMap, + orCreate: Boolean, + invalidateCache: (TableCatalog, Table, Identifier) => Unit) + extends TableWriteExecHelper with GpuExec with ColumnarToRowTransition { + + val properties = CatalogV2Util.convertTableProperties(tableSpec) + + override def supportsColumnar: Boolean = false + + override protected def run(): Seq[InternalRow] = { + val columns = CatalogV2Util.structTypeToV2Columns( + CharVarcharUtils.getRawSchema(query.schema, conf).asNullable) + if (catalog.tableExists(ident)) { + val table = catalog.loadTable(ident) + invalidateCache(catalog, table, ident) + } + val staged = if (orCreate) { + catalog.stageCreateOrReplace( + ident, columns, partitioning.toArray, properties.asJava) + } else if (catalog.tableExists(ident)) { + try { + catalog.stageReplace( + ident, columns, partitioning.toArray, properties.asJava) + } catch { + case e: NoSuchTableException => + throw QueryCompilationErrors.cannotReplaceMissingTableError(ident, Some(e)) + } + } else { + throw QueryCompilationErrors.cannotReplaceMissingTableError(ident) + } + writeToTable(catalog, staged, writeOptions, ident) + } + + override protected def withNewChildInternal( + newChild: SparkPlan): GpuAtomicReplaceTableAsSelectExec = copy(query = newChild) + + override protected def internalDoExecuteColumnar(): RDD[ColumnarBatch] = + throw new IllegalStateException("Columnar execution not supported") +} diff --git a/sql-plugin/src/main/spark350/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala b/sql-plugin/src/main/spark350/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala new file mode 100644 index 00000000000..9ac78288bef --- /dev/null +++ b/sql-plugin/src/main/spark350/scala/org/apache/spark/sql/execution/datasources/v2/rapids/GpuAtomicReplaceTableAsSelectExec.scala @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "350"} +spark-rapids-shim-json-lines ***/ +package org.apache.spark.sql.execution.datasources.v2.rapids + +import scala.collection.JavaConverters._ + +import com.nvidia.spark.rapids.GpuExec + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, TableSpec} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog, Table, TableCatalog} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.datasources.v2.V2CreateTableAsSelectBaseExec +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * GPU version of AtomicReplaceTableAsSelectExec. + * + * Physical plan node for v2 replace table as select when the catalog supports staging + * table replacement. + * + * A new table will be created using the schema of the query, and rows from the query are appended. + * If the table exists, its contents and schema should be replaced with the schema and the contents + * of the query. This implementation is atomic. The table replacement is staged, and the commit + * operation at the end should perform the replacement of the table's metadata and contents. If the + * write fails, the table is instructed to roll back staged changes and any previously written table + * is left untouched. + */ +case class GpuAtomicReplaceTableAsSelectExec( + catalog: StagingTableCatalog, + ident: Identifier, + partitioning: Seq[Transform], + query: LogicalPlan, + tableSpec: TableSpec, + writeOptions: Map[String, String], + orCreate: Boolean, + invalidateCache: (TableCatalog, Table, Identifier) => Unit) + extends V2CreateTableAsSelectBaseExec with GpuExec { + + val properties = CatalogV2Util.convertTableProperties(tableSpec) + + override def supportsColumnar: Boolean = false + + override protected def run(): Seq[InternalRow] = { + val columns = getV2Columns(query.schema, catalog.useNullableQuerySchema) + if (catalog.tableExists(ident)) { + val table = catalog.loadTable(ident) + invalidateCache(catalog, table, ident) + } + val staged = if (orCreate) { + catalog.stageCreateOrReplace( + ident, columns, partitioning.toArray, properties.asJava) + } else if (catalog.tableExists(ident)) { + try { + catalog.stageReplace( + ident, columns, partitioning.toArray, properties.asJava) + } catch { + case e: NoSuchTableException => + throw QueryCompilationErrors.cannotReplaceMissingTableError(ident, Some(e)) + } + } else { + throw QueryCompilationErrors.cannotReplaceMissingTableError(ident) + } + writeToTable(catalog, staged, writeOptions, ident, query) + } + + override protected def internalDoExecuteColumnar(): RDD[ColumnarBatch] = + throw new IllegalStateException("Columnar execution not supported") +}