Skip to content

Commit

Permalink
Support relational operators for decimal type (NVIDIA#1173)
Browse files Browse the repository at this point in the history
Signed-off-by: Niranjan Artal <nartal@nvidia.com>

Co-authored-by: Robert (Bobby) Evans <bobby@apache.org>
  • Loading branch information
nartal1 and revans2 authored Dec 1, 2020
1 parent 5b53c39 commit 515dd1a
Show file tree
Hide file tree
Showing 6 changed files with 258 additions and 16 deletions.
16 changes: 8 additions & 8 deletions integration_tests/src/main/python/cmp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pyspark.sql.types import *
import pyspark.sql.functions as f

@pytest.mark.parametrize('data_gen', eq_gens, ids=idfn)
@pytest.mark.parametrize('data_gen', eq_gens_with_decimal_gen, ids=idfn)
def test_eq(data_gen):
(s1, s2) = gen_scalars(data_gen, 2, force_no_nulls=not isinstance(data_gen, NullGen))
data_type = data_gen.data_type
Expand All @@ -31,7 +31,7 @@ def test_eq(data_gen):
s2 == f.col('b'),
f.lit(None).cast(data_type) == f.col('a'),
f.col('b') == f.lit(None).cast(data_type),
f.col('a') == f.col('b')))
f.col('a') == f.col('b')), conf=allow_negative_scale_of_decimal_conf)

@pytest.mark.parametrize('data_gen', eq_gens, ids=idfn)
def test_eq_ns(data_gen):
Expand All @@ -45,7 +45,7 @@ def test_eq_ns(data_gen):
f.col('b').eqNullSafe(f.lit(None).cast(data_type)),
f.col('a').eqNullSafe(f.col('b'))))

@pytest.mark.parametrize('data_gen', eq_gens, ids=idfn)
@pytest.mark.parametrize('data_gen', eq_gens_with_decimal_gen, ids=idfn)
def test_ne(data_gen):
(s1, s2) = gen_scalars(data_gen, 2, force_no_nulls=not isinstance(data_gen, NullGen))
data_type = data_gen.data_type
Expand All @@ -55,7 +55,7 @@ def test_ne(data_gen):
s2 != f.col('b'),
f.lit(None).cast(data_type) != f.col('a'),
f.col('b') != f.lit(None).cast(data_type),
f.col('a') != f.col('b')))
f.col('a') != f.col('b')), conf=allow_negative_scale_of_decimal_conf)

@pytest.mark.parametrize('data_gen', orderable_gens, ids=idfn)
def test_lt(data_gen):
Expand All @@ -67,7 +67,7 @@ def test_lt(data_gen):
s2 < f.col('b'),
f.lit(None).cast(data_type) < f.col('a'),
f.col('b') < f.lit(None).cast(data_type),
f.col('a') < f.col('b')))
f.col('a') < f.col('b')), conf=allow_negative_scale_of_decimal_conf)

@pytest.mark.parametrize('data_gen', orderable_gens, ids=idfn)
def test_lte(data_gen):
Expand All @@ -79,7 +79,7 @@ def test_lte(data_gen):
s2 <= f.col('b'),
f.lit(None).cast(data_type) <= f.col('a'),
f.col('b') <= f.lit(None).cast(data_type),
f.col('a') <= f.col('b')))
f.col('a') <= f.col('b')), conf=allow_negative_scale_of_decimal_conf)

@pytest.mark.parametrize('data_gen', orderable_gens, ids=idfn)
def test_gt(data_gen):
Expand All @@ -91,7 +91,7 @@ def test_gt(data_gen):
s2 > f.col('b'),
f.lit(None).cast(data_type) > f.col('a'),
f.col('b') > f.lit(None).cast(data_type),
f.col('a') > f.col('b')))
f.col('a') > f.col('b')), conf=allow_negative_scale_of_decimal_conf)

@pytest.mark.parametrize('data_gen', orderable_gens, ids=idfn)
def test_gte(data_gen):
Expand All @@ -103,7 +103,7 @@ def test_gte(data_gen):
s2 >= f.col('b'),
f.lit(None).cast(data_type) >= f.col('a'),
f.col('b') >= f.lit(None).cast(data_type),
f.col('a') >= f.col('b')))
f.col('a') >= f.col('b')), conf=allow_negative_scale_of_decimal_conf)

