Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AtomicReplaceTableAsSelectExec support for Delta Lake [databricks] #9443

Merged
merged 10 commits into from
Oct 19, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ import org.apache.spark.sql.connector.catalog.StagingTableCatalog
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.execution.datasources.{FileFormat, SaveIntoDataSourceCommand}
import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec
import org.apache.spark.sql.execution.datasources.v2.rapids.GpuAtomicCreateTableAsSelectExec
import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec}
import org.apache.spark.sql.execution.datasources.v2.rapids.{GpuAtomicCreateTableAsSelectExec, GpuAtomicReplaceTableAsSelectExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.ExternalSource
import org.apache.spark.sql.sources.CreatableRelationProvider
Expand Down Expand Up @@ -138,6 +138,40 @@ object DatabricksDeltaProvider extends DeltaProviderImplBase {
cpuExec.writeOptions,
cpuExec.ifNotExists)
}

override def tagForGpu(
cpuExec: AtomicReplaceTableAsSelectExec,
meta: AtomicReplaceTableAsSelectExecMeta): Unit = {
require(isSupportedCatalog(cpuExec.catalog.getClass))
if (!meta.conf.isDeltaWriteEnabled) {
meta.willNotWorkOnGpu("Delta Lake output acceleration has been disabled. To enable set " +
s"${RapidsConf.ENABLE_DELTA_WRITE} to true")
}
val properties = cpuExec.properties
val provider = properties.getOrElse("provider",
cpuExec.conf.getConf(SQLConf.DEFAULT_DATA_SOURCE_NAME))
if (!DeltaSourceUtils.isDeltaDataSourceName(provider)) {
meta.willNotWorkOnGpu(s"table provider '$provider' is not a Delta Lake provider")
}
RapidsDeltaUtils.tagForDeltaWrite(meta, cpuExec.query.schema, None,
cpuExec.writeOptions.asCaseSensitiveMap().asScala.toMap, cpuExec.session)
}

override def convertToGpu(
cpuExec: AtomicReplaceTableAsSelectExec,
meta: AtomicReplaceTableAsSelectExecMeta): GpuExec = {
GpuAtomicReplaceTableAsSelectExec(
cpuExec.output,
new GpuDeltaCatalog(cpuExec.catalog, meta.conf),
cpuExec.ident,
cpuExec.partitioning,
cpuExec.plan,
meta.childPlans.head.convertIfNeeded(),
cpuExec.tableSpec,
cpuExec.writeOptions,
cpuExec.orCreate,
cpuExec.invalidateCache)
}
}

class DeltaCreatableRelationProviderMeta(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.delta.catalog.DeltaCatalog
import org.apache.spark.sql.delta.rapids.DeltaRuntimeShim
import org.apache.spark.sql.delta.sources.{DeltaDataSource, DeltaSourceUtils}
import org.apache.spark.sql.execution.datasources.{FileFormat, SaveIntoDataSourceCommand}
import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec
import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.ExternalSource
import org.apache.spark.sql.rapids.execution.UnshimmedTrampolineUtil
Expand Down Expand Up @@ -66,15 +66,33 @@ abstract class DeltaIOProvider extends DeltaProviderImplBase {
meta.willNotWorkOnGpu("Delta Lake output acceleration has been disabled. To enable set " +
s"${RapidsConf.ENABLE_DELTA_WRITE} to true")
}
val properties = cpuExec.properties
val provider = properties.getOrElse("provider",
cpuExec.conf.getConf(SQLConf.DEFAULT_DATA_SOURCE_NAME))
if (!DeltaSourceUtils.isDeltaDataSourceName(provider)) {
meta.willNotWorkOnGpu(s"table provider '$provider' is not a Delta Lake provider")
checkDeltaProvider(meta, cpuExec.properties, cpuExec.conf)
RapidsDeltaUtils.tagForDeltaWrite(meta, cpuExec.query.schema, None,
cpuExec.writeOptions.asCaseSensitiveMap().asScala.toMap, cpuExec.session)
}

override def tagForGpu(
cpuExec: AtomicReplaceTableAsSelectExec,
meta: AtomicReplaceTableAsSelectExecMeta): Unit = {
require(isSupportedCatalog(cpuExec.catalog.getClass))
if (!meta.conf.isDeltaWriteEnabled) {
meta.willNotWorkOnGpu("Delta Lake output acceleration has been disabled. To enable set " +
s"${RapidsConf.ENABLE_DELTA_WRITE} to true")
}
checkDeltaProvider(meta, cpuExec.properties, cpuExec.conf)
RapidsDeltaUtils.tagForDeltaWrite(meta, cpuExec.query.schema, None,
cpuExec.writeOptions.asCaseSensitiveMap().asScala.toMap, cpuExec.session)
}

private def checkDeltaProvider(
meta: RapidsMeta[_, _, _],
properties: Map[String, String],
conf: SQLConf): Unit = {
val provider = properties.getOrElse("provider", conf.getConf(SQLConf.DEFAULT_DATA_SOURCE_NAME))
if (!DeltaSourceUtils.isDeltaDataSourceName(provider)) {
meta.willNotWorkOnGpu(s"table provider '$provider' is not a Delta Lake provider")
}
}
}

