diff --git a/docs/compatibility.md b/docs/compatibility.md
index 653037c46aef..807205179047 100644
--- a/docs/compatibility.md
+++ b/docs/compatibility.md
@@ -384,3 +384,4 @@ When translating UDFs to Catalyst expressions, the supported UDF functions are l
| | Array.empty[Float] |
| | Array.empty[Double] |
| | Array.empty[String] |
+| Method call | Only if the method being called
- consists of operations supported by the UDF compiler, and
- is one of the folllowing:
- a final method, or
- a method in a final class, or
- a method in a final object
|
diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/CatalystExpressionBuilder.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/CatalystExpressionBuilder.scala
index 8a6634473a2b..52c31444749c 100644
--- a/udf-compiler/src/main/scala/com/nvidia/spark/udf/CatalystExpressionBuilder.scala
+++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/CatalystExpressionBuilder.scala
@@ -63,11 +63,11 @@ case class CatalystExpressionBuilder(private val function: AnyRef) extends Loggi
* @param children a sequence of catalyst arguments to the udf.
* @return the compiled expression, optionally
*/
- def compile(children: Seq[Expression]): Option[Expression] = {
+ def compile(children: Seq[Expression], objref: Option[Expression] = None): Option[Expression] = {
// create starting state, this will be:
// State([children expressions], [empty stack], cond = true, expr = None)
- val entryState = State.makeStartingState(lambdaReflection, children)
+ val entryState = State.makeStartingState(lambdaReflection, children, objref)
// pick first of the Basic Blocks, and start recursing
val entryBlock = cfg.basicBlocks.head
diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala
index 2f36a70e3d09..5ddceaa38559 100644
--- a/udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala
+++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala
@@ -18,16 +18,16 @@ package com.nvidia.spark.udf
import CatalystExpressionBuilder.simplify
import java.nio.charset.Charset
+import javassist.Modifier
import javassist.bytecode.{CodeIterator, Opcode}
import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
-private object Repr {
+private[udf] object Repr {
abstract class CompilerInternal(name: String) extends Expression {
override def dataType: DataType = {
@@ -110,6 +110,8 @@ private object Repr {
case class ClassTag[T](classTag: scala.reflect.ClassTag[T])
extends CompilerInternal("scala.reflect.ClassTag")
+
+ case class UnknownCapturedArg() extends CompilerInternal("unknown captured arg")
}
/**
@@ -346,7 +348,8 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
(List[Expression], List[Expression])): State = {
val State(locals, stack, cond, expr) = state
val method = lambdaReflection.lookupBehavior(operand)
- val declaringClassName = method.getDeclaringClass.getName
+ val declaringClass = method.getDeclaringClass
+ val declaringClassName = declaringClass.getName
val paramTypes = method.getParameterTypes
val (args, rest) = getArgs(stack, paramTypes.length)
// We don't support arbitrary calls.
@@ -389,9 +392,32 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
} else if (declaringClassName.equals("java.time.LocalDateTime")) {
State(locals, localDateTimeOp(method.getName, args) :: rest, cond, expr)
} else {
- // Other functions
- throw new SparkException(
- s"Unsupported instruction: ${Opcode.INVOKEVIRTUAL} ${declaringClassName}")
+ val mModifiers = method.getModifiers
+ val cModifiers = declaringClass.getModifiers
+ if (!javassist.Modifier.isEnum(mModifiers) &&
+ !javassist.Modifier.isInterface(mModifiers) &&
+ !javassist.Modifier.isNative(mModifiers) &&
+ !javassist.Modifier.isPackage(mModifiers) &&
+ !javassist.Modifier.isStrict(mModifiers) &&
+ !javassist.Modifier.isSynchronized(mModifiers) &&
+ !javassist.Modifier.isTransient(mModifiers) &&
+ !javassist.Modifier.isVarArgs(mModifiers) &&
+ !javassist.Modifier.isVolatile(mModifiers) &&
+ (javassist.Modifier.isFinal(mModifiers) ||
+ javassist.Modifier.isFinal(cModifiers))) {
+ val retval = {
+ if (javassist.Modifier.isStatic(mModifiers)) {
+ CatalystExpressionBuilder(method).compile(args)
+ } else {
+ CatalystExpressionBuilder(method).compile(args.tail, Some(args.head))
+ }
+ }
+ State(locals, retval.toList ::: rest, cond, expr)
+ } else {
+ // Other functions
+ throw new SparkException(
+ s"Unsupported invocation of ${declaringClassName}.${method.getName}")
+ }
}
}
diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/LambdaReflection.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/LambdaReflection.scala
index ff91a1aef6bd..133f5166252d 100644
--- a/udf-compiler/src/main/scala/com/nvidia/spark/udf/LambdaReflection.scala
+++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/LambdaReflection.scala
@@ -16,12 +16,14 @@
package com.nvidia.spark.udf
+import Repr.UnknownCapturedArg
import java.lang.invoke.SerializedLambda
-
-import javassist.{ClassClassPath, ClassPool, CtBehavior, CtClass, CtField}
-import javassist.bytecode.{CodeIterator, ConstPool, Descriptor}
+import javassist.{ClassClassPath, ClassPool, CtBehavior, CtClass, CtField, CtMethod}
+import javassist.bytecode.{AccessFlag, CodeIterator, ConstPool,
+ Descriptor, MethodParametersAttribute}
import org.apache.spark.SparkException
+import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.types._
//
@@ -30,8 +32,10 @@ import org.apache.spark.sql.types._
// Provides the interface the class and the method that implements the body of the lambda
// used by the rest of the compiler.
//
-case class LambdaReflection(private val classPool: ClassPool,
- private val serializedLambda: SerializedLambda) {
+class LambdaReflection private(private val classPool: ClassPool,
+ private val ctClass: CtClass,
+ private val ctMethod: CtMethod,
+ val capturedArgs: Seq[Expression] = Seq()) {
def lookupConstant(constPoolIndex: Int): Any = {
constPool.getTag(constPoolIndex) match {
case ConstPool.CONST_Integer => constPool.getIntegerInfo(constPoolIndex)
@@ -61,10 +65,10 @@ case class LambdaReflection(private val classPool: ClassPool,
val methodName = constPool.getMethodrefName(constPoolIndex)
val descriptor = constPool.getMethodrefType(constPoolIndex)
val className = constPool.getMethodrefClassName(constPoolIndex)
- val params = Descriptor.getParameterTypes(descriptor, classPool)
if (constPool.isConstructor(className, constPoolIndex) == 0) {
- classPool.getCtClass(className).getDeclaredMethod(methodName, params)
+ classPool.getCtClass(className).getMethod(methodName, descriptor)
} else {
+ val params = Descriptor.getParameterTypes(descriptor, classPool)
classPool.getCtClass(className).getDeclaredConstructor(params)
}
}
@@ -76,20 +80,6 @@ case class LambdaReflection(private val classPool: ClassPool,
constPool.getClassInfo(constPoolIndex)
}
- // Get the CtClass object for the class that capture the lambda.
- private val ctClass = {
- val name = serializedLambda.getCapturingClass.replace('/', '.')
- val classForName = LambdaReflection.getClass(name)
- classPool.insertClassPath(new ClassClassPath(classForName))
- classPool.getCtClass(name)
- }
-
- // Get the CtMethod object for the method that implements the lambda body.
- private val ctMethod = {
- val lambdaImplName = serializedLambda.getImplMethodName
- ctClass.getDeclaredMethod(lambdaImplName.stripSuffix("$adapted"))
- }
-
private val methodInfo = ctMethod.getMethodInfo
val constPool = methodInfo.getConstPool
@@ -107,6 +97,9 @@ case class LambdaReflection(private val classPool: ClassPool,
object LambdaReflection {
def apply(function: AnyRef): LambdaReflection = {
+ if (function.isInstanceOf[CtMethod]) {
+ return LambdaReflection(function.asInstanceOf[CtMethod])
+ }
// writeReplace is supposed to return an object of SerializedLambda from
// the function class (See
// https://docs.oracle.com/javase/8/docs/api/java/lang/invoke/SerializedLambda.html).
@@ -117,9 +110,35 @@ object LambdaReflection {
writeReplace.setAccessible(true)
val serializedLambda = writeReplace.invoke(function)
.asInstanceOf[SerializedLambda]
+ val capturedArgs = Seq.tabulate(serializedLambda.getCapturedArgCount){
+ // Add UnknownCapturedArg expressions to the list of captured arguments
+ // until we figure out how to handle captured variables in various
+ // situations.
+ // This, however, lets us pass captured arguments as long as they are not
+ // evaluated.
+ _ => Repr.UnknownCapturedArg()
+ }
val classPool = new ClassPool(true)
- LambdaReflection(classPool, serializedLambda)
+ // Get the CtClass object for the class that capture the lambda.
+ val ctClass = {
+ val name = serializedLambda.getCapturingClass.replace('/', '.')
+ val classForName = LambdaReflection.getClass(name)
+ classPool.insertClassPath(new ClassClassPath(classForName))
+ classPool.getCtClass(name)
+ }
+ // Get the CtMethod object for the method that implements the lambda body.
+ val ctMethod = {
+ val lambdaImplName = serializedLambda.getImplMethodName
+ ctClass.getDeclaredMethod(lambdaImplName.stripSuffix("$adapted"))
+ }
+ new LambdaReflection(classPool, ctClass, ctMethod, capturedArgs)
+ }
+
+ private def apply(ctMethod: CtMethod): LambdaReflection = {
+ val ctClass = ctMethod.getDeclaringClass
+ val classPool = ctClass.getClassPool
+ new LambdaReflection(classPool, ctClass, ctMethod)
}
def getClass(name: String): Class[_] = {
diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/State.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/State.scala
index df49b80a6648..21ebe3e034f9 100644
--- a/udf-compiler/src/main/scala/com/nvidia/spark/udf/State.scala
+++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/State.scala
@@ -19,6 +19,7 @@ package com.nvidia.spark.udf
import CatalystExpressionBuilder.simplify
import javassist.CtClass
+import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.expressions.{Expression, If, Literal, Or}
/**
@@ -74,7 +75,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, If, Literal, Or}
* @param cond
* @param expr
*/
-case class State(locals: Array[Expression],
+case class State(locals: IndexedSeq[Expression],
stack: List[Expression] = List(),
cond: Expression = Literal.TrueLiteral,
expr: Option[Expression] = None) {
@@ -117,23 +118,23 @@ case class State(locals: Array[Expression],
object State {
def makeStartingState(lambdaReflection: LambdaReflection,
- children: Seq[Expression]): State = {
+ children: Seq[Expression],
+ objref: Option[Expression]): State = {
val max = lambdaReflection.maxLocals
- val params: Seq[(CtClass, Expression)] = lambdaReflection.parameters.view.zip(children)
- val (locals, _) = params.foldLeft((new Array[Expression](max), 0)) { (l, p) =>
- val (locals: Array[Expression], index: Int) = l
- val (param: CtClass, argExp: Expression) = p
-
- val newIndex = if (param == CtClass.doubleType || param == CtClass.longType) {
+ val _children = lambdaReflection.capturedArgs ++ children
+ val params: Seq[(CtClass, Expression)] = lambdaReflection.parameters.view.zip(_children)
+ val locals = params.foldLeft(objref.toVector) { (l, p) =>
+ val (param: CtClass, argExp) = p
+ if (param == CtClass.doubleType || param == CtClass.longType) {
// Long and Double occupies two slots in the local variable array.
+ // Append null to occupy an extra slot.
// See https://docs.oracle.com/javase/specs/jvms/se8/html/jvms-2.html#jvms-2.6.1
- index + 2
+ l :+ argExp :+ null
} else {
- index + 1
+ l :+ argExp
}
-
- (locals.updated(index, argExp), newIndex)
}
- State(locals)
+ // Ensure locals have enough slots with padTo.
+ State(locals.padTo(max, null))
}
}
diff --git a/udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala b/udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala
index ecf9d794431c..9b4d0de65f58 100644
--- a/udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala
+++ b/udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala
@@ -25,7 +25,7 @@ import org.scalatest.Assertions._
import org.scalatest.FunSuite
import org.apache.spark.SparkConf
-import org.apache.spark.sql.{Dataset, SparkSession}
+import org.apache.spark.sql.{Dataset, Row, SparkSession}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions._
import org.apache.spark.sql.functions.{udf => makeUdf}
@@ -35,7 +35,6 @@ class OpcodeSuite extends FunSuite {
val conf: SparkConf = new SparkConf()
.set("spark.sql.extensions", "com.nvidia.spark.udf.Plugin")
- .set("spark.rapids.sql.test.enabled", "true")
.set("spark.rapids.sql.udfCompiler.enabled", "true")
.set(RapidsConf.EXPLAIN.key, "true")
@@ -2088,19 +2087,201 @@ class OpcodeSuite extends FunSuite {
checkEquiv(result, ref)
}
- test("Arbitrary function call inside UDF - ") {
- def simple_func(str: String): Int = {
+ test("final static method call inside UDF - ") {
+ def simple_func2(str: String): Int = {
str.length
}
+ def simple_func1(str: String): Int = {
+ simple_func2(str) + simple_func2(str)
+ }
val myudf: (String) => Int = str => {
- simple_func(str)
+ simple_func1(str)
}
val u = makeUdf(myudf)
val dataset = List("hello", "world").toDS()
val result = dataset.withColumn("new", u(col("value")))
+ val ref = dataset.withColumn("new", length(col("value")) + length(col("value")))
+ checkEquiv(result, ref)
+ }
+
+ test("FALLBACK TO CPU: class method call inside UDF - ") {
+ class C {
+ def simple_func(str: String): Int = {
+ str.length
+ }
+ val myudf: (String) => Int = str => {
+ simple_func(str)
+ }
+ def apply(dataset: Dataset[String]): Dataset[Row] = {
+ val u = makeUdf(myudf)
+ dataset.withColumn("new", u(col("value")))
+ }
+ }
+ val runner = new C
+ val dataset = List("hello", "world").toDS()
+ val result = runner(dataset)
+ val ref = dataset.withColumn("new", length(col("value")))
+ checkEquivNotCompiled(result, ref)
+ }
+
+ test("final class method call inside UDF - ") {
+ final class C {
+ def simple_func(str: String): Int = {
+ str.length
+ }
+ val myudf: (String) => Int = str => {
+ simple_func(str)
+ }
+ def apply(dataset: Dataset[String]): Dataset[Row] = {
+ val u = makeUdf(myudf)
+ dataset.withColumn("new", u(col("value")))
+ }
+ }
+ val runner = new C
+ val dataset = List("hello", "world").toDS()
+ val result = runner(dataset)
+ val ref = dataset.withColumn("new", length(col("value")))
+ checkEquiv(result, ref)
+ }
+
+ test("FALL BACK TO CPU: object method call inside UDF - ") {
+ object C {
+ def simple_func(str: String): Int = {
+ str.length
+ }
+ val myudf: (String) => Int = str => {
+ simple_func(str)
+ }
+ def apply(dataset: Dataset[String]): Dataset[Row] = {
+ val u = makeUdf(myudf)
+ dataset.withColumn("new", u(col("value")))
+ }
+ }
+ val dataset = List("hello", "world").toDS()
+ val result = C(dataset)
val ref = dataset.withColumn("new", length(col("value")))
- result.explain(true)
- result.show
-// checkEquiv(result, ref)
+ checkEquivNotCompiled(result, ref)
}
+
+ test("final object method call inside UDF - ") {
+ final object C {
+ def simple_func(str: String): Int = {
+ str.length
+ }
+ val myudf: (String) => Int = str => {
+ simple_func(str)
+ }
+ def apply(dataset: Dataset[String]): Dataset[Row] = {
+ val u = makeUdf(myudf)
+ dataset.withColumn("new", u(col("value")))
+ }
+ }
+ val dataset = List("hello", "world").toDS()
+ val result = C(dataset)
+ val ref = dataset.withColumn("new", length(col("value")))
+ checkEquiv(result, ref)
+ }
+
+ test("super class final method call inside UDF - ") {
+ class B {
+ final def simple_func(str: String): Int = {
+ str.length
+ }
+ }
+ class D extends B {
+ val myudf: (String) => Int = str => {
+ simple_func(str)
+ }
+ def apply(dataset: Dataset[String]): Dataset[Row] = {
+ val u = makeUdf(myudf)
+ dataset.withColumn("new", u(col("value")))
+ }
+ }
+ val runner = new D
+ val dataset = List("hello", "world").toDS()
+ val result = runner(dataset)
+ val ref = dataset.withColumn("new", length(col("value")))
+ checkEquiv(result, ref)
+ }
+
+ test("FALLBACK TO CPU: final class calls super class method inside UDF - ") {
+ class B {
+ def simple_func(str: String): Int = {
+ str.length
+ }
+ }
+ final class D extends B {
+ val myudf: (String) => Int = str => {
+ simple_func(str)
+ }
+ def apply(dataset: Dataset[String]): Dataset[Row] = {
+ val u = makeUdf(myudf)
+ dataset.withColumn("new", u(col("value")))
+ }
+ }
+ val runner = new D
+ val dataset = List("hello", "world").toDS()
+ val result = runner(dataset)
+ val ref = dataset.withColumn("new", length(col("value")))
+ checkEquivNotCompiled(result, ref)
+ }
+
+ test("FALLBACK TO CPU: super class method call inside UDF - ") {
+ class B {
+ def simple_func(str: String): Int = {
+ str.length
+ }
+ }
+ class D extends B {
+ val myudf: (String) => Int = str => {
+ simple_func(str)
+ }
+ def apply(dataset: Dataset[String]): Dataset[Row] = {
+ val u = makeUdf(myudf)
+ dataset.withColumn("new", u(col("value")))
+ }
+ }
+ val runner = new D
+ val dataset = List("hello", "world").toDS()
+ val result = runner(dataset)
+ val ref = dataset.withColumn("new", length(col("value")))
+ checkEquivNotCompiled(result, ref)
+ }
+
+ test("FALLBACK TO CPU: capture a var in class - ") {
+ class C {
+ var capturedArg: Int = 4
+ val myudf: (String) => Int = str => {
+ str.length + capturedArg
+ }
+ def apply(dataset: Dataset[String]): Dataset[Row] = {
+ val u = makeUdf(myudf)
+ dataset.withColumn("new", u(col("value")))
+ }
+ }
+ val runner = new C
+ val dataset = List("hello", "world").toDS()
+ val result = runner(dataset)
+ val ref = dataset.withColumn("new", length(col("value")) + runner.capturedArg)
+ checkEquivNotCompiled(result, ref)
+ }
+
+ test("FALLBACK TO CPU: capture a var outside class - ") {
+ var capturedArg: Int = 4
+ class C {
+ val myudf: (String) => Int = str => {
+ str.length + capturedArg
+ }
+ def apply(dataset: Dataset[String]): Dataset[Row] = {
+ val u = makeUdf(myudf)
+ dataset.withColumn("new", u(col("value")))
+ }
+ }
+ val runner = new C
+ val dataset = List("hello", "world").toDS()
+ val result = runner(dataset)
+ val ref = dataset.withColumn("new", length(col("value")) + capturedArg)
+ checkEquivNotCompiled(result, ref)
+ }
+
}