From 2e81d82ce656b950db585d23cc4390030c6c9a9a Mon Sep 17 00:00:00 2001 From: Allen Xu Date: Fri, 25 Sep 2020 16:26:30 +0800 Subject: [PATCH] Add test case for arbitrary function call in UDF Signed-off-by: Allen Xu --- .../scala/com/nvidia/spark/OpcodeSuite.scala | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala b/udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala index c8c5e060404..ecf9d794431 100644 --- a/udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala +++ b/udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala @@ -35,6 +35,7 @@ class OpcodeSuite extends FunSuite { val conf: SparkConf = new SparkConf() .set("spark.sql.extensions", "com.nvidia.spark.udf.Plugin") + .set("spark.rapids.sql.test.enabled", "true") .set("spark.rapids.sql.udfCompiler.enabled", "true") .set(RapidsConf.EXPLAIN.key, "true") @@ -2086,4 +2087,20 @@ class OpcodeSuite extends FunSuite { val ref = dataset.select(lit(Array.empty[String]).as("emptyArrOfStr")) checkEquiv(result, ref) } + + test("Arbitrary function call inside UDF - ") { + def simple_func(str: String): Int = { + str.length + } + val myudf: (String) => Int = str => { + simple_func(str) + } + val u = makeUdf(myudf) + val dataset = List("hello", "world").toDS() + val result = dataset.withColumn("new", u(col("value"))) + val ref = dataset.withColumn("new", length(col("value"))) + result.explain(true) + result.show +// checkEquiv(result, ref) + } }