diff --git a/docs/supported_ops.md b/docs/supported_ops.md
index a64f432983e..0feefaec297 100644
--- a/docs/supported_ops.md
+++ b/docs/supported_ops.md
@@ -15230,7 +15230,7 @@ are limited.
NS |
NS |
|
-NS |
+PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, ARRAY, STRUCT, UDT |
NS |
@@ -15251,7 +15251,7 @@ are limited.
NS |
NS |
|
-NS |
+PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, ARRAY, STRUCT, UDT |
NS |
@@ -15273,7 +15273,7 @@ are limited.
NS |
NS |
|
-NS |
+PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, ARRAY, STRUCT, UDT |
NS |
@@ -15294,7 +15294,7 @@ are limited.
NS |
NS |
|
-NS |
+PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, ARRAY, STRUCT, UDT |
NS |
@@ -15389,7 +15389,7 @@ are limited.
NS |
NS |
|
-NS |
+PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, ARRAY, STRUCT, UDT |
NS |
@@ -15410,7 +15410,7 @@ are limited.
NS |
NS |
|
-NS |
+PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, ARRAY, STRUCT, UDT |
NS |
@@ -15432,7 +15432,7 @@ are limited.
NS |
NS |
|
-NS |
+PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, ARRAY, STRUCT, UDT |
NS |
@@ -15453,7 +15453,7 @@ are limited.
NS |
NS |
|
-NS |
+PS UTC is only supported TZ for child TIMESTAMP; unsupported child types BINARY, CALENDAR, ARRAY, STRUCT, UDT |
NS |
diff --git a/integration_tests/src/main/python/hash_aggregate_test.py b/integration_tests/src/main/python/hash_aggregate_test.py
index f862a11d996..54e2f913784 100644
--- a/integration_tests/src/main/python/hash_aggregate_test.py
+++ b/integration_tests/src/main/python/hash_aggregate_test.py
@@ -1704,3 +1704,46 @@ def test_groupby_std_variance_partial_replace_fallback(data_gen,
exist_classes=','.join(exist_clz),
non_exist_classes=','.join(non_exist_clz),
conf=local_conf)
+
+#
+# test min max on single level structure
+#
+gens_for_max_min = [byte_gen, short_gen, int_gen, long_gen,
+ FloatGen(no_nans = True), DoubleGen(no_nans = True),
+ string_gen, boolean_gen,
+ date_gen, timestamp_gen,
+ DecimalGen(precision=12, scale=2),
+ DecimalGen(precision=36, scale=5),
+ null_gen]
+@ignore_order(local=True)
+@pytest.mark.parametrize('data_gen', gens_for_max_min, ids=idfn)
+def test_min_max_for_single_level_struct(data_gen):
+ df_gen = [
+ ('a', StructGen([
+ ('aa', data_gen),
+ ('ab', data_gen)])),
+ ('b', RepeatSeqGen(IntegerGen(), length=20))]
+
+ # test max
+ assert_gpu_and_cpu_are_equal_sql(
+ lambda spark : gen_df(spark, df_gen),
+ "hash_agg_table",
+ 'select b, max(a) from hash_agg_table group by b',
+ _no_nans_float_conf)
+ assert_gpu_and_cpu_are_equal_sql(
+ lambda spark : gen_df(spark, df_gen),
+ "hash_agg_table",
+ 'select max(a) from hash_agg_table',
+ _no_nans_float_conf)
+
+ # test min
+ assert_gpu_and_cpu_are_equal_sql(
+ lambda spark : gen_df(spark, df_gen, length=1024),
+ "hash_agg_table",
+ 'select b, min(a) from hash_agg_table group by b',
+ _no_nans_float_conf)
+ assert_gpu_and_cpu_are_equal_sql(
+ lambda spark : gen_df(spark, df_gen, length=1024),
+ "hash_agg_table",
+ 'select min(a) from hash_agg_table',
+ _no_nans_float_conf)
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
index 7637686d898..31401832bbb 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
@@ -2234,14 +2234,27 @@ object GpuOverrides extends Logging {
}),
expr[Max](
"Max aggregate operator",
- ExprChecks.fullAgg(
- TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL, TypeSig.orderable,
- Seq(ParamCheck("input",
- (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL)
+ ExprChecksImpl(
+ ExprChecks.reductionAndGroupByAgg(
+ // Max supports single level struct, e.g.: max(struct(string, string))
+ (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT)
+ .nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL),
+ TypeSig.orderable,
+ Seq(ParamCheck("input",
+ (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT)
+ .nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL)
.withPsNote(TypeEnum.DOUBLE, nanAggPsNote)
.withPsNote(TypeEnum.FLOAT, nanAggPsNote),
- TypeSig.orderable))
- ),
+ TypeSig.orderable))).asInstanceOf[ExprChecksImpl].contexts
+ ++
+ ExprChecks.windowOnly(
+ (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL),
+ TypeSig.orderable,
+ Seq(ParamCheck("input",
+ (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL)
+ .withPsNote(TypeEnum.DOUBLE, nanAggPsNote)
+ .withPsNote(TypeEnum.FLOAT, nanAggPsNote),
+ TypeSig.orderable))).asInstanceOf[ExprChecksImpl].contexts),
(max, conf, p, r) => new AggExprMeta[Max](max, conf, p, r) {
override def tagAggForGpu(): Unit = {
val dataType = max.child.dataType
@@ -2256,14 +2269,27 @@ object GpuOverrides extends Logging {
}),
expr[Min](
"Min aggregate operator",
- ExprChecks.fullAgg(
- TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL, TypeSig.orderable,
- Seq(ParamCheck("input",
- (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL)
+ ExprChecksImpl(
+ ExprChecks.reductionAndGroupByAgg(
+ // Min supports single level struct, e.g.: max(struct(string, string))
+ (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT)
+ .nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL),
+ TypeSig.orderable,
+ Seq(ParamCheck("input",
+ (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT)
+ .nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL)
.withPsNote(TypeEnum.DOUBLE, nanAggPsNote)
.withPsNote(TypeEnum.FLOAT, nanAggPsNote),
- TypeSig.orderable))
- ),
+ TypeSig.orderable))).asInstanceOf[ExprChecksImpl].contexts
+ ++
+ ExprChecks.windowOnly(
+ (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL),
+ TypeSig.orderable,
+ Seq(ParamCheck("input",
+ (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL)
+ .withPsNote(TypeEnum.DOUBLE, nanAggPsNote)
+ .withPsNote(TypeEnum.FLOAT, nanAggPsNote),
+ TypeSig.orderable))).asInstanceOf[ExprChecksImpl].contexts),
(a, conf, p, r) => new AggExprMeta[Min](a, conf, p, r) {
override def tagAggForGpu(): Unit = {
val dataType = a.child.dataType