Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip RAPIDS accelerated Java UDF tests if UDF fails to load #1756

Merged
merged 1 commit into from
Feb 18, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions integration_tests/src/main/python/rapids_udf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def skip_if_no_hive(spark):
if spark.conf.get("spark.sql.catalogImplementation") != "hive":
skip_unless_precommit_tests('The Spark session does not have Hive support')

def load_udf_or_skip_test(spark, udfname, udfclass):
def load_hive_udf_or_skip_test(spark, udfname, udfclass):
drop_udf(spark, udfname)
try:
spark.sql("CREATE TEMPORARY FUNCTION {} AS '{}'".format(udfname, udfclass))
Expand All @@ -41,7 +41,7 @@ def test_hive_simple_udf():
with_spark_session(skip_if_no_hive)
data_gens = [["i", int_gen], ["s", encoded_url_gen]]
def evalfn(spark):
load_udf_or_skip_test(spark, "urldecode", "com.nvidia.spark.rapids.udf.hive.URLDecode")
load_hive_udf_or_skip_test(spark, "urldecode", "com.nvidia.spark.rapids.udf.hive.URLDecode")
return gen_df(spark, data_gens)
assert_gpu_and_cpu_are_equal_sql(
evalfn,
Expand All @@ -52,7 +52,7 @@ def test_hive_generic_udf():
with_spark_session(skip_if_no_hive)
data_gens = [["s", StringGen('.{0,30}')]]
def evalfn(spark):
load_udf_or_skip_test(spark, "urlencode", "com.nvidia.spark.rapids.udf.hive.URLEncode")
load_hive_udf_or_skip_test(spark, "urlencode", "com.nvidia.spark.rapids.udf.hive.URLEncode")
return gen_df(spark, data_gens)
assert_gpu_and_cpu_are_equal_sql(
evalfn,
Expand All @@ -63,23 +63,28 @@ def evalfn(spark):
def test_hive_simple_udf_native(enable_rapids_udf_example_native):
with_spark_session(skip_if_no_hive)
def evalfn(spark):
load_udf_or_skip_test(spark, "wordcount", "com.nvidia.spark.rapids.udf.hive.StringWordCount")
load_hive_udf_or_skip_test(spark, "wordcount", "com.nvidia.spark.rapids.udf.hive.StringWordCount")
return gen_df(spark, data_gens)
assert_gpu_and_cpu_are_equal_sql(
evalfn,
"hive_native_udf_test_table",
"SELECT wordcount(s) FROM hive_native_udf_test_table")

def load_java_udf_or_skip_test(spark, udfname, udfclass):
drop_udf(spark, udfname)
try:
spark.udf.registerJavaFunction(udfname, udfclass)
except AnalysisException:
skip_unless_precommit_tests("UDF {} failed to load, udf-examples jar is probably missing".format(udfname))

def test_java_url_decode():
def evalfn(spark):
drop_udf(spark, 'urldecode')
spark.udf.registerJavaFunction('urldecode', 'com.nvidia.spark.rapids.udf.java.URLDecode')
load_java_udf_or_skip_test(spark, 'urldecode', 'com.nvidia.spark.rapids.udf.java.URLDecode')
return unary_op_df(spark, encoded_url_gen).selectExpr("urldecode(a)")
assert_gpu_and_cpu_are_equal_collect(evalfn)

def test_java_url_encode():
def evalfn(spark):
drop_udf(spark, 'urlencode')
spark.udf.registerJavaFunction('urlencode', 'com.nvidia.spark.rapids.udf.java.URLEncode')
load_java_udf_or_skip_test(spark, 'urlencode', 'com.nvidia.spark.rapids.udf.java.URLEncode')
return unary_op_df(spark, StringGen('.{0,30}')).selectExpr("urlencode(a)")
assert_gpu_and_cpu_are_equal_collect(evalfn)