From 2bda1eab79d64ae401488713a3266993bfdc1e02 Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Mon, 11 Apr 2022 13:54:10 +0800 Subject: [PATCH 1/3] Optimize Spark33XShims to reduce code deplicate Signed-off-by: Chong Gao --- .../spark/rapids/shims/GpuTypeShims.scala | 23 +- .../rapids/shims/Spark320PlusShims.scala | 176 ++++++------ .../spark/rapids/shims/GpuTypeShims.scala | 22 +- .../spark/rapids/shims/Spark33XShims.scala | 251 +----------------- .../nvidia/spark/rapids/GpuOverrides.scala | 72 +++-- .../com/nvidia/spark/rapids/SparkShims.scala | 1 - 6 files changed, 189 insertions(+), 356 deletions(-) diff --git a/sql-plugin/src/main/311until330-all/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala b/sql-plugin/src/main/311until330-all/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala index 4e014d36935..fad78ab194a 100644 --- a/sql-plugin/src/main/311until330-all/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala +++ b/sql-plugin/src/main/311until330-all/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala @@ -95,8 +95,29 @@ object GpuTypeShims { def isSupportedYearMonthType(dt: DataType): Boolean = false /** - * Get additional supported types for this Shim + * Get additional arithmetic supported types for this Shim */ def additionalArithmeticSupportedTypes: TypeSig = TypeSig.none + /** + * Get additional predicate supported types for this Shim + */ + def additionalPredicateSupportedTypes: TypeSig = TypeSig.none + + /** + * Get additional Csv supported types for this Shim + */ + def additionalCsvSupportedTypes: TypeSig = TypeSig.none + + /** + * Get additional Parquet supported types for this Shim + */ + def additionalParquetSupportedTypes: TypeSig = TypeSig.none + + /** + * Get additional common operators supported types for this Shim + * (filter, sample, project, alias, table scan ...... which GPU supports from 330) + */ + def additionalCommonOperatorSupportedTypes: TypeSig = TypeSig.none + } diff --git a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala index c1a78e3c728..2a0d2143cc1 100644 --- a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala +++ b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.DateFormatter import org.apache.spark.sql.connector.read.Scan -import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.{FileSourceScanExec, _} import org.apache.spark.sql.execution.adaptive._ import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources._ @@ -337,93 +337,14 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging { "Reading data from files, often from Hive tables", ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP + TypeSig.ARRAY + TypeSig.DECIMAL_128).nested(), TypeSig.all), - (fsse, conf, p, r) => new SparkPlanMeta[FileSourceScanExec](fsse, conf, p, r) { - - // Replaces SubqueryBroadcastExec inside dynamic pruning filters with GPU counterpart - // if possible. Instead regarding filters as childExprs of current Meta, we create - // a new meta for SubqueryBroadcastExec. The reason is that the GPU replacement of - // FileSourceScan is independent from the replacement of the partitionFilters. It is - // possible that the FileSourceScan is on the CPU, while the dynamic partitionFilters - // are on the GPU. And vice versa. - private lazy val partitionFilters = { - val convertBroadcast = (bc: SubqueryBroadcastExec) => { - val meta = GpuOverrides.wrapAndTagPlan(bc, conf) - meta.tagForExplain() - meta.convertIfNeeded().asInstanceOf[BaseSubqueryExec] - } - wrapped.partitionFilters.map { filter => - filter.transformDown { - case dpe @ DynamicPruningExpression(inSub: InSubqueryExec) => - inSub.plan match { - case bc: SubqueryBroadcastExec => - dpe.copy(inSub.copy(plan = convertBroadcast(bc))) - case reuse @ ReusedSubqueryExec(bc: SubqueryBroadcastExec) => - dpe.copy(inSub.copy(plan = reuse.copy(convertBroadcast(bc)))) - case _ => - dpe - } - } - } - } - - // partition filters and data filters are not run on the GPU - override val childExprs: Seq[ExprMeta[_]] = Seq.empty - - override def tagPlanForGpu(): Unit = tagFileSourceScanExec(this) - - override def convertToCpu(): SparkPlan = { - wrapped.copy(partitionFilters = partitionFilters) - } - - override def convertToGpu(): GpuExec = { - val sparkSession = wrapped.relation.sparkSession - val options = wrapped.relation.options - - val location = AlluxioUtils.replacePathIfNeeded( - conf, - wrapped.relation, - partitionFilters, - wrapped.dataFilters) - - val newRelation = HadoopFsRelation( - location, - wrapped.relation.partitionSchema, - wrapped.relation.dataSchema, - wrapped.relation.bucketSpec, - GpuFileSourceScanExec.convertFileFormat(wrapped.relation.fileFormat), - options)(sparkSession) - - GpuFileSourceScanExec( - newRelation, - wrapped.output, - wrapped.requiredSchema, - partitionFilters, - wrapped.optionalBucketSet, - wrapped.optionalNumCoalescedBuckets, - wrapped.dataFilters, - wrapped.tableIdentifier, - wrapped.disableBucketedScan)(conf) - } - }), + (fsse, conf, p, r) => new FileSourceScanExecMeta320Plus(fsse, conf, p, r)), GpuOverrides.exec[BatchScanExec]( "The backend for most file input", ExecChecks( (TypeSig.commonCudfTypes + TypeSig.STRUCT + TypeSig.MAP + TypeSig.ARRAY + TypeSig.DECIMAL_128).nested(), TypeSig.all), - (p, conf, parent, r) => new SparkPlanMeta[BatchScanExec](p, conf, parent, r) { - override val childScans: scala.Seq[ScanMeta[_]] = - Seq(GpuOverrides.wrapScan(p.scan, conf, Some(this))) - - override def tagPlanForGpu(): Unit = { - if (!p.runtimeFilters.isEmpty) { - willNotWorkOnGpu("runtime filtering (DPP) on datasource V2 is not supported") - } - } - - override def convertToGpu(): GpuExec = - GpuBatchScanExec(p.output, childScans.head.convertToGpu(), p.runtimeFilters) - }) + (p, conf, parent, r) => new BatchScanExecMeta320Plus(p, conf, parent, r)) ).map(r => (r.getClassFor.asSubclass(classOf[SparkPlan]), r)).toMap } @@ -492,4 +413,95 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging { def tagFileSourceScanExec(meta: SparkPlanMeta[FileSourceScanExec]): Unit = { GpuFileSourceScanExec.tagSupport(meta) } + + class FileSourceScanExecMeta320Plus(plan: FileSourceScanExec, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _, _]], + rule: DataFromReplacementRule) + extends SparkPlanMeta[FileSourceScanExec](plan, conf, parent, rule) { + + // Replaces SubqueryBroadcastExec inside dynamic pruning filters with GPU counterpart + // if possible. Instead regarding filters as childExprs of current Meta, we create + // a new meta for SubqueryBroadcastExec. The reason is that the GPU replacement of + // FileSourceScan is independent from the replacement of the partitionFilters. It is + // possible that the FileSourceScan is on the CPU, while the dynamic partitionFilters + // are on the GPU. And vice versa. + private lazy val partitionFilters = { + val convertBroadcast = (bc: SubqueryBroadcastExec) => { + val meta = GpuOverrides.wrapAndTagPlan(bc, conf) + meta.tagForExplain() + meta.convertIfNeeded().asInstanceOf[BaseSubqueryExec] + } + wrapped.partitionFilters.map { filter => + filter.transformDown { + case dpe @ DynamicPruningExpression(inSub: InSubqueryExec) => + inSub.plan match { + case bc: SubqueryBroadcastExec => + dpe.copy(inSub.copy(plan = convertBroadcast(bc))) + case reuse @ ReusedSubqueryExec(bc: SubqueryBroadcastExec) => + dpe.copy(inSub.copy(plan = reuse.copy(convertBroadcast(bc)))) + case _ => + dpe + } + } + } + } + + // partition filters and data filters are not run on the GPU + override val childExprs: Seq[ExprMeta[_]] = Seq.empty + + override def tagPlanForGpu(): Unit = tagFileSourceScanExec(this) + + override def convertToCpu(): SparkPlan = { + wrapped.copy(partitionFilters = partitionFilters) + } + + override def convertToGpu(): GpuExec = { + val sparkSession = wrapped.relation.sparkSession + val options = wrapped.relation.options + + val location = AlluxioUtils.replacePathIfNeeded( + conf, + wrapped.relation, + partitionFilters, + wrapped.dataFilters) + + val newRelation = HadoopFsRelation( + location, + wrapped.relation.partitionSchema, + wrapped.relation.dataSchema, + wrapped.relation.bucketSpec, + GpuFileSourceScanExec.convertFileFormat(wrapped.relation.fileFormat), + options)(sparkSession) + + GpuFileSourceScanExec( + newRelation, + wrapped.output, + wrapped.requiredSchema, + partitionFilters, + wrapped.optionalBucketSet, + wrapped.optionalNumCoalescedBuckets, + wrapped.dataFilters, + wrapped.tableIdentifier, + wrapped.disableBucketedScan)(conf) + } + } + + class BatchScanExecMeta320Plus(p: BatchScanExec, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _, _]], + rule: DataFromReplacementRule) + extends SparkPlanMeta[BatchScanExec](p, conf, parent, rule) { + override val childScans: scala.Seq[ScanMeta[_]] = + Seq(GpuOverrides.wrapScan(p.scan, conf, Some(this))) + + override def tagPlanForGpu(): Unit = { + if (!p.runtimeFilters.isEmpty) { + willNotWorkOnGpu("runtime filtering (DPP) on datasource V2 is not supported") + } + } + + override def convertToGpu(): GpuExec = + GpuBatchScanExec(p.output, childScans.head.convertToGpu(), p.runtimeFilters) + } } diff --git a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala index 443eec6a3f9..45432da4428 100644 --- a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala +++ b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala @@ -182,8 +182,28 @@ object GpuTypeShims { def isSupportedYearMonthType(dt: DataType): Boolean = dt.isInstanceOf[YearMonthIntervalType] /** - * Get additional supported types for this Shim + * Get additional arithmetic supported types for this Shim */ def additionalArithmeticSupportedTypes: TypeSig = TypeSig.ansiIntervals + /** + * Get additional predicate supported types for this Shim + */ + def additionalPredicateSupportedTypes: TypeSig = TypeSig.DAYTIME + + /** + * Get additional Csv supported types for this Shim + */ + def additionalCsvSupportedTypes: TypeSig = TypeSig.DAYTIME + + /** + * Get additional Parquet supported types for this Shim + */ + def additionalParquetSupportedTypes: TypeSig = TypeSig.ansiIntervals + + /** + * Get additional common operators supported types for this Shim + * (filter, sample, project, alias, table scan ...... which GPU supports from 330) + */ + def additionalCommonOperatorSupportedTypes: TypeSig = TypeSig.ansiIntervals } diff --git a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/Spark33XShims.scala b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/Spark33XShims.scala index f2c6ec115f3..28cfe275397 100644 --- a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/Spark33XShims.scala +++ b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/Spark33XShims.scala @@ -24,10 +24,8 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.{BaseSubqueryExec, FileSourceScanExec, FilterExec, InSubqueryExec, ProjectExec, ReusedSubqueryExec, SparkPlan, SubqueryBroadcastExec} -import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec -import org.apache.spark.sql.execution.command.DataWritingCommandExec -import org.apache.spark.sql.execution.datasources.{DataSourceUtils, FilePartition, FileScanRDD, HadoopFsRelation, PartitionedFile} +import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} +import org.apache.spark.sql.execution.datasources.{DataSourceUtils, FilePartition, FileScanRDD, PartitionedFile} import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.internal.SQLConf @@ -75,25 +73,8 @@ trait Spark33XShims extends Spark321PlusShims with Spark320PlusNonDBShims { super.tagFileSourceScanExec(meta) } - // 330+ supports DAYTIME interval types - override def getFileFormats: Map[FileFormatType, Map[FileFormatOp, FileFormatChecks]] = { - Map( - (CsvFormatType, FileFormatChecks( - cudfRead = TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.DAYTIME, - cudfWrite = TypeSig.none, - sparkSig = TypeSig.cpuAtomics)), - (ParquetFormatType, FileFormatChecks( - cudfRead = (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.STRUCT + - TypeSig.ARRAY + TypeSig.MAP + TypeSig.ansiIntervals).nested(), - cudfWrite = (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.STRUCT + - TypeSig.ARRAY + TypeSig.MAP + TypeSig.ansiIntervals).nested(), - sparkSig = (TypeSig.cpuAtomics + TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + - TypeSig.UDT + TypeSig.ansiIntervals).nested()))) - } - // 330+ supports DAYTIME interval types override def getExprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = { - val _gpuCommonTypes = TypeSig.commonCudfTypes + TypeSig.NULL val map: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Seq( GpuOverrides.expr[RoundCeil]( "Computes the ceiling of the given expression to d decimal places", @@ -183,101 +164,6 @@ trait Spark33XShims extends Spark321PlusShims with Spark320PlusNonDBShims { override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = GpuTimeAdd(lhs, rhs) }), - GpuOverrides.expr[IsNull]( - "Checks if a value is null", - ExprChecks.unaryProject(TypeSig.BOOLEAN, TypeSig.BOOLEAN, - (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.MAP + TypeSig.ARRAY + - TypeSig.STRUCT + TypeSig.DECIMAL_128 + TypeSig.DAYTIME).nested(), - TypeSig.all), - (a, conf, p, r) => new UnaryExprMeta[IsNull](a, conf, p, r) { - override def convertToGpu(child: Expression): GpuExpression = GpuIsNull(child) - }), - GpuOverrides.expr[IsNotNull]( - "Checks if a value is not null", - ExprChecks.unaryProject(TypeSig.BOOLEAN, TypeSig.BOOLEAN, - (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.MAP + TypeSig.ARRAY + - TypeSig.STRUCT + TypeSig.DECIMAL_128 + TypeSig.DAYTIME).nested(), - TypeSig.all), - (a, conf, p, r) => new UnaryExprMeta[IsNotNull](a, conf, p, r) { - override def convertToGpu(child: Expression): GpuExpression = GpuIsNotNull(child) - }), - GpuOverrides.expr[EqualNullSafe]( - "Check if the values are equal including nulls <=>", - ExprChecks.binaryProject( - TypeSig.BOOLEAN, TypeSig.BOOLEAN, - ("lhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.DAYTIME, - TypeSig.comparable), - ("rhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.DAYTIME, - TypeSig.comparable)), - (a, conf, p, r) => new BinaryExprMeta[EqualNullSafe](a, conf, p, r) { - override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = - GpuEqualNullSafe(lhs, rhs) - }), - GpuOverrides.expr[EqualTo]( - "Check if the values are equal", - ExprChecks.binaryProjectAndAst( - TypeSig.comparisonAstTypes, - TypeSig.BOOLEAN, TypeSig.BOOLEAN, - ("lhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.DAYTIME, - TypeSig.comparable), - ("rhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.DAYTIME, - TypeSig.comparable)), - (a, conf, p, r) => new BinaryAstExprMeta[EqualTo](a, conf, p, r) { - override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = - GpuEqualTo(lhs, rhs) - }), - GpuOverrides.expr[GreaterThan]( - "> operator", - ExprChecks.binaryProjectAndAst( - TypeSig.comparisonAstTypes, - TypeSig.BOOLEAN, TypeSig.BOOLEAN, - ("lhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.DAYTIME, - TypeSig.orderable), - ("rhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.DAYTIME, - TypeSig.orderable)), - (a, conf, p, r) => new BinaryAstExprMeta[GreaterThan](a, conf, p, r) { - override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = - GpuGreaterThan(lhs, rhs) - }), - GpuOverrides.expr[GreaterThanOrEqual]( - ">= operator", - ExprChecks.binaryProjectAndAst( - TypeSig.comparisonAstTypes, - TypeSig.BOOLEAN, TypeSig.BOOLEAN, - ("lhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.DAYTIME, - TypeSig.orderable), - ("rhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.DAYTIME, - TypeSig.orderable)), - (a, conf, p, r) => new BinaryAstExprMeta[GreaterThanOrEqual](a, conf, p, r) { - override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = - GpuGreaterThanOrEqual(lhs, rhs) - }), - GpuOverrides.expr[LessThan]( - "< operator", - ExprChecks.binaryProjectAndAst( - TypeSig.comparisonAstTypes, - TypeSig.BOOLEAN, TypeSig.BOOLEAN, - ("lhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.DAYTIME, - TypeSig.orderable), - ("rhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.DAYTIME, - TypeSig.orderable)), - (a, conf, p, r) => new BinaryAstExprMeta[LessThan](a, conf, p, r) { - override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = - GpuLessThan(lhs, rhs) - }), - GpuOverrides.expr[LessThanOrEqual]( - "<= operator", - ExprChecks.binaryProjectAndAst( - TypeSig.comparisonAstTypes, - TypeSig.BOOLEAN, TypeSig.BOOLEAN, - ("lhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.DAYTIME, - TypeSig.orderable), - ("rhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.DAYTIME, - TypeSig.orderable)), - (a, conf, p, r) => new BinaryAstExprMeta[LessThanOrEqual](a, conf, p, r) { - override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = - GpuLessThanOrEqual(lhs, rhs) - }), GpuOverrides.expr[Abs]( "Absolute value", ExprChecks.unaryProjectAndAstInputMatchesOutput( @@ -300,144 +186,23 @@ trait Spark33XShims extends Spark321PlusShims with Spark320PlusNonDBShims { super.getExprs ++ map } - // 330+ supports DAYTIME interval types + // GPU support ANSI interval types from 330 override def getExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] = { - val _gpuCommonTypes = TypeSig.commonCudfTypes + TypeSig.NULL val map: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] = Seq( GpuOverrides.exec[BatchScanExec]( "The backend for most file input", ExecChecks( (TypeSig.commonCudfTypes + TypeSig.STRUCT + TypeSig.MAP + TypeSig.ARRAY + - TypeSig.DECIMAL_128 + TypeSig.ansiIntervals).nested(), - TypeSig.all), - (p, conf, parent, r) => new SparkPlanMeta[BatchScanExec](p, conf, parent, r) { - override val childScans: scala.Seq[ScanMeta[_]] = - Seq(GpuOverrides.wrapScan(p.scan, conf, Some(this))) - - override def tagPlanForGpu(): Unit = { - if (!p.runtimeFilters.isEmpty) { - willNotWorkOnGpu("runtime filtering (DPP) on datasource V2 is not supported") - } - if (!p.keyGroupedPartitioning.isEmpty) { - willNotWorkOnGpu("key grouped partitioning is not supported") - } - } - - override def convertToGpu(): GpuExec = GpuBatchScanExec(p.output, - childScans.head.convertToGpu(), p.runtimeFilters, p.keyGroupedPartitioning) - }), - GpuOverrides.exec[DataWritingCommandExec]( - "Writing data", - ExecChecks((TypeSig.commonCudfTypes + TypeSig.DECIMAL_128.withPsNote( - TypeEnum.DECIMAL, "128bit decimal only supported for Orc and Parquet") + - TypeSig.STRUCT.withPsNote(TypeEnum.STRUCT, "Only supported for Parquet") + - TypeSig.MAP.withPsNote(TypeEnum.MAP, "Only supported for Parquet") + - TypeSig.ARRAY.withPsNote(TypeEnum.ARRAY, "Only supported for Parquet") + - TypeSig.ansiIntervals).nested(), + TypeSig.DECIMAL_128 + GpuTypeShims.additionalCommonOperatorSupportedTypes).nested(), TypeSig.all), - (p, conf, parent, r) => new SparkPlanMeta[DataWritingCommandExec](p, conf, parent, r) { - override val childDataWriteCmds: scala.Seq[DataWritingCommandMeta[_]] = - Seq(GpuOverrides.wrapDataWriteCmds(p.cmd, conf, Some(this))) - - override def convertToGpu(): GpuExec = - GpuDataWritingCommandExec(childDataWriteCmds.head.convertToGpu(), - childPlans.head.convertIfNeeded()) - }), - // this is copied, only added TypeSig.DAYTIME and TypeSig.YEARMONTH check + (p, conf, parent, r) => new BatchScanExecMeta320Plus(p, conf, parent, r)), GpuOverrides.exec[FileSourceScanExec]( "Reading data from files, often from Hive tables", ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP + - TypeSig.ARRAY + TypeSig.DECIMAL_128 + TypeSig.ansiIntervals).nested(), + TypeSig.ARRAY + TypeSig.DECIMAL_128 + + GpuTypeShims.additionalCommonOperatorSupportedTypes).nested(), TypeSig.all), - (fsse, conf, p, r) => new SparkPlanMeta[FileSourceScanExec](fsse, conf, p, r) { - - // Replaces SubqueryBroadcastExec inside dynamic pruning filters with GPU counterpart - // if possible. Instead regarding filters as childExprs of current Meta, we create - // a new meta for SubqueryBroadcastExec. The reason is that the GPU replacement of - // FileSourceScan is independent from the replacement of the partitionFilters. It is - // possible that the FileSourceScan is on the CPU, while the dynamic partitionFilters - // are on the GPU. And vice versa. - private lazy val partitionFilters = { - val convertBroadcast = (bc: SubqueryBroadcastExec) => { - val meta = GpuOverrides.wrapAndTagPlan(bc, conf) - meta.tagForExplain() - meta.convertIfNeeded().asInstanceOf[BaseSubqueryExec] - } - wrapped.partitionFilters.map { filter => - filter.transformDown { - case dpe@DynamicPruningExpression(inSub: InSubqueryExec) => - inSub.plan match { - case bc: SubqueryBroadcastExec => - dpe.copy(inSub.copy(plan = convertBroadcast(bc))) - case reuse@ReusedSubqueryExec(bc: SubqueryBroadcastExec) => - dpe.copy(inSub.copy(plan = reuse.copy(convertBroadcast(bc)))) - case _ => - dpe - } - } - } - } - - // partition filters and data filters are not run on the GPU - override val childExprs: Seq[ExprMeta[_]] = Seq.empty - - override def tagPlanForGpu(): Unit = tagFileSourceScanExec(this) - - override def convertToCpu(): SparkPlan = { - wrapped.copy(partitionFilters = partitionFilters) - } - - override def convertToGpu(): GpuExec = { - val sparkSession = wrapped.relation.sparkSession - val options = wrapped.relation.options - - val location = AlluxioUtils.replacePathIfNeeded( - conf, - wrapped.relation, - partitionFilters, - wrapped.dataFilters) - - val newRelation = HadoopFsRelation( - location, - wrapped.relation.partitionSchema, - wrapped.relation.dataSchema, - wrapped.relation.bucketSpec, - GpuFileSourceScanExec.convertFileFormat(wrapped.relation.fileFormat), - options)(sparkSession) - - GpuFileSourceScanExec( - newRelation, - wrapped.output, - wrapped.requiredSchema, - partitionFilters, - wrapped.optionalBucketSet, - wrapped.optionalNumCoalescedBuckets, - wrapped.dataFilters, - wrapped.tableIdentifier, - wrapped.disableBucketedScan)(conf) - } - }), - GpuOverrides.exec[InMemoryTableScanExec]( - "Implementation of InMemoryTableScanExec to use GPU accelerated Caching", - // NullType is actually supported - ExecChecks(TypeSig.commonCudfTypesWithNested + TypeSig.ansiIntervals, - TypeSig.all), - (scan, conf, p, r) => new InMemoryTableScanMeta(scan, conf, p, r)), - GpuOverrides.exec[ProjectExec]( - "The backend for most select, withColumn and dropColumn statements", - ExecChecks( - (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP + - TypeSig.ARRAY + TypeSig.DECIMAL_128 + TypeSig.ansiIntervals).nested(), - TypeSig.all), - (proj, conf, p, r) => new GpuProjectExecMeta(proj, conf, p, r)), - GpuOverrides.exec[FilterExec]( - "The backend for most filter statements", - ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP + - TypeSig.ARRAY + TypeSig.DECIMAL_128 + TypeSig.DAYTIME).nested(), TypeSig.all), - (filter, conf, p, r) => new SparkPlanMeta[FilterExec](filter, conf, p, r) { - override def convertToGpu(): GpuExec = - GpuFilterExec(childExprs.head.convertToGpu(), childPlans.head.convertIfNeeded()) - }) + (fsse, conf, p, r) => new FileSourceScanExecMeta320Plus(fsse, conf, p, r)) ).map(r => (r.getClassFor.asSubclass(classOf[SparkPlan]), r)).toMap super.getExecs ++ map } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index b6f626e4c90..f1804d75fe8 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -808,18 +808,19 @@ object GpuOverrides extends Logging { .map(r => r.wrap(expr, conf, parent, r).asInstanceOf[BaseExprMeta[INPUT]]) .getOrElse(new RuleNotFoundExprMeta(expr, conf, parent)) - lazy val basicFormats: Map[FileFormatType, Map[FileFormatOp, FileFormatChecks]] = Map( + lazy val fileFormats: Map[FileFormatType, Map[FileFormatOp, FileFormatChecks]] = Map( (CsvFormatType, FileFormatChecks( - cudfRead = TypeSig.commonCudfTypes + TypeSig.DECIMAL_128, + cudfRead = TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + + GpuTypeShims.additionalCsvSupportedTypes, cudfWrite = TypeSig.none, sparkSig = TypeSig.cpuAtomics)), (ParquetFormatType, FileFormatChecks( cudfRead = (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.STRUCT + - TypeSig.ARRAY + TypeSig.MAP).nested(), + TypeSig.ARRAY + TypeSig.MAP + GpuTypeShims.additionalParquetSupportedTypes).nested(), cudfWrite = (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.STRUCT + - TypeSig.ARRAY + TypeSig.MAP).nested(), + TypeSig.ARRAY + TypeSig.MAP + GpuTypeShims.additionalParquetSupportedTypes).nested(), sparkSig = (TypeSig.cpuAtomics + TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + - TypeSig.UDT).nested())), + TypeSig.UDT + GpuTypeShims.additionalParquetSupportedTypes).nested())), (OrcFormatType, FileFormatChecks( cudfRead = (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.DECIMAL_128 + TypeSig.STRUCT + TypeSig.MAP).nested(), @@ -840,8 +841,6 @@ object GpuOverrides extends Logging { sparkSig = (TypeSig.cpuAtomics + TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + TypeSig.UDT).nested()))) - lazy val fileFormats = basicFormats ++ SparkShimImpl.getFileFormats - val commonExpressions: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Seq( expr[Literal]( "Holds a static value from the query", @@ -862,9 +861,9 @@ object GpuOverrides extends Logging { expr[Alias]( "Gives a column a name", ExprChecks.unaryProjectAndAstInputMatchesOutput( - TypeSig.astTypes + GpuTypeShims.additionalArithmeticSupportedTypes, + TypeSig.astTypes + GpuTypeShims.additionalCommonOperatorSupportedTypes, (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.MAP + TypeSig.ARRAY + TypeSig.STRUCT - + TypeSig.DECIMAL_128 + GpuTypeShims.additionalArithmeticSupportedTypes).nested(), + + TypeSig.DECIMAL_128 + GpuTypeShims.additionalCommonOperatorSupportedTypes).nested(), TypeSig.all), (a, conf, p, r) => new UnaryAstExprMeta[Alias](a, conf, p, r) { override def convertToGpu(child: Expression): GpuExpression = @@ -1364,7 +1363,8 @@ object GpuOverrides extends Logging { "Checks if a value is null", ExprChecks.unaryProject(TypeSig.BOOLEAN, TypeSig.BOOLEAN, (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.MAP + TypeSig.ARRAY + - TypeSig.STRUCT + TypeSig.DECIMAL_128).nested(), + TypeSig.STRUCT + TypeSig.DECIMAL_128 + + GpuTypeShims.additionalPredicateSupportedTypes).nested(), TypeSig.all), (a, conf, p, r) => new UnaryExprMeta[IsNull](a, conf, p, r) { override def convertToGpu(child: Expression): GpuExpression = GpuIsNull(child) @@ -1373,7 +1373,8 @@ object GpuOverrides extends Logging { "Checks if a value is not null", ExprChecks.unaryProject(TypeSig.BOOLEAN, TypeSig.BOOLEAN, (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.MAP + TypeSig.ARRAY + - TypeSig.STRUCT + TypeSig.DECIMAL_128).nested(), + TypeSig.STRUCT + TypeSig.DECIMAL_128+ + GpuTypeShims.additionalPredicateSupportedTypes).nested(), TypeSig.all), (a, conf, p, r) => new UnaryExprMeta[IsNotNull](a, conf, p, r) { override def convertToGpu(child: Expression): GpuExpression = GpuIsNotNull(child) @@ -1952,9 +1953,11 @@ object GpuOverrides extends Logging { "Check if the values are equal including nulls <=>", ExprChecks.binaryProject( TypeSig.BOOLEAN, TypeSig.BOOLEAN, - ("lhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, + ("lhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + + GpuTypeShims.additionalPredicateSupportedTypes, TypeSig.comparable), - ("rhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, + ("rhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + + GpuTypeShims.additionalPredicateSupportedTypes, TypeSig.comparable)), (a, conf, p, r) => new BinaryExprMeta[EqualNullSafe](a, conf, p, r) { override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = @@ -1965,9 +1968,11 @@ object GpuOverrides extends Logging { ExprChecks.binaryProjectAndAst( TypeSig.comparisonAstTypes, TypeSig.BOOLEAN, TypeSig.BOOLEAN, - ("lhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, + ("lhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + + GpuTypeShims.additionalPredicateSupportedTypes, TypeSig.comparable), - ("rhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, + ("rhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + + GpuTypeShims.additionalPredicateSupportedTypes, TypeSig.comparable)), (a, conf, p, r) => new BinaryAstExprMeta[EqualTo](a, conf, p, r) { override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = @@ -1978,9 +1983,11 @@ object GpuOverrides extends Logging { ExprChecks.binaryProjectAndAst( TypeSig.comparisonAstTypes, TypeSig.BOOLEAN, TypeSig.BOOLEAN, - ("lhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, + ("lhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + + GpuTypeShims.additionalPredicateSupportedTypes, TypeSig.orderable), - ("rhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, + ("rhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + + GpuTypeShims.additionalPredicateSupportedTypes, TypeSig.orderable)), (a, conf, p, r) => new BinaryAstExprMeta[GreaterThan](a, conf, p, r) { override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = @@ -1991,9 +1998,11 @@ object GpuOverrides extends Logging { ExprChecks.binaryProjectAndAst( TypeSig.comparisonAstTypes, TypeSig.BOOLEAN, TypeSig.BOOLEAN, - ("lhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, + ("lhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + + GpuTypeShims.additionalPredicateSupportedTypes, TypeSig.orderable), - ("rhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, + ("rhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + + GpuTypeShims.additionalPredicateSupportedTypes, TypeSig.orderable)), (a, conf, p, r) => new BinaryAstExprMeta[GreaterThanOrEqual](a, conf, p, r) { override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = @@ -2039,9 +2048,11 @@ object GpuOverrides extends Logging { ExprChecks.binaryProjectAndAst( TypeSig.comparisonAstTypes, TypeSig.BOOLEAN, TypeSig.BOOLEAN, - ("lhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, + ("lhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + + GpuTypeShims.additionalPredicateSupportedTypes, TypeSig.orderable), - ("rhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, + ("rhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + + GpuTypeShims.additionalPredicateSupportedTypes, TypeSig.orderable)), (a, conf, p, r) => new BinaryAstExprMeta[LessThan](a, conf, p, r) { override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = @@ -2052,9 +2063,11 @@ object GpuOverrides extends Logging { ExprChecks.binaryProjectAndAst( TypeSig.comparisonAstTypes, TypeSig.BOOLEAN, TypeSig.BOOLEAN, - ("lhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, + ("lhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + + GpuTypeShims.additionalPredicateSupportedTypes, TypeSig.orderable), - ("rhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128, + ("rhs", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + + GpuTypeShims.additionalPredicateSupportedTypes, TypeSig.orderable)), (a, conf, p, r) => new BinaryAstExprMeta[LessThanOrEqual](a, conf, p, r) { override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = @@ -3597,7 +3610,8 @@ object GpuOverrides extends Logging { "The backend for most select, withColumn and dropColumn statements", ExecChecks( (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP + - TypeSig.ARRAY + TypeSig.DECIMAL_128).nested(), + TypeSig.ARRAY + TypeSig.DECIMAL_128 + + GpuTypeShims.additionalCommonOperatorSupportedTypes).nested(), TypeSig.all), (proj, conf, p, r) => new GpuProjectExecMeta(proj, conf, p, r)), exec[RangeExec]( @@ -3638,7 +3652,8 @@ object GpuOverrides extends Logging { TypeEnum.DECIMAL, "128bit decimal only supported for Orc and Parquet") + TypeSig.STRUCT.withPsNote(TypeEnum.STRUCT, "Only supported for Parquet") + TypeSig.MAP.withPsNote(TypeEnum.MAP, "Only supported for Parquet") + - TypeSig.ARRAY.withPsNote(TypeEnum.ARRAY, "Only supported for Parquet")).nested(), + TypeSig.ARRAY.withPsNote(TypeEnum.ARRAY, "Only supported for Parquet") + + GpuTypeShims.additionalCommonOperatorSupportedTypes).nested(), TypeSig.all), (p, conf, parent, r) => new SparkPlanMeta[DataWritingCommandExec](p, conf, parent, r) { override val childDataWriteCmds: scala.Seq[DataWritingCommandMeta[_]] = @@ -3720,7 +3735,8 @@ object GpuOverrides extends Logging { exec[FilterExec]( "The backend for most filter statements", ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP + - TypeSig.ARRAY + TypeSig.DECIMAL_128).nested(), TypeSig.all), + TypeSig.ARRAY + TypeSig.DECIMAL_128 + + GpuTypeShims.additionalCommonOperatorSupportedTypes).nested(), TypeSig.all), (filter, conf, p, r) => new SparkPlanMeta[FilterExec](filter, conf, p, r) { override def convertToGpu(): GpuExec = GpuFilterExec(childExprs.head.convertToGpu(), childPlans.head.convertIfNeeded()) @@ -3932,8 +3948,8 @@ object GpuOverrides extends Logging { (mapPy, conf, p, r) => new GpuMapInPandasExecMeta(mapPy, conf, p, r)), exec[InMemoryTableScanExec]( "Implementation of InMemoryTableScanExec to use GPU accelerated caching", - ExecChecks((TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.STRUCT - + TypeSig.ARRAY + TypeSig.MAP).nested(), TypeSig.all), + ExecChecks((TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.STRUCT + TypeSig.ARRAY + + TypeSig.MAP + GpuTypeShims.additionalCommonOperatorSupportedTypes).nested(), TypeSig.all), (scan, conf, p, r) => new InMemoryTableScanMeta(scan, conf, p, r)), neverReplaceExec[AlterNamespaceSetPropertiesExec]("Namespace metadata operation"), neverReplaceExec[CreateNamespaceExec]("Namespace metadata operation"), diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala index 22783601194..de88333337e 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala @@ -85,7 +85,6 @@ trait SparkShims { def getExprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] def getExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] def getScans: Map[Class[_ <: Scan], ScanRule[_ <: Scan]] - def getFileFormats: Map[FileFormatType, Map[FileFormatOp, FileFormatChecks]] = Map() def newBroadcastQueryStageExec( old: BroadcastQueryStageExec, From a8dd591abc5ea8bcd97066a5af675c552538543a Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Mon, 11 Apr 2022 14:27:41 +0800 Subject: [PATCH 2/3] Optimize imports --- .../scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala index 2a0d2143cc1..4b3769d7de7 100644 --- a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala +++ b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.DateFormatter import org.apache.spark.sql.connector.read.Scan -import org.apache.spark.sql.execution.{FileSourceScanExec, _} +import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive._ import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources._ From 7e8ad86be77e857192c673d1e2a692ae200d8405 Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Mon, 11 Apr 2022 18:28:48 +0800 Subject: [PATCH 3/3] Rename class name and update comments --- .../com/nvidia/spark/rapids/shims/Spark320PlusShims.scala | 8 ++++---- .../com/nvidia/spark/rapids/shims/Spark33XShims.scala | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala index 4b3769d7de7..6c364951532 100644 --- a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala +++ b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala @@ -337,14 +337,14 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging { "Reading data from files, often from Hive tables", ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP + TypeSig.ARRAY + TypeSig.DECIMAL_128).nested(), TypeSig.all), - (fsse, conf, p, r) => new FileSourceScanExecMeta320Plus(fsse, conf, p, r)), + (fsse, conf, p, r) => new FileSourceScanExecMeta(fsse, conf, p, r)), GpuOverrides.exec[BatchScanExec]( "The backend for most file input", ExecChecks( (TypeSig.commonCudfTypes + TypeSig.STRUCT + TypeSig.MAP + TypeSig.ARRAY + TypeSig.DECIMAL_128).nested(), TypeSig.all), - (p, conf, parent, r) => new BatchScanExecMeta320Plus(p, conf, parent, r)) + (p, conf, parent, r) => new BatchScanExecMeta(p, conf, parent, r)) ).map(r => (r.getClassFor.asSubclass(classOf[SparkPlan]), r)).toMap } @@ -414,7 +414,7 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging { GpuFileSourceScanExec.tagSupport(meta) } - class FileSourceScanExecMeta320Plus(plan: FileSourceScanExec, + class FileSourceScanExecMeta(plan: FileSourceScanExec, conf: RapidsConf, parent: Option[RapidsMeta[_, _, _]], rule: DataFromReplacementRule) @@ -487,7 +487,7 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging { } } - class BatchScanExecMeta320Plus(p: BatchScanExec, + class BatchScanExecMeta(p: BatchScanExec, conf: RapidsConf, parent: Option[RapidsMeta[_, _, _]], rule: DataFromReplacementRule) diff --git a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/Spark33XShims.scala b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/Spark33XShims.scala index 28cfe275397..8dea131480b 100644 --- a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/Spark33XShims.scala +++ b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/Spark33XShims.scala @@ -73,7 +73,7 @@ trait Spark33XShims extends Spark321PlusShims with Spark320PlusNonDBShims { super.tagFileSourceScanExec(meta) } - // 330+ supports DAYTIME interval types + // GPU support ANSI interval types from 330 override def getExprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = { val map: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Seq( GpuOverrides.expr[RoundCeil]( @@ -195,14 +195,14 @@ trait Spark33XShims extends Spark321PlusShims with Spark320PlusNonDBShims { (TypeSig.commonCudfTypes + TypeSig.STRUCT + TypeSig.MAP + TypeSig.ARRAY + TypeSig.DECIMAL_128 + GpuTypeShims.additionalCommonOperatorSupportedTypes).nested(), TypeSig.all), - (p, conf, parent, r) => new BatchScanExecMeta320Plus(p, conf, parent, r)), + (p, conf, parent, r) => new BatchScanExecMeta(p, conf, parent, r)), GpuOverrides.exec[FileSourceScanExec]( "Reading data from files, often from Hive tables", ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP + TypeSig.ARRAY + TypeSig.DECIMAL_128 + GpuTypeShims.additionalCommonOperatorSupportedTypes).nested(), TypeSig.all), - (fsse, conf, p, r) => new FileSourceScanExecMeta320Plus(fsse, conf, p, r)) + (fsse, conf, p, r) => new FileSourceScanExecMeta(fsse, conf, p, r)) ).map(r => (r.getClassFor.asSubclass(classOf[SparkPlan]), r)).toMap super.getExecs ++ map }