Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
Signed-off-by: Firestarman <firestarmanllc@gmail.com>
  • Loading branch information
firestarman committed Sep 7, 2023
1 parent f5bc430 commit a66b5dc
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 73 deletions.
139 changes: 67 additions & 72 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/limit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,72 @@ import org.apache.spark.sql.execution.{CollectLimitExec, LimitExec, SparkPlan}
import org.apache.spark.sql.execution.exchange.ENSURE_REQUIREMENTS
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}

object GpuBaseLimitExec {
class GpuBaseLimitIterator(
input: Iterator[ColumnarBatch],
limit: Int,
offset: Int,
opTime: GpuMetric,
numOutputBatches: GpuMetric,
numOutputRows: GpuMetric) extends Iterator[ColumnarBatch] {
private var remainingLimit = limit - offset
private var remainingOffset = offset

override def hasNext: Boolean = (limit == -1 || remainingLimit > 0) && input.hasNext

override def next(): ColumnarBatch = {
if (!this.hasNext) {
throw new NoSuchElementException("Next on empty iterator")
}

var batch = input.next()
val numCols = batch.numCols()

// In each partition, we need to skip `offset` rows
while (batch != null && remainingOffset >= batch.numRows()) {
remainingOffset -= batch.numRows()
batch.safeClose()
batch = if (this.hasNext) {
input.next()
} else {
null
}
}

// If the last batch is null, then we have offset >= numRows in this partition.
// In such case, we should return an empty batch
if (batch == null || batch.numRows() == 0) {
return new ColumnarBatch(new ArrayBuffer[GpuColumnVector](numCols).toArray, 0)
}

// Here 0 <= remainingOffset < batch.numRow(), we need to get batch[remainingOffset:]
withResource(new NvtxWithMetrics("limit and offset", NvtxColor.ORANGE, opTime)) { _ =>
var result: ColumnarBatch = null
// limit < 0 (limit == -1) denotes there is no limitation, so when
// (remainingOffset == 0 && (remainingLimit >= batch.numRows() || limit < 0)) is true,
// we can take this batch completely
if (remainingOffset == 0 && (remainingLimit >= batch.numRows() || limit < 0)) {
result = batch
} else {
// otherwise, we need to slice batch with (remainingOffset, remainingLimit).
// And remainingOffset > 0 will be used only once, for the latter batches in this
// partition, set remainingOffset = 0
val length = if (remainingLimit >= batch.numRows() || limit < 0) {
batch.numRows()
} else {
remainingLimit
}
val scb = closeOnExcept(batch) { _ =>
SpillableColumnarBatch(batch, SpillPriorities.ACTIVE_ON_DECK_PRIORITY)
}
result = sliceBatchAndCloseWithRetry(scb, remainingOffset, length)
remainingOffset = 0
}
remainingLimit -= result.numRows()
numOutputBatches += 1
numOutputRows += result.numRows()
result
}
}

private def sliceBatchAndCloseWithRetry(
spillBatch: SpillableColumnarBatch,
Expand All @@ -56,76 +121,6 @@ object GpuBaseLimitExec {
}
}
}

def apply(
input: Iterator[ColumnarBatch],
limit: Int,
offset: Int,
opTime: GpuMetric,
numOutputBatches: GpuMetric,
numOutputRows: GpuMetric): Iterator[ColumnarBatch] = {
new Iterator[ColumnarBatch] {
private var remainingLimit = limit - offset
private var remainingOffset = offset

override def hasNext: Boolean = (limit == -1 || remainingLimit > 0) && input.hasNext

override def next(): ColumnarBatch = {
if (!this.hasNext) {
throw new NoSuchElementException("Next on empty iterator")
}

var batch = input.next()
val numCols = batch.numCols()

// In each partition, we need to skip `offset` rows
while (batch != null && remainingOffset >= batch.numRows()) {
remainingOffset -= batch.numRows()
batch.safeClose()
batch = if (this.hasNext) {
input.next()
} else {
null
}
}

// If the last batch is null, then we have offset >= numRows in this partition.
// In such case, we should return an empty batch
if (batch == null || batch.numRows() == 0) {
return new ColumnarBatch(new ArrayBuffer[GpuColumnVector](numCols).toArray, 0)
}

// Here 0 <= remainingOffset < batch.numRow(), we need to get batch[remainingOffset:]
withResource(new NvtxWithMetrics("limit and offset", NvtxColor.ORANGE, opTime)) { _ =>
var result: ColumnarBatch = null
// limit < 0 (limit == -1) denotes there is no limitation, so when
// (remainingOffset == 0 && (remainingLimit >= batch.numRows() || limit < 0)) is true,
// we can take this batch completely
if (remainingOffset == 0 && (remainingLimit >= batch.numRows() || limit < 0)) {
result = batch
} else {
// otherwise, we need to slice batch with (remainingOffset, remainingLimit).
// And remainingOffset > 0 will be used only once, for the latter batches in this
// partition, set remainingOffset = 0
val length = if (remainingLimit >= batch.numRows() || limit < 0) {
batch.numRows()
} else {
remainingLimit
}
val scb = closeOnExcept(batch) { _ =>
SpillableColumnarBatch(batch, SpillPriorities.ACTIVE_ON_DECK_PRIORITY)
}
result = sliceBatchAndCloseWithRetry(scb, remainingOffset, length)
remainingOffset = 0
}
remainingLimit -= result.numRows()
numOutputBatches += 1
numOutputRows += result.numRows()
result
}
}
}
}
}

/**
Expand Down Expand Up @@ -161,7 +156,7 @@ trait GpuBaseLimitExec extends LimitExec with GpuExec with ShimUnaryExecNode {
val numOutputRows = gpuLongMetric(NUM_OUTPUT_ROWS)
val numOutputBatches = gpuLongMetric(NUM_OUTPUT_BATCHES)
rdd.mapPartitions { iter =>
GpuBaseLimitExec(iter, limit, offset, opTime, numOutputBatches, numOutputRows)
new GpuBaseLimitIterator(iter, limit, offset, opTime, numOutputBatches, numOutputRows)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class LimitRetrySuite extends RmmSparkRetrySuiteBase {
test("GPU limit with retry OOM") {
val totalRows = 24
Seq((20, 5), (50, 5)).foreach { case (limit, offset) =>
val limitIter = GpuBaseLimitExec(
val limitIter = new GpuBaseLimitIterator(
// 3 batches as input, and each has 8 rows
(0 until totalRows).grouped(8).map(buildBatch(_)).toList.toIterator,
limit, offset, NoopMetric, NoopMetric, NoopMetric)
Expand Down

0 comments on commit a66b5dc

Please sign in to comment.