diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala index 91d022f76c7..10dd76035a1 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala @@ -123,10 +123,8 @@ abstract class AbstractGpuCoalesceIterator( collectTime: SQLMetric, concatTime: SQLMetric, totalTime: SQLMetric, - opName: String) extends Iterator[ColumnarBatch] with Logging { + opName: String) extends Iterator[ColumnarBatch] with Arm with Logging { private var batchInitialized: Boolean = false - private var collectMetric: Option[MetricRange] = None - private var totalMetric: Option[MetricRange] = None /** * Return true if there is something saved on deck for later processing. @@ -150,8 +148,8 @@ abstract class AbstractGpuCoalesceIterator( /** We need to track the sizes of string columns to make sure we don't exceed 2GB */ private val stringFieldIndices: Array[Int] = schema.fields.zipWithIndex - .filter(_._1.dataType == DataTypes.StringType) - .map(_._2) + .filter(_._1.dataType == DataTypes.StringType) + .map(_._2) /** Optional row limit */ var batchRowLimit: Int = 0 @@ -159,34 +157,30 @@ abstract class AbstractGpuCoalesceIterator( // note that TaskContext.get() can return null during unit testing so we wrap it in an // option here Option(TaskContext.get()) - .foreach(_.addTaskCompletionListener[Unit]( _ => clearOnDeck())) - - override def hasNext: Boolean = { - if (collectMetric.isEmpty) { - // use one being not set as indicator that neither are intialized to avoid - // 2 checks or extra initialized variable - collectMetric = Some(new MetricRange(collectTime)) - totalMetric = Some(new MetricRange(totalTime)) - } - while (!hasOnDeck && iter.hasNext) { - val cb = iter.next() - val numRows = cb.numRows() - numInputBatches += 1 - numInputRows += numRows - if (numRows > 0) { - saveOnDeck(cb) - } else { - cb.close() + .foreach(_.addTaskCompletionListener[Unit](_ => clearOnDeck())) + + private def iterHasNext: Boolean = withResource(new MetricRange(collectTime)) { _ => + iter.hasNext + } + + private def iterNext(): ColumnarBatch = withResource(new MetricRange(collectTime)) { _ => + iter.next() + } + + override def hasNext: Boolean = withResource(new MetricRange(totalTime)) { _ => + while (!hasOnDeck && iterHasNext) { + closeOnExcept(iterNext()) { cb => + val numRows = cb.numRows() + numInputBatches += 1 + numInputRows += numRows + if (numRows > 0) { + saveOnDeck(cb) + } else { + cb.close() + } } } - val res = hasOnDeck - if (!res) { - totalMetric.foreach(_.close()) - totalMetric = None - collectMetric.foreach(_.close()) - collectMetric = None - } - res + hasOnDeck } /** @@ -198,18 +192,21 @@ abstract class AbstractGpuCoalesceIterator( * Called to add a new batch to the final output batch. The batch passed in will * not be closed. If it needs to be closed it is the responsibility of the child class * to do it. + * * @param batch the batch to add in. */ def addBatchToConcat(batch: ColumnarBatch): Unit /** * Calculate (or estimate) the size of each column in a batch in bytes. + * * @return Array of column sizes in bytes */ def getColumnSizes(batch: ColumnarBatch): Array[Long] /** * Called after all of the batches have been added in. + * * @return the concated batches on the GPU. */ def concatAllAndPutOnGPU(): ColumnarBatch @@ -248,7 +245,7 @@ abstract class AbstractGpuCoalesceIterator( * * @return The coalesced batch */ - override def next(): ColumnarBatch = { + override def next(): ColumnarBatch = withResource(new MetricRange(totalTime)) { _ => // reset batch state batchInitialized = false batchRowLimit = 0 @@ -271,12 +268,9 @@ abstract class AbstractGpuCoalesceIterator( addBatch(batch) } - try { - - // there is a hard limit of 2^31 rows - while (numRows < Int.MaxValue && !hasOnDeck && iter.hasNext) { - - val cb = iter.next() + // there is a hard limit of 2^31 rows + while (numRows < Int.MaxValue && !hasOnDeck && iterHasNext) { + closeOnExcept(iterNext()) { cb => val nextRows = cb.numRows() numInputBatches += 1 @@ -300,14 +294,14 @@ abstract class AbstractGpuCoalesceIterator( // memory usage. val wouldBeStringColumnSizes = stringFieldIndices.map(i => getColumnDataSize(cb, i, wouldBeColumnSizes(i))) - .zip(stringColumnSizes) - .map(pair => pair._1 + pair._2) + .zip(stringColumnSizes) + .map(pair => pair._1 + pair._2) if (wouldBeRows > Int.MaxValue) { if (goal == RequireSingleBatch) { throw new IllegalStateException("A single batch is required for this operation," + - s" but cuDF only supports ${Int.MaxValue} rows. At least $wouldBeRows are in" + - s" this partition. Please try increasing your partition count.") + s" but cuDF only supports ${Int.MaxValue} rows. At least $wouldBeRows" + + s" are in this partition. Please try increasing your partition count.") } saveOnDeck(cb) } else if (batchRowLimit > 0 && wouldBeRows > batchRowLimit) { @@ -317,9 +311,9 @@ abstract class AbstractGpuCoalesceIterator( } else if (wouldBeStringColumnSizes.exists(size => size > Int.MaxValue)) { if (goal == RequireSingleBatch) { throw new IllegalStateException("A single batch is required for this operation," + - s" but cuDF only supports ${Int.MaxValue} bytes in a single string column." + - s" At least ${wouldBeStringColumnSizes.max} are in a single column in this" + - s" partition. Please try increasing your partition count.") + s" but cuDF only supports ${Int.MaxValue} bytes in a single string column." + + s" At least ${wouldBeStringColumnSizes.max} are in a single column in this" + + s" partition. Please try increasing your partition count.") } saveOnDeck(cb) } else { @@ -333,32 +327,21 @@ abstract class AbstractGpuCoalesceIterator( cb.close() } } + } - // enforce single batch limit when appropriate - if (goal == RequireSingleBatch && (hasOnDeck || iter.hasNext)) { - throw new IllegalStateException("A single batch is required for this operation." + + // enforce single batch limit when appropriate + if (goal == RequireSingleBatch && (hasOnDeck || iterHasNext)) { + throw new IllegalStateException("A single batch is required for this operation." + " Please try increasing your partition count.") - } - - numOutputRows += numRows - numOutputBatches += 1 - - } finally { - collectMetric.foreach(_.close()) - collectMetric = None } - val concatRange = new NvtxWithMetrics(s"$opName concat", NvtxColor.CYAN, concatTime) - val ret = try { + numOutputRows += numRows + numOutputBatches += 1 + withResource(new NvtxWithMetrics(s"$opName concat", NvtxColor.CYAN, concatTime)) { _ => concatAllAndPutOnGPU() - } finally { - concatRange.close() } - ret } finally { cleanupConcatIsDone() - totalMetric.foreach(_.close()) - totalMetric = None } }