Skip to content

Commit

Permalink
Rename DECIMAL_128_FULL and rework usage of TypeSig.gpuNumeric (NVIDI…
Browse files Browse the repository at this point in the history
…A#4462)

* Rename DECIMAL_128_FULL and rework usage of TypeSig.gpuNumeric

Signed-off-by: Kuhu Shukla <kuhus@nvidia.com>

Co-authored-by: Kuhu Shukla <kuhus@nvidia.com>
  • Loading branch information
Kuhu Shukla and kuhushukla authored Jan 10, 2022
1 parent 9417ee5 commit 5749673
Show file tree
Hide file tree
Showing 15 changed files with 334 additions and 334 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ abstract class Spark30XdbShims extends Spark30XdbShimsBase with Logging {
"Databricks-specific window function exec, for \"running\" windows, " +
"i.e. (UNBOUNDED PRECEDING TO CURRENT ROW)",
ExecChecks(
(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128_FULL +
(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 +
TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP).nested(),
TypeSig.all,
Map("partitionSpec" ->
Expand All @@ -142,7 +142,7 @@ abstract class Spark30XdbShims extends Spark30XdbShimsBase with Logging {
GpuOverrides.exec[FileSourceScanExec](
"Reading data from files, often from Hive tables",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP +
TypeSig.ARRAY + TypeSig.DECIMAL_128_FULL).nested(), TypeSig.all),
TypeSig.ARRAY + TypeSig.DECIMAL_128).nested(), TypeSig.all),
(fsse, conf, p, r) => new SparkPlanMeta[FileSourceScanExec](fsse, conf, p, r) {

// Replaces SubqueryBroadcastExec inside dynamic pruning filters with GPU counterpart
Expand Down Expand Up @@ -297,11 +297,11 @@ abstract class Spark30XdbShims extends Spark30XdbShimsBase with Logging {
GpuOverrides.expr[Average](
"Average aggregate operator",
ExprChecks.fullAgg(
TypeSig.DOUBLE + TypeSig.DECIMAL_128_FULL,
TypeSig.DOUBLE + TypeSig.DECIMAL_128_FULL,
TypeSig.DOUBLE + TypeSig.DECIMAL_128,
TypeSig.DOUBLE + TypeSig.DECIMAL_128,
Seq(ParamCheck("input",
TypeSig.integral + TypeSig.fp + TypeSig.DECIMAL_128_FULL,
TypeSig.numeric))),
TypeSig.integral + TypeSig.fp + TypeSig.DECIMAL_128,
TypeSig.cpuNumeric))),
(a, conf, p, r) => new AggExprMeta[Average](a, conf, p, r) {
override def tagAggForGpu(): Unit = {
// For Decimal Average the SUM adds a precision of 10 to avoid overflowing
Expand Down Expand Up @@ -335,8 +335,8 @@ abstract class Spark30XdbShims extends Spark30XdbShimsBase with Logging {
GpuOverrides.expr[Abs](
"Absolute value",
ExprChecks.unaryProjectAndAstInputMatchesOutput(
TypeSig.implicitCastsAstTypes, TypeSig.gpuNumeric + TypeSig.DECIMAL_128_FULL,
TypeSig.numeric),
TypeSig.implicitCastsAstTypes, TypeSig.gpuNumeric,
TypeSig.cpuNumeric),
(a, conf, p, r) => new UnaryAstExprMeta[Abs](a, conf, p, r) {
// ANSI support for ABS was added in 3.2.0 SPARK-33275
override def convertToGpu(child: Expression): GpuExpression = GpuAbs(child, false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ trait Spark30XdbShimsBase extends SparkShims {

override def aqeShuffleReaderExec: ExecRule[_ <: SparkPlan] = exec[CustomShuffleReaderExec](
"A wrapper of shuffle query stage",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128_FULL + TypeSig.ARRAY +
ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.MAP).nested(), TypeSig.all),
(exec, conf, p, r) => new GpuCustomShuffleReaderMeta(exec, conf, p, r))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ abstract class Spark30XShims extends Spark301until320Shims with Logging {
GpuOverrides.exec[FileSourceScanExec](
"Reading data from files, often from Hive tables",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP +
TypeSig.ARRAY + TypeSig.DECIMAL_128_FULL).nested(), TypeSig.all),
TypeSig.ARRAY + TypeSig.DECIMAL_128).nested(), TypeSig.all),
(fsse, conf, p, r) => new SparkPlanMeta[FileSourceScanExec](fsse, conf, p, r) {

// Replaces SubqueryBroadcastExec inside dynamic pruning filters with GPU counterpart
Expand Down Expand Up @@ -243,11 +243,11 @@ abstract class Spark30XShims extends Spark301until320Shims with Logging {
GpuOverrides.expr[Average](
"Average aggregate operator",
ExprChecks.fullAgg(
TypeSig.DOUBLE + TypeSig.DECIMAL_128_FULL,
TypeSig.DOUBLE + TypeSig.DECIMAL_128_FULL,
TypeSig.DOUBLE + TypeSig.DECIMAL_128,
TypeSig.DOUBLE + TypeSig.DECIMAL_128,
Seq(ParamCheck("input",
TypeSig.integral + TypeSig.fp + TypeSig.DECIMAL_128_FULL,
TypeSig.numeric))),
TypeSig.integral + TypeSig.fp + TypeSig.DECIMAL_128,
TypeSig.cpuNumeric))),
(a, conf, p, r) => new AggExprMeta[Average](a, conf, p, r) {
override def tagAggForGpu(): Unit = {
// For Decimal Average the SUM adds a precision of 10 to avoid overflowing
Expand Down Expand Up @@ -281,8 +281,8 @@ abstract class Spark30XShims extends Spark301until320Shims with Logging {
GpuOverrides.expr[Abs](
"Absolute value",
ExprChecks.unaryProjectAndAstInputMatchesOutput(
TypeSig.implicitCastsAstTypes, TypeSig.gpuNumeric + TypeSig.DECIMAL_128_FULL,
TypeSig.numeric),
TypeSig.implicitCastsAstTypes, TypeSig.gpuNumeric,
TypeSig.cpuNumeric),
(a, conf, p, r) => new UnaryAstExprMeta[Abs](a, conf, p, r) {
// ANSI support for ABS was added in 3.2.0 SPARK-33275
override def convertToGpu(child: Expression): GpuExpression = GpuAbs(child, false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,5 @@ object TypeSigUtil extends TypeSigUtilBase {

/** Get numeric and interval TypeSig */
override def getNumericAndInterval(): TypeSig =
TypeSig.numeric + TypeSig.CALENDAR
TypeSig.cpuNumeric + TypeSig.CALENDAR
}
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ trait Spark301until320Shims extends SparkShims {

override def aqeShuffleReaderExec: ExecRule[_ <: SparkPlan] = exec[CustomShuffleReaderExec](
"A wrapper of shuffle query stage",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128_FULL + TypeSig.ARRAY +
ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.MAP).nested(), TypeSig.all),
(exec, conf, p, r) => new GpuCustomShuffleReaderMeta(exec, conf, p, r))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,15 @@ abstract class Spark31XShims extends Spark301until320Shims with Logging {
import TypeSig._
// nullChecks are the same

override val booleanChecks: TypeSig = integral + fp + BOOLEAN + STRING + DECIMAL_128_FULL
override val sparkBooleanSig: TypeSig = numeric + BOOLEAN + STRING
override val booleanChecks: TypeSig = integral + fp + BOOLEAN + STRING + DECIMAL_128
override val sparkBooleanSig: TypeSig = cpuNumeric + BOOLEAN + STRING

override val integralChecks: TypeSig = gpuNumeric + BOOLEAN + STRING + DECIMAL_128_FULL
override val sparkIntegralSig: TypeSig = numeric + BOOLEAN + STRING
override val integralChecks: TypeSig = gpuNumeric + BOOLEAN + STRING
override val sparkIntegralSig: TypeSig = cpuNumeric + BOOLEAN + STRING

override val fpChecks: TypeSig = (gpuNumeric + BOOLEAN + STRING + DECIMAL_128_FULL)
override val fpChecks: TypeSig = (gpuNumeric + BOOLEAN + STRING)
.withPsNote(TypeEnum.STRING, fpToStringPsNote)
override val sparkFpSig: TypeSig = numeric + BOOLEAN + STRING
override val sparkFpSig: TypeSig = cpuNumeric + BOOLEAN + STRING

override val dateChecks: TypeSig = TIMESTAMP + DATE + STRING
override val sparkDateSig: TypeSig = TIMESTAMP + DATE + STRING
Expand All @@ -131,25 +131,25 @@ abstract class Spark31XShims extends Spark301until320Shims with Logging {

// stringChecks are the same
// binaryChecks are the same
override val decimalChecks: TypeSig = gpuNumeric + DECIMAL_128_FULL + STRING
override val sparkDecimalSig: TypeSig = numeric + BOOLEAN + STRING
override val decimalChecks: TypeSig = gpuNumeric + STRING
override val sparkDecimalSig: TypeSig = cpuNumeric + BOOLEAN + STRING

// calendarChecks are the same

override val arrayChecks: TypeSig =
ARRAY.nested(commonCudfTypes + DECIMAL_128_FULL + NULL + ARRAY + BINARY + STRUCT) +
ARRAY.nested(commonCudfTypes + DECIMAL_128 + NULL + ARRAY + BINARY + STRUCT) +
psNote(TypeEnum.ARRAY, "The array's child type must also support being cast to " +
"the desired child type")
override val sparkArraySig: TypeSig = ARRAY.nested(all)

override val mapChecks: TypeSig =
MAP.nested(commonCudfTypes + DECIMAL_128_FULL + NULL + ARRAY + BINARY + STRUCT + MAP) +
MAP.nested(commonCudfTypes + DECIMAL_128 + NULL + ARRAY + BINARY + STRUCT + MAP) +
psNote(TypeEnum.MAP, "the map's key and value must also support being cast to the " +
"desired child types")
override val sparkMapSig: TypeSig = MAP.nested(all)

override val structChecks: TypeSig =
STRUCT.nested(commonCudfTypes + DECIMAL_128_FULL + NULL + ARRAY + BINARY + STRUCT) +
STRUCT.nested(commonCudfTypes + DECIMAL_128 + NULL + ARRAY + BINARY + STRUCT) +
psNote(TypeEnum.STRUCT, "the struct's children must also support being cast to the " +
"desired child type(s)")
override val sparkStructSig: TypeSig = STRUCT.nested(all)
Expand All @@ -162,11 +162,11 @@ abstract class Spark31XShims extends Spark301until320Shims with Logging {
GpuOverrides.expr[Average](
"Average aggregate operator",
ExprChecks.fullAgg(
TypeSig.DOUBLE + TypeSig.DECIMAL_128_FULL,
TypeSig.DOUBLE + TypeSig.DECIMAL_128_FULL,
TypeSig.DOUBLE + TypeSig.DECIMAL_128,
TypeSig.DOUBLE + TypeSig.DECIMAL_128,
Seq(ParamCheck("input",
TypeSig.integral + TypeSig.fp + TypeSig.DECIMAL_128_FULL,
TypeSig.numeric))),
TypeSig.integral + TypeSig.fp + TypeSig.DECIMAL_128,
TypeSig.cpuNumeric))),
(a, conf, p, r) => new AggExprMeta[Average](a, conf, p, r) {
override def tagAggForGpu(): Unit = {
// For Decimal Average the SUM adds a precision of 10 to avoid overflowing
Expand Down Expand Up @@ -200,8 +200,8 @@ abstract class Spark31XShims extends Spark301until320Shims with Logging {
GpuOverrides.expr[Abs](
"Absolute value",
ExprChecks.unaryProjectAndAstInputMatchesOutput(
TypeSig.implicitCastsAstTypes, TypeSig.gpuNumeric + TypeSig.DECIMAL_128_FULL,
TypeSig.numeric),
TypeSig.implicitCastsAstTypes, TypeSig.gpuNumeric,
TypeSig.cpuNumeric),
(a, conf, p, r) => new UnaryAstExprMeta[Abs](a, conf, p, r) {
// ANSI support for ABS was added in 3.2.0 SPARK-33275
override def convertToGpu(child: Expression): GpuExpression = GpuAbs(child, false)
Expand All @@ -222,17 +222,17 @@ abstract class Spark31XShims extends Spark301until320Shims with Logging {
GpuOverrides.expr[Lead](
"Window function that returns N entries ahead of this one",
ExprChecks.windowOnly(
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128_FULL + TypeSig.NULL +
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL +
TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all,
Seq(
ParamCheck("input",
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128_FULL +
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 +
TypeSig.NULL + TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all),
ParamCheck("offset", TypeSig.INT, TypeSig.INT),
ParamCheck("default",
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128_FULL + TypeSig.NULL +
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL +
TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all)
)
Expand All @@ -245,17 +245,17 @@ abstract class Spark31XShims extends Spark301until320Shims with Logging {
GpuOverrides.expr[Lag](
"Window function that returns N entries behind this one",
ExprChecks.windowOnly(
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128_FULL + TypeSig.NULL +
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL +
TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all,
Seq(
ParamCheck("input",
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128_FULL +
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 +
TypeSig.NULL + TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all),
ParamCheck("offset", TypeSig.INT, TypeSig.INT),
ParamCheck("default",
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128_FULL + TypeSig.NULL +
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL +
TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all)
)
Expand All @@ -269,10 +269,10 @@ abstract class Spark31XShims extends Spark301until320Shims with Logging {
"Gets the field at `ordinal` in the Array",
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128_FULL + TypeSig.MAP).nested(),
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("array", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128_FULL + TypeSig.MAP),
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP),
TypeSig.ARRAY.nested(TypeSig.all)),
("ordinal", TypeSig.lit(TypeEnum.INT), TypeSig.INT)),
(in, conf, p, r) => new GpuGetArrayItemMeta(in, conf, p, r){
Expand All @@ -293,9 +293,9 @@ abstract class Spark31XShims extends Spark301until320Shims with Logging {
"Returns value for the given key in value if column is map.",
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128_FULL + TypeSig.MAP).nested(), TypeSig.all,
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(), TypeSig.all,
("array/map", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128_FULL + TypeSig.MAP) +
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP) +
TypeSig.MAP.nested(TypeSig.STRING)
.withPsNote(TypeEnum.MAP ,"If it's map, only string is supported."),
TypeSig.ARRAY.nested(TypeSig.all) + TypeSig.MAP.nested(TypeSig.all)),
Expand All @@ -318,10 +318,10 @@ abstract class Spark31XShims extends Spark301until320Shims with Logging {
// Match exactly with the checks for GetArrayItem
ExprChecks.binaryProject(
(TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.NULL +
TypeSig.DECIMAL_128_FULL + TypeSig.MAP).nested(),
TypeSig.DECIMAL_128 + TypeSig.MAP).nested(),
TypeSig.all,
("array", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128_FULL + TypeSig.MAP),
TypeSig.STRUCT + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.MAP),
TypeSig.ARRAY.nested(TypeSig.all)),
("ordinal", TypeSig.lit(TypeEnum.INT), TypeSig.INT))
case _ => throw new IllegalStateException("Only Array or Map is supported as input.")
Expand Down Expand Up @@ -361,7 +361,7 @@ abstract class Spark31XShims extends Spark301until320Shims with Logging {
GpuOverrides.exec[FileSourceScanExec](
"Reading data from files, often from Hive tables",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP +
TypeSig.ARRAY + TypeSig.DECIMAL_128_FULL).nested(), TypeSig.all),
TypeSig.ARRAY + TypeSig.DECIMAL_128).nested(), TypeSig.all),
(fsse, conf, p, r) => new SparkPlanMeta[FileSourceScanExec](fsse, conf, p, r) {

// Replaces SubqueryBroadcastExec inside dynamic pruning filters with GPU counterpart
Expand Down
Loading

0 comments on commit 5749673

Please sign in to comment.