Skip to content

Commit

Permalink
Add lambda classname check and comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowe committed Feb 12, 2021
1 parent 9b7439b commit cb1ea9d
Showing 1 changed file with 17 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import com.nvidia.spark.rapids.{ExprChecks, ExprMeta, ExprRule, GpuExpression, G
import com.nvidia.spark.rapids.GpuUserDefinedFunction.udfTypeSig

import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
import org.apache.spark.sql.rapids.execution.TrampolineUtil
import org.apache.spark.sql.types.DataType

case class GpuScalaUDF(
Expand Down Expand Up @@ -76,14 +77,23 @@ object GpuScalaUDF {
case f: RapidsUDF => Some(f)
case f =>
try {
// This may be a lambda that Spark's UDFRegistration wrapped around a Java UDF instance.
val clazz = f.getClass
val writeReplace = clazz.getDeclaredMethod("writeReplace")
writeReplace.setAccessible(true)
val serializedLambda = writeReplace.invoke(f).asInstanceOf[SerializedLambda]
if (serializedLambda.getCapturedArgCount == 1) {
serializedLambda.getCapturedArg(0) match {
case c: RapidsUDF => Some(c)
case _ => None
if (TrampolineUtil.getSimpleName(clazz).toLowerCase().contains("lambda")) {
// Try to find a `writeReplace` method, further indicating it is likely a lambda
// instance, and invoke it to serialize the lambda. Once serialized, captured arguments
// can be examine to locate the Java UDF instance.
// Note this relies on implementation details of Spark's UDFRegistration class.
val writeReplace = clazz.getDeclaredMethod("writeReplace")
writeReplace.setAccessible(true)
val serializedLambda = writeReplace.invoke(f).asInstanceOf[SerializedLambda]
if (serializedLambda.getCapturedArgCount == 1) {
serializedLambda.getCapturedArg(0) match {
case c: RapidsUDF => Some(c)
case _ => None
}
} else {
None
}
} else {
None
Expand Down

0 comments on commit cb1ea9d

Please sign in to comment.