From 0bcfc170b02bde0324825fb6b0fbf240ae50340a Mon Sep 17 00:00:00 2001 From: Sean Lee Date: Wed, 14 Oct 2020 23:14:55 -0700 Subject: [PATCH] Inline method calls when possible. With this change, method calls are inlined only if the method being called 1. consists of operations supported by the UDF compiler, and 2. is one of the folllowing: * a final method, or * a method in a final class, or * a method in a final class Signed-off-by: Sean Lee --- docs/compatibility.md | 1 + .../spark/udf/CatalystExpressionBuilder.scala | 4 +- .../com/nvidia/spark/udf/Instruction.scala | 38 +++- .../nvidia/spark/udf/LambdaReflection.scala | 63 ++++-- .../scala/com/nvidia/spark/udf/State.scala | 27 +-- .../scala/com/nvidia/spark/OpcodeSuite.scala | 197 +++++++++++++++++- 6 files changed, 279 insertions(+), 51 deletions(-) diff --git a/docs/compatibility.md b/docs/compatibility.md index 653037c46aef..807205179047 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -384,3 +384,4 @@ When translating UDFs to Catalyst expressions, the supported UDF functions are l | | Array.empty[Float] | | | Array.empty[Double] | | | Array.empty[String] | +| Method call | Only if the method being called
  1. consists of operations supported by the UDF compiler, and
  2. is one of the folllowing:
    • a final method, or
    • a method in a final class, or
    • a method in a final object
| diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/CatalystExpressionBuilder.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/CatalystExpressionBuilder.scala index 8a6634473a2b..52c31444749c 100644 --- a/udf-compiler/src/main/scala/com/nvidia/spark/udf/CatalystExpressionBuilder.scala +++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/CatalystExpressionBuilder.scala @@ -63,11 +63,11 @@ case class CatalystExpressionBuilder(private val function: AnyRef) extends Loggi * @param children a sequence of catalyst arguments to the udf. * @return the compiled expression, optionally */ - def compile(children: Seq[Expression]): Option[Expression] = { + def compile(children: Seq[Expression], objref: Option[Expression] = None): Option[Expression] = { // create starting state, this will be: // State([children expressions], [empty stack], cond = true, expr = None) - val entryState = State.makeStartingState(lambdaReflection, children) + val entryState = State.makeStartingState(lambdaReflection, children, objref) // pick first of the Basic Blocks, and start recursing val entryBlock = cfg.basicBlocks.head diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala index 2f36a70e3d09..5ddceaa38559 100644 --- a/udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala +++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala @@ -18,16 +18,16 @@ package com.nvidia.spark.udf import CatalystExpressionBuilder.simplify import java.nio.charset.Charset +import javassist.Modifier import javassist.bytecode.{CodeIterator, Opcode} import org.apache.spark.SparkException import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -private object Repr { +private[udf] object Repr { abstract class CompilerInternal(name: String) extends Expression { override def dataType: DataType = { @@ -110,6 +110,8 @@ private object Repr { case class ClassTag[T](classTag: scala.reflect.ClassTag[T]) extends CompilerInternal("scala.reflect.ClassTag") + + case class UnknownCapturedArg() extends CompilerInternal("unknown captured arg") } /** @@ -346,7 +348,8 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend (List[Expression], List[Expression])): State = { val State(locals, stack, cond, expr) = state val method = lambdaReflection.lookupBehavior(operand) - val declaringClassName = method.getDeclaringClass.getName + val declaringClass = method.getDeclaringClass + val declaringClassName = declaringClass.getName val paramTypes = method.getParameterTypes val (args, rest) = getArgs(stack, paramTypes.length) // We don't support arbitrary calls. @@ -389,9 +392,32 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend } else if (declaringClassName.equals("java.time.LocalDateTime")) { State(locals, localDateTimeOp(method.getName, args) :: rest, cond, expr) } else { - // Other functions - throw new SparkException( - s"Unsupported instruction: ${Opcode.INVOKEVIRTUAL} ${declaringClassName}") + val mModifiers = method.getModifiers + val cModifiers = declaringClass.getModifiers + if (!javassist.Modifier.isEnum(mModifiers) && + !javassist.Modifier.isInterface(mModifiers) && + !javassist.Modifier.isNative(mModifiers) && + !javassist.Modifier.isPackage(mModifiers) && + !javassist.Modifier.isStrict(mModifiers) && + !javassist.Modifier.isSynchronized(mModifiers) && + !javassist.Modifier.isTransient(mModifiers) && + !javassist.Modifier.isVarArgs(mModifiers) && + !javassist.Modifier.isVolatile(mModifiers) && + (javassist.Modifier.isFinal(mModifiers) || + javassist.Modifier.isFinal(cModifiers))) { + val retval = { + if (javassist.Modifier.isStatic(mModifiers)) { + CatalystExpressionBuilder(method).compile(args) + } else { + CatalystExpressionBuilder(method).compile(args.tail, Some(args.head)) + } + } + State(locals, retval.toList ::: rest, cond, expr) + } else { + // Other functions + throw new SparkException( + s"Unsupported invocation of ${declaringClassName}.${method.getName}") + } } } diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/LambdaReflection.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/LambdaReflection.scala index ff91a1aef6bd..133f5166252d 100644 --- a/udf-compiler/src/main/scala/com/nvidia/spark/udf/LambdaReflection.scala +++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/LambdaReflection.scala @@ -16,12 +16,14 @@ package com.nvidia.spark.udf +import Repr.UnknownCapturedArg import java.lang.invoke.SerializedLambda - -import javassist.{ClassClassPath, ClassPool, CtBehavior, CtClass, CtField} -import javassist.bytecode.{CodeIterator, ConstPool, Descriptor} +import javassist.{ClassClassPath, ClassPool, CtBehavior, CtClass, CtField, CtMethod} +import javassist.bytecode.{AccessFlag, CodeIterator, ConstPool, + Descriptor, MethodParametersAttribute} import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.types._ // @@ -30,8 +32,10 @@ import org.apache.spark.sql.types._ // Provides the interface the class and the method that implements the body of the lambda // used by the rest of the compiler. // -case class LambdaReflection(private val classPool: ClassPool, - private val serializedLambda: SerializedLambda) { +class LambdaReflection private(private val classPool: ClassPool, + private val ctClass: CtClass, + private val ctMethod: CtMethod, + val capturedArgs: Seq[Expression] = Seq()) { def lookupConstant(constPoolIndex: Int): Any = { constPool.getTag(constPoolIndex) match { case ConstPool.CONST_Integer => constPool.getIntegerInfo(constPoolIndex) @@ -61,10 +65,10 @@ case class LambdaReflection(private val classPool: ClassPool, val methodName = constPool.getMethodrefName(constPoolIndex) val descriptor = constPool.getMethodrefType(constPoolIndex) val className = constPool.getMethodrefClassName(constPoolIndex) - val params = Descriptor.getParameterTypes(descriptor, classPool) if (constPool.isConstructor(className, constPoolIndex) == 0) { - classPool.getCtClass(className).getDeclaredMethod(methodName, params) + classPool.getCtClass(className).getMethod(methodName, descriptor) } else { + val params = Descriptor.getParameterTypes(descriptor, classPool) classPool.getCtClass(className).getDeclaredConstructor(params) } } @@ -76,20 +80,6 @@ case class LambdaReflection(private val classPool: ClassPool, constPool.getClassInfo(constPoolIndex) } - // Get the CtClass object for the class that capture the lambda. - private val ctClass = { - val name = serializedLambda.getCapturingClass.replace('/', '.') - val classForName = LambdaReflection.getClass(name) - classPool.insertClassPath(new ClassClassPath(classForName)) - classPool.getCtClass(name) - } - - // Get the CtMethod object for the method that implements the lambda body. - private val ctMethod = { - val lambdaImplName = serializedLambda.getImplMethodName - ctClass.getDeclaredMethod(lambdaImplName.stripSuffix("$adapted")) - } - private val methodInfo = ctMethod.getMethodInfo val constPool = methodInfo.getConstPool @@ -107,6 +97,9 @@ case class LambdaReflection(private val classPool: ClassPool, object LambdaReflection { def apply(function: AnyRef): LambdaReflection = { + if (function.isInstanceOf[CtMethod]) { + return LambdaReflection(function.asInstanceOf[CtMethod]) + } // writeReplace is supposed to return an object of SerializedLambda from // the function class (See // https://docs.oracle.com/javase/8/docs/api/java/lang/invoke/SerializedLambda.html). @@ -117,9 +110,35 @@ object LambdaReflection { writeReplace.setAccessible(true) val serializedLambda = writeReplace.invoke(function) .asInstanceOf[SerializedLambda] + val capturedArgs = Seq.tabulate(serializedLambda.getCapturedArgCount){ + // Add UnknownCapturedArg expressions to the list of captured arguments + // until we figure out how to handle captured variables in various + // situations. + // This, however, lets us pass captured arguments as long as they are not + // evaluated. + _ => Repr.UnknownCapturedArg() + } val classPool = new ClassPool(true) - LambdaReflection(classPool, serializedLambda) + // Get the CtClass object for the class that capture the lambda. + val ctClass = { + val name = serializedLambda.getCapturingClass.replace('/', '.') + val classForName = LambdaReflection.getClass(name) + classPool.insertClassPath(new ClassClassPath(classForName)) + classPool.getCtClass(name) + } + // Get the CtMethod object for the method that implements the lambda body. + val ctMethod = { + val lambdaImplName = serializedLambda.getImplMethodName + ctClass.getDeclaredMethod(lambdaImplName.stripSuffix("$adapted")) + } + new LambdaReflection(classPool, ctClass, ctMethod, capturedArgs) + } + + private def apply(ctMethod: CtMethod): LambdaReflection = { + val ctClass = ctMethod.getDeclaringClass + val classPool = ctClass.getClassPool + new LambdaReflection(classPool, ctClass, ctMethod) } def getClass(name: String): Class[_] = { diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/State.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/State.scala index df49b80a6648..21ebe3e034f9 100644 --- a/udf-compiler/src/main/scala/com/nvidia/spark/udf/State.scala +++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/State.scala @@ -19,6 +19,7 @@ package com.nvidia.spark.udf import CatalystExpressionBuilder.simplify import javassist.CtClass +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.expressions.{Expression, If, Literal, Or} /** @@ -74,7 +75,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, If, Literal, Or} * @param cond * @param expr */ -case class State(locals: Array[Expression], +case class State(locals: IndexedSeq[Expression], stack: List[Expression] = List(), cond: Expression = Literal.TrueLiteral, expr: Option[Expression] = None) { @@ -117,23 +118,23 @@ case class State(locals: Array[Expression], object State { def makeStartingState(lambdaReflection: LambdaReflection, - children: Seq[Expression]): State = { + children: Seq[Expression], + objref: Option[Expression]): State = { val max = lambdaReflection.maxLocals - val params: Seq[(CtClass, Expression)] = lambdaReflection.parameters.view.zip(children) - val (locals, _) = params.foldLeft((new Array[Expression](max), 0)) { (l, p) => - val (locals: Array[Expression], index: Int) = l - val (param: CtClass, argExp: Expression) = p - - val newIndex = if (param == CtClass.doubleType || param == CtClass.longType) { + val _children = lambdaReflection.capturedArgs ++ children + val params: Seq[(CtClass, Expression)] = lambdaReflection.parameters.view.zip(_children) + val locals = params.foldLeft(objref.toVector) { (l, p) => + val (param: CtClass, argExp) = p + if (param == CtClass.doubleType || param == CtClass.longType) { // Long and Double occupies two slots in the local variable array. + // Append null to occupy an extra slot. // See https://docs.oracle.com/javase/specs/jvms/se8/html/jvms-2.html#jvms-2.6.1 - index + 2 + l :+ argExp :+ null } else { - index + 1 + l :+ argExp } - - (locals.updated(index, argExp), newIndex) } - State(locals) + // Ensure locals have enough slots with padTo. + State(locals.padTo(max, null)) } } diff --git a/udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala b/udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala index ecf9d794431c..9b4d0de65f58 100644 --- a/udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala +++ b/udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.Assertions._ import org.scalatest.FunSuite import org.apache.spark.SparkConf -import org.apache.spark.sql.{Dataset, SparkSession} +import org.apache.spark.sql.{Dataset, Row, SparkSession} import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions.{udf => makeUdf} @@ -35,7 +35,6 @@ class OpcodeSuite extends FunSuite { val conf: SparkConf = new SparkConf() .set("spark.sql.extensions", "com.nvidia.spark.udf.Plugin") - .set("spark.rapids.sql.test.enabled", "true") .set("spark.rapids.sql.udfCompiler.enabled", "true") .set(RapidsConf.EXPLAIN.key, "true") @@ -2088,19 +2087,201 @@ class OpcodeSuite extends FunSuite { checkEquiv(result, ref) } - test("Arbitrary function call inside UDF - ") { - def simple_func(str: String): Int = { + test("final static method call inside UDF - ") { + def simple_func2(str: String): Int = { str.length } + def simple_func1(str: String): Int = { + simple_func2(str) + simple_func2(str) + } val myudf: (String) => Int = str => { - simple_func(str) + simple_func1(str) } val u = makeUdf(myudf) val dataset = List("hello", "world").toDS() val result = dataset.withColumn("new", u(col("value"))) + val ref = dataset.withColumn("new", length(col("value")) + length(col("value"))) + checkEquiv(result, ref) + } + + test("FALLBACK TO CPU: class method call inside UDF - ") { + class C { + def simple_func(str: String): Int = { + str.length + } + val myudf: (String) => Int = str => { + simple_func(str) + } + def apply(dataset: Dataset[String]): Dataset[Row] = { + val u = makeUdf(myudf) + dataset.withColumn("new", u(col("value"))) + } + } + val runner = new C + val dataset = List("hello", "world").toDS() + val result = runner(dataset) + val ref = dataset.withColumn("new", length(col("value"))) + checkEquivNotCompiled(result, ref) + } + + test("final class method call inside UDF - ") { + final class C { + def simple_func(str: String): Int = { + str.length + } + val myudf: (String) => Int = str => { + simple_func(str) + } + def apply(dataset: Dataset[String]): Dataset[Row] = { + val u = makeUdf(myudf) + dataset.withColumn("new", u(col("value"))) + } + } + val runner = new C + val dataset = List("hello", "world").toDS() + val result = runner(dataset) + val ref = dataset.withColumn("new", length(col("value"))) + checkEquiv(result, ref) + } + + test("FALL BACK TO CPU: object method call inside UDF - ") { + object C { + def simple_func(str: String): Int = { + str.length + } + val myudf: (String) => Int = str => { + simple_func(str) + } + def apply(dataset: Dataset[String]): Dataset[Row] = { + val u = makeUdf(myudf) + dataset.withColumn("new", u(col("value"))) + } + } + val dataset = List("hello", "world").toDS() + val result = C(dataset) val ref = dataset.withColumn("new", length(col("value"))) - result.explain(true) - result.show -// checkEquiv(result, ref) + checkEquivNotCompiled(result, ref) } + + test("final object method call inside UDF - ") { + final object C { + def simple_func(str: String): Int = { + str.length + } + val myudf: (String) => Int = str => { + simple_func(str) + } + def apply(dataset: Dataset[String]): Dataset[Row] = { + val u = makeUdf(myudf) + dataset.withColumn("new", u(col("value"))) + } + } + val dataset = List("hello", "world").toDS() + val result = C(dataset) + val ref = dataset.withColumn("new", length(col("value"))) + checkEquiv(result, ref) + } + + test("super class final method call inside UDF - ") { + class B { + final def simple_func(str: String): Int = { + str.length + } + } + class D extends B { + val myudf: (String) => Int = str => { + simple_func(str) + } + def apply(dataset: Dataset[String]): Dataset[Row] = { + val u = makeUdf(myudf) + dataset.withColumn("new", u(col("value"))) + } + } + val runner = new D + val dataset = List("hello", "world").toDS() + val result = runner(dataset) + val ref = dataset.withColumn("new", length(col("value"))) + checkEquiv(result, ref) + } + + test("FALLBACK TO CPU: final class calls super class method inside UDF - ") { + class B { + def simple_func(str: String): Int = { + str.length + } + } + final class D extends B { + val myudf: (String) => Int = str => { + simple_func(str) + } + def apply(dataset: Dataset[String]): Dataset[Row] = { + val u = makeUdf(myudf) + dataset.withColumn("new", u(col("value"))) + } + } + val runner = new D + val dataset = List("hello", "world").toDS() + val result = runner(dataset) + val ref = dataset.withColumn("new", length(col("value"))) + checkEquivNotCompiled(result, ref) + } + + test("FALLBACK TO CPU: super class method call inside UDF - ") { + class B { + def simple_func(str: String): Int = { + str.length + } + } + class D extends B { + val myudf: (String) => Int = str => { + simple_func(str) + } + def apply(dataset: Dataset[String]): Dataset[Row] = { + val u = makeUdf(myudf) + dataset.withColumn("new", u(col("value"))) + } + } + val runner = new D + val dataset = List("hello", "world").toDS() + val result = runner(dataset) + val ref = dataset.withColumn("new", length(col("value"))) + checkEquivNotCompiled(result, ref) + } + + test("FALLBACK TO CPU: capture a var in class - ") { + class C { + var capturedArg: Int = 4 + val myudf: (String) => Int = str => { + str.length + capturedArg + } + def apply(dataset: Dataset[String]): Dataset[Row] = { + val u = makeUdf(myudf) + dataset.withColumn("new", u(col("value"))) + } + } + val runner = new C + val dataset = List("hello", "world").toDS() + val result = runner(dataset) + val ref = dataset.withColumn("new", length(col("value")) + runner.capturedArg) + checkEquivNotCompiled(result, ref) + } + + test("FALLBACK TO CPU: capture a var outside class - ") { + var capturedArg: Int = 4 + class C { + val myudf: (String) => Int = str => { + str.length + capturedArg + } + def apply(dataset: Dataset[String]): Dataset[Row] = { + val u = makeUdf(myudf) + dataset.withColumn("new", u(col("value"))) + } + } + val runner = new C + val dataset = List("hello", "world").toDS() + val result = runner(dataset) + val ref = dataset.withColumn("new", length(col("value")) + capturedArg) + checkEquivNotCompiled(result, ref) + } + }