diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuScalaUDF.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuScalaUDF.scala index fdda0c69c0d..eac51be377b 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuScalaUDF.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuScalaUDF.scala @@ -23,6 +23,7 @@ import com.nvidia.spark.rapids.{ExprChecks, ExprMeta, ExprRule, GpuExpression, G import com.nvidia.spark.rapids.GpuUserDefinedFunction.udfTypeSig import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} +import org.apache.spark.sql.rapids.execution.TrampolineUtil import org.apache.spark.sql.types.DataType case class GpuScalaUDF( @@ -76,14 +77,23 @@ object GpuScalaUDF { case f: RapidsUDF => Some(f) case f => try { + // This may be a lambda that Spark's UDFRegistration wrapped around a Java UDF instance. val clazz = f.getClass - val writeReplace = clazz.getDeclaredMethod("writeReplace") - writeReplace.setAccessible(true) - val serializedLambda = writeReplace.invoke(f).asInstanceOf[SerializedLambda] - if (serializedLambda.getCapturedArgCount == 1) { - serializedLambda.getCapturedArg(0) match { - case c: RapidsUDF => Some(c) - case _ => None + if (TrampolineUtil.getSimpleName(clazz).toLowerCase().contains("lambda")) { + // Try to find a `writeReplace` method, further indicating it is likely a lambda + // instance, and invoke it to serialize the lambda. Once serialized, captured arguments + // can be examine to locate the Java UDF instance. + // Note this relies on implementation details of Spark's UDFRegistration class. + val writeReplace = clazz.getDeclaredMethod("writeReplace") + writeReplace.setAccessible(true) + val serializedLambda = writeReplace.invoke(f).asInstanceOf[SerializedLambda] + if (serializedLambda.getCapturedArgCount == 1) { + serializedLambda.getCapturedArg(0) match { + case c: RapidsUDF => Some(c) + case _ => None + } + } else { + None } } else { None