Skip to content

Commit

Permalink
Support max on single-level struct in aggregation context (#4434)
Browse files Browse the repository at this point in the history
* Support max on single-level struct in aggregation context
* Refactor
* Support min and max on single-level
* Update test case after Cudf fixed bug about null

Signed-off-by: Chong Gao <res_life@163.com>
  • Loading branch information
Chong Gao authored Jan 15, 2022
1 parent 7d8b6d4 commit a63b483
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 20 deletions.
16 changes: 8 additions & 8 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -15230,7 +15230,7 @@ are limited.
<td><b>NS</b></td>
<td><b>NS</b></td>
<td> </td>
<td><b>NS</b></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, ARRAY, STRUCT, UDT</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand All @@ -15251,7 +15251,7 @@ are limited.
<td><b>NS</b></td>
<td><b>NS</b></td>
<td> </td>
<td><b>NS</b></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, ARRAY, STRUCT, UDT</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand All @@ -15273,7 +15273,7 @@ are limited.
<td><b>NS</b></td>
<td><b>NS</b></td>
<td> </td>
<td><b>NS</b></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, ARRAY, STRUCT, UDT</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand All @@ -15294,7 +15294,7 @@ are limited.
<td><b>NS</b></td>
<td><b>NS</b></td>
<td> </td>
<td><b>NS</b></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, ARRAY, STRUCT, UDT</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand Down Expand Up @@ -15389,7 +15389,7 @@ are limited.
<td><b>NS</b></td>
<td><b>NS</b></td>
<td> </td>
<td><b>NS</b></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, ARRAY, STRUCT, UDT</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand All @@ -15410,7 +15410,7 @@ are limited.
<td><b>NS</b></td>
<td><b>NS</b></td>
<td> </td>
<td><b>NS</b></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, ARRAY, STRUCT, UDT</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand All @@ -15432,7 +15432,7 @@ are limited.
<td><b>NS</b></td>
<td><b>NS</b></td>
<td> </td>
<td><b>NS</b></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, ARRAY, STRUCT, UDT</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand All @@ -15453,7 +15453,7 @@ are limited.
<td><b>NS</b></td>
<td><b>NS</b></td>
<td> </td>
<td><b>NS</b></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, ARRAY, STRUCT, UDT</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand Down
43 changes: 43 additions & 0 deletions integration_tests/src/main/python/hash_aggregate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1704,3 +1704,46 @@ def test_groupby_std_variance_partial_replace_fallback(data_gen,
exist_classes=','.join(exist_clz),
non_exist_classes=','.join(non_exist_clz),
conf=local_conf)

#
# test min max on single level structure
#
gens_for_max_min = [byte_gen, short_gen, int_gen, long_gen,
FloatGen(no_nans = True), DoubleGen(no_nans = True),
string_gen, boolean_gen,
date_gen, timestamp_gen,
DecimalGen(precision=12, scale=2),
DecimalGen(precision=36, scale=5),
null_gen]
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', gens_for_max_min, ids=idfn)
def test_min_max_for_single_level_struct(data_gen):
df_gen = [
('a', StructGen([
('aa', data_gen),
('ab', data_gen)])),
('b', RepeatSeqGen(IntegerGen(), length=20))]

# test max
assert_gpu_and_cpu_are_equal_sql(
lambda spark : gen_df(spark, df_gen),
"hash_agg_table",
'select b, max(a) from hash_agg_table group by b',
_no_nans_float_conf)
assert_gpu_and_cpu_are_equal_sql(
lambda spark : gen_df(spark, df_gen),
"hash_agg_table",
'select max(a) from hash_agg_table',
_no_nans_float_conf)

# test min
assert_gpu_and_cpu_are_equal_sql(
lambda spark : gen_df(spark, df_gen, length=1024),
"hash_agg_table",
'select b, min(a) from hash_agg_table group by b',
_no_nans_float_conf)
assert_gpu_and_cpu_are_equal_sql(
lambda spark : gen_df(spark, df_gen, length=1024),
"hash_agg_table",
'select min(a) from hash_agg_table',
_no_nans_float_conf)
Original file line number Diff line number Diff line change
Expand Up @@ -2234,14 +2234,27 @@ object GpuOverrides extends Logging {
}),
expr[Max](
"Max aggregate operator",
ExprChecks.fullAgg(
TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL, TypeSig.orderable,
Seq(ParamCheck("input",
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL)
ExprChecksImpl(
ExprChecks.reductionAndGroupByAgg(
// Max supports single level struct, e.g.: max(struct(string, string))
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT)
.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL),
TypeSig.orderable,
Seq(ParamCheck("input",
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT)
.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL)
.withPsNote(TypeEnum.DOUBLE, nanAggPsNote)
.withPsNote(TypeEnum.FLOAT, nanAggPsNote),
TypeSig.orderable))
),
TypeSig.orderable))).asInstanceOf[ExprChecksImpl].contexts
++
ExprChecks.windowOnly(
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL),
TypeSig.orderable,
Seq(ParamCheck("input",
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL)
.withPsNote(TypeEnum.DOUBLE, nanAggPsNote)
.withPsNote(TypeEnum.FLOAT, nanAggPsNote),
TypeSig.orderable))).asInstanceOf[ExprChecksImpl].contexts),
(max, conf, p, r) => new AggExprMeta[Max](max, conf, p, r) {
override def tagAggForGpu(): Unit = {
val dataType = max.child.dataType
Expand All @@ -2256,14 +2269,27 @@ object GpuOverrides extends Logging {
}),
expr[Min](
"Min aggregate operator",
ExprChecks.fullAgg(
TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL, TypeSig.orderable,
Seq(ParamCheck("input",
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL)
ExprChecksImpl(
ExprChecks.reductionAndGroupByAgg(
// Min supports single level struct, e.g.: max(struct(string, string))
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT)
.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL),
TypeSig.orderable,
Seq(ParamCheck("input",
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + TypeSig.STRUCT)
.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL)
.withPsNote(TypeEnum.DOUBLE, nanAggPsNote)
.withPsNote(TypeEnum.FLOAT, nanAggPsNote),
TypeSig.orderable))
),
TypeSig.orderable))).asInstanceOf[ExprChecksImpl].contexts
++
ExprChecks.windowOnly(
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL),
TypeSig.orderable,
Seq(ParamCheck("input",
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL)
.withPsNote(TypeEnum.DOUBLE, nanAggPsNote)
.withPsNote(TypeEnum.FLOAT, nanAggPsNote),
TypeSig.orderable))).asInstanceOf[ExprChecksImpl].contexts),
(a, conf, p, r) => new AggExprMeta[Min](a, conf, p, r) {
override def tagAggForGpu(): Unit = {
val dataType = a.child.dataType
Expand Down

0 comments on commit a63b483

Please sign in to comment.