Skip to content

Commit

Permalink
Add shim for Databricks 10.4 [databricks] (#4974)
Browse files Browse the repository at this point in the history
* Add shim for Databricks 10.4

Signed-off-by: Jason Lowe <jlowe@nvidia.com>

* Add missing source directory for 304 shim

* Add missing import

* Remove unused HADOOP_FULL_VERSION

* Fix Databricks version check to numerically compare

* Add comments, code cleanup

* Add 311+-db directory

* Move FileOptions into a v2 shim

* Fix Databricks version check
  • Loading branch information
jlowe authored Mar 18, 2022
1 parent d50e120 commit 5818905
Show file tree
Hide file tree
Showing 66 changed files with 1,545 additions and 163 deletions.
27 changes: 14 additions & 13 deletions docs/additional-functionality/rapids-shuffle.md
Original file line number Diff line number Diff line change
Expand Up @@ -285,19 +285,20 @@ In this section, we are using a docker container built using the sample dockerfi
1. Choose the version of the shuffle manager that matches your Spark version.
Currently we support:

| Spark Shim | spark.shuffle.manager value |
| --------------| -------------------------------------------------------- |
| 3.0.1 | com.nvidia.spark.rapids.spark301.RapidsShuffleManager |
| 3.0.2 | com.nvidia.spark.rapids.spark302.RapidsShuffleManager |
| 3.0.3 | com.nvidia.spark.rapids.spark303.RapidsShuffleManager |
| 3.1.1 | com.nvidia.spark.rapids.spark311.RapidsShuffleManager |
| 3.1.1 CDH | com.nvidia.spark.rapids.spark311cdh.RapidsShuffleManager |
| 3.1.2 | com.nvidia.spark.rapids.spark312.RapidsShuffleManager |
| 3.1.3 | com.nvidia.spark.rapids.spark313.RapidsShuffleManager |
| 3.2.0 | com.nvidia.spark.rapids.spark320.RapidsShuffleManager |
| 3.2.1 | com.nvidia.spark.rapids.spark321.RapidsShuffleManager |
| Databricks 7.3| com.nvidia.spark.rapids.spark301db.RapidsShuffleManager |
| Databricks 9.1| com.nvidia.spark.rapids.spark312db.RapidsShuffleManager |
| Spark Shim | spark.shuffle.manager value |
| --------------- | -------------------------------------------------------- |
| 3.0.1 | com.nvidia.spark.rapids.spark301.RapidsShuffleManager |
| 3.0.2 | com.nvidia.spark.rapids.spark302.RapidsShuffleManager |
| 3.0.3 | com.nvidia.spark.rapids.spark303.RapidsShuffleManager |
| 3.1.1 | com.nvidia.spark.rapids.spark311.RapidsShuffleManager |
| 3.1.1 CDH | com.nvidia.spark.rapids.spark311cdh.RapidsShuffleManager |
| 3.1.2 | com.nvidia.spark.rapids.spark312.RapidsShuffleManager |
| 3.1.3 | com.nvidia.spark.rapids.spark313.RapidsShuffleManager |
| 3.2.0 | com.nvidia.spark.rapids.spark320.RapidsShuffleManager |
| 3.2.1 | com.nvidia.spark.rapids.spark321.RapidsShuffleManager |
| Databricks 7.3 | com.nvidia.spark.rapids.spark301db.RapidsShuffleManager |
| Databricks 9.1 | com.nvidia.spark.rapids.spark312db.RapidsShuffleManager |
| Databricks 10.4 | com.nvidia.spark.rapids.spark321db.RapidsShuffleManager |

2. Settings for UCX 1.11.2+:

Expand Down
4 changes: 2 additions & 2 deletions integration_tests/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,8 @@
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-common</artifactId>
<version>${spark.version}</version>
<artifactId>hadoop-client</artifactId>
<version>${hadoop.client.version}</version>
<scope>provided</scope>
</dependency>
</dependencies>
Expand Down
4 changes: 2 additions & 2 deletions integration_tests/src/main/python/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql, assert_gpu_and_cpu_error, assert_gpu_fallback_collect
from data_gen import *
from spark_session import is_before_spark_311, is_before_spark_330
from spark_session import is_before_spark_311, is_before_spark_330, is_databricks104_or_later
from pyspark.sql.types import *
from pyspark.sql.types import IntegralType
from pyspark.sql.functions import array_contains, col, isnan, element_at
Expand Down Expand Up @@ -78,7 +78,7 @@ def test_array_item_with_strict_index(strict_index_enabled, index):
reason="Only in Spark [3.1.1, 3.3.0) with ANSI mode, it throws exceptions for invalid index")
@pytest.mark.parametrize('index', [-2, 100, array_neg_index_gen, array_out_index_gen], ids=idfn)
def test_array_item_ansi_fail_invalid_index(index):
message = "java.lang.ArrayIndexOutOfBoundsException"
message = "SparkArrayIndexOutOfBoundsException" if is_databricks104_or_later() else "java.lang.ArrayIndexOutOfBoundsException"
if isinstance(index, int):
test_func = lambda spark: unary_op_df(spark, ArrayGen(int_gen)).select(col('a')[index]).collect()
else:
Expand Down
5 changes: 4 additions & 1 deletion integration_tests/src/main/python/hash_aggregate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pyspark.sql.types import *
from marks import *
import pyspark.sql.functions as f
from spark_session import is_before_spark_311, with_cpu_session
from spark_session import is_before_spark_311, is_databricks104_or_later, with_cpu_session

