Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ExistenceJoin Iterator using an auxiliary left semijoin #4796

Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
a080f9a
wip
gerashegalov Feb 9, 2022
cf135d7
Merge remote-tracking branch 'origin/branch-22.04' into gerashegalov/…
gerashegalov Feb 9, 2022
4ad4616
Merge remote-tracking branch 'origin/branch-22.04' into gerashegalov/…
gerashegalov Feb 15, 2022
deba9a1
wip
gerashegalov Feb 16, 2022
664ba4e
wip
gerashegalov Feb 16, 2022
a5da440
Use verboseStringWithSuffix(1000) for TreeNode match
gerashegalov Feb 16, 2022
b55a3e1
Use Scala Int.MinValue
gerashegalov Feb 16, 2022
6293dc1
config.md update for spark.rapids.sql.join.existence.enabled
gerashegalov Feb 16, 2022
8c9597b
Merge remote-tracking branch 'origin/branch-22.04' into gerashegalov/…
gerashegalov Feb 16, 2022
2517f73
undo buildall changes
gerashegalov Feb 16, 2022
193dbac
restore whitespace
gerashegalov Feb 16, 2022
5ed19db
existence join test with rhs duplicates
gerashegalov Feb 17, 2022
1c332ec
semijoin-based implementation
gerashegalov Feb 18, 2022
a02992e
undo cosmetic change
gerashegalov Feb 18, 2022
7e6bfcd
fix tagForGpu in GpuHashJoin
gerashegalov Feb 18, 2022
e333159
test updates
gerashegalov Feb 19, 2022
89c9a0e
draft
gerashegalov Feb 22, 2022
f4fe704
Merge remote-tracking branch 'origin/branch-22.04' into gerashegalov/…
gerashegalov Feb 23, 2022
465a957
undo gatherer changes
gerashegalov Feb 23, 2022
1a44866
mem leaks fixed
gerashegalov Feb 24, 2022
baea237
lhs dupes, conditional join
gerashegalov Feb 24, 2022
29704bf
mixedLeftSemiJoinGatherMap
gerashegalov Feb 25, 2022
b19c98f
mnemonic test ids
gerashegalov Feb 25, 2022
2948a66
refactoring
gerashegalov Feb 25, 2022
6d55b6a
comment fix
gerashegalov Feb 25, 2022
a700f33
wip
gerashegalov Feb 28, 2022
67d8616
broadcast hash join test
gerashegalov Mar 2, 2022
e4c0a40
undo import py
gerashegalov Mar 2, 2022
3a5a173
undo explain
gerashegalov Mar 2, 2022
2f88244
undo explain
gerashegalov Mar 2, 2022
5393301
fixe bhj test id
gerashegalov Mar 2, 2022
02b037c
Update comment in join_test.py
gerashegalov Mar 2, 2022
9f2ed60
review comments
gerashegalov Mar 3, 2022
df4ad01
Merge remote-tracking branch 'origin/branch-22.04' into pr/gerashegal…
gerashegalov Mar 3, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ Name | Description | Default Value
<a name="sql.incompatibleDateFormats.enabled"></a>spark.rapids.sql.incompatibleDateFormats.enabled|When parsing strings as dates and timestamps in functions like unix_timestamp, some formats are fully supported on the GPU and some are unsupported and will fall back to the CPU. Some formats behave differently on the GPU than the CPU. Spark on the CPU interprets date formats with unsupported trailing characters as nulls, while Spark on the GPU will parse the date with invalid trailing characters. More detail can be found at [parsing strings as dates or timestamps](compatibility.md#parsing-strings-as-dates-or-timestamps).|false
<a name="sql.incompatibleOps.enabled"></a>spark.rapids.sql.incompatibleOps.enabled|For operations that work, but are not 100% compatible with the Spark equivalent set if they should be enabled by default or disabled by default.|false
<a name="sql.join.cross.enabled"></a>spark.rapids.sql.join.cross.enabled|When set to true cross joins are enabled on the GPU|true
<a name="sql.join.existence.enabled"></a>spark.rapids.sql.join.existence.enabled|When set to true existence joins are enabled on the GPU|true
<a name="sql.join.fullOuter.enabled"></a>spark.rapids.sql.join.fullOuter.enabled|When set to true full outer joins are enabled on the GPU|true
<a name="sql.join.inner.enabled"></a>spark.rapids.sql.join.inner.enabled|When set to true inner joins are enabled on the GPU|true
<a name="sql.join.leftAnti.enabled"></a>spark.rapids.sql.join.leftAnti.enabled|When set to true left anti joins are enabled on the GPU|true
Expand Down
60 changes: 44 additions & 16 deletions integration_tests/src/main/python/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,27 +779,55 @@ def do_join(spark):
# If the condition is something like an AND, it makes the result a subset of a SemiJoin, and
# the optimizer won't use ExistenceJoin.
@ignore_order(local=True)
@pytest.mark.parametrize(
"allowFallback", [
pytest.param('true',
marks=pytest.mark.allow_non_gpu('SortMergeJoinExec')),
pytest.param('false',
marks=pytest.mark.xfail(reason="https://github.com/NVIDIA/spark-rapids/issues/589"))
], ids=idfn
)
def test_existence_join(allowFallback, spark_tmp_table_factory):
@pytest.mark.parametrize('numComplementsToExists', [0, 1, 2], ids=(lambda val: f"complements:{val}") )
@pytest.mark.parametrize('aqeEnabled', [
pytest.param(False, id='aqe:off'),
# workaround: somehow AQE retains RDDScanExec preventing parent ShuffleExchangeExec
# from being executed on GPU
pytest.param(True, marks=pytest.mark.allow_non_gpu('ShuffleExchangeExec'), id='aqe:on')
])
@pytest.mark.parametrize('conditionalJoin', [False, True], ids=['ast:off', 'ast:on'])
def test_existence_join(numComplementsToExists, aqeEnabled, conditionalJoin, spark_tmp_table_factory):
jlowe marked this conversation as resolved.
Show resolved Hide resolved
leftTable = spark_tmp_table_factory.get()
rightTable = spark_tmp_table_factory.get()
def do_join(spark):
# create non-overlapping ranges to have a mix of exists=true and exists=false
spark.createDataFrame([v] for v in range(2, 10)).createOrReplaceTempView(leftTable)
spark.createDataFrame([v] for v in range(0, 8)).createOrReplaceTempView(rightTable)

# left-hand side rows
lhs_upper_bound = 10
lhs_data = list((f"left_{v}", v * 10, v * 100) for v in range(2, lhs_upper_bound))
# duplicate without a match
lhs_data.append(('left_1', 10, 100))
# duplicate with a match
lhs_data.append(('left_2', 20, 200))
lhs_data.append(('left_null', None, None))
df_left = spark.createDataFrame(lhs_data)
df_left.createOrReplaceTempView(leftTable)

rhs_data = list((f"right_{v}", v * 10, v * 100) for v in range(0, 8))
rhs_data.append(('right_null', None, None))
# duplicate every row in the rhs to verify it does not affect
# the number of the output rows, which should be equal to the number of the
jlowe marked this conversation as resolved.
Show resolved Hide resolved
rhs_data_with_dupes=[]
for dupe in rhs_data:
rhs_data_with_dupes.extend([dupe, dupe])

df_right = spark.createDataFrame(rhs_data_with_dupes)
df_right.createOrReplaceTempView(rightTable)
cond = "<=" if conditionalJoin else "="
res = spark.sql((
"select * "
"from {} as l "
"where l._1 < 0 "
" OR l._1 in (select * from {} as r)"
).format(leftTable, rightTable))
f"where l._2 >= {10 * (lhs_upper_bound - numComplementsToExists)}"
" or exists (select * from {} as r where r._2 = l._2 AND r._3 {} l._3)"
).format(leftTable, rightTable, cond))
return res
assert_cpu_and_gpu_are_equal_collect_with_capture(do_join, r".+Join ExistenceJoin\(exists#[0-9]+\).+")

