Skip to content

Commit

Permalink
DecimalType support for IfElse and Coalesce (NVIDIA#1453)
Browse files Browse the repository at this point in the history
Turned on the DecimalType support for IfElse and Coalesce
and added DecimalGen in tests for both ops

Signed-off-by: Raza Jafri <rjafri@nvidia.com>
  • Loading branch information
razajafri authored Jan 6, 2021
1 parent d264e5b commit e26cea6
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 12 deletions.
10 changes: 5 additions & 5 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -2717,7 +2717,7 @@ Accelerator support is described below.
<td>S</td>
<td>S*</td>
<td>S</td>
<td><b>NS</b></td>
<td>S*</td>
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand All @@ -2738,7 +2738,7 @@ Accelerator support is described below.
<td>S</td>
<td>S*</td>
<td>S</td>
<td><b>NS</b></td>
<td>S*</td>
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand Down Expand Up @@ -5740,7 +5740,7 @@ Accelerator support is described below.
<td>S</td>
<td>S*</td>
<td>S</td>
<td><b>NS</b></td>
<td>S*</td>
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand All @@ -5761,7 +5761,7 @@ Accelerator support is described below.
<td>S</td>
<td>S*</td>
<td>S</td>
<td><b>NS</b></td>
<td>S*</td>
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand All @@ -5782,7 +5782,7 @@ Accelerator support is described below.
<td>S</td>
<td>S*</td>
<td>S</td>
<td><b>NS</b></td>
<td>S*</td>
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand Down
4 changes: 2 additions & 2 deletions integration_tests/src/main/python/conditionals_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

all_gens = all_gen + [NullGen()]

@pytest.mark.parametrize('data_gen', all_basic_gens, ids=idfn)
@pytest.mark.parametrize('data_gen', all_gens, ids=idfn)
def test_if_else(data_gen):
(s1, s2) = gen_scalars_for_sql(data_gen, 2, force_no_nulls=not isinstance(data_gen, NullGen))
null_lit = get_null_lit_string(data_gen.data_type)
Expand Down Expand Up @@ -81,7 +81,7 @@ def test_nvl(data_gen):
'nvl(a, {})'.format(null_lit)))

#nvl is translated into a 2 param version of coalesce
@pytest.mark.parametrize('data_gen', all_basic_gens, ids=idfn)
@pytest.mark.parametrize('data_gen', all_gens, ids=idfn)
def test_coalesce(data_gen):
num_cols = 20
s1 = gen_scalar(data_gen, force_no_nulls=not isinstance(data_gen, NullGen))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1059,9 +1059,9 @@ object GpuOverrides {
expr[Coalesce] (
"Returns the first non-null argument if exists. Otherwise, null",
ExprChecks.projectNotLambda(
TypeSig.commonCudfTypes + TypeSig.NULL, TypeSig.all,
TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL, TypeSig.all,
repeatingParamCheck = Some(RepeatingParamCheck("param",
TypeSig.commonCudfTypes + TypeSig.NULL,
TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL,
TypeSig.all))),
(a, conf, p, r) => new ExprMeta[Coalesce](a, conf, p, r) {
override def convertToGpu(): GpuExpression = GpuCoalesce(childExprs.map(_.convertToGpu()))
Expand Down Expand Up @@ -1537,11 +1537,14 @@ object GpuOverrides {
}),
expr[If](
"IF expression",
ExprChecks.projectNotLambda(TypeSig.commonCudfTypes + TypeSig.NULL, TypeSig.all,
ExprChecks.projectNotLambda(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL,
TypeSig.all,
Seq(ParamCheck("predicate", TypeSig.psNote(TypeEnum.BOOLEAN,
"literal values are not supported"), TypeSig.BOOLEAN),
ParamCheck("trueValue", TypeSig.commonCudfTypes + TypeSig.NULL, TypeSig.all),
ParamCheck("falseValue", TypeSig.commonCudfTypes + TypeSig.NULL, TypeSig.all))),
ParamCheck("trueValue", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL,
TypeSig.all),
ParamCheck("falseValue", TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL,
TypeSig.all))),
(a, conf, p, r) => new ExprMeta[If](a, conf, p, r) {
override def tagExprForGpu(): Unit = {
if (isLit(a.predicate)) {
Expand Down

0 comments on commit e26cea6

Please sign in to comment.