From 1b221f982d0988e01310e2fa0499a2a389fd81cc Mon Sep 17 00:00:00 2001 From: wjxiz Date: Tue, 5 Jan 2021 22:36:58 +0800 Subject: [PATCH] Scala UDF will compile children expressions in Project (#1153) * Scala UDF will compile child expressions in Project Signed-off-by: Allen Xu Co-authored-by: Alessandro Bellina * Rebased to branch-0.3 to resolve conflicts Signed-off-by: Allen Xu * Remove flatten test case Co-authored-by: Allen Xu Co-authored-by: Alessandro Bellina --- .../scala/com/nvidia/spark/udf/Plugin.scala | 2 +- .../scala/com/nvidia/spark/OpcodeSuite.scala | 43 ++++++++++++------- 2 files changed, 28 insertions(+), 17 deletions(-) diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/Plugin.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/Plugin.scala index 679d126b95c..98164e70ea1 100644 --- a/udf-compiler/src/main/scala/com/nvidia/spark/udf/Plugin.scala +++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/Plugin.scala @@ -82,7 +82,7 @@ case class LogicalPlanRules() extends Rule[LogicalPlan] with Logging { plan match { case project: Project => Project(project.projectList.map(e => attemptToReplaceExpression(plan, e)) - .asInstanceOf[Seq[NamedExpression]], project.child) + .asInstanceOf[Seq[NamedExpression]], apply(project.child)) case x => { x.transformExpressions(replacePartialFunc(plan)) } diff --git a/udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala b/udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala index 064e755706d..048544483d4 100644 --- a/udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala +++ b/udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala @@ -2333,28 +2333,39 @@ class OpcodeSuite extends FunSuite { } val u = makeUdf((x: String, y: String, z: Boolean) => { - var r = new mutable.ArrayBuffer[String]() - r = r :+ x - if (!cond(y)) { - r = r :+ y - - if (z) { - r = r :+ transform(y) - } - } + var r = new mutable.ArrayBuffer[String]() + r = r :+ x + if (!cond(y)) { + r = r :+ y + if (z) { - r = r :+ transform(x) + r = r :+ transform(y) } - r.distinct.toArray - }) + } + if (z) { + r = r :+ transform(x) + } + r.distinct.toArray + }) val dataset = List(("######hello", null), - ("world", "######hello"), - ("", "@@@@target")).toDF("x", "y") + ("world", "######hello"), + ("", "@@@@target")).toDF("x", "y") val result = dataset.withColumn("new", u('x, 'y, lit(true))) val ref = List(("######hello", null, Array("######hello", "@@@@hello")), - ("world", "######hello", Array("world", "######hello", "@@@@hello")), - ("", "@@@@target", Array("", "@@@@target", "######target", null))).toDF + ("world", "######hello", Array("world", "######hello", "@@@@hello")), + ("", "@@@@target", Array("", "@@@@target", "######target", null))).toDF checkEquiv(result, ref) } + + test("compile child expresion in explode") { + val myudf: (String) => Array[String] = a => { + a.split(",") + } + val u = makeUdf(myudf) + val dataset = List("first,second").toDF("x").repartition(1) + var result = dataset.withColumn("new", explode(u(col("x")))) + val ref = List(("first,second","first"),("first,second","second")).toDF("x","new") + checkEquiv(result,ref) + } }