class DeltaCreatableRelationProviderMeta(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.nvidia.spark.rapids.delta.delta20x

import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, GpuExec, GpuOverrides, GpuReadParquetFileFormat, RunnableCommandRule, SparkPlanMeta}
import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, AtomicReplaceTableAsSelectExecMeta, GpuExec, GpuOverrides, GpuReadParquetFileFormat, RunnableCommandRule, SparkPlanMeta}
import com.nvidia.spark.rapids.delta.DeltaIOProvider

import org.apache.spark.sql.delta.DeltaParquetFileFormat
Expand All @@ -26,8 +26,8 @@ import org.apache.spark.sql.delta.rapids.DeltaRuntimeShim
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.execution.datasources.FileFormat
import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec
import org.apache.spark.sql.execution.datasources.v2.rapids.GpuAtomicCreateTableAsSelectExec
import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec}
import org.apache.spark.sql.execution.datasources.v2.rapids.{GpuAtomicCreateTableAsSelectExec, GpuAtomicReplaceTableAsSelectExec}

object Delta20xProvider extends DeltaIOProvider {

Expand Down Expand Up @@ -77,4 +77,20 @@ object Delta20xProvider extends DeltaIOProvider {
cpuExec.writeOptions,
cpuExec.ifNotExists)
}

override def convertToGpu(
cpuExec: AtomicReplaceTableAsSelectExec,
meta: AtomicReplaceTableAsSelectExecMeta): GpuExec = {
val cpuCatalog = cpuExec.catalog.asInstanceOf[DeltaCatalog]
GpuAtomicReplaceTableAsSelectExec(
DeltaRuntimeShim.getGpuDeltaCatalog(cpuCatalog, meta.conf),
cpuExec.ident,
cpuExec.partitioning,
cpuExec.plan,
meta.childPlans.head.convertIfNeeded(),
cpuExec.properties,
cpuExec.writeOptions,
cpuExec.orCreate,
cpuExec.invalidateCache)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.nvidia.spark.rapids.delta.delta21x

import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, GpuExec, GpuOverrides, GpuReadParquetFileFormat, RunnableCommandRule, SparkPlanMeta}
import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, AtomicReplaceTableAsSelectExecMeta, GpuExec, GpuOverrides, GpuReadParquetFileFormat, RunnableCommandRule, SparkPlanMeta}
import com.nvidia.spark.rapids.delta.DeltaIOProvider

import org.apache.spark.sql.delta.DeltaParquetFileFormat
Expand All @@ -26,8 +26,8 @@ import org.apache.spark.sql.delta.rapids.DeltaRuntimeShim
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.execution.datasources.FileFormat
import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec
import org.apache.spark.sql.execution.datasources.v2.rapids.GpuAtomicCreateTableAsSelectExec
import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec}
import org.apache.spark.sql.execution.datasources.v2.rapids.{GpuAtomicCreateTableAsSelectExec, GpuAtomicReplaceTableAsSelectExec}

object Delta21xProvider extends DeltaIOProvider {

Expand Down Expand Up @@ -77,4 +77,20 @@ object Delta21xProvider extends DeltaIOProvider {
cpuExec.writeOptions,
cpuExec.ifNotExists)
}

override def convertToGpu(
cpuExec: AtomicReplaceTableAsSelectExec,
meta: AtomicReplaceTableAsSelectExecMeta): GpuExec = {
val cpuCatalog = cpuExec.catalog.asInstanceOf[DeltaCatalog]
GpuAtomicReplaceTableAsSelectExec(
DeltaRuntimeShim.getGpuDeltaCatalog(cpuCatalog, meta.conf),
cpuExec.ident,
cpuExec.partitioning,
cpuExec.plan,
meta.childPlans.head.convertIfNeeded(),
cpuExec.tableSpec,
cpuExec.writeOptions,
cpuExec.orCreate,
cpuExec.invalidateCache)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.nvidia.spark.rapids.delta.delta22x

import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, GpuExec, GpuOverrides, GpuReadParquetFileFormat, RunnableCommandRule, SparkPlanMeta}
import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, AtomicReplaceTableAsSelectExecMeta, GpuExec, GpuOverrides, GpuReadParquetFileFormat, RunnableCommandRule, SparkPlanMeta}
import com.nvidia.spark.rapids.delta.DeltaIOProvider

