Skip to content

Commit

Permalink
Add division by zero tests for Spark 3.1 behavior (NVIDIA#1599)
Browse files Browse the repository at this point in the history
Adds tests reproducing SPARK-33008 on CPU and GPU. Contributes to NVIDIA#1464 

Signed-off-by: Gera Shegalov <gera@apache.org>
  • Loading branch information
gerashegalov authored Jan 30, 2021
1 parent 4987bd4 commit 8ed16d8
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 5 deletions.
36 changes: 31 additions & 5 deletions integration_tests/src/main/python/arithmetic_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_error
from data_gen import *
from marks import incompat, approximate_float
from pyspark.sql.types import *
from spark_session import with_spark_session, is_before_spark_310
from spark_session import with_cpu_session, with_gpu_session, with_spark_session, is_before_spark_310
import pyspark.sql.functions as f

decimal_gens_not_max_prec = [decimal_gen_neg_scale, decimal_gen_scale_precision,
Expand Down Expand Up @@ -48,7 +48,7 @@ def test_subtraction(data_gen):
f.col('a') - f.col('b')),
conf=allow_negative_scale_of_decimal_conf)

@pytest.mark.parametrize('data_gen', numeric_gens +
@pytest.mark.parametrize('data_gen', numeric_gens +
[decimal_gen_neg_scale, decimal_gen_scale_precision, decimal_gen_same_scale_precision, DecimalGen(8, 8)], ids=idfn)
def test_multiplication(data_gen):
data_type = data_gen.data_type
Expand Down Expand Up @@ -476,7 +476,7 @@ def test_least(data_gen):
num_cols = 20
s1 = gen_scalar(data_gen, force_no_nulls=not isinstance(data_gen, NullGen))
# we want lots of nulls
gen = StructGen([('_c' + str(x), data_gen.copy_special_case(None, weight=100.0))
gen = StructGen([('_c' + str(x), data_gen.copy_special_case(None, weight=100.0))
for x in range(0, num_cols)], nullable=False)

command_args = [f.col('_c' + str(x)) for x in range(0, num_cols)]
Expand All @@ -491,7 +491,7 @@ def test_greatest(data_gen):
num_cols = 20
s1 = gen_scalar(data_gen, force_no_nulls=not isinstance(data_gen, NullGen))
# we want lots of nulls
gen = StructGen([('_c' + str(x), data_gen.copy_special_case(None, weight=100.0))
gen = StructGen([('_c' + str(x), data_gen.copy_special_case(None, weight=100.0))
for x in range(0, num_cols)], nullable=False)
command_args = [f.col('_c' + str(x)) for x in range(0, num_cols)]
command_args.append(s1)
Expand All @@ -500,3 +500,29 @@ def test_greatest(data_gen):
lambda spark : gen_df(spark, gen).select(
f.greatest(*command_args)), conf=allow_negative_scale_of_decimal_conf)

@pytest.mark.parametrize('ansi_mode', ['nonAnsi', 'ansi'])
@pytest.mark.parametrize('exp_type', ['bothLiterals', 'justZeroLiteral', 'noLiterals'])
def test_div_by_zero(ansi_mode, exp_type):
if ansi_mode == 'ansi':
if is_before_spark_310():
pytest.xfail('https://github.com/apache/spark/pull/29882')
elif exp_type != 'bothLiterals':
pytest.xfail('https://github.com/NVIDIA/spark-rapids/issues/1464')

ansi_conf = {'spark.sql.ansi.enabled': ansi_mode == 'ansi'}
data_gen = lambda spark: two_col_df(spark, IntegerGen(), IntegerGen(min_val=0, max_val=0), length=1)

if exp_type == 'bothLiterals':
div_by_zero_func = lambda spark: data_gen(spark).select(f.lit(1) / f.lit(0))
elif exp_type == 'justZeroLiteral':
div_by_zero_func = lambda spark: data_gen(spark).select(f.col('a') / f.lit(0))
else:
div_by_zero_func = lambda spark: data_gen(spark).select(f.col('a') / f.col('b'))

if ansi_mode == 'ansi':
assert_gpu_and_cpu_error(df_fun=lambda spark: div_by_zero_func(spark).collect(),
conf=ansi_conf,
error_message='java.lang.ArithmeticException: divide by zero')
else:
assert_gpu_and_cpu_are_equal_collect(div_by_zero_func, ansi_conf)

25 changes: 25 additions & 0 deletions integration_tests/src/main/python/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from decimal import Decimal
import math
from pyspark.sql import Row
from py4j.protocol import Py4JJavaError

import pytest
from spark_session import with_cpu_session, with_gpu_session
import time
Expand Down Expand Up @@ -360,3 +362,26 @@ def do_it_all(spark):
df.createOrReplaceTempView(table_name)
return spark.sql(sql)
assert_gpu_and_cpu_are_equal_collect(do_it_all, conf)

def assert_py4j_exception(func, error_message):
"""
Assert that a specific Java exception is thrown
:param func: a function to be verified
:param error_message: a string such as the one produce by java.lang.Exception.toString
:return: Assertion failure if no exception matching error_message has occurred.
"""
with pytest.raises(Py4JJavaError) as py4jError:
func()
assert error_message in str(py4jError.value.java_exception)

def assert_gpu_and_cpu_error(df_fun, conf, error_message):
"""
Assert that GPU and CPU execution results in a specific Java exception thrown
:param df_fun: a function to be verified
:param conf: Spark config
:param error_message: a string such as the one produce by java.lang.Exception.toString
:return: Assertion failure if either GPU or CPU versions has not generated error messages
expected
"""
assert_py4j_exception(lambda: with_cpu_session(df_fun, conf), error_message)
assert_py4j_exception(lambda: with_gpu_session(df_fun, conf), error_message)

0 comments on commit 8ed16d8

Please sign in to comment.