From 806f86ed03602a8aa8ab4f4a8c52b70b2d4c54a2 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Tue, 16 Feb 2021 15:48:14 -0600 Subject: [PATCH] Skip RAPIDS accelerated Java UDF tests if UDF fails to load (#1728) Signed-off-by: Jason Lowe --- .../src/main/python/rapids_udf_test.py | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/integration_tests/src/main/python/rapids_udf_test.py b/integration_tests/src/main/python/rapids_udf_test.py index 47ff95b0183..45424c72ea9 100644 --- a/integration_tests/src/main/python/rapids_udf_test.py +++ b/integration_tests/src/main/python/rapids_udf_test.py @@ -30,7 +30,7 @@ def skip_if_no_hive(spark): if spark.conf.get("spark.sql.catalogImplementation") != "hive": skip_unless_precommit_tests('The Spark session does not have Hive support') -def load_udf_or_skip_test(spark, udfname, udfclass): +def load_hive_udf_or_skip_test(spark, udfname, udfclass): drop_udf(spark, udfname) try: spark.sql("CREATE TEMPORARY FUNCTION {} AS '{}'".format(udfname, udfclass)) @@ -41,7 +41,7 @@ def test_hive_simple_udf(): with_spark_session(skip_if_no_hive) data_gens = [["i", int_gen], ["s", encoded_url_gen]] def evalfn(spark): - load_udf_or_skip_test(spark, "urldecode", "com.nvidia.spark.rapids.udf.hive.URLDecode") + load_hive_udf_or_skip_test(spark, "urldecode", "com.nvidia.spark.rapids.udf.hive.URLDecode") return gen_df(spark, data_gens) assert_gpu_and_cpu_are_equal_sql( evalfn, @@ -52,7 +52,7 @@ def test_hive_generic_udf(): with_spark_session(skip_if_no_hive) data_gens = [["s", StringGen('.{0,30}')]] def evalfn(spark): - load_udf_or_skip_test(spark, "urlencode", "com.nvidia.spark.rapids.udf.hive.URLEncode") + load_hive_udf_or_skip_test(spark, "urlencode", "com.nvidia.spark.rapids.udf.hive.URLEncode") return gen_df(spark, data_gens) assert_gpu_and_cpu_are_equal_sql( evalfn, @@ -63,23 +63,28 @@ def evalfn(spark): def test_hive_simple_udf_native(enable_rapids_udf_example_native): with_spark_session(skip_if_no_hive) def evalfn(spark): - load_udf_or_skip_test(spark, "wordcount", "com.nvidia.spark.rapids.udf.hive.StringWordCount") + load_hive_udf_or_skip_test(spark, "wordcount", "com.nvidia.spark.rapids.udf.hive.StringWordCount") return gen_df(spark, data_gens) assert_gpu_and_cpu_are_equal_sql( evalfn, "hive_native_udf_test_table", "SELECT wordcount(s) FROM hive_native_udf_test_table") +def load_java_udf_or_skip_test(spark, udfname, udfclass): + drop_udf(spark, udfname) + try: + spark.udf.registerJavaFunction(udfname, udfclass) + except AnalysisException: + skip_unless_precommit_tests("UDF {} failed to load, udf-examples jar is probably missing".format(udfname)) + def test_java_url_decode(): def evalfn(spark): - drop_udf(spark, 'urldecode') - spark.udf.registerJavaFunction('urldecode', 'com.nvidia.spark.rapids.udf.java.URLDecode') + load_java_udf_or_skip_test(spark, 'urldecode', 'com.nvidia.spark.rapids.udf.java.URLDecode') return unary_op_df(spark, encoded_url_gen).selectExpr("urldecode(a)") assert_gpu_and_cpu_are_equal_collect(evalfn) def test_java_url_encode(): def evalfn(spark): - drop_udf(spark, 'urlencode') - spark.udf.registerJavaFunction('urlencode', 'com.nvidia.spark.rapids.udf.java.URLEncode') + load_java_udf_or_skip_test(spark, 'urlencode', 'com.nvidia.spark.rapids.udf.java.URLEncode') return unary_op_df(spark, StringGen('.{0,30}')).selectExpr("urlencode(a)") assert_gpu_and_cpu_are_equal_collect(evalfn)