From 25bad3d95ab3c464eb04df46e8a9300f81e75de3 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Fri, 27 Aug 2021 11:42:52 -0500 Subject: [PATCH] Use cudf to compute exact hash join output row sizes (#3288) Signed-off-by: Jason Lowe --- .../rapids/AbstractGpuJoinIterator.scala | 24 +- .../sql/rapids/execution/GpuHashJoin.scala | 234 +++++++++--------- 2 files changed, 126 insertions(+), 132 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AbstractGpuJoinIterator.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AbstractGpuJoinIterator.scala index 0f412a5781c..b4b9191a0f7 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AbstractGpuJoinIterator.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AbstractGpuJoinIterator.scala @@ -48,6 +48,12 @@ abstract class AbstractGpuJoinIterator( /** Returns whether there are any more batches on the stream side of the join */ protected def hasNextStreamBatch: Boolean + /** + * Called when a result batch is about to be produced. Any custom resources that were allocated + * to produce a result should be freed or made spillable here. + */ + protected def freeIntermediateResources(): Unit = {} + /** * Called to setup the next join gatherer instance when the previous instance is done or * there is no previous instance. @@ -118,6 +124,7 @@ abstract class AbstractGpuJoinIterator( if (ret.isDefined) { // We are about to return something. We got everything we need from it so now let it spill // if there is more to be gathered later on. + freeIntermediateResources() gathererStore.foreach(_.allowSpilling()) } ret @@ -230,25 +237,12 @@ abstract class SplittableJoinIterator( * the splits in the stream-side input * @param cb stream-side input batch to split * @param numBatches number of splits to produce with approximately the same number of rows each - * @param oom a prior OOM exception that this will try to recover from by splitting */ protected def splitAndSave( cb: ColumnarBatch, - numBatches: Int, - oom: Option[OutOfMemoryError] = None): Unit = { + numBatches: Int): Unit = { val batchSize = cb.numRows() / numBatches - if (oom.isDefined && batchSize < 100) { - // We just need some kind of cutoff to not get stuck in a loop if the batches get to be too - // small but we want to at least give it a chance to work (mostly for tests where the - // targetSize can be set really small) - throw oom.get - } - val msg = s"Split stream batch into $numBatches batches of about $batchSize rows" - if (oom.isDefined) { - logWarning(s"OOM Encountered: $msg") - } else { - logInfo(msg) - } + logInfo(s"Split stream batch into $numBatches batches of about $batchSize rows") val splits = withResource(GpuColumnVector.from(cb)) { tab => val splitIndexes = (1 until numBatches).map(num => num * batchSize) tab.contiguousSplit(splitIndexes: _*) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala index f5f350ee91a..41a866ccc8a 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala @@ -15,7 +15,7 @@ */ package org.apache.spark.sql.rapids.execution -import ai.rapids.cudf.{DType, GroupByAggregation, NullPolicy, NvtxColor, ReductionAggregation, Table} +import ai.rapids.cudf.{HashJoin, NvtxColor, Table} import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.RapidsBuffer.SpillCallback @@ -233,145 +233,145 @@ class HashJoinIterator( joinTime = joinTime, streamTime = streamTime, totalTime = totalTime) { - // We can cache this because the build side is not changing - private lazy val streamMagnificationFactor = joinType match { - case _: InnerLike | LeftOuter | RightOuter => - withResource(GpuProjectExec.project(built.getBatch, boundBuiltKeys)) { builtKeys => - guessStreamMagnificationFactor(builtKeys) - } - case _ => - // existence joins don't change size, and FullOuter cannot be split - 1.0 + private var builtHash: Option[HashJoin] = None + private var streamKeysTable: Option[Table] = None + + override def close(): Unit = { + super.close() + freeIntermediateResources() + } + + override def freeIntermediateResources(): Unit = { + builtHash.foreach(_.close()) + builtHash = None + streamKeysTable.foreach(_.close()) + streamKeysTable = None + } + + override def setupNextGatherer(startNanoTime: Long): Option[JoinGatherer] = { + // fetching the next stream batch invalidates any previous stream keys + streamKeysTable.foreach(_.close()) + streamKeysTable = None + super.setupNextGatherer(startNanoTime) + } + + override def splitAndSave(cb: ColumnarBatch, numBatches: Int): Unit = { + // splitting the stream batch invalidates the stream keys + streamKeysTable.foreach(_.close()) + streamKeysTable = None + super.splitAndSave(cb, numBatches) } override def computeNumJoinRows(cb: ColumnarBatch): Long = { - // TODO: Replace this estimate with exact join row counts using the corresponding cudf APIs - // being added in https://github.com/rapidsai/cudf/issues/9053. - joinType match { - case _: InnerLike | LeftOuter | RightOuter => - Math.ceil(cb.numRows() * streamMagnificationFactor).toLong - case _ => cb.numRows() + withResource(new NvtxWithMetrics("hash join build", NvtxColor.ORANGE, joinTime)) { _ => + joinType match { + case LeftSemi | LeftAnti | FullOuter => + // Semi or Anti joins do not explode, worst-case they return the same number of rows. + // For full joins, currently we only support the entire stream table at once. There's no + // reason to predict the output rows since the stream batch is not splittable. + // The cudf API to return a full outer join row count performs excess computation, so + // just returns an answer here that will be ignored later. + cb.numRows() + case _ => + val hashTable = maybeBuildHashTable() + assert(streamKeysTable.isEmpty, "stream keys table already exists") + val streamKeys = maybeProjectStreamKeys(cb) + joinType match { + case _: InnerLike => + streamKeys.innerJoinRowCount(hashTable) + case LeftOuter => + assert(buildSide == GpuBuildRight, s"$joinType with $buildSide") + streamKeys.leftJoinRowCount(hashTable) + case RightOuter => + assert(buildSide == GpuBuildLeft, s"$joinType with $buildSide") + streamKeys.leftJoinRowCount(hashTable) + case _ => + throw new IllegalStateException(s"unexpected join type: $joinType") + } + } } } override def createGatherer( cb: ColumnarBatch, numJoinRows: Option[Long]): Option[JoinGatherer] = { - try { - withResource(GpuProjectExec.project(built.getBatch, boundBuiltKeys)) { builtKeys => - joinGatherer(builtKeys, built, cb) - } - } catch { - // This should work for all join types except for FullOuter. There should be no need - // to do this for any of the existence joins because the output rows will never be - // larger than the input rows on the stream side. - case oom: OutOfMemoryError if joinType.isInstanceOf[InnerLike] - || joinType == LeftOuter - || joinType == RightOuter => - // Because this is just an estimate, it is possible for us to get this wrong, so - // make sure we at least split the batch in half. - val numBatches = Math.max(2, estimatedNumBatches(cb)) - - // Split batch and return no gatherer so the outer loop will try again - splitAndSave(cb, numBatches, Some(oom)) - None - } - } - - private def joinGathererLeftRight( - leftKeys: Table, - leftData: LazySpillableColumnarBatch, - rightKeys: Table, - rightData: LazySpillableColumnarBatch): Option[JoinGatherer] = { - withResource(new NvtxWithMetrics("hash join gather map", NvtxColor.ORANGE, joinTime)) { _ => + withResource(new NvtxWithMetrics("hash join gather maps", NvtxColor.ORANGE, joinTime)) { _ => val maps = joinType match { - case LeftOuter => leftKeys.leftJoinGatherMaps(rightKeys, compareNullsEqual) - case RightOuter => - // Reverse the output of the join, because we expect the right gather map to - // always be on the right - rightKeys.leftJoinGatherMaps(leftKeys, compareNullsEqual).reverse - case _: InnerLike => leftKeys.innerJoinGatherMaps(rightKeys, compareNullsEqual) - case LeftSemi => Array(leftKeys.leftSemiJoinGatherMap(rightKeys, compareNullsEqual)) - case LeftAnti => Array(leftKeys.leftAntiJoinGatherMap(rightKeys, compareNullsEqual)) - case FullOuter => leftKeys.fullJoinGatherMaps(rightKeys, compareNullsEqual) + case FullOuter | LeftSemi | LeftAnti => + assert(builtHash.isEmpty, s"$joinType but somehow precomputed a hash table") + assert(streamKeysTable.isEmpty, s"$joinType but somehow already projected stream keys") + withResource(GpuProjectExec.project(built.getBatch, boundBuiltKeys)) { builtKeys => + withResource(GpuColumnVector.from(builtKeys)) { builtTable => + withResource(GpuProjectExec.project(cb, boundStreamKeys)) { streamKeys => + withResource(GpuColumnVector.from(streamKeys)) { streamTable => + joinType match { + case FullOuter => + val maps = streamTable.fullJoinGatherMaps(builtTable, compareNullsEqual) + if (buildSide == GpuBuildLeft) maps.reverse else maps + case LeftSemi => + assert(buildSide == GpuBuildRight, s"$joinType with $buildSide") + Array(streamTable.leftSemiJoinGatherMap(builtTable, compareNullsEqual)) + case LeftAnti => + assert(buildSide == GpuBuildRight, s"$joinType with $buildSide") + Array(streamTable.leftAntiJoinGatherMap(builtTable, compareNullsEqual)) + case _ => + throw new IllegalStateException(s"Unexpected join type $joinType") + } + } + } + } + } case _ => - throw new NotImplementedError(s"Joint Type ${joinType.getClass} is not currently" + - s" supported") + val hashTable = maybeBuildHashTable() + val streamKeys = maybeProjectStreamKeys(cb) + joinType match { + case _: InnerLike => + val maps = streamKeys.innerJoinGatherMaps(hashTable) + if (buildSide == GpuBuildLeft) maps.reverse else maps + case LeftOuter => + assert(buildSide == GpuBuildRight, s"$joinType with $buildSide") + streamKeys.leftJoinGatherMaps(hashTable) + case RightOuter => + assert(buildSide == GpuBuildLeft, s"$joinType with $buildSide") + // Reverse the output of the join, because we expect the right gather map to + // always be on the right + streamKeys.leftJoinGatherMaps(hashTable).reverse + case _ => + throw new IllegalStateException(s"Unexpected join type: $joinType") + } } - makeGatherer(maps, leftData, rightData) - } - } - - private def joinGathererLeftRight( - leftKeys: ColumnarBatch, - leftData: LazySpillableColumnarBatch, - rightKeys: ColumnarBatch, - rightData: LazySpillableColumnarBatch): Option[JoinGatherer] = { - withResource(GpuColumnVector.from(leftKeys)) { leftKeysTab => - withResource(GpuColumnVector.from(rightKeys)) { rightKeysTab => - joinGathererLeftRight(leftKeysTab, leftData, rightKeysTab, rightData) + closeOnExcept(LazySpillableColumnarBatch(cb, spillCallback, "stream_data")) { streamData => + val buildData = LazySpillableColumnarBatch.spillOnly(built) + val (left, right) = buildSide match { + case GpuBuildLeft => (buildData, streamData) + case GpuBuildRight => (streamData, buildData) + } + makeGatherer(maps, left, right) } } } - private def joinGatherer( - buildKeys: ColumnarBatch, - buildData: LazySpillableColumnarBatch, - streamKeys: ColumnarBatch, - streamData: LazySpillableColumnarBatch): Option[JoinGatherer] = { - buildSide match { - case GpuBuildLeft => - joinGathererLeftRight(buildKeys, buildData, streamKeys, streamData) - case GpuBuildRight => - joinGathererLeftRight(streamKeys, streamData, buildKeys, buildData) - } - } - - private def joinGatherer( - buildKeys: ColumnarBatch, - buildData: LazySpillableColumnarBatch, - streamCb: ColumnarBatch): Option[JoinGatherer] = { - withResource(GpuProjectExec.project(streamCb, boundStreamKeys)) { streamKeys => - closeOnExcept(LazySpillableColumnarBatch(streamCb, spillCallback, "stream_data")) { sd => - joinGatherer(buildKeys, LazySpillableColumnarBatch.spillOnly(buildData), streamKeys, sd) + private def maybeBuildHashTable(): HashJoin = { + builtHash.getOrElse { + withResource(GpuProjectExec.project(built.getBatch, boundBuiltKeys)) { builtKeys => + withResource(GpuColumnVector.from(builtKeys)) { keysTable => + val hashTable = new HashJoin(keysTable, compareNullsEqual) + builtHash = Some(hashTable) + hashTable + } } } } - private def countGroups(keys: ColumnarBatch): Table = { - withResource(GpuColumnVector.from(keys)) { keysTable => - keysTable.groupBy(0 until keysTable.getNumberOfColumns: _*) - .aggregate(GroupByAggregation.count(NullPolicy.INCLUDE).onColumn(0)) - } - } - - /** - * Guess the magnification factor for a stream side batch. - * This is temporary until cudf gives us APIs to get the actual gather map size. - */ - private def guessStreamMagnificationFactor(builtKeys: ColumnarBatch): Double = { - // Based off of the keys on the build side guess at how many output rows there - // will be for each input row on the stream side. This does not take into account - // the join type, data skew or even if the keys actually match. - withResource(countGroups(builtKeys)) { builtCount => - val counts = builtCount.getColumn(builtCount.getNumberOfColumns - 1) - withResource(counts.reduce(ReductionAggregation.mean(), DType.FLOAT64)) { scalarAverage => - scalarAverage.getDouble + private def maybeProjectStreamKeys(cb: ColumnarBatch): Table = { + streamKeysTable.getOrElse { + withResource(GpuProjectExec.project(cb, boundStreamKeys)) { streamKeysBatch => + val table = GpuColumnVector.from(streamKeysBatch) + streamKeysTable = Some(table) + table } } } - - private def estimatedNumBatches(cb: ColumnarBatch): Int = joinType match { - case _: InnerLike | LeftOuter | RightOuter => - // We want the gather map size to be around the target size. There are two gather maps - // that are made up of ints, so estimate how many rows per batch on the stream side - // will produce the desired gather map size. - val approximateStreamRowCount = ((targetSize.toDouble / 2) / - DType.INT32.getSizeInBytes) / streamMagnificationFactor - val estimatedRowsPerStreamBatch = Math.min(Int.MaxValue, approximateStreamRowCount) - Math.ceil(cb.numRows() / estimatedRowsPerStreamBatch).toInt - case _ => 1 - } } trait GpuHashJoin extends GpuExec {