diff --git a/shims/spark300emr/src/main/scala/com/nvidia/spark/rapids/shims/spark300emr/Spark300EMRShims.scala b/shims/spark300emr/src/main/scala/com/nvidia/spark/rapids/shims/spark300emr/Spark300EMRShims.scala index 8446caeb6dd..4279a21c4dd 100644 --- a/shims/spark300emr/src/main/scala/com/nvidia/spark/rapids/shims/spark300emr/Spark300EMRShims.scala +++ b/shims/spark300emr/src/main/scala/com/nvidia/spark/rapids/shims/spark300emr/Spark300EMRShims.scala @@ -16,6 +16,8 @@ package com.nvidia.spark.rapids.shims.spark300emr +import java.lang.reflect.Constructor + import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.shims.spark300.Spark300Shims import com.nvidia.spark.rapids.spark300emr.RapidsShuffleManager @@ -27,6 +29,8 @@ import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, P class Spark300EMRShims extends Spark300Shims { + private var fileScanRddConstructor: Option[Constructor[_]] = None + override def getSparkShimVersion: ShimVersion = SparkShimServiceProvider.VERSION override def getRapidsShuffleManagerClass: String = { @@ -39,12 +43,16 @@ class Spark300EMRShims extends Spark300Shims { readFunction: (PartitionedFile) => Iterator[InternalRow], filePartitions: Seq[FilePartition]): RDD[InternalRow] = { - val tclass = classOf[org.apache.spark.sql.execution.datasources.FileScanRDD] - val constructors = tclass.getConstructors() - if (constructors.size > 1) { - throw new IllegalStateException(s"Only expected 1 constructor for FileScanRDD") + val constructor = fileScanRddConstructor.getOrElse { + val tclass = classOf[org.apache.spark.sql.execution.datasources.FileScanRDD] + val constructors = tclass.getConstructors() + if (constructors.size > 1) { + throw new IllegalStateException(s"Only expected 1 constructor for FileScanRDD") + } + val cnstr = constructors(0) + fileScanRddConstructor = Some(cnstr) + cnstr } - val constructor = constructors(0) val instance = if (constructor.getParameterCount() == 4) { constructor.newInstance(sparkSession, readFunction, filePartitions, None) } else if (constructor.getParameterCount() == 3) {