diff --git a/integration_tests/src/main/python/hash_aggregate_test.py b/integration_tests/src/main/python/hash_aggregate_test.py index 56a593b901d..7c5688b4206 100644 --- a/integration_tests/src/main/python/hash_aggregate_test.py +++ b/integration_tests/src/main/python/hash_aggregate_test.py @@ -452,7 +452,16 @@ def test_hash_groupby_collect_with_single_distinct(data_gen): @incompat @pytest.mark.parametrize('data_gen', _gen_data_for_collect_op, ids=idfn) def test_hash_groupby_single_distinct_collect(data_gen): - # test distinct collect with other aggregations + # test distinct collect + sql = """select a, + sort_array(collect_list(distinct b)), + sort_array(collect_set(distinct b)) + from tbl group by a""" + assert_gpu_and_cpu_are_equal_sql( + df_fun=lambda spark: gen_df(spark, data_gen, length=100), + table_name="tbl", sql=sql) + + # test distinct collect with nonDistinct aggregations sql = """select a, sort_array(collect_list(distinct b)), sort_array(collect_set(b)), @@ -471,7 +480,7 @@ def test_hash_groupby_single_distinct_collect(data_gen): @approximate_float @ignore_order(local=True) @allow_non_gpu('SortAggregateExec', - 'SortArray', 'Alias', 'Literal', 'First', 'If', 'EqualTo', 'Count', + 'SortArray', 'Alias', 'Literal', 'First', 'If', 'EqualTo', 'Count', 'Coalesce', 'CollectList', 'CollectSet', 'AggregateExpression') @incompat @pytest.mark.parametrize('data_gen', _gen_data_for_collect_op, ids=idfn) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala index 25bde0b259d..a7826f935fb 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala @@ -628,21 +628,43 @@ class GpuHashAggregateIterator( // boundInputReferences is used to pick out of the input batch the appropriate columns // for aggregation. // - // - PartialMerge with Partial mode: we use the inputProjections - // for Partial and non distinct merge expressions for PartialMerge. - // - Final or PartialMerge-only mode: we pick the columns in the order as handed to us. - // - Partial or Complete mode: we use the inputProjections + // - DistinctAggExpressions with nonDistinctAggExpressions in other mode: we switch the + // position of distinctAttributes and nonDistinctAttributes in childAttr. And we use the + // inputProjections for nonDistinctAggExpressions. + // - Final mode, PartialMerge-only mode or no AggExpressions: we pick the columns in the order + // as handed to us. + // - Partial mode or Complete mode: we use the inputProjections. val boundInputReferences = - if (modeInfo.hasPartialMergeMode && modeInfo.hasPartialMode) { - // The 3rd stage of AggWithOneDistinct, which combines (partial) reduce-side - // nonDistinctAggExpressions and map-side distinctAggExpressions. For this stage, we need to - // switch the position of distinctAttributes and nonDistinctAttributes. + if (modeInfo.uniqueModes.length > 1 && aggregateExpressions.exists(_.isDistinct)) { + // This block takes care of AggregateExec which contains nonDistinctAggExpressions and + // distinctAggExpressions with different AggregateModes. All nonDistinctAggExpressions share + // one mode and all distinctAggExpressions are in another mode. The specific mode varies in + // different Spark runtimes, so this block applies a general condition to adapt different + // runtimes: // - // The schema of the 2nd stage's outputs: - // groupingAttributes ++ distinctAttributes ++ nonDistinctAggBufferAttributes + // 1. Apache Spark: The 3rd stage of AggWithOneDistinct + // The 3rd stage of AggWithOneDistinct, which consists of for nonDistinctAggExpressions in + // PartialMerge mode and distinctAggExpressions in Partial mode. For this stage, we need to + // switch the position of distinctAttributes and nonDistinctAttributes if there exists at + // least one nonDistinctAggExpression. Because the positions of distinctAttributes are ahead + // of nonDistinctAttributes in the output of previous stage, since distinctAttributes are + // included in groupExpressions. + // To be specific, the schema of the 2nd stage's outputs is: + // (groupingAttributes ++ distinctAttributes) ++ nonDistinctAggBufferAttributes + // The schema of the 3rd stage's expressions is: + // groupingAttributes ++ nonDistinctAggExpressions(PartialMerge) ++ + // distinctAggExpressions(Partial) // - // The schema of the 3rd stage's expressions: - // nonDistinctMergeAggExpressions ++ distinctPartialAggExpressions + // 2. Databricks runtime: The final stage of AggWithOneDistinct + // Databricks runtime squeezes the 4-stage AggWithOneDistinct into 2 stages. Basically, it + // combines the 1st and 2nd stage into a "Partial" stage; and it combines the 3nd and 4th + // stage into a "Merge" stage. Similarly, nonDistinctAggExpressions are ahead of distinct + // ones in the layout of "Merge" stage's expressions: + // groupingAttributes ++ nonDistinctAggExpressions(Final) ++ DistinctAggExpressions(Complete) + // Meanwhile, as Apache Spark, distinctAttributes are ahead of nonDistinctAggBufferAttributes + // in the output schema of the "Partial" stage. + // Therefore, this block also works on the final stage of AggWithOneDistinct under Databricks + // runtime. val (distinctAggExpressions, nonDistinctAggExpressions) = aggregateExpressions.partition( _.isDistinct) @@ -656,10 +678,11 @@ class GpuHashAggregateIterator( val distinctAttributes = childAttr.attrs.slice( groupingAttributes.length, childAttr.attrs.length - sizeOfNonDistAttr) - // With PartialMerge modes, we just pass through corresponding attributes of child plan into - // nonDistinctExpressions. + // For nonDistinctExpressions, they are in either PartialMerge or Final modes. With either + // mode, we just need to pass through childAttr. val nonDistinctExpressions = nonDistinctAttributes.asInstanceOf[Seq[Expression]] - // With Partial modes, the input projections are necessary for distinctExpressions. + // For nonDistinctExpressions, they are in either Final or Complete modes. With either mode, + // we need to apply the input projections on these AggExpressions. val distinctExpressions = distinctAggExpressions.flatMap(_.aggregateFunction.inputProjection) // Align the expressions of input projections and input attributes @@ -668,18 +691,24 @@ class GpuHashAggregateIterator( GpuBindReferences.bindGpuReferences(inputProjections, inputAttributes) } else if (modeInfo.hasFinalMode || (modeInfo.hasPartialMergeMode && modeInfo.uniqueModes.length == 1)) { - // two possible conditions: + // This block takes care of two possible conditions: // 1. The Final stage, including the 2nd stage of NoDistinctAgg and 4th stage of // AggWithOneDistinct, which needs no input projections. Because the child outputs are // internal aggregation buffers, which are aligned for the final stage. - // // 2. The 2nd stage (PartialMerge) of AggWithOneDistinct, which works like the final stage // taking the child outputs as inputs without any projections. GpuBindReferences.bindGpuReferences(childAttr.attrs.asInstanceOf[Seq[Expression]], childAttr) } else if (modeInfo.hasPartialMode || modeInfo.hasCompleteMode || modeInfo.uniqueModes.isEmpty) { - // The first aggregation stage (including Partial or Complete or no aggExpression), - // whose child node is not an AggregateExec. Therefore, input projections are essential. + // The first aggregation stage which contains AggExpressions (in either Partial or Complete + // mode). In this case, the input projections are essential. + // To be specific, there are four conditions matching this case: + // 1. The Partial (1st) stage of NoDistinctAgg + // 2. The Partial (1st) stage of AggWithOneDistinct + // 3. In Databricks runtime, the "Final" (2nd) stage of AggWithOneDistinct which only contains + // DistinctAggExpressions (without any nonDistinctAggExpressions) + // + // In addition, this block also fits for aggregation stages without any AggExpressions. val inputProjections: Seq[Expression] = groupingExpressions ++ aggregateExpressions .flatMap(_.aggregateFunction.inputProjection) GpuBindReferences.bindGpuReferences(inputProjections, childAttr) @@ -1209,7 +1238,7 @@ case class GpuHashAggregateExec( override lazy val additionalMetrics: Map[String, GpuMetric] = Map( NUM_TASKS_FALL_BACKED -> createMetric(MODERATE_LEVEL, DESCRIPTION_NUM_TASKS_FALL_BACKED), AGG_TIME -> createNanoTimingMetric(MODERATE_LEVEL, DESCRIPTION_AGG_TIME), - CONCAT_TIME-> createNanoTimingMetric(MODERATE_LEVEL, DESCRIPTION_CONCAT_TIME), + CONCAT_TIME -> createNanoTimingMetric(MODERATE_LEVEL, DESCRIPTION_CONCAT_TIME), SORT_TIME -> createNanoTimingMetric(MODERATE_LEVEL, DESCRIPTION_SORT_TIME) ) ++ spillMetrics @@ -1298,8 +1327,8 @@ case class GpuHashAggregateExec( // Used in de-duping and optimizer rules override def producedAttributes: AttributeSet = AttributeSet(aggregateAttributes) ++ - AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ - AttributeSet(aggregateBufferAttributes) + AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ + AttributeSet(aggregateBufferAttributes) // AllTuples = distribution with a single partition and all tuples of the dataset are co-located. // Clustered = dataset with tuples co-located in the same partition if they share a specific value @@ -1340,7 +1369,7 @@ case class GpuHashAggregateExec( s"GpuHashAggregate(keys=$keyString, functions=$functionString, output=$outputString)" } else { s"GpuHashAggregate(keys=$keyString, functions=$functionString)," + - s" filters=${aggregateExpressions.map(_.filter)})" + s" filters=${aggregateExpressions.map(_.filter)})" } } //