diff --git a/build/buildall b/build/buildall index ab2497dd7f0..d96d6b74dd4 100755 --- a/build/buildall +++ b/build/buildall @@ -242,7 +242,7 @@ function build_single_shim() { -Dskip \ -Dmaven.scalastyle.skip="$SKIP_CHECKS" \ -pl aggregator -am > "$LOG_FILE" 2>&1 || { - [[ "$LOG_FILE" != "/dev/tty" ]] && tail -20 "$LOG_FILE" || true + [[ "$LOG_FILE" != "/dev/tty" ]] && echo "$LOG_FILE:" && tail -20 "$LOG_FILE" || true exit 255 } } diff --git a/docs/additional-functionality/advanced_configs.md b/docs/additional-functionality/advanced_configs.md index 78da3f027ef..298fbbfae61 100644 --- a/docs/additional-functionality/advanced_configs.md +++ b/docs/additional-functionality/advanced_configs.md @@ -6,12 +6,12 @@ nav_order: 10 --- # RAPIDS Accelerator for Apache Spark Advanced Configuration -Most users will not need to modify the configuration options listed below. +Most users will not need to modify the configuration options listed below. They are documented here for completeness and advanced usage. The following configuration options are supported by the RAPIDS Accelerator for Apache Spark. -For commonly used configurations and examples of setting options, please refer to the +For commonly used configurations and examples of setting options, please refer to the [RAPIDS Accelerator for Configuration](../configs.md) page. diff --git a/integration_tests/src/main/python/hash_aggregate_test.py b/integration_tests/src/main/python/hash_aggregate_test.py index 4f997f8431d..b61cb45b327 100644 --- a/integration_tests/src/main/python/hash_aggregate_test.py +++ b/integration_tests/src/main/python/hash_aggregate_test.py @@ -1883,3 +1883,15 @@ def test_hash_aggregate_complete_with_grouping_expressions(): lambda spark : spark.range(10).withColumn("id2", f.col("id")), "hash_agg_complete_table", "select id, avg(id) from hash_agg_complete_table group by id, id2 + 1") + +@ignore_order(local=True) +@pytest.mark.parametrize('cast_key_to', ["byte", "short", "int", + "long", "string", "DECIMAL(38,5)"], ids=idfn) +def test_hash_agg_force_pre_sort(cast_key_to): + def do_it(spark): + gen = StructGen([("key", UniqueLongGen()), ("value", long_gen)], nullable=False) + df = gen_df(spark, gen) + return df.selectExpr("CAST((key div 10) as " + cast_key_to + ") as key", "value").groupBy("key").sum("value") + assert_gpu_and_cpu_are_equal_collect(do_it, + conf={'spark.rapids.sql.agg.forceSinglePassPartialSort': True, + 'spark.rapids.sql.agg.singlePassPartialSortEnabled': True}) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBatchUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBatchUtils.scala index dfb023633cb..bf2d2474dfe 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBatchUtils.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBatchUtils.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -63,6 +63,32 @@ object GpuBatchUtils { estimateGpuMemory(field.dataType, field.nullable, rowCount) } + /** + * Get the minimum size a column could be that matches these conditions. + */ + def minGpuMemory(dataType:DataType, nullable: Boolean, rowCount: Long): Long = { + val validityBufferSize = if (nullable) { + calculateValidityBufferSize(rowCount) + } else { + 0 + } + + val dataSize = dataType match { + case DataTypes.BinaryType | DataTypes.StringType | _: MapType | _: ArrayType=> + // For nested types (like list or string) the smallest possible size is when + // each row is empty (length 0). In that case there is no data, just offsets + // and all of the offsets are 0. + calculateOffsetBufferSize(rowCount) + case dt: StructType => + dt.fields.map { f => + minGpuMemory(f.dataType, f.nullable, rowCount) + }.sum + case dt => + dt.defaultSize * rowCount + } + dataSize + validityBufferSize + } + def estimateGpuMemory(dataType: DataType, nullable: Boolean, rowCount: Long): Long = { val validityBufferSize = if (nullable) { calculateValidityBufferSize(rowCount) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDeviceManager.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDeviceManager.scala index 6e3e818e29b..5d7fcfa47ae 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDeviceManager.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDeviceManager.scala @@ -61,6 +61,11 @@ object GpuDeviceManager extends Logging { */ def getDeviceId(): Option[Int] = deviceId + @volatile private var poolSizeLimit = 0L + + // Never split below 100 MiB (but this is really just for testing) + def getSplitUntilSize: Long = Math.max(poolSizeLimit / 8, 100 * 1024 * 1024) + // Attempt to set and acquire the gpu, return true if acquired, false otherwise def tryToSetGpuDeviceAndAcquire(addr: Int): Boolean = { try { @@ -149,6 +154,7 @@ object GpuDeviceManager extends Logging { chunkedPackMemoryResource.foreach(_.close) chunkedPackMemoryResource = None + poolSizeLimit = 0L RapidsBufferCatalog.close() GpuShuffleEnv.shutdown() @@ -338,6 +344,7 @@ object GpuDeviceManager extends Logging { Cuda.setDevice(gpuId) try { + poolSizeLimit = poolAllocation Rmm.initialize(init, logConf, poolAllocation) } catch { case firstEx: CudfException if ((init & RmmAllocationMode.CUDA_ASYNC) != 0) => { @@ -351,6 +358,7 @@ object GpuDeviceManager extends Logging { logError( "Failed to initialize RMM with either ASYNC or ARENA allocators. Exiting...") secondEx.addSuppressed(firstEx) + poolSizeLimit = 0L throw secondEx } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortExec.scala index 2ddbc5f398b..93da5b5f101 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortExec.scala @@ -136,8 +136,7 @@ case class GpuSortExec( val singleBatch = sortType == FullSortSingleBatch child.executeColumnar().mapPartitions { cbIter => if (outOfCore) { - val cpuOrd = new LazilyGeneratedOrdering(sorter.cpuOrdering) - val iter = GpuOutOfCoreSortIterator(cbIter, sorter, cpuOrd, + val iter = GpuOutOfCoreSortIterator(cbIter, sorter, targetSize, opTime, sortTime, outputBatch, outputRows) TaskContext.get().addTaskCompletionListener(_ -> iter.close()) iter @@ -239,7 +238,6 @@ class Pending(cpuOrd: LazilyGeneratedOrdering) extends AutoCloseable { case class GpuOutOfCoreSortIterator( iter: Iterator[ColumnarBatch], sorter: GpuSorter, - cpuOrd: LazilyGeneratedOrdering, targetSize: Long, opTime: GpuMetric, sortTime: GpuMetric, @@ -247,6 +245,7 @@ case class GpuOutOfCoreSortIterator( outputRows: GpuMetric) extends Iterator[ColumnarBatch] with AutoCloseable { + private val cpuOrd = new LazilyGeneratedOrdering(sorter.cpuOrdering) // A priority queue of data that is not merged yet. private val pending = new Pending(cpuOrd) @@ -328,16 +327,16 @@ case class GpuOutOfCoreSortIterator( targetRowCount until rows by targetRowCount } // Get back the first row so we can sort the batches - val gatherIndexes = if (hasFullySortedData) { + val lowerGatherIndexes = if (hasFullySortedData) { // The first batch is sorted so don't gather a row for it splitIndexes } else { Seq(0) ++ splitIndexes } - val boundaries = - withResource(new NvtxRange("boundaries", NvtxColor.ORANGE)) { _ => - withResource(ColumnVector.fromInts(gatherIndexes: _*)) { gatherMap => + val lowerBoundaries = + withResource(new NvtxRange("lower boundaries", NvtxColor.ORANGE)) { _ => + withResource(ColumnVector.fromInts(lowerGatherIndexes: _*)) { gatherMap => withResource(sortedTbl.gather(gatherMap)) { boundariesTab => convertBoundaries(boundariesTab) } @@ -355,9 +354,9 @@ case class GpuOutOfCoreSortIterator( } closeOnExcept(sortedCb) { _ => - assert(boundaries.length == stillPending.length) + assert(lowerBoundaries.length == stillPending.length) closeOnExcept(pendingObs) { _ => - stillPending.zip(boundaries).foreach { + stillPending.zip(lowerBoundaries).foreach { case (ct: ContiguousTable, lower: UnsafeRow) => if (ct.getRowCount > 0) { val sp = SpillableColumnarBatch(ct, sorter.projectedBatchTypes, diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index 76029f72f80..548dc5884b3 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -1273,6 +1273,23 @@ object RapidsConf { .booleanConf .createWithDefault(true) + val ENABLE_SINGLE_PASS_PARTIAL_SORT_AGG: ConfEntryWithDefault[Boolean] = + conf("spark.rapids.sql.agg.singlePassPartialSortEnabled") + .doc("Enable or disable a single pass partial sort optimization where if a heuristic " + + "indicates it would be good we pre-sort the data before a partial agg and then " + + "do the agg in a single pass with no merge, so there is no spilling") + .internal() + .booleanConf + .createWithDefault(true) + + val FORCE_SINGLE_PASS_PARTIAL_SORT_AGG: ConfEntryWithDefault[Boolean] = + conf("spark.rapids.sql.agg.forceSinglePassPartialSort") + .doc("Force a single pass partial sort agg to happen in all cases that it could, " + + "no matter what the heuristic says. This is really just for testing.") + .internal() + .booleanConf + .createWithDefault(false) + val ENABLE_REGEXP = conf("spark.rapids.sql.regexp.enabled") .doc("Specifies whether supported regular expressions will be evaluated on the GPU. " + "Unsupported expressions will fall back to CPU. However, there are some known edge cases " + @@ -2007,12 +2024,12 @@ object RapidsConf { println(s"") // scalastyle:off line.size.limit println("""# RAPIDS Accelerator for Apache Spark Advanced Configuration - |Most users will not need to modify the configuration options listed below. + |Most users will not need to modify the configuration options listed below. |They are documented here for completeness and advanced usage. | |The following configuration options are supported by the RAPIDS Accelerator for Apache Spark. | - |For commonly used configurations and examples of setting options, please refer to the + |For commonly used configurations and examples of setting options, please refer to the |[RAPIDS Accelerator for Configuration](../configs.md) page. |""".stripMargin) // scalastyle:on line.size.limit @@ -2554,6 +2571,10 @@ class RapidsConf(conf: Map[String, String]) extends Logging { lazy val isRangeWindowDecimalEnabled: Boolean = get(ENABLE_RANGE_WINDOW_DECIMAL) + lazy val allowSinglePassPartialSortAgg: Boolean = get(ENABLE_SINGLE_PASS_PARTIAL_SORT_AGG) + + lazy val forceSinglePassPartialSortAgg: Boolean = get(FORCE_SINGLE_PASS_PARTIAL_SORT_AGG) + lazy val isRegExpEnabled: Boolean = get(ENABLE_REGEXP) lazy val maxRegExpStateMemory: Long = { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala index cf04b4ebc2f..14fbba8e4f1 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala @@ -1418,7 +1418,8 @@ final class RuleNotFoundExprMeta[INPUT <: Expression]( willNotWorkOnGpu(s"GPU does not currently support the operator ${expr.getClass}") override def convertToGpu(): GpuExpression = - throw new IllegalStateException("Cannot be converted to GPU") + throw new IllegalStateException(s"Cannot be converted to GPU ${expr.getClass} " + + s"${expr.dataType} $expr") } /** Base class for metadata around `RunnableCommand`. */ diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala index cd0a65a7e6b..df20db39c75 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala @@ -24,8 +24,9 @@ import scala.collection.mutable import ai.rapids.cudf import ai.rapids.cudf.{NvtxColor, NvtxRange} import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} -import com.nvidia.spark.rapids.GpuHashAggregateIterator.{computeAggregateAndClose, concatenateBatches, AggHelper} +import com.nvidia.spark.rapids.GpuAggregateIterator.{computeAggregateAndClose, computeAggregateWithoutPreprocessAndClose, concatenateBatches} import com.nvidia.spark.rapids.GpuMetric._ +import com.nvidia.spark.rapids.GpuOverrides.pluginSupportedOrderableSig import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.RmmRapidsRetryIterator.{splitSpillableInHalfByRows, withRetry, withRetryNoSplit} import com.nvidia.spark.rapids.shims.{AggregationTagging, ShimUnaryExecNode} @@ -36,7 +37,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeReference, AttributeSeq, AttributeSet, Expression, ExprId, If, NamedExpression, NullsFirst, SortOrder} import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, HashPartitioning, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.catalyst.trees.TreeNodeTag @@ -45,7 +45,7 @@ import org.apache.spark.sql.execution.{ExplainUtils, SortExec, SparkPlan} import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.rapids.{CpuToGpuAggregateBufferConverter, CudfAggregate, GpuAggregateExpression, GpuToCpuAggregateBufferConverter} import org.apache.spark.sql.rapids.execution.{GpuShuffleMeta, TrampolineUtil} -import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType, MapType} +import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch object AggregateUtils { @@ -132,7 +132,12 @@ case class GpuHashAggregateMetrics( opTime: GpuMetric, computeAggTime: GpuMetric, concatTime: GpuMetric, - sortTime: GpuMetric) + sortTime: GpuMetric, + numAggOps: GpuMetric, + numPreSplits: GpuMetric, + singlePassTasks: GpuMetric, + heuristicTime: GpuMetric) { +} /** Utility class to convey information on the aggregation modes being used */ case class AggregateModeInfo( @@ -154,235 +159,282 @@ object AggregateModeInfo { } } -object GpuHashAggregateIterator extends Logging { - /** - * Internal class used in `computeAggregates` for the pre, agg, and post steps - * - * @param inputAttributes input attributes to identify the input columns from the input batches - * @param groupingExpressions expressions used for producing the grouping keys - * @param aggregateExpressions GPU aggregate expressions used to produce the aggregations - * @param forceMerge if true, we are merging two pre-aggregated batches, so we should use - * the merge steps for each aggregate function - * @param isSorted if the batch is sorted this is set to true and is passed to cuDF - * as an optimization hint - * @param useTieredProject if true, used tiered project for input projections - */ - class AggHelper( - inputAttributes: Seq[Attribute], - groupingExpressions: Seq[NamedExpression], - aggregateExpressions: Seq[GpuAggregateExpression], - forceMerge: Boolean, - isSorted: Boolean = false, - useTieredProject : Boolean = true) { +/** + * Internal class used in `computeAggregates` for the pre, agg, and post steps + * + * @param inputAttributes input attributes to identify the input columns from the input batches + * @param groupingExpressions expressions used for producing the grouping keys + * @param aggregateExpressions GPU aggregate expressions used to produce the aggregations + * @param forceMerge if true, we are merging two pre-aggregated batches, so we should use + * the merge steps for each aggregate function + * @param isSorted if the batch is sorted this is set to true and is passed to cuDF + * as an optimization hint + * @param useTieredProject if true, used tiered project for input projections + */ +class AggHelper( + inputAttributes: Seq[Attribute], + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[GpuAggregateExpression], + forceMerge: Boolean, + isSorted: Boolean = false, + useTieredProject: Boolean = true) extends Serializable { - // `CudfAggregate` instances to apply, either update or merge aggregates - // package private for testing - private[rapids] val cudfAggregates = new mutable.ArrayBuffer[CudfAggregate]() + private var doSortAgg = isSorted - // integers for each column the aggregate is operating on - // package private for testing - private[rapids] val aggOrdinals = new mutable.ArrayBuffer[Int] + def setSort(isSorted: Boolean): Unit = { + doSortAgg = isSorted + } - // grouping ordinals are the indices of the tables to aggregate that need to be - // the grouping key - // package private for testing - private[rapids] val groupingOrdinals: Array[Int] = groupingExpressions.indices.toArray + // `CudfAggregate` instances to apply, either update or merge aggregates + // package private for testing + private[rapids] val cudfAggregates = new mutable.ArrayBuffer[CudfAggregate]() + + // integers for each column the aggregate is operating on + // package private for testing + private[rapids] val aggOrdinals = new mutable.ArrayBuffer[Int] + + // grouping ordinals are the indices of the tables to aggregate that need to be + // the grouping key + // package private for testing + private[rapids] val groupingOrdinals: Array[Int] = groupingExpressions.indices.toArray + + // the resulting data type from the cuDF aggregate (from + // the update or merge aggregate, be it reduction or group by) + private[rapids] val postStepDataTypes = new mutable.ArrayBuffer[DataType]() + + private val groupingAttributes = groupingExpressions.map(_.toAttribute) + private val aggBufferAttributes = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + + // `GpuAggregateFunction` can add a pre and post step for update + // and merge aggregates. + private val preStep = new mutable.ArrayBuffer[Expression]() + private val postStep = new mutable.ArrayBuffer[Expression]() + private val postStepAttr = new mutable.ArrayBuffer[Attribute]() + + // we add the grouping expression first, which should bind as pass-through + if (forceMerge) { + // a grouping expression can do actual computation, but we cannot do that computation again + // on a merge, nor would we want to if we could. So use the attributes instead of the + // original expression when we are forcing a merge. + preStep ++= groupingAttributes + } else { + preStep ++= groupingExpressions + } + postStep ++= groupingAttributes + postStepAttr ++= groupingAttributes + postStepDataTypes ++= + groupingExpressions.map(_.dataType) + + private var ix = groupingAttributes.length + for (aggExp <- aggregateExpressions) { + val aggFn = aggExp.aggregateFunction + if ((aggExp.mode == Partial || aggExp.mode == Complete) && !forceMerge) { + val ordinals = (ix until ix + aggFn.updateAggregates.length) + aggOrdinals ++= ordinals + ix += ordinals.length + val updateAggs = aggFn.updateAggregates + postStepDataTypes ++= updateAggs.map(_.dataType) + cudfAggregates ++= updateAggs + preStep ++= aggFn.inputProjection + postStep ++= aggFn.postUpdate + postStepAttr ++= aggFn.postUpdateAttr + } else { + val ordinals = (ix until ix + aggFn.mergeAggregates.length) + aggOrdinals ++= ordinals + ix += ordinals.length + val mergeAggs = aggFn.mergeAggregates + postStepDataTypes ++= mergeAggs.map(_.dataType) + cudfAggregates ++= mergeAggs + preStep ++= aggFn.preMerge + postStep ++= aggFn.postMerge + postStepAttr ++= aggFn.postMergeAttr + } + } - // the resulting data type from the cuDF aggregate (from - // the update or merge aggregate, be it reduction or group by) - private[rapids] val postStepDataTypes = new mutable.ArrayBuffer[DataType]() + // a bound expression that is applied before the cuDF aggregate + private val preStepAttributes = if (forceMerge) { + aggBufferAttributes + } else { + inputAttributes + } + val preStepBound = GpuBindReferences.bindGpuReferencesTiered(preStep.toList, + preStepAttributes.toList, useTieredProject) - private val groupingAttributes = groupingExpressions.map(_.toAttribute) - private val aggBufferAttributes = groupingAttributes ++ - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + // a bound expression that is applied after the cuDF aggregate + private val postStepBound = GpuBindReferences.bindGpuReferencesTiered(postStep.toList, + postStepAttr.toList, useTieredProject) - // `GpuAggregateFunction` can add a pre and post step for update - // and merge aggregates. - private val preStep = new mutable.ArrayBuffer[Expression]() - private val postStep = new mutable.ArrayBuffer[Expression]() - private val postStepAttr = new mutable.ArrayBuffer[Attribute]() - - // we add the grouping expression first, which should bind as pass-through - if (forceMerge) { - // a grouping expression can do actual computation, but we cannot do that computation again - // on a merge, nor would we want to if we could. So use the attributes instead of the - // original expression when we are forcing a merge. - preStep ++= groupingAttributes - } else { - preStep ++= groupingExpressions - } - postStep ++= groupingAttributes - postStepAttr ++= groupingAttributes - postStepDataTypes ++= - groupingExpressions.map(_.dataType) - - private var ix = groupingAttributes.length - for (aggExp <- aggregateExpressions) { - val aggFn = aggExp.aggregateFunction - if ((aggExp.mode == Partial || aggExp.mode == Complete) && !forceMerge) { - val ordinals = (ix until ix + aggFn.updateAggregates.length) - aggOrdinals ++= ordinals - ix += ordinals.length - val updateAggs = aggFn.updateAggregates - postStepDataTypes ++= updateAggs.map(_.dataType) - cudfAggregates ++= updateAggs - preStep ++= aggFn.inputProjection - postStep ++= aggFn.postUpdate - postStepAttr ++= aggFn.postUpdateAttr - } else { - val ordinals = (ix until ix + aggFn.mergeAggregates.length) - aggOrdinals ++= ordinals - ix += ordinals.length - val mergeAggs = aggFn.mergeAggregates - postStepDataTypes ++= mergeAggs.map(_.dataType) - cudfAggregates ++= mergeAggs - preStep ++= aggFn.preMerge - postStep ++= aggFn.postMerge - postStepAttr ++= aggFn.postMergeAttr - } - } + /** + * Apply the "pre" step: preMerge for merge, or pass-through in the update case + * + * @param toAggregateBatch - input (to the agg) batch from the child directly in the + * merge case, or from the `inputProjection` in the update case. + * @return a pre-processed batch that can be later cuDF aggregated + */ + def preProcess( + toAggregateBatch: ColumnarBatch, + metrics: GpuHashAggregateMetrics): SpillableColumnarBatch = { + val inputBatch = SpillableColumnarBatch(toAggregateBatch, + SpillPriorities.ACTIVE_ON_DECK_PRIORITY) - // a bound expression that is applied before the cuDF aggregate - private val preStepAttributes = if (forceMerge) { - aggBufferAttributes - } else { - inputAttributes - } - private val preStepBound = GpuBindReferences.bindGpuReferencesTiered(preStep.toList, - preStepAttributes.toList, useTieredProject) - - // a bound expression that is applied after the cuDF aggregate - private val postStepBound = GpuBindReferences.bindGpuReferencesTiered(postStep.toList, - postStepAttr.toList, useTieredProject) - - /** - * Apply the "pre" step: preMerge for merge, or pass-through in the update case - * @param toAggregateBatch - input (to the agg) batch from the child directly in the - * merge case, or from the `inputProjection` in the update case. - * @return a pre-processed batch that can be later cuDF aggregated - */ - def preProcess( - toAggregateBatch: ColumnarBatch, - metrics: GpuHashAggregateMetrics): SpillableColumnarBatch = { - val inputBatch = SpillableColumnarBatch(toAggregateBatch, - SpillPriorities.ACTIVE_ON_DECK_PRIORITY) - - val projectedCb = withResource(new NvtxRange("pre-process", NvtxColor.DARK_GREEN)) { _ => - preStepBound.projectAndCloseWithRetrySingleBatch(inputBatch) - } - SpillableColumnarBatch( - projectedCb, - SpillPriorities.ACTIVE_BATCHING_PRIORITY) + val projectedCb = withResource(new NvtxRange("pre-process", NvtxColor.DARK_GREEN)) { _ => + preStepBound.projectAndCloseWithRetrySingleBatch(inputBatch) } + SpillableColumnarBatch( + projectedCb, + SpillPriorities.ACTIVE_BATCHING_PRIORITY) + } - def aggregate(preProcessed: ColumnarBatch): ColumnarBatch = { - if (groupingOrdinals.nonEmpty) { - performGroupByAggregation(preProcessed) - } else { - performReduction(preProcessed) - } + def aggregate(preProcessed: ColumnarBatch, numAggs: GpuMetric): ColumnarBatch = { + val ret = if (groupingOrdinals.nonEmpty) { + performGroupByAggregation(preProcessed) + } else { + performReduction(preProcessed) } + numAggs += 1 + ret + } - def aggregate( - metrics: GpuHashAggregateMetrics, - preProcessed: SpillableColumnarBatch): SpillableColumnarBatch = { - val aggregatedSeq = - withRetry(preProcessed, splitSpillableInHalfByRows) { preProcessedAttempt => + def aggregateWithoutCombine(metrics: GpuHashAggregateMetrics, + preProcessed: Iterator[SpillableColumnarBatch]): Iterator[SpillableColumnarBatch] = { + val computeAggTime = metrics.computeAggTime + val opTime = metrics.opTime + val numAggs = metrics.numAggOps + preProcessed.flatMap { sb => + withRetry(sb, splitSpillableInHalfByRows) { preProcessedAttempt => + withResource(new NvtxWithMetrics("computeAggregate", NvtxColor.CYAN, computeAggTime, + opTime)) { _ => withResource(preProcessedAttempt.getColumnarBatch()) { cb => SpillableColumnarBatch( - aggregate(cb), - SpillPriorities.ACTIVE_BATCHING_PRIORITY) - } - }.toSeq - - // We need to merge the aggregated batches into 1 before calling post process, - // if the aggregate code had to split on a retry - if (aggregatedSeq.size > 1) { - val concatted = concatenateBatches(metrics, aggregatedSeq) - withRetryNoSplit(concatted) { attempt => - withResource(attempt.getColumnarBatch()) { cb => - SpillableColumnarBatch( - aggregate(cb), + aggregate(cb, numAggs), SpillPriorities.ACTIVE_BATCHING_PRIORITY) } } - } else { - aggregatedSeq.head } } + } - /** - * Invoke reduction functions as defined in each `CudfAggreagte` - * @param preProcessed - a batch after the "pre" step - * @return - */ - def performReduction(preProcessed: ColumnarBatch): ColumnarBatch = { - withResource(new NvtxRange("reduce", NvtxColor.BLUE)) { _ => - val cvs = mutable.ArrayBuffer[GpuColumnVector]() - cudfAggregates.zipWithIndex.foreach { case (cudfAgg, ix) => - val aggFn = cudfAgg.reductionAggregate - val cols = GpuColumnVector.extractColumns(preProcessed) - val reductionCol = cols(aggOrdinals(ix)) - withResource(aggFn(reductionCol.getBase)) { res => - cvs += GpuColumnVector.from( - cudf.ColumnVector.fromScalar(res, 1), cudfAgg.dataType) - } + def aggregate( + metrics: GpuHashAggregateMetrics, + preProcessed: SpillableColumnarBatch): SpillableColumnarBatch = { + val numAggs = metrics.numAggOps + val aggregatedSeq = + withRetry(preProcessed, splitSpillableInHalfByRows) { preProcessedAttempt => + withResource(preProcessedAttempt.getColumnarBatch()) { cb => + SpillableColumnarBatch( + aggregate(cb, numAggs), + SpillPriorities.ACTIVE_BATCHING_PRIORITY) + } + }.toSeq + + // We need to merge the aggregated batches into 1 before calling post process, + // if the aggregate code had to split on a retry + if (aggregatedSeq.size > 1) { + val concatted = concatenateBatches(metrics, aggregatedSeq) + withRetryNoSplit(concatted) { attempt => + withResource(attempt.getColumnarBatch()) { cb => + SpillableColumnarBatch( + aggregate(cb, numAggs), + SpillPriorities.ACTIVE_BATCHING_PRIORITY) } - new ColumnarBatch(cvs.toArray, 1) } + } else { + aggregatedSeq.head } + } - /** - * Used to produce a group-by aggregate - * @param preProcessed the batch after the "pre" step - * @return a Table that has been cuDF aggregated - */ - def performGroupByAggregation(preProcessed: ColumnarBatch): ColumnarBatch = { - withResource(new NvtxRange("groupby", NvtxColor.BLUE)) { _ => - withResource(GpuColumnVector.from(preProcessed)) { preProcessedTbl => - val groupOptions = cudf.GroupByOptions.builder() - .withIgnoreNullKeys(false) - .withKeysSorted(isSorted) - .build() - - val cudfAggsOnColumn = cudfAggregates.zip(aggOrdinals).map { - case (cudfAgg, ord) => cudfAgg.groupByAggregate.onColumn(ord) - } + /** + * Invoke reduction functions as defined in each `CudfAggreagte` + * + * @param preProcessed - a batch after the "pre" step + * @return + */ + def performReduction(preProcessed: ColumnarBatch): ColumnarBatch = { + withResource(new NvtxRange("reduce", NvtxColor.BLUE)) { _ => + val cvs = mutable.ArrayBuffer[GpuColumnVector]() + cudfAggregates.zipWithIndex.foreach { case (cudfAgg, ix) => + val aggFn = cudfAgg.reductionAggregate + val cols = GpuColumnVector.extractColumns(preProcessed) + val reductionCol = cols(aggOrdinals(ix)) + withResource(aggFn(reductionCol.getBase)) { res => + cvs += GpuColumnVector.from( + cudf.ColumnVector.fromScalar(res, 1), cudfAgg.dataType) + } + } + new ColumnarBatch(cvs.toArray, 1) + } + } - // perform the aggregate - val aggTbl = preProcessedTbl - .groupBy(groupOptions, groupingOrdinals:_*) - .aggregate(cudfAggsOnColumn: _*) + /** + * Used to produce a group-by aggregate + * + * @param preProcessed the batch after the "pre" step + * @return a Table that has been cuDF aggregated + */ + def performGroupByAggregation(preProcessed: ColumnarBatch): ColumnarBatch = { + withResource(new NvtxRange("groupby", NvtxColor.BLUE)) { _ => + withResource(GpuColumnVector.from(preProcessed)) { preProcessedTbl => + val groupOptions = cudf.GroupByOptions.builder() + .withIgnoreNullKeys(false) + .withKeysSorted(doSortAgg) + .build() + + val cudfAggsOnColumn = cudfAggregates.zip(aggOrdinals).map { + case (cudfAgg, ord) => cudfAgg.groupByAggregate.onColumn(ord) + } - withResource(aggTbl) { _ => - GpuColumnVector.from(aggTbl, postStepDataTypes.toArray) - } + // perform the aggregate + val aggTbl = preProcessedTbl + .groupBy(groupOptions, groupingOrdinals: _*) + .aggregate(cudfAggsOnColumn: _*) + + withResource(aggTbl) { _ => + GpuColumnVector.from(aggTbl, postStepDataTypes.toArray) } } } + } - /** - * Used to produce the outbound batch from the aggregate that could be - * shuffled or could be passed through the evaluateExpression if we are in the final - * stage. - * It takes a cuDF aggregated batch and applies the "post" step: - * postUpdate for update, or postMerge for merge - * @param resultBatch - cuDF aggregated batch - * @return output batch from the aggregate - */ - def postProcess( - aggregatedSpillable: SpillableColumnarBatch, - metrics: GpuHashAggregateMetrics): SpillableColumnarBatch = { - val postProcessed = - withResource(new NvtxRange("post-process", NvtxColor.ORANGE)) { _ => - postStepBound.projectAndCloseWithRetrySingleBatch(aggregatedSpillable) - } - SpillableColumnarBatch( - postProcessed, - SpillPriorities.ACTIVE_BATCHING_PRIORITY) + /** + * Used to produce the outbound batch from the aggregate that could be + * shuffled or could be passed through the evaluateExpression if we are in the final + * stage. + * It takes a cuDF aggregated batch and applies the "post" step: + * postUpdate for update, or postMerge for merge + * + * @param resultBatch - cuDF aggregated batch + * @return output batch from the aggregate + */ + def postProcess( + aggregatedSpillable: SpillableColumnarBatch, + metrics: GpuHashAggregateMetrics): SpillableColumnarBatch = { + val postProcessed = + withResource(new NvtxRange("post-process", NvtxColor.ORANGE)) { _ => + postStepBound.projectAndCloseWithRetrySingleBatch(aggregatedSpillable) + } + SpillableColumnarBatch( + postProcessed, + SpillPriorities.ACTIVE_BATCHING_PRIORITY) + } + + def postProcess(input: Iterator[SpillableColumnarBatch], + metrics: GpuHashAggregateMetrics): Iterator[SpillableColumnarBatch] = { + val computeAggTime = metrics.computeAggTime + val opTime = metrics.opTime + input.map { aggregated => + withResource(new NvtxWithMetrics("post-process", NvtxColor.ORANGE, computeAggTime, + opTime)) { _ => + val postProcessed = postStepBound.projectAndCloseWithRetrySingleBatch(aggregated) + SpillableColumnarBatch( + postProcessed, + SpillPriorities.ACTIVE_BATCHING_PRIORITY) + } } } +} +object GpuAggregateIterator extends Logging { /** * @note abstracted away for a unit test.. * @param helper @@ -430,6 +482,30 @@ object GpuHashAggregateIterator extends Logging { } } + def computeAggregateWithoutPreprocessAndClose( + metrics: GpuHashAggregateMetrics, + inputBatches: Iterator[ColumnarBatch], + helper: AggHelper): Iterator[SpillableColumnarBatch] = { + val computeAggTime = metrics.computeAggTime + val opTime = metrics.opTime + // 1) a pre-processing step required before we go into the cuDF aggregate, This has already + // been done and is skipped + + val spillableInput = inputBatches.map { cb => + withResource(new MetricRange(computeAggTime, opTime)) { _ => + SpillableColumnarBatch(cb, SpillPriorities.ACTIVE_BATCHING_PRIORITY) + } + } + + // 2) perform the aggregation + // OOM retry means we could get a list of batches + val aggregatedSpillable = helper.aggregateWithoutCombine(metrics, spillableInput) + + // 3) a post-processing step required in some scenarios, casting or picking + // apart a struct + helper.postProcess(aggregatedSpillable, metrics) + } + /** * Concatenates batches after extracting them from `SpllableColumnarBatch` * @note the input batches are not closed as part of this operation @@ -469,6 +545,142 @@ object GpuHashAggregateIterator extends Logging { } } +object GpuAggFirstPassIterator { + def apply(cbIter: Iterator[ColumnarBatch], + aggHelper: AggHelper, + metrics: GpuHashAggregateMetrics): Iterator[SpillableColumnarBatch] = { + val preprocessProjectIter = cbIter.map { cb => + val sb = SpillableColumnarBatch (cb, SpillPriorities.ACTIVE_ON_DECK_PRIORITY) + aggHelper.preStepBound.projectAndCloseWithRetrySingleBatch (sb) + } + computeAggregateWithoutPreprocessAndClose(metrics, preprocessProjectIter, aggHelper) + } +} + +// Partial mode: +// * boundFinalProjections: is a pass-through of the agg buffer +// * boundResultReferences: is a pass-through of the merged aggregate +// +// Final mode: +// * boundFinalProjections: on merged batches, finalize aggregates +// (GpuAverage => CudfSum/CudfCount) +// * boundResultReferences: project the result expressions Spark expects in the output. +// +// Complete mode: +// * boundFinalProjections: on merged batches, finalize aggregates +// (GpuAverage => CudfSum/CudfCount) +// * boundResultReferences: project the result expressions Spark expects in the output. +case class BoundExpressionsModeAggregates( + boundFinalProjections: Option[Seq[GpuExpression]], + boundResultReferences: Seq[Expression]) + +object GpuAggFinalPassIterator { + + /** + * `setupReferences` binds input, final and result references for the aggregate. + * - input: used to obtain columns coming into the aggregate from the child + * - final: some aggregates like average use this to specify an expression to produce + * the final output of the aggregate. Average keeps sum and count throughout, + * and at the end it has to divide the two, to produce the single sum/count result. + * - result: used at the end to output to our parent + */ + def setupReferences( + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[GpuAggregateExpression], + aggregateAttributes: Seq[Attribute], + resultExpressions: Seq[NamedExpression], + modeInfo: AggregateModeInfo): BoundExpressionsModeAggregates = { + val groupingAttributes = groupingExpressions.map(_.toAttribute) + val aggBufferAttributes = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + + val boundFinalProjections = if (modeInfo.hasFinalMode || modeInfo.hasCompleteMode) { + val finalProjections = groupingAttributes ++ + aggregateExpressions.map(_.aggregateFunction.evaluateExpression) + Some(GpuBindReferences.bindGpuReferences(finalProjections, aggBufferAttributes)) + } else { + None + } + + // allAttributes can be different things, depending on aggregation mode: + // - Partial mode: grouping key + cudf aggregates (e.g. no avg, intead sum::count + // - Final mode: grouping key + spark aggregates (e.g. avg) + val finalAttributes = groupingAttributes ++ aggregateAttributes + + // boundResultReferences is used to project the aggregated input batch(es) for the result. + // - Partial mode: it's a pass through. We take whatever was aggregated and let it come + // out of the node as is. + // - Final or Complete mode: we use resultExpressions to pick out the correct columns that + // finalReferences has pre-processed for us + val boundResultReferences = if (modeInfo.hasPartialMode || modeInfo.hasPartialMergeMode) { + GpuBindReferences.bindGpuReferences( + resultExpressions, + resultExpressions.map(_.toAttribute)) + } else if (modeInfo.hasFinalMode || modeInfo.hasCompleteMode) { + GpuBindReferences.bindGpuReferences( + resultExpressions, + finalAttributes) + } else { + GpuBindReferences.bindGpuReferences( + resultExpressions, + groupingAttributes) + } + BoundExpressionsModeAggregates( + boundFinalProjections, + boundResultReferences) + } + + private[this] def reorderFinalBatch(finalBatch: ColumnarBatch, + boundExpressions: BoundExpressionsModeAggregates, + metrics: GpuHashAggregateMetrics): ColumnarBatch = { + // Perform the last project to get the correct shape that Spark expects. Note this may + // add things like literals that were not part of the aggregate into the batch. + closeOnExcept(GpuProjectExec.projectAndClose(finalBatch, + boundExpressions.boundResultReferences, NoopMetric)) { ret => + metrics.numOutputRows += ret.numRows() + metrics.numOutputBatches += 1 + ret + } + } + + def makeIter(cbIter: Iterator[ColumnarBatch], + boundExpressions: BoundExpressionsModeAggregates, + metrics: GpuHashAggregateMetrics): Iterator[ColumnarBatch] = { + val aggTime = metrics.computeAggTime + val opTime = metrics.opTime + cbIter.map { batch => + withResource(new NvtxWithMetrics("finalize agg", NvtxColor.DARK_GREEN, aggTime, + opTime)) { _ => + val finalBatch = boundExpressions.boundFinalProjections.map { exprs => + GpuProjectExec.projectAndClose(batch, exprs, NoopMetric) + }.getOrElse(batch) + reorderFinalBatch(finalBatch, boundExpressions, metrics) + } + } + } + + def makeIterFromSpillable(sbIter: Iterator[SpillableColumnarBatch], + boundExpressions: BoundExpressionsModeAggregates, + metrics: GpuHashAggregateMetrics): Iterator[ColumnarBatch] = { + val aggTime = metrics.computeAggTime + val opTime = metrics.opTime + sbIter.map { sb => + withResource(new NvtxWithMetrics("finalize agg", NvtxColor.DARK_GREEN, aggTime, + opTime)) { _ => + val finalBatch = boundExpressions.boundFinalProjections.map { exprs => + GpuProjectExec.projectAndCloseWithRetrySingleBatch(sb, exprs) + }.getOrElse { + withRetryNoSplit(sb) { _ => + sb.getColumnarBatch() + } + } + reorderFinalBatch(finalBatch, boundExpressions, metrics) + } + } + } +} + + /** * Iterator that takes another columnar batch iterator as input and emits new columnar batches that * are aggregated based on the specified grouping and aggregation expressions. This iterator tries @@ -484,7 +696,7 @@ object GpuHashAggregateIterator extends Logging { * `buildSortFallbackIterator` is used to sort the aggregated batches by the grouping keys and * performs a final merge aggregation pass on the sorted batches. * - * @param cbIter iterator providing the input columnar batches + * @param firstPassIter iterator that has done a first aggregation pass over the input data. * @param inputAttributes input attributes to identify the input columns from the input batches * @param groupingExpressions expressions used for producing the grouping keys * @param aggregateExpressions GPU aggregate expressions used to produce the aggregations @@ -495,8 +707,8 @@ object GpuHashAggregateIterator extends Logging { * @param configuredTargetBatchSize user-specified value for the targeted input batch size * @param useTieredProject user-specified option to enable tiered projections */ -class GpuHashAggregateIterator( - cbIter: Iterator[ColumnarBatch], +class GpuMergeAggregateIterator( + firstPassIter: Iterator[SpillableColumnarBatch], inputAttributes: Seq[Attribute], groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[GpuAggregateExpression], @@ -507,29 +719,7 @@ class GpuHashAggregateIterator( configuredTargetBatchSize: Long, useTieredProject: Boolean) extends Iterator[ColumnarBatch] with AutoCloseable with Logging { - - // Partial mode: - // 1. boundInputReferences: picks column from raw input - // 2. boundFinalProjections: is a pass-through of the agg buffer - // 3. boundResultReferences: is a pass-through of the merged aggregate - // - // Final mode: - // 1. boundInputReferences: is a pass-through of the merged aggregate - // 2. boundFinalProjections: on merged batches, finalize aggregates - // (GpuAverage => CudfSum/CudfCount) - // 3. boundResultReferences: project the result expressions Spark expects in the output. - // - // Complete mode: - // 1. boundInputReferences: picks column from raw input - // 2. boundFinalProjections: on merged batches, finalize aggregates - // (GpuAverage => CudfSum/CudfCount) - // 3. boundResultReferences: project the result expressions Spark expects in the output. - private case class BoundExpressionsModeAggregates( - boundFinalProjections: Option[Seq[GpuExpression]], - boundResultReferences: Seq[Expression]) - private[this] val isReductionOnly = groupingExpressions.isEmpty - private[this] val boundExpressions = setupReferences() private[this] val targetMergeBatchSize = computeTargetMergeBatchSize(configuredTargetBatchSize) private[this] val aggregatedBatches = new util.ArrayDeque[SpillableColumnarBatch] private[this] var outOfCoreIter: Option[GpuOutOfCoreSortIterator] = None @@ -545,14 +735,14 @@ class GpuHashAggregateIterator( override def hasNext: Boolean = { sortFallbackIter.map(_.hasNext).getOrElse { // reductions produce a result even if the input is empty - hasReductionOnlyBatch || !aggregatedBatches.isEmpty || cbIter.hasNext + hasReductionOnlyBatch || !aggregatedBatches.isEmpty || firstPassIter.hasNext } } override def next(): ColumnarBatch = { - val batch = sortFallbackIter.map(_.next()).getOrElse { + sortFallbackIter.map(_.next()).getOrElse { // aggregate and merge all pending inputs - if (cbIter.hasNext) { + if (firstPassIter.hasNext) { aggregateInputBatches() tryMergeAggregatedBatches() } @@ -576,8 +766,6 @@ class GpuHashAggregateIterator( } } } - - finalProjectBatch(batch) } override def close(): Unit = { @@ -596,12 +784,9 @@ class GpuHashAggregateIterator( /** Aggregate all input batches and place the results in the aggregatedBatches queue. */ private def aggregateInputBatches(): Unit = { - val aggHelper = new AggHelper( - inputAttributes, groupingExpressions, aggregateExpressions, - forceMerge = false, useTieredProject = useTieredProject) - while (cbIter.hasNext) { - aggregatedBatches.add( - computeAggregateAndClose(metrics, cbIter.next(), aggHelper)) + // cache everything in the first pass + while (firstPassIter.hasNext) { + aggregatedBatches.add(firstPassIter.next()) } } @@ -746,7 +931,6 @@ class GpuHashAggregateIterator( outOfCoreIter = Some(GpuOutOfCoreSortIterator( aggregatedBatchIter, sorter, - LazilyGeneratedOrdering.forSchema(TrampolineUtil.fromAttributes(groupingAttributes)), configuredTargetBatchSize, opTime = metrics.opTime, sortTime = metrics.sortTime, @@ -808,82 +992,6 @@ class GpuHashAggregateIterator( } new ColumnarBatch(vecs.toArray, 1) } - - /** - * Project a merged aggregated batch result to the layout that Spark expects - * i.e.: select avg(foo) from bar group by baz will produce: - * Partial mode: 3 columns => [bar, sum(foo) as sum_foo, count(foo) as count_foo] - * Final mode: 2 columns => [bar, sum(sum_foo) / sum(count_foo)] - */ - private def finalProjectBatch(batch: ColumnarBatch): ColumnarBatch = { - val aggTime = metrics.computeAggTime - val opTime = metrics.opTime - withResource(new NvtxWithMetrics("finalize agg", NvtxColor.DARK_GREEN, aggTime, - opTime)) { _ => - val finalBatch = boundExpressions.boundFinalProjections.map { exprs => - GpuProjectExec.projectAndClose(batch, exprs, NoopMetric) - }.getOrElse(batch) - - // Perform the last project to get the correct shape that Spark expects. Note this may - // add things like literals that were not part of the aggregate into the batch. - closeOnExcept(GpuProjectExec.projectAndClose(finalBatch, - boundExpressions.boundResultReferences, NoopMetric)) { ret => - metrics.numOutputRows += ret.numRows() - metrics.numOutputBatches += 1 - ret - } - } - } - - /** - * `setupReferences` binds input, final and result references for the aggregate. - * - input: used to obtain columns coming into the aggregate from the child - * - final: some aggregates like average use this to specify an expression to produce - * the final output of the aggregate. Average keeps sum and count throughout, - * and at the end it has to divide the two, to produce the single sum/count result. - * - result: used at the end to output to our parent - */ - private def setupReferences(): BoundExpressionsModeAggregates = { - val groupingAttributes = groupingExpressions.map(_.toAttribute) - val aggBufferAttributes = groupingAttributes ++ - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) - - val boundFinalProjections = if (modeInfo.hasFinalMode || modeInfo.hasCompleteMode) { - val finalProjections = groupingAttributes ++ - aggregateExpressions.map(_.aggregateFunction.evaluateExpression) - Some(GpuBindReferences.bindGpuReferences(finalProjections, aggBufferAttributes)) - } else { - None - } - - // allAttributes can be different things, depending on aggregation mode: - // - Partial mode: grouping key + cudf aggregates (e.g. no avg, intead sum::count - // - Final mode: grouping key + spark aggregates (e.g. avg) - val finalAttributes = groupingAttributes ++ aggregateAttributes - - // boundResultReferences is used to project the aggregated input batch(es) for the result. - // - Partial mode: it's a pass through. We take whatever was aggregated and let it come - // out of the node as is. - // - Final or Complete mode: we use resultExpressions to pick out the correct columns that - // finalReferences has pre-processed for us - val boundResultReferences = if (modeInfo.hasPartialMode || modeInfo.hasPartialMergeMode) { - GpuBindReferences.bindGpuReferences( - resultExpressions, - resultExpressions.map(_.toAttribute)) - } else if (modeInfo.hasFinalMode || modeInfo.hasCompleteMode) { - GpuBindReferences.bindGpuReferences( - resultExpressions, - finalAttributes) - } else { - GpuBindReferences.bindGpuReferences( - resultExpressions, - groupingAttributes) - } - BoundExpressionsModeAggregates( - boundFinalProjections, - boundResultReferences) - } - } object GpuBaseAggregateMeta { @@ -1046,16 +1154,99 @@ abstract class GpuBaseAggregateMeta[INPUT <: SparkPlan]( } } + private val orderable = + (pluginSupportedOrderableSig + TypeSig.DECIMAL_128 + TypeSig.STRUCT).nested() + override def convertToGpu(): GpuExec = { + lazy val aggModes = agg.aggregateExpressions.map(_.mode).toSet + lazy val canUsePartialSortAgg = aggModes.forall { mode => + mode == Partial || mode == PartialMerge + } && agg.groupingExpressions.nonEmpty // Don't do this for a reduce... + + lazy val groupingCanBeSorted = agg.groupingExpressions.forall { expr => + orderable.isSupportedByPlugin(expr.dataType) + } + + // This is a short term heuristic until we can better understand the cost + // of sort vs the cost of doing aggregations so we can better decide. + lazy val hasSingleBasicGroupingKey = agg.groupingExpressions.length == 1 && + agg.groupingExpressions.headOption.map(_.dataType).exists { + case StringType | BooleanType | ByteType | ShortType | IntegerType | + LongType | _: DecimalType | DateType | TimestampType => true + case _ => false + } + + val gpuChild = childPlans.head.convertIfNeeded() + val gpuAggregateExpressions = + aggregateExpressions.map(_.convertToGpu().asInstanceOf[GpuAggregateExpression]) + val gpuGroupingExpressions = + groupingExpressions.map(_.convertToGpu().asInstanceOf[NamedExpression]) + val useTiered = conf.isTieredProjectEnabled + + // Sorting before we do an aggregation helps if the size of the input data is + // smaller than the size of the output data. But it is an aggregation how can + // the output be larger than the input? This happens if there are a lot of + // aggregations relative to the number of input columns. So if we will have + // to cache/spill data we would rather cache and potentially spill the smaller + // data. By sorting the data on a partial agg we don't have to worry about + // combining the output of multiple batch because the final agg will take care + // of it, and in the worst case we will end up with one extra row per-batch that + // would need to be shuffled to the final aggregate. An okay tradeoff. + // + // The formula for the estimated output size after aggregation is + // input * pre-project-growth * agg-reduction. + // + // We are going to estimate the pre-project-growth here, and if it looks like + // it is going to be larger than the input, then we will dynamically estimate + // the agg-reduction in each task and make a choice there what to do. + + lazy val estimatedPreProcessGrowth = { + val inputAggBufferAttributes = + GpuHashAggregateExecBase.calcInputAggBufferAttributes(gpuAggregateExpressions) + val inputAttrs = GpuHashAggregateExecBase.calcInputAttributes(gpuAggregateExpressions, + gpuChild.output, inputAggBufferAttributes) + val preProcessAggHelper = new AggHelper( + inputAttrs, gpuGroupingExpressions, gpuAggregateExpressions, + forceMerge = false, useTieredProject = useTiered) + + // We are going to estimate the growth by looking at the estimated size the output could + // be compared to the estimated size of the input (both based off of the schemas). + // It if far from perfect, but it should work okay-ish. + val numRowsForEstimate = 1000000 // 1 million rows... + val estimatedInputSize = gpuChild.output.map { attr => + GpuBatchUtils.estimateGpuMemory(attr.dataType, attr.nullable, numRowsForEstimate) + }.sum + val estimatedOutputSize = preProcessAggHelper.preStepBound.outputTypes.map { dt => + GpuBatchUtils.estimateGpuMemory(dt, true, numRowsForEstimate) + }.sum + if (estimatedInputSize == 0 && estimatedOutputSize == 0) { + 1.0 + } else if (estimatedInputSize == 0) { + 100.0 + } else { + estimatedOutputSize.toDouble / estimatedInputSize + } + } + + val allowSinglePassAgg = (conf.forceSinglePassPartialSortAgg || + (conf.allowSinglePassPartialSortAgg && + hasSingleBasicGroupingKey && + estimatedPreProcessGrowth > 1.1)) && + canUsePartialSortAgg && + groupingCanBeSorted + GpuHashAggregateExec( aggRequiredChildDistributionExpressions, - groupingExpressions.map(_.convertToGpu().asInstanceOf[NamedExpression]), - aggregateExpressions.map(_.convertToGpu().asInstanceOf[GpuAggregateExpression]), + gpuGroupingExpressions, + gpuAggregateExpressions, aggregateAttributes.map(_.convertToGpu().asInstanceOf[Attribute]), resultExpressions.map(_.convertToGpu().asInstanceOf[NamedExpression]), - childPlans.head.convertIfNeeded(), + gpuChild, conf.gpuTargetBatchSizeBytes, - conf.isTieredProjectEnabled) + useTiered, + estimatedPreProcessGrowth, + conf.forceSinglePassPartialSortAgg, + allowSinglePassAgg) } } @@ -1137,7 +1328,11 @@ abstract class GpuTypedImperativeSupportedAggregateExecMeta[INPUT <: BaseAggrega retExpressions.map(_.convertToGpu().asInstanceOf[NamedExpression]), childPlans.head.convertIfNeeded(), conf.gpuTargetBatchSizeBytes, - conf.isTieredProjectEnabled) + conf.isTieredProjectEnabled, + // For now we are just going to go with the original hash aggregation + 1.0, + false, + false) } else { super.convertToGpu() } @@ -1437,8 +1632,43 @@ class GpuObjectHashAggregateExecMeta( extends GpuTypedImperativeSupportedAggregateExecMeta(agg, agg.requiredChildDistributionExpressions, conf, parent, rule) +object GpuHashAggregateExecBase { + + def calcInputAttributes(aggregateExpressions: Seq[GpuAggregateExpression], + childOutput: Seq[Attribute], + inputAggBufferAttributes: Seq[Attribute]): Seq[Attribute] = { + val modes = aggregateExpressions.map(_.mode).distinct + if (modes.contains(Final) || modes.contains(PartialMerge)) { + // SPARK-31620: when planning aggregates, the partial aggregate uses aggregate function's + // `inputAggBufferAttributes` as its output. And Final and PartialMerge aggregate rely on the + // output to bind references used by `mergeAggregates`. But if we copy the + // aggregate function somehow after aggregate planning, the `DeclarativeAggregate` will + // be replaced by a new instance with new `inputAggBufferAttributes`. Then Final and + // PartialMerge aggregate can't bind the references used by `mergeAggregates` with the output + // of the partial aggregate, as they use the `inputAggBufferAttributes` of the + // original `DeclarativeAggregate` before copy. Instead, we shall use + // `inputAggBufferAttributes` after copy to match the new `mergeExpressions`. + val aggAttrs = inputAggBufferAttributes + childOutput.dropRight(aggAttrs.length) ++ aggAttrs + } else { + childOutput + } + } + + def calcInputAggBufferAttributes(aggregateExpressions: Seq[GpuAggregateExpression]): + Seq[Attribute] = { + aggregateExpressions + // there are exactly four cases needs `inputAggBufferAttributes` from child according to the + // agg planning in `AggUtils`: Partial -> Final, PartialMerge -> Final, + // Partial -> PartialMerge, PartialMerge -> PartialMerge. + .filter(a => a.mode == Final || a.mode == PartialMerge) + .flatMap(_.aggregateFunction.aggBufferAttributes) + } +} + /** - * The GPU version of HashAggregateExec + * The GPU version of SortAggregateExec that is intended for partial aggregations that are not + * reductions and so it sorts the input data ahead of time to do it in a single pass. * * @param requiredChildDistributionExpressions this is unchanged by the GPU. It is used in * EnsureRequirements to be able to add shuffle nodes @@ -1460,38 +1690,21 @@ case class GpuHashAggregateExec( resultExpressions: Seq[NamedExpression], child: SparkPlan, configuredTargetBatchSize: Long, - configuredTieredProjectEnabled: Boolean) extends ShimUnaryExecNode with GpuExec { + configuredTieredProjectEnabled: Boolean, + estimatedPreProcessGrowth: Double, + forceSinglePassAgg: Boolean, + allowSinglePassAgg: Boolean) extends ShimUnaryExecNode with GpuExec { // lifted directly from `BaseAggregateExec.inputAttributes`, edited comment. - def inputAttributes: Seq[Attribute] = { - val modes = aggregateExpressions.map(_.mode).distinct - if (modes.contains(Final) || modes.contains(PartialMerge)) { - // SPARK-31620: when planning aggregates, the partial aggregate uses aggregate function's - // `inputAggBufferAttributes` as its output. And Final and PartialMerge aggregate rely on the - // output to bind references used by `mergeAggregates`. But if we copy the - // aggregate function somehow after aggregate planning, the `DeclarativeAggregate` will - // be replaced by a new instance with new `inputAggBufferAttributes`. Then Final and - // PartialMerge aggregate can't bind the references used by `mergeAggregates` with the output - // of the partial aggregate, as they use the `inputAggBufferAttributes` of the - // original `DeclarativeAggregate` before copy. Instead, we shall use - // `inputAggBufferAttributes` after copy to match the new `mergeExpressions`. - val aggAttrs = inputAggBufferAttributes - child.output.dropRight(aggAttrs.length) ++ aggAttrs - } else { - child.output - } - } + def inputAttributes: Seq[Attribute] = + GpuHashAggregateExecBase.calcInputAttributes(aggregateExpressions, + child.output, + inputAggBufferAttributes) - private val inputAggBufferAttributes: Seq[Attribute] = { - aggregateExpressions - // there're exactly four cases needs `inputAggBufferAttributes` from child according to the - // agg planning in `AggUtils`: Partial -> Final, PartialMerge -> Final, - // Partial -> PartialMerge, PartialMerge -> PartialMerge. - .filter(a => a.mode == Final || a.mode == PartialMerge) - .flatMap(_.aggregateFunction.aggBufferAttributes) - } + private val inputAggBufferAttributes: Seq[Attribute] = + GpuHashAggregateExecBase.calcInputAggBufferAttributes(aggregateExpressions) - private lazy val uniqueModes: Seq[AggregateMode] = aggregateExpressions.map(_.mode).distinct + protected lazy val uniqueModes: Seq[AggregateMode] = aggregateExpressions.map(_.mode).distinct protected override val outputRowsLevel: MetricsLevel = ESSENTIAL_LEVEL protected override val outputBatchesLevel: MetricsLevel = MODERATE_LEVEL @@ -1500,7 +1713,11 @@ case class GpuHashAggregateExec( OP_TIME -> createNanoTimingMetric(MODERATE_LEVEL, DESCRIPTION_OP_TIME), AGG_TIME -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_AGG_TIME), CONCAT_TIME -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_CONCAT_TIME), - SORT_TIME -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_SORT_TIME) + SORT_TIME -> createNanoTimingMetric(DEBUG_LEVEL, DESCRIPTION_SORT_TIME), + "NUM_AGGS" -> createMetric(DEBUG_LEVEL, "num agg operations"), + "NUM_PRE_SPLITS" -> createMetric(DEBUG_LEVEL, "num pre splits"), + "NUM_TASKS_SINGLE_PASS" -> createMetric(MODERATE_LEVEL, "number of single pass tasks"), + "HEURISTIC_TIME" -> createNanoTimingMetric(DEBUG_LEVEL, "time in heuristic") ) // requiredChildDistributions are CPU expressions, so remove it from the GPU expressions list @@ -1526,7 +1743,11 @@ case class GpuHashAggregateExec( opTime = gpuLongMetric(OP_TIME), computeAggTime = gpuLongMetric(AGG_TIME), concatTime = gpuLongMetric(CONCAT_TIME), - sortTime = gpuLongMetric(SORT_TIME)) + sortTime = gpuLongMetric(SORT_TIME), + numAggOps = gpuLongMetric("NUM_AGGS"), + numPreSplits = gpuLongMetric("NUM_PRE_SPLITS"), + singlePassTasks = gpuLongMetric("NUM_TASKS_SINGLE_PASS"), + heuristicTime = gpuLongMetric("HEURISTIC_TIME")) // cache in a local variable to avoid serializing the full child plan val inputAttrs = inputAttributes @@ -1540,18 +1761,24 @@ case class GpuHashAggregateExec( val rdd = child.executeColumnar() + val localForcePre = forceSinglePassAgg + val localAllowPre = allowSinglePassAgg + val expectedOrdering = expectedChildOrderingIfNeeded + val alreadySorted = SortOrder.orderingSatisfies(child.outputOrdering, + expectedOrdering) && expectedOrdering.nonEmpty + val localEstimatedPreProcessGrowth = estimatedPreProcessGrowth + + val boundGroupExprs = GpuBindReferences.bindGpuReferencesTiered(groupingExprs, inputAttrs, true) + rdd.mapPartitions { cbIter => - new GpuHashAggregateIterator( - cbIter, - inputAttrs, - groupingExprs, - aggregateExprs, - aggregateAttrs, - resultExprs, - modeInfo, - aggMetrics, - targetBatchSize, - useTieredProject) + val postBoundReferences = GpuAggFinalPassIterator.setupReferences(groupingExprs, + aggregateExprs, aggregateAttrs, resultExprs, modeInfo) + + new DynamicGpuPartialSortAggregateIterator(cbIter, inputAttrs, groupingExprs, + boundGroupExprs, aggregateExprs, aggregateAttrs, resultExprs, modeInfo, + localEstimatedPreProcessGrowth, alreadySorted, expectedOrdering, + postBoundReferences, targetBatchSize, aggMetrics, useTieredProject, + localForcePre, localAllowPre) } } @@ -1586,7 +1813,7 @@ case class GpuHashAggregateExec( protected def replaceAlias(attr: AttributeReference): Option[Attribute] = { outputExpressions.collectFirst { - case a @ Alias(child: AttributeReference, _) if child.semanticEquals(attr) => + case a@Alias(child: AttributeReference, _) if child.semanticEquals(attr) => a.toAttribute } } @@ -1642,4 +1869,152 @@ case class GpuHashAggregateExec( // // End copies from HashAggregateExec // + + // We are not going to override requiredChildOrdering because we don't want to + // always sort the data. So we will insert in the sort ourselves if we need to + def expectedChildOrderingIfNeeded: Seq[SortOrder] = { + groupingExpressions.map(SortOrder(_, Ascending)) + } } + +class DynamicGpuPartialSortAggregateIterator( + cbIter: Iterator[ColumnarBatch], + inputAttrs: Seq[Attribute], + groupingExprs: Seq[NamedExpression], + boundGroupExprs: GpuTieredProject, + aggregateExprs: Seq[GpuAggregateExpression], + aggregateAttrs: Seq[Attribute], + resultExprs: Seq[NamedExpression], + modeInfo: AggregateModeInfo, + estimatedPreGrowth: Double, + alreadySorted: Boolean, + ordering: Seq[SortOrder], + postBoundReferences: BoundExpressionsModeAggregates, + configuredTargetBatchSize: Long, + metrics: GpuHashAggregateMetrics, + useTiered: Boolean, + forceSinglePassAgg: Boolean, + allowSinglePassAgg: Boolean) extends Iterator[ColumnarBatch] { + private var aggIter: Option[Iterator[ColumnarBatch]] = None + private[this] val isReductionOnly = boundGroupExprs.outputTypes.isEmpty + + // When doing a reduction we don't have the aggIter setup for the very first time + // so we have to match what happens for the normal reduction operations. + override def hasNext: Boolean = aggIter.map(_.hasNext) + .getOrElse(isReductionOnly || cbIter.hasNext) + + private[this] def estimateCardinality(cb: ColumnarBatch): Int = { + withResource(boundGroupExprs.project(cb)) { groupingKeys => + withResource(GpuColumnVector.from(groupingKeys)) { table => + table.distinctCount() + } + } + } + + private[this] def firstBatchHeuristic( + cbIter: Iterator[ColumnarBatch], + helper: AggHelper): (Iterator[ColumnarBatch], Boolean) = { + // we need to decide if we are going to sort the data or not, so the very + // first thing we need to do is get a batch and make a choice. + withResource(new NvtxWithMetrics("dynamic sort heuristic", NvtxColor.BLUE, + metrics.opTime, metrics.heuristicTime)) { _ => + val cb = cbIter.next() + lazy val estimatedGrowthAfterAgg: Double = closeOnExcept(cb) { cb => + val numRows = cb.numRows() + val cardinality = estimateCardinality(cb) + val minPreGrowth = PreProjectSplitIterator.calcMinOutputSize(cb, + helper.preStepBound).toDouble / GpuColumnVector.getTotalDeviceMemoryUsed(cb) + (math.max(minPreGrowth, estimatedPreGrowth) * cardinality) / numRows + } + val wrappedIter = Seq(cb).toIterator ++ cbIter + (wrappedIter, estimatedGrowthAfterAgg > 1.0) + } + } + + private[this] def singlePassSortedAgg( + inputIter: Iterator[ColumnarBatch], + preProcessAggHelper: AggHelper): Iterator[ColumnarBatch] = { + // The data is already sorted so just do the sorted agg either way... + val sortedIter = if (alreadySorted) { + inputIter + } else { + val sorter = new GpuSorter(ordering, inputAttrs) + GpuOutOfCoreSortIterator(inputIter, + sorter, + configuredTargetBatchSize, + opTime = metrics.opTime, + sortTime = metrics.sortTime, + outputBatches = NoopMetric, + outputRows = NoopMetric) + } + + // After sorting we want to split the input for the project so that + // we don't get ourselves in trouble. + val sortedSplitIter = new PreProjectSplitIterator(sortedIter, + inputAttrs.map(_.dataType).toArray, preProcessAggHelper.preStepBound, + metrics.opTime, metrics.numPreSplits) + + val firstPassIter = GpuAggFirstPassIterator(sortedSplitIter, preProcessAggHelper, metrics) + + // Technically on a partial-agg, which this only works for, this last iterator should + // be a noop except for some metrics. But for consistency between all of the + // agg paths and to be more resilient in the future with code changes we include a final pass + // iterator here. + GpuAggFinalPassIterator.makeIterFromSpillable(firstPassIter, postBoundReferences, metrics) + } + + private[this] def fullHashAggWithMerge( + inputIter: Iterator[ColumnarBatch], + preProcessAggHelper: AggHelper): Iterator[ColumnarBatch] = { + // We still want to split the input, because the heuristic may not be perfect and + // this is relatively light weight + val splitInputIter = new PreProjectSplitIterator(inputIter, + inputAttrs.map(_.dataType).toArray, preProcessAggHelper.preStepBound, + metrics.opTime, metrics.numPreSplits) + + val firstPassIter = GpuAggFirstPassIterator(splitInputIter, preProcessAggHelper, metrics) + + val mergeIter = new GpuMergeAggregateIterator( + firstPassIter, + inputAttrs, + groupingExprs, + aggregateExprs, + aggregateAttrs, + resultExprs, + modeInfo, + metrics, + configuredTargetBatchSize, + useTiered) + + GpuAggFinalPassIterator.makeIter(mergeIter, postBoundReferences, metrics) + } + + override def next(): ColumnarBatch = { + if (aggIter.isEmpty) { + val preProcessAggHelper = new AggHelper( + inputAttrs, groupingExprs, aggregateExprs, + forceMerge = false, isSorted = true, useTieredProject = useTiered) + val (inputIter, doSinglePassAgg) = if (allowSinglePassAgg) { + if (forceSinglePassAgg || alreadySorted) { + (cbIter, true) + } else { + firstBatchHeuristic(cbIter, preProcessAggHelper) + } + } else { + (cbIter, false) + } + val newIter = if (doSinglePassAgg) { + metrics.singlePassTasks += 1 + singlePassSortedAgg(inputIter, preProcessAggHelper) + } else { + // Not sorting so go back to that + preProcessAggHelper.setSort(false) + fullHashAggWithMerge(inputIter, preProcessAggHelper) + } + aggIter = Some(newIter) + } + aggIter.map(_.next()).getOrElse { + throw new NoSuchElementException() + } + } +} \ No newline at end of file diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala index 6d6d304606a..8f912e68b52 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala @@ -21,10 +21,10 @@ import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf import ai.rapids.cudf._ -import com.nvidia.spark.rapids.Arm.withResource +import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.GpuMetric._ import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.RmmRapidsRetryIterator.{splitSpillableInHalfByRows, withRetry} +import com.nvidia.spark.rapids.RmmRapidsRetryIterator.{splitSpillableInHalfByRows, withRetry, withRetryNoSplit} import com.nvidia.spark.rapids.shims._ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} @@ -210,6 +210,116 @@ trait GpuProjectExecLike extends ShimUnaryExecNode with GpuExec { override def outputBatching: CoalesceGoal = GpuExec.outputBatching(child) } +/** + * An iterator that is intended to split the input to or output of a project on rows. + * In practice this is only used for splitting the input prior to a project in some + * very special cases. If the projected size of the output is so large that it would + * risk us not being able to split it later on if we ran into trouble. + * @param iter the input iterator of columnar batches. + * @param schema the schema of that input so we can make things spillable if needed + * @param opTime metric for how long this took + * @param numSplitsMetric metric for the number of splits that happened. + */ +abstract class AbstractProjectSplitIterator(iter: Iterator[ColumnarBatch], + schema: Array[DataType], + opTime: GpuMetric, + numSplitsMetric: GpuMetric) extends Iterator[ColumnarBatch] { + private[this] val pending = new scala.collection.mutable.Queue[SpillableColumnarBatch]() + + override def hasNext: Boolean = pending.nonEmpty || iter.hasNext + + protected def calcNumSplits(cb: ColumnarBatch): Int + + override def next(): ColumnarBatch = { + if (!hasNext) { + throw new NoSuchElementException() + } else if (pending.nonEmpty) { + opTime.ns { + withRetryNoSplit(pending.dequeue()) { sb => + sb.getColumnarBatch() + } + } + } else { + val cb = iter.next() + opTime.ns { + val numSplits = closeOnExcept(cb) { cb => + calcNumSplits(cb) + } + if (numSplits <= 1) { + cb + } else { + numSplitsMetric += numSplits - 1 + val sb = SpillableColumnarBatch(cb, SpillPriorities.ACTIVE_ON_DECK_PRIORITY) + val tables = withRetryNoSplit(sb) { sb => + withResource(sb.getColumnarBatch()) { cb => + withResource(GpuColumnVector.from(cb)) { table => + val rows = table.getRowCount.toInt + val rowsPerSplit = math.ceil(rows.toDouble / numSplits).toInt + val splitIndexes = rowsPerSplit until rows by rowsPerSplit + table.contiguousSplit(splitIndexes: _*) + } + } + } + withResource(tables) { tables => + val ret = tables.head.getTable + tables.tail.foreach { ct => + pending.enqueue( + SpillableColumnarBatch(ct, schema, SpillPriorities.ACTIVE_BATCHING_PRIORITY)) + } + GpuColumnVector.from(ret, schema) + } + } + } + } + } +} + +object PreProjectSplitIterator { + def calcMinOutputSize(cb: ColumnarBatch, boundExprs: GpuTieredProject): Long = { + val numRows = cb.numRows() + boundExprs.outputTypes.zipWithIndex.map { + case (dataType, index) => + if (GpuBatchUtils.isFixedWidth(dataType)) { + GpuBatchUtils.minGpuMemory(dataType, true, numRows) + } else { + boundExprs.getPassThroughIndex(index).map { inputIndex => + cb.column(inputIndex).asInstanceOf[GpuColumnVector].getBase.getDeviceMemorySize + }.getOrElse { + GpuBatchUtils.minGpuMemory(dataType, true, numRows) + } + } + }.sum + } +} + +/** + * An iterator that can be used to split the input of a project before it happens to prevent + * situations where the output could not be split later on. In testing we tried to see what + * would happen if we split it to the target batch size, but there was a very significant + * performance degradation when that happened. For now this is only used in a few specific + * places and not everywhere. In the future this could be extended, but if we do that there + * are some places where we don't want a split, like a project before a window operation. + * @param iter the input iterator of columnar batches. + * @param schema the schema of that input so we can make things spillable if needed + * @param boundExprs the bound project so we can get a good idea of the output size. + * @param opTime metric for how long this took + * @param numSplits the number of splits that happened. + */ +class PreProjectSplitIterator( + iter: Iterator[ColumnarBatch], + schema: Array[DataType], + boundExprs: GpuTieredProject, + opTime: GpuMetric, + numSplits: GpuMetric) extends AbstractProjectSplitIterator(iter, schema, opTime, numSplits) { + + override def calcNumSplits(cb: ColumnarBatch): Int = { + val minOutputSize = PreProjectSplitIterator.calcMinOutputSize(cb, boundExprs) + // If the minimum size is too large we will split before doing the project, to help avoid + // extreme cases where the output size is so large that we cannot split it afterwards. + math.max(1, math.ceil(minOutputSize / GpuDeviceManager.getSplitUntilSize.toDouble).toInt) + } +} + case class GpuProjectExec( // NOTE for Scala 2.12.x and below we enforce usage of (eager) List to prevent running // into a deep recursion during serde of lazy lists. See @@ -364,6 +474,38 @@ case class GpuProjectAstExec( } } + lazy val outputTypes = exprTiers.last.map(_.dataType).toArray + + private[this] def getPassThroughIndex(tierIndex: Int, + expr: Expression, + exprIndex: Int): Option[Int] = expr match { + case GpuAlias(child, _) => + getPassThroughIndex(tierIndex, child, exprIndex) + case GpuBoundReference(index, _, _) => + if (tierIndex <= 0) { + // We are at the input tier so the bound attribute is good!!! + Some(index) + } else { + // Not at the input yet + val newTier = tierIndex - 1 + val newExpr = exprTiers(newTier)(index) + getPassThroughIndex(newTier, newExpr, index) + } + case _ => + None + } + + /** + * Given an output index check to see if this is just going to be a pass through to a + * specific input column index. + * @param index the output column index to check + * @return the index of the input column that it passes through to or else None + */ + def getPassThroughIndex(index: Int): Option[Int] = { + val startTier = exprTiers.length - 1 + getPassThroughIndex(startTier, exprTiers.last(index), index) + } + private [this] def projectWithRetrySingleBatchInternal(sb: SpillableColumnarBatch, closeInputBatch: Boolean): ColumnarBatch = { if (areAllDeterministic) { diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala index 2321aa12387..c13b24689e4 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala @@ -286,7 +286,7 @@ case class GpuAggregateExpression(origAggregateFunction: GpuAggregateFunction, override def sql: String = aggregateFunction.sql(isDistinct) } -trait CudfAggregate { +trait CudfAggregate extends Serializable { // we use this to get the ordinal of the bound reference, s.t. we can ask cudf to perform // the aggregate on that column val reductionAggregate: cudf.ColumnVector => cudf.Scalar @@ -325,7 +325,7 @@ class CudfSum(override val dataType: DataType) extends CudfAggregate { // sum(shorts): bigint // Aggregate [sum(shorts#33) AS sum(shorts)#50L] // - @transient val rapidsSumType: DType = GpuColumnVector.getNonNestedRapidsType(dataType) + @transient lazy val rapidsSumType: DType = GpuColumnVector.getNonNestedRapidsType(dataType) override val reductionAggregate: cudf.ColumnVector => cudf.Scalar = (col: cudf.ColumnVector) => col.sum(rapidsSumType) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriter.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriter.scala index 04e120caaaf..37697382eba 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriter.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriter.scala @@ -35,7 +35,6 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, Cast, Concat, Expression, Literal, NullsFirst, ScalaUDF, SortOrder, UnsafeProjection} -import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.connector.write.DataWriter import org.apache.spark.sql.execution.datasources.{BucketingUtils, PartitioningUtils, WriteTaskResult} import org.apache.spark.sql.rapids.GpuFileFormatDataWriter.{shouldSplitToFitMaxRecordsPerFile, splitToFitMaxRecordsAndClose} @@ -831,7 +830,6 @@ class GpuDynamicPartitionDataConcurrentWriter( val gpuSortOrder: Seq[SortOrder] = spec.sortOrder val output: Seq[Attribute] = spec.output val sorter = new GpuSorter(gpuSortOrder, output) - val cpuOrd = new LazilyGeneratedOrdering(sorter.cpuOrdering) // use noop metrics below val sortTime = NoopMetric @@ -841,7 +839,7 @@ class GpuDynamicPartitionDataConcurrentWriter( val targetSize = GpuSortExec.targetSize(spec.batchSize) // out of core sort the entire iterator - GpuOutOfCoreSortIterator(iterator, sorter, cpuOrd, targetSize, + GpuOutOfCoreSortIterator(iterator, sorter, targetSize, opTime, sortTime, outputBatch, outputRows) } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/HashFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/HashFunctions.scala index 209c78cad1e..60a469032dc 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/HashFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/HashFunctions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.rapids import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView} -import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuProjectExec, GpuUnaryExpression} +import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuProjectExec, GpuTieredProject, GpuUnaryExpression} import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.shims.{HashUtils, ShimExpression} @@ -55,6 +55,20 @@ object GpuMurmur3Hash { } } } + + def computeTiered(batch: ColumnarBatch, + boundExpr: GpuTieredProject, + seed: Int = 42): ColumnVector = { + withResource(boundExpr.project(batch)) { args => + val bases = GpuColumnVector.extractBases(args) + val normalized = bases.safeMap { cv => + HashUtils.normalizeInput(cv).asInstanceOf[ColumnView] + } + withResource(normalized) { _ => + ColumnVector.spark32BitMurmurHash3(seed, normalized) + } + } + } } case class GpuMurmur3Hash(children: Seq[Expression], seed: Int) extends GpuExpression diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/GpuOutOfCoreSortRetrySuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/GpuOutOfCoreSortRetrySuite.scala index 372e13da0ab..ecdf64c066e 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/GpuOutOfCoreSortRetrySuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/GpuOutOfCoreSortRetrySuite.scala @@ -22,7 +22,6 @@ import com.nvidia.spark.rapids.jni.{RetryOOM, SplitAndRetryOOM} import org.scalatestplus.mockito.MockitoSugar import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, ExprId, SortOrder} -import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.types.IntegerType import org.apache.spark.sql.vectorized.ColumnarBatch @@ -44,7 +43,6 @@ class GpuOutOfCoreSortRetrySuite extends RmmSparkRetrySuiteBase with MockitoSuga val outCoreIter = new GpuOutOfCoreSortIteratorThatThrows( Iterator(buildBatch), gpuSorter, - new LazilyGeneratedOrdering(gpuSorter.cpuOrdering), targetSize = 1024) withResource(outCoreIter) { _ => withResource(outCoreIter.next()) { cb => @@ -59,7 +57,6 @@ class GpuOutOfCoreSortRetrySuite extends RmmSparkRetrySuiteBase with MockitoSuga val outCoreIter = new GpuOutOfCoreSortIteratorThatThrows( Iterator(buildBatch), gpuSorter, - new LazilyGeneratedOrdering(gpuSorter.cpuOrdering), targetSize = 1024, firstPassSortExp = new RetryOOM()) withResource(outCoreIter) { _ => @@ -75,7 +72,6 @@ class GpuOutOfCoreSortRetrySuite extends RmmSparkRetrySuiteBase with MockitoSuga val outCoreIter = new GpuOutOfCoreSortIteratorThatThrows( Iterator(buildBatch), gpuSorter, - new LazilyGeneratedOrdering(gpuSorter.cpuOrdering), targetSize = 1024, firstPassSortExp = new SplitAndRetryOOM()) withResource(outCoreIter) { _ => @@ -91,7 +87,6 @@ class GpuOutOfCoreSortRetrySuite extends RmmSparkRetrySuiteBase with MockitoSuga val outCoreIter = new GpuOutOfCoreSortIteratorThatThrows( Iterator(buildBatch), gpuSorter, - new LazilyGeneratedOrdering(gpuSorter.cpuOrdering), targetSize = 1024, firstPassSplitExp = new RetryOOM()) withResource(outCoreIter) { _ => @@ -107,7 +102,6 @@ class GpuOutOfCoreSortRetrySuite extends RmmSparkRetrySuiteBase with MockitoSuga val outCoreIter = new GpuOutOfCoreSortIteratorThatThrows( Iterator(buildBatch), gpuSorter, - new LazilyGeneratedOrdering(gpuSorter.cpuOrdering), targetSize = 1024, firstPassSplitExp = new SplitAndRetryOOM()) withResource(outCoreIter) { _ => @@ -121,7 +115,6 @@ class GpuOutOfCoreSortRetrySuite extends RmmSparkRetrySuiteBase with MockitoSuga val outCoreIter = new GpuOutOfCoreSortIteratorThatThrows( Iterator(buildBatch, buildBatch), gpuSorter, - new LazilyGeneratedOrdering(gpuSorter.cpuOrdering), targetSize = 400, mergeSortExp = new RetryOOM()) withResource(outCoreIter) { _ => @@ -139,7 +132,6 @@ class GpuOutOfCoreSortRetrySuite extends RmmSparkRetrySuiteBase with MockitoSuga val outCoreIter = new GpuOutOfCoreSortIteratorThatThrows( Iterator(buildBatch, buildBatch), gpuSorter, - new LazilyGeneratedOrdering(gpuSorter.cpuOrdering), targetSize = 400, mergeSortExp = new SplitAndRetryOOM()) withResource(outCoreIter) { _ => @@ -153,7 +145,6 @@ class GpuOutOfCoreSortRetrySuite extends RmmSparkRetrySuiteBase with MockitoSuga val outCoreIter = new GpuOutOfCoreSortIteratorThatThrows( Iterator(buildBatch, buildBatch), gpuSorter, - new LazilyGeneratedOrdering(gpuSorter.cpuOrdering), targetSize = 400, concatOutExp = new RetryOOM()) withResource(outCoreIter) { _ => @@ -171,7 +162,6 @@ class GpuOutOfCoreSortRetrySuite extends RmmSparkRetrySuiteBase with MockitoSuga val outCoreIter = new GpuOutOfCoreSortIteratorThatThrows( Iterator(buildBatch, buildBatch), gpuSorter, - new LazilyGeneratedOrdering(gpuSorter.cpuOrdering), targetSize = 400, concatOutExp = new SplitAndRetryOOM()) withResource(outCoreIter) { _ => @@ -184,14 +174,13 @@ class GpuOutOfCoreSortRetrySuite extends RmmSparkRetrySuiteBase with MockitoSuga private class GpuOutOfCoreSortIteratorThatThrows( iter: Iterator[ColumnarBatch], sorter: GpuSorter, - cpuOrd: LazilyGeneratedOrdering, targetSize: Long, firstPassSortExp: Throwable = null, firstPassSplitExp: Throwable = null, mergeSortExp: Throwable = null, concatOutExp: Throwable = null, expMaxCount: Int = 1) - extends GpuOutOfCoreSortIterator(iter, sorter, cpuOrd, targetSize, + extends GpuOutOfCoreSortIterator(iter, sorter, targetSize, NoopMetric, NoopMetric, NoopMetric, NoopMetric){ private var expCnt = expMaxCount diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/HashAggregateRetrySuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/HashAggregateRetrySuite.scala index 65af8b6e6bb..04ebad76e9d 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/HashAggregateRetrySuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/HashAggregateRetrySuite.scala @@ -55,7 +55,8 @@ class HashAggregateRetrySuite val mockMetrics = mock[GpuHashAggregateMetrics] when(mockMetrics.opTime).thenReturn(NoopMetric) when(mockMetrics.concatTime).thenReturn(NoopMetric) - val aggHelper = spy(new GpuHashAggregateIterator.AggHelper( + when(mockMetrics.numAggOps).thenReturn(NoopMetric) + val aggHelper = spy(new AggHelper( Seq.empty, Seq.empty, Seq.empty, forceMerge = false, isSorted = false)) @@ -69,13 +70,13 @@ class HashAggregateRetrySuite // attempt a cuDF reduction withResource(input) { _ => - GpuHashAggregateIterator.aggregate( + GpuAggregateIterator.aggregate( aggHelper, input, mockMetrics) } } - def makeGroupByAggHelper(forceMerge: Boolean): GpuHashAggregateIterator.AggHelper = { - val aggHelper = spy(new GpuHashAggregateIterator.AggHelper( + def makeGroupByAggHelper(forceMerge: Boolean): AggHelper = { + val aggHelper = spy(new AggHelper( Seq.empty, Seq.empty, Seq.empty, forceMerge = forceMerge, isSorted = false)) @@ -104,9 +105,10 @@ class HashAggregateRetrySuite val mockMetrics = mock[GpuHashAggregateMetrics] when(mockMetrics.opTime).thenReturn(NoopMetric) when(mockMetrics.concatTime).thenReturn(NoopMetric) + when(mockMetrics.numAggOps).thenReturn(NoopMetric) // attempt a cuDF group by - GpuHashAggregateIterator.aggregate( + GpuAggregateIterator.aggregate( makeGroupByAggHelper(forceMerge = false), input, mockMetrics)