import org.apache.spark.sql.delta.DeltaParquetFileFormat
Expand All @@ -26,8 +26,8 @@ import org.apache.spark.sql.delta.rapids.DeltaRuntimeShim
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.execution.datasources.FileFormat
import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec
import org.apache.spark.sql.execution.datasources.v2.rapids.GpuAtomicCreateTableAsSelectExec
import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec}
import org.apache.spark.sql.execution.datasources.v2.rapids.{GpuAtomicCreateTableAsSelectExec, GpuAtomicReplaceTableAsSelectExec}

object Delta22xProvider extends DeltaIOProvider {

Expand Down Expand Up @@ -77,4 +77,20 @@ object Delta22xProvider extends DeltaIOProvider {
cpuExec.writeOptions,
cpuExec.ifNotExists)
}

override def convertToGpu(
cpuExec: AtomicReplaceTableAsSelectExec,
meta: AtomicReplaceTableAsSelectExecMeta): GpuExec = {
val cpuCatalog = cpuExec.catalog.asInstanceOf[DeltaCatalog]
GpuAtomicReplaceTableAsSelectExec(
DeltaRuntimeShim.getGpuDeltaCatalog(cpuCatalog, meta.conf),
cpuExec.ident,
cpuExec.partitioning,
cpuExec.plan,
meta.childPlans.head.convertIfNeeded(),
cpuExec.tableSpec,
cpuExec.writeOptions,
cpuExec.orCreate,
cpuExec.invalidateCache)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.nvidia.spark.rapids.delta.delta24x

import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, GpuExec, GpuOverrides, GpuReadParquetFileFormat, RunnableCommandRule, SparkPlanMeta}
import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, AtomicReplaceTableAsSelectExecMeta, GpuExec, GpuOverrides, GpuReadParquetFileFormat, RunnableCommandRule, SparkPlanMeta}
import com.nvidia.spark.rapids.delta.DeltaIOProvider

import org.apache.spark.sql.delta.DeltaParquetFileFormat
Expand All @@ -27,8 +27,8 @@ import org.apache.spark.sql.delta.rapids.DeltaRuntimeShim
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.execution.datasources.FileFormat
import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec
import org.apache.spark.sql.execution.datasources.v2.rapids.GpuAtomicCreateTableAsSelectExec
import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec}
import org.apache.spark.sql.execution.datasources.v2.rapids.{GpuAtomicCreateTableAsSelectExec, GpuAtomicReplaceTableAsSelectExec}

object Delta24xProvider extends DeltaIOProvider {

Expand Down Expand Up @@ -91,4 +91,20 @@ object Delta24xProvider extends DeltaIOProvider {
cpuExec.writeOptions,
cpuExec.ifNotExists)
}

override def convertToGpu(
cpuExec: AtomicReplaceTableAsSelectExec,
meta: AtomicReplaceTableAsSelectExecMeta): GpuExec = {
val cpuCatalog = cpuExec.catalog.asInstanceOf[DeltaCatalog]
GpuAtomicReplaceTableAsSelectExec(
DeltaRuntimeShim.getGpuDeltaCatalog(cpuCatalog, meta.conf),
cpuExec.ident,
cpuExec.partitioning,
cpuExec.plan,
meta.childPlans.head.convertIfNeeded(),
cpuExec.tableSpec,
cpuExec.writeOptions,
cpuExec.orCreate,
cpuExec.invalidateCache)
}
}
30 changes: 22 additions & 8 deletions integration_tests/src/main/python/delta_lake_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,25 +256,39 @@ def test_delta_overwrite_round_trip_unmanaged(spark_tmp_path):
def test_delta_append_round_trip_unmanaged(spark_tmp_path):
do_update_round_trip_managed(spark_tmp_path, "append")

@allow_non_gpu(*delta_meta_allow)
@delta_lake
@ignore_order(local=True)
@pytest.mark.skipif(is_before_spark_320(), reason="Delta Lake writes are not supported before Spark 3.2.x")
@pytest.mark.parametrize("gens", parquet_write_gens_list, ids=idfn)
def test_delta_atomic_create_table_as_select(gens, spark_tmp_table_factory, spark_tmp_path):
def _atomic_write_table_as_select(gens, spark_tmp_table_factory, spark_tmp_path, overwrite):
gen_list = [("c" + str(i), gen) for i, gen in enumerate(gens)]
data_path = spark_tmp_path + "/DELTA_DATA"
confs = copy_and_update(writer_confs, delta_writes_enabled_conf)
path_to_table= {}
def do_write(spark, path):
table = spark_tmp_table_factory.get()
path_to_table[path] = table
gen_df(spark, gen_list).coalesce(1).write.format("delta").saveAsTable(table)
writer = gen_df(spark, gen_list).coalesce(1).write.format("delta")
if overwrite:
writer = writer.mode("overwrite")
writer.saveAsTable(table)
assert_gpu_and_cpu_writes_are_equal_collect(
do_write,
lambda spark, path: spark.read.format("delta").table(path_to_table[path]),
data_path,
conf = confs)
conf=confs)

