diff --git a/docs/configs.md b/docs/configs.md index c97fa82e7b3..17c7a3e4c56 100644 --- a/docs/configs.md +++ b/docs/configs.md @@ -297,8 +297,8 @@ Name | SQL Function(s) | Description | Default Value | Notes spark.rapids.sql.expression.Year|`year`|Returns the year from a date or timestamp|true|None| spark.rapids.sql.expression.AggregateExpression| |Aggregate expression|true|None| spark.rapids.sql.expression.Average|`avg`, `mean`|Average aggregate operator|true|None| -spark.rapids.sql.expression.CollectList|`collect_list`|Collect a list of non-unique elements, only supported in rolling window in current.|true|None| -spark.rapids.sql.expression.CollectSet|`collect_set`|Collect a set of unique elements, only supported in rolling window in current.|true|None| +spark.rapids.sql.expression.CollectList|`collect_list`|Collect a list of non-unique elements, not supported in reduction.|true|None| +spark.rapids.sql.expression.CollectSet|`collect_set`|Collect a set of unique elements, not supported in reduction.|true|None| spark.rapids.sql.expression.Count|`count`|Count aggregate operator|true|None| spark.rapids.sql.expression.First|`first_value`, `first`|first aggregate operator|true|None| spark.rapids.sql.expression.Last|`last`, `last_value`|last aggregate operator|true|None| @@ -330,6 +330,7 @@ Name | Description | Default Value | Notes spark.rapids.sql.exec.UnionExec|The backend for the union operator|true|None| spark.rapids.sql.exec.CustomShuffleReaderExec|A wrapper of shuffle query stage|true|None| spark.rapids.sql.exec.HashAggregateExec|The backend for hash based aggregations|true|None| +spark.rapids.sql.exec.ObjectHashAggregateExec|The backend for hash based aggregations supporting TypedImperativeAggregate functions|true|None| spark.rapids.sql.exec.SortAggregateExec|The backend for sort based aggregations|true|None| spark.rapids.sql.exec.DataWritingCommandExec|Writing data|true|None| spark.rapids.sql.exec.BatchScanExec|The backend for most file input|true|None| diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 7c22e577fb9..7db63b25910 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -495,8 +495,8 @@ Accelerator supports are described below. UDT -SortAggregateExec -The backend for sort based aggregations +ObjectHashAggregateExec +The backend for hash based aggregations supporting TypedImperativeAggregate functions None S S @@ -512,10 +512,33 @@ Accelerator supports are described below. S NS NS +PS* (not allowed for grouping expressions; missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT) +PS* (not allowed for grouping expressions; missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT) +PS* (not allowed for grouping expressions; missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT) NS -PS (missing nested BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE, DATE, TIMESTAMP, DECIMAL, NULL, BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT) + + +SortAggregateExec +The backend for sort based aggregations +None +S +S +S +S +S +S +S +S +S* +S +S* +S NS NS +PS* (not allowed for grouping expressions; missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT) +PS* (not allowed for grouping expressions; missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT) +PS* (not allowed for grouping expressions; missing nested BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT) +NS DataWritingCommandExec @@ -817,29 +840,6 @@ Accelerator supports are described below. NS -WindowInPandasExec -The backend for Window Aggregation Pandas UDF, Accelerates the data transfer between the Java process and the Python process. It also supports scheduling GPU resources for the Python process when enabled. For now it only supports row based window frame. -This is disabled by default because it only supports row based frame for now -S -S -S -S -S -S -S -S -S* -S -NS -NS -NS -NS -PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT) -NS -NS -NS - - Executor Description Notes @@ -863,6 +863,29 @@ Accelerator supports are described below. UDT +WindowInPandasExec +The backend for Window Aggregation Pandas UDF, Accelerates the data transfer between the Java process and the Python process. It also supports scheduling GPU resources for the Python process when enabled. For now it only supports row based window frame. +This is disabled by default because it only supports row based frame for now +S +S +S +S +S +S +S +S +S* +S +NS +NS +NS +NS +PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT) +NS +NS +NS + + WindowExec Window-operator backend None @@ -18515,9 +18538,9 @@ Accelerator support is described below. CollectList `collect_list` -Collect a list of non-unique elements, only supported in rolling window in current. +Collect a list of non-unique elements, not supported in reduction. None -aggregation +reduction input NS NS @@ -18560,25 +18583,25 @@ Accelerator support is described below. -reduction +aggregation input +S +S +S +S +S +S +S +S +S* +S +S* NS NS NS NS NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS +PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT) NS @@ -18597,7 +18620,7 @@ Accelerator support is described below. -NS +PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT) @@ -18648,9 +18671,9 @@ Accelerator support is described below. CollectSet `collect_set` -Collect a set of unique elements, only supported in rolling window in current. +Collect a set of unique elements, not supported in reduction. None -aggregation +reduction input NS NS @@ -18693,19 +18716,19 @@ Accelerator support is described below. -reduction +aggregation input -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS -NS +S +S +S +S +S +S +S +S +S* +S +S* NS NS NS @@ -18730,7 +18753,7 @@ Accelerator support is described below. -NS +PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT) diff --git a/integration_tests/src/main/python/asserts.py b/integration_tests/src/main/python/asserts.py index a08473ef5fb..0f02932a62c 100644 --- a/integration_tests/src/main/python/asserts.py +++ b/integration_tests/src/main/python/asserts.py @@ -301,6 +301,53 @@ def assert_gpu_fallback_write(write_func, assert_equal(from_cpu, from_gpu) +def assert_cpu_and_gpu_are_equal_collect_with_capture(func, + exist_classes='', + non_exist_classes='', + conf={}): + (bring_back, collect_type) = _prep_func_for_compare(func, 'COLLECT_WITH_DATAFRAME') + + conf = _prep_incompat_conf(conf) + + print('### CPU RUN ###') + cpu_start = time.time() + from_cpu, cpu_df = with_cpu_session(bring_back, conf=conf) + cpu_end = time.time() + print('### GPU RUN ###') + gpu_start = time.time() + from_gpu, gpu_df = with_gpu_session(bring_back, conf=conf) + gpu_end = time.time() + jvm = spark_jvm() + for clz in exist_classes.split(','): + jvm.com.nvidia.spark.rapids.ExecutionPlanCaptureCallback.assertContains(gpu_df._jdf, clz) + for clz in non_exist_classes.split(','): + jvm.com.nvidia.spark.rapids.ExecutionPlanCaptureCallback.assertNotContain(gpu_df._jdf, clz) + print('### {}: GPU TOOK {} CPU TOOK {} ###'.format(collect_type, + gpu_end - gpu_start, cpu_end - cpu_start)) + if should_sort_locally(): + from_cpu.sort(key=_RowCmp) + from_gpu.sort(key=_RowCmp) + + assert_equal(from_cpu, from_gpu) + +def assert_cpu_and_gpu_are_equal_sql_with_capture(df_fun, + sql, + table_name, + exist_classes='', + non_exist_classes='', + conf=None, + debug=False): + if conf is None: + conf = {} + def do_it_all(spark): + df = df_fun(spark) + df.createOrReplaceTempView(table_name) + if debug: + return data_gen.debug_df(spark.sql(sql)) + else: + return spark.sql(sql) + assert_cpu_and_gpu_are_equal_collect_with_capture(do_it_all, exist_classes, non_exist_classes, conf) + def assert_gpu_fallback_collect(func, cpu_fallback_class_name, conf={}): diff --git a/integration_tests/src/main/python/hash_aggregate_test.py b/integration_tests/src/main/python/hash_aggregate_test.py index ca1468c70e6..7d3b7df547f 100644 --- a/integration_tests/src/main/python/hash_aggregate_test.py +++ b/integration_tests/src/main/python/hash_aggregate_test.py @@ -14,7 +14,9 @@ import pytest -from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql, assert_gpu_fallback_collect +from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql,\ + assert_gpu_fallback_collect, assert_cpu_and_gpu_are_equal_sql_with_capture,\ + assert_cpu_and_gpu_are_equal_collect_with_capture from data_gen import * from functools import reduce from pyspark.sql.types import * @@ -375,6 +377,166 @@ def test_hash_reduction_pivot_without_nans(data_gen, conf): .agg(f.sum('c')), conf=conf) +_repeat_agg_column_for_collect_op = [ + RepeatSeqGen(BooleanGen(), length=15), + RepeatSeqGen(IntegerGen(), length=15), + RepeatSeqGen(LongGen(), length=15), + RepeatSeqGen(ShortGen(), length=15), + RepeatSeqGen(DateGen(), length=15), + RepeatSeqGen(TimestampGen(), length=15), + RepeatSeqGen(ByteGen(), length=15), + RepeatSeqGen(StringGen(), length=15), + RepeatSeqGen(FloatGen(), length=15), + RepeatSeqGen(DoubleGen(), length=15), + RepeatSeqGen(DecimalGen(precision=8, scale=3), length=15), + # case to verify the NAN_UNEQUAL strategy + RepeatSeqGen(FloatGen().with_special_case(math.nan, 200.0), length=5), +] + +_gen_data_for_collect_op = [[ + ('a', RepeatSeqGen(LongGen(), length=20)), + ('b', value_gen), + ('c', LongRangeGen())] for value_gen in _repeat_agg_column_for_collect_op +] + +# We wrapped sort_array functions on collect_list/collect_set because the orders of collected lists/sets are not +# deterministic. The annotation `ignore_order` only affects on the order between rows, while with collect ops we also +# need to guarantee the consistency of the row-wise order (the orders within each array produced by collect ops). +@approximate_float +@ignore_order(local=True) +@incompat +@pytest.mark.parametrize('data_gen', _gen_data_for_collect_op, ids=idfn) +@pytest.mark.parametrize('use_obj_hash_agg', [True, False], ids=idfn) +def test_hash_groupby_collect_list(data_gen, use_obj_hash_agg): + assert_gpu_and_cpu_are_equal_collect( + lambda spark: gen_df(spark, data_gen, length=100) + .groupby('a') + .agg(f.sort_array(f.collect_list('b')), f.count('b')), + conf={'spark.sql.execution.useObjectHashAggregateExec': str(use_obj_hash_agg).lower()}) + +@approximate_float +@ignore_order(local=True) +@incompat +@pytest.mark.parametrize('data_gen', _gen_data_for_collect_op, ids=idfn) +def test_hash_groupby_collect_set(data_gen): + assert_gpu_and_cpu_are_equal_collect( + lambda spark: gen_df(spark, data_gen, length=100) + .groupby('a') + .agg(f.sort_array(f.collect_set('b')), f.count('b'))) + +@approximate_float +@ignore_order(local=True) +@incompat +@pytest.mark.parametrize('data_gen', _gen_data_for_collect_op, ids=idfn) +def test_hash_groupby_collect_with_single_distinct(data_gen): + # test collect_ops with other distinct aggregations + assert_gpu_and_cpu_are_equal_collect( + lambda spark: gen_df(spark, data_gen, length=100) + .groupby('a') + .agg(f.sort_array(f.collect_list('b')), + f.sort_array(f.collect_set('b')), + f.countDistinct('c'), + f.count('c'))) + +@approximate_float +@ignore_order(local=True) +@incompat +@pytest.mark.parametrize('data_gen', _gen_data_for_collect_op, ids=idfn) +def test_hash_groupby_single_distinct_collect(data_gen): + # test distinct collect with other aggregations + sql = """select a, + sort_array(collect_list(distinct b)), + sort_array(collect_set(b)), + count(distinct b), + count(c) + from tbl group by a""" + assert_gpu_and_cpu_are_equal_sql( + df_fun=lambda spark: gen_df(spark, data_gen, length=100), + table_name="tbl", sql=sql) + +# Queries with multiple distinct aggregations will fallback to CPU if they also contain +# collect aggregations. Because Spark optimizer will insert expressions like `If` and `First` +# when rewriting distinct aggregates, while `GpuFirst` doesn't support the datatype of collect +# aggregations (ArrayType). +# TODO: support GPUFirst on ArrayType https://github.com/NVIDIA/spark-rapids/issues/3097 +@approximate_float +@ignore_order(local=True) +@allow_non_gpu('SortAggregateExec', + 'SortArray', 'Alias', 'Literal', 'First', 'If', 'EqualTo', 'Count', + 'CollectList', 'CollectSet', 'AggregateExpression') +@incompat +@pytest.mark.parametrize('data_gen', _gen_data_for_collect_op, ids=idfn) +def test_hash_groupby_collect_with_multi_distinct_fallback(data_gen): + def spark_fn(spark_session): + return gen_df(spark_session, data_gen, length=100).groupby('a').agg( + f.sort_array(f.collect_list('b')), + f.sort_array(f.collect_set('b')), + f.countDistinct('b'), + f.countDistinct('c')) + assert_gpu_fallback_collect(func=spark_fn, cpu_fallback_class_name='SortAggregateExec') + +@approximate_float +@ignore_order(local=True) +@allow_non_gpu('ObjectHashAggregateExec', 'ShuffleExchangeExec', + 'HashPartitioning', 'SortArray', 'Alias', 'Literal', + 'Count', 'CollectList', 'CollectSet', 'AggregateExpression') +@incompat +@pytest.mark.parametrize('data_gen', _gen_data_for_collect_op, ids=idfn) +@pytest.mark.parametrize('conf', [_nans_float_conf_partial, _nans_float_conf_final], ids=idfn) +@pytest.mark.parametrize('aqe_enabled', ['true', 'false'], ids=idfn) +def test_hash_groupby_collect_partial_replace_fallback(data_gen, conf, aqe_enabled): + local_conf = conf.copy() + local_conf.update({'spark.sql.adaptive.enabled': aqe_enabled}) + # test without Distinct + assert_cpu_and_gpu_are_equal_collect_with_capture( + lambda spark: gen_df(spark, data_gen, length=100) + .groupby('a') + .agg(f.sort_array(f.collect_list('b')), f.sort_array(f.collect_set('b'))), + exist_classes='CollectList,CollectSet', + non_exist_classes='GpuCollectList,GpuCollectSet', + conf=local_conf) + # test with single Distinct + assert_cpu_and_gpu_are_equal_collect_with_capture( + lambda spark: gen_df(spark, data_gen, length=100) + .groupby('a') + .agg(f.sort_array(f.collect_list('b')), + f.sort_array(f.collect_set('b')), + f.countDistinct('c'), + f.count('c')), + exist_classes='CollectList,CollectSet', + non_exist_classes='GpuCollectList,GpuCollectSet', + conf=local_conf) + +@ignore_order(local=True) +@allow_non_gpu('ObjectHashAggregateExec', 'ShuffleExchangeExec', 'HashAggregateExec', + 'HashPartitioning', 'SortArray', 'Alias', 'Literal', + 'CollectList', 'CollectSet', 'Max', 'AggregateExpression') +@pytest.mark.parametrize('conf', [_nans_float_conf_final, _nans_float_conf_partial], ids=idfn) +@pytest.mark.parametrize('aqe_enabled', ['true', 'false'], ids=idfn) +def test_hash_groupby_collect_partial_replace_fallback_with_other_agg(conf, aqe_enabled): + # This test is to ensure "associated fallback" will not affect another Aggregate plans. + local_conf = conf.copy() + local_conf.update({'spark.sql.adaptive.enabled': aqe_enabled}) + + assert_cpu_and_gpu_are_equal_sql_with_capture( + lambda spark: gen_df(spark, [('k1', RepeatSeqGen(LongGen(), length=20)), + ('k2', RepeatSeqGen(LongGen(), length=20)), + ('v', LongRangeGen())], length=100), + exist_classes='GpuMax,Max,CollectList,CollectSet', + non_exist_classes='GpuObjectHashAggregateExec,GpuCollectList,GpuCollectSet', + table_name='table', + sql=""" + select k1, + sort_array(collect_set(k2)), + sort_array(collect_list(max_v)) + from + (select k1, k2, + max(v) as max_v + from table group by k1, k2 + )t + group by k1""", + conf=local_conf) + @approximate_float @ignore_order @incompat diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 4cbcc250c7f..be12958bc04 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.ScalarSubquery import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, CustomShuffleReaderExec, ShuffleQueryStageExec} -import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, DataWritingCommand, DataWritingCommandExec, ExecutedCommandExec} import org.apache.spark.sql.execution.datasources.{FileFormat, InsertIntoHadoopFsRelationCommand} import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat @@ -2677,32 +2677,42 @@ object GpuOverrides { override def convertToGpu(): GpuExpression = GpuPosExplode(childExprs.head.convertToGpu()) }), expr[CollectList]( - "Collect a list of non-unique elements, only supported in rolling window in current.", - // GpuCollectList is not yet supported under GroupBy and Reduction context. - ExprChecks.aggNotGroupByOrReduction( + "Collect a list of non-unique elements, not supported in reduction.", + // GpuCollectList is not yet supported in Reduction context. + ExprChecks.aggNotReduction( TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.STRUCT), TypeSig.ARRAY.nested(TypeSig.all), Seq(ParamCheck("input", TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.STRUCT.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL), TypeSig.all))), - (c, conf, p, r) => new ExprMeta[CollectList](c, conf, p, r) { - override def convertToGpu(): GpuExpression = GpuCollectList( - childExprs.head.convertToGpu(), c.mutableAggBufferOffset, c.inputAggBufferOffset) + (c, conf, p, r) => new TypedImperativeAggExprMeta[CollectList](c, conf, p, r) { + override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = { + GpuCollectList(childExprs.head, c.mutableAggBufferOffset, c.inputAggBufferOffset) + } + override def aggBufferAttribute: AttributeReference = { + val aggBuffer = c.aggBufferAttributes.head + aggBuffer.copy(dataType = c.dataType)(aggBuffer.exprId, aggBuffer.qualifier) + } }), expr[CollectSet]( - "Collect a set of unique elements, only supported in rolling window in current.", - // GpuCollectSet is not yet supported under GroupBy and Reduction context. + "Collect a set of unique elements, not supported in reduction.", + // GpuCollectSet is not yet supported in Reduction context. // Compared to CollectList, StructType is NOT in GpuCollectSet because underlying // method drop_list_duplicates doesn't support nested types. - ExprChecks.aggNotGroupByOrReduction( + ExprChecks.aggNotReduction( TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL), TypeSig.ARRAY.nested(TypeSig.all), Seq(ParamCheck("input", TypeSig.commonCudfTypes + TypeSig.DECIMAL, TypeSig.all))), - (c, conf, p, r) => new ExprMeta[CollectSet](c, conf, p, r) { - override def convertToGpu(): GpuExpression = GpuCollectSet( - childExprs.head.convertToGpu(), c.mutableAggBufferOffset, c.inputAggBufferOffset) + (c, conf, p, r) => new TypedImperativeAggExprMeta[CollectSet](c, conf, p, r) { + override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = { + GpuCollectSet(childExprs.head, c.mutableAggBufferOffset, c.inputAggBufferOffset) + } + override def aggBufferAttribute: AttributeReference = { + val aggBuffer = c.aggBufferAttributes.head + aggBuffer.copy(dataType = c.dataType)(aggBuffer.exprId, aggBuffer.qualifier) + } }), expr[GetJsonObject]( "Extracts a json object from path", @@ -3045,13 +3055,28 @@ object GpuOverrides { .withPsNote(TypeEnum.STRUCT, "not allowed for grouping expressions"), TypeSig.all), (agg, conf, p, r) => new GpuHashAggregateMeta(agg, conf, p, r)), + exec[ObjectHashAggregateExec]( + "The backend for hash based aggregations supporting TypedImperativeAggregate functions", + ExecChecks( + (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL + + TypeSig.MAP + TypeSig.ARRAY + TypeSig.STRUCT) + .nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL) + .withPsNote(TypeEnum.ARRAY, "not allowed for grouping expressions") + .withPsNote(TypeEnum.MAP, "not allowed for grouping expressions") + .withPsNote(TypeEnum.STRUCT, "not allowed for grouping expressions"), + TypeSig.all), + (agg, conf, p, r) => new GpuObjectHashAggregateExecMeta(agg, conf, p, r)), exec[SortAggregateExec]( "The backend for sort based aggregations", ExecChecks( - (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.MAP) - .nested(TypeSig.STRING), + (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL + + TypeSig.MAP + TypeSig.ARRAY + TypeSig.STRUCT) + .nested(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL) + .withPsNote(TypeEnum.ARRAY, "not allowed for grouping expressions") + .withPsNote(TypeEnum.MAP, "not allowed for grouping expressions") + .withPsNote(TypeEnum.STRUCT, "not allowed for grouping expressions"), TypeSig.all), - (agg, conf, p, r) => new GpuSortAggregateMeta(agg, conf, p, r)), + (agg, conf, p, r) => new GpuSortAggregateExecMeta(agg, conf, p, r)), exec[SortExec]( "The backend for the sort operator", // The SortOrder TypeSig will govern what types can actually be used as sorting key data type. diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortExec.scala index 6f9ca0586ed..597bd988fa8 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortExec.scala @@ -45,6 +45,12 @@ class GpuSortMeta( parent: Option[RapidsMeta[_, _, _]], rule: DataFromReplacementRule) extends SparkPlanMeta[SortExec](sort, conf, parent, rule) { + + // Uses output attributes of child plan because SortExec will not change the attributes, + // and we need to propagate possible type conversions on the output attributes of + // GpuSortAggregateExec. + override protected val useOutputAttributesOfChild: Boolean = true + override def convertToGpu(): GpuExec = { GpuSortExec(childExprs.map(_.convertToGpu()).asInstanceOf[Seq[SortOrder]], sort.global, diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala index b243901ef96..59c9c57ae64 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.{DataFrame, SparkSessionExtensions} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec} import org.apache.spark.sql.internal.StaticSQLConf import org.apache.spark.sql.rapids.GpuShuffleEnv import org.apache.spark.sql.util.QueryExecutionListener @@ -330,6 +330,28 @@ object ExecutionPlanCaptureCallback { assertDidFallBack(executedPlan, fallbackCpuClass) } + def assertContains(gpuPlan: SparkPlan, className: String): Unit = { + val executedPlan = ExecutionPlanCaptureCallback.extractExecutedPlan(Some(gpuPlan)) + assert(containsPlan(executedPlan, className), + s"Could not find $className in the Spark plan\n$executedPlan") + } + + def assertContains(df: DataFrame, gpuClass: String): Unit = { + val executedPlan = df.queryExecution.executedPlan + assertContains(executedPlan, gpuClass) + } + + def assertNotContain(gpuPlan: SparkPlan, className: String): Unit = { + val executedPlan = ExecutionPlanCaptureCallback.extractExecutedPlan(Some(gpuPlan)) + assert(!containsPlan(executedPlan, className), + s"We found $className in the Spark plan\n$executedPlan") + } + + def assertNotContain(df: DataFrame, gpuClass: String): Unit = { + val executedPlan = df.queryExecution.executedPlan + assertNotContain(executedPlan, gpuClass) + } + private def didFallBack(exp: Expression, fallbackCpuClass: String): Boolean = { !exp.isInstanceOf[GpuExpression] && PlanUtils.getBaseNameFromClass(exp.getClass.getName) == fallbackCpuClass || @@ -341,6 +363,21 @@ object ExecutionPlanCaptureCallback { !executedPlan.isInstanceOf[GpuExec] && PlanUtils.sameClass(executedPlan, fallbackCpuClass) || executedPlan.expressions.exists(didFallBack(_, fallbackCpuClass)) } + + private def containsExpression(exp: Expression, className: String): Boolean = { + PlanUtils.getBaseNameFromClass(exp.getClass.getName) == className || + exp.children.exists(containsExpression(_, className)) + } + + private def containsPlan(plan: SparkPlan, className: String): Boolean = { + val p = ExecutionPlanCaptureCallback.extractExecutedPlan(Some(plan)) match { + case p: QueryStageExec => p.plan + case p => p + } + PlanUtils.sameClass(p, className) || + p.expressions.exists(containsExpression(_, className)) || + p.children.exists(containsPlan(_, className)) + } } /** diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala index c6cbcad64a7..acc70c77e97 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala @@ -21,7 +21,7 @@ import java.time.ZoneId import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryExpression, ComplexTypeMergingExpression, Expression, LambdaFunction, String2TrimExpression, TernaryExpression, UnaryExpression, WindowExpression, WindowFunction} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, ImperativeAggregate} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, ImperativeAggregate, TypedImperativeAggregate} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.connector.read.Scan @@ -29,7 +29,6 @@ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.aggregate.BaseAggregateExec import org.apache.spark.sql.execution.command.DataWritingCommand import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec -import org.apache.spark.sql.execution.window.WindowExecBase import org.apache.spark.sql.types.DataType trait DataFromReplacementRule { @@ -672,18 +671,18 @@ abstract class SparkPlanMeta[INPUT <: SparkPlan](plan: INPUT, } /** - * Gets output attributes of current SparkPlanMeta, which is supposed to be called - * in the tag methods of ExecChecks. + * Gets output attributes of current SparkPlanMeta, which is supposed to be called during + * type checking for the current plan. * - * By default, it simply returns the output of wrapped plan. But for specific plans, they can - * override outputTypeMetas to apply custom conversions on the output of wrapped plan. + * By default, it simply returns the output of wrapped plan. For specific plans, they can + * override outputTypeMetas to apply custom conversions on the output of wrapped plan. For plans + * which just pass through the schema of childPlan, they can set useOutputAttributesOfChild to + * true, in order to propagate the custom conversions of childPlan if they exist. */ def outputAttributes: Seq[Attribute] = outputTypeMetas match { case Some(typeMetas) => - if (typeMetas.length != wrapped.output.length) { - throw new IllegalArgumentException( - "The length of outputTypeMetas doesn't match to the length of plan's output") - } + require(typeMetas.length == wrapped.output.length, + "The length of outputTypeMetas doesn't match to the length of plan's output") wrapped.output.zip(typeMetas).map { case (ar, meta) if meta.typeConverted => addConvertedDataType(ar.name, meta) @@ -692,6 +691,21 @@ abstract class SparkPlanMeta[INPUT <: SparkPlan](plan: INPUT, case (ar, _) => ar } + case None if useOutputAttributesOfChild => + require(wrapped.children.length == 1, + "useOutputAttributesOfChild ONLY works on UnaryPlan") + // We will check whether the child plan can be replaced or not. We only pass through the + // outputAttributes of the child plan when it is GPU enabled. Otherwise, we should fetch the + // outputAttributes from the wrapped plan, because type overriding of RapidsMeta is + // specialized for the GPU runtime. + // + // We can safely call childPlan.canThisBeReplaced here, because outputAttributes is called + // via tagSelfForGpu. At this point, tagging of the child plan has already happened. + if (childPlans.head.canThisBeReplaced) { + childPlans.head.outputAttributes + } else { + wrapped.output + } case None => wrapped.output } @@ -699,7 +713,12 @@ abstract class SparkPlanMeta[INPUT <: SparkPlan](plan: INPUT, /** * Overrides this method to implement custom conversions for specific plans. */ - protected def outputTypeMetas: Option[Seq[DataTypeMeta]] = None + protected lazy val outputTypeMetas: Option[Seq[DataTypeMeta]] = None + + /** + * Whether to pass through the outputAttributes of childPlan's meta, only for UnaryPlan + */ + protected val useOutputAttributesOfChild: Boolean = false } /** @@ -818,13 +837,13 @@ object DataTypeMeta { /** * create DataTypeMeta from Expression */ - def apply(expr: Expression): DataTypeMeta = { + def apply(expr: Expression, overrideType: Option[DataType]): DataTypeMeta = { val wrapped = try { Some(expr.dataType) } catch { case _: java.lang.UnsupportedOperationException => None } - new DataTypeMeta(wrapped) + new DataTypeMeta(wrapped, overrideType) } } @@ -857,9 +876,20 @@ abstract class BaseExprMeta[INPUT <: Expression]( * tag methods of expression-level type checks. * * By default, it simply returns the data type of wrapped expression. But for specific - * expressions, they can override this method to apply custom transitions on the data type. + * expressions, they can easily override data type for type checking through calling the + * method `overrideDataType`. */ - def typeMeta: DataTypeMeta = DataTypeMeta(wrapped.asInstanceOf[Expression]) + def typeMeta: DataTypeMeta = DataTypeMeta(wrapped.asInstanceOf[Expression], overrideType) + + /** + * Overrides the data type of the wrapped expression during type checking. + * + * NOTICE: This method will NOT modify the wrapped expression itself. Therefore, the actual + * transition on data type is still necessary when converting this expression to GPU. + */ + def overrideDataType(dt: DataType): Unit = overrideType = Some(dt) + + private var overrideType: Option[DataType] = None lazy val context: ExpressionContext = expr match { case _: LambdaFunction => LambdaExprContext @@ -944,6 +974,24 @@ abstract class ImperativeAggExprMeta[INPUT <: ImperativeAggregate]( def convertToGpu(childExprs: Seq[Expression]): GpuExpression } +/** + * Base class for metadata around `TypedImperativeAggregate`. + */ +abstract class TypedImperativeAggExprMeta[INPUT <: TypedImperativeAggregate[_]]( + expr: INPUT, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _, _]], + rule: DataFromReplacementRule) + extends ImperativeAggExprMeta[INPUT](expr, conf, parent, rule) { + + /** + * Returns aggregation buffer with the actual data type under GPU runtime. This method is + * called to override the data types of typed imperative aggregation buffers during GPU + * overriding. + */ + def aggBufferAttribute: AttributeReference +} + /** * Base class for metadata around `BinaryExpression`. */ @@ -1025,4 +1073,4 @@ final class RuleNotFoundExprMeta[INPUT <: Expression]( override def convertToGpu(): GpuExpression = throw new IllegalStateException("Cannot be converted to GPU") -} \ No newline at end of file +} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala index 1702436def0..84712649360 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala @@ -1204,28 +1204,27 @@ object ExprChecks { ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)))) /** - * An aggregation check where window operations are supported by the plugin, but Spark - * also supports group by and reduction on these. - * This is now really for 'collect_list' which is only supported by windowing. + * An aggregation check where group by and window operations are supported by the plugin, but + * Spark also supports reduction on these. */ - def aggNotGroupByOrReduction( + def aggNotReduction( outputCheck: TypeSig, sparkOutputSig: TypeSig, paramCheck: Seq[ParamCheck] = Seq.empty, repeatingParamCheck: Option[RepeatingParamCheck] = None): ExprChecks = { - val notWindowParamCheck = paramCheck.map { pc => + val noneParamCheck = paramCheck.map { pc => ParamCheck(pc.name, TypeSig.none, pc.spark) } - val notWindowRepeat = repeatingParamCheck.map { pc => + val noneRepeatCheck = repeatingParamCheck.map { pc => RepeatingParamCheck(pc.name, TypeSig.none, pc.spark) } ExprChecksImpl(Map( - (GroupByAggExprContext, - ContextChecks(TypeSig.none, sparkOutputSig, notWindowParamCheck, notWindowRepeat)), (ReductionAggExprContext, - ContextChecks(TypeSig.none, sparkOutputSig, notWindowParamCheck, notWindowRepeat)), + ContextChecks(TypeSig.none, sparkOutputSig, noneParamCheck, noneRepeatCheck)), + (GroupByAggExprContext, + ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)), (WindowAggExprContext, - ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)))) + ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck)))) } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala index 883c82911a4..25bde0b259d 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala @@ -21,7 +21,7 @@ import java.util import scala.collection.mutable import ai.rapids.cudf -import ai.rapids.cudf.{NvtxColor, Scalar} +import ai.rapids.cudf.{DType, NvtxColor, Scalar} import com.nvidia.spark.rapids.GpuMetric._ import com.nvidia.spark.rapids.RapidsPluginImplicits._ @@ -29,13 +29,15 @@ import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeReference, AttributeSeq, AttributeSet, Expression, If, NamedExpression, NullsFirst} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeReference, AttributeSeq, AttributeSet, Expression, ExprId, If, NamedExpression, NullsFirst} import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, HashPartitioning, Partitioning, UnspecifiedDistribution} +import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.{ExplainUtils, SortExec, SparkPlan, UnaryExecNode} -import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, SortAggregateExec} +import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.rapids.{CudfAggregate, GpuAggregateExpression} import org.apache.spark.sql.rapids.execution.TrampolineUtil import org.apache.spark.sql.types.{ArrayType, DataType, LongType, MapType, StructType} @@ -773,15 +775,28 @@ class GpuHashAggregateIterator( // later. Cast here to the type that the aggregate expects (e.g. Long in case of count) val dataTypes = groupingExpressions.map(_.dataType) ++ aggregates.map(_.dataType) - val resCols = new mutable.ArrayBuffer[ColumnVector](result.getNumberOfColumns) - for (i <- 0 until result.getNumberOfColumns) { - val rapidsType = GpuColumnVector.getNonNestedRapidsType(dataTypes(i)) - // cast will be cheap if type matches, only does refCount++ in that case - closeOnExcept(result.getColumn(i).castTo(rapidsType)) { castedCol => - resCols += GpuColumnVector.from(castedCol, dataTypes(i)) + val resCols = mutable.ArrayBuffer.empty[ColumnVector] + closeOnExcept(resCols) { resCols => + (0 until result.getNumberOfColumns).foldLeft(resCols) { case (ret, i) => + val column = result.getColumn(i) + val rapidsType = dataTypes(i) match { + case dt if GpuColumnVector.isNonNestedSupportedType(dt) => + GpuColumnVector.getNonNestedRapidsType(dataTypes(i)) + case dt: ArrayType if GpuColumnVector.typeConversionAllowed(column, dt) => + DType.LIST + case dt: StructType if GpuColumnVector.typeConversionAllowed(column, dt) => + DType.STRUCT + case dt => + throw new IllegalArgumentException(s"Can NOT convert $column to data type $dt.") + } + // cast will be cheap if type matches, only does refCount++ in that case + withResource(column.castTo(rapidsType)) { castedCol => + ret += GpuColumnVector.from(castedCol.incRefCount(), dataTypes(i)) + } + ret } + new ColumnarBatch(resCols.toArray, result.getRowCount.toInt) } - new ColumnarBatch(resCols.toArray, result.getRowCount.toInt) } } else { // Reduction aggregate @@ -829,15 +844,15 @@ abstract class GpuBaseAggregateMeta[INPUT <: SparkPlan]( val agg: BaseAggregateExec - private val requiredChildDistributionExpressions: Option[Seq[BaseExprMeta[_]]] = + protected val requiredChildDistributionExpressions: Option[Seq[BaseExprMeta[_]]] = aggRequiredChildDistributionExpressions.map(_.map(GpuOverrides.wrapExpr(_, conf, Some(this)))) - private val groupingExpressions: Seq[BaseExprMeta[_]] = + protected val groupingExpressions: Seq[BaseExprMeta[_]] = agg.groupingExpressions.map(GpuOverrides.wrapExpr(_, conf, Some(this))) - private val aggregateExpressions: Seq[BaseExprMeta[_]] = + protected val aggregateExpressions: Seq[BaseExprMeta[_]] = agg.aggregateExpressions.map(GpuOverrides.wrapExpr(_, conf, Some(this))) - private val aggregateAttributes: Seq[BaseExprMeta[_]] = + protected val aggregateAttributes: Seq[BaseExprMeta[_]] = agg.aggregateAttributes.map(GpuOverrides.wrapExpr(_, conf, Some(this))) - private val resultExpressions: Seq[BaseExprMeta[_]] = + protected val resultExpressions: Seq[BaseExprMeta[_]] = agg.resultExpressions.map(GpuOverrides.wrapExpr(_, conf, Some(this))) override val childExprs: Seq[BaseExprMeta[_]] = @@ -927,6 +942,191 @@ abstract class GpuBaseAggregateMeta[INPUT <: SparkPlan]( } } +/** + * Base class for metadata around `SortAggregateExec` and `ObjectHashAggregateExec`, which may + * contain TypedImperativeAggregate functions in aggregate expressions. + */ +abstract class GpuTypedImperativeSupportedAggregateExecMeta[INPUT <: SparkPlan]( + plan: INPUT, + aggRequiredChildDistributionExpressions: Option[Seq[Expression]], + conf: RapidsConf, + parent: Option[RapidsMeta[_, _, _]], + rule: DataFromReplacementRule) extends GpuBaseAggregateMeta[INPUT](plan, + aggRequiredChildDistributionExpressions, conf, parent, rule) { + + private val mayNeedAggBufferConversion: Boolean = + agg.aggregateExpressions.exists { expr => + expr.aggregateFunction.isInstanceOf[TypedImperativeAggregate[_]] && + (expr.mode == Partial || expr.mode == PartialMerge) + } + + // overriding data types of Aggregation Buffers if necessary + if (mayNeedAggBufferConversion) overrideAggBufTypes() + + override protected lazy val outputTypeMetas: Option[Seq[DataTypeMeta]] = + if (mayNeedAggBufferConversion) { + Some(resultExpressions.map(_.typeMeta)) + } else { + None + } + + override def tagPlanForGpu(): Unit = { + // when AQE is enabled and we are planning a new query stage, we need to look at meta-data + // previously stored on the spark plan to determine whether this plan can run on GPU + wrapped.getTagValue(gpuSupportedTag).foreach(_.foreach(willNotWorkOnGpu)) + + super.tagPlanForGpu() + + // We can not run part of TypedImperativeAggregate functions on GPU, because GPU buffers + // are inconsistent with CPU buffers. Therefore, we have to fall back all Aggregate stages + // to CPU once any of them did fallback, in order to guarantee no partial-accelerated + // TypedImperativeAggregate function. + // + // This fallback procedure adapts AQE. As what GpuExchanges do, it leverages the + // `gpuSupportedTag` to store the information about whether instances are GPU-supported + // or not, which is produced by the side effect of `willNotWorkOnGpu`. When AQE is on, + // during the preparation stage, there will be a run of GpuOverrides on the entire plan to + // trigger these side effects if necessary, before AQE splits the entire query into several + // query stages. + GpuTypedImperativeSupportedAggregateExecMeta.checkAndFallbackEntirely(this) + } + + override def convertToGpu(): GpuExec = { + if (mayNeedAggBufferConversion) { + // transforms the data types of aggregate attributes with typeMeta + val aggAttributes = aggregateAttributes.map { + case meta if meta.typeMeta.typeConverted => + val ref = meta.wrapped.asInstanceOf[AttributeReference] + val converted = ref.copy( + dataType = meta.typeMeta.dataType.get)(ref.exprId, ref.qualifier) + GpuOverrides.wrapExpr(converted, conf, Some(this)) + case meta => meta + } + // transforms the data types of result expressions with typeMeta + val retExpressions = resultExpressions.map { + case meta if meta.typeMeta.typeConverted => + val ref = meta.wrapped.asInstanceOf[AttributeReference] + val converted = ref.copy( + dataType = meta.typeMeta.dataType.get)(ref.exprId, ref.qualifier) + GpuOverrides.wrapExpr(converted, conf, Some(this)) + case meta => meta + } + GpuHashAggregateExec( + requiredChildDistributionExpressions.map(_.map(_.convertToGpu())), + groupingExpressions.map(_.convertToGpu()), + aggregateExpressions.map(_.convertToGpu()).asInstanceOf[Seq[GpuAggregateExpression]], + aggAttributes.map(_.convertToGpu()).asInstanceOf[Seq[Attribute]], + retExpressions.map(_.convertToGpu()).asInstanceOf[Seq[NamedExpression]], + childPlans.head.convertIfNeeded(), + conf.gpuTargetBatchSizeBytes) + } else { + super.convertToGpu() + } + } + + /** + * The method replaces data types of aggregation buffers created by TypedImperativeAggregate + * functions with the actual data types used in the GPU runtime. + * + * Firstly, this method traverses aggregateFunctions, to search attributes referring to + * aggregation buffers of TypedImperativeAggregate functions. + * Then, we extract the desired (actual) data types on GPU runtime for these attributes, + * and map them to expression IDs of attributes. + * At last, we traverse aggregateAttributes and resultExpressions, overriding data type in + * RapidsMeta if necessary, in order to ensure TypeChecks tagging exact data types in runtime. + */ + private def overrideAggBufTypes(): Unit = { + val desiredAggBufTypes = mutable.HashMap.empty[ExprId, DataType] + val desiredInputAggBufTypes = mutable.HashMap.empty[ExprId, DataType] + // Collects exprId from TypedImperativeAggBufferAttributes, and maps them to the data type + // of `TypedImperativeAggExprMeta.aggBufferAttribute`. + agg.aggregateExpressions.zipWithIndex.foreach { + case (expr, i) if expr.aggregateFunction.isInstanceOf[TypedImperativeAggregate[_]] => + + val aggFn = expr.aggregateFunction + val aggMeta = aggregateExpressions(i).childExprs.head + .asInstanceOf[TypedImperativeAggExprMeta[_]] + val desiredDataType = aggMeta.aggBufferAttribute.dataType + + var buf = aggFn.aggBufferAttributes.head + desiredAggBufTypes(buf.exprId) = desiredDataType + + buf = aggFn.inputAggBufferAttributes.head + desiredInputAggBufTypes(buf.exprId) = desiredDataType + + case _ => + } + + // Overrides the data types of typed imperative aggregation buffers for type checking + aggregateAttributes.foreach { attrMeta => + attrMeta.wrapped match { + case ar: AttributeReference if desiredAggBufTypes.contains(ar.exprId) => + attrMeta.overrideDataType(desiredAggBufTypes(ar.exprId)) + case _ => + } + } + resultExpressions.foreach { retMeta => + retMeta.wrapped match { + case ar: AttributeReference if desiredInputAggBufTypes.contains(ar.exprId) => + retMeta.overrideDataType(desiredInputAggBufTypes(ar.exprId)) + case _ => + } + } + } +} + +object GpuTypedImperativeSupportedAggregateExecMeta { + + private val entireAggFallbackCheck = TreeNodeTag[Boolean]( + "rapids.gpu.checkAndFallbackAggregateExecEntirely") + + private def checkAndFallbackEntirely( + meta: GpuTypedImperativeSupportedAggregateExecMeta[_]): Unit = { + // We only run the check for final stages which contain TypedImperativeAggregate. + val needToCheck = meta.agg.aggregateExpressions.exists(e => e.mode == Final && + e.aggregateFunction.isInstanceOf[TypedImperativeAggregate[_]]) + if (!needToCheck) return + // Avoid duplicated check and fallback. + val checked = meta.agg.getTagValue[Boolean](entireAggFallbackCheck).contains(true) + if (checked) return + + meta.agg.setTagValue(entireAggFallbackCheck, true) + val logicalPlan = meta.agg.logicalLink.get + val stageMetas = mutable.ListBuffer[GpuBaseAggregateMeta[_]]() + // Go through all Aggregate stages to check whether all stages is GPU supported. If not, + // we fall back all GPU supported stages to CPU. + if (recursiveCheckForFallback(meta, logicalPlan, stageMetas)) { + stageMetas.foreach { + case aggMeta if aggMeta.canThisBeReplaced => + aggMeta.willNotWorkOnGpu("Associated fallback for TypedImperativeAggregate") + case _ => + } + } + } + + /** + * Recursively collect all PlanMetas of input LogicalPlan. At the same time, and check whether + * existing plans which can NOT be replaced among them. If any, return true as the label of + * the entire fallback. + */ + private def recursiveCheckForFallback( + currentMeta: SparkPlanMeta[_], + logical: LogicalPlan, + metaOfAllStages: mutable.ListBuffer[GpuBaseAggregateMeta[_]]): Boolean = { + currentMeta match { + case aggMeta: GpuBaseAggregateMeta[_] if aggMeta.agg.logicalLink.contains(logical) => + metaOfAllStages += aggMeta + val childCheck = recursiveCheckForFallback(aggMeta.childPlans.head, + logical, metaOfAllStages) + !aggMeta.canThisBeReplaced || childCheck + case unaryMeta: SparkPlanMeta[_] if unaryMeta.childPlans.length == 1 => + recursiveCheckForFallback(unaryMeta.childPlans.head, logical, metaOfAllStages) + case _ => + false + } + } +} + class GpuHashAggregateMeta( override val agg: HashAggregateExec, conf: RapidsConf, @@ -935,13 +1135,13 @@ class GpuHashAggregateMeta( extends GpuBaseAggregateMeta(agg, agg.requiredChildDistributionExpressions, conf, parent, rule) -class GpuSortAggregateMeta( +class GpuSortAggregateExecMeta( override val agg: SortAggregateExec, conf: RapidsConf, parent: Option[RapidsMeta[_, _, _]], rule: DataFromReplacementRule) - extends GpuBaseAggregateMeta(agg, agg.requiredChildDistributionExpressions, - conf, parent, rule) { + extends GpuTypedImperativeSupportedAggregateExecMeta(agg, + agg.requiredChildDistributionExpressions, conf, parent, rule) { override def tagPlanForGpu(): Unit = { super.tagPlanForGpu() @@ -972,6 +1172,14 @@ class GpuSortAggregateMeta( } } +class GpuObjectHashAggregateExecMeta( + override val agg: ObjectHashAggregateExec, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _, _]], + rule: DataFromReplacementRule) + extends GpuTypedImperativeSupportedAggregateExecMeta(agg, + agg.requiredChildDistributionExpressions, conf, parent, rule) + /** * The GPU version of HashAggregateExec * diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala index ce34150722a..024d941cdb0 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala @@ -253,6 +253,50 @@ class CudfMin(ref: Expression) extends CudfAggregate(ref) { override def toString(): String = "CudfMin" } +class CudfCollectList(ref: Expression) extends CudfAggregate(ref) { + override lazy val updateReductionAggregate: cudf.ColumnVector => cudf.Scalar = + throw new UnsupportedOperationException("CollectList is not yet supported in reduction") + override lazy val mergeReductionAggregate: cudf.ColumnVector => cudf.Scalar = + throw new UnsupportedOperationException("CollectList is not yet supported in reduction") + override lazy val updateAggregate: Aggregation = Aggregation.collectList() + override lazy val mergeAggregate: Aggregation = Aggregation.mergeLists() + override def toString(): String = "CudfCollectList" + override def dataType: DataType = ArrayType(ref.dataType, containsNull = false) + override def nullable: Boolean = false +} + +class CudfMergeLists(ref: Expression) extends CudfAggregate(ref) { + override lazy val updateReductionAggregate: cudf.ColumnVector => cudf.Scalar = + throw new UnsupportedOperationException("MergeLists is not yet supported in reduction") + override lazy val mergeReductionAggregate: cudf.ColumnVector => cudf.Scalar = + throw new UnsupportedOperationException("MergeLists is not yet supported in reduction") + override lazy val updateAggregate: Aggregation = Aggregation.mergeLists() + override lazy val mergeAggregate: Aggregation = Aggregation.mergeLists() + override def toString(): String = "CudfMergeLists" +} + +class CudfCollectSet(ref: Expression) extends CudfAggregate(ref) { + override lazy val updateReductionAggregate: cudf.ColumnVector => cudf.Scalar = + throw new UnsupportedOperationException("CollectSet is not yet supported in reduction") + override lazy val mergeReductionAggregate: cudf.ColumnVector => cudf.Scalar = + throw new UnsupportedOperationException("CollectSet is not yet supported in reduction") + override lazy val updateAggregate: Aggregation = Aggregation.collectSet() + override lazy val mergeAggregate: Aggregation = Aggregation.mergeSets() + override def toString(): String = "CudfCollectSet" + override def dataType: DataType = ArrayType(ref.dataType, containsNull = false) + override def nullable: Boolean = false +} + +class CudfMergeSets(ref: Expression) extends CudfAggregate(ref) { + override lazy val updateReductionAggregate: cudf.ColumnVector => cudf.Scalar = + throw new UnsupportedOperationException("CudfMergeSets is not yet supported in reduction") + override lazy val mergeReductionAggregate: cudf.ColumnVector => cudf.Scalar = + throw new UnsupportedOperationException("CudfMergeSets is not yet supported in reduction") + override lazy val updateAggregate: Aggregation = Aggregation.mergeSets() + override lazy val mergeAggregate: Aggregation = Aggregation.mergeSets() + override def toString(): String = "CudfMergeSets" +} + abstract class CudfFirstLastBase(ref: Expression) extends CudfAggregate(ref) { val includeNulls: NullPolicy val offset: Int @@ -756,7 +800,7 @@ case class GpuLast(child: Expression, ignoreNulls: Boolean) trait GpuCollectBase[T <: Aggregation with RollingAggregation[T]] extends GpuAggregateFunction with GpuAggregateWindowFunction[T] { - def childExpression: Expression + def child: Expression // Collect operations are non-deterministic since their results depend on the // actual order of input rows. @@ -764,23 +808,29 @@ trait GpuCollectBase[T <: Aggregation with RollingAggregation[T]] extends GpuAgg override def nullable: Boolean = false - override def dataType: DataType = ArrayType(childExpression.dataType, false) + override def dataType: DataType = ArrayType(child.dataType, containsNull = false) - override def children: Seq[Expression] = childExpression :: Nil + override def children: Seq[Expression] = child :: Nil // WINDOW FUNCTION - override val windowInputProjection: Seq[Expression] = Seq(childExpression) + override val windowInputProjection: Seq[Expression] = Seq(child) // Make them lazy to avoid being initialized when creating a GpuCollectOp. override lazy val initialValues: Seq[GpuExpression] = throw new UnsupportedOperationException - override lazy val updateExpressions: Seq[Expression] = throw new UnsupportedOperationException - override lazy val mergeExpressions: Seq[GpuExpression] = throw new UnsupportedOperationException - override lazy val evaluateExpression: Expression = throw new UnsupportedOperationException - override val inputProjection: Seq[Expression] = Seq(childExpression) - override def aggBufferAttributes: Seq[AttributeReference] = { - throw new UnsupportedOperationException - } + override val inputProjection: Seq[Expression] = Seq(child) + + // Unlike other GpuAggregateFunction, GpuCollectFunction will change the type of input data in + // update stage (childType => Array[childType]). And the input type of merge expression is not + // same as update expression. Meanwhile, they still share the same ordinal in terms of cuDF + // table. + // Therefore, we create two separate buffers for update and merge. And they are pointed to + // the same ordinal since they share the same exprId. + protected final lazy val inputBuf: AttributeReference = + AttributeReference("inputBuf", child.dataType)() + + protected final lazy val outputBuf: AttributeReference = + inputBuf.copy("outputBuf", dataType)(inputBuf.exprId, inputBuf.qualifier) } /** @@ -790,11 +840,19 @@ trait GpuCollectBase[T <: Aggregation with RollingAggregation[T]] extends GpuAgg * with the CPU version and automated checks. */ case class GpuCollectList( - childExpression: Expression, + child: Expression, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) extends GpuCollectBase[CollectListAggregation] { + override lazy val updateExpressions: Seq[GpuExpression] = new CudfCollectList(inputBuf) :: Nil + + override lazy val mergeExpressions: Seq[GpuExpression] = new CudfMergeLists(outputBuf) :: Nil + + override lazy val evaluateExpression: Expression = outputBuf + + override def aggBufferAttributes: Seq[AttributeReference] = outputBuf :: Nil + override def prettyName: String = "collect_list" override def windowAggregation( @@ -810,11 +868,19 @@ case class GpuCollectList( * with the CPU version and automated checks. */ case class GpuCollectSet( - childExpression: Expression, + child: Expression, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) extends GpuCollectBase[CollectSetAggregation] { + override lazy val updateExpressions: Seq[GpuExpression] = new CudfCollectSet(inputBuf) :: Nil + + override lazy val mergeExpressions: Seq[GpuExpression] = new CudfMergeSets(outputBuf) :: Nil + + override lazy val evaluateExpression: Expression = outputBuf + + override def aggBufferAttributes: Seq[AttributeReference] = outputBuf :: Nil + override def prettyName: String = "collect_set" override def windowAggregation( diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExec.scala index 5f29e293065..9e47fdbe0ff 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExec.scala @@ -49,6 +49,28 @@ class GpuShuffleMeta( override val childParts: scala.Seq[PartMeta[_]] = Seq(GpuOverrides.wrapPart(shuffle.outputPartitioning, conf, Some(this))) + // Propagate possible type conversions on the output attributes of map-side plans to + // reduce-side counterparts. We can pass through the outputs of child because Shuffle will + // not change the data schema. And we need to pass through because Shuffle itself and + // reduce-side plans may failed to pass the type check for tagging CPU data types rather + // than their GPU counterparts. + // + // Taking AggregateExec with TypedImperativeAggregate function as example: + // Assume I have a query: SELECT a, COLLECT_LIST(b) FROM table GROUP BY a, which physical plan + // looks like: + // ObjectHashAggregate(keys=[a#10], functions=[collect_list(b#11, 0, 0)], + // output=[a#10, collect_list(b)#17]) + // +- Exchange hashpartitioning(a#10, 200), true, [id=#13] + // +- ObjectHashAggregate(keys=[a#10], functions=[partial_collect_list(b#11, 0, 0)], + // output=[a#10, buf#21]) + // +- LocalTableScan [a#10, b#11] + // + // We will override the data type of buf#21 in GpuNoHashAggregateMeta. Otherwise, the partial + // Aggregate will fall back to CPU because buf#21 produce a GPU-unsupported type: BinaryType. + // Just like the partial Aggregate, the ShuffleExchange will also fall back to CPU unless we + // apply the same type overriding as its child plan: the partial Aggregate. + override protected val useOutputAttributesOfChild: Boolean = true + override def tagPlanForGpu(): Unit = { // when AQE is enabled and we are planning a new query stage, we need to look at meta-data // previously stored on the spark plan to determine whether this exchange can run on GPU