Skip to content

Commit

Permalink
Revert unnecessary changes
Browse files Browse the repository at this point in the history
  • Loading branch information
liurenjie1024 committed Jun 14, 2024
1 parent 5e8af47 commit 9fa4cf2
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 21 deletions.
127 changes: 107 additions & 20 deletions integration_tests/src/main/python/delta_lake_merge_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,66 @@

import pyspark.sql.functions as f
import pytest
import string

from delta_lake_merge_common import *
from asserts import *
from data_gen import *
from delta_lake_utils import *
from marks import *
from pyspark.sql.types import *
from spark_session import is_before_spark_320, is_databricks_runtime, spark_version

# Databricks changes the number of files being written, so we cannot compare logs
num_slices_to_test = [10] if is_databricks_runtime() else [1, 10]

delta_merge_enabled_conf = copy_and_update(delta_writes_enabled_conf,
{"spark.rapids.sql.command.MergeIntoCommand": "true",
"spark.rapids.sql.command.MergeIntoCommandEdge": "true"})

def make_df(spark, gen, num_slices):
return three_col_df(spark, gen, SetValuesGen(StringType(), string.ascii_lowercase),
SetValuesGen(StringType(), string.ascii_uppercase), num_slices=num_slices)

def delta_sql_merge_test(spark_tmp_path, spark_tmp_table_factory, use_cdf,
src_table_func, dest_table_func, merge_sql, check_func,
partition_columns=None):
data_path = spark_tmp_path + "/DELTA_DATA"
src_table = spark_tmp_table_factory.get()
def setup_tables(spark):
setup_delta_dest_tables(spark, data_path, dest_table_func, use_cdf, partition_columns)
src_table_func(spark).createOrReplaceTempView(src_table)
def do_merge(spark, path):
dest_table = spark_tmp_table_factory.get()
read_delta_path(spark, path).createOrReplaceTempView(dest_table)
return spark.sql(merge_sql.format(src_table=src_table, dest_table=dest_table)).collect()
with_cpu_session(setup_tables)
check_func(data_path, do_merge)

def assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf,
src_table_func, dest_table_func, merge_sql,
compare_logs, partition_columns=None,
conf=delta_merge_enabled_conf):
def read_data(spark, path):
read_func = read_delta_path_with_cdf if use_cdf else read_delta_path
df = read_func(spark, path)
return df.sort(df.columns)
def checker(data_path, do_merge):
cpu_path = data_path + "/CPU"
gpu_path = data_path + "/GPU"
# compare resulting dataframe from the merge operation (some older Spark versions return empty here)
cpu_result = with_cpu_session(lambda spark: do_merge(spark, cpu_path), conf=conf)
gpu_result = with_gpu_session(lambda spark: do_merge(spark, gpu_path), conf=conf)
assert_equal(cpu_result, gpu_result)
# compare merged table data results, read both via CPU to make sure GPU write can be read by CPU
cpu_result = with_cpu_session(lambda spark: read_data(spark, cpu_path).collect(), conf=conf)
gpu_result = with_cpu_session(lambda spark: read_data(spark, gpu_path).collect(), conf=conf)
assert_equal(cpu_result, gpu_result)
# Using partition columns involves sorting, and there's no guarantees on the task
# partitioning due to random sampling.
if compare_logs and not partition_columns:
with_cpu_session(lambda spark: assert_gpu_and_cpu_delta_logs_equivalent(spark, data_path))
delta_sql_merge_test(spark_tmp_path, spark_tmp_table_factory, use_cdf,
src_table_func, dest_table_func, merge_sql, checker, partition_columns)

