Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Initial work on supporting DecimalType #1063

Closed
wants to merge 13 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,17 @@ private static DType toRapidsOrNull(DataType type) {
return DType.TIMESTAMP_MICROSECONDS;
} else if (type instanceof StringType) {
return DType.STRING;
} else if (type instanceof DecimalType) {
// Decimal supportable check has been conducted in the GPU plan overriding stage.
// So, we don't have to handle decimal-supportable problem at here.
DecimalType dt = (DecimalType) type;
if (dt.precision() > DType.DECIMAL64_MAX_PRECISION) {
return null;
} else if (dt.precision() > DType.DECIMAL32_MAX_PRECISION) {
return DType.create(DType.DTypeEnum.DECIMAL64, -dt.scale());
} else {
return DType.create(DType.DTypeEnum.DECIMAL32, -dt.scale());
}
}
return null;
}
Expand Down Expand Up @@ -289,6 +300,14 @@ public static ColumnarBatch from(Table table, DataType[] colTypes) {
*/
private static <T> boolean typeConversionAllowed(ColumnViewAccess<T> cv, DataType colType) {
DType dt = cv.getDataType();
if (dt.isDecimalType()) {
if (!(colType instanceof DecimalType)) {
return false;
}
// check for overflow
int maxPrecision = dt.isBackedByLong() ? DType.DECIMAL64_MAX_PRECISION : DType.DECIMAL32_MAX_PRECISION;
return ((DecimalType) colType).precision() <= maxPrecision;
}
if (!dt.isNestedType()) {
return getRapidsType(colType).equals(dt);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
import org.apache.spark.sql.vectorized.ColumnarMap;
import org.apache.spark.unsafe.types.UTF8String;

import java.math.BigDecimal;
import java.math.RoundingMode;

/**
* A GPU accelerated version of the Spark ColumnVector.
* Most of the standard Spark APIs should never be called, as they assume that the data
Expand Down Expand Up @@ -166,7 +169,9 @@ public ColumnarMap getMap(int ordinal) {

@Override
public Decimal getDecimal(int rowId, int precision, int scale) {
throw new IllegalStateException("The decimal type is currently not supported by rapids cudf");
BigDecimal bigDec = cudfCv.getBigDecimal(rowId).setScale(scale, RoundingMode.UNNECESSARY);
assert bigDec.precision() <= precision;
return Decimal.apply(bigDec);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import org.apache.spark.sql.vectorized.ColumnarMap;
import org.apache.spark.unsafe.types.UTF8String;

import java.math.BigDecimal;
import java.math.RoundingMode;

/**
* A GPU accelerated version of the Spark ColumnVector.
Expand Down Expand Up @@ -112,7 +114,9 @@ public ColumnarMap getMap(int ordinal) {

@Override
public Decimal getDecimal(int rowId, int precision, int scale) {
throw new IllegalStateException("The decimal type is currently not supported by rapids cudf");
BigDecimal bigDec = cudfCv.getBigDecimal(rowId).setScale(scale, RoundingMode.UNNECESSARY);
assert bigDec.precision() <= precision;
return Decimal.apply(bigDec);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,14 @@ object GpuOverrides {
}
}

/**
* A walkaround method to include DecimalType for expressions who supports Decimal.
*/
def isSupportedTypeWithDecimal(dataType: DataType): Boolean = dataType match {
sperlingxx marked this conversation as resolved.
Show resolved Hide resolved
case dt: DecimalType => dt.precision <= ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION
case dt => isSupportedType(dt)
}

/**
* Checks to see if any expressions are a String Literal
*/
Expand Down Expand Up @@ -598,7 +606,7 @@ object GpuOverrides {
}

override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t, allowCalendarInterval = true)
GpuOverrides.isSupportedType(t, allowCalendarInterval = true, allowDecimal = true)
}),
expr[Signum](
"Returns -1.0, 0.0 or 1.0 as expr is negative, 0 or positive",
Expand All @@ -609,7 +617,8 @@ object GpuOverrides {
"Gives a column a name",
(a, conf, p, r) => new UnaryExprMeta[Alias](a, conf, p, r) {
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t, allowStringMaps = true, allowBinary = true)
GpuOverrides.isSupportedType(t,
allowStringMaps = true, allowBinary = true, allowDecimal = true)

override def convertToGpu(child: Expression): GpuExpression =
GpuAlias(child, a.name)(a.exprId, a.qualifier, a.explicitMetadata)
Expand All @@ -618,7 +627,7 @@ object GpuOverrides {
"References an input column",
(att, conf, p, r) => new BaseExprMeta[AttributeReference](att, conf, p, r) {
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t, allowStringMaps = true)
GpuOverrides.isSupportedType(t, allowStringMaps = true, allowDecimal = true)

// This is the only NOOP operator. It goes away when things are bound
override def convertToGpu(): Expression = att
Expand Down Expand Up @@ -810,13 +819,15 @@ object GpuOverrides {
expr[IsNull](
"Checks if a value is null",
(a, conf, p, r) => new UnaryExprMeta[IsNull](a, conf, p, r) {
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t, allowStringMaps = true, allowDecimal = true)
override def convertToGpu(child: Expression): GpuExpression = GpuIsNull(child)
}),
expr[IsNotNull](
"Checks if a value is not null",
(a, conf, p, r) => new UnaryExprMeta[IsNotNull](a, conf, p, r) {
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t, allowStringMaps = true)
GpuOverrides.isSupportedType(t, allowStringMaps = true, allowDecimal = true)
override def convertToGpu(child: Expression): GpuExpression = GpuIsNotNull(child)
}),
expr[IsNaN](
Expand Down Expand Up @@ -1799,7 +1810,7 @@ object GpuOverrides {
(proj, conf, p, r) => {
new SparkPlanMeta[ProjectExec](proj, conf, p, r) {
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t, allowStringMaps = true)
GpuOverrides.isSupportedType(t, allowStringMaps = true, allowDecimal = true)

override def convertToGpu(): GpuExec =
GpuProjectExec(childExprs.map(_.convertToGpu()), childPlans(0).convertIfNeeded())
Expand Down Expand Up @@ -1884,7 +1895,7 @@ object GpuOverrides {
"The backend for most filter statements",
(filter, conf, p, r) => new SparkPlanMeta[FilterExec](filter, conf, p, r) {
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t, allowStringMaps = true)
GpuOverrides.isSupportedType(t, allowStringMaps = true, allowDecimal = true)

override def convertToGpu(): GpuExec =
GpuFilterExec(childExprs(0).convertToGpu(), childPlans(0).convertIfNeeded())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,15 @@ private object GpuRowToColumnConverter {
case (TimestampType, false) => NotNullLongConverter
case (StringType, true) => StringConverter
case (StringType, false) => NotNullStringConverter
case (dt: DecimalType, true) => DecimalConverter(dt.precision, dt.scale)
case (dt: DecimalType, false) => NotNullDecimalConverter(dt.precision, dt.scale)
// NOT SUPPORTED YET
// case CalendarIntervalType => CalendarConverter
// NOT SUPPORTED YET
// case at: ArrayType => new ArrayConverter(getConverterForType(at.elementType))
// NOT SUPPORTED YET
// case st: StructType => new StructConverter(st.fields.map(
// (f) => getConverterForType(f.dataType)))
// NOT SUPPORTED YET
// case dt: DecimalType => new DecimalConverter(dt)
// NOT SUPPORTED YET
case (MapType(StringType, StringType, _), _) => MapConverter
case (unknown, _) => throw new UnsupportedOperationException(
Expand Down Expand Up @@ -264,6 +264,26 @@ private object GpuRowToColumnConverter {
}
}

private case class DecimalConverter(
precision: Int, scale: Int) extends FixedWidthTypeConverter {
override def append(row: SpecializedGetters,
column: Int,
builder: ai.rapids.cudf.HostColumnVector.ColumnBuilder): Unit =
if (row.isNullAt(column)) {
builder.appendNull()
} else {
NotNullDecimalConverter(precision, scale).append(row, column, builder)
}
}

private case class NotNullDecimalConverter(
precision: Int, scale: Int) extends FixedWidthTypeConverter {
override def append(row: SpecializedGetters,
column: Int,
builder: ai.rapids.cudf.HostColumnVector.ColumnBuilder): Unit =
builder.append(row.getDecimal(column, precision, scale).toJavaBigDecimal)
}

// ONLY supports Map(String, String)
private case object MapConverter
extends VariableWidthTypeConverter {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,25 @@ object HostColumnarToGpu {
for (i <- 0 until rows) {
b.appendUTF8String(cv.getUTF8String(i).getBytes)
}
case (dt, nullable) if dt.isDecimalType =>
val precision = if (dt.isBackedByInt) {
DType.DECIMAL32_MAX_PRECISION
} else {
DType.DECIMAL64_MAX_PRECISION
}
if (nullable) {
for (i <- 0 until rows) {
if (cv.isNullAt(i)) {
b.appendNull()
} else {
b.append(cv.getDecimal(i, precision, -dt.getScale).toJavaBigDecimal)
}
}
} else {
for (i <- 0 until rows) {
b.append(cv.getDecimal(i, precision, -dt.getScale).toJavaBigDecimal)
}
}
case (t, n) =>
throw new UnsupportedOperationException(s"Converting to GPU for ${t} is not currently " +
s"supported")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,11 @@ abstract class SparkPlanMeta[INPUT <: SparkPlan](plan: INPUT,
override val childParts: Seq[PartMeta[_]] = Seq.empty
override val childDataWriteCmds: Seq[DataWritingCommandMeta[_]] = Seq.empty

// We assume that all common plans are decimal supportable by default, considering
// whether decimal allowable is mainly determined in expression-level.
override def isSupportedType(t: DataType): Boolean =
GpuOverrides.isSupportedType(t, allowDecimal = true)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have tests that verify that we can support decimal for all top level spark operations? Have we tested join, expand, generate, filter, project, union, window, sort, or hash agregate? What about all of the arrow python UDF code where we go to/from arrow?

I think it would be much better if we split this big PR up into smaller pieces and put each piece in separately with corresponding tests to show that it works, and we only add decimal to the allow list for those things that we know it works for because we have tested it. If you want me to help with this I am happy to do it. I am already in the middle of doing it for Lists I am going to add in structs, maps, binary, null type and finally calendar interval based off of how much time I have and priorities. Some of these we will only be able to do very basic things with, but that should be enough to unblock others for using them for more complicated processing.


override def convertToCpu(): SparkPlan = {
wrapped.withNewChildren(childPlans.map(_.convertIfNeeded()))
}
Expand Down
21 changes: 21 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ object GpuScalar {
case DType.TIMESTAMP_DAYS => v.getInt
case DType.TIMESTAMP_MICROSECONDS => v.getLong
case DType.STRING => v.getJavaString
case dt: DType if dt.isDecimalType => Decimal(v.getBigDecimal)
case t => throw new IllegalStateException(s"$t is not a supported rapids scalar type yet")
}

Expand All @@ -88,12 +89,32 @@ object GpuScalar {
case b: Boolean => Scalar.fromBool(b)
case s: String => Scalar.fromString(s)
case s: UTF8String => Scalar.fromString(s.toString)
case dec: Decimal => Scalar.fromDecimal(dec.toBigDecimal.bigDecimal)
case dec: BigDecimal => Scalar.fromDecimal(dec.bigDecimal)
case _ =>
throw new IllegalStateException(s"${v.getClass} '${v}' is not supported as a scalar yet")
}

def from(v: Any, t: DataType): Scalar = v match {
case _ if v == null => Scalar.fromNull(GpuColumnVector.getRapidsType(t))
case _ if t.isInstanceOf[DecimalType] =>
var bigDec = v match {
case vv: Decimal => vv.toBigDecimal.bigDecimal
case vv: BigDecimal => vv.bigDecimal
case vv: Double => BigDecimal(vv).bigDecimal
case vv: Float => BigDecimal(vv).bigDecimal
case vv: String => BigDecimal(vv).bigDecimal
case vv: Double => BigDecimal(vv).bigDecimal
case vv: Long => BigDecimal(vv).bigDecimal
case vv: Int => BigDecimal(vv).bigDecimal
case vv => throw new IllegalStateException(
s"${vv.getClass} '${vv}' is not supported as a scalar yet")
}
bigDec = bigDec.setScale(t.asInstanceOf[DecimalType].scale)
if (bigDec.precision() > t.asInstanceOf[DecimalType].precision) {
throw new IllegalArgumentException(s"BigDecimal $bigDec exceeds precision constraint of $t")
}
Scalar.fromDecimal(bigDec)
case l: Long => t match {
case LongType => Scalar.fromLong(l)
case TimestampType => Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, l)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.scalatest.FunSuite

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType}
import org.apache.spark.sql.types.{DataTypes, Decimal, DecimalType, StructField, StructType}
import org.apache.spark.unsafe.types.UTF8String

class GpuBatchUtilsSuite extends FunSuite {
Expand All @@ -44,6 +44,11 @@ class GpuBatchUtilsSuite extends FunSuite {
StructField("c0", DataTypes.StringType, nullable = false)
))

val decimalSchema = new StructType(Array(
StructField("c0", DataTypes.StringType, nullable = true),
StructField("c0", DataTypes.StringType, nullable = false)
))

/** Mix of data types and nullable and not nullable */
val mixedSchema = new StructType(Array(
StructField("c0", DataTypes.ByteType, nullable = false),
Expand All @@ -61,7 +66,9 @@ class GpuBatchUtilsSuite extends FunSuite {
StructField("c6", DataTypes.StringType, nullable = false),
StructField("c6_nullable", DataTypes.StringType, nullable = true),
StructField("c7", DataTypes.BooleanType, nullable = false),
StructField("c7_nullable", DataTypes.BooleanType, nullable = true)
StructField("c7_nullable", DataTypes.BooleanType, nullable = true),
StructField("c8", DataTypes.createDecimalType(15, 6), nullable = false),
StructField("c8_nullable", DataTypes.createDecimalType(15, 6), nullable = true)
))

test("Calculate GPU memory for batch of 64 rows with integers") {
Expand All @@ -72,6 +79,10 @@ class GpuBatchUtilsSuite extends FunSuite {
compareEstimateWithActual(stringSchema, 64)
}

test("Calculate GPU memory for batch of 64 rows with decimals") {
compareEstimateWithActual(decimalSchema, 64)
}

test("Calculate GPU memory for batch of 64 rows with mixed types") {
compareEstimateWithActual(mixedSchema, 64)
}
Expand All @@ -84,6 +95,10 @@ class GpuBatchUtilsSuite extends FunSuite {
compareEstimateWithActual(stringSchema, 124)
}

test("Calculate GPU memory for batch of 124 rows with decimals") {
compareEstimateWithActual(decimalSchema, 124)
}

test("Calculate GPU memory for batch of 124 rows with mixed types") {
compareEstimateWithActual(mixedSchema, 124)
}
Expand All @@ -96,6 +111,10 @@ class GpuBatchUtilsSuite extends FunSuite {
compareEstimateWithActual(stringSchema, 1024)
}

test("Calculate GPU memory for batch of 1024 rows with decimals") {
compareEstimateWithActual(decimalSchema, 1024)
}

test("Calculate GPU memory for batch of 1024 rows with mixed types") {
compareEstimateWithActual(mixedSchema, 1024)
}
Expand Down Expand Up @@ -185,6 +204,10 @@ class GpuBatchUtilsSuite extends FunSuite {
case DataTypes.LongType => maybeNull(field, i, r.nextLong())
case DataTypes.FloatType => maybeNull(field, i, r.nextFloat())
case DataTypes.DoubleType => maybeNull(field, i, r.nextDouble())
case dataType: DecimalType =>
val upperBound = (0 until dataType.precision).foldLeft(1L)((x, _) => x * 10)
val unScaledValue = r.nextLong() % upperBound
maybeNull(field, i, Decimal(unScaledValue, dataType.precision, dataType.scale))
case dataType@DataTypes.StringType =>
if (field.nullable) {
// since we want a deterministic test that compares the estimate with actual
Expand Down
Loading