Skip to content

Commit

Permalink
Fix hash_aggregate test failures due to TypedImperativeAggregate (#3178)
Browse files Browse the repository at this point in the history
* fix

Signed-off-by: sperlingxx <lovedreamf@gmail.com>

* update comments

Signed-off-by: sperlingxx <lovedreamf@gmail.com>

* update

Signed-off-by: sperlingxx <lovedreamf@gmail.com>

* fix scala style

Signed-off-by: sperlingxx <lovedreamf@gmail.com>

* Revert "fix scala style"

This reverts commit 1658125.

* fix scala style

Signed-off-by: sperlingxx <lovedreamf@gmail.com>
  • Loading branch information
sperlingxx authored Aug 11, 2021
1 parent 3a8dfcd commit 32d7892
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 25 deletions.
13 changes: 11 additions & 2 deletions integration_tests/src/main/python/hash_aggregate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,16 @@ def test_hash_groupby_collect_with_single_distinct(data_gen):
@incompat
@pytest.mark.parametrize('data_gen', _gen_data_for_collect_op, ids=idfn)
def test_hash_groupby_single_distinct_collect(data_gen):
# test distinct collect with other aggregations
# test distinct collect
sql = """select a,
sort_array(collect_list(distinct b)),
sort_array(collect_set(distinct b))
from tbl group by a"""
assert_gpu_and_cpu_are_equal_sql(
df_fun=lambda spark: gen_df(spark, data_gen, length=100),
table_name="tbl", sql=sql)

# test distinct collect with nonDistinct aggregations
sql = """select a,
sort_array(collect_list(distinct b)),
sort_array(collect_set(b)),
Expand All @@ -471,7 +480,7 @@ def test_hash_groupby_single_distinct_collect(data_gen):
@approximate_float
@ignore_order(local=True)
@allow_non_gpu('SortAggregateExec',
'SortArray', 'Alias', 'Literal', 'First', 'If', 'EqualTo', 'Count',
'SortArray', 'Alias', 'Literal', 'First', 'If', 'EqualTo', 'Count', 'Coalesce',
'CollectList', 'CollectSet', 'AggregateExpression')
@incompat
@pytest.mark.parametrize('data_gen', _gen_data_for_collect_op, ids=idfn)
Expand Down
75 changes: 52 additions & 23 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala
Original file line number Diff line number Diff line change
Expand Up @@ -628,21 +628,43 @@ class GpuHashAggregateIterator(
// boundInputReferences is used to pick out of the input batch the appropriate columns
// for aggregation.
//
// - PartialMerge with Partial mode: we use the inputProjections
// for Partial and non distinct merge expressions for PartialMerge.
// - Final or PartialMerge-only mode: we pick the columns in the order as handed to us.
// - Partial or Complete mode: we use the inputProjections
// - DistinctAggExpressions with nonDistinctAggExpressions in other mode: we switch the
// position of distinctAttributes and nonDistinctAttributes in childAttr. And we use the
// inputProjections for nonDistinctAggExpressions.
// - Final mode, PartialMerge-only mode or no AggExpressions: we pick the columns in the order
// as handed to us.
// - Partial mode or Complete mode: we use the inputProjections.
val boundInputReferences =
if (modeInfo.hasPartialMergeMode && modeInfo.hasPartialMode) {
// The 3rd stage of AggWithOneDistinct, which combines (partial) reduce-side
// nonDistinctAggExpressions and map-side distinctAggExpressions. For this stage, we need to
// switch the position of distinctAttributes and nonDistinctAttributes.
if (modeInfo.uniqueModes.length > 1 && aggregateExpressions.exists(_.isDistinct)) {
// This block takes care of AggregateExec which contains nonDistinctAggExpressions and
// distinctAggExpressions with different AggregateModes. All nonDistinctAggExpressions share
// one mode and all distinctAggExpressions are in another mode. The specific mode varies in
// different Spark runtimes, so this block applies a general condition to adapt different
// runtimes:
//
// The schema of the 2nd stage's outputs:
// groupingAttributes ++ distinctAttributes ++ nonDistinctAggBufferAttributes
// 1. Apache Spark: The 3rd stage of AggWithOneDistinct
// The 3rd stage of AggWithOneDistinct, which consists of for nonDistinctAggExpressions in
// PartialMerge mode and distinctAggExpressions in Partial mode. For this stage, we need to
// switch the position of distinctAttributes and nonDistinctAttributes if there exists at
// least one nonDistinctAggExpression. Because the positions of distinctAttributes are ahead
// of nonDistinctAttributes in the output of previous stage, since distinctAttributes are
// included in groupExpressions.
// To be specific, the schema of the 2nd stage's outputs is:
// (groupingAttributes ++ distinctAttributes) ++ nonDistinctAggBufferAttributes
// The schema of the 3rd stage's expressions is:
// groupingAttributes ++ nonDistinctAggExpressions(PartialMerge) ++
// distinctAggExpressions(Partial)
//
// The schema of the 3rd stage's expressions:
// nonDistinctMergeAggExpressions ++ distinctPartialAggExpressions
// 2. Databricks runtime: The final stage of AggWithOneDistinct
// Databricks runtime squeezes the 4-stage AggWithOneDistinct into 2 stages. Basically, it
// combines the 1st and 2nd stage into a "Partial" stage; and it combines the 3nd and 4th
// stage into a "Merge" stage. Similarly, nonDistinctAggExpressions are ahead of distinct
// ones in the layout of "Merge" stage's expressions:
// groupingAttributes ++ nonDistinctAggExpressions(Final) ++ DistinctAggExpressions(Complete)
// Meanwhile, as Apache Spark, distinctAttributes are ahead of nonDistinctAggBufferAttributes
// in the output schema of the "Partial" stage.
// Therefore, this block also works on the final stage of AggWithOneDistinct under Databricks
// runtime.

val (distinctAggExpressions, nonDistinctAggExpressions) = aggregateExpressions.partition(
_.isDistinct)
Expand All @@ -656,10 +678,11 @@ class GpuHashAggregateIterator(
val distinctAttributes = childAttr.attrs.slice(
groupingAttributes.length, childAttr.attrs.length - sizeOfNonDistAttr)

// With PartialMerge modes, we just pass through corresponding attributes of child plan into
// nonDistinctExpressions.
// For nonDistinctExpressions, they are in either PartialMerge or Final modes. With either
// mode, we just need to pass through childAttr.
val nonDistinctExpressions = nonDistinctAttributes.asInstanceOf[Seq[Expression]]
// With Partial modes, the input projections are necessary for distinctExpressions.
// For nonDistinctExpressions, they are in either Final or Complete modes. With either mode,
// we need to apply the input projections on these AggExpressions.
val distinctExpressions = distinctAggExpressions.flatMap(_.aggregateFunction.inputProjection)

// Align the expressions of input projections and input attributes
Expand All @@ -668,18 +691,24 @@ class GpuHashAggregateIterator(
GpuBindReferences.bindGpuReferences(inputProjections, inputAttributes)
} else if (modeInfo.hasFinalMode ||
(modeInfo.hasPartialMergeMode && modeInfo.uniqueModes.length == 1)) {
// two possible conditions:
// This block takes care of two possible conditions:
// 1. The Final stage, including the 2nd stage of NoDistinctAgg and 4th stage of
// AggWithOneDistinct, which needs no input projections. Because the child outputs are
// internal aggregation buffers, which are aligned for the final stage.
//
// 2. The 2nd stage (PartialMerge) of AggWithOneDistinct, which works like the final stage
// taking the child outputs as inputs without any projections.
GpuBindReferences.bindGpuReferences(childAttr.attrs.asInstanceOf[Seq[Expression]], childAttr)
} else if (modeInfo.hasPartialMode || modeInfo.hasCompleteMode ||
modeInfo.uniqueModes.isEmpty) {
// The first aggregation stage (including Partial or Complete or no aggExpression),
// whose child node is not an AggregateExec. Therefore, input projections are essential.
// The first aggregation stage which contains AggExpressions (in either Partial or Complete
// mode). In this case, the input projections are essential.
// To be specific, there are four conditions matching this case:
// 1. The Partial (1st) stage of NoDistinctAgg
// 2. The Partial (1st) stage of AggWithOneDistinct
// 3. In Databricks runtime, the "Final" (2nd) stage of AggWithOneDistinct which only contains
// DistinctAggExpressions (without any nonDistinctAggExpressions)
//
// In addition, this block also fits for aggregation stages without any AggExpressions.
val inputProjections: Seq[Expression] = groupingExpressions ++ aggregateExpressions
.flatMap(_.aggregateFunction.inputProjection)
GpuBindReferences.bindGpuReferences(inputProjections, childAttr)
Expand Down Expand Up @@ -1209,7 +1238,7 @@ case class GpuHashAggregateExec(
override lazy val additionalMetrics: Map[String, GpuMetric] = Map(
NUM_TASKS_FALL_BACKED -> createMetric(MODERATE_LEVEL, DESCRIPTION_NUM_TASKS_FALL_BACKED),
AGG_TIME -> createNanoTimingMetric(MODERATE_LEVEL, DESCRIPTION_AGG_TIME),
CONCAT_TIME-> createNanoTimingMetric(MODERATE_LEVEL, DESCRIPTION_CONCAT_TIME),
CONCAT_TIME -> createNanoTimingMetric(MODERATE_LEVEL, DESCRIPTION_CONCAT_TIME),
SORT_TIME -> createNanoTimingMetric(MODERATE_LEVEL, DESCRIPTION_SORT_TIME)
) ++ spillMetrics

Expand Down Expand Up @@ -1298,8 +1327,8 @@ case class GpuHashAggregateExec(
// Used in de-duping and optimizer rules
override def producedAttributes: AttributeSet =
AttributeSet(aggregateAttributes) ++
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
AttributeSet(aggregateBufferAttributes)
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
AttributeSet(aggregateBufferAttributes)

// AllTuples = distribution with a single partition and all tuples of the dataset are co-located.
// Clustered = dataset with tuples co-located in the same partition if they share a specific value
Expand Down Expand Up @@ -1340,7 +1369,7 @@ case class GpuHashAggregateExec(
s"GpuHashAggregate(keys=$keyString, functions=$functionString, output=$outputString)"
} else {
s"GpuHashAggregate(keys=$keyString, functions=$functionString)," +
s" filters=${aggregateExpressions.map(_.filter)})"
s" filters=${aggregateExpressions.map(_.filter)})"
}
}
//
Expand Down

0 comments on commit 32d7892

Please sign in to comment.