Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Use cudf to compute exact hash join output row sizes (#3288)" #3657

Merged
merged 1 commit into from
Sep 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,6 @@ abstract class AbstractGpuJoinIterator(
/** Returns whether there are any more batches on the stream side of the join */
protected def hasNextStreamBatch: Boolean

/**
* Called when a result batch is about to be produced. Any custom resources that were allocated
* to produce a result should be freed or made spillable here.
*/
protected def freeIntermediateResources(): Unit = {}

/**
* Called to setup the next join gatherer instance when the previous instance is done or
* there is no previous instance.
Expand Down Expand Up @@ -124,7 +118,6 @@ abstract class AbstractGpuJoinIterator(
if (ret.isDefined) {
// We are about to return something. We got everything we need from it so now let it spill
// if there is more to be gathered later on.
freeIntermediateResources()
gathererStore.foreach(_.allowSpilling())
}
ret
Expand Down Expand Up @@ -237,12 +230,25 @@ abstract class SplittableJoinIterator(
* the splits in the stream-side input
* @param cb stream-side input batch to split
* @param numBatches number of splits to produce with approximately the same number of rows each
* @param oom a prior OOM exception that this will try to recover from by splitting
*/
protected def splitAndSave(
cb: ColumnarBatch,
numBatches: Int): Unit = {
numBatches: Int,
oom: Option[OutOfMemoryError] = None): Unit = {
val batchSize = cb.numRows() / numBatches
logInfo(s"Split stream batch into $numBatches batches of about $batchSize rows")
if (oom.isDefined && batchSize < 100) {
// We just need some kind of cutoff to not get stuck in a loop if the batches get to be too
// small but we want to at least give it a chance to work (mostly for tests where the
// targetSize can be set really small)
throw oom.get
}
val msg = s"Split stream batch into $numBatches batches of about $batchSize rows"
if (oom.isDefined) {
logWarning(s"OOM Encountered: $msg")
} else {
logInfo(msg)
}
val splits = withResource(GpuColumnVector.from(cb)) { tab =>
val splitIndexes = (1 until numBatches).map(num => num * batchSize)
tab.contiguousSplit(splitIndexes: _*)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
package org.apache.spark.sql.rapids.execution

import ai.rapids.cudf.{HashJoin, NvtxColor, Table}
import ai.rapids.cudf.{DType, GroupByAggregation, NullPolicy, NvtxColor, ReductionAggregation, Table}
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.RapidsBuffer.SpillCallback

Expand Down Expand Up @@ -233,145 +233,145 @@ class HashJoinIterator(
joinTime = joinTime,
streamTime = streamTime,
totalTime = totalTime) {
private var builtHash: Option[HashJoin] = None
private var streamKeysTable: Option[Table] = None

override def close(): Unit = {
super.close()
freeIntermediateResources()
}

override def freeIntermediateResources(): Unit = {
builtHash.foreach(_.close())
builtHash = None
streamKeysTable.foreach(_.close())
streamKeysTable = None
}

override def setupNextGatherer(startNanoTime: Long): Option[JoinGatherer] = {
// fetching the next stream batch invalidates any previous stream keys
streamKeysTable.foreach(_.close())
streamKeysTable = None
super.setupNextGatherer(startNanoTime)
}

override def splitAndSave(cb: ColumnarBatch, numBatches: Int): Unit = {
// splitting the stream batch invalidates the stream keys
streamKeysTable.foreach(_.close())
streamKeysTable = None
super.splitAndSave(cb, numBatches)
// We can cache this because the build side is not changing
private lazy val streamMagnificationFactor = joinType match {
case _: InnerLike | LeftOuter | RightOuter =>
withResource(GpuProjectExec.project(built.getBatch, boundBuiltKeys)) { builtKeys =>
guessStreamMagnificationFactor(builtKeys)
}
case _ =>
// existence joins don't change size, and FullOuter cannot be split
1.0
}

override def computeNumJoinRows(cb: ColumnarBatch): Long = {
withResource(new NvtxWithMetrics("hash join build", NvtxColor.ORANGE, joinTime)) { _ =>
joinType match {
case LeftSemi | LeftAnti | FullOuter =>
// Semi or Anti joins do not explode, worst-case they return the same number of rows.
// For full joins, currently we only support the entire stream table at once. There's no
// reason to predict the output rows since the stream batch is not splittable.
// The cudf API to return a full outer join row count performs excess computation, so
// just returns an answer here that will be ignored later.
cb.numRows()
case _ =>
val hashTable = maybeBuildHashTable()
assert(streamKeysTable.isEmpty, "stream keys table already exists")
val streamKeys = maybeProjectStreamKeys(cb)
joinType match {
case _: InnerLike =>
streamKeys.innerJoinRowCount(hashTable)
case LeftOuter =>
assert(buildSide == GpuBuildRight, s"$joinType with $buildSide")
streamKeys.leftJoinRowCount(hashTable)
case RightOuter =>
assert(buildSide == GpuBuildLeft, s"$joinType with $buildSide")
streamKeys.leftJoinRowCount(hashTable)
case _ =>
throw new IllegalStateException(s"unexpected join type: $joinType")
}
}
// TODO: Replace this estimate with exact join row counts using the corresponding cudf APIs
// being added in https://github.com/rapidsai/cudf/issues/9053.
joinType match {
case _: InnerLike | LeftOuter | RightOuter =>
Math.ceil(cb.numRows() * streamMagnificationFactor).toLong
case _ => cb.numRows()
}
}

override def createGatherer(
cb: ColumnarBatch,
numJoinRows: Option[Long]): Option[JoinGatherer] = {
withResource(new NvtxWithMetrics("hash join gather maps", NvtxColor.ORANGE, joinTime)) { _ =>
try {
withResource(GpuProjectExec.project(built.getBatch, boundBuiltKeys)) { builtKeys =>
joinGatherer(builtKeys, built, cb)
}
} catch {
// This should work for all join types except for FullOuter. There should be no need
// to do this for any of the existence joins because the output rows will never be
// larger than the input rows on the stream side.
case oom: OutOfMemoryError if joinType.isInstanceOf[InnerLike]
|| joinType == LeftOuter
|| joinType == RightOuter =>
// Because this is just an estimate, it is possible for us to get this wrong, so
// make sure we at least split the batch in half.
val numBatches = Math.max(2, estimatedNumBatches(cb))

// Split batch and return no gatherer so the outer loop will try again
splitAndSave(cb, numBatches, Some(oom))
None
}
}

private def joinGathererLeftRight(
leftKeys: Table,
leftData: LazySpillableColumnarBatch,
rightKeys: Table,
rightData: LazySpillableColumnarBatch): Option[JoinGatherer] = {
withResource(new NvtxWithMetrics("hash join gather map", NvtxColor.ORANGE, joinTime)) { _ =>
val maps = joinType match {
case FullOuter | LeftSemi | LeftAnti =>
assert(builtHash.isEmpty, s"$joinType but somehow precomputed a hash table")
assert(streamKeysTable.isEmpty, s"$joinType but somehow already projected stream keys")
withResource(GpuProjectExec.project(built.getBatch, boundBuiltKeys)) { builtKeys =>
withResource(GpuColumnVector.from(builtKeys)) { builtTable =>
withResource(GpuProjectExec.project(cb, boundStreamKeys)) { streamKeys =>
withResource(GpuColumnVector.from(streamKeys)) { streamTable =>
joinType match {
case FullOuter =>
val maps = streamTable.fullJoinGatherMaps(builtTable, compareNullsEqual)
if (buildSide == GpuBuildLeft) maps.reverse else maps
case LeftSemi =>
assert(buildSide == GpuBuildRight, s"$joinType with $buildSide")
Array(streamTable.leftSemiJoinGatherMap(builtTable, compareNullsEqual))
case LeftAnti =>
assert(buildSide == GpuBuildRight, s"$joinType with $buildSide")
Array(streamTable.leftAntiJoinGatherMap(builtTable, compareNullsEqual))
case _ =>
throw new IllegalStateException(s"Unexpected join type $joinType")
}
}
}
}
}
case LeftOuter => leftKeys.leftJoinGatherMaps(rightKeys, compareNullsEqual)
case RightOuter =>
// Reverse the output of the join, because we expect the right gather map to
// always be on the right
rightKeys.leftJoinGatherMaps(leftKeys, compareNullsEqual).reverse
case _: InnerLike => leftKeys.innerJoinGatherMaps(rightKeys, compareNullsEqual)
case LeftSemi => Array(leftKeys.leftSemiJoinGatherMap(rightKeys, compareNullsEqual))
case LeftAnti => Array(leftKeys.leftAntiJoinGatherMap(rightKeys, compareNullsEqual))
case FullOuter => leftKeys.fullJoinGatherMaps(rightKeys, compareNullsEqual)
case _ =>
val hashTable = maybeBuildHashTable()
val streamKeys = maybeProjectStreamKeys(cb)
joinType match {
case _: InnerLike =>
val maps = streamKeys.innerJoinGatherMaps(hashTable)
if (buildSide == GpuBuildLeft) maps.reverse else maps
case LeftOuter =>
assert(buildSide == GpuBuildRight, s"$joinType with $buildSide")
streamKeys.leftJoinGatherMaps(hashTable)
case RightOuter =>
assert(buildSide == GpuBuildLeft, s"$joinType with $buildSide")
// Reverse the output of the join, because we expect the right gather map to
// always be on the right
streamKeys.leftJoinGatherMaps(hashTable).reverse
case _ =>
throw new IllegalStateException(s"Unexpected join type: $joinType")
}
throw new NotImplementedError(s"Joint Type ${joinType.getClass} is not currently" +
s" supported")
}
closeOnExcept(LazySpillableColumnarBatch(cb, spillCallback, "stream_data")) { streamData =>
val buildData = LazySpillableColumnarBatch.spillOnly(built)
val (left, right) = buildSide match {
case GpuBuildLeft => (buildData, streamData)
case GpuBuildRight => (streamData, buildData)
}
makeGatherer(maps, left, right)
makeGatherer(maps, leftData, rightData)
}
}

private def joinGathererLeftRight(
leftKeys: ColumnarBatch,
leftData: LazySpillableColumnarBatch,
rightKeys: ColumnarBatch,
rightData: LazySpillableColumnarBatch): Option[JoinGatherer] = {
withResource(GpuColumnVector.from(leftKeys)) { leftKeysTab =>
withResource(GpuColumnVector.from(rightKeys)) { rightKeysTab =>
joinGathererLeftRight(leftKeysTab, leftData, rightKeysTab, rightData)
}
}
}

private def maybeBuildHashTable(): HashJoin = {
builtHash.getOrElse {
withResource(GpuProjectExec.project(built.getBatch, boundBuiltKeys)) { builtKeys =>
withResource(GpuColumnVector.from(builtKeys)) { keysTable =>
val hashTable = new HashJoin(keysTable, compareNullsEqual)
builtHash = Some(hashTable)
hashTable
}
private def joinGatherer(
buildKeys: ColumnarBatch,
buildData: LazySpillableColumnarBatch,
streamKeys: ColumnarBatch,
streamData: LazySpillableColumnarBatch): Option[JoinGatherer] = {
buildSide match {
case GpuBuildLeft =>
joinGathererLeftRight(buildKeys, buildData, streamKeys, streamData)
case GpuBuildRight =>
joinGathererLeftRight(streamKeys, streamData, buildKeys, buildData)
}
}

private def joinGatherer(
buildKeys: ColumnarBatch,
buildData: LazySpillableColumnarBatch,
streamCb: ColumnarBatch): Option[JoinGatherer] = {
withResource(GpuProjectExec.project(streamCb, boundStreamKeys)) { streamKeys =>
closeOnExcept(LazySpillableColumnarBatch(streamCb, spillCallback, "stream_data")) { sd =>
joinGatherer(buildKeys, LazySpillableColumnarBatch.spillOnly(buildData), streamKeys, sd)
}
}
}

private def maybeProjectStreamKeys(cb: ColumnarBatch): Table = {
streamKeysTable.getOrElse {
withResource(GpuProjectExec.project(cb, boundStreamKeys)) { streamKeysBatch =>
val table = GpuColumnVector.from(streamKeysBatch)
streamKeysTable = Some(table)
table
private def countGroups(keys: ColumnarBatch): Table = {
withResource(GpuColumnVector.from(keys)) { keysTable =>
keysTable.groupBy(0 until keysTable.getNumberOfColumns: _*)
.aggregate(GroupByAggregation.count(NullPolicy.INCLUDE).onColumn(0))
}
}

/**
* Guess the magnification factor for a stream side batch.
* This is temporary until cudf gives us APIs to get the actual gather map size.
*/
private def guessStreamMagnificationFactor(builtKeys: ColumnarBatch): Double = {
// Based off of the keys on the build side guess at how many output rows there
// will be for each input row on the stream side. This does not take into account
// the join type, data skew or even if the keys actually match.
withResource(countGroups(builtKeys)) { builtCount =>
val counts = builtCount.getColumn(builtCount.getNumberOfColumns - 1)
withResource(counts.reduce(ReductionAggregation.mean(), DType.FLOAT64)) { scalarAverage =>
scalarAverage.getDouble
}
}
}

private def estimatedNumBatches(cb: ColumnarBatch): Int = joinType match {
case _: InnerLike | LeftOuter | RightOuter =>
// We want the gather map size to be around the target size. There are two gather maps
// that are made up of ints, so estimate how many rows per batch on the stream side
// will produce the desired gather map size.
val approximateStreamRowCount = ((targetSize.toDouble / 2) /
DType.INT32.getSizeInBytes) / streamMagnificationFactor
val estimatedRowsPerStreamBatch = Math.min(Int.MaxValue, approximateStreamRowCount)
Math.ceil(cb.numRows() / estimatedRowsPerStreamBatch).toInt
case _ => 1
}
}

trait GpuHashJoin extends GpuExec {
Expand Down