Skip to content

Commit

Permalink
DecimalType support for Aggregate Count (NVIDIA#1476)
Browse files Browse the repository at this point in the history
Signed-off-by: Raza Jafri <rjafri@nvidia.com>
  • Loading branch information
razajafri authored Jan 8, 2021
1 parent 8b4d931 commit 1958465
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 4 deletions.
6 changes: 3 additions & 3 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -15483,7 +15483,7 @@ Accelerator support is described below.
<td>S</td>
<td>S*</td>
<td>S</td>
<td><b>NS</b></td>
<td>S*</td>
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand Down Expand Up @@ -15526,7 +15526,7 @@ Accelerator support is described below.
<td>S</td>
<td>S*</td>
<td>S</td>
<td><b>NS</b></td>
<td>S*</td>
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand Down Expand Up @@ -15569,7 +15569,7 @@ Accelerator support is described below.
<td>S</td>
<td>S*</td>
<td>S</td>
<td><b>NS</b></td>
<td>S*</td>
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand Down
7 changes: 7 additions & 0 deletions integration_tests/src/main/python/hash_aggregate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,3 +418,10 @@ def test_arithmetic_reductions(data_gen):
'avg(a)'),
conf = _no_nans_float_conf)

@ignore_order
@pytest.mark.parametrize('data_gen', all_gen, ids=idfn)
@pytest.mark.parametrize('count_func', [f.count, f.countDistinct])
def test_agg_count(data_gen, count_func):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : gen_df(spark, [('a', data_gen), ('b', data_gen)],
length=1024).groupBy('a').agg(count_func("b")))
Original file line number Diff line number Diff line change
Expand Up @@ -1648,7 +1648,7 @@ object GpuOverrides {
ExprChecks.fullAgg(
TypeSig.LONG, TypeSig.LONG,
repeatingParamCheck = Some(RepeatingParamCheck(
"input", TypeSig.commonCudfTypes + TypeSig.NULL, TypeSig.all))),
"input", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL, TypeSig.all))),
(count, conf, p, r) => new ExprMeta[Count](count, conf, p, r) {
override def tagExprForGpu(): Unit = {
if (count.children.size > 1) {
Expand Down

0 comments on commit 1958465

Please sign in to comment.