Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ExistenceJoin Iterator using an auxiliary left semijoin #4796

Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
a080f9a
wip
gerashegalov Feb 9, 2022
cf135d7
Merge remote-tracking branch 'origin/branch-22.04' into gerashegalov/…
gerashegalov Feb 9, 2022
4ad4616
Merge remote-tracking branch 'origin/branch-22.04' into gerashegalov/…
gerashegalov Feb 15, 2022
deba9a1
wip
gerashegalov Feb 16, 2022
664ba4e
wip
gerashegalov Feb 16, 2022
a5da440
Use verboseStringWithSuffix(1000) for TreeNode match
gerashegalov Feb 16, 2022
b55a3e1
Use Scala Int.MinValue
gerashegalov Feb 16, 2022
6293dc1
config.md update for spark.rapids.sql.join.existence.enabled
gerashegalov Feb 16, 2022
8c9597b
Merge remote-tracking branch 'origin/branch-22.04' into gerashegalov/…
gerashegalov Feb 16, 2022
2517f73
undo buildall changes
gerashegalov Feb 16, 2022
193dbac
restore whitespace
gerashegalov Feb 16, 2022
5ed19db
existence join test with rhs duplicates
gerashegalov Feb 17, 2022
1c332ec
semijoin-based implementation
gerashegalov Feb 18, 2022
a02992e
undo cosmetic change
gerashegalov Feb 18, 2022
7e6bfcd
fix tagForGpu in GpuHashJoin
gerashegalov Feb 18, 2022
e333159
test updates
gerashegalov Feb 19, 2022
89c9a0e
draft
gerashegalov Feb 22, 2022
f4fe704
Merge remote-tracking branch 'origin/branch-22.04' into gerashegalov/…
gerashegalov Feb 23, 2022
465a957
undo gatherer changes
gerashegalov Feb 23, 2022
1a44866
mem leaks fixed
gerashegalov Feb 24, 2022
baea237
lhs dupes, conditional join
gerashegalov Feb 24, 2022
29704bf
mixedLeftSemiJoinGatherMap
gerashegalov Feb 25, 2022
b19c98f
mnemonic test ids
gerashegalov Feb 25, 2022
2948a66
refactoring
gerashegalov Feb 25, 2022
6d55b6a
comment fix
gerashegalov Feb 25, 2022
a700f33
wip
gerashegalov Feb 28, 2022
67d8616
broadcast hash join test
gerashegalov Mar 2, 2022
e4c0a40
undo import py
gerashegalov Mar 2, 2022
3a5a173
undo explain
gerashegalov Mar 2, 2022
2f88244
undo explain
gerashegalov Mar 2, 2022
5393301
fixe bhj test id
gerashegalov Mar 2, 2022
02b037c
Update comment in join_test.py
gerashegalov Mar 2, 2022
9f2ed60
review comments
gerashegalov Mar 3, 2022
df4ad01
Merge remote-tracking branch 'origin/branch-22.04' into pr/gerashegal…
gerashegalov Mar 3, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions build/buildall
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,22 @@
# limitations under the License.
#

set -ex
set -e

shopt -s extglob

# actually bloopInstall mojo version but it's just a nuance at the moment.
BLOOP_VERSION=${BLOOP_VERSION:-"1.4.12"}
BLOOP_SCALA_VERSION=${BLOOP_SCALA_VERSION:-"2.13"}
SKIP_CLEAN=1

