diff --git a/docs/compatibility.md b/docs/compatibility.md index 653037c46aef..807205179047 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -384,3 +384,4 @@ When translating UDFs to Catalyst expressions, the supported UDF functions are l | | Array.empty[Float] | | | Array.empty[Double] | | | Array.empty[String] | +| Method call | Only if the method being called
  1. consists of operations supported by the UDF compiler, and
  2. is one of the folllowing:
| diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/CatalystExpressionBuilder.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/CatalystExpressionBuilder.scala index 8a6634473a2b..52c31444749c 100644 --- a/udf-compiler/src/main/scala/com/nvidia/spark/udf/CatalystExpressionBuilder.scala +++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/CatalystExpressionBuilder.scala @@ -63,11 +63,11 @@ case class CatalystExpressionBuilder(private val function: AnyRef) extends Loggi * @param children a sequence of catalyst arguments to the udf. * @return the compiled expression, optionally */ - def compile(children: Seq[Expression]): Option[Expression] = { + def compile(children: Seq[Expression], objref: Option[Expression] = None): Option[Expression] = { // create starting state, this will be: // State([children expressions], [empty stack], cond = true, expr = None) - val entryState = State.makeStartingState(lambdaReflection, children) + val entryState = State.makeStartingState(lambdaReflection, children, objref) // pick first of the Basic Blocks, and start recursing val entryBlock = cfg.basicBlocks.head diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala index 2f36a70e3d09..5ddceaa38559 100644 --- a/udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala +++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala @@ -18,16 +18,16 @@ package com.nvidia.spark.udf import CatalystExpressionBuilder.simplify import java.nio.charset.Charset +import javassist.Modifier import javassist.bytecode.{CodeIterator, Opcode} import org.apache.spark.SparkException import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -private object Repr { +private[udf] object Repr { abstract class CompilerInternal(name: String) extends Expression { override def dataType: DataType = { @@ -110,6 +110,8 @@ private object Repr { case class ClassTag[T](classTag: scala.reflect.ClassTag[T]) extends CompilerInternal("scala.reflect.ClassTag") + + case class UnknownCapturedArg() extends CompilerInternal("unknown captured arg") } /** @@ -346,7 +348,8 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend (List[Expression], List[Expression])): State = { val State(locals, stack, cond, expr) = state val method = lambdaReflection.lookupBehavior(operand) - val declaringClassName = method.getDeclaringClass.getName + val declaringClass = method.getDeclaringClass + val declaringClassName = declaringClass.getName val paramTypes = method.getParameterTypes val (args, rest) = getArgs(stack, paramTypes.length) // We don't support arbitrary calls. @@ -389,9 +392,32 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend } else if (declaringClassName.equals("java.time.LocalDateTime")) { State(locals, localDateTimeOp(method.getName, args) :: rest, cond, expr) } else { - // Other functions - throw new SparkException( - s"Unsupported instruction: ${Opcode.INVOKEVIRTUAL} ${declaringClassName}") + val mModifiers = method.getModifiers + val cModifiers = declaringClass.getModifiers + if (!javassist.Modifier.isEnum(mModifiers) && + !javassist.Modifier.isInterface(mModifiers) && + !javassist.Modifier.isNative(mModifiers) && + !javassist.Modifier.isPackage(mModifiers) && + !javassist.Modifier.isStrict(mModifiers) && + !javassist.Modifier.isSynchronized(mModifiers) && + !javassist.Modifier.isTransient(mModifiers) && + !javassist.Modifier.isVarArgs(mModifiers) && + !javassist.Modifier.isVolatile(mModifiers) && + (javassist.Modifier.isFinal(mModifiers) || + javassist.Modifier.isFinal(cModifiers))) { + val retval = { + if (javassist.Modifier.isStatic(mModifiers)) { + CatalystExpressionBuilder(method).compile(args) + } else { + CatalystExpressionBuilder(method).compile(args.tail, Some(args.head)) + } + } + State(locals, retval.toList ::: rest, cond, expr) + } else { + // Other functions + throw new SparkException( + s"Unsupported invocation of ${declaringClassName}.${method.getName}") + } } } diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/LambdaReflection.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/LambdaReflection.scala index ff91a1aef6bd..133f5166252d 100644 --- a/udf-compiler/src/main/scala/com/nvidia/spark/udf/LambdaReflection.scala +++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/LambdaReflection.scala @@ -16,12 +16,14 @@ package com.nvidia.spark.udf +import Repr.UnknownCapturedArg import java.lang.invoke.SerializedLambda - -import javassist.{ClassClassPath, ClassPool, CtBehavior, CtClass, CtField} -import javassist.bytecode.{CodeIterator, ConstPool, Descriptor} +import javassist.{ClassClassPath, ClassPool, CtBehavior, CtClass, CtField, CtMethod} +import javassist.bytecode.{AccessFlag, CodeIterator, ConstPool, + Descriptor, MethodParametersAttribute} import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.types._ // @@ -30,8 +32,10 @@ import org.apache.spark.sql.types._ // Provides the interface the class and the method that implements the body of the lambda // used by the rest of the compiler. // -case class LambdaReflection(private val classPool: ClassPool, - private val serializedLambda: SerializedLambda) { +class LambdaReflection private(private val classPool: ClassPool, + private val ctClass: CtClass, + private val ctMethod: CtMethod, + val capturedArgs: Seq[Expression] = Seq()) { def lookupConstant(constPoolIndex: Int): Any = { constPool.getTag(constPoolIndex) match { case ConstPool.CONST_Integer => constPool.getIntegerInfo(constPoolIndex) @@ -61,10 +65,10 @@ case class LambdaReflection(private val classPool: ClassPool, val methodName = constPool.getMethodrefName(constPoolIndex) val descriptor = constPool.getMethodrefType(constPoolIndex) val className = constPool.getMethodrefClassName(constPoolIndex) - val params = Descriptor.getParameterTypes(descriptor, classPool) if (constPool.isConstructor(className, constPoolIndex) == 0) { - classPool.getCtClass(className).getDeclaredMethod(methodName, params) + classPool.getCtClass(className).getMethod(methodName, descriptor) } else { + val params = Descriptor.getParameterTypes(descriptor, classPool) classPool.getCtClass(className).getDeclaredConstructor(params) } } @@ -76,20 +80,6 @@ case class LambdaReflection(private val classPool: ClassPool, constPool.getClassInfo(constPoolIndex) } - // Get the CtClass object for the class that capture the lambda. - private val ctClass = { - val name = serializedLambda.getCapturingClass.replace('/', '.') - val classForName = LambdaReflection.getClass(name) - classPool.insertClassPath(new ClassClassPath(classForName)) - classPool.getCtClass(name) - } - - // Get the CtMethod object for the method that implements the lambda body. - private val ctMethod = { - val lambdaImplName = serializedLambda.getImplMethodName - ctClass.getDeclaredMethod(lambdaImplName.stripSuffix("$adapted")) - } - private val methodInfo = ctMethod.getMethodInfo val constPool = methodInfo.getConstPool @@ -107,6 +97,9 @@ case class LambdaReflection(private val classPool: ClassPool, object LambdaReflection { def apply(function: AnyRef): LambdaReflection = { + if (function.isInstanceOf[CtMethod]) { + return LambdaReflection(function.asInstanceOf[CtMethod]) + } // writeReplace is supposed to return an object of SerializedLambda from // the function class (See // https://docs.oracle.com/javase/8/docs/api/java/lang/invoke/SerializedLambda.html). @@ -117,9 +110,35 @@ object LambdaReflection { writeReplace.setAccessible(true) val serializedLambda = writeReplace.invoke(function) .asInstanceOf[SerializedLambda] + val capturedArgs = Seq.tabulate(serializedLambda.getCapturedArgCount){ + // Add UnknownCapturedArg expressions to the list of captured arguments + // until we figure out how to handle captured variables in various + // situations. + // This, however, lets us pass captured arguments as long as they are not + // evaluated. + _ => Repr.UnknownCapturedArg() + } val classPool = new ClassPool(true) - LambdaReflection(classPool, serializedLambda) + // Get the CtClass object for the class that capture the lambda. + val ctClass = { + val name = serializedLambda.getCapturingClass.replace('/', '.') + val classForName = LambdaReflection.getClass(name) + classPool.insertClassPath(new ClassClassPath(classForName)) + classPool.getCtClass(name) + } + // Get the CtMethod object for the method that implements the lambda body. + val ctMethod = { + val lambdaImplName = serializedLambda.getImplMethodName + ctClass.getDeclaredMethod(lambdaImplName.stripSuffix("$adapted")) + } + new LambdaReflection(classPool, ctClass, ctMethod, capturedArgs) + } + + private def apply(ctMethod: CtMethod): LambdaReflection = { + val ctClass = ctMethod.getDeclaringClass + val classPool = ctClass.getClassPool + new LambdaReflection(classPool, ctClass, ctMethod) } def getClass(name: String): Class[_] = { diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/State.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/State.scala index df49b80a6648..21ebe3e034f9 100644 --- a/udf-compiler/src/main/scala/com/nvidia/spark/udf/State.scala +++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/State.scala @@ -19,6 +19,7 @@ package com.nvidia.spark.udf import CatalystExpressionBuilder.simplify import javassist.CtClass +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.expressions.{Expression, If, Literal, Or} /** @@ -74,7 +75,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, If, Literal, Or} * @param cond * @param expr */ -case class State(locals: Array[Expression], +case class State(locals: IndexedSeq[Expression], stack: List[Expression] = List(), cond: Expression = Literal.TrueLiteral, expr: Option[Expression] = None) { @@ -117,23 +118,23 @@ case class State(locals: Array[Expression], object State { def makeStartingState(lambdaReflection: LambdaReflection, - children: Seq[Expression]): State = { + children: Seq[Expression], + objref: Option[Expression]): State = { val max = lambdaReflection.maxLocals - val params: Seq[(CtClass, Expression)] = lambdaReflection.parameters.view.zip(children) - val (locals, _) = params.foldLeft((new Array[Expression](max), 0)) { (l, p) => - val (locals: Array[Expression], index: Int) = l - val (param: CtClass, argExp: Expression) = p - - val newIndex = if (param == CtClass.doubleType || param == CtClass.longType) { + val _children = lambdaReflection.capturedArgs ++ children + val params: Seq[(CtClass, Expression)] = lambdaReflection.parameters.view.zip(_children) + val locals = params.foldLeft(objref.toVector) { (l, p) => + val (param: CtClass, argExp) = p + if (param == CtClass.doubleType || param == CtClass.longType) { // Long and Double occupies two slots in the local variable array. + // Append null to occupy an extra slot. // See https://docs.oracle.com/javase/specs/jvms/se8/html/jvms-2.html#jvms-2.6.1 - index + 2 + l :+ argExp :+ null } else { - index + 1 + l :+ argExp } - - (locals.updated(index, argExp), newIndex) } - State(locals) + // Ensure locals have enough slots with padTo. + State(locals.padTo(max, null)) } } diff --git a/udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala b/udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala index ecf9d794431c..9b4d0de65f58 100644 --- a/udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala +++ b/udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.Assertions._ import org.scalatest.FunSuite import org.apache.spark.SparkConf -import org.apache.spark.sql.{Dataset, SparkSession} +import org.apache.spark.sql.{Dataset, Row, SparkSession} import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions.{udf => makeUdf} @@ -35,7 +35,6 @@ class OpcodeSuite extends FunSuite { val conf: SparkConf = new SparkConf() .set("spark.sql.extensions", "com.nvidia.spark.udf.Plugin") - .set("spark.rapids.sql.test.enabled", "true") .set("spark.rapids.sql.udfCompiler.enabled", "true") .set(RapidsConf.EXPLAIN.key, "true") @@ -2088,19 +2087,201 @@ class OpcodeSuite extends FunSuite { checkEquiv(result, ref) } - test("Arbitrary function call inside UDF - ") { - def simple_func(str: String): Int = { + test("final static method call inside UDF - ") { + def simple_func2(str: String): Int = { str.length } + def simple_func1(str: String): Int = { + simple_func2(str) + simple_func2(str) + } val myudf: (String) => Int = str => { - simple_func(str) + simple_func1(str) } val u = makeUdf(myudf) val dataset = List("hello", "world").toDS() val result = dataset.withColumn("new", u(col("value"))) + val ref = dataset.withColumn("new", length(col("value")) + length(col("value"))) + checkEquiv(result, ref) + } + + test("FALLBACK TO CPU: class method call inside UDF - ") { + class C { + def simple_func(str: String): Int = { + str.length + } + val myudf: (String) => Int = str => { + simple_func(str) + } + def apply(dataset: Dataset[String]): Dataset[Row] = { + val u = makeUdf(myudf) + dataset.withColumn("new", u(col("value"))) + } + } + val runner = new C + val dataset = List("hello", "world").toDS() + val result = runner(dataset) + val ref = dataset.withColumn("new", length(col("value"))) + checkEquivNotCompiled(result, ref) + } + + test("final class method call inside UDF - ") { + final class C { + def simple_func(str: String): Int = { + str.length + } + val myudf: (String) => Int = str => { + simple_func(str) + } + def apply(dataset: Dataset[String]): Dataset[Row] = { + val u = makeUdf(myudf) + dataset.withColumn("new", u(col("value"))) + } + } + val runner = new C + val dataset = List("hello", "world").toDS() + val result = runner(dataset) + val ref = dataset.withColumn("new", length(col("value"))) + checkEquiv(result, ref) + } + + test("FALL BACK TO CPU: object method call inside UDF - ") { + object C { + def simple_func(str: String): Int = { + str.length + } + val myudf: (String) => Int = str => { + simple_func(str) + } + def apply(dataset: Dataset[String]): Dataset[Row] = { + val u = makeUdf(myudf) + dataset.withColumn("new", u(col("value"))) + } + } + val dataset = List("hello", "world").toDS() + val result = C(dataset) val ref = dataset.withColumn("new", length(col("value"))) - result.explain(true) - result.show -// checkEquiv(result, ref) + checkEquivNotCompiled(result, ref) } + + test("final object method call inside UDF - ") { + final object C { + def simple_func(str: String): Int = { + str.length + } + val myudf: (String) => Int = str => { + simple_func(str) + } + def apply(dataset: Dataset[String]): Dataset[Row] = { + val u = makeUdf(myudf) + dataset.withColumn("new", u(col("value"))) + } + } + val dataset = List("hello", "world").toDS() + val result = C(dataset) + val ref = dataset.withColumn("new", length(col("value"))) + checkEquiv(result, ref) + } + + test("super class final method call inside UDF - ") { + class B { + final def simple_func(str: String): Int = { + str.length + } + } + class D extends B { + val myudf: (String) => Int = str => { + simple_func(str) + } + def apply(dataset: Dataset[String]): Dataset[Row] = { + val u = makeUdf(myudf) + dataset.withColumn("new", u(col("value"))) + } + } + val runner = new D + val dataset = List("hello", "world").toDS() + val result = runner(dataset) + val ref = dataset.withColumn("new", length(col("value"))) + checkEquiv(result, ref) + } + + test("FALLBACK TO CPU: final class calls super class method inside UDF - ") { + class B { + def simple_func(str: String): Int = { + str.length + } + } + final class D extends B { + val myudf: (String) => Int = str => { + simple_func(str) + } + def apply(dataset: Dataset[String]): Dataset[Row] = { + val u = makeUdf(myudf) + dataset.withColumn("new", u(col("value"))) + } + } + val runner = new D + val dataset = List("hello", "world").toDS() + val result = runner(dataset) + val ref = dataset.withColumn("new", length(col("value"))) + checkEquivNotCompiled(result, ref) + } + + test("FALLBACK TO CPU: super class method call inside UDF - ") { + class B { + def simple_func(str: String): Int = { + str.length + } + } + class D extends B { + val myudf: (String) => Int = str => { + simple_func(str) + } + def apply(dataset: Dataset[String]): Dataset[Row] = { + val u = makeUdf(myudf) + dataset.withColumn("new", u(col("value"))) + } + } + val runner = new D + val dataset = List("hello", "world").toDS() + val result = runner(dataset) + val ref = dataset.withColumn("new", length(col("value"))) + checkEquivNotCompiled(result, ref) + } + + test("FALLBACK TO CPU: capture a var in class - ") { + class C { + var capturedArg: Int = 4 + val myudf: (String) => Int = str => { + str.length + capturedArg + } + def apply(dataset: Dataset[String]): Dataset[Row] = { + val u = makeUdf(myudf) + dataset.withColumn("new", u(col("value"))) + } + } + val runner = new C + val dataset = List("hello", "world").toDS() + val result = runner(dataset) + val ref = dataset.withColumn("new", length(col("value")) + runner.capturedArg) + checkEquivNotCompiled(result, ref) + } + + test("FALLBACK TO CPU: capture a var outside class - ") { + var capturedArg: Int = 4 + class C { + val myudf: (String) => Int = str => { + str.length + capturedArg + } + def apply(dataset: Dataset[String]): Dataset[Row] = { + val u = makeUdf(myudf) + dataset.withColumn("new", u(col("value"))) + } + } + val runner = new C + val dataset = List("hello", "world").toDS() + val result = runner(dataset) + val ref = dataset.withColumn("new", length(col("value")) + capturedArg) + checkEquivNotCompiled(result, ref) + } + }