diff --git a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala index 27bbc361f19..907d608af74 100644 --- a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala +++ b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala @@ -642,8 +642,10 @@ trait Spark32XShims extends SparkShims { + TypeSig.ARRAY).nested(), TypeSig.all), (scan, conf, p, r) => new SparkPlanMeta[InMemoryTableScanExec](scan, conf, p, r) { override def tagPlanForGpu(): Unit = { - if (!scan.relation.cacheBuilder.serializer.isInstanceOf[ParquetCachedBatchSerializer]) { - willNotWorkOnGpu("ParquetCachedBatchSerializer is not being used") + scan.relation.cacheBuilder.serializer match { + case _: com.nvidia.spark.ParquetCachedBatchSerializer => () + case _ => + willNotWorkOnGpu("ParquetCachedBatchSerializer is not being used") } } @@ -885,7 +887,7 @@ trait Spark32XShims extends SparkShims { exportColumnRdd: Boolean): GpuColumnarToRowExecParent = { val serName = plan.conf.getConf(StaticSQLConf.SPARK_CACHE_SERIALIZER) val serClass = ShimLoader.loadClass(serName) - if (serClass == classOf[ParquetCachedBatchSerializer]) { + if (serClass == classOf[com.nvidia.spark.ParquetCachedBatchSerializer]) { GpuColumnarToRowTransitionExec(plan) } else { GpuColumnarToRowExec(plan)