From 5f449cdf18c3f1ec9c36b140d7a3d9583f4ea1ef Mon Sep 17 00:00:00 2001 From: Sean Lee Date: Tue, 17 May 2022 12:40:23 -0700 Subject: [PATCH] Add limited support for captured vars and athrow (#5487) * Add limited support for captured vars and athrow Only the primitive type variables captured from a method are supported. athrow is supported only with SparkException and is converted to raise_error. Following additional changes have also been made: * A check to reject lambdas with void return type has been added. * An operand stack bug fix for consturctor calls has been added. This commit also simplifies the code that handles method calls. Signed-off-by: Sean Lee * Update copyright years * Catch RuntimeException instead of Throwable * Update udf-to-catalyst doc --- .../udf-to-catalyst-expressions.md | 2 + .../main/scala/com/nvidia/spark/udf/CFG.scala | 4 +- .../spark/udf/CatalystExpressionBuilder.scala | 49 +++- .../com/nvidia/spark/udf/Instruction.scala | 218 ++++++++++-------- .../nvidia/spark/udf/LambdaReflection.scala | 28 ++- .../scala/com/nvidia/spark/OpcodeSuite.scala | 39 ++++ 6 files changed, 233 insertions(+), 107 deletions(-) diff --git a/docs/additional-functionality/udf-to-catalyst-expressions.md b/docs/additional-functionality/udf-to-catalyst-expressions.md index 8bd241eddcc..ab5e8f08456 100644 --- a/docs/additional-functionality/udf-to-catalyst-expressions.md +++ b/docs/additional-functionality/udf-to-catalyst-expressions.md @@ -109,6 +109,8 @@ When translating UDFs to Catalyst expressions, the supported UDF functions are l | | lhs += rhs | | | lhs :+ rhs | | Method call | Only if the method being called 1. Consists of operations supported by the UDF compiler, and 2. is one of the folllowing: a final method, a method in a final class, or a method in a final object | +| Captured variables | Only primitive type variables captured from a method | +| Throwing exception | Only if the exception thrown is a SparkException. The exception is then convered to a RuntimeException at runtime | All other expressions, including but not limited to `try` and `catch`, are unsupported and UDFs with such expressions cannot be compiled. diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/CFG.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/CFG.scala index c7abed92e75..add24a9fdf4 100644 --- a/udf-compiler/src/main/scala/com/nvidia/spark/udf/CFG.scala +++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/CFG.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -108,7 +108,7 @@ case class BB(instructionTable: SortedMap[Int, Instruction]) extends Logging { val defaultState = state.copy(cond = simplify(defaultCondition)) newStates + (defaultSucc -> defaultState.merge(newStates.get(defaultSucc))) case Opcode.IRETURN | Opcode.LRETURN | Opcode.FRETURN | Opcode.DRETURN | - Opcode.ARETURN | Opcode.RETURN => states + Opcode.ARETURN | Opcode.RETURN | Opcode.ATHROW => states case _ => val (0, successor) :: Nil = cfg.successor(this) // The condition, stack and locals from the current BB state need to be diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/CatalystExpressionBuilder.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/CatalystExpressionBuilder.scala index 82789523ce1..fee1d9272ec 100644 --- a/udf-compiler/src/main/scala/com/nvidia/spark/udf/CatalystExpressionBuilder.scala +++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/CatalystExpressionBuilder.scala @@ -176,7 +176,36 @@ case class CatalystExpressionBuilder(private val function: AnyRef) extends Loggi // A basic block can have other branching instructions as the last instruction, // otherwise. if (basicBlock.lastInstruction.isReturn) { - newStates(basicBlock).expr + // When the compiled method throws exceptions, we combine the returned + // expression (expr) with exceptions (except0, except1, except2.., exceptn) + // based on exception conditions (cond0, cond1, cond2, .., condn) as follows: + // If + // | + // +------+------+ + // | | | + // cond0 except0 If + // | + // +------+------+ + // | | | + // cond1 except1 If + // | + // +------+------+ + // | | | + // cond2 except2 If + // | + // . + // . + // . + // +------+------+ + // | | | + // condn exceptn expr + newStates.foldRight(newStates(basicBlock).expr) { case ((bb, s), expr) => + if (bb.lastInstruction.isThrow) { + Some(If(s.cond, s.stack.head, expr.get)) + } else { + expr + } + } } else { // account for this block in visited val newVisited = visited + basicBlock @@ -327,6 +356,24 @@ object CatalystExpressionBuilder extends Logging { case _ => expr } } + case And(c1@GreaterThanOrEqual(s1, Literal(v1, t1)), + c2@GreaterThanOrEqual(s2, Literal(v2, t2))) if s1 == s2 && t1 == t2 => { + t1 match { + case IntegerType => + if (v1.asInstanceOf[Int] > v2.asInstanceOf[Int]) { + c1 + } else { + c2 + } + case LongType => + if (v1.asInstanceOf[Long] > v2.asInstanceOf[Long]) { + c1 + } else { + c2 + } + case _ => expr + } + } case And(c1@GreaterThan(s1, Literal(v1, t1)), c2@GreaterThanOrEqual(s2, Literal(v2, t2))) if s1 == s2 && t1 == t2 => { t1 match { diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala index fa9a7192ca0..3f212394179 100644 --- a/udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala +++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala @@ -63,17 +63,37 @@ private[udf] object Repr { extends CompilerInternal("java.lang.StringBuilder") { override def dataType: DataType = string.dataType - def invoke(methodName: String, args: List[Expression]): (Expression, Boolean) = { + def invoke(methodName: String, args: List[Expression]): Option[(Expression, Boolean)] = { methodName match { - case "StringBuilder" => (this, false) - case "append" => (StringBuilder(Concat(string :: args)), true) - case "toString" => (string, false) + case "StringBuilder" => None + case "append" => Some((StringBuilder(Concat(string :: args)), true)) + case "toString" => Some((string, false)) case _ => throw new SparkException(s"Unsupported StringBuilder op ${methodName}") } } } + // Internal representation of org.apache.spark.SparkException + case class SparkExcept(var message: Expression = Literal.default(StringType)) + extends CompilerInternal("org.apache.spark.SparkException") { + + def invoke(methodName: String, args: List[Expression]): Option[(Expression, Boolean)] = { + methodName match { + case "SparkException" => + if (args.length > 1) { + throw new SparkException("Unsupported SparkException construction " + + "with multiple arguments") + } else { + message = args.head + Some((this, false)) + } + case _ => + throw new SparkException(s"Unsupported SparkException op ${methodName}") + } + } + } + // Internal representation of the bytecode instruction getstatic. // This class is needed because we can't represent getstatic in Catalyst, but // we need the getstatic information to handle some method calls @@ -96,11 +116,11 @@ private[udf] object Repr { extends CompilerInternal("scala.collection.mutable.ArrayBuffer") { override def dataType: DataType = arrayBuffer.dataType - def invoke(methodName: String, args: List[Expression]): (Expression, Boolean) = { + def invoke(methodName: String, args: List[Expression]): Option[(Expression, Boolean)] = { methodName match { - case "ArrayBuffer" => (this, false) - case "distinct" => (ArrayBuffer(ArrayDistinct(arrayBuffer)), false) - case "toArray" => (arrayBuffer, false) + case "ArrayBuffer" => Some((this, false)) + case "distinct" => Some((ArrayBuffer(ArrayDistinct(arrayBuffer)), false)) + case "toArray" => Some((arrayBuffer, false)) case "$plus$eq" | "$colon$plus" => val mutable = { if (methodName == "$plus$eq") { @@ -138,11 +158,11 @@ private[udf] object Repr { } } } - (arrayBuffer match { + Some((arrayBuffer match { case CreateArray(Nil, _) => ArrayBuffer(CreateArray(elem)) case array => ArrayBuffer(Concat(Seq(array, CreateArray(elem)))) }, - mutable) + mutable)) case _ => throw new SparkException(s"Unsupported ArrayBuffer op ${methodName}") } @@ -233,6 +253,7 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend const(state, (opcode - Opcode.ICONST_0).asInstanceOf[Int]) case Opcode.LCONST_0 | Opcode.LCONST_1 => const(state, (opcode - Opcode.LCONST_0).asInstanceOf[Long]) + case Opcode.ATHROW => athrow(state) case Opcode.DADD | Opcode.FADD | Opcode.IADD | Opcode.LADD => binary(state, Add(_, _)) case Opcode.DSUB | Opcode.FSUB | Opcode.ISUB | Opcode.LSUB => binary(state, Subtract(_, _)) case Opcode.DMUL | Opcode.FMUL | Opcode.IMUL | Opcode.LMUL => binary(state, Multiply(_, _)) @@ -310,6 +331,11 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend case _ => false } + def isThrow: Boolean = opcode match { + case Opcode.ATHROW => true + case _ => false + } + // // Handle instructions // @@ -323,6 +349,17 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend State(locals.updated(localsIndex, top), rest, cond, expr) } + private def athrow(state: State): State = { + val State(locals, top :: rest, cond, expr) = state + if (!top.isInstanceOf[Repr.SparkExcept]) { + throw new SparkException("Unsupported type for athrow") + } + // Empty the stack and convert the internal representation of + // org.apache.spark.SparkException object to RaiseError, then push it to the + // stack. + State(locals, List(RaiseError(top.asInstanceOf[Repr.SparkExcept].message)), cond, expr) + } + private def const(state: State, value: Any): State = { val State(locals, stack, cond, expr) = state State(locals, Literal(value) :: stack, cond, expr) @@ -360,6 +397,9 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend if (typeName.equals("java.lang.StringBuilder")) { val State(locals, stack, cond, expr) = state State(locals, Repr.StringBuilder() :: stack, cond, expr) + } else if (typeName.equals("org.apache.spark.SparkException")) { + val State(locals, stack, cond, expr) = state + State(locals, Repr.SparkExcept() :: stack, cond, expr) } else if (typeName.equals("scala.collection.mutable.ArrayBuffer")) { val State(locals, stack, cond, expr) = state State(locals, Repr.ArrayBuffer() :: stack, cond, expr) @@ -447,98 +487,84 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend val (args, rest) = getArgs(stack, paramTypes.length) // We don't support arbitrary calls. // We support only some math and string methods. - if (declaringClassName.equals("scala.math.package$")) { - State(locals, - mathOp(method.getName, args) :: rest, - cond, - expr) - } else if (declaringClassName.equals("scala.Predef$")) { - State(locals, - predefOp(method.getName, args) :: rest, - cond, - expr) - } else if (declaringClassName.equals("scala.Array$")) { - State(locals, - arrayOp(method.getName, args) :: rest, - cond, - expr) - } else if (declaringClassName.equals("scala.reflect.ClassTag$")) { - State(locals, - classTagOp(method.getName, args) :: rest, - cond, - expr) - } else if (declaringClassName.equals("scala.collection.mutable.ArrayBuffer$")) { - State(locals, - arrayBufferOp(method.getName, args) :: rest, - cond, - expr) - } else if (declaringClassName.equals("java.lang.Double")) { - State(locals, doubleOp(method.getName, args) :: rest, cond, expr) - } else if (declaringClassName.equals("java.lang.Float")) { - State(locals, floatOp(method.getName, args) :: rest, cond, expr) - } else if (declaringClassName.equals("java.lang.String")) { - State(locals, stringOp(method.getName, args) :: rest, cond, expr) - } else if (declaringClassName.equals("java.lang.StringBuilder")) { - if (!args.head.isInstanceOf[Repr.StringBuilder]) { - throw new SparkException("Internal error with StringBuilder") - } - val (retval, updateState) = args.head.asInstanceOf[Repr.StringBuilder] - .invoke(method.getName, args.tail) - val newState = State(locals, retval :: rest, cond, expr) - if (updateState) { - newState.remap(args.head, retval) + val retval = { + if (declaringClassName.equals("scala.math.package$")) { + Some((mathOp(method.getName, args), false)) + } else if (declaringClassName.equals("scala.Predef$")) { + Some((predefOp(method.getName, args), false)) + } else if (declaringClassName.equals("scala.Array$")) { + Some((arrayOp(method.getName, args), false)) + } else if (declaringClassName.equals("scala.reflect.ClassTag$")) { + Some((classTagOp(method.getName, args), false)) + } else if (declaringClassName.equals("scala.collection.mutable.ArrayBuffer$")) { + Some((arrayBufferOp(method.getName, args), false)) + } else if (declaringClassName.equals("java.lang.Double")) { + Some((doubleOp(method.getName, args), false)) + } else if (declaringClassName.equals("java.lang.Float")) { + Some((floatOp(method.getName, args), false)) + } else if (declaringClassName.equals("java.lang.String")) { + Some((stringOp(method.getName, args), false)) + } else if (declaringClassName.equals("java.lang.StringBuilder")) { + if (!args.head.isInstanceOf[Repr.StringBuilder]) { + throw new SparkException("Internal error with StringBuilder") + } + args.head.asInstanceOf[Repr.StringBuilder].invoke(method.getName, args.tail) + } else if (declaringClassName.equals("org.apache.spark.SparkException")) { + if (!args.head.isInstanceOf[Repr.SparkExcept]) { + throw new SparkException("Internal error with SparkException") + } + args.head.asInstanceOf[Repr.SparkExcept].invoke(method.getName, args.tail) + } else if (declaringClassName.equals("scala.collection.mutable.ArrayBuffer") || + ((args.nonEmpty && args.head.isInstanceOf[Repr.ArrayBuffer]) && + ((declaringClassName.equals("scala.collection.AbstractSeq") && + opcode == Opcode.INVOKEVIRTUAL) || + (declaringClassName.equals("scala.collection.TraversableOnce") && + opcode == Opcode.INVOKEINTERFACE)))) { + if (!args.head.isInstanceOf[Repr.ArrayBuffer]) { + throw new SparkException( + s"Unexpected argument for ${declaringClassName}.${method.getName}") + } + args.head.asInstanceOf[Repr.ArrayBuffer].invoke(method.getName, args.tail) + } else if (declaringClassName.equals("java.time.format.DateTimeFormatter")) { + Some((dateTimeFormatterOp(method.getName, args), false)) + } else if (declaringClassName.equals("java.time.LocalDateTime")) { + Some((localDateTimeOp(method.getName, args), false)) } else { - newState - } - } else if (declaringClassName.equals("scala.collection.mutable.ArrayBuffer") || - ((args.nonEmpty && args.head.isInstanceOf[Repr.ArrayBuffer]) && - ((declaringClassName.equals("scala.collection.AbstractSeq") && - opcode == Opcode.INVOKEVIRTUAL) || - (declaringClassName.equals("scala.collection.TraversableOnce") && - opcode == Opcode.INVOKEINTERFACE)))) { - if (!args.head.isInstanceOf[Repr.ArrayBuffer]) { - throw new SparkException( - s"Unexpected argument for ${declaringClassName}.${method.getName}") + val mModifiers = method.getModifiers + val cModifiers = declaringClass.getModifiers + if (!javassist.Modifier.isEnum(mModifiers) && + !javassist.Modifier.isInterface(mModifiers) && + !javassist.Modifier.isNative(mModifiers) && + !javassist.Modifier.isPackage(mModifiers) && + !javassist.Modifier.isStrict(mModifiers) && + !javassist.Modifier.isSynchronized(mModifiers) && + !javassist.Modifier.isTransient(mModifiers) && + !javassist.Modifier.isVarArgs(mModifiers) && + !javassist.Modifier.isVolatile(mModifiers) && + (javassist.Modifier.isFinal(mModifiers) || + javassist.Modifier.isFinal(cModifiers))) { + val retval = { + if (javassist.Modifier.isStatic(mModifiers)) { + CatalystExpressionBuilder(method).compile(args) + } else { + CatalystExpressionBuilder(method).compile(args.tail, Some(args.head)) + } + } + Some((retval.get, false)) + } else { + // Other functions + throw new SparkException( + s"Unsupported invocation of ${declaringClassName}.${method.getName}") + } } - val (retval, updateState) = args.head.asInstanceOf[Repr.ArrayBuffer] - .invoke(method.getName, args.tail) - val newState = State(locals, retval :: rest, cond, expr) + } + retval.fold(State(locals, rest, cond, expr)) { case (v, updateState) => + val newState = State(locals, v :: rest, cond, expr) if (updateState) { - newState.remap(args.head, retval) + newState.remap(args.head, v) } else { newState } - } else if (declaringClassName.equals("java.time.format.DateTimeFormatter")) { - State(locals, dateTimeFormatterOp(method.getName, args) :: rest, cond, expr) - } else if (declaringClassName.equals("java.time.LocalDateTime")) { - State(locals, localDateTimeOp(method.getName, args) :: rest, cond, expr) - } else { - val mModifiers = method.getModifiers - val cModifiers = declaringClass.getModifiers - if (!javassist.Modifier.isEnum(mModifiers) && - !javassist.Modifier.isInterface(mModifiers) && - !javassist.Modifier.isNative(mModifiers) && - !javassist.Modifier.isPackage(mModifiers) && - !javassist.Modifier.isStrict(mModifiers) && - !javassist.Modifier.isSynchronized(mModifiers) && - !javassist.Modifier.isTransient(mModifiers) && - !javassist.Modifier.isVarArgs(mModifiers) && - !javassist.Modifier.isVolatile(mModifiers) && - (javassist.Modifier.isFinal(mModifiers) || - javassist.Modifier.isFinal(cModifiers))) { - val retval = { - if (javassist.Modifier.isStatic(mModifiers)) { - CatalystExpressionBuilder(method).compile(args) - } else { - CatalystExpressionBuilder(method).compile(args.tail, Some(args.head)) - } - } - State(locals, retval.toList ::: rest, cond, expr) - } else { - // Other functions - throw new SparkException( - s"Unsupported invocation of ${declaringClassName}.${method.getName}") - } } } diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/LambdaReflection.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/LambdaReflection.scala index 4c1e248538d..cb21aacc9e5 100644 --- a/udf-compiler/src/main/scala/com/nvidia/spark/udf/LambdaReflection.scala +++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/LambdaReflection.scala @@ -22,7 +22,7 @@ import javassist.{ClassClassPath, ClassPool, CtBehavior, CtClass, CtField, CtMet import javassist.bytecode.{CodeIterator, ConstPool, Descriptor} import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} import org.apache.spark.sql.types._ // @@ -101,6 +101,10 @@ class LambdaReflection private(private val classPool: ClassPool, lazy val ret: CtClass = ctMethod.getReturnType + if (ret == CtClass.voidType) { + throw new SparkException("Cannot construct Catalyst expression with void return type") + } + lazy val maxLocals: Int = codeAttribute.getMaxLocals } @@ -119,13 +123,21 @@ object LambdaReflection { writeReplace.setAccessible(true) val serializedLambda = writeReplace.invoke(function) .asInstanceOf[SerializedLambda] - val capturedArgs = Seq.tabulate(serializedLambda.getCapturedArgCount){ - // Add UnknownCapturedArg expressions to the list of captured arguments - // until we figure out how to handle captured variables in various - // situations. - // This, however, lets us pass captured arguments as long as they are not - // evaluated. - _ => Repr.UnknownCapturedArg() + val capturedArgs = Seq.tabulate(serializedLambda.getCapturedArgCount) { i => + val capturedArg = serializedLambda.getCapturedArg(i) + if (capturedArg.isInstanceOf[Byte] || capturedArg.isInstanceOf[Char] || + capturedArg.isInstanceOf[Short] || capturedArg.isInstanceOf[Int] || + capturedArg.isInstanceOf[Long] || capturedArg.isInstanceOf[Float] || + capturedArg.isInstanceOf[Double] || capturedArg.isInstanceOf[Boolean]) { + Literal(capturedArg) + } else { + // Add UnknownCapturedArg expressions to the list of captured arguments + // until we figure out how to handle captured variables in various + // situations. + // This, however, lets us pass captured arguments as long as they are + // not evaluated. + Repr.UnknownCapturedArg() + } } val classPool = new ClassPool(true) diff --git a/udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala b/udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala index b0210d84042..8f1f225179e 100644 --- a/udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala +++ b/udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala @@ -25,6 +25,7 @@ import com.nvidia.spark.rapids.RapidsConf import org.scalatest.FunSuite import org.apache.spark.SparkConf +import org.apache.spark.SparkException import org.apache.spark.sql.{Dataset, Row, SparkSession} import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions.{udf => makeUdf} @@ -2350,6 +2351,44 @@ class OpcodeSuite extends FunSuite { checkEquivNotCompiled(result, ref) } + test("capture a primitive var in method") { + def run() = { + val capturedArg: Int = 4 + val myudf: (String) => Int = str => { + str.length + capturedArg + } + val u = makeUdf(myudf) + val dataset = List("hello", "world").toDS() + val result = dataset.withColumn("new", u(col("value"))) + val ref = dataset.withColumn("new", length(col("value")) + capturedArg) + checkEquiv(result, ref) + } + run + } + + test("throw a SparkException object") { + def run(x: Int) = { + val myudf: (Int) => Boolean = i => { + if (i < 0 || i >= x) { + throw new SparkException(s"Fold number must be in range [0, $x), but got $i.") + } + true + } + val u = makeUdf(myudf) + val dataset = List(2, 20).toDS() + val result = dataset.withColumn("new", u('value)) + val ref = dataset.withColumn("new", lit(true)) + checkEquiv(result, ref) + } + run(30) + try { + run(20) + } catch { + case e: RuntimeException => + assert(e.getMessage == "Fold number must be in range [0, 20), but got 20.") + } + } + test("Conditional array buffer processing") { def cond(s: String): Boolean = { s == null || s.trim.isEmpty