@allow_non_gpu(*delta_meta_allow)
@delta_lake
@ignore_order(local=True)
@pytest.mark.skipif(is_before_spark_320(), reason="Delta Lake writes are not supported before Spark 3.2.x")
@pytest.mark.parametrize("gens", parquet_write_gens_list, ids=idfn)
def test_delta_atomic_create_table_as_select(gens, spark_tmp_table_factory, spark_tmp_path):
_atomic_write_table_as_select(gens, spark_tmp_table_factory, spark_tmp_path, overwrite=False)

@allow_non_gpu(*delta_meta_allow)
@delta_lake
@ignore_order(local=True)
@pytest.mark.skipif(is_before_spark_320(), reason="Delta Lake writes are not supported before Spark 3.2.x")
@pytest.mark.parametrize("gens", parquet_write_gens_list, ids=idfn)
def test_delta_atomic_replace_table_as_select(gens, spark_tmp_table_factory, spark_tmp_path):
_atomic_write_table_as_select(gens, spark_tmp_table_factory, spark_tmp_path, overwrite=True)

@allow_non_gpu(*delta_meta_allow)
@delta_lake
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@

package com.nvidia.spark.rapids.delta

import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, CreatableRelationProviderRule, ExecRule, GpuExec, RunnableCommandRule, ShimLoader, SparkPlanMeta}
import com.nvidia.spark.rapids.{AtomicCreateTableAsSelectExecMeta, AtomicReplaceTableAsSelectExecMeta, CreatableRelationProviderRule, ExecRule, GpuExec, RunnableCommandRule, ShimLoader, SparkPlanMeta}

import org.apache.spark.sql.Strategy
import org.apache.spark.sql.connector.catalog.StagingTableCatalog
import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.execution.datasources.FileFormat
import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec
import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec}
import org.apache.spark.sql.sources.CreatableRelationProvider

/** Probe interface to determine which Delta Lake provider to use. */
Expand Down Expand Up @@ -58,6 +58,14 @@ trait DeltaProvider {
def convertToGpu(
cpuExec: AtomicCreateTableAsSelectExec,
meta: AtomicCreateTableAsSelectExecMeta): GpuExec

def tagForGpu(
cpuExec: AtomicReplaceTableAsSelectExec,
meta: AtomicReplaceTableAsSelectExecMeta): Unit

def convertToGpu(
cpuExec: AtomicReplaceTableAsSelectExec,
meta: AtomicReplaceTableAsSelectExecMeta): GpuExec
}

object DeltaProvider {
Expand Down Expand Up @@ -100,4 +108,16 @@ object NoDeltaProvider extends DeltaProvider {
meta: AtomicCreateTableAsSelectExecMeta): GpuExec = {
throw new IllegalStateException("catalog not supported, should not be called")
}

override def tagForGpu(
cpuExec: AtomicReplaceTableAsSelectExec,
meta: AtomicReplaceTableAsSelectExecMeta): Unit = {
throw new IllegalStateException("catalog not supported, should not be called")
}

override def convertToGpu(
cpuExec: AtomicReplaceTableAsSelectExec,
meta: AtomicReplaceTableAsSelectExecMeta): GpuExec = {
throw new IllegalStateException("catalog not supported, should not be called")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.nvidia.spark.rapids

import org.apache.spark.sql.execution.datasources.v2.AtomicCreateTableAsSelectExec
import org.apache.spark.sql.execution.datasources.v2.{AtomicCreateTableAsSelectExec, AtomicReplaceTableAsSelectExec}
import org.apache.spark.sql.rapids.ExternalSource

class AtomicCreateTableAsSelectExecMeta(
Expand All @@ -34,3 +34,19 @@ class AtomicCreateTableAsSelectExecMeta(
ExternalSource.convertToGpu(wrapped, this)
}
}

class AtomicReplaceTableAsSelectExecMeta(
wrapped: AtomicReplaceTableAsSelectExec,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends SparkPlanMeta[AtomicReplaceTableAsSelectExec](wrapped, conf, parent, rule) {

override def tagPlanForGpu(): Unit = {
ExternalSource.tagForGpu(wrapped, this)
}

override def convertToGpu(): GpuExec = {
ExternalSource.convertToGpu(wrapped, this)
}
}
Loading