Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

spillable cache for GpuCartesianRDD #1878

Merged
merged 5 commits into from
Mar 16, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ package org.apache.spark.sql.rapids

import java.io.{IOException, ObjectInputStream, ObjectOutputStream}

import scala.collection.mutable

import ai.rapids.cudf.{JCudfSerialization, NvtxColor, NvtxRange, Table}
import com.nvidia.spark.rapids.{Arm, GpuBindReferences, GpuBuildLeft, GpuColumnVector, GpuExec, GpuExpression, GpuMetric, GpuSemaphore, MetricsLevel}
import com.nvidia.spark.rapids.{Arm, GpuBindReferences, GpuBuildLeft, GpuColumnVector, GpuExec, GpuExpression, GpuMetric, GpuSemaphore, MetricsLevel, RapidsBuffer, SpillableColumnarBatch, SpillPriorities}
import com.nvidia.spark.rapids.RapidsPluginImplicits._

import org.apache.spark.{Dependency, NarrowDependency, Partition, SparkContext, TaskContext}
Expand All @@ -45,7 +47,7 @@ class GpuSerializableBatch(batch: ColumnarBatch)
}

private def writeObject(out: ObjectOutputStream): Unit = {
withResource (new NvtxRange("SerializeBatch", NvtxColor.PURPLE)) { _ =>
withResource(new NvtxRange("SerializeBatch", NvtxColor.PURPLE)) { _ =>
if (internalBatch == null) {
throw new IllegalStateException("Cannot re-serialize a batch this way...")
} else {
Expand All @@ -67,7 +69,7 @@ class GpuSerializableBatch(batch: ColumnarBatch)

private def readObject(in: ObjectInputStream): Unit = {
GpuSemaphore.acquireIfNecessary(TaskContext.get())
withResource (new NvtxRange("DeserializeBatch", NvtxColor.PURPLE)) { _ =>
withResource(new NvtxRange("DeserializeBatch", NvtxColor.PURPLE)) { _ =>
val schemaArray = in.readObject().asInstanceOf[Array[DataType]]
withResource(JCudfSerialization.readTableFrom(in)) { tableInfo =>
val tmp = tableInfo.getTable
Expand Down Expand Up @@ -116,8 +118,8 @@ class GpuCartesianRDD(
numOutputBatches: GpuMetric,
filterTime: GpuMetric,
totalTime: GpuMetric,
var rdd1 : RDD[GpuSerializableBatch],
var rdd2 : RDD[GpuSerializableBatch])
var rdd1: RDD[GpuSerializableBatch],
var rdd2: RDD[GpuSerializableBatch])
extends RDD[ColumnarBatch](sc, Nil)
with Serializable with Arm {

Expand All @@ -141,15 +143,49 @@ class GpuCartesianRDD(
override def compute(split: Partition, context: TaskContext):
Iterator[ColumnarBatch] = {
val currSplit = split.asInstanceOf[GpuCartesianPartition]

// create a buffer to cache stream-side data in a spillable manner
val spillBatchBuffer = mutable.ArrayBuffer[SpillableColumnarBatch]()
// sentinel variable to label whether stream-side data is cached or not
var streamSideCached = false
// a pointer to track buildTableOnFlight
var buildTableOnFlight: Option[Table] = None

// Add a taskCompletionListener to ensure the release of GPU memory. This listener will work
// if the CompletionIterator does not fully iterate before the task completes, which may
// happen if there exists specific plans like `LimitExec`.
context.addTaskCompletionListener[Unit]((_: TaskContext) => {
spillBatchBuffer.safeClose()
buildTableOnFlight.foreach(_.close())
})

rdd1.iterator(currSplit.s1, context).flatMap { lhs =>
val table = withResource(lhs) { lhs =>
GpuColumnVector.from(lhs.getBatch)
}
// Ideally instead of looping through and recomputing rdd2 for
// each batch in rdd1 we would instead cache rdd2 in a way that
// it could spill to disk so we can avoid re-computation
buildTableOnFlight = Some(table)
// Introduce sentinel `streamSideCached` to record whether stream-side data is cached or
// not, because predicate `spillBatchBuffer.isEmpty` will always be true if
// `rdd2.iterator` is an empty iterator.
val streamIterator = if (!streamSideCached) {
streamSideCached = true
// lazily compute and cache stream-side data
rdd2.iterator(currSplit.s2, context).map { serializableBatch =>
closeOnExcept(spillBatchBuffer) { buffer =>
val batch = SpillableColumnarBatch(serializableBatch.getBatch,
SpillPriorities.ACTIVE_ON_DECK_PRIORITY,
RapidsBuffer.defaultSpillCallback)
buffer += batch
batch.getColumnarBatch()
}
}
} else {
// fetch stream-side data directly if they are cached
spillBatchBuffer.toIterator.map(_.getColumnarBatch())
}

val ret = GpuBroadcastNestedLoopJoinExecBase.innerLikeJoin(
rdd2.iterator(currSplit.s2, context).map(i => i.getBatch),
streamIterator,
table,
GpuBuildLeft,
boundCondition,
Expand All @@ -161,7 +197,14 @@ class GpuCartesianRDD(
filterTime,
totalTime)

CompletionIterator[ColumnarBatch, Iterator[ColumnarBatch]](ret, table.close())
CompletionIterator[ColumnarBatch, Iterator[ColumnarBatch]](ret, {
jlowe marked this conversation as resolved.
Show resolved Hide resolved
// clean up spill batch buffer
spillBatchBuffer.safeClose()
spillBatchBuffer.clear()
// clean up build table
table.close()
buildTableOnFlight = None
})
}
}

Expand Down Expand Up @@ -190,10 +233,10 @@ object GpuNoColumnCrossJoin extends Arm {
// Hash aggregate explodes the rows out, so if we go too large
// it can blow up. The size of a Long is 8 bytes so we just go with
// that as our estimate, no nulls.
val maxRowCount = targetSizeBytes/8
val maxRowCount = targetSizeBytes / 8
jlowe marked this conversation as resolved.
Show resolved Hide resolved

def divideIntoBatches(rows: Long): Iterable[ColumnarBatch] = {
val numBatches = (rows + maxRowCount - 1)/maxRowCount
val numBatches = (rows + maxRowCount - 1) / maxRowCount
(0L until numBatches).map(i => {
val ret = new ColumnarBatch(new Array[ColumnVector](0))
if ((i + 1) * maxRowCount > rows) {
Expand Down Expand Up @@ -232,9 +275,10 @@ case class GpuCartesianProductExec(
right: SparkPlan,
condition: Option[Expression],
targetSizeBytes: Long) extends BinaryExecNode with GpuExec {

import GpuMetric._

override def output: Seq[Attribute]= left.output ++ right.output
override def output: Seq[Attribute] = left.output ++ right.output

override def verboseStringWithOperatorId(): String = {
val joinCondStr = if (condition.isDefined) s"${condition.get}" else "None"
Expand Down