Skip to content

Commit

Permalink
udf spec (NVIDIA#1150)
Browse files Browse the repository at this point in the history
* Conditional array buffer processing test case

Signed-off-by: Allen Xu <allxu@nvidia.com>

* Add support for interface method ref lookup

Signed-off-by: Sean Lee <selee@nvidia.com>

* Handle the pop instruction

Signed-off-by: Sean Lee <selee@nvidia.com>

* Handle the aconst_null instruction

Signed-off-by: Sean Lee <selee@nvidia.com>

* Ensure type consistency for conditional expressions.

Catalyst If(c, t, f) expression requires t and f to have the same type.

Signed-off-by: Sean Lee <selee@nvidia.com>

* Use Repr.GetStatic to represent getstatic information.

We were using Integer literal, but that can cause an error that the
getstatic index is incorrectly evaluated as a value.

Signed-off-by: Sean Lee <selee@nvidia.com>

* Add support for some ArrayBuffer ops.

* new ArrayBuffer()
* x.distinct
* x.toArray
* lhs += rhs
* lhs :+ rhs

This commit also adds the compiler internal representation of
CanBuildFrom, which is needed to support :+.

Signed-off-by: Sean Lee <selee@nvidia.com>

Co-authored-by: Allen Xu <allxu@nvidia.com>
Co-authored-by: wjxiz <wjxiz1992@gmail.com>
  • Loading branch information
3 people authored Nov 24, 2020
1 parent 382abde commit a02827b
Show file tree
Hide file tree
Showing 6 changed files with 270 additions and 61 deletions.
8 changes: 7 additions & 1 deletion docs/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -419,4 +419,10 @@ When translating UDFs to Catalyst expressions, the supported UDF functions are l
| | Array.empty[Float] |
| | Array.empty[Double] |
| | Array.empty[String] |
| Method call | Only if the method being called <ol><li>consists of operations supported by the UDF compiler, and</li><li>is one of the following:<ul><li>a final method, or</li><li>a method in a final class, or</li><li>a method in a final object</li></ul></li></ol> |
| Arraybuffer | new ArrayBuffer() |
| | x.distinct |
| | x.toArray |
| | lhs += rhs |
| | lhs :+ rhs |
| Method call | Only if the method being called <ol><li>consists of operations supported by the UDF compiler, and</li><li>is one of the folllowing:<ul><li>a final method, or</li><li>a method in a final class, or</li><li>a method in a final object</li></ul></li></ol> |

Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ object CatalystExpressionBuilder extends Logging {
simplifyExpr(If(simplifyExpr(c),
simplifyExpr(Cast(t, BooleanType, tz)),
simplifyExpr(Cast(f, BooleanType, tz))))
case If(c, Repr.ArrayBuffer(t), Repr.ArrayBuffer(f)) => Repr.ArrayBuffer(If(c, t, f))
case _ => expr
}
logDebug(s"[CatalystExpressionBuilder] simplify: ${expr} ==> ${res}")
Expand Down
187 changes: 155 additions & 32 deletions udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import javassist.bytecode.{CodeIterator, Opcode}

import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.analysis.TypeCoercion
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -72,6 +73,81 @@ private[udf] object Repr {
var string: Expression = Literal.default(StringType)
}

// Internal representation of the bytecode instruction getstatic.
// This class is needed because we can't represent getstatic in Catalyst, but
// we need the getstatic information to handle some method calls
// (see [[Instruction.mathOp]] for example).
case class GetStatic(lambdaReflection: LambdaReflection, index: Int)
extends CompilerInternal("bytecode$getstatic") {
def getTypeName: String = {
lambdaReflection.lookupField(index).getType.getName
}
}

// Internal representation of CanBuildFrom.
// This class is needed because we can't represent CanBuildFrom in Catalyst,
// but we need the getstatic information to handle some method calls
// (see the $colon$plus case of [[ArrayBuffer.invoke]] for example).
case class CanBuildFrom(objref: Repr.GetStatic)
extends CompilerInternal("scala.collection.generic.CanBuildFrom")

case class ArrayBuffer(var arrayBuffer: Expression = CreateArray(Seq.empty[Expression]))
extends CompilerInternal("scala.collection.mutable.ArrayBuffer") {
override def dataType: DataType = arrayBuffer.dataType

def invoke(methodName: String, args: List[Expression]): (Expression, Boolean) = {
methodName match {
case "ArrayBuffer" => (this, false)
case "distinct" => (ArrayBuffer(ArrayDistinct(arrayBuffer)), false)
case "toArray" => (arrayBuffer, false)
case "$plus$eq" | "$colon$plus" =>
val mutable = {
if (methodName == "$plus$eq") {
if (args.length != 1) {
throw new SparkException(
s"ArrayBuffer.+= operation expects 1 argument, " +
s"but instead got ${args.length} argument(s)")
}
true
} else {
if (args.length != 2) {
throw new SparkException(
s"ArrayBuffer.:+ operation expects 2 arguments, " +
s"but instead got ${args.length} argument(s)")
} else if (!args.last.isInstanceOf[Repr.CanBuildFrom] ||
!args.last.asInstanceOf[Repr.CanBuildFrom].objref.getTypeName
.equals("scala.collection.mutable.ArrayBuffer$")) {
throw new SparkException(
s"ArrayBuffer.:+ operation expects CanBuildFrom for the last argument")
}
false
}
}
// Implement this with Concat as Catalyst doesn't have an operator for
// append.
val arrElemType = arrayBuffer.dataType.asInstanceOf[ArrayType].elementType
val elem = {
val elemType = args.head.dataType
val commonType = TypeCoercion.findTightestCommonType(arrElemType, elemType)
commonType.fold(throw new SparkException(s"type check failure")){ t =>
if (elemType == t) {
Seq(args.head)
} else {
Seq(Cast(args.head, t))
}
}
}
(arrayBuffer match {
case CreateArray(Nil, _) => ArrayBuffer(CreateArray(elem))
case array => ArrayBuffer(Concat(Seq(array, CreateArray(elem))))
},
mutable)
case _ =>
throw new SparkException(s"Unsupported ArrayBuffer op ${methodName}")
}
}
}

case class DateTimeFormatter private (private[Repr] val pattern: Expression)
extends CompilerInternal("java.time.format.DateTimeFormatter") {
def invoke(methodName: String, args: List[Expression]): Expression = {
Expand Down Expand Up @@ -142,6 +218,8 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
Opcode.ISTORE_3 | Opcode.LSTORE_3 => store(state, 3)
case Opcode.ASTORE | Opcode.DSTORE | Opcode.FSTORE |
Opcode.ISTORE | Opcode.LSTORE => store(state, operand)
case Opcode.ACONST_NULL =>
const(state, null)
case Opcode.DCONST_0 | Opcode.DCONST_1 =>
const(state, (opcode - Opcode.DCONST_0).asInstanceOf[Double])
case Opcode.FCONST_0 | Opcode.FCONST_1 | Opcode.FCONST_2 =>
Expand Down Expand Up @@ -172,7 +250,8 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
case Opcode.LCMP => cmp(state)
case Opcode.LDC | Opcode.LDC_W | Opcode.LDC2_W => ldc(lambdaReflection, state)
case Opcode.DUP => dup(state)
case Opcode.GETSTATIC => getstatic(state)
case Opcode.POP => pop(state)
case Opcode.GETSTATIC => getstatic(lambdaReflection, state)
case Opcode.NEW => newObj(lambdaReflection, state)
// Cast instructions
case Opcode.I2B => cast(state, ByteType)
Expand Down Expand Up @@ -207,13 +286,13 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
state.copy(expr = Some(state.stack.head))
// Call instructions
case Opcode.INVOKESTATIC =>
invoke(lambdaReflection, state,
invoke(opcode, lambdaReflection, state,
(stack, n) => {
val (args, rest) = stack.splitAt(n)
(args.reverse, rest)
})
case Opcode.INVOKEVIRTUAL | Opcode.INVOKESPECIAL =>
invoke(lambdaReflection, state,
case Opcode.INVOKEVIRTUAL | Opcode.INVOKESPECIAL | Opcode.INVOKEINTERFACE =>
invoke(opcode, lambdaReflection, state,
(stack, n) => {
val (args, rest) = stack.splitAt(n + 1)
(args.reverse, rest)
Expand Down Expand Up @@ -269,20 +348,28 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
State(locals, top :: top :: rest, cond, expr)
}

private def pop(state: State): State = {
val State(locals, top :: rest, cond, expr) = state
State(locals, rest, cond, expr)
}

private def newObj(lambdaReflection: LambdaReflection,
state: State): State = {
val typeName = lambdaReflection.lookupClassName(operand)
if (typeName.equals("java.lang.StringBuilder")) {
val State(locals, stack, cond, expr) = state
State(locals, Repr.StringBuilder() :: stack, cond, expr)
} else if (typeName.equals("scala.collection.mutable.ArrayBuffer")) {
val State(locals, stack, cond, expr) = state
State(locals, Repr.ArrayBuffer() :: stack, cond, expr)
} else {
throw new SparkException("Unsupported type for new:" + typeName)
}
}

private def getstatic(state: State): State = {
private def getstatic(lambdaReflection: LambdaReflection, state: State): State = {
val State(locals, stack, cond, expr) = state
State(locals, Literal(operand) :: stack, cond, expr)
State(locals, Repr.GetStatic(lambdaReflection, operand) :: stack, cond, expr)
}

private def cmp(state: State, default: Int): State = {
Expand Down Expand Up @@ -319,10 +406,15 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
private def checkcast(lambdaReflection: LambdaReflection, state: State): State = {
val State(locals, top :: rest, cond, expr) = state
val typeName = lambdaReflection.lookupClassName(operand)
if (LambdaReflection.parseTypeSig(typeName) != top.dataType) {
throw new SparkException(s"checkcast failed: ${typeName}")
LambdaReflection.parseTypeSig(typeName).fold{
// Defer the check until top is actually used.
state
}{ t =>
if (t != top.dataType) {
throw new SparkException(s"checkcast failed: ${typeName} ${t}")
}
state
}
state
}

private def ifCmp(state: State,
Expand All @@ -343,7 +435,7 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
State(locals, rest, cond, Some(top))
}

private def invoke(lambdaReflection: LambdaReflection, state: State,
private def invoke(opcode:Int, lambdaReflection: LambdaReflection, state: State,
getArgs: (List[Expression], Int) =>
(List[Expression], List[Expression])): State = {
val State(locals, stack, cond, expr) = state
Expand Down Expand Up @@ -374,6 +466,11 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
classTagOp(lambdaReflection, method.getName, args) :: rest,
cond,
expr)
} else if (declaringClassName.equals("scala.collection.mutable.ArrayBuffer$")) {
State(locals,
arrayBufferOp(lambdaReflection, method.getName, args) :: rest,
cond,
expr)
} else if (declaringClassName.equals("java.lang.Double")) {
State(locals, doubleOp(method.getName, args) :: rest, cond, expr)
} else if (declaringClassName.equals("java.lang.Float")) {
Expand All @@ -387,6 +484,24 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
val retval = args.head.asInstanceOf[Repr.StringBuilder]
.invoke(method.getName, args.tail)
State(locals, retval :: rest, cond, expr)
} else if (declaringClassName.equals("scala.collection.mutable.ArrayBuffer") ||
((!args.isEmpty && args.head.isInstanceOf[Repr.ArrayBuffer]) &&
((declaringClassName.equals("scala.collection.AbstractSeq") &&
opcode == Opcode.INVOKEVIRTUAL) ||
(declaringClassName.equals("scala.collection.TraversableOnce") &&
opcode == Opcode.INVOKEINTERFACE)))) {
if (!args.head.isInstanceOf[Repr.ArrayBuffer]) {
throw new SparkException(
s"Unexpected argument for ${declaringClassName}.${method.getName}")
}
val (retval, updateState) = args.head.asInstanceOf[Repr.ArrayBuffer]
.invoke(method.getName, args.tail)
val newState = State(locals, retval :: rest, cond, expr)
if (updateState) {
newState.remap(args.head, retval)
} else {
newState
}
} else if (declaringClassName.equals("java.time.format.DateTimeFormatter")) {
State(locals, dateTimeFormatterOp(method.getName, args) :: rest, cond, expr)
} else if (declaringClassName.equals("java.time.LocalDateTime")) {
Expand Down Expand Up @@ -449,9 +564,8 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
}
// Make sure that the objref is scala.math.package$.
args.head match {
case IntegerLiteral(index) =>
if (!lambdaReflection.lookupField(index.asInstanceOf[Int])
.getType.getName.equals("scala.math.package$")) {
case getstatic: Repr.GetStatic =>
if (!getstatic.getTypeName.equals("scala.math.package$")) {
throw new SparkException("Unsupported math function objref: " + args.head)
}
case _ =>
Expand Down Expand Up @@ -483,9 +597,8 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
methodName: String, args: List[Expression]): Expression = {
// Make sure that the objref is scala.math.package$.
args.head match {
case IntegerLiteral(index) =>
if (!lambdaReflection.lookupField(index.asInstanceOf[Int])
.getType.getName.equals("scala.Predef$")) {
case getstatic: Repr.GetStatic =>
if (!getstatic.getTypeName.equals("scala.Predef$")) {
throw new SparkException("Unsupported predef function objref: " + args.head)
}
case _ =>
Expand All @@ -494,10 +607,10 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
// Translate to Catalyst
methodName match {
case "double2Double" =>
checkArgs(methodName, List(IntegerType, DoubleType), args)
checkArgs(methodName, List(DoubleType), args.tail)
args.last
case "float2Float" =>
checkArgs(methodName, List(IntegerType, FloatType), args)
checkArgs(methodName, List(FloatType), args.tail)
args.last
case _ => throw new SparkException("Unsupported predef function: " + methodName)
}
Expand All @@ -507,9 +620,8 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
methodName: String, args: List[Expression]): Expression = {
// Make sure that the objref is scala.math.package$.
args.head match {
case IntegerLiteral(index) =>
if (!lambdaReflection.lookupField(index.asInstanceOf[Int])
.getType.getName.equals("scala.Array$")) {
case getstatic: Repr.GetStatic =>
if (!getstatic.getTypeName.equals("scala.Array$")) {
throw new SparkException("Unsupported array function objref: " + args.head)
}
case _ =>
Expand Down Expand Up @@ -550,9 +662,8 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
methodName: String, args: List[Expression]): Expression = {
// Make sure that the objref is scala.math.package$.
args.head match {
case IntegerLiteral(index) =>
if (!lambdaReflection.lookupField(index.asInstanceOf[Int])
.getType.getName.equals("scala.reflect.ClassTag$")) {
case getstatic: Repr.GetStatic =>
if (!getstatic.getTypeName.equals("scala.reflect.ClassTag$")) {
throw new SparkException("Unsupported classTag function objref: " + args.head)
}
case _ =>
Expand All @@ -561,34 +672,46 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
// Translate to Catalyst
methodName match {
case "Boolean" =>
checkArgs(methodName, List(IntegerType), args)
new Repr.ClassTag(scala.reflect.ClassTag.Boolean)
case "Byte" =>
checkArgs(methodName, List(IntegerType), args)
new Repr.ClassTag(scala.reflect.ClassTag.Byte)
case "Short" =>
checkArgs(methodName, List(IntegerType), args)
new Repr.ClassTag(scala.reflect.ClassTag.Short)
case "Int" =>
checkArgs(methodName, List(IntegerType), args)
new Repr.ClassTag(scala.reflect.ClassTag.Int)
case "Long" =>
checkArgs(methodName, List(IntegerType), args)
new Repr.ClassTag(scala.reflect.ClassTag.Long)
case "Float" =>
checkArgs(methodName, List(IntegerType), args)
new Repr.ClassTag(scala.reflect.ClassTag.Float)
case "Double" =>
checkArgs(methodName, List(IntegerType), args)
new Repr.ClassTag(scala.reflect.ClassTag.Double)
case "apply" =>
checkArgs(methodName, List(IntegerType, StringType), args)
checkArgs(methodName, List(StringType), args.tail)
new Repr.ClassTag(scala.reflect.ClassTag.apply(
LambdaReflection.getClass(args.last.toString)))
case _ => throw new SparkException("Unsupported classTag function: " + methodName)
}
}

private def arrayBufferOp(lambdaReflection: LambdaReflection,
methodName: String, args: List[Expression]): Expression = {
// Make sure that the objref is scala.math.package$.
args.head match {
case getstatic: Repr.GetStatic =>
if (!getstatic.getTypeName.equals("scala.collection.mutable.ArrayBuffer$")) {
throw new SparkException("Unsupported arrayBuffer function objref: " + args.head)
}
case _ =>
throw new SparkException("Unsupported arrayBuffer function objref: " + args.head)
}
// Translate to Catalyst
methodName match {
case "canBuildFrom" =>
new Repr.CanBuildFrom(args.head.asInstanceOf[Repr.GetStatic])
case _ => throw new SparkException("Unsupported arrayBuffer function: " + methodName)
}
}

private def doubleOp(methodName: String, args: List[Expression]): Expression = {
methodName match {
case "isNaN" =>
Expand Down
Loading

0 comments on commit a02827b

Please sign in to comment.