Skip to content

Commit

Permalink
Inline method calls when possible.
Browse files Browse the repository at this point in the history
With this change, method calls are inlined only if the method being
called
  1. consists of operations supported by the UDF compiler, and
  2. is one of the folllowing:
    * a final method, or
    * a method in a final class, or
    * a method in a final class

Signed-off-by: Sean Lee <selee@nvidia.com>
  • Loading branch information
seanprime7 committed Oct 26, 2020
1 parent 2e81d82 commit b37dba8
Show file tree
Hide file tree
Showing 6 changed files with 279 additions and 51 deletions.
1 change: 1 addition & 0 deletions docs/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -384,3 +384,4 @@ When translating UDFs to Catalyst expressions, the supported UDF functions are l
| | Array.empty[Float] |
| | Array.empty[Double] |
| | Array.empty[String] |
| Method call | Only if the method being called <ol><li>consists of operations supported by the UDF compiler, and</li><li>is one of the folllowing:<ul><li>a final method, or</li><li>a method in a final class, or</li><li>a method in a final object</li></ul></li></ol> |
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ case class CatalystExpressionBuilder(private val function: AnyRef) extends Loggi
* @param children a sequence of catalyst arguments to the udf.
* @return the compiled expression, optionally
*/
def compile(children: Seq[Expression]): Option[Expression] = {
def compile(children: Seq[Expression], objref: Option[Expression] = None): Option[Expression] = {

// create starting state, this will be:
// State([children expressions], [empty stack], cond = true, expr = None)
val entryState = State.makeStartingState(lambdaReflection, children)
val entryState = State.makeStartingState(lambdaReflection, children, objref)

// pick first of the Basic Blocks, and start recursing
val entryBlock = cfg.basicBlocks.head
Expand Down
38 changes: 32 additions & 6 deletions udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@ package com.nvidia.spark.udf

import CatalystExpressionBuilder.simplify
import java.nio.charset.Charset
import javassist.Modifier
import javassist.bytecode.{CodeIterator, Opcode}

import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._


private object Repr {
private[udf] object Repr {

abstract class CompilerInternal(name: String) extends Expression {
override def dataType: DataType = {
Expand Down Expand Up @@ -110,6 +110,8 @@ private object Repr {

case class ClassTag[T](classTag: scala.reflect.ClassTag[T])
extends CompilerInternal("scala.reflect.ClassTag")

case class UnknownCapturedArg() extends CompilerInternal("unknown captured arg")
}

/**
Expand Down Expand Up @@ -346,7 +348,8 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
(List[Expression], List[Expression])): State = {
val State(locals, stack, cond, expr) = state
val method = lambdaReflection.lookupBehavior(operand)
val declaringClassName = method.getDeclaringClass.getName
val declaringClass = method.getDeclaringClass
val declaringClassName = declaringClass.getName
val paramTypes = method.getParameterTypes
val (args, rest) = getArgs(stack, paramTypes.length)
// We don't support arbitrary calls.
Expand Down Expand Up @@ -389,9 +392,32 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
} else if (declaringClassName.equals("java.time.LocalDateTime")) {
State(locals, localDateTimeOp(method.getName, args) :: rest, cond, expr)
} else {
// Other functions
throw new SparkException(
s"Unsupported instruction: ${Opcode.INVOKEVIRTUAL} ${declaringClassName}")
val mModifiers = method.getModifiers
val cModifiers = declaringClass.getModifiers
if (!javassist.Modifier.isEnum(mModifiers) &&
!javassist.Modifier.isInterface(mModifiers) &&
!javassist.Modifier.isNative(mModifiers) &&
!javassist.Modifier.isPackage(mModifiers) &&
!javassist.Modifier.isStrict(mModifiers) &&
!javassist.Modifier.isSynchronized(mModifiers) &&
!javassist.Modifier.isTransient(mModifiers) &&
!javassist.Modifier.isVarArgs(mModifiers) &&
!javassist.Modifier.isVolatile(mModifiers) &&
(javassist.Modifier.isFinal(mModifiers) ||
javassist.Modifier.isFinal(cModifiers))) {
val retval = {
if (javassist.Modifier.isStatic(mModifiers)) {
CatalystExpressionBuilder(method).compile(args)
} else {
CatalystExpressionBuilder(method).compile(args.tail, Some(args.head))
}
}
State(locals, retval.toList ::: rest, cond, expr)
} else {
// Other functions
throw new SparkException(
s"Unsupported invocation of ${declaringClassName}.${method.getName}")
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@

package com.nvidia.spark.udf

import Repr.UnknownCapturedArg
import java.lang.invoke.SerializedLambda

import javassist.{ClassClassPath, ClassPool, CtBehavior, CtClass, CtField}
import javassist.bytecode.{CodeIterator, ConstPool, Descriptor}
import javassist.{ClassClassPath, ClassPool, CtBehavior, CtClass, CtField, CtMethod}
import javassist.bytecode.{AccessFlag, CodeIterator, ConstPool,
Descriptor, MethodParametersAttribute}

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.types._

//
Expand All @@ -30,8 +32,10 @@ import org.apache.spark.sql.types._
// Provides the interface the class and the method that implements the body of the lambda
// used by the rest of the compiler.
//
case class LambdaReflection(private val classPool: ClassPool,
private val serializedLambda: SerializedLambda) {
class LambdaReflection private(private val classPool: ClassPool,
private val ctClass: CtClass,
private val ctMethod: CtMethod,
val capturedArgs: Seq[Expression] = Seq()) {
def lookupConstant(constPoolIndex: Int): Any = {
constPool.getTag(constPoolIndex) match {
case ConstPool.CONST_Integer => constPool.getIntegerInfo(constPoolIndex)
Expand Down Expand Up @@ -61,10 +65,10 @@ case class LambdaReflection(private val classPool: ClassPool,
val methodName = constPool.getMethodrefName(constPoolIndex)
val descriptor = constPool.getMethodrefType(constPoolIndex)
val className = constPool.getMethodrefClassName(constPoolIndex)
val params = Descriptor.getParameterTypes(descriptor, classPool)
if (constPool.isConstructor(className, constPoolIndex) == 0) {
classPool.getCtClass(className).getDeclaredMethod(methodName, params)
classPool.getCtClass(className).getMethod(methodName, descriptor)
} else {
val params = Descriptor.getParameterTypes(descriptor, classPool)
classPool.getCtClass(className).getDeclaredConstructor(params)
}
}
Expand All @@ -76,20 +80,6 @@ case class LambdaReflection(private val classPool: ClassPool,
constPool.getClassInfo(constPoolIndex)
}

// Get the CtClass object for the class that capture the lambda.
private val ctClass = {
val name = serializedLambda.getCapturingClass.replace('/', '.')
val classForName = LambdaReflection.getClass(name)
classPool.insertClassPath(new ClassClassPath(classForName))
classPool.getCtClass(name)
}

// Get the CtMethod object for the method that implements the lambda body.
private val ctMethod = {
val lambdaImplName = serializedLambda.getImplMethodName
ctClass.getDeclaredMethod(lambdaImplName.stripSuffix("$adapted"))
}

private val methodInfo = ctMethod.getMethodInfo

val constPool = methodInfo.getConstPool
Expand All @@ -107,6 +97,9 @@ case class LambdaReflection(private val classPool: ClassPool,

object LambdaReflection {
def apply(function: AnyRef): LambdaReflection = {
if (function.isInstanceOf[CtMethod]) {
return LambdaReflection(function.asInstanceOf[CtMethod])
}
// writeReplace is supposed to return an object of SerializedLambda from
// the function class (See
// https://docs.oracle.com/javase/8/docs/api/java/lang/invoke/SerializedLambda.html).
Expand All @@ -117,9 +110,35 @@ object LambdaReflection {
writeReplace.setAccessible(true)
val serializedLambda = writeReplace.invoke(function)
.asInstanceOf[SerializedLambda]
val capturedArgs = Seq.tabulate(serializedLambda.getCapturedArgCount){
// Add UnknownCapturedArg expressions to the list of captured arguments
// until we figure out how to handle captured variables in various
// situations.
// This, however, lets us pass captured arguments as long as they are not
// evaluated.
_ => Repr.UnknownCapturedArg()
}

val classPool = new ClassPool(true)
LambdaReflection(classPool, serializedLambda)
// Get the CtClass object for the class that capture the lambda.
val ctClass = {
val name = serializedLambda.getCapturingClass.replace('/', '.')
val classForName = LambdaReflection.getClass(name)
classPool.insertClassPath(new ClassClassPath(classForName))
classPool.getCtClass(name)
}
// Get the CtMethod object for the method that implements the lambda body.
val ctMethod = {
val lambdaImplName = serializedLambda.getImplMethodName
ctClass.getDeclaredMethod(lambdaImplName.stripSuffix("$adapted"))
}
new LambdaReflection(classPool, ctClass, ctMethod, capturedArgs)
}

private def apply(ctMethod: CtMethod): LambdaReflection = {
val ctClass = ctMethod.getDeclaringClass
val classPool = ctClass.getClassPool
new LambdaReflection(classPool, ctClass, ctMethod)
}

def getClass(name: String): Class[_] = {
Expand Down
27 changes: 14 additions & 13 deletions udf-compiler/src/main/scala/com/nvidia/spark/udf/State.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package com.nvidia.spark.udf
import CatalystExpressionBuilder.simplify
import javassist.CtClass

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.expressions.{Expression, If, Literal, Or}

/**
Expand Down Expand Up @@ -74,7 +75,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, If, Literal, Or}
* @param cond
* @param expr
*/
case class State(locals: Array[Expression],
case class State(locals: IndexedSeq[Expression],
stack: List[Expression] = List(),
cond: Expression = Literal.TrueLiteral,
expr: Option[Expression] = None) {
Expand Down Expand Up @@ -117,23 +118,23 @@ case class State(locals: Array[Expression],

object State {
def makeStartingState(lambdaReflection: LambdaReflection,
children: Seq[Expression]): State = {
children: Seq[Expression],
objref: Option[Expression]): State = {
val max = lambdaReflection.maxLocals
val params: Seq[(CtClass, Expression)] = lambdaReflection.parameters.view.zip(children)
val (locals, _) = params.foldLeft((new Array[Expression](max), 0)) { (l, p) =>
val (locals: Array[Expression], index: Int) = l
val (param: CtClass, argExp: Expression) = p

val newIndex = if (param == CtClass.doubleType || param == CtClass.longType) {
val args = lambdaReflection.capturedArgs ++ children
val paramTypesAndArgs: Seq[(CtClass, Expression)] = lambdaReflection.parameters.view.zip(args)
val locals = paramTypesAndArgs.foldLeft(objref.toVector) { (l, p) =>
val (paramType : CtClass, argExp) = p
if (paramType == CtClass.doubleType || paramType == CtClass.longType) {
// Long and Double occupies two slots in the local variable array.
// Append null to occupy an extra slot.
// See https://docs.oracle.com/javase/specs/jvms/se8/html/jvms-2.html#jvms-2.6.1
index + 2
l :+ argExp :+ null
} else {
index + 1
l :+ argExp
}

(locals.updated(index, argExp), newIndex)
}
State(locals)
// Ensure locals have enough slots with padTo.
State(locals.padTo(max, null))
}
}
Loading

0 comments on commit b37dba8

Please sign in to comment.