Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Spark33XShims to avoid code duplication #5195

Merged
merged 3 commits into from
Apr 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,29 @@ object GpuTypeShims {
def isSupportedYearMonthType(dt: DataType): Boolean = false

/**
* Get additional supported types for this Shim
* Get additional arithmetic supported types for this Shim
*/
def additionalArithmeticSupportedTypes: TypeSig = TypeSig.none

/**
* Get additional predicate supported types for this Shim
*/
def additionalPredicateSupportedTypes: TypeSig = TypeSig.none

/**
* Get additional Csv supported types for this Shim
*/
def additionalCsvSupportedTypes: TypeSig = TypeSig.none

/**
* Get additional Parquet supported types for this Shim
*/
def additionalParquetSupportedTypes: TypeSig = TypeSig.none

/**
* Get additional common operators supported types for this Shim
* (filter, sample, project, alias, table scan ...... which GPU supports from 330)
*/
def additionalCommonOperatorSupportedTypes: TypeSig = TypeSig.none

}
Original file line number Diff line number Diff line change
Expand Up @@ -337,93 +337,14 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging {
"Reading data from files, often from Hive tables",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP +
TypeSig.ARRAY + TypeSig.DECIMAL_128).nested(), TypeSig.all),
(fsse, conf, p, r) => new SparkPlanMeta[FileSourceScanExec](fsse, conf, p, r) {

// Replaces SubqueryBroadcastExec inside dynamic pruning filters with GPU counterpart
// if possible. Instead regarding filters as childExprs of current Meta, we create
// a new meta for SubqueryBroadcastExec. The reason is that the GPU replacement of
// FileSourceScan is independent from the replacement of the partitionFilters. It is
// possible that the FileSourceScan is on the CPU, while the dynamic partitionFilters
// are on the GPU. And vice versa.
private lazy val partitionFilters = {
val convertBroadcast = (bc: SubqueryBroadcastExec) => {
val meta = GpuOverrides.wrapAndTagPlan(bc, conf)
meta.tagForExplain()
meta.convertIfNeeded().asInstanceOf[BaseSubqueryExec]
}
wrapped.partitionFilters.map { filter =>
filter.transformDown {
case dpe @ DynamicPruningExpression(inSub: InSubqueryExec) =>
inSub.plan match {
case bc: SubqueryBroadcastExec =>
dpe.copy(inSub.copy(plan = convertBroadcast(bc)))
case reuse @ ReusedSubqueryExec(bc: SubqueryBroadcastExec) =>
dpe.copy(inSub.copy(plan = reuse.copy(convertBroadcast(bc))))
case _ =>
dpe
}
}
}
}

// partition filters and data filters are not run on the GPU
override val childExprs: Seq[ExprMeta[_]] = Seq.empty

override def tagPlanForGpu(): Unit = tagFileSourceScanExec(this)

override def convertToCpu(): SparkPlan = {
wrapped.copy(partitionFilters = partitionFilters)
}

override def convertToGpu(): GpuExec = {
val sparkSession = wrapped.relation.sparkSession
val options = wrapped.relation.options

val location = AlluxioUtils.replacePathIfNeeded(
conf,
wrapped.relation,
partitionFilters,
wrapped.dataFilters)

val newRelation = HadoopFsRelation(
location,
wrapped.relation.partitionSchema,
wrapped.relation.dataSchema,
wrapped.relation.bucketSpec,
GpuFileSourceScanExec.convertFileFormat(wrapped.relation.fileFormat),
options)(sparkSession)

GpuFileSourceScanExec(
newRelation,
wrapped.output,
wrapped.requiredSchema,
partitionFilters,
wrapped.optionalBucketSet,
wrapped.optionalNumCoalescedBuckets,
wrapped.dataFilters,
wrapped.tableIdentifier,
wrapped.disableBucketedScan)(conf)
}
}),
(fsse, conf, p, r) => new FileSourceScanExecMeta(fsse, conf, p, r)),
GpuOverrides.exec[BatchScanExec](
"The backend for most file input",
ExecChecks(
(TypeSig.commonCudfTypes + TypeSig.STRUCT + TypeSig.MAP + TypeSig.ARRAY +
TypeSig.DECIMAL_128).nested(),
TypeSig.all),
(p, conf, parent, r) => new SparkPlanMeta[BatchScanExec](p, conf, parent, r) {
override val childScans: scala.Seq[ScanMeta[_]] =
Seq(GpuOverrides.wrapScan(p.scan, conf, Some(this)))

override def tagPlanForGpu(): Unit = {
if (!p.runtimeFilters.isEmpty) {
willNotWorkOnGpu("runtime filtering (DPP) on datasource V2 is not supported")
}
}

override def convertToGpu(): GpuExec =
GpuBatchScanExec(p.output, childScans.head.convertToGpu(), p.runtimeFilters)
})
(p, conf, parent, r) => new BatchScanExecMeta(p, conf, parent, r))
).map(r => (r.getClassFor.asSubclass(classOf[SparkPlan]), r)).toMap
}

