-
Notifications
You must be signed in to change notification settings - Fork 232
/
decimalExpressions.scala
134 lines (123 loc) · 5.06 KB
/
decimalExpressions.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.nvidia.spark.rapids
import ai.rapids.cudf
import ai.rapids.cudf.{ColumnVector, DecimalUtils, DType, Scalar}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.types.{DataType, DecimalType, LongType}
/**
* A GPU substitution for CheckOverflow. This cannot match the Spark CheckOverflow 100% because
* Spark will calculate values in BigDecimal with unbounded precision and then see if there was an
* overflow. This will check bounds, but can only detect that an overflow happened if the result is
* outside the bounds of what the Spark type supports, but did not yet overflow the bounds for what
* the CUDF type supports. For most operations when this is a possibility for the given precision
* then the operator should fall back to the CPU, or have alternative ways of checking for overflow
* prior to this being called.
*/
case class GpuCheckOverflow(child: Expression,
dataType: DecimalType,
nullOnOverflow: Boolean) extends GpuUnaryExpression {
private[this] val expectedCudfScale = -dataType.scale
private[this] lazy val resultDType = if (dataType.precision > DType.DECIMAL64_MAX_PRECISION) {
DType.create(DType.DTypeEnum.DECIMAL128, expectedCudfScale)
} else if (dataType.precision > DType.DECIMAL32_MAX_PRECISION) {
DType.create(DType.DTypeEnum.DECIMAL64, expectedCudfScale)
} else {
DType.create(DType.DTypeEnum.DECIMAL32, expectedCudfScale)
}
override protected def doColumnar(input: GpuColumnVector): ColumnVector = {
val base = input.getBase
val rounded = if (resultDType.equals(base.getType)) {
base.incRefCount()
} else {
withResource(base.round(dataType.scale, cudf.RoundMode.HALF_UP)) { rounded =>
if (resultDType.getTypeId != base.getType.getTypeId) {
rounded.castTo(resultDType)
} else {
rounded.incRefCount()
}
}
}
withResource(rounded) { rounded =>
GpuCast.checkNFixDecimalBounds(rounded, dataType, !nullOnOverflow)
}
}
}
/**
* A GPU substitution of PromotePrecision, which is a NOOP in Spark too.
*/
case class GpuPromotePrecision(child: Expression) extends GpuUnaryExpression {
override protected def doColumnar(input: GpuColumnVector): ColumnVector =
input.getBase.incRefCount()
override def dataType: DataType = child.dataType
}
case class GpuUnscaledValue(child: Expression) extends GpuUnaryExpression {
override def dataType: DataType = LongType
override def toString: String = s"UnscaledValue($child)"
override protected def doColumnar(input: GpuColumnVector): ColumnVector = {
if (input.getBase.getType.isBackedByInt) {
withResource(input.getBase.bitCastTo(DType.INT32)) { int32View =>
int32View.castTo(DType.INT64)
}
} else {
withResource(input.getBase.bitCastTo(DType.INT64)) { view =>
view.copyToColumnVector()
}
}
}
}
case class GpuMakeDecimal(
child: Expression,
precision: Int,
sparkScale: Int,
nullOnOverflow: Boolean) extends GpuUnaryExpression {
override def dataType: DecimalType = DecimalType(precision, sparkScale)
override def nullable: Boolean = child.nullable || nullOnOverflow
override def toString: String = s"MakeDecimal($child,$precision,$sparkScale)"
private lazy val (minValue, maxValue) = {
val bounds = DecimalUtils.bounds(dataType.precision, dataType.scale)
(bounds.getKey.unscaledValue().longValue(), bounds.getValue.unscaledValue().longValue())
}
override protected def doColumnar(input: GpuColumnVector): ColumnVector = {
val outputType = DecimalUtils.createDecimalType(precision, sparkScale)
val base = input.getBase
val outOfBounds = withResource(Scalar.fromLong(maxValue)) { maxScalar =>
withResource(base.greaterThan(maxScalar)) { over =>
withResource(Scalar.fromLong(minValue)) { minScalar =>
withResource(base.lessThan(minScalar)) { under =>
over.or(under)
}
}
}
}
withResource(outOfBounds) { outOfBounds =>
withResource(base.bitCastTo(outputType)) { outputView =>
if (!nullOnOverflow) {
withResource(outOfBounds.any()) { isAny =>
if (isAny.isValid && isAny.getBoolean) {
throw new IllegalStateException(GpuCast.INVALID_INPUT_MESSAGE)
}
}
outputView.copyToColumnVector()
} else {
withResource(Scalar.fromNull(outputType)) { nullVal =>
outOfBounds.ifElse(nullVal, outputView)
}
}
}
}
}
}