Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Initial work on supporting DecimalType #1063

Closed
wants to merge 13 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,16 @@ private static DType toRapidsOrNull(DataType type) {
return DType.TIMESTAMP_MICROSECONDS;
} else if (type instanceof StringType) {
return DType.STRING;
} else if (type instanceof DecimalType) {
// Decimal supportable check has been conducted in the GPU plan overriding stage.
// So, we don't have to handle decimal-supportable problem at here.
DecimalType dt = (DecimalType) type;
if (dt.precision() > DType.DECIMAL64_MAX_PRECISION) {
return null;
} else {
// Map all DecimalType to DECIMAL64, in case of underlying DType transaction.
return DType.create(DType.DTypeEnum.DECIMAL64, -dt.scale());
}
}
return null;
}
Expand Down Expand Up @@ -299,6 +309,14 @@ public static ColumnarBatch from(Table table, DataType[] colTypes) {
*/
private static <T> boolean typeConversionAllowed(ColumnViewAccess<T> cv, DataType colType) {
DType dt = cv.getDataType();
if (dt.isDecimalType()) {
if (!(colType instanceof DecimalType)) {
return false;
}
// check for overflow
int maxPrecision = dt.isBackedByLong() ? DType.DECIMAL64_MAX_PRECISION : DType.DECIMAL32_MAX_PRECISION;
return ((DecimalType) colType).precision() <= maxPrecision;
}
if (!dt.isNestedType()) {
return getRapidsType(colType).equals(dt);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ public final RapidsHostColumnVector incRefCount() {
return this;
}

@Override
public final ai.rapids.cudf.HostColumnVector getBase() {
return cudfCv;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import org.apache.spark.sql.vectorized.ColumnarMap;
import org.apache.spark.unsafe.types.UTF8String;

import java.math.BigDecimal;
import java.math.RoundingMode;

/**
* A GPU accelerated version of the Spark ColumnVector.
Expand Down Expand Up @@ -158,8 +160,10 @@ public final ColumnarMap getMap(int ordinal) {
}

@Override
public final Decimal getDecimal(int rowId, int precision, int scale) {
throw new IllegalStateException("The decimal type is currently not supported by rapids cudf");
public Decimal getDecimal(int rowId, int precision, int scale) {
BigDecimal bigDec = cudfCv.getBigDecimal(rowId).setScale(scale, RoundingMode.UNNECESSARY);
assert bigDec.precision() <= precision : "Assert" + bigDec.precision() + " <= " + precision;
return Decimal.apply(bigDec);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ object GpuOverrides {
}

override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t, allowCalendarInterval = true)
GpuOverrides.isSupportedType(t, allowCalendarInterval = true, allowDecimal = true)
}),
expr[Signum](
"Returns -1.0, 0.0 or 1.0 as expr is negative, 0 or positive",
Expand All @@ -610,6 +610,7 @@ object GpuOverrides {
(a, conf, p, r) => new UnaryExprMeta[Alias](a, conf, p, r) {
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t,
allowDecimal = true,
allowStringMaps = true,
allowArray = true,
allowNesting = true)
Expand All @@ -622,6 +623,7 @@ object GpuOverrides {
(att, conf, p, r) => new BaseExprMeta[AttributeReference](att, conf, p, r) {
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t,
allowDecimal = true,
allowStringMaps = true,
allowArray = true,
allowNesting = true)
Expand Down Expand Up @@ -818,6 +820,7 @@ object GpuOverrides {
(a, conf, p, r) => new UnaryExprMeta[IsNull](a, conf, p, r) {
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t,
allowDecimal = true,
allowStringMaps = true,
allowArray = true,
allowNesting = true)
Expand All @@ -829,6 +832,7 @@ object GpuOverrides {
(a, conf, p, r) => new UnaryExprMeta[IsNotNull](a, conf, p, r) {
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t,
allowDecimal = true,
allowStringMaps = true,
allowArray = true,
allowNesting = true)
Expand Down Expand Up @@ -1177,19 +1181,19 @@ object GpuOverrides {
}),
expr[Add](
"Addition",
(a, conf, p, r) => new BinaryExprMeta[Add](a, conf, p, r) {
(a, conf, p, r) => new BinaryExprMeta[Add](a, conf, p, r, allowDecimal = true) {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuAdd(lhs, rhs)
}),
expr[Subtract](
"Subtraction",
(a, conf, p, r) => new BinaryExprMeta[Subtract](a, conf, p, r) {
(a, conf, p, r) => new BinaryExprMeta[Subtract](a, conf, p, r, allowDecimal = true) {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuSubtract(lhs, rhs)
}),
expr[Multiply](
"Multiplication",
(a, conf, p, r) => new BinaryExprMeta[Multiply](a, conf, p, r) {
(a, conf, p, r) => new BinaryExprMeta[Multiply](a, conf, p, r, allowDecimal = true) {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuMultiply(lhs, rhs)
}),
Expand All @@ -1215,20 +1219,21 @@ object GpuOverrides {
"Check if the values are equal",
(a, conf, p, r) => new BinaryExprMeta[EqualTo](a, conf, p, r) {
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t, allowStringMaps = true)
GpuOverrides.isSupportedType(t, allowStringMaps = true, allowDecimal = true)

override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuEqualTo(lhs, rhs)
}),
expr[GreaterThan](
"> operator",
(a, conf, p, r) => new BinaryExprMeta[GreaterThan](a, conf, p, r) {
(a, conf, p, r) => new BinaryExprMeta[GreaterThan](a, conf, p, r, allowDecimal = true) {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuGreaterThan(lhs, rhs)
}),
expr[GreaterThanOrEqual](
">= operator",
(a, conf, p, r) => new BinaryExprMeta[GreaterThanOrEqual](a, conf, p, r) {
(a, conf, p, r) => new BinaryExprMeta[GreaterThanOrEqual](a, conf, p, r,
allowDecimal = true) {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuGreaterThanOrEqual(lhs, rhs)
}),
Expand Down Expand Up @@ -1269,13 +1274,13 @@ object GpuOverrides {
}),
expr[LessThan](
"< operator",
(a, conf, p, r) => new BinaryExprMeta[LessThan](a, conf, p, r) {
(a, conf, p, r) => new BinaryExprMeta[LessThan](a, conf, p, r, allowDecimal = true) {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuLessThan(lhs, rhs)
}),
expr[LessThanOrEqual](
"<= operator",
(a, conf, p, r) => new BinaryExprMeta[LessThanOrEqual](a, conf, p, r) {
(a, conf, p, r) => new BinaryExprMeta[LessThanOrEqual](a, conf, p, r, allowDecimal = true) {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuLessThanOrEqual(lhs, rhs)
}),
Expand Down Expand Up @@ -1823,6 +1828,7 @@ object GpuOverrides {
new SparkPlanMeta[ProjectExec](proj, conf, p, r) {
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t,
allowDecimal = true,
allowStringMaps = true,
allowArray = true,
allowNesting = true)
Expand Down Expand Up @@ -1911,6 +1917,7 @@ object GpuOverrides {
(filter, conf, p, r) => new SparkPlanMeta[FilterExec](filter, conf, p, r) {
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t,
allowDecimal = true,
allowStringMaps = true,
allowArray = true,
allowNesting = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ private object GpuRowToColumnConverter {
case (TimestampType, false) => NotNullLongConverter
case (StringType, true) => StringConverter
case (StringType, false) => NotNullStringConverter
case (dt: DecimalType, true) => new DecimalConverter(dt.precision, dt.scale)
case (dt: DecimalType, false) => new NotNullDecimalConverter(dt.precision, dt.scale)
// NOT SUPPORTED YET
// case CalendarIntervalType => CalendarConverter
case (at: ArrayType, true) =>
Expand All @@ -100,8 +102,6 @@ private object GpuRowToColumnConverter {
// NOT SUPPORTED YET
// case st: StructType => new StructConverter(st.fields.map(
// (f) => getConverterForType(f.dataType)))
// NOT SUPPORTED YET
// case dt: DecimalType => new DecimalConverter(dt)
// NOT SUPPORTED YET
case (MapType(k, v, vcn), true) =>
MapConverter(getConverterForType(k, nullable = false),
Expand Down Expand Up @@ -289,6 +289,32 @@ private object GpuRowToColumnConverter {
}
}

private class DecimalConverter(precision: Int, scale: Int) extends TypeConverter {
override def append(
row: SpecializedGetters,
column: Int,
builder: ai.rapids.cudf.HostColumnVector.ColumnBuilder): Double = {
if (row.isNullAt(column)) {
builder.appendNull()
} else {
new NotNullDecimalConverter(precision, scale).append(row, column, builder)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is on the data path. I would like us to avoid object creation if at all possible to speed up the data path. Please make a static method instead, or use inheritance, which I think is less ideal.

}
// Infer the storage type via precision, because we can't access DType of builder.
(if (precision > ai.rapids.cudf.DType.DECIMAL32_MAX_PRECISION) 8 else 4) + VALIDITY
}
}

private class NotNullDecimalConverter(precision: Int, scale: Int) extends TypeConverter {
override def append(
row: SpecializedGetters,
column: Int,
builder: ai.rapids.cudf.HostColumnVector.ColumnBuilder): Double = {
builder.append(row.getDecimal(column, precision, scale).toJavaBigDecimal)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to see us use toUnscaledLong to avoid any extra object creation, it is also what we are going to ultimately store the data as. This might mean that we need different classes for DECIMAL64 and DECIMAL32 too.

// Infer the storage type via precision, because we can't access DType of builder.
if (precision > ai.rapids.cudf.DType.DECIMAL32_MAX_PRECISION) 8 else 4
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here too it would be great to avoid conditionals in the data path. I would prefer it if we either passed in the size or had separate implementations for DECIMAL32 and DECIMAL64

}
}

private[this] def mapConvert(
keyConverter: TypeConverter,
valueConverter: TypeConverter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,25 @@ object HostColumnarToGpu {
for (i <- 0 until rows) {
b.appendUTF8String(cv.getUTF8String(i).getBytes)
}
case (dt, nullable) if dt.isDecimalType =>
val precision = if (dt.isBackedByInt) {
DType.DECIMAL32_MAX_PRECISION
} else {
DType.DECIMAL64_MAX_PRECISION
}
if (nullable) {
for (i <- 0 until rows) {
if (cv.isNullAt(i)) {
b.appendNull()
} else {
b.append(cv.getDecimal(i, precision, -dt.getScale).toJavaBigDecimal)
}
}
} else {
for (i <- 0 until rows) {
b.append(cv.getDecimal(i, precision, -dt.getScale).toJavaBigDecimal)
}
}
case (t, n) =>
throw new UnsupportedOperationException(s"Converting to GPU for ${t} is not currently " +
s"supported")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,11 @@ abstract class SparkPlanMeta[INPUT <: SparkPlan](plan: INPUT,
override val childParts: Seq[PartMeta[_]] = Seq.empty
override val childDataWriteCmds: Seq[DataWritingCommandMeta[_]] = Seq.empty

// We assume that all common plans are decimal supportable by default, considering
// whether decimal allowable is mainly determined in expression-level.
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t, allowDecimal = true)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have tests that verify that we can support decimal for all top level spark operations? Have we tested join, expand, generate, filter, project, union, window, sort, or hash agregate? What about all of the arrow python UDF code where we go to/from arrow?

I think it would be much better if we split this big PR up into smaller pieces and put each piece in separately with corresponding tests to show that it works, and we only add decimal to the allow list for those things that we know it works for because we have tested it. If you want me to help with this I am happy to do it. I am already in the middle of doing it for Lists I am going to add in structs, maps, binary, null type and finally calendar interval based off of how much time I have and priorities. Some of these we will only be able to do very basic things with, but that should be enough to unblock others for using them for more complicated processing.


override def convertToCpu(): SparkPlan = {
wrapped.withNewChildren(childPlans.map(_.convertIfNeeded()))
}
Expand Down Expand Up @@ -765,9 +770,13 @@ abstract class BinaryExprMeta[INPUT <: BinaryExpression](
expr: INPUT,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: ConfKeysAndIncompat)
rule: ConfKeysAndIncompat,
allowDecimal: Boolean = false)
extends ExprMeta[INPUT](expr, conf, parent, rule) {

override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t, allowDecimal = allowDecimal)

override final def convertToGpu(): GpuExpression =
convertToGpu(childExprs(0).convertToGpu(), childExprs(1).convertToGpu())

Expand Down
23 changes: 23 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ object GpuScalar {
case DType.TIMESTAMP_DAYS => v.getInt
case DType.TIMESTAMP_MICROSECONDS => v.getLong
case DType.STRING => v.getJavaString
case dt: DType if dt.isDecimalType => Decimal(v.getBigDecimal)
case t => throw new IllegalStateException(s"$t is not a supported rapids scalar type yet")
}

Expand All @@ -88,12 +89,34 @@ object GpuScalar {
case b: Boolean => Scalar.fromBool(b)
case s: String => Scalar.fromString(s)
case s: UTF8String => Scalar.fromString(s.toString)
case dec: Decimal =>
Scalar.fromDecimal(-dec.scale, dec.toUnscaledLong)
case dec: BigDecimal =>
Scalar.fromDecimal(-dec.scale, dec.bigDecimal.unscaledValue().longValueExact())
case _ =>
throw new IllegalStateException(s"${v.getClass} '${v}' is not supported as a scalar yet")
}

def from(v: Any, t: DataType): Scalar = v match {
case _ if v == null => Scalar.fromNull(GpuColumnVector.getRapidsType(t))
case _ if t.isInstanceOf[DecimalType] =>
var bigDec = v match {
case vv: Decimal => vv.toBigDecimal.bigDecimal
case vv: BigDecimal => vv.bigDecimal
case vv: Double => BigDecimal(vv).bigDecimal
case vv: Float => BigDecimal(vv).bigDecimal
case vv: String => BigDecimal(vv).bigDecimal
case vv: Double => BigDecimal(vv).bigDecimal
case vv: Long => BigDecimal(vv).bigDecimal
case vv: Int => BigDecimal(vv).bigDecimal
case vv => throw new IllegalStateException(
s"${vv.getClass} '${vv}' is not supported as a scalar yet")
}
bigDec = bigDec.setScale(t.asInstanceOf[DecimalType].scale)
if (bigDec.precision() > t.asInstanceOf[DecimalType].precision) {
throw new IllegalArgumentException(s"BigDecimal $bigDec exceeds precision constraint of $t")
}
Scalar.fromDecimal(-bigDec.scale(), bigDec.unscaledValue().longValueExact())
case l: Long => t match {
case LongType => Scalar.fromLong(l)
case TimestampType => Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, l)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ object GpuDivModLike {
case DType.INT64 => Scalar.fromLong(0L)
case DType.FLOAT32 => Scalar.fromFloat(0f)
case DType.FLOAT64 => Scalar.fromDouble(0)
case dt if dt.isDecimalType && dt.isBackedByInt => Scalar.fromDecimal(0, 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is going to work in all cases. I think this might be another place were we have some tech debt to pay off and need to pass in a DataType instead of a DType.

case dt if dt.isDecimalType && dt.isBackedByLong => Scalar.fromDecimal(0, 0L)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the div/mod to work properly we also need to update isScalarZero, or we are going to miss a divide by zero case in decimal.

case t => throw new IllegalArgumentException(s"Unexpected type: $t")
}
}
Expand Down
Loading