diff --git a/docs/supported_ops.md b/docs/supported_ops.md
index 8c26e9e97c1..5ecc0f73153 100644
--- a/docs/supported_ops.md
+++ b/docs/supported_ops.md
@@ -2426,7 +2426,7 @@ Accelerator support is described below.
S |
S* |
S |
-NS |
+S* |
S |
NS |
NS |
@@ -2447,7 +2447,7 @@ Accelerator support is described below.
S |
S* |
S |
-NS |
+S* |
S |
NS |
NS |
diff --git a/integration_tests/src/main/python/conditionals_test.py b/integration_tests/src/main/python/conditionals_test.py
index 66126c6279f..1ec590e1f11 100644
--- a/integration_tests/src/main/python/conditionals_test.py
+++ b/integration_tests/src/main/python/conditionals_test.py
@@ -20,6 +20,8 @@
from pyspark.sql.types import *
import pyspark.sql.functions as f
+all_gens = all_gen + [NullGen()]
+
@pytest.mark.parametrize('data_gen', all_basic_gens, ids=idfn)
def test_if_else(data_gen):
(s1, s2) = gen_scalars_for_sql(data_gen, 2, force_no_nulls=not isinstance(data_gen, NullGen))
@@ -34,7 +36,7 @@ def test_if_else(data_gen):
'IF(a, b, {})'.format(null_lit),
'IF(a, {}, c)'.format(null_lit)))
-@pytest.mark.parametrize('data_gen', all_basic_gens, ids=idfn)
+@pytest.mark.parametrize('data_gen', all_gens, ids=idfn)
def test_case_when(data_gen):
num_cmps = 20
s1 = gen_scalar(data_gen, force_no_nulls=not isinstance(data_gen, NullGen))
diff --git a/integration_tests/src/main/python/data_gen.py b/integration_tests/src/main/python/data_gen.py
index 655ede78500..ec0dd982019 100644
--- a/integration_tests/src/main/python/data_gen.py
+++ b/integration_tests/src/main/python/data_gen.py
@@ -684,6 +684,8 @@ def to_cast_string(spark_type):
return 'TIMESTAMP'
elif isinstance(spark_type, StringType):
return 'STRING'
+ elif isinstance(spark_type, DecimalType):
+ return 'DECIMAL({}, {})'.format(spark_type.precision, spark_type.scale)
else:
raise RuntimeError('CAST TO TYPE {} NOT SUPPORTED YET'.format(spark_type))
@@ -796,3 +798,8 @@ def gen_scalars_for_sql(data_gen, count, seed=0, force_no_nulls=False):
MapGen(StringGen(pattern='key_[0-9]', nullable=False), simple_string_to_string_map_gen)]
allow_negative_scale_of_decimal_conf = {'spark.sql.legacy.allowNegativeScaleOfDecimal': 'true'}
+
+all_gen = [StringGen(), ByteGen(), ShortGen(), IntegerGen(), LongGen(),
+ FloatGen(), DoubleGen(), BooleanGen(), DateGen(), TimestampGen(),
+ decimal_gen_default, decimal_gen_scale_precision, decimal_gen_same_scale_precision,
+ decimal_gen_64bit]
diff --git a/integration_tests/src/main/python/generate_expr_test.py b/integration_tests/src/main/python/generate_expr_test.py
index d83f16937cb..2abc75688a6 100644
--- a/integration_tests/src/main/python/generate_expr_test.py
+++ b/integration_tests/src/main/python/generate_expr_test.py
@@ -28,11 +28,6 @@ def four_op_df(spark, gen, length=2048, seed=0):
('c', gen),
('d', gen)], nullable=False), length=length, seed=seed)
-all_gen = [StringGen(), ByteGen(), ShortGen(), IntegerGen(), LongGen(),
- FloatGen(), DoubleGen(), BooleanGen(), DateGen(), TimestampGen(),
- decimal_gen_default, decimal_gen_scale_precision, decimal_gen_same_scale_precision,
- decimal_gen_64bit]
-
#sort locally because of https://github.com/NVIDIA/spark-rapids/issues/84
# After 3.1.0 is the min spark version we can drop this
@ignore_order(local=True)
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
index 2fd76aef3b9..0cef1ddeb21 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
@@ -601,7 +601,7 @@ case class ExprChecksImpl(contexts: Map[ExpressionContext, ContextChecks])
* This is specific to CaseWhen, because it does not follow the typical parameter convention.
*/
object CaseWhenCheck extends ExprChecks {
- val check: TypeSig = TypeSig.commonCudfTypes + TypeSig.NULL
+ val check: TypeSig = TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL
val sparkSig: TypeSig = TypeSig.all
override def tag(meta: RapidsMeta[_, _, _]): Unit = {