Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add limited support for captured vars and athrow #5487

Merged
merged 4 commits into from
May 17, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion udf-compiler/src/main/scala/com/nvidia/spark/udf/CFG.scala
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ case class BB(instructionTable: SortedMap[Int, Instruction]) extends Logging {
val defaultState = state.copy(cond = simplify(defaultCondition))
newStates + (defaultSucc -> defaultState.merge(newStates.get(defaultSucc)))
case Opcode.IRETURN | Opcode.LRETURN | Opcode.FRETURN | Opcode.DRETURN |
Opcode.ARETURN | Opcode.RETURN => states
Opcode.ARETURN | Opcode.RETURN | Opcode.ATHROW => states
abellina marked this conversation as resolved.
Show resolved Hide resolved
case _ =>
val (0, successor) :: Nil = cfg.successor(this)
// The condition, stack and locals from the current BB state need to be
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,36 @@ case class CatalystExpressionBuilder(private val function: AnyRef) extends Loggi
// A basic block can have other branching instructions as the last instruction,
// otherwise.
if (basicBlock.lastInstruction.isReturn) {
newStates(basicBlock).expr
// When the compiled method throws exceptions, we combine the returned
abellina marked this conversation as resolved.
Show resolved Hide resolved
// expression (expr) with exceptions (except0, except1, except2.., exceptn)
// based on exception conditions (cond0, cond1, cond2, .., condn) as follows:
// If
// |
// +------+------+
// | | |
// cond0 except0 If
// |
// +------+------+
// | | |
// cond1 except1 If
// |
// +------+------+
// | | |
// cond2 except2 If
// |
// .
// .
// .
// +------+------+
// | | |
// condn exceptn expr
newStates.foldRight(newStates(basicBlock).expr) { case ((bb, s), expr) =>
if (bb.lastInstruction.isThrow) {
Some(If(s.cond, s.stack.head, expr.get))
} else {
expr
}
}
} else {
// account for this block in visited
val newVisited = visited + basicBlock
Expand Down Expand Up @@ -327,6 +356,24 @@ object CatalystExpressionBuilder extends Logging {
case _ => expr
}
}
case And(c1@GreaterThanOrEqual(s1, Literal(v1, t1)),
c2@GreaterThanOrEqual(s2, Literal(v2, t2))) if s1 == s2 && t1 == t2 => {
t1 match {
case IntegerType =>
if (v1.asInstanceOf[Int] > v2.asInstanceOf[Int]) {
c1
} else {
c2
}
case LongType =>
if (v1.asInstanceOf[Long] > v2.asInstanceOf[Long]) {
c1
} else {
c2
}
case _ => expr
}
}
case And(c1@GreaterThan(s1, Literal(v1, t1)),
c2@GreaterThanOrEqual(s2, Literal(v2, t2))) if s1 == s2 && t1 == t2 => {
t1 match {
Expand Down
218 changes: 122 additions & 96 deletions udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,37 @@ private[udf] object Repr {
extends CompilerInternal("java.lang.StringBuilder") {
override def dataType: DataType = string.dataType

def invoke(methodName: String, args: List[Expression]): (Expression, Boolean) = {
def invoke(methodName: String, args: List[Expression]): Option[(Expression, Boolean)] = {
methodName match {
case "StringBuilder" => (this, false)
case "append" => (StringBuilder(Concat(string :: args)), true)
case "toString" => (string, false)
case "StringBuilder" => None
case "append" => Some((StringBuilder(Concat(string :: args)), true))
case "toString" => Some((string, false))
case _ =>
throw new SparkException(s"Unsupported StringBuilder op ${methodName}")
}
}
}

// Internal representation of org.apache.spark.SparkException
case class SparkExcept(var message: Expression = Literal.default(StringType))
extends CompilerInternal("org.apache.spark.SparkException") {

def invoke(methodName: String, args: List[Expression]): Option[(Expression, Boolean)] = {
methodName match {
case "SparkException" =>
if (args.length > 1) {
throw new SparkException("Unsupported SparkException construction " +
"with multiple arguments")
} else {
message = args.head
Some((this, false))
}
case _ =>
throw new SparkException(s"Unsupported SparkException op ${methodName}")
}
}
}

// Internal representation of the bytecode instruction getstatic.
// This class is needed because we can't represent getstatic in Catalyst, but
// we need the getstatic information to handle some method calls
Expand All @@ -96,11 +116,11 @@ private[udf] object Repr {
extends CompilerInternal("scala.collection.mutable.ArrayBuffer") {
override def dataType: DataType = arrayBuffer.dataType

def invoke(methodName: String, args: List[Expression]): (Expression, Boolean) = {
def invoke(methodName: String, args: List[Expression]): Option[(Expression, Boolean)] = {
methodName match {
case "ArrayBuffer" => (this, false)
case "distinct" => (ArrayBuffer(ArrayDistinct(arrayBuffer)), false)
case "toArray" => (arrayBuffer, false)
case "ArrayBuffer" => Some((this, false))
case "distinct" => Some((ArrayBuffer(ArrayDistinct(arrayBuffer)), false))
case "toArray" => Some((arrayBuffer, false))
case "$plus$eq" | "$colon$plus" =>
val mutable = {
if (methodName == "$plus$eq") {
Expand Down Expand Up @@ -138,11 +158,11 @@ private[udf] object Repr {
}
}
}
(arrayBuffer match {
Some((arrayBuffer match {
case CreateArray(Nil, _) => ArrayBuffer(CreateArray(elem))
case array => ArrayBuffer(Concat(Seq(array, CreateArray(elem))))
},
mutable)
mutable))
case _ =>
throw new SparkException(s"Unsupported ArrayBuffer op ${methodName}")
}
Expand Down Expand Up @@ -233,6 +253,7 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
const(state, (opcode - Opcode.ICONST_0).asInstanceOf[Int])
case Opcode.LCONST_0 | Opcode.LCONST_1 =>
const(state, (opcode - Opcode.LCONST_0).asInstanceOf[Long])
case Opcode.ATHROW => athrow(state)
case Opcode.DADD | Opcode.FADD | Opcode.IADD | Opcode.LADD => binary(state, Add(_, _))
case Opcode.DSUB | Opcode.FSUB | Opcode.ISUB | Opcode.LSUB => binary(state, Subtract(_, _))
case Opcode.DMUL | Opcode.FMUL | Opcode.IMUL | Opcode.LMUL => binary(state, Multiply(_, _))
Expand Down Expand Up @@ -310,6 +331,11 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
case _ => false
}

def isThrow: Boolean = opcode match {
case Opcode.ATHROW => true
case _ => false
}

//
// Handle instructions
//
Expand All @@ -323,6 +349,17 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
State(locals.updated(localsIndex, top), rest, cond, expr)
}

private def athrow(state: State): State = {
val State(locals, top :: rest, cond, expr) = state
if (!top.isInstanceOf[Repr.SparkExcept]) {
throw new SparkException("Unsupported type for athrow")
}
// Empty the stack and convert the internal representation of
// org.apache.spark.SparkException object to RaiseError, then push it to the
// stack.
State(locals, List(RaiseError(top.asInstanceOf[Repr.SparkExcept].message)), cond, expr)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe one approach here would be to code up a new catalyst node RaiseErrorV2 or RaiseSpecificError or something like this, where it has the CPU eval and doGenCode, as it would be very much like RaiseError, but it has an extra parameter which is the specific type of exception to throw. It would then instantiate this exception and throw the specific type.

For the gpu, we could replace this new node with the GPU specific type, and for the CPU it would work the way the caller intended. @jlowe @revans2 for some 👀

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After discussing with @revans2 we were thinking it's OK if the exceptions are not the same for now. Could we make a note of this in: https://github.com/NVIDIA/spark-rapids/blob/branch-22.06/docs/additional-functionality/udf-to-catalyst-expressions.md?

We may need to revisit this if someone needs SparkException thrown in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated udf-to-catalyst-expressions.md

}

private def const(state: State, value: Any): State = {
val State(locals, stack, cond, expr) = state
State(locals, Literal(value) :: stack, cond, expr)
Expand Down Expand Up @@ -360,6 +397,9 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
if (typeName.equals("java.lang.StringBuilder")) {
val State(locals, stack, cond, expr) = state
State(locals, Repr.StringBuilder() :: stack, cond, expr)
} else if (typeName.equals("org.apache.spark.SparkException")) {
val State(locals, stack, cond, expr) = state
State(locals, Repr.SparkExcept() :: stack, cond, expr)
} else if (typeName.equals("scala.collection.mutable.ArrayBuffer")) {
val State(locals, stack, cond, expr) = state
State(locals, Repr.ArrayBuffer() :: stack, cond, expr)
Expand Down Expand Up @@ -447,98 +487,84 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
val (args, rest) = getArgs(stack, paramTypes.length)
// We don't support arbitrary calls.
// We support only some math and string methods.
if (declaringClassName.equals("scala.math.package$")) {
State(locals,
mathOp(method.getName, args) :: rest,
cond,
expr)
} else if (declaringClassName.equals("scala.Predef$")) {
State(locals,
predefOp(method.getName, args) :: rest,
cond,
expr)
} else if (declaringClassName.equals("scala.Array$")) {
State(locals,
arrayOp(method.getName, args) :: rest,
cond,
expr)
} else if (declaringClassName.equals("scala.reflect.ClassTag$")) {
State(locals,
classTagOp(method.getName, args) :: rest,
cond,
expr)
} else if (declaringClassName.equals("scala.collection.mutable.ArrayBuffer$")) {
State(locals,
arrayBufferOp(method.getName, args) :: rest,
cond,
expr)
} else if (declaringClassName.equals("java.lang.Double")) {
State(locals, doubleOp(method.getName, args) :: rest, cond, expr)
} else if (declaringClassName.equals("java.lang.Float")) {
State(locals, floatOp(method.getName, args) :: rest, cond, expr)
} else if (declaringClassName.equals("java.lang.String")) {
State(locals, stringOp(method.getName, args) :: rest, cond, expr)
} else if (declaringClassName.equals("java.lang.StringBuilder")) {
if (!args.head.isInstanceOf[Repr.StringBuilder]) {
throw new SparkException("Internal error with StringBuilder")
}
val (retval, updateState) = args.head.asInstanceOf[Repr.StringBuilder]
.invoke(method.getName, args.tail)
val newState = State(locals, retval :: rest, cond, expr)
if (updateState) {
newState.remap(args.head, retval)
val retval = {
if (declaringClassName.equals("scala.math.package$")) {
Some((mathOp(method.getName, args), false))
} else if (declaringClassName.equals("scala.Predef$")) {
Some((predefOp(method.getName, args), false))
} else if (declaringClassName.equals("scala.Array$")) {
Some((arrayOp(method.getName, args), false))
} else if (declaringClassName.equals("scala.reflect.ClassTag$")) {
Some((classTagOp(method.getName, args), false))
} else if (declaringClassName.equals("scala.collection.mutable.ArrayBuffer$")) {
Some((arrayBufferOp(method.getName, args), false))
} else if (declaringClassName.equals("java.lang.Double")) {
Some((doubleOp(method.getName, args), false))
} else if (declaringClassName.equals("java.lang.Float")) {
Some((floatOp(method.getName, args), false))
} else if (declaringClassName.equals("java.lang.String")) {
Some((stringOp(method.getName, args), false))
} else if (declaringClassName.equals("java.lang.StringBuilder")) {
if (!args.head.isInstanceOf[Repr.StringBuilder]) {
throw new SparkException("Internal error with StringBuilder")
}
args.head.asInstanceOf[Repr.StringBuilder].invoke(method.getName, args.tail)
} else if (declaringClassName.equals("org.apache.spark.SparkException")) {
if (!args.head.isInstanceOf[Repr.SparkExcept]) {
throw new SparkException("Internal error with SparkException")
}
args.head.asInstanceOf[Repr.SparkExcept].invoke(method.getName, args.tail)
} else if (declaringClassName.equals("scala.collection.mutable.ArrayBuffer") ||
((args.nonEmpty && args.head.isInstanceOf[Repr.ArrayBuffer]) &&
((declaringClassName.equals("scala.collection.AbstractSeq") &&
opcode == Opcode.INVOKEVIRTUAL) ||
(declaringClassName.equals("scala.collection.TraversableOnce") &&
opcode == Opcode.INVOKEINTERFACE)))) {
if (!args.head.isInstanceOf[Repr.ArrayBuffer]) {
throw new SparkException(
s"Unexpected argument for ${declaringClassName}.${method.getName}")
}
args.head.asInstanceOf[Repr.ArrayBuffer].invoke(method.getName, args.tail)
} else if (declaringClassName.equals("java.time.format.DateTimeFormatter")) {
Some((dateTimeFormatterOp(method.getName, args), false))
} else if (declaringClassName.equals("java.time.LocalDateTime")) {
Some((localDateTimeOp(method.getName, args), false))
} else {
newState
}
} else if (declaringClassName.equals("scala.collection.mutable.ArrayBuffer") ||
((args.nonEmpty && args.head.isInstanceOf[Repr.ArrayBuffer]) &&
((declaringClassName.equals("scala.collection.AbstractSeq") &&
opcode == Opcode.INVOKEVIRTUAL) ||
(declaringClassName.equals("scala.collection.TraversableOnce") &&
opcode == Opcode.INVOKEINTERFACE)))) {
if (!args.head.isInstanceOf[Repr.ArrayBuffer]) {
throw new SparkException(
s"Unexpected argument for ${declaringClassName}.${method.getName}")
val mModifiers = method.getModifiers
val cModifiers = declaringClass.getModifiers
if (!javassist.Modifier.isEnum(mModifiers) &&
!javassist.Modifier.isInterface(mModifiers) &&
!javassist.Modifier.isNative(mModifiers) &&
!javassist.Modifier.isPackage(mModifiers) &&
!javassist.Modifier.isStrict(mModifiers) &&
!javassist.Modifier.isSynchronized(mModifiers) &&
!javassist.Modifier.isTransient(mModifiers) &&
!javassist.Modifier.isVarArgs(mModifiers) &&
!javassist.Modifier.isVolatile(mModifiers) &&
(javassist.Modifier.isFinal(mModifiers) ||
javassist.Modifier.isFinal(cModifiers))) {
val retval = {
if (javassist.Modifier.isStatic(mModifiers)) {
CatalystExpressionBuilder(method).compile(args)
} else {
CatalystExpressionBuilder(method).compile(args.tail, Some(args.head))
}
}
Some((retval.get, false))
} else {
// Other functions
throw new SparkException(
s"Unsupported invocation of ${declaringClassName}.${method.getName}")
}
}
val (retval, updateState) = args.head.asInstanceOf[Repr.ArrayBuffer]
.invoke(method.getName, args.tail)
val newState = State(locals, retval :: rest, cond, expr)
}
retval.fold(State(locals, rest, cond, expr)) { case (v, updateState) =>
val newState = State(locals, v :: rest, cond, expr)
if (updateState) {
newState.remap(args.head, retval)
newState.remap(args.head, v)
} else {
newState
}
} else if (declaringClassName.equals("java.time.format.DateTimeFormatter")) {
State(locals, dateTimeFormatterOp(method.getName, args) :: rest, cond, expr)
} else if (declaringClassName.equals("java.time.LocalDateTime")) {
State(locals, localDateTimeOp(method.getName, args) :: rest, cond, expr)
} else {
val mModifiers = method.getModifiers
val cModifiers = declaringClass.getModifiers
if (!javassist.Modifier.isEnum(mModifiers) &&
!javassist.Modifier.isInterface(mModifiers) &&
!javassist.Modifier.isNative(mModifiers) &&
!javassist.Modifier.isPackage(mModifiers) &&
!javassist.Modifier.isStrict(mModifiers) &&
!javassist.Modifier.isSynchronized(mModifiers) &&
!javassist.Modifier.isTransient(mModifiers) &&
!javassist.Modifier.isVarArgs(mModifiers) &&
!javassist.Modifier.isVolatile(mModifiers) &&
(javassist.Modifier.isFinal(mModifiers) ||
javassist.Modifier.isFinal(cModifiers))) {
val retval = {
if (javassist.Modifier.isStatic(mModifiers)) {
CatalystExpressionBuilder(method).compile(args)
} else {
CatalystExpressionBuilder(method).compile(args.tail, Some(args.head))
}
}
State(locals, retval.toList ::: rest, cond, expr)
} else {
// Other functions
throw new SparkException(
s"Unsupported invocation of ${declaringClassName}.${method.getName}")
}
}
}

Expand Down
Loading