diff --git a/integration_tests/src/main/python/hash_aggregate_test.py b/integration_tests/src/main/python/hash_aggregate_test.py index 9d236adcf71..516ba905617 100644 --- a/integration_tests/src/main/python/hash_aggregate_test.py +++ b/integration_tests/src/main/python/hash_aggregate_test.py @@ -329,6 +329,17 @@ def test_hash_reduction_sum_count_action(data_gen): lambda spark: gen_df(spark, data_gen, length=100).agg(f.sum('b')) ) +# Make sure that we can do computation in the group by columns +@ignore_order +def test_computation_in_grpby_columns(): + conf = {'spark.rapids.sql.batchSizeBytes' : '1000'} + data_gen = [ + ('a', RepeatSeqGen(StringGen('a{1,20}'), length=50)), + ('b', short_gen)] + assert_gpu_and_cpu_are_equal_collect( + lambda spark: gen_df(spark, data_gen).groupby(f.substring(f.col('a'), 2, 10)).agg(f.sum('b')), + conf = conf) + @shuffle_test @approximate_float @ignore_order diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala index 58896717179..a90b3309163 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala @@ -423,8 +423,8 @@ class GpuHashAggregateIterator( } val shims = ShimLoader.getSparkShims - val ordering = groupingExpressions.map(shims.sortOrder(_, Ascending, NullsFirst)) val groupingAttributes = groupingExpressions.map(_.toAttribute) + val ordering = groupingAttributes.map(shims.sortOrder(_, Ascending, NullsFirst)) val aggBufferAttributes = groupingAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) val sorter = new GpuSorter(ordering, aggBufferAttributes) @@ -644,8 +644,15 @@ class GpuHashAggregateIterator( private val postStep = new mutable.ArrayBuffer[Expression]() private val postStepAttr = new mutable.ArrayBuffer[Attribute]() - // we add the grouping expression first, which bind as pass-through - preStep ++= groupingExpressions + // we add the grouping expression first, which should bind as pass-through + if (forceMerge) { + // a grouping expression can do actual computation, but we cannot do that computation again + // on a merge, nor would we want to if we could. So use the attributes instead of the + // original expression when we are forcing a merge. + preStep ++= groupingAttributes + } else { + preStep ++= groupingExpressions + } postStep ++= groupingAttributes postStepAttr ++= groupingAttributes postStepDataTypes ++= @@ -679,7 +686,7 @@ class GpuHashAggregateIterator( // a bound expression that is applied before the cuDF aggregate private val preStepBound = if (forceMerge) { - GpuBindReferences.bindGpuReferences(preStep, aggBufferAttributes) + GpuBindReferences.bindGpuReferences(preStep.toList, aggBufferAttributes.toList) } else { GpuBindReferences.bindGpuReferences(preStep, inputAttributes) }