@pytest.mark.parametrize('data_gen', eq_gens + array_gens_sample + struct_gens_sample + map_gens_sample, ids=idfn)
def test_isnull(data_gen):
Expand Down
42 changes: 41 additions & 1 deletion integration_tests/src/main/python/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import copy
from datetime import date, datetime, timedelta, timezone
from decimal import *
import math
from pyspark.sql.types import *
import pyspark.sql.functions as f
Expand Down Expand Up @@ -208,6 +209,33 @@ def __init__(self, nullable=True, min_val = INT_MIN, max_val = INT_MAX,
def start(self, rand):
self._start(rand, lambda : rand.randint(self._min_val, self._max_val))

class DecimalGen(DataGen):
"""Generate Decimals, with some built in corner cases."""
def __init__(self, precision=None, scale=None, nullable=True, special_cases=[]):
if precision is None:
#Maximum number of decimal digits a Long can represent is 18
precision = 18
scale = 0
DECIMAL_MIN = Decimal('-' + ('9' * precision) + 'e' + str(-scale))
DECIMAL_MAX = Decimal(('9'* precision) + 'e' + str(-scale))
special_cases = [Decimal('0'), Decimal(DECIMAL_MIN), Decimal(DECIMAL_MAX)]
super().__init__(DecimalType(precision, scale), nullable=nullable, special_cases=special_cases)
self._scale = scale
self._precision = precision
pattern = "[0-9]{1,"+ str(precision) + "}e" + str(-scale)
self.base_strs = sre_yield.AllStrings(pattern, flags=0, charset=sre_yield.CHARSET, max_count=_MAX_CHOICES)

def __repr__(self):
return super().__repr__() + '(' + str(self._precision) + ',' + str(self._scale) + ')'

def start(self, rand):
strs = self.base_strs
try:
length = int(len(strs))
except OverflowError:
length = _MAX_CHOICES
self._start(rand, lambda : Decimal(strs[rand.randrange(0, length)]))

LONG_MIN = -(1 << 63)
LONG_MAX = (1 << 63) - 1
class LongGen(DataGen):
Expand Down Expand Up @@ -690,16 +718,23 @@ def gen_scalars_for_sql(data_gen, count, seed=0, force_no_nulls=False):
boolean_gen = BooleanGen()
date_gen = DateGen()
timestamp_gen = TimestampGen()
decimal_gen_default = DecimalGen()
decimal_gen_neg_scale = DecimalGen(precision=7, scale=-3)
decimal_gen_scale_precision = DecimalGen(precision=7, scale=3)
decimal_gen_same_scale_precision = DecimalGen(precision=7, scale=7)

null_gen = NullGen()

numeric_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen]

integral_gens = [byte_gen, short_gen, int_gen, long_gen]
# A lot of mathematical expressions only support a double as input
# by parametrizing even for a single param for the test it makes the tests consistent
double_gens = [double_gen]
double_n_long_gens = [double_gen, long_gen]
int_n_long_gens = [int_gen, long_gen]
decimal_gens = [decimal_gen_default, decimal_gen_neg_scale, decimal_gen_scale_precision,
decimal_gen_same_scale_precision]

# all of the basic gens
all_basic_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
Expand All @@ -708,13 +743,16 @@ def gen_scalars_for_sql(data_gen, count, seed=0, force_no_nulls=False):
# TODO add in some array generators to this once that is supported for sorting
# a selection of generators that should be orderable (sortable and compareable)
orderable_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen, boolean_gen, date_gen, timestamp_gen, null_gen]
string_gen, boolean_gen, date_gen, timestamp_gen, null_gen] + decimal_gens

# TODO add in some array generators to this once that is supported for these operations
# a selection of generators that can be compared for equality
eq_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen, boolean_gen, date_gen, timestamp_gen, null_gen]

# Include decimal type while testing equalTo and notEqualTo
eq_gens_with_decimal_gen = eq_gens + decimal_gens

date_gens = [date_gen]
date_n_time_gens = [date_gen, timestamp_gen]

Expand Down Expand Up @@ -748,3 +786,5 @@ def gen_scalars_for_sql(data_gen, count, seed=0, force_no_nulls=False):
MapGen(RepeatSeqGen(IntegerGen(nullable=False), 10), long_gen, max_length=10),
MapGen(BooleanGen(nullable=False), boolean_gen, max_length=2),
MapGen(StringGen(pattern='key_[0-9]', nullable=False), simple_string_to_string_map_gen)]