@allow_non_gpu(delta_write_fallback_allow, *delta_meta_allow)
@delta_lake
Expand Down Expand Up @@ -113,9 +162,16 @@ def test_delta_merge_partial_fallback_via_conf(spark_tmp_path, spark_tmp_table_f
@pytest.mark.parametrize("num_slices", num_slices_to_test, ids=idfn)
def test_delta_merge_not_match_insert_only(spark_tmp_path, spark_tmp_table_factory, table_ranges,
use_cdf, partition_columns, num_slices):
do_test_delta_merge_not_match_insert_only(spark_tmp_path, spark_tmp_table_factory,
table_ranges, use_cdf, partition_columns,
num_slices, num_slices == 1, delta_merge_enabled_conf)
src_range, dest_range = table_ranges
src_table_func = lambda spark: make_df(spark, SetValuesGen(IntegerType(), src_range), num_slices)
dest_table_func = lambda spark: make_df(spark, SetValuesGen(IntegerType(), dest_range), num_slices)
merge_sql = "MERGE INTO {dest_table} USING {src_table} ON {dest_table}.a == {src_table}.a" \
" WHEN NOT MATCHED THEN INSERT *"
# Non-deterministic input for each task means we can only reliably compare record counts when using only one task
compare_logs = num_slices == 1
assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf,
src_table_func, dest_table_func, merge_sql, compare_logs,
partition_columns)

@allow_non_gpu(*delta_meta_allow)
@delta_lake
Expand All @@ -130,9 +186,16 @@ def test_delta_merge_not_match_insert_only(spark_tmp_path, spark_tmp_table_facto
@pytest.mark.parametrize("num_slices", num_slices_to_test, ids=idfn)
def test_delta_merge_match_delete_only(spark_tmp_path, spark_tmp_table_factory, table_ranges,
use_cdf, partition_columns, num_slices):
do_test_delta_merge_match_delete_only(spark_tmp_path, spark_tmp_table_factory, table_ranges,
use_cdf, partition_columns, num_slices, num_slices == 1,
delta_merge_enabled_conf)
src_range, dest_range = table_ranges
src_table_func = lambda spark: make_df(spark, SetValuesGen(IntegerType(), src_range), num_slices)
dest_table_func = lambda spark: make_df(spark, SetValuesGen(IntegerType(), dest_range), num_slices)
merge_sql = "MERGE INTO {dest_table} USING {src_table} ON {dest_table}.a == {src_table}.a" \
" WHEN MATCHED THEN DELETE"
# Non-deterministic input for each task means we can only reliably compare record counts when using only one task
compare_logs = num_slices == 1
assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf,
src_table_func, dest_table_func, merge_sql, compare_logs,
partition_columns)

@allow_non_gpu(*delta_meta_allow)
@delta_lake
Expand All @@ -141,9 +204,15 @@ def test_delta_merge_match_delete_only(spark_tmp_path, spark_tmp_table_factory,
@pytest.mark.parametrize("use_cdf", [True, False], ids=idfn)
@pytest.mark.parametrize("num_slices", num_slices_to_test, ids=idfn)
def test_delta_merge_standard_upsert(spark_tmp_path, spark_tmp_table_factory, use_cdf, num_slices):
do_test_delta_merge_standard_upsert(spark_tmp_path, spark_tmp_table_factory, use_cdf,
num_slices, num_slices == 1, delta_merge_enabled_conf)

# Need to eliminate duplicate keys in the source table otherwise update semantics are ambiguous
src_table_func = lambda spark: two_col_df(spark, int_gen, string_gen, num_slices=num_slices).groupBy("a").agg(f.max("b").alias("b"))
dest_table_func = lambda spark: two_col_df(spark, int_gen, string_gen, seed=1, num_slices=num_slices)
merge_sql = "MERGE INTO {dest_table} USING {src_table} ON {dest_table}.a == {src_table}.a" \
" WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT *"
# Non-deterministic input for each task means we can only reliably compare record counts when using only one task
compare_logs = num_slices == 1
assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf,
src_table_func, dest_table_func, merge_sql, compare_logs)

@allow_non_gpu(*delta_meta_allow)
@delta_lake
Expand All @@ -163,10 +232,13 @@ def test_delta_merge_standard_upsert(spark_tmp_path, spark_tmp_table_factory, us
" WHEN NOT MATCHED AND s.b > 'f' AND s.b < 'z' THEN INSERT (b) VALUES ('not here')" ], ids=idfn)
@pytest.mark.parametrize("num_slices", num_slices_to_test, ids=idfn)
def test_delta_merge_upsert_with_condition(spark_tmp_path, spark_tmp_table_factory, use_cdf, merge_sql, num_slices):
do_test_delta_merge_upsert_with_condition(spark_tmp_path, spark_tmp_table_factory, use_cdf,
merge_sql, num_slices, num_slices == 1,
delta_merge_enabled_conf)

