diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/limit.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/limit.scala index 78dc5c64349..a325473e335 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/limit.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/limit.scala @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids import scala.collection.mutable.ArrayBuffer -import ai.rapids.cudf.Table +import ai.rapids.cudf.{NvtxColor, Table} import com.nvidia.spark.rapids.GpuMetricNames._ import com.nvidia.spark.rapids.RapidsPluginImplicits._ @@ -47,6 +47,10 @@ trait GpuBaseLimitExec extends LimitExec with GpuExec { throw new IllegalStateException(s"Row-based execution should not occur for $this") override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val numOutputRows = longMetric(NUM_OUTPUT_ROWS) + val numOutputBatches = longMetric(NUM_OUTPUT_BATCHES) + val totalTime = longMetric(TOTAL_TIME) + val crdd = child.executeColumnar() crdd.mapPartitions { cbIter => new Iterator[ColumnarBatch] { @@ -56,13 +60,17 @@ trait GpuBaseLimitExec extends LimitExec with GpuExec { override def next(): ColumnarBatch = { val batch = cbIter.next() - val result = if (batch.numRows() > remainingLimit) { - sliceBatch(batch) - } else { - batch + withResource(new NvtxWithMetrics("limit", NvtxColor.ORANGE, totalTime)) { _ => + val result = if (batch.numRows() > remainingLimit) { + sliceBatch(batch) + } else { + batch + } + numOutputBatches += 1 + numOutputRows += result.numRows() + remainingLimit -= result.numRows() + result } - remainingLimit -= result.numRows() - result } def sliceBatch(batch: ColumnarBatch): ColumnarBatch = { @@ -123,37 +131,8 @@ class GpuCollectLimitMeta( Seq(GpuOverrides.wrapPart(collectLimit.outputPartitioning, conf, Some(this))) override def convertToGpu(): GpuExec = - GpuCollectLimitExec(collectLimit.limit, childParts(0).convertToGpu(), - GpuLocalLimitExec(collectLimit.limit, - GpuShuffleExchangeExec(GpuSinglePartitioning(Seq.empty), childPlans(0).convertIfNeeded()))) -} - -case class GpuCollectLimitExec( - limit: Int, partitioning: GpuPartitioning, - child: SparkPlan) extends GpuBaseLimitExec { - - private val serializer: Serializer = new GpuColumnarBatchSerializer(child.output.size) - - private lazy val writeMetrics = - SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) - private lazy val readMetrics = - SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext) + GpuGlobalLimitExec(collectLimit.limit, + GpuShuffleExchangeExec(GpuSinglePartitioning(Seq.empty), + GpuLocalLimitExec(collectLimit.limit, childPlans(0).convertIfNeeded()))) - lazy val shuffleMetrics = readMetrics ++ writeMetrics - - override def doExecuteColumnar(): RDD[ColumnarBatch] = { - val locallyLimited: RDD[ColumnarBatch] = super.doExecuteColumnar() - - val shuffleDependency = GpuShuffleExchangeExec.prepareBatchShuffleDependency( - locallyLimited, - child.output, - partitioning, - serializer, - metrics ++ shuffleMetrics, - metrics ++ writeMetrics) - - val shuffled = new ShuffledBatchRDD(shuffleDependency, metrics ++ shuffleMetrics, None) - shuffled.mapPartitions(_.take(limit)) - } - -} +} \ No newline at end of file