From 4c102b66a8f6320f3fab98dcd5c814363d870fbe Mon Sep 17 00:00:00 2001 From: Sean Lee Date: Sun, 29 Nov 2020 23:09:55 -0800 Subject: [PATCH] Fix a bug with the support for java.lang.StringBuilder.append. (#1202) * Fix a bug with the support for java.lang.StringBuilder.append. As java.lang.StringBuilder.append is a mutable operation, when this method is called with a string builder object, all the references of this object needs to be updated in locals and stack. Signed-off-by: Sean Lee * Add multiple appends to the test case Signed-off-by: Sean Lee --- .../spark/udf/CatalystExpressionBuilder.scala | 1 + .../com/nvidia/spark/udf/Instruction.scala | 25 ++++++++++------- .../scala/com/nvidia/spark/OpcodeSuite.scala | 27 +++++++++++++++++++ 3 files changed, 43 insertions(+), 10 deletions(-) diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/CatalystExpressionBuilder.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/CatalystExpressionBuilder.scala index 44145a9cc9c..5084ce1e018 100644 --- a/udf-compiler/src/main/scala/com/nvidia/spark/udf/CatalystExpressionBuilder.scala +++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/CatalystExpressionBuilder.scala @@ -417,6 +417,7 @@ object CatalystExpressionBuilder extends Logging { simplifyExpr(Cast(t, BooleanType, tz)), simplifyExpr(Cast(f, BooleanType, tz)))) case If(c, Repr.ArrayBuffer(t), Repr.ArrayBuffer(f)) => Repr.ArrayBuffer(If(c, t, f)) + case If(c, Repr.StringBuilder(t), Repr.StringBuilder(f)) => Repr.StringBuilder(If(c, t, f)) case _ => expr } logDebug(s"[CatalystExpressionBuilder] simplify: ${expr} ==> ${res}") diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala index 7093ed4a726..733d23aec17 100644 --- a/udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala +++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala @@ -58,19 +58,19 @@ private[udf] object Repr { } // Internal representation of java.lang.StringBuilder. - case class StringBuilder() extends CompilerInternal("java.lang.StringBuilder") { - def invoke(methodName: String, args: List[Expression]): Expression = { + case class StringBuilder(var string: Expression = Literal.default(StringType)) + extends CompilerInternal("java.lang.StringBuilder") { + override def dataType: DataType = string.dataType + + def invoke(methodName: String, args: List[Expression]): (Expression, Boolean) = { methodName match { - case "StringBuilder" => this - case "append" => string = Concat(string :: args) - this - case "toString" => string + case "StringBuilder" => (this, false) + case "append" => (StringBuilder(Concat(string :: args)), true) + case "toString" => (string, false) case _ => throw new SparkException(s"Unsupported StringBuilder op ${methodName}") } } - - var string: Expression = Literal.default(StringType) } // Internal representation of the bytecode instruction getstatic. @@ -481,9 +481,14 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend if (!args.head.isInstanceOf[Repr.StringBuilder]) { throw new SparkException("Internal error with StringBuilder") } - val retval = args.head.asInstanceOf[Repr.StringBuilder] + val (retval, updateState) = args.head.asInstanceOf[Repr.StringBuilder] .invoke(method.getName, args.tail) - State(locals, retval :: rest, cond, expr) + val newState = State(locals, retval :: rest, cond, expr) + if (updateState) { + newState.remap(args.head, retval) + } else { + newState + } } else if (declaringClassName.equals("scala.collection.mutable.ArrayBuffer") || ((!args.isEmpty && args.head.isInstanceOf[Repr.ArrayBuffer]) && ((declaringClassName.equals("scala.collection.AbstractSeq") && diff --git a/udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala b/udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala index bfe9ed45d5d..064e755706d 100644 --- a/udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala +++ b/udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala @@ -1432,6 +1432,33 @@ class OpcodeSuite extends FunSuite { } // Tests for string ops + test("java lang string builder test - append") { + // We do not support java.lang.StringBuilder officially in the udf compiler, + // but string + string in Scala generates some code with + // java.lang.StringBuilder. For that reason, we have some tests with + // java.lang.StringBuilder. + val myudf: (String, String, Boolean) => String = (a,b,c) => { + val sb = new java.lang.StringBuilder() + if (c) { + sb.append(a) + sb.append(" ") + sb.append(b) + sb.toString + "@@@" + " true" + } else { + sb.append(b) + sb.append(" ") + sb.append(a) + sb.toString + "!!!" + " false" + } + } + val u = makeUdf(myudf) + val dataset = List(("Hello", "World", false), + ("Oh", "Hello", true)).toDF("x","y","z").repartition(1) + val result = dataset.withColumn("new", u(col("x"),col("y"),col("z"))) + val ref = List(("Hello", "World", false, "World Hello!!! false"), + ("Oh", "Hello", true, "Oh Hello@@@ true")).toDF + checkEquiv(result, ref) + } test("string test - + concat") { val myudf: (String, String) => String = (a,b) => {