Skip to content

Commit

Permalink
Clean up CoalesceBatch to use withResource (#890)
Browse files Browse the repository at this point in the history
Signed-off-by: Robert (Bobby) Evans <bobby@apache.org>
  • Loading branch information
revans2 authored Sep 30, 2020
1 parent 666f89b commit fe70dbb
Showing 1 changed file with 46 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,8 @@ abstract class AbstractGpuCoalesceIterator(
collectTime: SQLMetric,
concatTime: SQLMetric,
totalTime: SQLMetric,
opName: String) extends Iterator[ColumnarBatch] with Logging {
opName: String) extends Iterator[ColumnarBatch] with Arm with Logging {
private var batchInitialized: Boolean = false
private var collectMetric: Option[MetricRange] = None
private var totalMetric: Option[MetricRange] = None

/**
* Return true if there is something saved on deck for later processing.
Expand All @@ -150,43 +148,39 @@ abstract class AbstractGpuCoalesceIterator(

/** We need to track the sizes of string columns to make sure we don't exceed 2GB */
private val stringFieldIndices: Array[Int] = schema.fields.zipWithIndex
.filter(_._1.dataType == DataTypes.StringType)
.map(_._2)
.filter(_._1.dataType == DataTypes.StringType)
.map(_._2)

/** Optional row limit */
var batchRowLimit: Int = 0

// note that TaskContext.get() can return null during unit testing so we wrap it in an
// option here
Option(TaskContext.get())
.foreach(_.addTaskCompletionListener[Unit]( _ => clearOnDeck()))

override def hasNext: Boolean = {
if (collectMetric.isEmpty) {
// use one being not set as indicator that neither are intialized to avoid
// 2 checks or extra initialized variable
collectMetric = Some(new MetricRange(collectTime))
totalMetric = Some(new MetricRange(totalTime))
}
while (!hasOnDeck && iter.hasNext) {
val cb = iter.next()
val numRows = cb.numRows()
numInputBatches += 1
numInputRows += numRows
if (numRows > 0) {
saveOnDeck(cb)
} else {
cb.close()
.foreach(_.addTaskCompletionListener[Unit](_ => clearOnDeck()))

private def iterHasNext: Boolean = withResource(new MetricRange(collectTime)) { _ =>
iter.hasNext
}

private def iterNext(): ColumnarBatch = withResource(new MetricRange(collectTime)) { _ =>
iter.next()
}

override def hasNext: Boolean = withResource(new MetricRange(totalTime)) { _ =>
while (!hasOnDeck && iterHasNext) {
closeOnExcept(iterNext()) { cb =>
val numRows = cb.numRows()
numInputBatches += 1
numInputRows += numRows
if (numRows > 0) {
saveOnDeck(cb)
} else {
cb.close()
}
}
}
val res = hasOnDeck
if (!res) {
totalMetric.foreach(_.close())
totalMetric = None
collectMetric.foreach(_.close())
collectMetric = None
}
res
hasOnDeck
}

/**
Expand All @@ -198,18 +192,21 @@ abstract class AbstractGpuCoalesceIterator(
* Called to add a new batch to the final output batch. The batch passed in will
* not be closed. If it needs to be closed it is the responsibility of the child class
* to do it.
*
* @param batch the batch to add in.
*/
def addBatchToConcat(batch: ColumnarBatch): Unit

/**
* Calculate (or estimate) the size of each column in a batch in bytes.
*
* @return Array of column sizes in bytes
*/
def getColumnSizes(batch: ColumnarBatch): Array[Long]

/**
* Called after all of the batches have been added in.
*
* @return the concated batches on the GPU.
*/
def concatAllAndPutOnGPU(): ColumnarBatch
Expand Down Expand Up @@ -248,7 +245,7 @@ abstract class AbstractGpuCoalesceIterator(
*
* @return The coalesced batch
*/
override def next(): ColumnarBatch = {
override def next(): ColumnarBatch = withResource(new MetricRange(totalTime)) { _ =>
// reset batch state
batchInitialized = false
batchRowLimit = 0
Expand All @@ -271,12 +268,9 @@ abstract class AbstractGpuCoalesceIterator(
addBatch(batch)
}

try {

// there is a hard limit of 2^31 rows
while (numRows < Int.MaxValue && !hasOnDeck && iter.hasNext) {

val cb = iter.next()
// there is a hard limit of 2^31 rows
while (numRows < Int.MaxValue && !hasOnDeck && iterHasNext) {
closeOnExcept(iterNext()) { cb =>
val nextRows = cb.numRows()
numInputBatches += 1

Expand All @@ -300,14 +294,14 @@ abstract class AbstractGpuCoalesceIterator(
// memory usage.
val wouldBeStringColumnSizes =
stringFieldIndices.map(i => getColumnDataSize(cb, i, wouldBeColumnSizes(i)))
.zip(stringColumnSizes)
.map(pair => pair._1 + pair._2)
.zip(stringColumnSizes)
.map(pair => pair._1 + pair._2)

if (wouldBeRows > Int.MaxValue) {
if (goal == RequireSingleBatch) {
throw new IllegalStateException("A single batch is required for this operation," +
s" but cuDF only supports ${Int.MaxValue} rows. At least $wouldBeRows are in" +
s" this partition. Please try increasing your partition count.")
s" but cuDF only supports ${Int.MaxValue} rows. At least $wouldBeRows" +
s" are in this partition. Please try increasing your partition count.")
}
saveOnDeck(cb)
} else if (batchRowLimit > 0 && wouldBeRows > batchRowLimit) {
Expand All @@ -317,9 +311,9 @@ abstract class AbstractGpuCoalesceIterator(
} else if (wouldBeStringColumnSizes.exists(size => size > Int.MaxValue)) {
if (goal == RequireSingleBatch) {
throw new IllegalStateException("A single batch is required for this operation," +
s" but cuDF only supports ${Int.MaxValue} bytes in a single string column." +
s" At least ${wouldBeStringColumnSizes.max} are in a single column in this" +
s" partition. Please try increasing your partition count.")
s" but cuDF only supports ${Int.MaxValue} bytes in a single string column." +
s" At least ${wouldBeStringColumnSizes.max} are in a single column in this" +
s" partition. Please try increasing your partition count.")
}
saveOnDeck(cb)
} else {
Expand All @@ -333,32 +327,21 @@ abstract class AbstractGpuCoalesceIterator(
cb.close()
}
}
}

// enforce single batch limit when appropriate
if (goal == RequireSingleBatch && (hasOnDeck || iter.hasNext)) {
throw new IllegalStateException("A single batch is required for this operation." +
// enforce single batch limit when appropriate
if (goal == RequireSingleBatch && (hasOnDeck || iterHasNext)) {
throw new IllegalStateException("A single batch is required for this operation." +
" Please try increasing your partition count.")
}

numOutputRows += numRows
numOutputBatches += 1

} finally {
collectMetric.foreach(_.close())
collectMetric = None
}

val concatRange = new NvtxWithMetrics(s"$opName concat", NvtxColor.CYAN, concatTime)
val ret = try {
numOutputRows += numRows
numOutputBatches += 1
withResource(new NvtxWithMetrics(s"$opName concat", NvtxColor.CYAN, concatTime)) { _ =>
concatAllAndPutOnGPU()
} finally {
concatRange.close()
}
ret
} finally {
cleanupConcatIsDone()
totalMetric.foreach(_.close())
totalMetric = None
}
}

Expand Down

0 comments on commit fe70dbb

Please sign in to comment.