Expand Down Expand Up @@ -492,4 +413,95 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging {
def tagFileSourceScanExec(meta: SparkPlanMeta[FileSourceScanExec]): Unit = {
GpuFileSourceScanExec.tagSupport(meta)
}

class FileSourceScanExecMeta(plan: FileSourceScanExec,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends SparkPlanMeta[FileSourceScanExec](plan, conf, parent, rule) {

// Replaces SubqueryBroadcastExec inside dynamic pruning filters with GPU counterpart
// if possible. Instead regarding filters as childExprs of current Meta, we create
// a new meta for SubqueryBroadcastExec. The reason is that the GPU replacement of
// FileSourceScan is independent from the replacement of the partitionFilters. It is
// possible that the FileSourceScan is on the CPU, while the dynamic partitionFilters
// are on the GPU. And vice versa.
private lazy val partitionFilters = {
val convertBroadcast = (bc: SubqueryBroadcastExec) => {
val meta = GpuOverrides.wrapAndTagPlan(bc, conf)
meta.tagForExplain()
meta.convertIfNeeded().asInstanceOf[BaseSubqueryExec]
}
wrapped.partitionFilters.map { filter =>
filter.transformDown {
case dpe @ DynamicPruningExpression(inSub: InSubqueryExec) =>
inSub.plan match {
case bc: SubqueryBroadcastExec =>
dpe.copy(inSub.copy(plan = convertBroadcast(bc)))
case reuse @ ReusedSubqueryExec(bc: SubqueryBroadcastExec) =>
dpe.copy(inSub.copy(plan = reuse.copy(convertBroadcast(bc))))
case _ =>
dpe
}
}
}
}

// partition filters and data filters are not run on the GPU
override val childExprs: Seq[ExprMeta[_]] = Seq.empty

override def tagPlanForGpu(): Unit = tagFileSourceScanExec(this)

override def convertToCpu(): SparkPlan = {
wrapped.copy(partitionFilters = partitionFilters)
}

override def convertToGpu(): GpuExec = {
val sparkSession = wrapped.relation.sparkSession
val options = wrapped.relation.options

val location = AlluxioUtils.replacePathIfNeeded(
conf,
wrapped.relation,
partitionFilters,
wrapped.dataFilters)

val newRelation = HadoopFsRelation(
location,
wrapped.relation.partitionSchema,
wrapped.relation.dataSchema,
wrapped.relation.bucketSpec,
GpuFileSourceScanExec.convertFileFormat(wrapped.relation.fileFormat),
options)(sparkSession)

GpuFileSourceScanExec(
newRelation,
wrapped.output,
wrapped.requiredSchema,
partitionFilters,
wrapped.optionalBucketSet,
wrapped.optionalNumCoalescedBuckets,
wrapped.dataFilters,
wrapped.tableIdentifier,
wrapped.disableBucketedScan)(conf)
}
}

class BatchScanExecMeta(p: BatchScanExec,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
extends SparkPlanMeta[BatchScanExec](p, conf, parent, rule) {
override val childScans: scala.Seq[ScanMeta[_]] =
Seq(GpuOverrides.wrapScan(p.scan, conf, Some(this)))

override def tagPlanForGpu(): Unit = {
if (!p.runtimeFilters.isEmpty) {
willNotWorkOnGpu("runtime filtering (DPP) on datasource V2 is not supported")
}
}

override def convertToGpu(): GpuExec =
GpuBatchScanExec(p.output, childScans.head.convertToGpu(), p.runtimeFilters)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,28 @@ object GpuTypeShims {
def isSupportedYearMonthType(dt: DataType): Boolean = dt.isInstanceOf[YearMonthIntervalType]

/**
* Get additional supported types for this Shim
* Get additional arithmetic supported types for this Shim
*/
def additionalArithmeticSupportedTypes: TypeSig = TypeSig.ansiIntervals

/**
* Get additional predicate supported types for this Shim
*/
def additionalPredicateSupportedTypes: TypeSig = TypeSig.DAYTIME

/**
* Get additional Csv supported types for this Shim
*/
def additionalCsvSupportedTypes: TypeSig = TypeSig.DAYTIME

/**
* Get additional Parquet supported types for this Shim
*/
def additionalParquetSupportedTypes: TypeSig = TypeSig.ansiIntervals

/**
* Get additional common operators supported types for this Shim
* (filter, sample, project, alias, table scan ...... which GPU supports from 330)
*/
def additionalCommonOperatorSupportedTypes: TypeSig = TypeSig.ansiIntervals
}
Loading