allow_negative_scale_of_decimal_conf = {'spark.sql.legacy.allowNegativeScaleOfDecimal': 'true'}
Original file line number Diff line number Diff line change
Expand Up @@ -1294,7 +1294,8 @@ object GpuOverrides {
(a, conf, p, r) => new BinaryExprMeta[EqualTo](a, conf, p, r) {
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t,
allowNull = true)
allowNull = true,
allowDecimal = true)

override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuEqualTo(lhs, rhs)
Expand All @@ -1304,7 +1305,8 @@ object GpuOverrides {
(a, conf, p, r) => new BinaryExprMeta[GreaterThan](a, conf, p, r) {
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t,
allowNull = true)
allowNull = true,
allowDecimal = true)

override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuGreaterThan(lhs, rhs)
Expand All @@ -1314,7 +1316,8 @@ object GpuOverrides {
(a, conf, p, r) => new BinaryExprMeta[GreaterThanOrEqual](a, conf, p, r) {
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t,
allowNull = true)
allowNull = true,
allowDecimal = true)

override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuGreaterThanOrEqual(lhs, rhs)
Expand Down Expand Up @@ -1366,8 +1369,10 @@ object GpuOverrides {
"< operator",
(a, conf, p, r) => new BinaryExprMeta[LessThan](a, conf, p, r) {
override def isSupportedType(t: DataType): Boolean =

GpuOverrides.isSupportedType(t,
allowNull = true)
allowNull = true,
allowDecimal = true)

override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuLessThan(lhs, rhs)
Expand All @@ -1377,7 +1382,8 @@ object GpuOverrides {
(a, conf, p, r) => new BinaryExprMeta[LessThanOrEqual](a, conf, p, r) {
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t,
allowNull = true)
allowNull = true,
allowDecimal = true)

override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuLessThanOrEqual(lhs, rhs)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids

import ai.rapids.cudf.DType

import org.apache.spark.sql.rapids.{GpuEqualTo, GpuGreaterThan, GpuGreaterThanOrEqual, GpuLessThan, GpuLessThanOrEqual}
import org.apache.spark.sql.types.{DataTypes, Decimal, DecimalType}

class DecimalBinaryOpSuite extends GpuExpressionTestSuite {
private val schema = FuzzerUtils.createSchema(Seq(
DecimalType(DType.DECIMAL32_MAX_PRECISION, 4),
DecimalType(DType.DECIMAL32_MAX_PRECISION, 2)))
private val litValue = Decimal(12345.6789)
private val lit = GpuLiteral(litValue, DecimalType(DType.DECIMAL64_MAX_PRECISION, 5))
private val leftExpr = GpuBoundReference(0, schema.head.dataType, nullable = true)
private val rightExpr = GpuBoundReference(1, schema(1).dataType, nullable = true)

private val schemaSame = FuzzerUtils.createSchema(Seq(
DecimalType(DType.DECIMAL32_MAX_PRECISION, 3),
DecimalType(DType.DECIMAL32_MAX_PRECISION, 3)))
private val leftExprSame = GpuBoundReference(0, schemaSame.head.dataType, nullable = true)
private val rightExprSame = GpuBoundReference(1, schemaSame(1).dataType, nullable = true)

test("GpuEqualTo") {
val expectedFun = (l: Decimal, r: Decimal) => Option(l == r)
checkEvaluateGpuBinaryExpression(GpuEqualTo(leftExpr, rightExpr),
schema.head.dataType, schema(1).dataType, DataTypes.BooleanType,
expectedFun, schema)
checkEvaluateGpuBinaryExpression(GpuEqualTo(leftExprSame, rightExprSame),
schemaSame.head.dataType, schemaSame.head.dataType, DataTypes.BooleanType,
expectedFun, schemaSame)

val expectedFunVS = (x: Decimal) => Option(x == litValue)
checkEvaluateGpuUnaryExpression(GpuEqualTo(leftExpr, lit),
schema.head.dataType, DataTypes.BooleanType, expectedFunVS, schema)
val expectedFunSV = (x: Decimal) => Option(litValue == x)
checkEvaluateGpuUnaryExpression(GpuEqualTo(lit, leftExpr),
schema.head.dataType, DataTypes.BooleanType, expectedFunSV, schema)
}

test("GpuGreaterThan") {
val expectedFunVV = (l: Decimal, r: Decimal) => Option(l > r)
checkEvaluateGpuBinaryExpression(GpuGreaterThan(leftExpr, rightExpr),
schema.head.dataType, schema(1).dataType, DataTypes.BooleanType,
expectedFunVV, schema)

val expectedFunVS = (x: Decimal) => Option(x > litValue)
checkEvaluateGpuUnaryExpression(GpuGreaterThan(leftExpr, lit),
schema.head.dataType, DataTypes.BooleanType, expectedFunVS, schema)
val expectedFunSV = (x: Decimal) => Option(litValue > x)
checkEvaluateGpuUnaryExpression(GpuGreaterThan(lit, leftExpr),
schema.head.dataType, DataTypes.BooleanType, expectedFunSV, schema)
}

test("GpuGreaterThanOrEqual") {
val expectedFunVV = (l: Decimal, r: Decimal) => Option(l >= r)
checkEvaluateGpuBinaryExpression(GpuGreaterThanOrEqual(leftExpr, rightExpr),
schema.head.dataType, schema(1).dataType, DataTypes.BooleanType,
expectedFunVV, schema)
checkEvaluateGpuBinaryExpression(GpuGreaterThanOrEqual(leftExprSame, rightExprSame),
schemaSame.head.dataType, schemaSame.head.dataType, DataTypes.BooleanType,
expectedFunVV, schemaSame)

val expectedFunVS = (x: Decimal) => Option(x >= litValue)
checkEvaluateGpuUnaryExpression(GpuGreaterThanOrEqual(leftExpr, lit),
schema.head.dataType, DataTypes.BooleanType, expectedFunVS, schema)
val expectedFunSV = (x: Decimal) => Option(litValue >= x)
checkEvaluateGpuUnaryExpression(GpuGreaterThanOrEqual(lit, leftExpr),
schema.head.dataType, DataTypes.BooleanType, expectedFunSV, schema)
}

test("GpuLessThan") {
val expectedFunVV = (l: Decimal, r: Decimal) => Option(l < r)
checkEvaluateGpuBinaryExpression(GpuLessThan(leftExpr, rightExpr),
schema.head.dataType, schema(1).dataType, DataTypes.BooleanType,
expectedFunVV, schema)

val expectedFunVS = (x: Decimal) => Option(x < litValue)
checkEvaluateGpuUnaryExpression(GpuLessThan(leftExpr, lit),
schema.head.dataType, DataTypes.BooleanType, expectedFunVS, schema)
val expectedFunSV = (x: Decimal) => Option(litValue < x)
checkEvaluateGpuUnaryExpression(GpuLessThan(lit, leftExpr),
schema.head.dataType, DataTypes.BooleanType, expectedFunSV, schema)
}

test("GpuLessThanOrEqual") {
val expectedFunVV = (l: Decimal, r: Decimal) => Option(l <= r)
checkEvaluateGpuBinaryExpression(GpuLessThanOrEqual(leftExpr, rightExpr),
schema.head.dataType, schema(1).dataType, DataTypes.BooleanType,
expectedFunVV, schema)
checkEvaluateGpuBinaryExpression(GpuLessThanOrEqual(leftExprSame, rightExprSame),
schemaSame.head.dataType, schemaSame.head.dataType, DataTypes.BooleanType,
expectedFunVV, schemaSame)

val expectedFunVS = (x: Decimal) => Option(x <= litValue)
checkEvaluateGpuUnaryExpression(GpuLessThanOrEqual(leftExpr, lit),
schema.head.dataType, DataTypes.BooleanType, expectedFunVS, schema)
val expectedFunSV = (x: Decimal) => Option(litValue <= x)
checkEvaluateGpuUnaryExpression(GpuLessThanOrEqual(lit, leftExpr),
schema.head.dataType, DataTypes.BooleanType, expectedFunSV, schema)
}
}
13 changes: 12 additions & 1 deletion tests/src/test/scala/com/nvidia/spark/rapids/FuzzerUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import scala.util.Random
import com.nvidia.spark.rapids.GpuColumnVector.GpuColumnarBatchBuilder

import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, MapType, StructField, StructType}
import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, DecimalType, MapType, StructField, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.unsafe.types.CalendarInterval

Expand Down Expand Up @@ -125,6 +125,17 @@ object FuzzerUtils {
case None => builder.appendNull()
}
})
case dt: DecimalType =>
rows.foreach(_ => {
maybeNull(rand, r.nextLong()) match {
case Some(value) =>
// bounding unscaledValue with precision
val invScale = (dt.precision to ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION)
.foldLeft(10L)((x, _) => x * 10)
builder.append(BigDecimal(value / invScale, dt.scale).bigDecimal)
case None => builder.appendNull()
}
})
}
}
builders.build(rowCount)
Expand Down
Loading

0 comments on commit 515dd1a

Please sign in to comment.