From 2e81d82ce656b950db585d23cc4390030c6c9a9a Mon Sep 17 00:00:00 2001 From: Allen Xu Date: Fri, 25 Sep 2020 16:26:30 +0800 Subject: [PATCH 1/2] Add test case for arbitrary function call in UDF Signed-off-by: Allen Xu --- .../scala/com/nvidia/spark/OpcodeSuite.scala | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala b/udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala index c8c5e060404..ecf9d794431 100644 --- a/udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala +++ b/udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala @@ -35,6 +35,7 @@ class OpcodeSuite extends FunSuite { val conf: SparkConf = new SparkConf() .set("spark.sql.extensions", "com.nvidia.spark.udf.Plugin") + .set("spark.rapids.sql.test.enabled", "true") .set("spark.rapids.sql.udfCompiler.enabled", "true") .set(RapidsConf.EXPLAIN.key, "true") @@ -2086,4 +2087,20 @@ class OpcodeSuite extends FunSuite { val ref = dataset.select(lit(Array.empty[String]).as("emptyArrOfStr")) checkEquiv(result, ref) } + + test("Arbitrary function call inside UDF - ") { + def simple_func(str: String): Int = { + str.length + } + val myudf: (String) => Int = str => { + simple_func(str) + } + val u = makeUdf(myudf) + val dataset = List("hello", "world").toDS() + val result = dataset.withColumn("new", u(col("value"))) + val ref = dataset.withColumn("new", length(col("value"))) + result.explain(true) + result.show +// checkEquiv(result, ref) + } } From b37dba8ed74cbd09888da8d1e91140c3c0ab2c58 Mon Sep 17 00:00:00 2001 From: Sean Lee Date: Wed, 14 Oct 2020 23:14:55 -0700 Subject: [PATCH 2/2] Inline method calls when possible. With this change, method calls are inlined only if the method being called 1. consists of operations supported by the UDF compiler, and 2. is one of the folllowing: * a final method, or * a method in a final class, or * a method in a final class Signed-off-by: Sean Lee --- docs/compatibility.md | 1 + .../spark/udf/CatalystExpressionBuilder.scala | 4 +- .../com/nvidia/spark/udf/Instruction.scala | 38 +++- .../nvidia/spark/udf/LambdaReflection.scala | 63 ++++-- .../scala/com/nvidia/spark/udf/State.scala | 27 +-- .../scala/com/nvidia/spark/OpcodeSuite.scala | 197 +++++++++++++++++- 6 files changed, 279 insertions(+), 51 deletions(-) diff --git a/docs/compatibility.md b/docs/compatibility.md index 653037c46ae..80720517904 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -384,3 +384,4 @@ When translating UDFs to Catalyst expressions, the supported UDF functions are l | | Array.empty[Float] | | | Array.empty[Double] | | | Array.empty[String] | +| Method call | Only if the method being called
  1. consists of operations supported by the UDF compiler, and
  2. is one of the folllowing:
    • a final method, or
    • a method in a final class, or
    • a method in a final object
| diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/CatalystExpressionBuilder.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/CatalystExpressionBuilder.scala index 8a6634473a2..52c31444749 100644 --- a/udf-compiler/src/main/scala/com/nvidia/spark/udf/CatalystExpressionBuilder.scala +++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/CatalystExpressionBuilder.scala @@ -63,11 +63,11 @@ case class CatalystExpressionBuilder(private val function: AnyRef) extends Loggi * @param children a sequence of catalyst arguments to the udf. * @return the compiled expression, optionally */ - def compile(children: Seq[Expression]): Option[Expression] = { + def compile(children: Seq[Expression], objref: Option[Expression] = None): Option[Expression] = { // create starting state, this will be: // State([children expressions], [empty stack], cond = true, expr = None) - val entryState = State.makeStartingState(lambdaReflection, children) + val entryState = State.makeStartingState(lambdaReflection, children, objref) // pick first of the Basic Blocks, and start recursing val entryBlock = cfg.basicBlocks.head diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala index 2f36a70e3d0..5ddceaa3855 100644 --- a/udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala +++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala @@ -18,16 +18,16 @@ package com.nvidia.spark.udf import CatalystExpressionBuilder.simplify import java.nio.charset.Charset +import javassist.Modifier import javassist.bytecode.{CodeIterator, Opcode} import org.apache.spark.SparkException import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -private object Repr { +private[udf] object Repr { abstract class CompilerInternal(name: String) extends Expression { override def dataType: DataType = { @@ -110,6 +110,8 @@ private object Repr { case class ClassTag[T](classTag: scala.reflect.ClassTag[T]) extends CompilerInternal("scala.reflect.ClassTag") + + case class UnknownCapturedArg() extends CompilerInternal("unknown captured arg") } /** @@ -346,7 +348,8 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend (List[Expression], List[Expression])): State = { val State(locals, stack, cond, expr) = state val method = lambdaReflection.lookupBehavior(operand) - val declaringClassName = method.getDeclaringClass.getName + val declaringClass = method.getDeclaringClass + val declaringClassName = declaringClass.getName val paramTypes = method.getParameterTypes val (args, rest) = getArgs(stack, paramTypes.length) // We don't support arbitrary calls. @@ -389,9 +392,32 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend } else if (declaringClassName.equals("java.time.LocalDateTime")) { State(locals, localDateTimeOp(method.getName, args) :: rest, cond, expr) } else { - // Other functions - throw new SparkException( - s"Unsupported instruction: ${Opcode.INVOKEVIRTUAL} ${declaringClassName}") + val mModifiers = method.getModifiers + val cModifiers = declaringClass.getModifiers + if (!javassist.Modifier.isEnum(mModifiers) && + !javassist.Modifier.isInterface(mModifiers) && + !javassist.Modifier.isNative(mModifiers) && + !javassist.Modifier.isPackage(mModifiers) && + !javassist.Modifier.isStrict(mModifiers) && + !javassist.Modifier.isSynchronized(mModifiers) && + !javassist.Modifier.isTransient(mModifiers) && + !javassist.Modifier.isVarArgs(mModifiers) && + !javassist.Modifier.isVolatile(mModifiers) && + (javassist.Modifier.isFinal(mModifiers) || + javassist.Modifier.isFinal(cModifiers))) { + val retval = { + if (javassist.Modifier.isStatic(mModifiers)) { + CatalystExpressionBuilder(method).compile(args) + } else { + CatalystExpressionBuilder(method).compile(args.tail, Some(args.head)) + } + } + State(locals, retval.toList ::: rest, cond, expr) + } else { + // Other functions + throw new SparkException( + s"Unsupported invocation of ${declaringClassName}.${method.getName}") + } } } diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/LambdaReflection.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/LambdaReflection.scala index ff91a1aef6b..133f5166252 100644 --- a/udf-compiler/src/main/scala/com/nvidia/spark/udf/LambdaReflection.scala +++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/LambdaReflection.scala @@ -16,12 +16,14 @@ package com.nvidia.spark.udf +import Repr.UnknownCapturedArg import java.lang.invoke.SerializedLambda - -import javassist.{ClassClassPath, ClassPool, CtBehavior, CtClass, CtField} -import javassist.bytecode.{CodeIterator, ConstPool, Descriptor} +import javassist.{ClassClassPath, ClassPool, CtBehavior, CtClass, CtField, CtMethod} +import javassist.bytecode.{AccessFlag, CodeIterator, ConstPool, + Descriptor, MethodParametersAttribute} import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.types._ // @@ -30,8 +32,10 @@ import org.apache.spark.sql.types._ // Provides the interface the class and the method that implements the body of the lambda // used by the rest of the compiler. // -case class LambdaReflection(private val classPool: ClassPool, - private val serializedLambda: SerializedLambda) { +class LambdaReflection private(private val classPool: ClassPool, + private val ctClass: CtClass, + private val ctMethod: CtMethod, + val capturedArgs: Seq[Expression] = Seq()) { def lookupConstant(constPoolIndex: Int): Any = { constPool.getTag(constPoolIndex) match { case ConstPool.CONST_Integer => constPool.getIntegerInfo(constPoolIndex) @@ -61,10 +65,10 @@ case class LambdaReflection(private val classPool: ClassPool, val methodName = constPool.getMethodrefName(constPoolIndex) val descriptor = constPool.getMethodrefType(constPoolIndex) val className = constPool.getMethodrefClassName(constPoolIndex) - val params = Descriptor.getParameterTypes(descriptor, classPool) if (constPool.isConstructor(className, constPoolIndex) == 0) { - classPool.getCtClass(className).getDeclaredMethod(methodName, params) + classPool.getCtClass(className).getMethod(methodName, descriptor) } else { + val params = Descriptor.getParameterTypes(descriptor, classPool) classPool.getCtClass(className).getDeclaredConstructor(params) } } @@ -76,20 +80,6 @@ case class LambdaReflection(private val classPool: ClassPool, constPool.getClassInfo(constPoolIndex) } - // Get the CtClass object for the class that capture the lambda. - private val ctClass = { - val name = serializedLambda.getCapturingClass.replace('/', '.') - val classForName = LambdaReflection.getClass(name) - classPool.insertClassPath(new ClassClassPath(classForName)) - classPool.getCtClass(name) - } - - // Get the CtMethod object for the method that implements the lambda body. - private val ctMethod = { - val lambdaImplName = serializedLambda.getImplMethodName - ctClass.getDeclaredMethod(lambdaImplName.stripSuffix("$adapted")) - } - private val methodInfo = ctMethod.getMethodInfo val constPool = methodInfo.getConstPool @@ -107,6 +97,9 @@ case class LambdaReflection(private val classPool: ClassPool, object LambdaReflection { def apply(function: AnyRef): LambdaReflection = { + if (function.isInstanceOf[CtMethod]) { + return LambdaReflection(function.asInstanceOf[CtMethod]) + } // writeReplace is supposed to return an object of SerializedLambda from // the function class (See // https://docs.oracle.com/javase/8/docs/api/java/lang/invoke/SerializedLambda.html). @@ -117,9 +110,35 @@ object LambdaReflection { writeReplace.setAccessible(true) val serializedLambda = writeReplace.invoke(function) .asInstanceOf[SerializedLambda] + val capturedArgs = Seq.tabulate(serializedLambda.getCapturedArgCount){ + // Add UnknownCapturedArg expressions to the list of captured arguments + // until we figure out how to handle captured variables in various + // situations. + // This, however, lets us pass captured arguments as long as they are not + // evaluated. + _ => Repr.UnknownCapturedArg() + } val classPool = new ClassPool(true) - LambdaReflection(classPool, serializedLambda) + // Get the CtClass object for the class that capture the lambda. + val ctClass = { + val name = serializedLambda.getCapturingClass.replace('/', '.') + val classForName = LambdaReflection.getClass(name) + classPool.insertClassPath(new ClassClassPath(classForName)) + classPool.getCtClass(name) + } + // Get the CtMethod object for the method that implements the lambda body. + val ctMethod = { + val lambdaImplName = serializedLambda.getImplMethodName + ctClass.getDeclaredMethod(lambdaImplName.stripSuffix("$adapted")) + } + new LambdaReflection(classPool, ctClass, ctMethod, capturedArgs) + } + + private def apply(ctMethod: CtMethod): LambdaReflection = { + val ctClass = ctMethod.getDeclaringClass + val classPool = ctClass.getClassPool + new LambdaReflection(classPool, ctClass, ctMethod) } def getClass(name: String): Class[_] = { diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/State.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/State.scala index df49b80a664..8998ee248e9 100644 --- a/udf-compiler/src/main/scala/com/nvidia/spark/udf/State.scala +++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/State.scala @@ -19,6 +19,7 @@ package com.nvidia.spark.udf import CatalystExpressionBuilder.simplify import javassist.CtClass +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.expressions.{Expression, If, Literal, Or} /** @@ -74,7 +75,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, If, Literal, Or} * @param cond * @param expr */ -case class State(locals: Array[Expression], +case class State(locals: IndexedSeq[Expression], stack: List[Expression] = List(), cond: Expression = Literal.TrueLiteral, expr: Option[Expression] = None) { @@ -117,23 +118,23 @@ case class State(locals: Array[Expression], object State { def makeStartingState(lambdaReflection: LambdaReflection, - children: Seq[Expression]): State = { + children: Seq[Expression], + objref: Option[Expression]): State = { val max = lambdaReflection.maxLocals - val params: Seq[(CtClass, Expression)] = lambdaReflection.parameters.view.zip(children) - val (locals, _) = params.foldLeft((new Array[Expression](max), 0)) { (l, p) => - val (locals: Array[Expression], index: Int) = l - val (param: CtClass, argExp: Expression) = p - - val newIndex = if (param == CtClass.doubleType || param == CtClass.longType) { + val args = lambdaReflection.capturedArgs ++ children + val paramTypesAndArgs: Seq[(CtClass, Expression)] = lambdaReflection.parameters.view.zip(args) + val locals = paramTypesAndArgs.foldLeft(objref.toVector) { (l, p) => + val (paramType : CtClass, argExp) = p + if (paramType == CtClass.doubleType || paramType == CtClass.longType) { // Long and Double occupies two slots in the local variable array. + // Append null to occupy an extra slot. // See https://docs.oracle.com/javase/specs/jvms/se8/html/jvms-2.html#jvms-2.6.1 - index + 2 + l :+ argExp :+ null } else { - index + 1 + l :+ argExp } - - (locals.updated(index, argExp), newIndex) } - State(locals) + // Ensure locals have enough slots with padTo. + State(locals.padTo(max, null)) } } diff --git a/udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala b/udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala index ecf9d794431..ab2ef7ae250 100644 --- a/udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala +++ b/udf-compiler/src/test/scala/com/nvidia/spark/OpcodeSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.Assertions._ import org.scalatest.FunSuite import org.apache.spark.SparkConf -import org.apache.spark.sql.{Dataset, SparkSession} +import org.apache.spark.sql.{Dataset, Row, SparkSession} import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions.{udf => makeUdf} @@ -35,7 +35,6 @@ class OpcodeSuite extends FunSuite { val conf: SparkConf = new SparkConf() .set("spark.sql.extensions", "com.nvidia.spark.udf.Plugin") - .set("spark.rapids.sql.test.enabled", "true") .set("spark.rapids.sql.udfCompiler.enabled", "true") .set(RapidsConf.EXPLAIN.key, "true") @@ -2088,19 +2087,201 @@ class OpcodeSuite extends FunSuite { checkEquiv(result, ref) } - test("Arbitrary function call inside UDF - ") { - def simple_func(str: String): Int = { + test("final static method call inside UDF") { + def simple_func2(str: String): Int = { str.length } + def simple_func1(str: String): Int = { + simple_func2(str) + simple_func2(str) + } val myudf: (String) => Int = str => { - simple_func(str) + simple_func1(str) } val u = makeUdf(myudf) val dataset = List("hello", "world").toDS() val result = dataset.withColumn("new", u(col("value"))) + val ref = dataset.withColumn("new", length(col("value")) + length(col("value"))) + checkEquiv(result, ref) + } + + test("FALLBACK TO CPU: class method call inside UDF") { + class C { + def simple_func(str: String): Int = { + str.length + } + val myudf: (String) => Int = str => { + simple_func(str) + } + def apply(dataset: Dataset[String]): Dataset[Row] = { + val u = makeUdf(myudf) + dataset.withColumn("new", u(col("value"))) + } + } + val runner = new C + val dataset = List("hello", "world").toDS() + val result = runner(dataset) + val ref = dataset.withColumn("new", length(col("value"))) + checkEquivNotCompiled(result, ref) + } + + test("final class method call inside UDF") { + final class C { + def simple_func(str: String): Int = { + str.length + } + val myudf: (String) => Int = str => { + simple_func(str) + } + def apply(dataset: Dataset[String]): Dataset[Row] = { + val u = makeUdf(myudf) + dataset.withColumn("new", u(col("value"))) + } + } + val runner = new C + val dataset = List("hello", "world").toDS() + val result = runner(dataset) + val ref = dataset.withColumn("new", length(col("value"))) + checkEquiv(result, ref) + } + + test("FALL BACK TO CPU: object method call inside UDF") { + object C { + def simple_func(str: String): Int = { + str.length + } + val myudf: (String) => Int = str => { + simple_func(str) + } + def apply(dataset: Dataset[String]): Dataset[Row] = { + val u = makeUdf(myudf) + dataset.withColumn("new", u(col("value"))) + } + } + val dataset = List("hello", "world").toDS() + val result = C(dataset) val ref = dataset.withColumn("new", length(col("value"))) - result.explain(true) - result.show -// checkEquiv(result, ref) + checkEquivNotCompiled(result, ref) } + + test("final object method call inside UDF") { + final object C { + def simple_func(str: String): Int = { + str.length + } + val myudf: (String) => Int = str => { + simple_func(str) + } + def apply(dataset: Dataset[String]): Dataset[Row] = { + val u = makeUdf(myudf) + dataset.withColumn("new", u(col("value"))) + } + } + val dataset = List("hello", "world").toDS() + val result = C(dataset) + val ref = dataset.withColumn("new", length(col("value"))) + checkEquiv(result, ref) + } + + test("super class final method call inside UDF") { + class B { + final def simple_func(str: String): Int = { + str.length + } + } + class D extends B { + val myudf: (String) => Int = str => { + simple_func(str) + } + def apply(dataset: Dataset[String]): Dataset[Row] = { + val u = makeUdf(myudf) + dataset.withColumn("new", u(col("value"))) + } + } + val runner = new D + val dataset = List("hello", "world").toDS() + val result = runner(dataset) + val ref = dataset.withColumn("new", length(col("value"))) + checkEquiv(result, ref) + } + + test("FALLBACK TO CPU: final class calls super class method inside UDF") { + class B { + def simple_func(str: String): Int = { + str.length + } + } + final class D extends B { + val myudf: (String) => Int = str => { + simple_func(str) + } + def apply(dataset: Dataset[String]): Dataset[Row] = { + val u = makeUdf(myudf) + dataset.withColumn("new", u(col("value"))) + } + } + val runner = new D + val dataset = List("hello", "world").toDS() + val result = runner(dataset) + val ref = dataset.withColumn("new", length(col("value"))) + checkEquivNotCompiled(result, ref) + } + + test("FALLBACK TO CPU: super class method call inside UDF") { + class B { + def simple_func(str: String): Int = { + str.length + } + } + class D extends B { + val myudf: (String) => Int = str => { + simple_func(str) + } + def apply(dataset: Dataset[String]): Dataset[Row] = { + val u = makeUdf(myudf) + dataset.withColumn("new", u(col("value"))) + } + } + val runner = new D + val dataset = List("hello", "world").toDS() + val result = runner(dataset) + val ref = dataset.withColumn("new", length(col("value"))) + checkEquivNotCompiled(result, ref) + } + + test("FALLBACK TO CPU: capture a var in class") { + class C { + var capturedArg: Int = 4 + val myudf: (String) => Int = str => { + str.length + capturedArg + } + def apply(dataset: Dataset[String]): Dataset[Row] = { + val u = makeUdf(myudf) + dataset.withColumn("new", u(col("value"))) + } + } + val runner = new C + val dataset = List("hello", "world").toDS() + val result = runner(dataset) + val ref = dataset.withColumn("new", length(col("value")) + runner.capturedArg) + checkEquivNotCompiled(result, ref) + } + + test("FALLBACK TO CPU: capture a var outside class") { + var capturedArg: Int = 4 + class C { + val myudf: (String) => Int = str => { + str.length + capturedArg + } + def apply(dataset: Dataset[String]): Dataset[Row] = { + val u = makeUdf(myudf) + dataset.withColumn("new", u(col("value"))) + } + } + val runner = new C + val dataset = List("hello", "world").toDS() + val result = runner(dataset) + val ref = dataset.withColumn("new", length(col("value")) + capturedArg) + checkEquivNotCompiled(result, ref) + } + }