Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add shim for Databricks 10.4 [databricks] #4974

Merged
merged 9 commits into from
Mar 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions docs/additional-functionality/rapids-shuffle.md
Original file line number Diff line number Diff line change
Expand Up @@ -285,19 +285,20 @@ In this section, we are using a docker container built using the sample dockerfi
1. Choose the version of the shuffle manager that matches your Spark version.
Currently we support:

| Spark Shim | spark.shuffle.manager value |
| --------------| -------------------------------------------------------- |
| 3.0.1 | com.nvidia.spark.rapids.spark301.RapidsShuffleManager |
| 3.0.2 | com.nvidia.spark.rapids.spark302.RapidsShuffleManager |
| 3.0.3 | com.nvidia.spark.rapids.spark303.RapidsShuffleManager |
| 3.1.1 | com.nvidia.spark.rapids.spark311.RapidsShuffleManager |
| 3.1.1 CDH | com.nvidia.spark.rapids.spark311cdh.RapidsShuffleManager |
| 3.1.2 | com.nvidia.spark.rapids.spark312.RapidsShuffleManager |
| 3.1.3 | com.nvidia.spark.rapids.spark313.RapidsShuffleManager |
| 3.2.0 | com.nvidia.spark.rapids.spark320.RapidsShuffleManager |
| 3.2.1 | com.nvidia.spark.rapids.spark321.RapidsShuffleManager |
| Databricks 7.3| com.nvidia.spark.rapids.spark301db.RapidsShuffleManager |
| Databricks 9.1| com.nvidia.spark.rapids.spark312db.RapidsShuffleManager |
| Spark Shim | spark.shuffle.manager value |
| --------------- | -------------------------------------------------------- |
| 3.0.1 | com.nvidia.spark.rapids.spark301.RapidsShuffleManager |
| 3.0.2 | com.nvidia.spark.rapids.spark302.RapidsShuffleManager |
| 3.0.3 | com.nvidia.spark.rapids.spark303.RapidsShuffleManager |
| 3.1.1 | com.nvidia.spark.rapids.spark311.RapidsShuffleManager |
| 3.1.1 CDH | com.nvidia.spark.rapids.spark311cdh.RapidsShuffleManager |
| 3.1.2 | com.nvidia.spark.rapids.spark312.RapidsShuffleManager |
| 3.1.3 | com.nvidia.spark.rapids.spark313.RapidsShuffleManager |
| 3.2.0 | com.nvidia.spark.rapids.spark320.RapidsShuffleManager |
| 3.2.1 | com.nvidia.spark.rapids.spark321.RapidsShuffleManager |
| Databricks 7.3 | com.nvidia.spark.rapids.spark301db.RapidsShuffleManager |
| Databricks 9.1 | com.nvidia.spark.rapids.spark312db.RapidsShuffleManager |
| Databricks 10.4 | com.nvidia.spark.rapids.spark321db.RapidsShuffleManager |

2. Settings for UCX 1.11.2+:

Expand Down
4 changes: 2 additions & 2 deletions integration_tests/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,8 @@
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-common</artifactId>
<version>${spark.version}</version>
<artifactId>hadoop-client</artifactId>
<version>${hadoop.client.version}</version>
<scope>provided</scope>
</dependency>
</dependencies>
Expand Down
4 changes: 2 additions & 2 deletions integration_tests/src/main/python/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql, assert_gpu_and_cpu_error, assert_gpu_fallback_collect
from data_gen import *
from spark_session import is_before_spark_311, is_before_spark_330
from spark_session import is_before_spark_311, is_before_spark_330, is_databricks104_or_later
from pyspark.sql.types import *
from pyspark.sql.types import IntegralType
from pyspark.sql.functions import array_contains, col, isnan, element_at
Expand Down Expand Up @@ -78,7 +78,7 @@ def test_array_item_with_strict_index(strict_index_enabled, index):
reason="Only in Spark [3.1.1, 3.3.0) with ANSI mode, it throws exceptions for invalid index")
@pytest.mark.parametrize('index', [-2, 100, array_neg_index_gen, array_out_index_gen], ids=idfn)
def test_array_item_ansi_fail_invalid_index(index):
message = "java.lang.ArrayIndexOutOfBoundsException"
message = "SparkArrayIndexOutOfBoundsException" if is_databricks104_or_later() else "java.lang.ArrayIndexOutOfBoundsException"
jlowe marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(index, int):
test_func = lambda spark: unary_op_df(spark, ArrayGen(int_gen)).select(col('a')[index]).collect()
else:
Expand Down
5 changes: 4 additions & 1 deletion integration_tests/src/main/python/hash_aggregate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pyspark.sql.types import *
from marks import *
import pyspark.sql.functions as f
from spark_session import is_before_spark_311, with_cpu_session
from spark_session import is_before_spark_311, is_databricks104_or_later, with_cpu_session

pytestmark = pytest.mark.nightly_resource_consuming_test

