Skip to content

Commit

Permalink
Fallback to CPU for Parquet reads with _databricks_internal columns…
Browse files Browse the repository at this point in the history
… [databricks] (#6257)

* Fallback to CPU for Parquet reads with _databricks_internal columns

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* Add integration test

* fix import

* fix test

* fix test

* test passes

Signed-off-by: Andy Grove <andygrove@nvidia.com>
  • Loading branch information
andygrove authored Aug 10, 2022
1 parent eb36525 commit f3f6bab
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 1 deletion.
30 changes: 29 additions & 1 deletion integration_tests/src/main/python/delta_lake_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# limitations under the License.

import pytest
from pyspark.sql import Row
from asserts import assert_gpu_fallback_collect
from marks import allow_non_gpu, delta_lake
from spark_session import with_cpu_session, is_databricks91_or_later
from spark_session import with_cpu_session, with_gpu_session, is_databricks91_or_later, is_databricks104_or_later

_conf = {'spark.rapids.sql.explain': 'ALL'}

Expand All @@ -34,3 +35,30 @@ def setup_delta_table(spark):
assert_gpu_fallback_collect(
lambda spark : spark.read.json("/tmp/delta-table/{}/_delta_log/00000000000000000000.json".format(table)),
"FileSourceScanExec", conf = _conf)

@delta_lake
@pytest.mark.skipif(not is_databricks104_or_later(), \
reason="This test is specific to Databricks because we only fall back to CPU for merges on Databricks")
@allow_non_gpu(any = True)
def test_delta_merge_query(spark_tmp_table_factory):
table_name_1 = spark_tmp_table_factory.get()
table_name_2 = spark_tmp_table_factory.get()
def setup_delta_table1(spark):
df = spark.createDataFrame([('a', 10), ('b', 20)], ["c0", "c1"])
df.write.format("delta").save("/tmp/delta-table/{}".format(table_name_1))
def setup_delta_table2(spark):
df = spark.createDataFrame([('a', 30), ('c', 30)], ["c0", "c1"])
df.write.format("delta").save("/tmp/delta-table/{}".format(table_name_2))
with_cpu_session(setup_delta_table1)
with_cpu_session(setup_delta_table2)
def merge(spark):
spark.read.format("delta").load("/tmp/delta-table/{}".format(table_name_1)).createOrReplaceTempView("t1")
spark.read.format("delta").load("/tmp/delta-table/{}".format(table_name_2)).createOrReplaceTempView("t2")
return spark.sql("MERGE INTO t1 USING t2 ON t1.c0 = t2.c0 \
WHEN MATCHED THEN UPDATE SET c1 = t1.c1 + t2.c1 \
WHEN NOT MATCHED THEN INSERT (c0, c1) VALUES (t2.c0, t2.c1)").collect()
# run the MERGE on GPU
with_gpu_session(lambda spark : merge(spark), conf = _conf)
# check the results on CPU
result = with_cpu_session(lambda spark: spark.sql("SELECT * FROM t1 ORDER BY c0").collect(), conf=_conf)
assert [Row(c0='a', c1=40), Row(c0='b', c1=20), Row(c0='c', c1=30)] == result
Original file line number Diff line number Diff line change
Expand Up @@ -4431,6 +4431,10 @@ case class GpuOverrides() extends Rule[SparkPlan] with Logging {
*/
def isDeltaLakeMetadataQuery(plan: SparkPlan): Boolean = {
val deltaLogScans = PlanUtils.findOperators(plan, {
case f: FileSourceScanExec if f.requiredSchema.fields
.exists(_.name.startsWith("_databricks_internal")) =>
logDebug(s"Fallback for FileSourceScanExec with _databricks_internal: $f")
true
case f: FileSourceScanExec =>
// example filename: "file:/tmp/delta-table/_delta_log/00000000000000000000.json"
val found = f.relation.inputFiles.exists(name =>
Expand Down

0 comments on commit f3f6bab

Please sign in to comment.