Skip to content

Commit

Permalink
Update Spark 3.1 shim for ShuffleOrigin shuffle parameter (NVIDIA#1206)
Browse files Browse the repository at this point in the history
Signed-off-by: Jason Lowe <jlowe@nvidia.com>
  • Loading branch information
jlowe authored Nov 25, 2020
1 parent 7bcd2f6 commit 60df2a7
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class Spark300Shims extends SparkShims {
override def getGpuShuffleExchangeExec(
outputPartitioning: Partitioning,
child: SparkPlan,
canChangeNumPartitions: Boolean): GpuShuffleExchangeExecBase = {
cpuShuffle: Option[ShuffleExchangeExec]): GpuShuffleExchangeExecBase = {
GpuShuffleExchangeExec(outputPartitioning, child)
}

Expand All @@ -108,21 +108,21 @@ class Spark300Shims extends SparkShims {
override def isGpuHashJoin(plan: SparkPlan): Boolean = {
plan match {
case _: GpuHashJoin => true
case p => false
case _ => false
}
}

override def isGpuBroadcastHashJoin(plan: SparkPlan): Boolean = {
plan match {
case _: GpuBroadcastHashJoinExec => true
case p => false
case _ => false
}
}

override def isGpuShuffledHashJoin(plan: SparkPlan): Boolean = {
plan match {
case _: GpuShuffledHashJoinExec => true
case p => false
case _ => false
}
}

Expand Down Expand Up @@ -381,7 +381,7 @@ class Spark300Shims extends SparkShims {

override def getFileScanRDD(
sparkSession: SparkSession,
readFunction: (PartitionedFile) => Iterator[InternalRow],
readFunction: PartitionedFile => Iterator[InternalRow],
filePartitions: Seq[FilePartition]): RDD[InternalRow] = {
new FileScanRDD(sparkSession, readFunction, filePartitions)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ShuffleExchangeLike}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExecBase, GpuShuffleExchangeExecBase}
import org.apache.spark.sql.types.DataType
Expand Down Expand Up @@ -102,7 +102,8 @@ class Spark301Shims extends Spark300Shims {
override def getGpuShuffleExchangeExec(
outputPartitioning: Partitioning,
child: SparkPlan,
canChangeNumPartitions: Boolean): GpuShuffleExchangeExecBase = {
cpuShuffle: Option[ShuffleExchangeExec]): GpuShuffleExchangeExecBase = {
val canChangeNumPartitions = cpuShuffle.forall(_.canChangeNumPartitions)
GpuShuffleExchangeExec(outputPartitioning, child, canChangeNumPartitions)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
import org.apache.spark.sql.execution.datasources.{FilePartition, HadoopFsRelation, PartitionDirectory, PartitionedFile}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashJoin, SortMergeJoinExec}
import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec
import org.apache.spark.sql.execution.python.WindowInPandasExec
Expand Down Expand Up @@ -216,7 +217,8 @@ class Spark301dbShims extends Spark301Shims {
override def getGpuShuffleExchangeExec(
outputPartitioning: Partitioning,
child: SparkPlan,
canChangeNumPartitions: Boolean): GpuShuffleExchangeExecBase = {
cpuShuffle: Option[ShuffleExchangeExec]): GpuShuffleExchangeExecBase = {
val canChangeNumPartitions = cpuShuffle.forall(_.canChangeNumPartitions)
GpuShuffleExchangeExec(outputPartitioning, child, canChangeNumPartitions)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.shims.spark310

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.{ShufflePartitionSpec, SparkPlan}
import org.apache.spark.sql.execution.exchange.{ShuffleExchangeLike, ShuffleOrigin}
import org.apache.spark.sql.rapids.execution.GpuShuffleExchangeExecBase

case class GpuShuffleExchangeExec(
override val outputPartitioning: Partitioning,
child: SparkPlan,
shuffleOrigin: ShuffleOrigin)
extends GpuShuffleExchangeExecBase(outputPartitioning, child) with ShuffleExchangeLike {

override def numMappers: Int = shuffleDependencyColumnar.rdd.getNumPartitions

override def numPartitions: Int = shuffleDependencyColumnar.partitioner.numPartitions

override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[_] = {
throw new UnsupportedOperationException
}

override def runtimeStatistics: Statistics = {
val dataSize = metrics("dataSize").value
Statistics(dataSize)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,18 @@ import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.datasources.HadoopFsRelation
import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.internal.StaticSQLConf
import org.apache.spark.sql.rapids.{GpuFileSourceScanExec, GpuStringReplace, ShuffleManagerShimBase}
import org.apache.spark.sql.rapids.execution.GpuBroadcastNestedLoopJoinExecBase
import org.apache.spark.sql.rapids.execution.{GpuBroadcastNestedLoopJoinExecBase, GpuShuffleExchangeExecBase}
import org.apache.spark.sql.rapids.shims.spark310._
import org.apache.spark.sql.types._
import org.apache.spark.storage.{BlockId, BlockManagerId}
Expand Down Expand Up @@ -79,21 +81,21 @@ class Spark310Shims extends Spark301Shims {
override def isGpuHashJoin(plan: SparkPlan): Boolean = {
plan match {
case _: GpuHashJoin => true
case p => false
case _ => false
}
}

override def isGpuBroadcastHashJoin(plan: SparkPlan): Boolean = {
plan match {
case _: GpuBroadcastHashJoinExec => true
case p => false
case _ => false
}
}

override def isGpuShuffledHashJoin(plan: SparkPlan): Boolean = {
plan match {
case _: GpuShuffledHashJoinExec => true
case p => false
case _ => false
}
}

Expand Down Expand Up @@ -289,4 +291,11 @@ class Spark310Shims extends Spark301Shims {
GpuSchemaUtils.checkColumnNameDuplication(schema, colType, resolver)
}

override def getGpuShuffleExchangeExec(
outputPartitioning: Partitioning,
child: SparkPlan,
cpuShuffle: Option[ShuffleExchangeExec]): GpuShuffleExchangeExecBase = {
val shuffleOrigin = cpuShuffle.map(_.shuffleOrigin).getOrElse(ENSURE_REQUIREMENTS)
GpuShuffleExchangeExec(outputPartitioning, child, shuffleOrigin)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
import org.apache.spark.sql.execution.datasources.{FilePartition, HadoopFsRelation, PartitionDirectory, PartitionedFile}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.rapids.{GpuFileSourceScanExec, ShuffleManagerShimBase}
import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExecBase, GpuBroadcastNestedLoopJoinExecBase, GpuShuffleExchangeExecBase}
Expand Down Expand Up @@ -101,7 +102,7 @@ trait SparkShims {
def getGpuShuffleExchangeExec(
outputPartitioning: Partitioning,
child: SparkPlan,
canChangeNumPartitions: Boolean = true): GpuShuffleExchangeExecBase
cpuShuffle: Option[ShuffleExchangeExec] = None): GpuShuffleExchangeExecBase

def getGpuShuffleExchangeExec(
queryStage: ShuffleQueryStageExec): GpuShuffleExchangeExecBase
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class GpuShuffleMeta(
ShimLoader.getSparkShims.getGpuShuffleExchangeExec(
childParts(0).convertToGpu(),
childPlans(0).convertIfNeeded(),
shuffle.canChangeNumPartitions)
Some(shuffle))
}

/**
Expand Down

0 comments on commit 60df2a7

Please sign in to comment.