function print_usage() {
echo "Usage: buildall [OPTION]"
echo "Options:"
echo " -h, --help"
echo " print this help message"
echo " --clean"
echo " include Maven clean phase"
echo " -gb, --generate-bloop"
echo " generate projects for Bloop clients: IDE (Scala Metals, IntelliJ) or Bloop CLI"
echo " -p=DIST_PROFILE, --profile=DIST_PROFILE"
Expand All @@ -50,7 +57,7 @@ function bloopInstall() {
mkdir -p "$bloop_config_dir"
rm -f "$bloop_config_dir"/*

mvn install ch.epfl.scala:maven-bloop_2.13:1.4.9:bloopInstall -pl dist -am \
mvn install ch.epfl.scala:maven-bloop_${BLOOP_SCALA_VERSION}:${BLOOP_VERSION}:bloopInstall -pl dist -am \
jlowe marked this conversation as resolved.
Show resolved Hide resolved
-Dbloop.configDirectory="$bloop_config_dir" \
-DdownloadSources=true \
-Dbuildver="$bv" \
Expand Down Expand Up @@ -102,6 +109,14 @@ case "$1" in
BUILD_PARALLEL="${1#*=}"
;;

--debug)
set -x
;;

--clean)
SKIP_CLEAN="0"
;;

*)
echo >&2 "Unknown arg: $1"
print_usage
Expand Down Expand Up @@ -183,8 +198,10 @@ export BASE_VER=${SPARK_SHIM_VERSIONS[0]}
export NUM_SHIMS=${#SPARK_SHIM_VERSIONS[@]}
export BUILD_PARALLEL=${BUILD_PARALLEL:-4}

echo Clean once across all modules
mvn -q clean
if [[ "$SKIP_CLEAN" != "1" ]]; then
echo Clean once across all modules
mvn -q clean
fi

echo "Building a combined dist jar with Shims for ${SPARK_SHIM_VERSIONS[@]} ..."

Expand Down Expand Up @@ -219,7 +236,7 @@ function build_single_shim() {
-Dbuildver="$BUILD_VER" \
-Drat.skip="$SKIP_CHECKS" \
-Dmaven.javadoc.skip="$SKIP_CHECKS" \
-Dskip="$SKIP_CHECKS" \
-Dskip \
-Dmaven.scalastyle.skip="$SKIP_CHECKS" \
-pl aggregator -am > "$LOG_FILE" 2>&1 || {
[[ "$LOG_FILE" != "/dev/tty" ]] && tail -20 "$LOG_FILE" || true
Expand Down
1 change: 1 addition & 0 deletions docs/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ Name | Description | Default Value
<a name="sql.incompatibleDateFormats.enabled"></a>spark.rapids.sql.incompatibleDateFormats.enabled|When parsing strings as dates and timestamps in functions like unix_timestamp, some formats are fully supported on the GPU and some are unsupported and will fall back to the CPU. Some formats behave differently on the GPU than the CPU. Spark on the CPU interprets date formats with unsupported trailing characters as nulls, while Spark on the GPU will parse the date with invalid trailing characters. More detail can be found at [parsing strings as dates or timestamps](compatibility.md#parsing-strings-as-dates-or-timestamps).|false
<a name="sql.incompatibleOps.enabled"></a>spark.rapids.sql.incompatibleOps.enabled|For operations that work, but are not 100% compatible with the Spark equivalent set if they should be enabled by default or disabled by default.|false
<a name="sql.join.cross.enabled"></a>spark.rapids.sql.join.cross.enabled|When set to true cross joins are enabled on the GPU|true
<a name="sql.join.existence.enabled"></a>spark.rapids.sql.join.existence.enabled|When set to true existence joins are enabled on the GPU|true
<a name="sql.join.fullOuter.enabled"></a>spark.rapids.sql.join.fullOuter.enabled|When set to true full outer joins are enabled on the GPU|true
<a name="sql.join.inner.enabled"></a>spark.rapids.sql.join.inner.enabled|When set to true inner joins are enabled on the GPU|true
<a name="sql.join.leftAnti.enabled"></a>spark.rapids.sql.join.leftAnti.enabled|When set to true left anti joins are enabled on the GPU|true
Expand Down
24 changes: 10 additions & 14 deletions integration_tests/src/main/python/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,27 +802,23 @@ def do_join(spark):
# If the condition is something like an AND, it makes the result a subset of a SemiJoin, and
# the optimizer won't use ExistenceJoin.
@ignore_order(local=True)
@pytest.mark.parametrize(
"allowFallback", [
pytest.param('true',
marks=pytest.mark.allow_non_gpu('SortMergeJoinExec')),
pytest.param('false',
marks=pytest.mark.xfail(reason="https://github.com/NVIDIA/spark-rapids/issues/589"))
], ids=idfn
)
def test_existence_join(allowFallback, spark_tmp_table_factory):
def test_existence_join(spark_tmp_table_factory):
leftTable = spark_tmp_table_factory.get()
rightTable = spark_tmp_table_factory.get()
def do_join(spark):
# create non-overlapping ranges to have a mix of exists=true and exists=false
spark.createDataFrame([v] for v in range(2, 10)).createOrReplaceTempView(leftTable)
spark.createDataFrame([v] for v in range(0, 8)).createOrReplaceTempView(rightTable)
spark.createDataFrame([f"left_{v}", v * 10, v * 100] for v in range(2, 10)).createOrReplaceTempView(leftTable)
spark.createDataFrame([f"right_{v}", v * 10, v * 100] for v in range(0, 8)).createOrReplaceTempView(rightTable)
jlowe marked this conversation as resolved.
Show resolved Hide resolved
res = spark.sql((
"select * "
"from {} as l "
"where l._1 < 0 "
" OR l._1 in (select * from {} as r)"
" or exists (select * from {} as r where r._2 = l._2 AND r._3 = l._3)"
).format(leftTable, rightTable))
res.explain(True)
gerashegalov marked this conversation as resolved.
Show resolved Hide resolved
return res
assert_cpu_and_gpu_are_equal_collect_with_capture(do_join, r".+Join ExistenceJoin\(exists#[0-9]+\).+")

assert_cpu_and_gpu_are_equal_collect_with_capture(do_join, r"ExistenceJoin\(exists#[0-9]+\)",
conf={
"spark.rapids.sql.explain" : "ALL",
jlowe marked this conversation as resolved.
Show resolved Hide resolved
"spark.sql.adaptive.enabled" : False
jlowe marked this conversation as resolved.
Show resolved Hide resolved
})
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import ai.rapids.cudf.{GatherMap, NvtxColor, OutOfBoundsPolicy}
import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.{InnerLike, JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, InnerLike, JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.vectorized.ColumnarBatch

/**
Expand Down Expand Up @@ -305,7 +305,7 @@ abstract class SplittableJoinIterator(
// In these cases, the map and the table are both the left side, and everything in the map
// is a match on the left table, so we don't want to check for bounds.
rightData.close()
JoinGatherer(lazyLeftMap, leftData, OutOfBoundsPolicy.DONT_CHECK)
JoinGatherer(lazyLeftMap, Some(leftData), OutOfBoundsPolicy.DONT_CHECK)
case Some(right) =>
// Inner joins -- manifest the intersection of both left and right sides. The gather maps
// contain the number of rows that must be manifested, and every index
Expand All @@ -330,9 +330,15 @@ abstract class SplittableJoinIterator(
case _ => OutOfBoundsPolicy.NULLIFY
}
val lazyRightMap = LazySpillableGatherMap(right, spillCallback, "right_map")
JoinGatherer(lazyLeftMap, leftData, lazyRightMap, rightData,
leftOutOfBoundsPolicy, rightOutOfBoundsPolicy)
}
joinType match {
case ExistenceJoin(_) =>
rightData.close()
JoinGatherer(lazyLeftMap, leftData, lazyRightMap, leftOutOfBoundsPolicy)
case _ =>
JoinGatherer(lazyLeftMap, leftData, lazyRightMap, rightData, leftOutOfBoundsPolicy,
rightOutOfBoundsPolicy)
}
}
if (gatherer.isDone) {
// Nothing matched...
gatherer.close()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,7 +16,7 @@

package com.nvidia.spark.rapids

import ai.rapids.cudf.{ColumnVector, ColumnView, DeviceMemoryBuffer, DType, GatherMap, NvtxColor, NvtxRange, OrderByArg, OutOfBoundsPolicy, Scalar, Table}
import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DeviceMemoryBuffer, DType, GatherMap, NvtxColor, NvtxRange, OrderByArg, OutOfBoundsPolicy, Scalar, Table}

import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand Down Expand Up @@ -136,21 +136,39 @@ trait JoinGatherer extends LazySpillable with Arm {

object JoinGatherer extends Arm {
def apply(gatherMap: LazySpillableGatherMap,
inputData: LazySpillableColumnarBatch,
outOfBoundsPolicy: OutOfBoundsPolicy): JoinGatherer =
new JoinGathererImpl(gatherMap, inputData, outOfBoundsPolicy)
inputData: Option[LazySpillableColumnarBatch],
outOfBoundsPolicy: OutOfBoundsPolicy
): JoinGatherer = {
new JoinGathererImpl(gatherMap, inputData, outOfBoundsPolicy)
}

def apply(leftMap: LazySpillableGatherMap,
leftData: LazySpillableColumnarBatch,
rightMap: LazySpillableGatherMap,
rightData: LazySpillableColumnarBatch,
outOfBoundsPolicyLeft: OutOfBoundsPolicy,
outOfBoundsPolicyRight: OutOfBoundsPolicy): JoinGatherer = {
val left = JoinGatherer(leftMap, leftData, outOfBoundsPolicyLeft)
val right = JoinGatherer(rightMap, rightData, outOfBoundsPolicyRight)
val left = JoinGatherer(leftMap, Some(leftData), outOfBoundsPolicyLeft)
val right = JoinGatherer(rightMap, Some(rightData), outOfBoundsPolicyRight)
MultiJoinGather(left, right)
}

// existence join gather
def apply(
leftMap: LazySpillableGatherMap,
leftData: LazySpillableColumnarBatch,
rightMap: LazySpillableGatherMap,
outOfBoundsPolicyLeft: OutOfBoundsPolicy): JoinGatherer = {
val left = JoinGatherer(leftMap, Some(leftData), outOfBoundsPolicyLeft)
val right = JoinGatherer(rightMap)
MultiJoinGather(left, right)
}

// existence column gather
def apply(map: LazySpillableGatherMap): JoinGatherer = {
JoinGatherer(map, None, OutOfBoundsPolicy.DONT_CHECK)
}

def getRowsInNextBatch(gatherer: JoinGatherer, targetSize: Long): Int = {
withResource(new NvtxRange("calc gather size", NvtxColor.YELLOW)) { _ =>
val rowsLeft = gatherer.numRowsLeft
Expand Down Expand Up @@ -496,28 +514,29 @@ object JoinGathererImpl {
*/
class JoinGathererImpl(
private val gatherMap: LazySpillableGatherMap,
private val data: LazySpillableColumnarBatch,
private val dataOpt: Option[LazySpillableColumnarBatch],
boundsCheckPolicy: OutOfBoundsPolicy) extends JoinGatherer {

assert(data.numCols > 0, "data with no columns should have been filtered out already")
assert(dataOpt.forall(_.numCols > 0),
"data with no columns should have been filtered out already")

// How much of the gather map we have output so far
private var gatheredUpTo: Long = 0
private val totalRows: Long = gatherMap.getRowCount
private val (fixedWidthRowSizeBits, nullRowSizeBits) = {
val dts = data.dataTypes
val dts = dataOpt.map(_.dataTypes).getOrElse(Array[DataType](BooleanType))
val fw = JoinGathererImpl.fixedWidthRowSizeBits(dts)
val nullVal = JoinGathererImpl.nullRowSizeBits(dts)
(fw, nullVal)
}

override def toString: String = {
s"GATHERER $gatheredUpTo/$totalRows $gatherMap $data"
s"GATHERER $gatheredUpTo/$totalRows $gatherMap $dataOpt"
}

override def realCheapPerRowSizeEstimate: Double = {
val totalInputRows: Int = data.numRows
val totalInputSize: Long = data.deviceMemorySize
val totalInputRows: Long = gatherMap.getRowCount
val totalInputSize: Long = dataOpt.map(_.deviceMemorySize).getOrElse(totalInputRows)
// Avoid divide by 0 here and later on
if (totalInputRows > 0 && totalInputSize > 0) {
totalInputSize.toDouble / totalInputRows
Expand All @@ -532,12 +551,21 @@ class JoinGathererImpl(
val start = gatheredUpTo
assert((start + n) <= totalRows)
val ret = withResource(gatherMap.toColumnView(start, n)) { gatherView =>
val batch = data.getBatch
val gatheredTable = withResource(GpuColumnVector.from(batch)) { table =>
table.gather(gatherView, boundsCheckPolicy)
}
withResource(gatheredTable) { gt =>
GpuColumnVector.from(gt, GpuColumnVector.extractTypes(batch))
dataOpt.map { data =>
val batch = data.getBatch
val gatheredTable = withResource(GpuColumnVector.from(batch)) { table =>
table.gather(gatherView, boundsCheckPolicy)
}
withResource(gatheredTable) { gt =>
GpuColumnVector.from(gt, GpuColumnVector.extractTypes(batch))
}
}.getOrElse {
withResource(gatherView.binaryOp(BinaryOp.GREATER,
revans2 marked this conversation as resolved.
Show resolved Hide resolved
Scalar.fromInt(Int.MinValue), DType.BOOL8)) { existenceColumn =>
revans2 marked this conversation as resolved.
Show resolved Hide resolved
withResource(new Table(existenceColumn)) { existenceTable =>
GpuColumnVector.from(existenceTable, Array[DataType](BooleanType))
}
}
}
}
gatheredUpTo += n
Expand All @@ -550,12 +578,14 @@ class JoinGathererImpl(
override def numRowsLeft: Long = totalRows - gatheredUpTo

override def allowSpilling(): Unit = {
data.allowSpilling()
dataOpt.foreach(_.allowSpilling())
gatherMap.allowSpilling()
}

override def getBitSizeMap(n: Int): ColumnView = {
val cb = data.getBatch
val cb = dataOpt.getOrElse{
throw new UnsupportedOperationException("unsupported on exists column gather")
}.getBatch
val inputBitCounts = withResource(GpuColumnVector.from(cb)) { table =>
withResource(table.rowBitCount()) { bits =>
bits.castTo(DType.INT64)
Expand Down Expand Up @@ -590,7 +620,7 @@ class JoinGathererImpl(

override def close(): Unit = {
gatherMap.close()
data.close()
dataOpt.foreach(_.close())
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,9 @@ object ExecutionPlanCaptureCallback {
case p if p.expressions.exists(containsExpression(_, className, regexMap)) =>
true
case p: SparkPlan =>
val sparkPlanStringForRegex = p.verboseStringWithSuffix(1000)
regexMap.getOrElseUpdate(className, className.r)
.findFirstIn(p.simpleStringWithNodeId())
.findFirstIn(sparkPlanStringForRegex)
.nonEmpty
}.nonEmpty
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,11 @@ object RapidsConf {
.booleanConf
.createWithDefault(true)

val ENABLE_EXISTENCE_JOIN = conf("spark.rapids.sql.join.existence.enabled")
.doc("When set to true existence joins are enabled on the GPU")
.booleanConf
.createWithDefault(true)

val ENABLE_PROJECT_AST = conf("spark.rapids.sql.projectAstEnabled")
.doc("Enable project operations to use cudf AST expressions when possible.")
.internal()
Expand Down Expand Up @@ -1603,6 +1608,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging {

lazy val areLeftAntiJoinsEnabled: Boolean = get(ENABLE_LEFT_ANTI_JOIN)

lazy val areExistenceJoinsEnabled: Boolean = get(ENABLE_EXISTENCE_JOIN)

lazy val isCastDecimalToFloatEnabled: Boolean = get(ENABLE_CAST_DECIMAL_TO_FLOAT)

lazy val isCastFloatToDecimalEnabled: Boolean = get(ENABLE_CAST_FLOAT_TO_DECIMAL)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,11 @@ class CrossJoinIterator(
case (_, 0) =>
rightBatch.close()
rightMap.close()
JoinGatherer(leftMap, leftBatch, OutOfBoundsPolicy.DONT_CHECK)
JoinGatherer(leftMap, Some(leftBatch), OutOfBoundsPolicy.DONT_CHECK)
case (0, _) =>
leftBatch.close()
leftMap.close()
JoinGatherer(rightMap, rightBatch, OutOfBoundsPolicy.DONT_CHECK)
JoinGatherer(rightMap, Some(rightBatch), OutOfBoundsPolicy.DONT_CHECK)
case (_, _) =>
JoinGatherer(leftMap, leftBatch, rightMap, rightBatch,
OutOfBoundsPolicy.DONT_CHECK, OutOfBoundsPolicy.DONT_CHECK)
Expand Down
Loading