Skip to content

Commit

Permalink
Support UCX when AQE is enabled
Browse files Browse the repository at this point in the history
Signed-off-by: Andy Grove <andygrove@nvidia.com>
  • Loading branch information
andygrove committed Aug 27, 2020
1 parent b7ad292 commit c2b410b
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.util.concurrent.TimeUnit.NANOSECONDS
import scala.collection.mutable.ListBuffer

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.{DataFrame, Row, SparkSession}

object TpcdsLikeBench extends Logging {

Expand All @@ -47,21 +47,31 @@ object TpcdsLikeBench extends Logging {
println(s"*** Cold run $i took $elapsed msec.")
}

var df: DataFrame = null
var results: Array[Row] = null
val hotRunElapsed = new ListBuffer[Long]()
for (i <- 0 until numHotRuns) {
println(s"*** Start hot run $i:")
val start = System.nanoTime()
TpcdsLikeSpark.run(spark, query).collect
df = TpcdsLikeSpark.run(spark, query)
results = df.collect
val end = System.nanoTime()
val elapsed = NANOSECONDS.toMillis(end - start)
hotRunElapsed.append(elapsed)
println(s"*** Hot run $i took $elapsed msec.")
}

// for easier comparison between running with different configs, show query plan, sample data,
// and row count
df.show()
println(s"Row count: ${results.length}")

// show summary of performance at end
for (i <- 0 until numColdRuns) {
println(s"Cold run $i for query $query took ${coldRunElapsed(i)} msec.")
}
println(s"Average cold run took ${coldRunElapsed.sum.toDouble/numColdRuns} msec.")
println(s"Average cold run for query $query took " +
s"${coldRunElapsed.sum.toDouble/numColdRuns} msec.")

for (i <- 0 until numHotRuns) {
println(s"Hot run $i for query $query took ${hotRunElapsed(i)} msec.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,8 @@ class RapidsShuffleInternalManager(conf: SparkConf, isDriver: Boolean)
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
// NOTE: This type of reader is not possible for gpu shuffle, as we'd need
// to use the optimization within our manager, and we don't.
wrapped.getReaderForRange(RapidsShuffleInternalManagerBase.unwrapHandle(handle),
startMapIndex, endMapIndex, startPartition, endPartition, context, metrics)
getReaderInternal(handle, startMapIndex, endMapIndex, startPartition, endPartition, context,
metrics)
}

def getReader[K, C](
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.rapids.shims.spark301

import org.apache.spark.{SparkConf, TaskContext}
import org.apache.spark.shuffle._
import org.apache.spark.sql.rapids.RapidsShuffleInternalManagerBase

/**
* A shuffle manager optimized for the RAPIDS Plugin For Apache Spark.
* @note This is an internal class to obtain access to the private
* `ShuffleManager` and `SortShuffleManager` classes.
*/
class RapidsShuffleInternalManager(conf: SparkConf, isDriver: Boolean)
extends RapidsShuffleInternalManagerBase(conf, isDriver) {

override def getReaderForRange[K, C](
handle: ShuffleHandle,
startMapIndex: Int,
endMapIndex: Int,
startPartition: Int,
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
getReaderInternal(handle, startMapIndex, endMapIndex, startPartition, endPartition, context,
metrics)
}

def getReader[K, C](
handle: ShuffleHandle,
startPartition: Int,
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
getReaderInternal(handle, 0, Int.MaxValue, startPartition, endPartition, context, metrics)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.shims.spark300.{GpuShuffledHashJoinMeta, GpuSortMergeJoinMeta, Spark300Shims}
import com.nvidia.spark.rapids.spark301.RapidsShuffleManager

import org.apache.spark.SparkEnv
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last}
Expand All @@ -30,6 +31,7 @@ import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ShuffleExchangeLike}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExecBase, GpuShuffleExchangeExecBase}
import org.apache.spark.storage.{BlockId, BlockManagerId}

class Spark301Shims extends Spark300Shims {

Expand Down Expand Up @@ -72,6 +74,15 @@ class Spark301Shims extends Spark300Shims {
classOf[RapidsShuffleManager].getCanonicalName
}

override def getMapSizesByExecutorId(
shuffleId: Int,
startMapIndex: Int,
endMapIndex: Int, startPartition: Int,
endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
SparkEnv.get.mapOutputTracker.getMapSizesByRange(shuffleId,
startMapIndex, endMapIndex, startPartition, endPartition)
}

override def getGpuBroadcastExchangeExec(
mode: BroadcastMode,
child: SparkPlan): GpuBroadcastExchangeExecBase = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package com.nvidia.spark.rapids.spark301

import org.apache.spark.SparkConf
import org.apache.spark.sql.rapids.shims.spark300.RapidsShuffleInternalManager
import org.apache.spark.sql.rapids.shims.spark301.RapidsShuffleInternalManager

/** A shuffle manager optimized for the RAPIDS Plugin for Apache Spark. */
sealed class RapidsShuffleManager(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package com.nvidia.spark.rapids.spark302

import org.apache.spark.SparkConf
import org.apache.spark.sql.rapids.shims.spark300.RapidsShuffleInternalManager
import org.apache.spark.sql.rapids.shims.spark301.RapidsShuffleInternalManager

/** A shuffle manager optimized for the RAPIDS Plugin for Apache Spark. */
sealed class RapidsShuffleManager(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class RapidsShuffleInternalManager(conf: SparkConf, isDriver: Boolean)
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
getReaderInternal(handle, 0, Int.MaxValue, startPartition, endPartition, context, metrics)
getReaderInternal(handle, startMapIndex, endMapIndex, startPartition, endPartition, context,
metrics)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,6 @@ abstract class RapidsShuffleInternalManagerBase(conf: SparkConf, isDriver: Boole
if (!GpuShuffleEnv.isRapidsShuffleEnabled) {
fallThroughReasons += "external shuffle is enabled"
}
if (SQLConf.get.adaptiveExecutionEnabled) {
fallThroughReasons += "adaptive query execution is enabled"
}
if (fallThroughReasons.nonEmpty) {
logWarning(s"Rapids Shuffle Plugin is falling back to SortShuffleManager " +
s"because: ${fallThroughReasons.mkString(", ")}")
Expand Down

0 comments on commit c2b410b

Please sign in to comment.