Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix hash_aggregate test failures due to TypedImperativeAggregate #3178

Merged
merged 6 commits into from
Aug 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions integration_tests/src/main/python/hash_aggregate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,16 @@ def test_hash_groupby_collect_with_single_distinct(data_gen):
@incompat
@pytest.mark.parametrize('data_gen', _gen_data_for_collect_op, ids=idfn)
def test_hash_groupby_single_distinct_collect(data_gen):
# test distinct collect with other aggregations
# test distinct collect
sql = """select a,
sort_array(collect_list(distinct b)),
sort_array(collect_set(distinct b))
from tbl group by a"""
assert_gpu_and_cpu_are_equal_sql(
df_fun=lambda spark: gen_df(spark, data_gen, length=100),
table_name="tbl", sql=sql)

# test distinct collect with nonDistinct aggregations
sql = """select a,
sort_array(collect_list(distinct b)),
sort_array(collect_set(b)),
Expand All @@ -471,7 +480,7 @@ def test_hash_groupby_single_distinct_collect(data_gen):
@approximate_float
@ignore_order(local=True)
@allow_non_gpu('SortAggregateExec',
'SortArray', 'Alias', 'Literal', 'First', 'If', 'EqualTo', 'Count',
'SortArray', 'Alias', 'Literal', 'First', 'If', 'EqualTo', 'Count', 'Coalesce',
'CollectList', 'CollectSet', 'AggregateExpression')
@incompat
@pytest.mark.parametrize('data_gen', _gen_data_for_collect_op, ids=idfn)
Expand Down
75 changes: 52 additions & 23 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala
Original file line number Diff line number Diff line change
Expand Up @@ -628,21 +628,43 @@ class GpuHashAggregateIterator(
// boundInputReferences is used to pick out of the input batch the appropriate columns
// for aggregation.
//
// - PartialMerge with Partial mode: we use the inputProjections
// for Partial and non distinct merge expressions for PartialMerge.
// - Final or PartialMerge-only mode: we pick the columns in the order as handed to us.
// - Partial or Complete mode: we use the inputProjections
// - DistinctAggExpressions with nonDistinctAggExpressions in other mode: we switch the
// position of distinctAttributes and nonDistinctAttributes in childAttr. And we use the
// inputProjections for nonDistinctAggExpressions.
// - Final mode, PartialMerge-only mode or no AggExpressions: we pick the columns in the order
// as handed to us.
// - Partial mode or Complete mode: we use the inputProjections.
val boundInputReferences =
if (modeInfo.hasPartialMergeMode && modeInfo.hasPartialMode) {
// The 3rd stage of AggWithOneDistinct, which combines (partial) reduce-side
// nonDistinctAggExpressions and map-side distinctAggExpressions. For this stage, we need to
// switch the position of distinctAttributes and nonDistinctAttributes.
if (modeInfo.uniqueModes.length > 1 && aggregateExpressions.exists(_.isDistinct)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes me a little nervous that we are missing something. The Spark aggregation code does not look at distinct at all. It really just looks at the individual modes for each operation. Why is it that we need to do this to get the aggregation right, but the Spark code does not?

Copy link
Collaborator Author

@sperlingxx sperlingxx Aug 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For AggWithOneDistinct, the Spark plans 4-stage stack of AggregateExec. Each stage owns an unique Modes:

  • Stage 1: Partial mode, only includes nonDistinct ones
  • Stage 2: PartialMerge mode, only includes nonDistinct ones
  • Stage 3: PartialMerge mode for nonDistinct ones and Partial mode for Distinct ones
  • Stage 4: Final mode for both nonDistinct and Distinct AggregateExpressions

In contrast, Databricks runtime seems to apply a quite different planning strategy to AggWithOneDistinct. With the dumped plan trees, we infer Databricks runtime only plans 2-stage stack for AggWithOneDistinct: Map-stage and Reduce-stage.

  • Map-stage: Partial mode, only includes nonDistinct ones
  • Reduce-stage: Final mode for nonDistinct ones and Complete mode for Distinct ones

Apparently, the Map-stage corresponds to Stage 1 and Stage 2; the Reduce-stage corresponds to Stage 3 and Stage 4.
The condition here was used to match Stage 3, so it checked whether modeInfo contains both PartialMerge and Partial. Currently, we want to adapt Databricks runtime. In terms of Reduce-stage, the input projections of Reduce-stage are exactly same as Stage 3, though they contain different AggregateModes. Therefore, we change the condition here to match the Reduce-stage of Databrick runtimes as well as the Stage 3 of Spark. In fact, the condition modeInfo.uniqueModes.length > 1 along is enough to distinguish Stage 3 and Reduce-stage from other stages. The latter condition aggregateExpressions.exists(_.isDistinct) is to increase the robustness in case of some unknown special cases.

In addition, the input projections for Stage 1 fully fits the Map-stage of Databricks runtime. We don't need to change anything to adapt Databricks runtime.

Copy link
Collaborator Author

@sperlingxx sperlingxx Aug 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, the condition like (modeInfo.hasPartialMergeMode && modeInfo.hasPartialMode) || (modeInfo.hasFinalMode && modeInfo.hasCompleteMode) may look more straightforward.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am okay with this as a short term fix. The problem is not with your logic. The problem is that we keep hacking special cases onto something that should not need them.

Each aggregation comes with a mode. Each mode tells the aggregation what to do as a part of that stage. Originally the code assumed that there would only ever be one mode for all of the aggregations. I thought we had ripped that all out and each aggregation does the right thing.

To successfully do an aggregation there are a few steps used.

  1. Initial projection to get input data setup properly.
  2. Initial aggregation to produce partial result(s)
  3. Merge aggregation to combine partial results (This requires that the input schema and the output schema be the same)
  4. Final projection to take the combined partial results and produce a final result.

In general the steps take the pattern 1, 2, 3*, 4. Which means 1, 2 and 4 are required and step 3 can be done as often as needed because the input and output schemas are the same.

Step 4 requires that all of the data for a given group by key is on the same task and has been merged into a single output row. There are several different ways to do this, which is why we end up with several aggregation modes.

  • Partial mode means that we do Step 1 and Step 2. Then we can do Step 3 as many times as needed depending on how we are doing memory management, and how many batches are needed.
  • PartialMerge mode means we can do Step 3 at least once and possibly more times depending on how we are doing memory management and how many batches are needed.
  • Final mode means that we do the same steps as with PartialMerge but do Step 4 when we are done doing the partial merges.
  • Complete mode is something only Databricks does, but it essentially means we do Step 1, Step 2, Step 3* (depending on memory management requirements), and Step 4 all at once.

I know that the details are a lot more complicated, but conceptually it should not be too difficult. I will file a follow on issue for us to figure this out.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the main ask is to not do this wholesale, assuming that a hash aggregate exec has a certain shape. If this function could decide per aggregate expression mode what the right binding should be, it should be more robust to new aggregate exec setups that mix and match modes (if we encounter new ones). That said, I don't think this is your fault as the setupReferences code was built that way, it needs to be reworked separately.

// This block takes care of AggregateExec which contains nonDistinctAggExpressions and
// distinctAggExpressions with different AggregateModes. All nonDistinctAggExpressions share
// one mode and all distinctAggExpressions are in another mode. The specific mode varies in
// different Spark runtimes, so this block applies a general condition to adapt different
// runtimes:
//
// The schema of the 2nd stage's outputs:
// groupingAttributes ++ distinctAttributes ++ nonDistinctAggBufferAttributes
// 1. Apache Spark: The 3rd stage of AggWithOneDistinct
// The 3rd stage of AggWithOneDistinct, which consists of for nonDistinctAggExpressions in
// PartialMerge mode and distinctAggExpressions in Partial mode. For this stage, we need to
// switch the position of distinctAttributes and nonDistinctAttributes if there exists at
// least one nonDistinctAggExpression. Because the positions of distinctAttributes are ahead
// of nonDistinctAttributes in the output of previous stage, since distinctAttributes are
// included in groupExpressions.
// To be specific, the schema of the 2nd stage's outputs is:
// (groupingAttributes ++ distinctAttributes) ++ nonDistinctAggBufferAttributes
// The schema of the 3rd stage's expressions is:
// groupingAttributes ++ nonDistinctAggExpressions(PartialMerge) ++
// distinctAggExpressions(Partial)
//
// The schema of the 3rd stage's expressions:
// nonDistinctMergeAggExpressions ++ distinctPartialAggExpressions
// 2. Databricks runtime: The final stage of AggWithOneDistinct
// Databricks runtime squeezes the 4-stage AggWithOneDistinct into 2 stages. Basically, it
// combines the 1st and 2nd stage into a "Partial" stage; and it combines the 3nd and 4th
// stage into a "Merge" stage. Similarly, nonDistinctAggExpressions are ahead of distinct
// ones in the layout of "Merge" stage's expressions:
// groupingAttributes ++ nonDistinctAggExpressions(Final) ++ DistinctAggExpressions(Complete)
// Meanwhile, as Apache Spark, distinctAttributes are ahead of nonDistinctAggBufferAttributes
// in the output schema of the "Partial" stage.
// Therefore, this block also works on the final stage of AggWithOneDistinct under Databricks
// runtime.

val (distinctAggExpressions, nonDistinctAggExpressions) = aggregateExpressions.partition(
_.isDistinct)
Expand All @@ -656,10 +678,11 @@ class GpuHashAggregateIterator(
val distinctAttributes = childAttr.attrs.slice(
groupingAttributes.length, childAttr.attrs.length - sizeOfNonDistAttr)

// With PartialMerge modes, we just pass through corresponding attributes of child plan into
// nonDistinctExpressions.
// For nonDistinctExpressions, they are in either PartialMerge or Final modes. With either
// mode, we just need to pass through childAttr.
val nonDistinctExpressions = nonDistinctAttributes.asInstanceOf[Seq[Expression]]
// With Partial modes, the input projections are necessary for distinctExpressions.
// For nonDistinctExpressions, they are in either Final or Complete modes. With either mode,
// we need to apply the input projections on these AggExpressions.
val distinctExpressions = distinctAggExpressions.flatMap(_.aggregateFunction.inputProjection)

// Align the expressions of input projections and input attributes
Expand All @@ -668,18 +691,24 @@ class GpuHashAggregateIterator(
GpuBindReferences.bindGpuReferences(inputProjections, inputAttributes)
} else if (modeInfo.hasFinalMode ||
(modeInfo.hasPartialMergeMode && modeInfo.uniqueModes.length == 1)) {
// two possible conditions:
// This block takes care of two possible conditions:
// 1. The Final stage, including the 2nd stage of NoDistinctAgg and 4th stage of
// AggWithOneDistinct, which needs no input projections. Because the child outputs are
// internal aggregation buffers, which are aligned for the final stage.
//
// 2. The 2nd stage (PartialMerge) of AggWithOneDistinct, which works like the final stage
// taking the child outputs as inputs without any projections.
GpuBindReferences.bindGpuReferences(childAttr.attrs.asInstanceOf[Seq[Expression]], childAttr)
} else if (modeInfo.hasPartialMode || modeInfo.hasCompleteMode ||
modeInfo.uniqueModes.isEmpty) {
// The first aggregation stage (including Partial or Complete or no aggExpression),
// whose child node is not an AggregateExec. Therefore, input projections are essential.
// The first aggregation stage which contains AggExpressions (in either Partial or Complete
// mode). In this case, the input projections are essential.
// To be specific, there are four conditions matching this case:
// 1. The Partial (1st) stage of NoDistinctAgg
// 2. The Partial (1st) stage of AggWithOneDistinct
// 3. In Databricks runtime, the "Final" (2nd) stage of AggWithOneDistinct which only contains
// DistinctAggExpressions (without any nonDistinctAggExpressions)
//
// In addition, this block also fits for aggregation stages without any AggExpressions.
val inputProjections: Seq[Expression] = groupingExpressions ++ aggregateExpressions
.flatMap(_.aggregateFunction.inputProjection)
GpuBindReferences.bindGpuReferences(inputProjections, childAttr)
Expand Down Expand Up @@ -1209,7 +1238,7 @@ case class GpuHashAggregateExec(
override lazy val additionalMetrics: Map[String, GpuMetric] = Map(
NUM_TASKS_FALL_BACKED -> createMetric(MODERATE_LEVEL, DESCRIPTION_NUM_TASKS_FALL_BACKED),
AGG_TIME -> createNanoTimingMetric(MODERATE_LEVEL, DESCRIPTION_AGG_TIME),
CONCAT_TIME-> createNanoTimingMetric(MODERATE_LEVEL, DESCRIPTION_CONCAT_TIME),
CONCAT_TIME -> createNanoTimingMetric(MODERATE_LEVEL, DESCRIPTION_CONCAT_TIME),
SORT_TIME -> createNanoTimingMetric(MODERATE_LEVEL, DESCRIPTION_SORT_TIME)
) ++ spillMetrics

Expand Down Expand Up @@ -1298,8 +1327,8 @@ case class GpuHashAggregateExec(
// Used in de-duping and optimizer rules
override def producedAttributes: AttributeSet =
AttributeSet(aggregateAttributes) ++
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
AttributeSet(aggregateBufferAttributes)
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
AttributeSet(aggregateBufferAttributes)

// AllTuples = distribution with a single partition and all tuples of the dataset are co-located.
// Clustered = dataset with tuples co-located in the same partition if they share a specific value
Expand Down Expand Up @@ -1340,7 +1369,7 @@ case class GpuHashAggregateExec(
s"GpuHashAggregate(keys=$keyString, functions=$functionString, output=$outputString)"
} else {
s"GpuHashAggregate(keys=$keyString, functions=$functionString)," +
s" filters=${aggregateExpressions.map(_.filter)})"
s" filters=${aggregateExpressions.map(_.filter)})"
}
}
//
Expand Down