pytestmark = pytest.mark.nightly_resource_consuming_test

Expand Down Expand Up @@ -421,6 +421,7 @@ def test_hash_grpby_avg(data_gen, conf):
@pytest.mark.parametrize('data_gen', [
StructGen(children=[('a', int_gen), ('b', int_gen)],nullable=False,
special_cases=[((None, None), 400.0), ((None, -1542301795), 100.0)])], ids=idfn)
@pytest.mark.xfail(condition=is_databricks104_or_later(), reason='https://github.com/NVIDIA/spark-rapids/issues/4963')
def test_hash_avg_nulls_partial_only(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: gen_df(spark, data_gen, length=2).agg(f.avg('b')),
Expand Down Expand Up @@ -791,6 +792,7 @@ def test_hash_groupby_collect_partial_replace_fallback(data_gen,
@pytest.mark.parametrize('replace_mode', _replace_modes_single_distinct, ids=idfn)
@pytest.mark.parametrize('aqe_enabled', ['false', 'true'], ids=idfn)
@pytest.mark.parametrize('use_obj_hash_agg', ['false', 'true'], ids=idfn)
@pytest.mark.xfail(condition=is_databricks104_or_later(), reason='https://github.com/NVIDIA/spark-rapids/issues/4963')
def test_hash_groupby_collect_partial_replace_with_distinct_fallback(data_gen,
replace_mode,
aqe_enabled,
Expand Down Expand Up @@ -1668,6 +1670,7 @@ def test_groupby_std_variance_nulls(data_gen, conf, ansi_enabled):
@pytest.mark.parametrize('conf', get_params(_confs, params_markers_for_confs), ids=idfn)
@pytest.mark.parametrize('replace_mode', _replace_modes_non_distinct, ids=idfn)
@pytest.mark.parametrize('aqe_enabled', ['false', 'true'], ids=idfn)
@pytest.mark.xfail(condition=is_databricks104_or_later(), reason='https://github.com/NVIDIA/spark-rapids/issues/4963')
def test_groupby_std_variance_partial_replace_fallback(data_gen,
conf,
replace_mode,
Expand Down
6 changes: 3 additions & 3 deletions integration_tests/src/main/python/map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
assert_gpu_fallback_collect
from data_gen import *
from marks import incompat, allow_non_gpu
from spark_session import is_before_spark_311, is_before_spark_330
from spark_session import is_before_spark_311, is_before_spark_330, is_databricks104_or_later
from pyspark.sql.types import *
from pyspark.sql.types import IntegralType

Expand Down Expand Up @@ -215,7 +215,7 @@ def test_str_to_map_expr_with_all_regex_delimiters():
reason="Only in Spark 3.1.1+ (< 3.3.0) + ANSI mode, map key throws on no such element")
@pytest.mark.parametrize('data_gen', [simple_string_to_string_map_gen], ids=idfn)
def test_simple_get_map_value_ansi_fail(data_gen):
message = "java.util.NoSuchElementException"
message = "org.apache.spark.SparkNoSuchElementException" if is_databricks104_or_later() else "java.util.NoSuchElementException"
assert_gpu_and_cpu_error(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
'a["NOT_FOUND"]').collect(),
Expand Down Expand Up @@ -264,7 +264,7 @@ def test_simple_element_at_map(data_gen):
@pytest.mark.skipif(is_before_spark_311(), reason="Only in Spark 3.1.1 + ANSI mode, map key throws on no such element")
@pytest.mark.parametrize('data_gen', [simple_string_to_string_map_gen], ids=idfn)
def test_map_element_at_ansi_fail(data_gen):
message = "org.apache.spark.SparkNoSuchElementException" if not is_before_spark_330() else "java.util.NoSuchElementException"
message = "org.apache.spark.SparkNoSuchElementException" if (not is_before_spark_330() or is_databricks104_or_later()) else "java.util.NoSuchElementException"
# For 3.3.0+ strictIndexOperator should not affect element_at
test_conf=copy_and_update(ansi_enabled_conf, {'spark.sql.ansi.strictIndexOperator': 'false'})
assert_gpu_and_cpu_error(
Expand Down
14 changes: 12 additions & 2 deletions integration_tests/src/main/python/spark_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,16 @@ def is_before_spark_330():
def is_spark_330_or_later():
return spark_version() >= "3.3.0"

def is_databricks91_or_later():
def is_databricks_version_or_later(major, minor):
spark = get_spark_i_know_what_i_am_doing()
return spark.conf.get("spark.databricks.clusterUsageTags.sparkVersion", "") >= "9.1"
version = spark.conf.get("spark.databricks.clusterUsageTags.sparkVersion", "0.0")
parts = version.split(".")
if (len(parts) < 2):
raise RuntimeError("Unable to determine Databricks version from version string: " + version)
return int(parts[0]) >= major and int(parts[1]) >= minor

def is_databricks91_or_later():
return is_databricks_version_or_later(9, 1)

def is_databricks104_or_later():
return is_databricks_version_or_later(10, 4)
Loading

0 comments on commit 5818905

Please sign in to comment.