Expand Down Expand Up @@ -421,6 +421,7 @@ def test_hash_grpby_avg(data_gen, conf):
@pytest.mark.parametrize('data_gen', [
StructGen(children=[('a', int_gen), ('b', int_gen)],nullable=False,
special_cases=[((None, None), 400.0), ((None, -1542301795), 100.0)])], ids=idfn)
@pytest.mark.xfail(condition=is_databricks104_or_later(), reason='https://github.com/NVIDIA/spark-rapids/issues/4963')
def test_hash_avg_nulls_partial_only(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: gen_df(spark, data_gen, length=2).agg(f.avg('b')),
Expand Down Expand Up @@ -791,6 +792,7 @@ def test_hash_groupby_collect_partial_replace_fallback(data_gen,
@pytest.mark.parametrize('replace_mode', _replace_modes_single_distinct, ids=idfn)
@pytest.mark.parametrize('aqe_enabled', ['false', 'true'], ids=idfn)
@pytest.mark.parametrize('use_obj_hash_agg', ['false', 'true'], ids=idfn)
@pytest.mark.xfail(condition=is_databricks104_or_later(), reason='https://github.com/NVIDIA/spark-rapids/issues/4963')
def test_hash_groupby_collect_partial_replace_with_distinct_fallback(data_gen,
replace_mode,
aqe_enabled,
Expand Down Expand Up @@ -1668,6 +1670,7 @@ def test_groupby_std_variance_nulls(data_gen, conf, ansi_enabled):
@pytest.mark.parametrize('conf', get_params(_confs, params_markers_for_confs), ids=idfn)
@pytest.mark.parametrize('replace_mode', _replace_modes_non_distinct, ids=idfn)
@pytest.mark.parametrize('aqe_enabled', ['false', 'true'], ids=idfn)
@pytest.mark.xfail(condition=is_databricks104_or_later(), reason='https://github.com/NVIDIA/spark-rapids/issues/4963')
def test_groupby_std_variance_partial_replace_fallback(data_gen,
conf,
replace_mode,
Expand Down
6 changes: 3 additions & 3 deletions integration_tests/src/main/python/map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
assert_gpu_fallback_collect
from data_gen import *
from marks import incompat, allow_non_gpu
from spark_session import is_before_spark_311, is_before_spark_330
from spark_session import is_before_spark_311, is_before_spark_330, is_databricks104_or_later
from pyspark.sql.types import *
from pyspark.sql.types import IntegralType

Expand Down Expand Up @@ -215,7 +215,7 @@ def test_str_to_map_expr_with_all_regex_delimiters():
reason="Only in Spark 3.1.1+ (< 3.3.0) + ANSI mode, map key throws on no such element")
@pytest.mark.parametrize('data_gen', [simple_string_to_string_map_gen], ids=idfn)
def test_simple_get_map_value_ansi_fail(data_gen):
message = "java.util.NoSuchElementException"
message = "org.apache.spark.SparkNoSuchElementException" if is_databricks104_or_later() else "java.util.NoSuchElementException"
assert_gpu_and_cpu_error(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'a["NOT_FOUND"]').collect(),
Expand Down Expand Up @@ -264,7 +264,7 @@ def test_simple_element_at_map(data_gen):
@pytest.mark.skipif(is_before_spark_311(), reason="Only in Spark 3.1.1 + ANSI mode, map key throws on no such element")
@pytest.mark.parametrize('data_gen', [simple_string_to_string_map_gen], ids=idfn)
def test_map_element_at_ansi_fail(data_gen):
message = "org.apache.spark.SparkNoSuchElementException" if not is_before_spark_330() else "java.util.NoSuchElementException"
message = "org.apache.spark.SparkNoSuchElementException" if (not is_before_spark_330() or is_databricks104_or_later()) else "java.util.NoSuchElementException"
# For 3.3.0+ strictIndexOperator should not affect element_at
test_conf=copy_and_update(ansi_enabled_conf, {'spark.sql.ansi.strictIndexOperator': 'false'})
assert_gpu_and_cpu_error(
Expand Down
14 changes: 12 additions & 2 deletions integration_tests/src/main/python/spark_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,16 @@ def is_before_spark_330():
def is_spark_330_or_later():
return spark_version() >= "3.3.0"

def is_databricks91_or_later():
def is_databricks_version_or_later(major, minor):
spark = get_spark_i_know_what_i_am_doing()
return spark.conf.get("spark.databricks.clusterUsageTags.sparkVersion", "") >= "9.1"
version = spark.conf.get("spark.databricks.clusterUsageTags.sparkVersion", "0.0")
parts = version.split(".")
if (len(parts) < 2):
raise RuntimeError("Unable to determine Databricks version from version string: " + version)
return int(parts[0]) >= major and int(parts[1]) >= minor

def is_databricks91_or_later():
return is_databricks_version_or_later(9, 1)

def is_databricks104_or_later():
return is_databricks_version_or_later(10, 4)
Loading