Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arbitrary function call in UDF #854

Merged
merged 2 commits into from
Oct 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -384,3 +384,4 @@ When translating UDFs to Catalyst expressions, the supported UDF functions are l
| | Array.empty[Float] |
| | Array.empty[Double] |
| | Array.empty[String] |
| Method call | Only if the method being called <ol><li>consists of operations supported by the UDF compiler, and</li><li>is one of the folllowing:<ul><li>a final method, or</li><li>a method in a final class, or</li><li>a method in a final object</li></ul></li></ol> |
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ case class CatalystExpressionBuilder(private val function: AnyRef) extends Loggi
* @param children a sequence of catalyst arguments to the udf.
* @return the compiled expression, optionally
*/
def compile(children: Seq[Expression]): Option[Expression] = {
def compile(children: Seq[Expression], objref: Option[Expression] = None): Option[Expression] = {

// create starting state, this will be:
// State([children expressions], [empty stack], cond = true, expr = None)
val entryState = State.makeStartingState(lambdaReflection, children)
val entryState = State.makeStartingState(lambdaReflection, children, objref)

// pick first of the Basic Blocks, and start recursing
val entryBlock = cfg.basicBlocks.head
Expand Down
38 changes: 32 additions & 6 deletions udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@ package com.nvidia.spark.udf

import CatalystExpressionBuilder.simplify
import java.nio.charset.Charset
import javassist.Modifier
import javassist.bytecode.{CodeIterator, Opcode}

import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._


private object Repr {
private[udf] object Repr {

abstract class CompilerInternal(name: String) extends Expression {
override def dataType: DataType = {
Expand Down Expand Up @@ -110,6 +110,8 @@ private object Repr {

case class ClassTag[T](classTag: scala.reflect.ClassTag[T])
extends CompilerInternal("scala.reflect.ClassTag")

case class UnknownCapturedArg() extends CompilerInternal("unknown captured arg")
}

/**
Expand Down Expand Up @@ -346,7 +348,8 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
(List[Expression], List[Expression])): State = {
val State(locals, stack, cond, expr) = state
val method = lambdaReflection.lookupBehavior(operand)
val declaringClassName = method.getDeclaringClass.getName
val declaringClass = method.getDeclaringClass
val declaringClassName = declaringClass.getName
val paramTypes = method.getParameterTypes
val (args, rest) = getArgs(stack, paramTypes.length)
// We don't support arbitrary calls.
Expand Down Expand Up @@ -389,9 +392,32 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
} else if (declaringClassName.equals("java.time.LocalDateTime")) {
State(locals, localDateTimeOp(method.getName, args) :: rest, cond, expr)
} else {
// Other functions
throw new SparkException(
s"Unsupported instruction: ${Opcode.INVOKEVIRTUAL} ${declaringClassName}")
val mModifiers = method.getModifiers
val cModifiers = declaringClass.getModifiers
if (!javassist.Modifier.isEnum(mModifiers) &&
!javassist.Modifier.isInterface(mModifiers) &&
!javassist.Modifier.isNative(mModifiers) &&
!javassist.Modifier.isPackage(mModifiers) &&
!javassist.Modifier.isStrict(mModifiers) &&
!javassist.Modifier.isSynchronized(mModifiers) &&
!javassist.Modifier.isTransient(mModifiers) &&
!javassist.Modifier.isVarArgs(mModifiers) &&
!javassist.Modifier.isVolatile(mModifiers) &&
(javassist.Modifier.isFinal(mModifiers) ||
javassist.Modifier.isFinal(cModifiers))) {
val retval = {
if (javassist.Modifier.isStatic(mModifiers)) {
CatalystExpressionBuilder(method).compile(args)
} else {
CatalystExpressionBuilder(method).compile(args.tail, Some(args.head))
}
}
State(locals, retval.toList ::: rest, cond, expr)
} else {
// Other functions
throw new SparkException(
s"Unsupported invocation of ${declaringClassName}.${method.getName}")
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@

package com.nvidia.spark.udf

import Repr.UnknownCapturedArg
import java.lang.invoke.SerializedLambda

import javassist.{ClassClassPath, ClassPool, CtBehavior, CtClass, CtField}
import javassist.bytecode.{CodeIterator, ConstPool, Descriptor}
import javassist.{ClassClassPath, ClassPool, CtBehavior, CtClass, CtField, CtMethod}
import javassist.bytecode.{AccessFlag, CodeIterator, ConstPool,
Descriptor, MethodParametersAttribute}

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.types._

//
Expand All @@ -30,8 +32,10 @@ import org.apache.spark.sql.types._
// Provides the interface the class and the method that implements the body of the lambda
// used by the rest of the compiler.
//
case class LambdaReflection(private val classPool: ClassPool,
private val serializedLambda: SerializedLambda) {
class LambdaReflection private(private val classPool: ClassPool,
private val ctClass: CtClass,
private val ctMethod: CtMethod,
val capturedArgs: Seq[Expression] = Seq()) {
def lookupConstant(constPoolIndex: Int): Any = {
constPool.getTag(constPoolIndex) match {
case ConstPool.CONST_Integer => constPool.getIntegerInfo(constPoolIndex)
Expand Down Expand Up @@ -61,10 +65,10 @@ case class LambdaReflection(private val classPool: ClassPool,
val methodName = constPool.getMethodrefName(constPoolIndex)
val descriptor = constPool.getMethodrefType(constPoolIndex)
val className = constPool.getMethodrefClassName(constPoolIndex)
val params = Descriptor.getParameterTypes(descriptor, classPool)
if (constPool.isConstructor(className, constPoolIndex) == 0) {
classPool.getCtClass(className).getDeclaredMethod(methodName, params)
classPool.getCtClass(className).getMethod(methodName, descriptor)
} else {
val params = Descriptor.getParameterTypes(descriptor, classPool)
classPool.getCtClass(className).getDeclaredConstructor(params)
}
}
Expand All @@ -76,20 +80,6 @@ case class LambdaReflection(private val classPool: ClassPool,
constPool.getClassInfo(constPoolIndex)
}

// Get the CtClass object for the class that capture the lambda.
private val ctClass = {
val name = serializedLambda.getCapturingClass.replace('/', '.')
val classForName = LambdaReflection.getClass(name)
classPool.insertClassPath(new ClassClassPath(classForName))
classPool.getCtClass(name)
}

// Get the CtMethod object for the method that implements the lambda body.
private val ctMethod = {
val lambdaImplName = serializedLambda.getImplMethodName
ctClass.getDeclaredMethod(lambdaImplName.stripSuffix("$adapted"))
}

private val methodInfo = ctMethod.getMethodInfo

val constPool = methodInfo.getConstPool
Expand All @@ -107,6 +97,9 @@ case class LambdaReflection(private val classPool: ClassPool,

object LambdaReflection {
def apply(function: AnyRef): LambdaReflection = {
if (function.isInstanceOf[CtMethod]) {
return LambdaReflection(function.asInstanceOf[CtMethod])
}
// writeReplace is supposed to return an object of SerializedLambda from
// the function class (See
// https://docs.oracle.com/javase/8/docs/api/java/lang/invoke/SerializedLambda.html).
Expand All @@ -117,9 +110,35 @@ object LambdaReflection {
writeReplace.setAccessible(true)
val serializedLambda = writeReplace.invoke(function)
.asInstanceOf[SerializedLambda]
val capturedArgs = Seq.tabulate(serializedLambda.getCapturedArgCount){
// Add UnknownCapturedArg expressions to the list of captured arguments
// until we figure out how to handle captured variables in various
// situations.
// This, however, lets us pass captured arguments as long as they are not
// evaluated.
_ => Repr.UnknownCapturedArg()
}

val classPool = new ClassPool(true)
LambdaReflection(classPool, serializedLambda)
// Get the CtClass object for the class that capture the lambda.
val ctClass = {
val name = serializedLambda.getCapturingClass.replace('/', '.')
val classForName = LambdaReflection.getClass(name)
classPool.insertClassPath(new ClassClassPath(classForName))
classPool.getCtClass(name)
}
// Get the CtMethod object for the method that implements the lambda body.
val ctMethod = {
val lambdaImplName = serializedLambda.getImplMethodName
ctClass.getDeclaredMethod(lambdaImplName.stripSuffix("$adapted"))
}
new LambdaReflection(classPool, ctClass, ctMethod, capturedArgs)
}

private def apply(ctMethod: CtMethod): LambdaReflection = {
val ctClass = ctMethod.getDeclaringClass
val classPool = ctClass.getClassPool
new LambdaReflection(classPool, ctClass, ctMethod)
}

def getClass(name: String): Class[_] = {
Expand Down
27 changes: 14 additions & 13 deletions udf-compiler/src/main/scala/com/nvidia/spark/udf/State.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package com.nvidia.spark.udf
import CatalystExpressionBuilder.simplify
import javassist.CtClass

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.expressions.{Expression, If, Literal, Or}

/**
Expand Down Expand Up @@ -74,7 +75,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, If, Literal, Or}
* @param cond
* @param expr
*/
case class State(locals: Array[Expression],
case class State(locals: IndexedSeq[Expression],
stack: List[Expression] = List(),
cond: Expression = Literal.TrueLiteral,
expr: Option[Expression] = None) {
Expand Down Expand Up @@ -117,23 +118,23 @@ case class State(locals: Array[Expression],

object State {
def makeStartingState(lambdaReflection: LambdaReflection,
children: Seq[Expression]): State = {
children: Seq[Expression],
objref: Option[Expression]): State = {
val max = lambdaReflection.maxLocals
val params: Seq[(CtClass, Expression)] = lambdaReflection.parameters.view.zip(children)
val (locals, _) = params.foldLeft((new Array[Expression](max), 0)) { (l, p) =>
val (locals: Array[Expression], index: Int) = l
val (param: CtClass, argExp: Expression) = p

val newIndex = if (param == CtClass.doubleType || param == CtClass.longType) {
val args = lambdaReflection.capturedArgs ++ children
val paramTypesAndArgs: Seq[(CtClass, Expression)] = lambdaReflection.parameters.view.zip(args)
val locals = paramTypesAndArgs.foldLeft(objref.toVector) { (l, p) =>
val (paramType : CtClass, argExp) = p
if (paramType == CtClass.doubleType || paramType == CtClass.longType) {
// Long and Double occupies two slots in the local variable array.
// Append null to occupy an extra slot.
// See https://docs.oracle.com/javase/specs/jvms/se8/html/jvms-2.html#jvms-2.6.1
index + 2
l :+ argExp :+ null
} else {
index + 1
l :+ argExp
}

(locals.updated(index, argExp), newIndex)
}
State(locals)
// Ensure locals have enough slots with padTo.
State(locals.padTo(max, null))
}
}
Loading