From 30467f255a83529d1dc5f98ef055f6c318473efc Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Mon, 27 Jun 2022 15:08:29 -0500 Subject: [PATCH] Improve sort removal heuristic for sort aggregate (#5917) * Improve sort removal heuristic for sort aggregate Signed-off-by: Robert (Bobby) Evans --- docs/configs.md | 2 +- .../src/main/python/hash_aggregate_test.py | 13 +++++++ .../com/nvidia/spark/rapids/RapidsConf.scala | 2 +- .../com/nvidia/spark/rapids/aggregate.scala | 37 ++++++++++++------- 4 files changed, 39 insertions(+), 15 deletions(-) diff --git a/docs/configs.md b/docs/configs.md index 66164fc8d7e..8428f069d60 100644 --- a/docs/configs.md +++ b/docs/configs.md @@ -101,7 +101,7 @@ Name | Description | Default Value spark.rapids.sql.hasNans|Config to indicate if your data has NaN's. Cudf doesn't currently support NaN's properly so you can get corrupt data if you have NaN's in your data and it runs on the GPU.|true spark.rapids.sql.hashOptimizeSort.enabled|Whether sorts should be inserted after some hashed operations to improve output ordering. This can improve output file sizes when saving to columnar formats.|false spark.rapids.sql.improvedFloatOps.enabled|For some floating point operations spark uses one way to compute the value and the underlying cudf implementation can use an improved algorithm. In some cases this can result in cudf producing an answer when spark overflows.|true -spark.rapids.sql.improvedTimeOps.enabled|When set to true, some operators will avoid overflowing by converting epoch days directly to seconds without first converting to microseconds|false +spark.rapids.sql.improvedTimeOps.enabled|When set to true, some operators will avoid overflowing by converting epoch days directly to seconds without first converting to microseconds|false spark.rapids.sql.incompatibleDateFormats.enabled|When parsing strings as dates and timestamps in functions like unix_timestamp, some formats are fully supported on the GPU and some are unsupported and will fall back to the CPU. Some formats behave differently on the GPU than the CPU. Spark on the CPU interprets date formats with unsupported trailing characters as nulls, while Spark on the GPU will parse the date with invalid trailing characters. More detail can be found at [parsing strings as dates or timestamps](compatibility.md#parsing-strings-as-dates-or-timestamps).|false spark.rapids.sql.incompatibleOps.enabled|For operations that work, but are not 100% compatible with the Spark equivalent set if they should be enabled by default or disabled by default.|true spark.rapids.sql.join.cross.enabled|When set to true cross joins are enabled on the GPU|true diff --git a/integration_tests/src/main/python/hash_aggregate_test.py b/integration_tests/src/main/python/hash_aggregate_test.py index ceb3dd76f0b..8b20d2a908f 100644 --- a/integration_tests/src/main/python/hash_aggregate_test.py +++ b/integration_tests/src/main/python/hash_aggregate_test.py @@ -1104,6 +1104,19 @@ def test_groupby_first_last(data_gen): # We set parallelism 1 to prevent nondeterministic results because of distributed setup. lambda spark: agg_fn(gen_df(spark, gen_fn, num_slices=1))) +@ignore_order(local=True) +@pytest.mark.parametrize('data_gen', all_gen + _struct_only_nested_gens, ids=idfn) +def test_sorted_groupby_first_last(data_gen): + gen_fn = [('a', RepeatSeqGen(LongGen(), length=20)), ('b', data_gen)] + # sort by more than the group by columns to be sure that first/last don't remove the ordering + agg_fn = lambda df: df.orderBy('a', 'b').groupBy('a').agg( + f.first('b'), f.last('b'), f.first('b', True), f.last('b', True)) + assert_gpu_and_cpu_are_equal_collect( + # First and last are not deterministic when they are run in a real distributed setup. + # We set parallelism 1 to prevent nondeterministic results because of distributed setup. + lambda spark: agg_fn(gen_df(spark, gen_fn, num_slices=1)), + conf = {'spark.rapids.sql.explain': 'ALL'}) + @ignore_order @pytest.mark.parametrize('data_gen', all_gen, ids=idfn) @pytest.mark.parametrize('count_func', [f.count, f.countDistinct]) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index d1630f2484a..65714baf775 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -526,7 +526,7 @@ object RapidsConf { val IMPROVED_TIMESTAMP_OPS = conf("spark.rapids.sql.improvedTimeOps.enabled") .doc("When set to true, some operators will avoid overflowing by converting epoch days " + - " directly to seconds without first converting to microseconds") + "directly to seconds without first converting to microseconds") .booleanConf .createWithDefault(false) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala index be90b0a3972..3a9c9595ab2 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala @@ -1313,24 +1313,35 @@ class GpuSortAggregateExecMeta( // Make sure this is the last check - if this is SortAggregate, the children can be sorts and we // want to validate they can run on GPU and remove them before replacing this with a - // HashAggregate. We don't want to do this if there is a first or last aggregate, - // because dropping the sort will make them no longer deterministic. - // In the future we might be able to pull the sort functionality into the aggregate so - // we can sort a single batch at a time and sort the combined result as well which would help - // with data skew. - val hasFirstOrLast = agg.aggregateExpressions.exists { agg => - agg.aggregateFunction match { - case _: First | _: Last => true - case _ => false - } - } - if (canThisBeReplaced && !hasFirstOrLast) { + // HashAggregate. + if (canThisBeReplaced) { childPlans.foreach { plan => if (plan.wrapped.isInstanceOf[SortExec]) { if (!plan.canThisBeReplaced) { willNotWorkOnGpu("one of the preceding SortExec's cannot be replaced") } else { - plan.shouldBeRemoved("replacing sort aggregate with hash aggregate") + // But if this includes a first or last aggregate and the sort includes more than what + // the group by requires we cannot drop the sort. For example + // if the group by is on a single key "a", but the ordering is on "a" and "b", then + // we have to keep the sort, so that the rows are ordered to take "b" into account + // before first/last work on it. + val hasFirstOrLast = agg.aggregateExpressions.exists { agg => + agg.aggregateFunction match { + case _: First | _: Last => true + case _ => false + } + } + val shouldRemoveSort = if (hasFirstOrLast) { + val sortedOrder = plan.wrapped.asInstanceOf[SortExec].sortOrder + val groupByRequiredOrdering = agg.requiredChildOrdering.head + sortedOrder == groupByRequiredOrdering + } else { + true + } + + if (shouldRemoveSort) { + plan.shouldBeRemoved("replacing sort aggregate with hash aggregate") + } } } }