Skip to content

Commit

Permalink
Improve sort removal heuristic for sort aggregate (#5917)
Browse files Browse the repository at this point in the history
* Improve sort removal heuristic for sort aggregate

Signed-off-by: Robert (Bobby) Evans <bobby@apache.org>
  • Loading branch information
revans2 authored Jun 27, 2022
1 parent a5dfef2 commit 30467f2
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 15 deletions.
2 changes: 1 addition & 1 deletion docs/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ Name | Description | Default Value
<a name="sql.hasNans"></a>spark.rapids.sql.hasNans|Config to indicate if your data has NaN's. Cudf doesn't currently support NaN's properly so you can get corrupt data if you have NaN's in your data and it runs on the GPU.|true
<a name="sql.hashOptimizeSort.enabled"></a>spark.rapids.sql.hashOptimizeSort.enabled|Whether sorts should be inserted after some hashed operations to improve output ordering. This can improve output file sizes when saving to columnar formats.|false
<a name="sql.improvedFloatOps.enabled"></a>spark.rapids.sql.improvedFloatOps.enabled|For some floating point operations spark uses one way to compute the value and the underlying cudf implementation can use an improved algorithm. In some cases this can result in cudf producing an answer when spark overflows.|true
<a name="sql.improvedTimeOps.enabled"></a>spark.rapids.sql.improvedTimeOps.enabled|When set to true, some operators will avoid overflowing by converting epoch days directly to seconds without first converting to microseconds|false
<a name="sql.improvedTimeOps.enabled"></a>spark.rapids.sql.improvedTimeOps.enabled|When set to true, some operators will avoid overflowing by converting epoch days directly to seconds without first converting to microseconds|false
<a name="sql.incompatibleDateFormats.enabled"></a>spark.rapids.sql.incompatibleDateFormats.enabled|When parsing strings as dates and timestamps in functions like unix_timestamp, some formats are fully supported on the GPU and some are unsupported and will fall back to the CPU. Some formats behave differently on the GPU than the CPU. Spark on the CPU interprets date formats with unsupported trailing characters as nulls, while Spark on the GPU will parse the date with invalid trailing characters. More detail can be found at [parsing strings as dates or timestamps](compatibility.md#parsing-strings-as-dates-or-timestamps).|false
<a name="sql.incompatibleOps.enabled"></a>spark.rapids.sql.incompatibleOps.enabled|For operations that work, but are not 100% compatible with the Spark equivalent set if they should be enabled by default or disabled by default.|true
<a name="sql.join.cross.enabled"></a>spark.rapids.sql.join.cross.enabled|When set to true cross joins are enabled on the GPU|true
Expand Down
13 changes: 13 additions & 0 deletions integration_tests/src/main/python/hash_aggregate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,6 +1104,19 @@ def test_groupby_first_last(data_gen):
# We set parallelism 1 to prevent nondeterministic results because of distributed setup.
lambda spark: agg_fn(gen_df(spark, gen_fn, num_slices=1)))

@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', all_gen + _struct_only_nested_gens, ids=idfn)
def test_sorted_groupby_first_last(data_gen):
gen_fn = [('a', RepeatSeqGen(LongGen(), length=20)), ('b', data_gen)]
# sort by more than the group by columns to be sure that first/last don't remove the ordering
agg_fn = lambda df: df.orderBy('a', 'b').groupBy('a').agg(
f.first('b'), f.last('b'), f.first('b', True), f.last('b', True))
assert_gpu_and_cpu_are_equal_collect(
# First and last are not deterministic when they are run in a real distributed setup.
# We set parallelism 1 to prevent nondeterministic results because of distributed setup.
lambda spark: agg_fn(gen_df(spark, gen_fn, num_slices=1)),
conf = {'spark.rapids.sql.explain': 'ALL'})

@ignore_order
@pytest.mark.parametrize('data_gen', all_gen, ids=idfn)
@pytest.mark.parametrize('count_func', [f.count, f.countDistinct])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ object RapidsConf {
val IMPROVED_TIMESTAMP_OPS =
conf("spark.rapids.sql.improvedTimeOps.enabled")
.doc("When set to true, some operators will avoid overflowing by converting epoch days " +
" directly to seconds without first converting to microseconds")
"directly to seconds without first converting to microseconds")
.booleanConf
.createWithDefault(false)

Expand Down
37 changes: 24 additions & 13 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1313,24 +1313,35 @@ class GpuSortAggregateExecMeta(

// Make sure this is the last check - if this is SortAggregate, the children can be sorts and we
// want to validate they can run on GPU and remove them before replacing this with a
// HashAggregate. We don't want to do this if there is a first or last aggregate,
// because dropping the sort will make them no longer deterministic.
// In the future we might be able to pull the sort functionality into the aggregate so
// we can sort a single batch at a time and sort the combined result as well which would help
// with data skew.
val hasFirstOrLast = agg.aggregateExpressions.exists { agg =>
agg.aggregateFunction match {
case _: First | _: Last => true
case _ => false
}
}
if (canThisBeReplaced && !hasFirstOrLast) {
// HashAggregate.
if (canThisBeReplaced) {
childPlans.foreach { plan =>
if (plan.wrapped.isInstanceOf[SortExec]) {
if (!plan.canThisBeReplaced) {
willNotWorkOnGpu("one of the preceding SortExec's cannot be replaced")
} else {
plan.shouldBeRemoved("replacing sort aggregate with hash aggregate")
// But if this includes a first or last aggregate and the sort includes more than what
// the group by requires we cannot drop the sort. For example
// if the group by is on a single key "a", but the ordering is on "a" and "b", then
// we have to keep the sort, so that the rows are ordered to take "b" into account
// before first/last work on it.
val hasFirstOrLast = agg.aggregateExpressions.exists { agg =>
agg.aggregateFunction match {
case _: First | _: Last => true
case _ => false
}
}
val shouldRemoveSort = if (hasFirstOrLast) {
val sortedOrder = plan.wrapped.asInstanceOf[SortExec].sortOrder
val groupByRequiredOrdering = agg.requiredChildOrdering.head
sortedOrder == groupByRequiredOrdering
} else {
true
}

if (shouldRemoveSort) {
plan.shouldBeRemoved("replacing sort aggregate with hash aggregate")
}
}
}
}
Expand Down

0 comments on commit 30467f2

Please sign in to comment.