if conditionalJoin:
existenceJoinRegex = r"ExistenceJoin\(exists#[0-9]+\), \(.+ <= .+\)"
else:
existenceJoinRegex = r"ExistenceJoin\(exists#[0-9]+\)"

assert_cpu_and_gpu_are_equal_collect_with_capture(do_join, existenceJoinRegex,
conf={
"spark.sql.adaptive.enabled": aqeEnabled,
jlowe marked this conversation as resolved.
Show resolved Hide resolved
})
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,9 @@ object ExecutionPlanCaptureCallback {
case p if p.expressions.exists(containsExpression(_, className, regexMap)) =>
true
case p: SparkPlan =>
val sparkPlanStringForRegex = p.verboseStringWithSuffix(1000)
regexMap.getOrElseUpdate(className, className.r)
.findFirstIn(p.simpleStringWithNodeId())
.findFirstIn(sparkPlanStringForRegex)
.nonEmpty
}.nonEmpty
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,11 @@ object RapidsConf {
.booleanConf
.createWithDefault(true)

val ENABLE_EXISTENCE_JOIN = conf("spark.rapids.sql.join.existence.enabled")
.doc("When set to true existence joins are enabled on the GPU")
.booleanConf
.createWithDefault(true)

val ENABLE_PROJECT_AST = conf("spark.rapids.sql.projectAstEnabled")
.doc("Enable project operations to use cudf AST expressions when possible.")
.internal()
Expand Down Expand Up @@ -1562,6 +1567,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging {

lazy val areLeftAntiJoinsEnabled: Boolean = get(ENABLE_LEFT_ANTI_JOIN)

lazy val areExistenceJoinsEnabled: Boolean = get(ENABLE_EXISTENCE_JOIN)

lazy val isCastDecimalToFloatEnabled: Boolean = get(ENABLE_CAST_DECIMAL_TO_FLOAT)

lazy val isCastFloatToDecimalEnabled: Boolean = get(ENABLE_CAST_FLOAT_TO_DECIMAL)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
*/
package org.apache.spark.sql.rapids.execution

import ai.rapids.cudf.{DType, GroupByAggregation, NullEquality, NullPolicy, NvtxColor, ReductionAggregation, Table}
import scala.collection.mutable.ArrayBuffer

import ai.rapids.cudf.{ColumnVector, DType, GatherMap, GroupByAggregation, NullEquality, NullPolicy, NvtxColor, ReductionAggregation, Scalar, Table}
import ai.rapids.cudf.ast.CompiledExpression
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.RapidsPluginImplicits._

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.plans.{Cross, ExistenceJoin, FullOuter, Inner, InnerLike, JoinType, LeftAnti, LeftExistence, LeftOuter, LeftSemi, RightOuter}
Expand Down Expand Up @@ -50,6 +53,9 @@ object JoinTypeChecks {
case LeftAnti if !conf.areLeftAntiJoinsEnabled =>
meta.willNotWorkOnGpu("left anti joins have been disabled. To enable set " +
s"${RapidsConf.ENABLE_LEFT_ANTI_JOIN.key} to true")
case ExistenceJoin(_) if !conf.areExistenceJoinsEnabled =>
meta.willNotWorkOnGpu("existence joins have been disabled. To enable set " +
s"${RapidsConf.ENABLE_EXISTENCE_JOIN.key} to true")
case _ => // not disabled
}
}
Expand Down Expand Up @@ -107,7 +113,7 @@ object GpuHashJoin extends Arm {
JoinTypeChecks.tagForGpu(joinType, meta)
joinType match {
case _: InnerLike =>
case RightOuter | LeftOuter | LeftSemi | LeftAnti =>
case RightOuter | LeftOuter | LeftSemi | LeftAnti | ExistenceJoin(_) =>
conditionMeta.foreach(meta.requireAstForGpuOn)
case FullOuter =>
conditionMeta.foreach(meta.requireAstForGpuOn)
Expand Down Expand Up @@ -622,7 +628,114 @@ trait GpuHashJoin extends GpuExec {
(joinLeft.output.size, boundCondition)
}

def doJoin(
private def existenceJoinIterator(
builtBatch: ColumnarBatch,
stream: Iterator[ColumnarBatch]) = new Iterator[ColumnarBatch]()
gerashegalov marked this conversation as resolved.
Show resolved Hide resolved
with AutoCloseable with Arm {
gerashegalov marked this conversation as resolved.
Show resolved Hide resolved
var closed: Boolean = false

// iteration-independent resources
val resources = ArrayBuffer[AutoCloseable]()
def use[T <: AutoCloseable](ac: T): T = {
resources += ac
ac
}

val compiledConditionRes: Option[(Table, CompiledExpression)] = boundCondition.map(gpuExpr => (
use(GpuColumnVector.from(builtBatch)),
gerashegalov marked this conversation as resolved.
Show resolved Hide resolved
use(gpuExpr.convertToAst(numFirstConditionTableColumns).compile())
)
)

val rightKeysTab = use(
withResource(GpuProjectExec.project(builtBatch, boundBuildKeys))(GpuColumnVector.from(_)))
gerashegalov marked this conversation as resolved.
Show resolved Hide resolved

val falseScalar = use(Scalar.fromBool(false))
val trueScalar = use(Scalar.fromBool(true))
gerashegalov marked this conversation as resolved.
Show resolved Hide resolved
override def hasNext: Boolean = {
val streamHasNext = stream.hasNext
if (!streamHasNext) {
close()
}
streamHasNext
}

override def next(): ColumnarBatch = {
try {
withResource(stream.next()) { leftColumnarBatch =>
existenceJoinNextBatch(leftColumnarBatch)
}
} catch {
case t: Throwable =>
gerashegalov marked this conversation as resolved.
Show resolved Hide resolved
close()
throw t
}
}

override def close() = if (!closed) {
closed = true
val resourcesReversed = resources.reverse
resourcesReversed.safeClose()
}

private def leftKeysTable(leftColumnarBatch: ColumnarBatch): Table = {
withResource(GpuProjectExec.project(leftColumnarBatch, boundStreamKeys))(
leftKeys => GpuColumnVector.from(leftKeys))
}

private def conditionalBatchLeftSemiJoin(
leftColumnarBatch: ColumnarBatch,
rightTab: Table,
leftfKeysTab: Table,
compiledCondition: CompiledExpression): GatherMap = {
withResource(GpuColumnVector.from(leftColumnarBatch))(leftTab =>
Table.mixedLeftSemiJoinGatherMap(
leftfKeysTab,
rightKeysTab,
leftTab,
rightTab,
compiledCondition,
if (compareNullsEqual) NullEquality.EQUAL else NullEquality.UNEQUAL))
}

private def existsScatterMap(leftColumnarBatch: ColumnarBatch): GatherMap = {
gerashegalov marked this conversation as resolved.
Show resolved Hide resolved
withResource(leftKeysTable(leftColumnarBatch))(leftKeysTab =>
gerashegalov marked this conversation as resolved.
Show resolved Hide resolved
compiledConditionRes.map { case (rightTab, compiledCondition) =>
conditionalBatchLeftSemiJoin(leftColumnarBatch, rightTab, leftKeysTab, compiledCondition)
}.getOrElse {
leftKeysTab.leftSemiJoinGatherMap(rightKeysTab, compareNullsEqual)
})
}

private def falseColumnTable(numLeftRows: Int): Table = {
withResource(ColumnVector.fromScalar(falseScalar, numLeftRows))(
new Table(_)
)
gerashegalov marked this conversation as resolved.
Show resolved Hide resolved
}

private def existsTable(leftColumnarBatch: ColumnarBatch): Table = {
gerashegalov marked this conversation as resolved.
Show resolved Hide resolved
withResource(existsScatterMap(leftColumnarBatch)) { existsScatterMap =>
val numLeftRows = leftColumnarBatch.numRows
withResource(falseColumnTable(numLeftRows)) { allFalseTable =>
val numExistsTrueRows = existsScatterMap.getRowCount.toInt
withResource(existsScatterMap.toColumnView(0, numExistsTrueRows)) { existsView =>
Table.scatter(Array(trueScalar), existsView, allFalseTable, false)
}
}
}
}

private def existenceJoinNextBatch(leftColumnarBatch: ColumnarBatch): ColumnarBatch = {
// left columns with exists
withResource(existsTable(leftColumnarBatch)) { existsTable =>
val resCols = GpuColumnVector.extractBases(leftColumnarBatch) :+ existsTable.getColumn(0)
val resTypes = GpuColumnVector.extractTypes(leftColumnarBatch) :+ BooleanType
withResource(new Table(resCols: _*))(resTab => GpuColumnVector.from(resTab, resTypes))
}
}
}

private def hashJoinLikeIterator(
builtBatch: ColumnarBatch,
stream: Iterator[ColumnarBatch],
targetSize: Long,
Expand Down Expand Up @@ -657,7 +770,7 @@ trait GpuHashJoin extends GpuExec {

// The HashJoinIterator takes ownership of the built keys and built data. It will close
// them when it is done
val joinIterator = if (boundCondition.isDefined) {
if (boundCondition.isDefined) {
// ConditionalHashJoinIterator will close the compiled condition
val compiledCondition =
boundCondition.get.convertToAst(numFirstConditionTableColumns).compile()
Expand All @@ -669,6 +782,27 @@ trait GpuHashJoin extends GpuExec {
streamedPlan.output, realTarget, joinType, buildSide, compareNullsEqual, spillCallback,
opTime, joinTime)
}
}

def doJoin(
builtBatch: ColumnarBatch,
stream: Iterator[ColumnarBatch],
targetSize: Long,
spillCallback: SpillCallback,
numOutputRows: GpuMetric,
joinOutputRows: GpuMetric,
numOutputBatches: GpuMetric,
opTime: GpuMetric,
joinTime: GpuMetric): Iterator[ColumnarBatch] = {

// The HashJoinIterator takes ownership of the built keys and built data. It will close
// them when it is done
val joinIterator = if (joinType.isInstanceOf[ExistenceJoin]) {
existenceJoinIterator(builtBatch, stream)
jlowe marked this conversation as resolved.
Show resolved Hide resolved
} else {
hashJoinLikeIterator(builtBatch, stream, targetSize, spillCallback, numOutputRows,
joinOutputRows, numOutputBatches, opTime, joinTime)
}

joinIterator.map { cb =>
joinOutputRows += cb.numRows()
Expand Down