Skip to content

Commit

Permalink
Add limited support for captured vars and athrow (#5487)
Browse files Browse the repository at this point in the history
* Add limited support for captured vars and athrow

Only the primitive type variables captured from a method are supported.
athrow is supported only with SparkException and is converted to raise_error.

Following additional changes have also been made:
* A check to reject lambdas with void return type has been added.
* An operand stack bug fix for consturctor calls has been added.

This commit also simplifies the code that handles method calls.

Signed-off-by: Sean Lee <selee@nvidia.com>

* Update copyright years

* Catch RuntimeException instead of Throwable

* Update udf-to-catalyst doc
  • Loading branch information
seanprime7 authored May 17, 2022
1 parent a427977 commit 5f449cd
Show file tree
Hide file tree
Showing 6 changed files with 233 additions and 107 deletions.
2 changes: 2 additions & 0 deletions docs/additional-functionality/udf-to-catalyst-expressions.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ When translating UDFs to Catalyst expressions, the supported UDF functions are l
| | lhs += rhs |
| | lhs :+ rhs |
| Method call | Only if the method being called 1. Consists of operations supported by the UDF compiler, and 2. is one of the folllowing: a final method, a method in a final class, or a method in a final object |
| Captured variables | Only primitive type variables captured from a method |
| Throwing exception | Only if the exception thrown is a SparkException. The exception is then convered to a RuntimeException at runtime |

All other expressions, including but not limited to `try` and `catch`, are unsupported and UDFs
with such expressions cannot be compiled.
4 changes: 2 additions & 2 deletions udf-compiler/src/main/scala/com/nvidia/spark/udf/CFG.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -108,7 +108,7 @@ case class BB(instructionTable: SortedMap[Int, Instruction]) extends Logging {
val defaultState = state.copy(cond = simplify(defaultCondition))
newStates + (defaultSucc -> defaultState.merge(newStates.get(defaultSucc)))
case Opcode.IRETURN | Opcode.LRETURN | Opcode.FRETURN | Opcode.DRETURN |
Opcode.ARETURN | Opcode.RETURN => states
Opcode.ARETURN | Opcode.RETURN | Opcode.ATHROW => states
case _ =>
val (0, successor) :: Nil = cfg.successor(this)
// The condition, stack and locals from the current BB state need to be
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,36 @@ case class CatalystExpressionBuilder(private val function: AnyRef) extends Loggi
// A basic block can have other branching instructions as the last instruction,
// otherwise.
if (basicBlock.lastInstruction.isReturn) {
newStates(basicBlock).expr
// When the compiled method throws exceptions, we combine the returned
// expression (expr) with exceptions (except0, except1, except2.., exceptn)
// based on exception conditions (cond0, cond1, cond2, .., condn) as follows:
// If
// |
// +------+------+
// | | |
// cond0 except0 If
// |
// +------+------+
// | | |
// cond1 except1 If
// |
// +------+------+
// | | |
// cond2 except2 If
// |
// .
// .
// .
// +------+------+
// | | |
// condn exceptn expr
newStates.foldRight(newStates(basicBlock).expr) { case ((bb, s), expr) =>
if (bb.lastInstruction.isThrow) {
Some(If(s.cond, s.stack.head, expr.get))
} else {
expr
}
}
} else {
// account for this block in visited
val newVisited = visited + basicBlock
Expand Down Expand Up @@ -327,6 +356,24 @@ object CatalystExpressionBuilder extends Logging {
case _ => expr
}
}
case And(c1@GreaterThanOrEqual(s1, Literal(v1, t1)),
c2@GreaterThanOrEqual(s2, Literal(v2, t2))) if s1 == s2 && t1 == t2 => {
t1 match {
case IntegerType =>
if (v1.asInstanceOf[Int] > v2.asInstanceOf[Int]) {
c1
} else {
c2
}
case LongType =>
if (v1.asInstanceOf[Long] > v2.asInstanceOf[Long]) {
c1
} else {
c2
}
case _ => expr
}
}
case And(c1@GreaterThan(s1, Literal(v1, t1)),
c2@GreaterThanOrEqual(s2, Literal(v2, t2))) if s1 == s2 && t1 == t2 => {
t1 match {
Expand Down
218 changes: 122 additions & 96 deletions udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,37 @@ private[udf] object Repr {
extends CompilerInternal("java.lang.StringBuilder") {
override def dataType: DataType = string.dataType

def invoke(methodName: String, args: List[Expression]): (Expression, Boolean) = {
def invoke(methodName: String, args: List[Expression]): Option[(Expression, Boolean)] = {
methodName match {
case "StringBuilder" => (this, false)
case "append" => (StringBuilder(Concat(string :: args)), true)
case "toString" => (string, false)
case "StringBuilder" => None
case "append" => Some((StringBuilder(Concat(string :: args)), true))
case "toString" => Some((string, false))
case _ =>
throw new SparkException(s"Unsupported StringBuilder op ${methodName}")
}
}
}

// Internal representation of org.apache.spark.SparkException
case class SparkExcept(var message: Expression = Literal.default(StringType))
extends CompilerInternal("org.apache.spark.SparkException") {

def invoke(methodName: String, args: List[Expression]): Option[(Expression, Boolean)] = {
methodName match {
case "SparkException" =>
if (args.length > 1) {
throw new SparkException("Unsupported SparkException construction " +
"with multiple arguments")
} else {
message = args.head
Some((this, false))
}
case _ =>
throw new SparkException(s"Unsupported SparkException op ${methodName}")
}
}
}

// Internal representation of the bytecode instruction getstatic.
// This class is needed because we can't represent getstatic in Catalyst, but
// we need the getstatic information to handle some method calls
Expand All @@ -96,11 +116,11 @@ private[udf] object Repr {
extends CompilerInternal("scala.collection.mutable.ArrayBuffer") {
override def dataType: DataType = arrayBuffer.dataType

def invoke(methodName: String, args: List[Expression]): (Expression, Boolean) = {
def invoke(methodName: String, args: List[Expression]): Option[(Expression, Boolean)] = {
methodName match {
case "ArrayBuffer" => (this, false)
case "distinct" => (ArrayBuffer(ArrayDistinct(arrayBuffer)), false)
case "toArray" => (arrayBuffer, false)
case "ArrayBuffer" => Some((this, false))
case "distinct" => Some((ArrayBuffer(ArrayDistinct(arrayBuffer)), false))
case "toArray" => Some((arrayBuffer, false))
case "$plus$eq" | "$colon$plus" =>
val mutable = {
if (methodName == "$plus$eq") {
Expand Down Expand Up @@ -138,11 +158,11 @@ private[udf] object Repr {
}
}
}
(arrayBuffer match {
Some((arrayBuffer match {
case CreateArray(Nil, _) => ArrayBuffer(CreateArray(elem))
case array => ArrayBuffer(Concat(Seq(array, CreateArray(elem))))
},
mutable)
mutable))
case _ =>
throw new SparkException(s"Unsupported ArrayBuffer op ${methodName}")
}
Expand Down Expand Up @@ -233,6 +253,7 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
const(state, (opcode - Opcode.ICONST_0).asInstanceOf[Int])
case Opcode.LCONST_0 | Opcode.LCONST_1 =>
const(state, (opcode - Opcode.LCONST_0).asInstanceOf[Long])
case Opcode.ATHROW => athrow(state)
case Opcode.DADD | Opcode.FADD | Opcode.IADD | Opcode.LADD => binary(state, Add(_, _))
case Opcode.DSUB | Opcode.FSUB | Opcode.ISUB | Opcode.LSUB => binary(state, Subtract(_, _))
case Opcode.DMUL | Opcode.FMUL | Opcode.IMUL | Opcode.LMUL => binary(state, Multiply(_, _))
Expand Down Expand Up @@ -310,6 +331,11 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
case _ => false
}

def isThrow: Boolean = opcode match {
case Opcode.ATHROW => true
case _ => false
}

//
// Handle instructions
//
Expand All @@ -323,6 +349,17 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
State(locals.updated(localsIndex, top), rest, cond, expr)
}

private def athrow(state: State): State = {
val State(locals, top :: rest, cond, expr) = state
if (!top.isInstanceOf[Repr.SparkExcept]) {
throw new SparkException("Unsupported type for athrow")
}
// Empty the stack and convert the internal representation of
// org.apache.spark.SparkException object to RaiseError, then push it to the
// stack.
State(locals, List(RaiseError(top.asInstanceOf[Repr.SparkExcept].message)), cond, expr)
}

private def const(state: State, value: Any): State = {
val State(locals, stack, cond, expr) = state
State(locals, Literal(value) :: stack, cond, expr)
Expand Down Expand Up @@ -360,6 +397,9 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
if (typeName.equals("java.lang.StringBuilder")) {
val State(locals, stack, cond, expr) = state
State(locals, Repr.StringBuilder() :: stack, cond, expr)
} else if (typeName.equals("org.apache.spark.SparkException")) {
val State(locals, stack, cond, expr) = state
State(locals, Repr.SparkExcept() :: stack, cond, expr)
} else if (typeName.equals("scala.collection.mutable.ArrayBuffer")) {
val State(locals, stack, cond, expr) = state
State(locals, Repr.ArrayBuffer() :: stack, cond, expr)
Expand Down Expand Up @@ -447,98 +487,84 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
val (args, rest) = getArgs(stack, paramTypes.length)
// We don't support arbitrary calls.
// We support only some math and string methods.
if (declaringClassName.equals("scala.math.package$")) {
State(locals,
mathOp(method.getName, args) :: rest,
cond,
expr)
} else if (declaringClassName.equals("scala.Predef$")) {
State(locals,
predefOp(method.getName, args) :: rest,
cond,
expr)
} else if (declaringClassName.equals("scala.Array$")) {
State(locals,
arrayOp(method.getName, args) :: rest,
cond,
expr)
} else if (declaringClassName.equals("scala.reflect.ClassTag$")) {
State(locals,
classTagOp(method.getName, args) :: rest,
cond,
expr)
} else if (declaringClassName.equals("scala.collection.mutable.ArrayBuffer$")) {
State(locals,
arrayBufferOp(method.getName, args) :: rest,
cond,
expr)
} else if (declaringClassName.equals("java.lang.Double")) {
State(locals, doubleOp(method.getName, args) :: rest, cond, expr)
} else if (declaringClassName.equals("java.lang.Float")) {
State(locals, floatOp(method.getName, args) :: rest, cond, expr)
} else if (declaringClassName.equals("java.lang.String")) {
State(locals, stringOp(method.getName, args) :: rest, cond, expr)
} else if (declaringClassName.equals("java.lang.StringBuilder")) {
if (!args.head.isInstanceOf[Repr.StringBuilder]) {
throw new SparkException("Internal error with StringBuilder")
}
val (retval, updateState) = args.head.asInstanceOf[Repr.StringBuilder]
.invoke(method.getName, args.tail)
val newState = State(locals, retval :: rest, cond, expr)
if (updateState) {
newState.remap(args.head, retval)
val retval = {
if (declaringClassName.equals("scala.math.package$")) {
Some((mathOp(method.getName, args), false))
} else if (declaringClassName.equals("scala.Predef$")) {
Some((predefOp(method.getName, args), false))
} else if (declaringClassName.equals("scala.Array$")) {
Some((arrayOp(method.getName, args), false))
} else if (declaringClassName.equals("scala.reflect.ClassTag$")) {
Some((classTagOp(method.getName, args), false))
} else if (declaringClassName.equals("scala.collection.mutable.ArrayBuffer$")) {
Some((arrayBufferOp(method.getName, args), false))
} else if (declaringClassName.equals("java.lang.Double")) {
Some((doubleOp(method.getName, args), false))
} else if (declaringClassName.equals("java.lang.Float")) {
Some((floatOp(method.getName, args), false))
} else if (declaringClassName.equals("java.lang.String")) {
Some((stringOp(method.getName, args), false))
} else if (declaringClassName.equals("java.lang.StringBuilder")) {
if (!args.head.isInstanceOf[Repr.StringBuilder]) {
throw new SparkException("Internal error with StringBuilder")
}
args.head.asInstanceOf[Repr.StringBuilder].invoke(method.getName, args.tail)
} else if (declaringClassName.equals("org.apache.spark.SparkException")) {
if (!args.head.isInstanceOf[Repr.SparkExcept]) {
throw new SparkException("Internal error with SparkException")
}
args.head.asInstanceOf[Repr.SparkExcept].invoke(method.getName, args.tail)
} else if (declaringClassName.equals("scala.collection.mutable.ArrayBuffer") ||
((args.nonEmpty && args.head.isInstanceOf[Repr.ArrayBuffer]) &&
((declaringClassName.equals("scala.collection.AbstractSeq") &&
opcode == Opcode.INVOKEVIRTUAL) ||
(declaringClassName.equals("scala.collection.TraversableOnce") &&
opcode == Opcode.INVOKEINTERFACE)))) {
if (!args.head.isInstanceOf[Repr.ArrayBuffer]) {
throw new SparkException(
s"Unexpected argument for ${declaringClassName}.${method.getName}")
}
args.head.asInstanceOf[Repr.ArrayBuffer].invoke(method.getName, args.tail)
} else if (declaringClassName.equals("java.time.format.DateTimeFormatter")) {
Some((dateTimeFormatterOp(method.getName, args), false))
} else if (declaringClassName.equals("java.time.LocalDateTime")) {
Some((localDateTimeOp(method.getName, args), false))
} else {
newState
}
} else if (declaringClassName.equals("scala.collection.mutable.ArrayBuffer") ||
((args.nonEmpty && args.head.isInstanceOf[Repr.ArrayBuffer]) &&
((declaringClassName.equals("scala.collection.AbstractSeq") &&
opcode == Opcode.INVOKEVIRTUAL) ||
(declaringClassName.equals("scala.collection.TraversableOnce") &&
opcode == Opcode.INVOKEINTERFACE)))) {
if (!args.head.isInstanceOf[Repr.ArrayBuffer]) {
throw new SparkException(
s"Unexpected argument for ${declaringClassName}.${method.getName}")
val mModifiers = method.getModifiers
val cModifiers = declaringClass.getModifiers
if (!javassist.Modifier.isEnum(mModifiers) &&
!javassist.Modifier.isInterface(mModifiers) &&
!javassist.Modifier.isNative(mModifiers) &&
!javassist.Modifier.isPackage(mModifiers) &&
!javassist.Modifier.isStrict(mModifiers) &&
!javassist.Modifier.isSynchronized(mModifiers) &&
!javassist.Modifier.isTransient(mModifiers) &&
!javassist.Modifier.isVarArgs(mModifiers) &&
!javassist.Modifier.isVolatile(mModifiers) &&
(javassist.Modifier.isFinal(mModifiers) ||
javassist.Modifier.isFinal(cModifiers))) {
val retval = {
if (javassist.Modifier.isStatic(mModifiers)) {
CatalystExpressionBuilder(method).compile(args)
} else {
CatalystExpressionBuilder(method).compile(args.tail, Some(args.head))
}
}
Some((retval.get, false))
} else {
// Other functions
throw new SparkException(
s"Unsupported invocation of ${declaringClassName}.${method.getName}")
}
}
val (retval, updateState) = args.head.asInstanceOf[Repr.ArrayBuffer]
.invoke(method.getName, args.tail)
val newState = State(locals, retval :: rest, cond, expr)
}
retval.fold(State(locals, rest, cond, expr)) { case (v, updateState) =>
val newState = State(locals, v :: rest, cond, expr)
if (updateState) {
newState.remap(args.head, retval)
newState.remap(args.head, v)
} else {
newState
}
} else if (declaringClassName.equals("java.time.format.DateTimeFormatter")) {
State(locals, dateTimeFormatterOp(method.getName, args) :: rest, cond, expr)
} else if (declaringClassName.equals("java.time.LocalDateTime")) {
State(locals, localDateTimeOp(method.getName, args) :: rest, cond, expr)
} else {
val mModifiers = method.getModifiers
val cModifiers = declaringClass.getModifiers
if (!javassist.Modifier.isEnum(mModifiers) &&
!javassist.Modifier.isInterface(mModifiers) &&
!javassist.Modifier.isNative(mModifiers) &&
!javassist.Modifier.isPackage(mModifiers) &&
!javassist.Modifier.isStrict(mModifiers) &&
!javassist.Modifier.isSynchronized(mModifiers) &&
!javassist.Modifier.isTransient(mModifiers) &&
!javassist.Modifier.isVarArgs(mModifiers) &&
!javassist.Modifier.isVolatile(mModifiers) &&
(javassist.Modifier.isFinal(mModifiers) ||
javassist.Modifier.isFinal(cModifiers))) {
val retval = {
if (javassist.Modifier.isStatic(mModifiers)) {
CatalystExpressionBuilder(method).compile(args)
} else {
CatalystExpressionBuilder(method).compile(args.tail, Some(args.head))
}
}
State(locals, retval.toList ::: rest, cond, expr)
} else {
// Other functions
throw new SparkException(
s"Unsupported invocation of ${declaringClassName}.${method.getName}")
}
}
}

Expand Down
Loading

0 comments on commit 5f449cd

Please sign in to comment.