Skip to content

Commit

Permalink
Fix inconsistencies in AQE support for broadcast joins (#1042)
Browse files Browse the repository at this point in the history
* Fix inconsistencies with AQE support for broadcast joins

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* code cleanup and change test behavior for Spark 3.0.0

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* fix inconsistency

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* fix test failure with Spark 3.1.0

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* fix inconsistency

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* fix imports

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* fix regression

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* tighten up rules

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* Move GpuBroadcastJoinMeta to com.nvidia.spark.rapids package

Signed-off-by: Andy Grove <andygrove@nvidia.com>

* Move GpuBroadcastJoinMeta to com.nvidia.spark.rapids package

Signed-off-by: Andy Grove <andygrove@nvidia.com>
  • Loading branch information
andygrove authored Oct 29, 2020
1 parent a855df7 commit ee0fff2
Show file tree
Hide file tree
Showing 10 changed files with 162 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.rapids.execution.SerializeConcatHostBuffersDeserializeBatch
import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExecBase, SerializeConcatHostBuffersDeserializeBatch}
import org.apache.spark.sql.vectorized.ColumnarBatch

/**
Expand All @@ -40,7 +40,7 @@ class GpuBroadcastHashJoinMeta(
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: ConfKeysAndIncompat)
extends SparkPlanMeta[BroadcastHashJoinExec](join, conf, parent, rule) {
extends GpuBroadcastJoinMeta[BroadcastHashJoinExec](join, conf, parent, rule) {

val leftKeys: Seq[BaseExprMeta[_]] =
join.leftKeys.map(GpuOverrides.wrapExpr(_, conf, Some(this)))
Expand Down Expand Up @@ -76,9 +76,7 @@ class GpuBroadcastHashJoinMeta(
case BuildLeft => left
case BuildRight => right
}
if (!buildSide.isInstanceOf[GpuBroadcastExchangeExec]) {
throw new IllegalStateException("the broadcast must be on the GPU too")
}
verifyBuildSideWasReplaced(buildSide)
GpuBroadcastHashJoinExec(
leftKeys.map(_.convertToGpu()),
rightKeys.map(_.convertToGpu()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,13 @@ class Spark300Shims extends SparkShims {
}
}

override def isGpuBroadcastNestedLoopJoin(plan: SparkPlan): Boolean = {
plan match {
case _: GpuBroadcastNestedLoopJoinExecBase => true
case _ => false
}
}

override def isGpuShuffledHashJoin(plan: SparkPlan): Boolean = {
plan match {
case _: GpuShuffledHashJoinExec => true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
package com.nvidia.spark.rapids.shims.spark301

import com.nvidia.spark.rapids.{BaseExprMeta, ConfKeysAndIncompat, GpuBindReferences, GpuColumnVector, GpuExec, GpuOverrides, GpuProjectExec, RapidsConf, RapidsMeta, SparkPlanMeta}
import com.nvidia.spark.rapids.{BaseExprMeta, ConfKeysAndIncompat, GpuBindReferences, GpuBroadcastJoinMeta, GpuColumnVector, GpuExec, GpuOverrides, GpuProjectExec, RapidsConf, RapidsMeta, SparkPlanMeta}
import com.nvidia.spark.rapids.GpuMetricNames.{NUM_OUTPUT_BATCHES, NUM_OUTPUT_ROWS, TOTAL_TIME}
import com.nvidia.spark.rapids.shims.spark300.GpuHashJoin

Expand All @@ -29,15 +29,15 @@ import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildLeft, BuildRight, BuildSide, HashedRelationBroadcastMode}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExecBase, SerializeConcatHostBuffersDeserializeBatch}
import org.apache.spark.sql.rapids.execution.SerializeConcatHostBuffersDeserializeBatch
import org.apache.spark.sql.vectorized.ColumnarBatch

class GpuBroadcastHashJoinMeta(
join: BroadcastHashJoinExec,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: ConfKeysAndIncompat)
extends SparkPlanMeta[BroadcastHashJoinExec](join, conf, parent, rule) {
extends GpuBroadcastJoinMeta[BroadcastHashJoinExec](join, conf, parent, rule) {

val leftKeys: Seq[BaseExprMeta[_]] =
join.leftKeys.map(GpuOverrides.wrapExpr(_, conf, Some(this)))
Expand All @@ -56,18 +56,12 @@ class GpuBroadcastHashJoinMeta(
case BuildRight => childPlans(1)
}

buildSide.wrapped match {
case _: BroadcastQueryStageExec =>
// this already ran on GPU

case _ =>
if (!buildSide.canThisBeReplaced) {
willNotWorkOnGpu("the broadcast for this join must be on the GPU too")
}
if (!canBuildSideBeReplaced(buildSide)) {
willNotWorkOnGpu("the broadcast for this join must be on the GPU too")
}

if (!canThisBeReplaced) {
buildSide.willNotWorkOnGpu("the BroadcastHashJoin this feeds is not on the GPU")
}
if (!canThisBeReplaced) {
buildSide.willNotWorkOnGpu("the BroadcastHashJoin this feeds is not on the GPU")
}
}

Expand All @@ -79,11 +73,7 @@ class GpuBroadcastHashJoinMeta(
case BuildLeft => left
case BuildRight => right
}
buildSide match {
case _: GpuBroadcastExchangeExecBase =>
case _: BroadcastQueryStageExec =>
case _ => throw new IllegalStateException("the broadcast must be on the GPU too")
}
verifyBuildSideWasReplaced(buildSide)
GpuBroadcastHashJoinExec(
leftKeys.map(_.convertToGpu()),
rightKeys.map(_.convertToGpu()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class GpuBroadcastHashJoinMeta(
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: ConfKeysAndIncompat)
extends SparkPlanMeta[BroadcastHashJoinExec](join, conf, parent, rule) {
extends GpuBroadcastJoinMeta[BroadcastHashJoinExec](join, conf, parent, rule) {

val leftKeys: Seq[BaseExprMeta[_]] =
join.leftKeys.map(GpuOverrides.wrapExpr(_, conf, Some(this)))
Expand All @@ -61,19 +61,14 @@ class GpuBroadcastHashJoinMeta(
case BuildRight => childPlans(1)
}

buildSide.wrapped match {
case _: BroadcastQueryStageExec =>
// this already ran on GPU

case _ =>
if (!buildSide.canThisBeReplaced) {
willNotWorkOnGpu("the broadcast for this join must be on the GPU too")
}
if (!canBuildSideBeReplaced(buildSide)) {
willNotWorkOnGpu("the broadcast for this join must be on the GPU too")
}

if (!canThisBeReplaced) {
buildSide.willNotWorkOnGpu("the BroadcastHashJoin this feeds is not on the GPU")
}
if (!canThisBeReplaced) {
buildSide.willNotWorkOnGpu("the BroadcastHashJoin this feeds is not on the GPU")
}

}

override def convertToGpu(): GpuExec = {
Expand All @@ -84,11 +79,7 @@ class GpuBroadcastHashJoinMeta(
case BuildLeft => left
case BuildRight => right
}
buildSide match {
case _: GpuBroadcastExchangeExecBase =>
case _: BroadcastQueryStageExec =>
case _ => throw new IllegalStateException("the broadcast must be on the GPU too")
}
verifyBuildSideWasReplaced(buildSide)
GpuBroadcastHashJoinExec(
leftKeys.map(_.convertToGpu()),
rightKeys.map(_.convertToGpu()),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.nvidia.spark.rapids

import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.rapids.execution.GpuBroadcastExchangeExecBase

abstract class GpuBroadcastJoinMeta[INPUT <: SparkPlan](plan: INPUT,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: ConfKeysAndIncompat)
extends SparkPlanMeta[INPUT](plan, conf, parent, rule) {

def canBuildSideBeReplaced(buildSide: SparkPlanMeta[_]): Boolean = {
buildSide.wrapped match {
case BroadcastQueryStageExec(_, _: GpuBroadcastExchangeExecBase) => true
case BroadcastQueryStageExec(_, reused: ReusedExchangeExec) =>
reused.child.isInstanceOf[GpuBroadcastExchangeExecBase]
case reused: ReusedExchangeExec => reused.child.isInstanceOf[GpuBroadcastExchangeExecBase]
case _: GpuBroadcastExchangeExecBase => true
case _ => buildSide.canThisBeReplaced
}
}

def verifyBuildSideWasReplaced(buildSide: SparkPlan): Unit = {
val buildSideOnGpu = buildSide match {
case BroadcastQueryStageExec(_, _: GpuBroadcastExchangeExecBase) => true
case BroadcastQueryStageExec(_, reused: ReusedExchangeExec) =>
reused.child.isInstanceOf[GpuBroadcastExchangeExecBase]
case reused: ReusedExchangeExec => reused.child.isInstanceOf[GpuBroadcastExchangeExecBase]
case _: GpuBroadcastExchangeExecBase => true
case _ => false
}
if (!buildSideOnGpu) {
throw new IllegalStateException(s"the broadcast must be on the GPU too")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ trait SparkShims {
def getSparkShimVersion: ShimVersion
def isGpuHashJoin(plan: SparkPlan): Boolean
def isGpuBroadcastHashJoin(plan: SparkPlan): Boolean
def isGpuBroadcastNestedLoopJoin(plan: SparkPlan): Boolean
def isGpuShuffledHashJoin(plan: SparkPlan): Boolean
def isBroadcastExchangeLike(plan: SparkPlan): Boolean
def isShuffleExchangeLike(plan: SparkPlan): Boolean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.plans.{Cross, ExistenceJoin, FullOuter, Inner, InnerLike, JoinType, LeftExistence, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, IdentityBroadcastMode, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, QueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoinExec
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
Expand All @@ -39,7 +40,7 @@ class GpuBroadcastNestedLoopJoinMeta(
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: ConfKeysAndIncompat)
extends SparkPlanMeta[BroadcastNestedLoopJoinExec](join, conf, parent, rule) {
extends GpuBroadcastJoinMeta[BroadcastNestedLoopJoinExec](join, conf, parent, rule) {

val condition: Option[BaseExprMeta[_]] =
join.condition.map(GpuOverrides.wrapExpr(_, conf, Some(this)))
Expand All @@ -59,13 +60,13 @@ class GpuBroadcastNestedLoopJoinMeta(
case GpuBuildRight => childPlans(1)
}

if (!buildSide.canThisBeReplaced) {
if (!canBuildSideBeReplaced(buildSide)) {
willNotWorkOnGpu("the broadcast for this join must be on the GPU too")
}

if (!canThisBeReplaced) {
buildSide.willNotWorkOnGpu(
"the GpuBroadcastNestedLoopJoin this feeds is not on the GPU")
"the BroadcastNestedLoopJoin this feeds is not on the GPU")
}
}

Expand All @@ -77,10 +78,8 @@ class GpuBroadcastNestedLoopJoinMeta(
val buildSide = gpuBuildSide match {
case GpuBuildLeft => left
case GpuBuildRight => right
}
if (!buildSide.isInstanceOf[GpuBroadcastExchangeExecBase]) {
throw new IllegalStateException("the broadcast must be on the GPU too")
}
verifyBuildSideWasReplaced(buildSide)
ShimLoader.getSparkShims.getGpuBroadcastNestedLoopJoinShim(
left, right, join,
join.joinType,
Expand Down Expand Up @@ -161,6 +160,9 @@ abstract class GpuBroadcastNestedLoopJoinExecBase(
}

def broadcastExchange: GpuBroadcastExchangeExecBase = broadcast match {
case BroadcastQueryStageExec(_, gpu: GpuBroadcastExchangeExecBase) => gpu
case BroadcastQueryStageExec(_, reused: ReusedExchangeExec) =>
reused.child.asInstanceOf[GpuBroadcastExchangeExecBase]
case gpu: GpuBroadcastExchangeExecBase => gpu
case reused: ReusedExchangeExec => reused.child.asInstanceOf[GpuBroadcastExchangeExecBase]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.nvidia.spark.rapids

import com.nvidia.spark.rapids.TestUtils.{findOperator, operatorCount}
import com.nvidia.spark.rapids.TestUtils.{findOperator, findOperators}

import org.apache.spark.SparkConf
import org.apache.spark.sql.execution.joins.HashJoin
Expand All @@ -40,10 +40,10 @@ class BroadcastHashJoinSuite extends SparkQueryCompareTestSuite {
df5.collect()
val plan = df5.queryExecution.executedPlan

val bhjCount = operatorCount(plan, ShimLoader.getSparkShims.isGpuBroadcastHashJoin)
val bhjCount = findOperators(plan, ShimLoader.getSparkShims.isGpuBroadcastHashJoin)
assert(bhjCount.size === 1)

val shjCount = operatorCount(plan, ShimLoader.getSparkShims.isGpuShuffledHashJoin)
val shjCount = findOperators(plan, ShimLoader.getSparkShims.isGpuShuffledHashJoin)
assert(shjCount.size === 1)
}, conf)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids

import com.nvidia.spark.rapids.TestUtils.findOperators

import org.apache.spark.SparkConf
import org.apache.spark.sql.functions.broadcast
import org.apache.spark.sql.internal.SQLConf

class BroadcastNestedLoopJoinSuite extends SparkQueryCompareTestSuite {

test("BroadcastNestedLoopJoinExec AQE off") {
val conf = new SparkConf()
.set("spark.rapids.sql.exec.BroadcastNestedLoopJoinExec", "true")
.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "false")

withGpuSparkSession(spark => {
val df1 = longsDf(spark).repartition(2)
val df2 = nonZeroLongsDf(spark).repartition(2)
val df3 = df1.crossJoin(broadcast(df2))
df3.collect()
val plan = df3.queryExecution.executedPlan

val nljCount = findOperators(plan, ShimLoader.getSparkShims.isGpuBroadcastNestedLoopJoin)
assert(nljCount.size === 1)
}, conf)
}

test("BroadcastNestedLoopJoinExec AQE on") {
val conf = new SparkConf()
.set("spark.rapids.sql.exec.BroadcastNestedLoopJoinExec", "true")
.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")

withGpuSparkSession(spark => {
val df1 = longsDf(spark).repartition(2)
val df2 = nonZeroLongsDf(spark).repartition(2)
val df3 = df1.crossJoin(broadcast(df2))
df3.collect()
val plan = df3.queryExecution.executedPlan

val nljCount = findOperators(plan, ShimLoader.getSparkShims.isGpuBroadcastNestedLoopJoin)

ShimLoader.getSparkShims.getSparkShimVersion match {
case SparkShimVersion(3, 0, 0) =>
// we didn't start supporting GPU exchanges with AQE until 3.0.1
assert(nljCount.size === 0)
case _ =>
assert(nljCount.size === 1)
}

}, conf)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ object TestUtils extends Assertions with Arm {
}

/** Return list of matching predicates present in the plan */
def operatorCount(plan: SparkPlan, predicate: SparkPlan => Boolean): Seq[SparkPlan] = {
def findOperators(plan: SparkPlan, predicate: SparkPlan => Boolean): Seq[SparkPlan] = {
def recurse(
plan: SparkPlan,
predicate: SparkPlan => Boolean,
Expand Down

0 comments on commit ee0fff2

Please sign in to comment.