diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala index 336d9b4f5d3..ecdcf9b9450 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala @@ -160,6 +160,8 @@ abstract class AbstractGpuCoalesceIterator(origIter: Iterator[ColumnarBatch], private val iter = new RemoveEmptyBatchIterator(origIter, numInputBatches) private var onDeck: Option[ColumnarBatch] = None private var batchInitialized: Boolean = false + private var collectMetric: Option[MetricRange] = None + private var totalMetric: Option[MetricRange] = None /** We need to track the sizes of string columns to make sure we don't exceed 2GB */ private val stringFieldIndices: Array[Int] = schema.fields.zipWithIndex @@ -174,7 +176,22 @@ abstract class AbstractGpuCoalesceIterator(origIter: Iterator[ColumnarBatch], Option(TaskContext.get()) .foreach(_.addTaskCompletionListener[Unit](_ => onDeck.foreach(_.close()))) - override def hasNext: Boolean = onDeck.isDefined || iter.hasNext + override def hasNext: Boolean = { + if (!collectMetric.isDefined) { + // use one being not set as indicator that neither are intialized to avoid + // 2 checks or extra initialized variable + collectMetric = Some(new MetricRange(collectTime)) + totalMetric = Some(new MetricRange(totalTime)) + } + val res = onDeck.isDefined || iter.hasNext + if (!res) { + collectMetric.foreach(_.close()) + collectMetric = None + totalMetric.foreach(_.close()) + totalMetric = None + } + res + } /** * Called first to initialize any state needed for a new batch to be created. @@ -236,9 +253,6 @@ abstract class AbstractGpuCoalesceIterator(origIter: Iterator[ColumnarBatch], * @return The coalesced batch */ override def next(): ColumnarBatch = { - - val total = new MetricRange(totalTime) - // reset batch state batchInitialized = false batchRowLimit = 0 @@ -261,7 +275,6 @@ abstract class AbstractGpuCoalesceIterator(origIter: Iterator[ColumnarBatch], numBytes += columnSizes.sum } - val collect = new MetricRange(collectTime) try { // there is a hard limit of 2^31 rows @@ -339,7 +352,8 @@ abstract class AbstractGpuCoalesceIterator(origIter: Iterator[ColumnarBatch], s"and $numBytes bytes") } finally { - collect.close() + collectMetric.foreach(_.close()) + collectMetric = None } val concatRange = new NvtxWithMetrics(s"$opName concat", NvtxColor.CYAN, concatTime) @@ -351,7 +365,8 @@ abstract class AbstractGpuCoalesceIterator(origIter: Iterator[ColumnarBatch], ret } finally { cleanupConcatIsDone() - total.close() + totalMetric.foreach(_.close()) + totalMetric = None } }