Skip to content

Commit

Permalink
Merge pull request #1319 from NVIDIA/branch-0.3
Browse files Browse the repository at this point in the history
[auto-merge] branch-0.3 to branch-0.4 [skip ci] [bot]
  • Loading branch information
nvauto authored Dec 8, 2020
2 parents a8f944b + e0d6d77 commit 37147f2
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
import org.apache.spark.sql.execution.datasources.{FilePartition, HadoopFsRelation, PartitionDirectory, PartitionedFile}
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashJoin, SortMergeJoinExec}
import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec
Expand Down Expand Up @@ -109,7 +110,18 @@ class Spark301dbShims extends Spark301Shims {
// partition filters and data filters are not run on the GPU
override val childExprs: Seq[ExprMeta[_]] = Seq.empty

override def tagPlanForGpu(): Unit = GpuFileSourceScanExec.tagSupport(this)
override def tagPlanForGpu(): Unit = {
// this is very specific check to have any of the Delta log metadata queries
// fallback and run on the CPU since there is some incompatibilities in
// Databricks Spark and Apache Spark.
if (wrapped.relation.fileFormat.isInstanceOf[JsonFileFormat] &&
wrapped.relation.location.getClass.getCanonicalName() ==
"com.databricks.sql.transaction.tahoe.DeltaLogFileIndex") {
this.entirePlanWillNotWork("Plans that read Delta Index JSON files can not run " +
"any part of the plan on the GPU!")
}
GpuFileSourceScanExec.tagSupport(this)
}

override def convertToGpu(): GpuExec = {
val sparkSession = wrapped.relation.sparkSession
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2267,16 +2267,25 @@ case class GpuOverrides() extends Rule[SparkPlan] with Logging {
if (conf.isSqlEnabled) {
val wrap = GpuOverrides.wrapPlan(plan, conf, None)
wrap.tagForGpu()
wrap.runAfterTagRules()
val reasonsToNotReplaceEntirePlan = wrap.getReasonsNotToReplaceEntirePlan
val exp = conf.explain
if (!exp.equalsIgnoreCase("NONE")) {
val explain = wrap.explain(exp.equalsIgnoreCase("ALL"))
if (!explain.isEmpty) {
logWarning(s"\n$explain")
if (conf.allowDisableEntirePlan && reasonsToNotReplaceEntirePlan.nonEmpty) {
if (!exp.equalsIgnoreCase("NONE")) {
logWarning("Can't replace any part of this plan due to: " +
s"${reasonsToNotReplaceEntirePlan.mkString(",")}")
}
plan
} else {
wrap.runAfterTagRules()
if (!exp.equalsIgnoreCase("NONE")) {
val explain = wrap.explain(exp.equalsIgnoreCase("ALL"))
if (!explain.isEmpty) {
logWarning(s"\n$explain")
}
}
val convertedPlan = wrap.convertIfNeeded()
addSortsIfNeeded(convertedPlan, conf)
}
val convertedPlan = wrap.convertIfNeeded()
addSortsIfNeeded(convertedPlan, conf)
} else {
plan
}
Expand Down
11 changes: 11 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,15 @@ object RapidsConf {
.booleanConf
.createWithDefault(false)

val ALLOW_DISABLE_ENTIRE_PLAN = conf("spark.rapids.allowDisableEntirePlan")
.internal()
.doc("The plugin has the ability to detect possibe incompatibility with some specific " +
"queries and cluster configurations. In those cases the plugin will disable GPU support " +
"for the entire query. Set this to false if you want to override that behavior, but use " +
"with caution.")
.booleanConf
.createWithDefault(true)

private def printSectionHeader(category: String): Unit =
println(s"\n### $category")

Expand Down Expand Up @@ -1072,6 +1081,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging {

lazy val cudfVersionOverride: Boolean = get(CUDF_VERSION_OVERRIDE)

lazy val allowDisableEntirePlan: Boolean = get(ALLOW_DISABLE_ENTIRE_PLAN)

lazy val getCloudSchemes: Option[Seq[String]] = get(CLOUD_SCHEMES)

def isOperatorEnabled(key: String, incompat: Boolean, isDisabledByDefault: Boolean): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ abstract class RapidsMeta[INPUT <: BASE, BASE, OUTPUT <: BASE](
def convertToCpu(): BASE = wrapped

private var cannotBeReplacedReasons: Option[mutable.Set[String]] = None

private var cannotReplaceAnyOfPlanReasons: Option[mutable.Set[String]] = None
private var shouldBeRemovedReasons: Option[mutable.Set[String]] = None

val gpuSupportedTag = TreeNodeTag[Set[String]]("rapids.gpu.supported")
Expand All @@ -141,6 +141,14 @@ abstract class RapidsMeta[INPUT <: BASE, BASE, OUTPUT <: BASE](
}
}

/**
* Call this if there is a condition found that the entire plan is not allowed
* to run on the GPU.
*/
final def entirePlanWillNotWork(because: String): Unit = {
cannotReplaceAnyOfPlanReasons.get.add(because)
}

final def shouldBeRemoved(because: String): Unit =
shouldBeRemovedReasons.get.add(because)

Expand All @@ -154,6 +162,15 @@ abstract class RapidsMeta[INPUT <: BASE, BASE, OUTPUT <: BASE](
*/
final def canThisBeReplaced: Boolean = cannotBeReplacedReasons.exists(_.isEmpty)

/**
* Returns the list of reasons the entire plan can't be replaced. An empty
* set means the entire plan is ok to be replaced, do the normal checking
* per exec and children.
*/
final def entirePlanExcludedReasons: Seq[String] = {
cannotReplaceAnyOfPlanReasons.getOrElse(mutable.Set.empty).toSeq
}

/**
* Returns true iff all of the expressions and their children could be replaced.
*/
Expand Down Expand Up @@ -184,6 +201,7 @@ abstract class RapidsMeta[INPUT <: BASE, BASE, OUTPUT <: BASE](
def initReasons(): Unit = {
cannotBeReplacedReasons = Some(mutable.Set[String]())
shouldBeRemovedReasons = Some(mutable.Set[String]())
cannotReplaceAnyOfPlanReasons = Some(mutable.Set[String]())
}

/**
Expand Down Expand Up @@ -511,6 +529,11 @@ abstract class SparkPlanMeta[INPUT <: SparkPlan](plan: INPUT,
}
}

def getReasonsNotToReplaceEntirePlan: Seq[String] = {
val childReasons = childPlans.flatMap(_.getReasonsNotToReplaceEntirePlan)
entirePlanExcludedReasons ++ childReasons
}

private def fixUpJoinConsistencyIfNeeded(): Unit = {
childPlans.foreach(_.fixUpJoinConsistencyIfNeeded())
wrapped match {
Expand Down

0 comments on commit 37147f2

Please sign in to comment.