# Need to eliminate duplicate keys in the source table otherwise update semantics are ambiguous
src_table_func = lambda spark: two_col_df(spark, int_gen, string_gen, num_slices=num_slices).groupBy("a").agg(f.max("b").alias("b"))
dest_table_func = lambda spark: two_col_df(spark, int_gen, string_gen, seed=1, num_slices=num_slices)
# Non-deterministic input for each task means we can only reliably compare record counts when using only one task
compare_logs = num_slices == 1
assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf,
src_table_func, dest_table_func, merge_sql, compare_logs)

@allow_non_gpu(*delta_meta_allow)
@delta_lake
Expand All @@ -175,19 +247,34 @@ def test_delta_merge_upsert_with_condition(spark_tmp_path, spark_tmp_table_facto
@pytest.mark.parametrize("use_cdf", [True, False], ids=idfn)
@pytest.mark.parametrize("num_slices", num_slices_to_test, ids=idfn)
def test_delta_merge_upsert_with_unmatchable_match_condition(spark_tmp_path, spark_tmp_table_factory, use_cdf, num_slices):
do_test_delta_merge_upsert_with_unmatchable_match_condition(spark_tmp_path,
spark_tmp_table_factory, use_cdf,
num_slices, num_slices == 1,
delta_merge_enabled_conf)
# Need to eliminate duplicate keys in the source table otherwise update semantics are ambiguous
src_table_func = lambda spark: two_col_df(spark, int_gen, string_gen, num_slices=num_slices).groupBy("a").agg(f.max("b").alias("b"))
dest_table_func = lambda spark: two_col_df(spark, SetValuesGen(IntegerType(), range(100)), string_gen, seed=1, num_slices=num_slices)
merge_sql = "MERGE INTO {dest_table} USING {src_table} ON {dest_table}.a == {src_table}.a" \
" WHEN MATCHED AND {dest_table}.a > 100 THEN UPDATE SET *"
# Non-deterministic input for each task means we can only reliably compare record counts when using only one task
compare_logs = num_slices == 1
assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf,
src_table_func, dest_table_func, merge_sql, compare_logs)

@allow_non_gpu(*delta_meta_allow)
@delta_lake
@ignore_order
@pytest.mark.skipif(is_before_spark_320(), reason="Delta Lake writes are not supported before Spark 3.2.x")
@pytest.mark.parametrize("use_cdf", [True, False], ids=idfn)
def test_delta_merge_update_with_aggregation(spark_tmp_path, spark_tmp_table_factory, use_cdf):
do_test_delta_merge_update_with_aggregation(spark_tmp_path, spark_tmp_table_factory, use_cdf,
delta_merge_enabled_conf)
# Need to eliminate duplicate keys in the source table otherwise update semantics are ambiguous
src_table_func = lambda spark: spark.range(10).withColumn("x", f.col("id") + 1)\
.select(f.col("id"), (f.col("x") + 1).alias("x"))\
.drop_duplicates(["id"])\
.limit(10)
dest_table_func = lambda spark: spark.range(5).withColumn("x", f.col("id") + 1)
merge_sql = "MERGE INTO {dest_table} USING {src_table} ON {dest_table}.id == {src_table}.id" \
" WHEN MATCHED THEN UPDATE SET {dest_table}.x = {src_table}.x + 2" \
" WHEN NOT MATCHED AND {src_table}.x < 7 THEN INSERT *"

assert_delta_sql_merge_collect(spark_tmp_path, spark_tmp_table_factory, use_cdf,
src_table_func, dest_table_func, merge_sql, compare_logs=False)

@allow_non_gpu(*delta_meta_allow)
@delta_lake
Expand Down
2 changes: 1 addition & 1 deletion jenkins/spark-premerge-build.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash
#
# Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down

0 comments on commit 9fa4cf2

Please sign in to comment.