diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuCartesianProductExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuCartesianProductExec.scala index fc6c9b89f28..508fb077a0f 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuCartesianProductExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuCartesianProductExec.scala @@ -18,8 +18,10 @@ package org.apache.spark.sql.rapids import java.io.{IOException, ObjectInputStream, ObjectOutputStream} +import scala.collection.mutable + import ai.rapids.cudf.{JCudfSerialization, NvtxColor, NvtxRange, Table} -import com.nvidia.spark.rapids.{Arm, GpuBindReferences, GpuBuildLeft, GpuColumnVector, GpuExec, GpuExpression, GpuMetric, GpuSemaphore, MetricsLevel} +import com.nvidia.spark.rapids.{Arm, GpuBindReferences, GpuBuildLeft, GpuColumnVector, GpuExec, GpuExpression, GpuMetric, GpuSemaphore, MetricsLevel, RapidsBuffer, SpillableColumnarBatch, SpillPriorities} import com.nvidia.spark.rapids.RapidsPluginImplicits._ import org.apache.spark.{Dependency, NarrowDependency, Partition, SparkContext, TaskContext} @@ -45,7 +47,7 @@ class GpuSerializableBatch(batch: ColumnarBatch) } private def writeObject(out: ObjectOutputStream): Unit = { - withResource (new NvtxRange("SerializeBatch", NvtxColor.PURPLE)) { _ => + withResource(new NvtxRange("SerializeBatch", NvtxColor.PURPLE)) { _ => if (internalBatch == null) { throw new IllegalStateException("Cannot re-serialize a batch this way...") } else { @@ -67,7 +69,7 @@ class GpuSerializableBatch(batch: ColumnarBatch) private def readObject(in: ObjectInputStream): Unit = { GpuSemaphore.acquireIfNecessary(TaskContext.get()) - withResource (new NvtxRange("DeserializeBatch", NvtxColor.PURPLE)) { _ => + withResource(new NvtxRange("DeserializeBatch", NvtxColor.PURPLE)) { _ => val schemaArray = in.readObject().asInstanceOf[Array[DataType]] withResource(JCudfSerialization.readTableFrom(in)) { tableInfo => val tmp = tableInfo.getTable @@ -116,8 +118,8 @@ class GpuCartesianRDD( numOutputBatches: GpuMetric, filterTime: GpuMetric, totalTime: GpuMetric, - var rdd1 : RDD[GpuSerializableBatch], - var rdd2 : RDD[GpuSerializableBatch]) + var rdd1: RDD[GpuSerializableBatch], + var rdd2: RDD[GpuSerializableBatch]) extends RDD[ColumnarBatch](sc, Nil) with Serializable with Arm { @@ -141,15 +143,49 @@ class GpuCartesianRDD( override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { val currSplit = split.asInstanceOf[GpuCartesianPartition] + + // create a buffer to cache stream-side data in a spillable manner + val spillBatchBuffer = mutable.ArrayBuffer[SpillableColumnarBatch]() + // sentinel variable to label whether stream-side data is cached or not + var streamSideCached = false + // a pointer to track buildTableOnFlight + var buildTableOnFlight: Option[Table] = None + + // Add a taskCompletionListener to ensure the release of GPU memory. This listener will work + // if the CompletionIterator does not fully iterate before the task completes, which may + // happen if there exists specific plans like `LimitExec`. + context.addTaskCompletionListener[Unit]((_: TaskContext) => { + spillBatchBuffer.safeClose() + buildTableOnFlight.foreach(_.close()) + }) + rdd1.iterator(currSplit.s1, context).flatMap { lhs => val table = withResource(lhs) { lhs => GpuColumnVector.from(lhs.getBatch) } - // Ideally instead of looping through and recomputing rdd2 for - // each batch in rdd1 we would instead cache rdd2 in a way that - // it could spill to disk so we can avoid re-computation + buildTableOnFlight = Some(table) + // Introduce sentinel `streamSideCached` to record whether stream-side data is cached or + // not, because predicate `spillBatchBuffer.isEmpty` will always be true if + // `rdd2.iterator` is an empty iterator. + val streamIterator = if (!streamSideCached) { + streamSideCached = true + // lazily compute and cache stream-side data + rdd2.iterator(currSplit.s2, context).map { serializableBatch => + closeOnExcept(spillBatchBuffer) { buffer => + val batch = SpillableColumnarBatch(serializableBatch.getBatch, + SpillPriorities.ACTIVE_ON_DECK_PRIORITY, + RapidsBuffer.defaultSpillCallback) + buffer += batch + batch.getColumnarBatch() + } + } + } else { + // fetch stream-side data directly if they are cached + spillBatchBuffer.toIterator.map(_.getColumnarBatch()) + } + val ret = GpuBroadcastNestedLoopJoinExecBase.innerLikeJoin( - rdd2.iterator(currSplit.s2, context).map(i => i.getBatch), + streamIterator, table, GpuBuildLeft, boundCondition, @@ -161,7 +197,14 @@ class GpuCartesianRDD( filterTime, totalTime) - CompletionIterator[ColumnarBatch, Iterator[ColumnarBatch]](ret, table.close()) + CompletionIterator[ColumnarBatch, Iterator[ColumnarBatch]](ret, { + // clean up spill batch buffer + spillBatchBuffer.safeClose() + spillBatchBuffer.clear() + // clean up build table + table.close() + buildTableOnFlight = None + }) } } @@ -190,10 +233,10 @@ object GpuNoColumnCrossJoin extends Arm { // Hash aggregate explodes the rows out, so if we go too large // it can blow up. The size of a Long is 8 bytes so we just go with // that as our estimate, no nulls. - val maxRowCount = targetSizeBytes/8 + val maxRowCount = targetSizeBytes / 8 def divideIntoBatches(rows: Long): Iterable[ColumnarBatch] = { - val numBatches = (rows + maxRowCount - 1)/maxRowCount + val numBatches = (rows + maxRowCount - 1) / maxRowCount (0L until numBatches).map(i => { val ret = new ColumnarBatch(new Array[ColumnVector](0)) if ((i + 1) * maxRowCount > rows) { @@ -232,9 +275,10 @@ case class GpuCartesianProductExec( right: SparkPlan, condition: Option[Expression], targetSizeBytes: Long) extends BinaryExecNode with GpuExec { + import GpuMetric._ - override def output: Seq[Attribute]= left.output ++ right.output + override def output: Seq[Attribute] = left.output ++ right.output override def verboseStringWithOperatorId(): String = { val joinCondStr = if (condition.isDefined) s"${condition.get}" else "None"