Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support arithmetic operators on ANSI interval types #5020

Merged
merged 13 commits into from
Apr 8, 2022
20 changes: 12 additions & 8 deletions integration_tests/src/main/python/arithmetic_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,10 +879,11 @@ def test_unary_minus_day_time_interval(ansi_enabled):
conf={'spark.sql.ansi.enabled': ansi_enabled})

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
def test_unary_minus_ansi_overflow_day_time_interval():
@pytest.mark.parametrize('ansi_enabled', ['false', 'true'])
def test_unary_minus_ansi_overflow_day_time_interval(ansi_enabled):
assert_gpu_and_cpu_error(
df_fun=lambda spark: _get_overflow_df(spark, [timedelta(microseconds=LONG_MIN)], DayTimeIntervalType(), '-a').collect(),
conf=ansi_enabled_conf,
conf={'spark.sql.ansi.enabled': ansi_enabled},
error_message='ArithmeticException')

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
Expand All @@ -894,10 +895,11 @@ def test_abs_ansi_no_overflow_day_time_interval(ansi_enabled):
conf={'spark.sql.ansi.enabled': ansi_enabled})

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
def test_abs_ansi_overflow_day_time_interval():
@pytest.mark.parametrize('ansi_enabled', ['false', 'true'])
def test_abs_ansi_overflow_day_time_interval(ansi_enabled):
assert_gpu_and_cpu_error(
df_fun=lambda spark: _get_overflow_df(spark, [timedelta(microseconds=LONG_MIN)], DayTimeIntervalType(), 'abs(a)').collect(),
conf=ansi_enabled_conf,
conf={'spark.sql.ansi.enabled': ansi_enabled},
error_message='ArithmeticException')

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
Expand All @@ -910,13 +912,14 @@ def test_addition_day_time_interval(ansi_enabled):
conf={'spark.sql.ansi.enabled': ansi_enabled})

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
def test_add_overflow_with_ansi_enabled_day_time_interval():
@pytest.mark.parametrize('ansi_enabled', ['false', 'true'])
def test_add_overflow_with_ansi_enabled_day_time_interval(ansi_enabled):
assert_gpu_and_cpu_error(
df_fun=lambda spark: spark.createDataFrame(
SparkContext.getOrCreate().parallelize([(timedelta(microseconds=LONG_MAX), timedelta(microseconds=10)),]),
StructType([StructField('a', DayTimeIntervalType()), StructField('b', DayTimeIntervalType())])
).selectExpr('a + b').collect(),
conf=ansi_enabled_conf,
conf={'spark.sql.ansi.enabled': ansi_enabled},
error_message='ArithmeticException')

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
Expand All @@ -929,13 +932,14 @@ def test_subtraction_day_time_interval(ansi_enabled):
conf={'spark.sql.ansi.enabled': ansi_enabled})

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
def test_subtraction_overflow_with_ansi_enabled_day_time_interval():
@pytest.mark.parametrize('ansi_enabled', ['false', 'true'])
def test_subtraction_overflow_with_ansi_enabled_day_time_interval(ansi_enabled):
assert_gpu_and_cpu_error(
df_fun=lambda spark: spark.createDataFrame(
SparkContext.getOrCreate().parallelize([(timedelta(microseconds=LONG_MIN), timedelta(microseconds=10)),]),
StructType([StructField('a', DayTimeIntervalType()), StructField('b', DayTimeIntervalType())])
).selectExpr('a - b').collect(),
conf=ansi_enabled_conf,
conf={'spark.sql.ansi.enabled': ansi_enabled},
error_message='ArithmeticException')

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
Expand Down
7 changes: 6 additions & 1 deletion integration_tests/src/main/python/ast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from asserts import assert_cpu_and_gpu_are_equal_collect_with_capture
jlowe marked this conversation as resolved.
Show resolved Hide resolved
from data_gen import *
from marks import approximate_float
from spark_session import with_cpu_session
from spark_session import with_cpu_session, is_before_spark_330
import pyspark.sql.functions as f

# Each descriptor contains a list of data generators and a corresponding boolean
Expand Down Expand Up @@ -105,6 +105,11 @@ def test_bitwise_not(data_descr):
def test_unary_positive(data_descr):
assert_unary_ast(data_descr, lambda df: df.selectExpr('+a'))

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
def test_unary_positive_for_daytime_interval():
data_descr = (DayTimeIntervalGen(), True)
assert_unary_ast(data_descr, lambda df: df.selectExpr('+a'))

@pytest.mark.parametrize('data_descr', ast_arithmetic_descrs, ids=idfn)
def test_unary_minus(data_descr):
assert_unary_ast(data_descr, lambda df: df.selectExpr('-a'))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package com.nvidia.spark.rapids.shims
import ai.rapids.cudf
import ai.rapids.cudf.DType
import com.nvidia.spark.rapids.GpuRowToColumnConverter.TypeConverter
import com.nvidia.spark.rapids.TypeSig

import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.vectorized.ColumnVector
Expand Down Expand Up @@ -82,15 +83,55 @@ object GpuTypeShims {
throw new RuntimeException(s"Not support type $dt.")

/**
* Spark supports interval type from 320; Spark supports to/from Parquet from 330,
* And lots of interval operators from 330, so just return false for Spark330-
* Alias supported types for this shim
*/
def isDayTimeIntervalType(dt: DataType) : Boolean = false
def supportedTypesForAlias: TypeSig = TypeSig.none

/**
* Spark supports interval type from 320; Spark supports to/from Parquet from 330,
* And lots of interval operators from 330, so just return false for Spark330-
* Abs supported types for this shim
*/
def isYearMonthIntervalType(dt: DataType) : Boolean = false
def supportedTypesForAbs: TypeSig = TypeSig.none
jlowe marked this conversation as resolved.
Show resolved Hide resolved
jlowe marked this conversation as resolved.
Show resolved Hide resolved

def isDayTimeTypeAndAbsSupports(dt: DataType): Boolean = false
jlowe marked this conversation as resolved.
Show resolved Hide resolved

def isYearMonthTypeAndAbsSupports(dt: DataType): Boolean = false

/**
* Minus supported types for this shim
*/
def supportedTypesForMinus: TypeSig = TypeSig.none

def isDayTimeTypeAndMinusSupports(dt: DataType): Boolean = false

def isYearMonthTypeAndMinusSupports(dt: DataType): Boolean = false

/**
* Add supported types for this shim
*/
def supportedTypesForAdd: TypeSig = TypeSig.none

def isDayTimeTypeAndAddSupports(dt: DataType): Boolean = false

def isYearMonthTypeAndAddSupports(dt: DataType): Boolean = false

/**
* Subtract supported types for this shim
*/
def supportedTypesForSubtract: TypeSig = TypeSig.none

def isDayTimeTypeAndSubtractSupports(dt: DataType): Boolean = false

def isYearMonthTypeAndSubtractSupports(dt: DataType): Boolean = false

/**
* Positive supported types for this shim
*/
def supportedTypesForPositive: TypeSig = TypeSig.none

def supportedTypesForCoalesce: TypeSig = TypeSig.none

def supportedTypesForShuffle: TypeSig = TypeSig.none

def supportedTypesForAttributeReference: TypeSig = TypeSig.none

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ package com.nvidia.spark.rapids.shims

import ai.rapids.cudf
import ai.rapids.cudf.{DType, Scalar}
import com.nvidia.spark.rapids.ColumnarCopyHelper
import com.nvidia.spark.rapids.{ColumnarCopyHelper, TypeSig}
import com.nvidia.spark.rapids.GpuRowToColumnConverter.{IntConverter, LongConverter, NotNullIntConverter, NotNullLongConverter, TypeConverter}

import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, YearMonthIntervalType}
Expand Down Expand Up @@ -170,7 +170,64 @@ object GpuTypeShims {
}
}

def isDayTimeIntervalType(dt: DataType) : Boolean = dt.isInstanceOf[DayTimeIntervalType]
/**
* Alias supported types for this shim
*/
def supportedTypesForAlias: TypeSig = TypeSig.ansiIntervals

/**
* Abs supported types for this shim
*/
def supportedTypesForAbs: TypeSig = TypeSig.ansiIntervals

def isDayTimeTypeAndAbsSupports(dt: DataType): Boolean =
dt.isInstanceOf[DayTimeIntervalType]

def isYearMonthTypeAndAbsSupports(dt: DataType): Boolean =
dt.isInstanceOf[YearMonthIntervalType]

/**
* Minus supported types for this shim
*/
def supportedTypesForMinus: TypeSig = TypeSig.ansiIntervals

def isDayTimeTypeAndMinusSupports(dt: DataType): Boolean =
dt.isInstanceOf[DayTimeIntervalType]

def isYearMonthTypeAndMinusSupports(dt: DataType): Boolean =
dt.isInstanceOf[YearMonthIntervalType]

/**
* Add supported types for this shim
*/
def supportedTypesForAdd: TypeSig = TypeSig.ansiIntervals

def isDayTimeTypeAndAddSupports(dt: DataType): Boolean =
dt.isInstanceOf[DayTimeIntervalType]

def isYearMonthTypeAndAddSupports(dt: DataType): Boolean =
dt.isInstanceOf[YearMonthIntervalType]

/**
* Subtract supported types for this shim
*/
def supportedTypesForSubtract: TypeSig = TypeSig.ansiIntervals

def isDayTimeTypeAndSubtractSupports(dt: DataType): Boolean =
dt.isInstanceOf[DayTimeIntervalType]

def isYearMonthTypeAndSubtractSupports(dt: DataType): Boolean =
dt.isInstanceOf[YearMonthIntervalType]

/**
* Positive supported types for this shim
*/
def supportedTypesForPositive: TypeSig = TypeSig.ansiIntervals

def supportedTypesForCoalesce: TypeSig = TypeSig.ansiIntervals

def supportedTypesForShuffle: TypeSig = TypeSig.ansiIntervals

def supportedTypesForAttributeReference: TypeSig = TypeSig.ansiIntervals

def isYearMonthIntervalType(dt: DataType) : Boolean = dt.isInstanceOf[YearMonthIntervalType]
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,14 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.{BaseSubqueryExec, CoalesceExec, FileSourceScanExec, FilterExec, InSubqueryExec, ProjectExec, ReusedSubqueryExec, SparkPlan, SubqueryBroadcastExec}
import org.apache.spark.sql.execution.{BaseSubqueryExec, FileSourceScanExec, FilterExec, InSubqueryExec, ProjectExec, ReusedSubqueryExec, SparkPlan, SubqueryBroadcastExec}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.command.DataWritingCommandExec
import org.apache.spark.sql.execution.datasources.{DataSourceUtils, FilePartition, FileScanRDD, HadoopFsRelation, PartitionedFile}
import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids._
import org.apache.spark.sql.rapids.execution.GpuShuffleMeta
import org.apache.spark.sql.rapids.shims.GpuTimeAdd
import org.apache.spark.sql.types.{CalendarIntervalType, DayTimeIntervalType, DecimalType, StructType}
import org.apache.spark.unsafe.types.CalendarInterval
Expand Down Expand Up @@ -97,39 +95,6 @@ trait Spark33XShims extends Spark321PlusShims with Spark320PlusNonDBShims {
override def getExprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = {
val _gpuCommonTypes = TypeSig.commonCudfTypes + TypeSig.NULL
val map: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Seq(
GpuOverrides.expr[Coalesce](
"Returns the first non-null argument if exists. Otherwise, null",
ExprChecks.projectOnly(
(_gpuCommonTypes + TypeSig.DECIMAL_128 + TypeSig.ARRAY + TypeSig.STRUCT +
TypeSig.ansiIntervals).nested(),
TypeSig.all,
repeatingParamCheck = Some(RepeatingParamCheck("param",
(_gpuCommonTypes + TypeSig.DECIMAL_128 + TypeSig.ARRAY + TypeSig.STRUCT +
TypeSig.ansiIntervals).nested(),
TypeSig.all))),
(a, conf, p, r) => new ExprMeta[Coalesce](a, conf, p, r) {
override def convertToGpu():
GpuExpression = GpuCoalesce(childExprs.map(_.convertToGpu()))
}),
GpuOverrides.expr[AttributeReference](
"References an input column",
ExprChecks.projectAndAst(
TypeSig.astTypes,
(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.MAP + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.DECIMAL_128 + TypeSig.ansiIntervals).nested(),
TypeSig.all),
(att, conf, p, r) => new BaseExprMeta[AttributeReference](att, conf, p, r) {
// This is the only NOOP operator. It goes away when things are bound
override def convertToGpu(): Expression = att

// There are so many of these that we don't need to print them out, unless it
// will not work on the GPU
override def print(append: StringBuilder, depth: Int, all: Boolean): Unit = {
if (!this.canThisBeReplaced || cannotRunOnGpuBecauseOfSparkPlan) {
super.print(append, depth, all)
}
}
}),
GpuOverrides.expr[RoundCeil](
"Computes the ceiling of the given expression to d decimal places",
ExprChecks.binaryProject(
Expand Down Expand Up @@ -329,18 +294,7 @@ trait Spark33XShims extends Spark321PlusShims with Spark320PlusNonDBShims {

// ANSI support for ABS was added in 3.2.0 SPARK-33275
override def convertToGpu(child: Expression): GpuExpression = GpuAbs(child, ansiEnabled)
}),
GpuOverrides.expr[Alias](
"Gives a column a name",
ExprChecks.unaryProjectAndAstInputMatchesOutput(
TypeSig.astTypes,
(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.MAP + TypeSig.ARRAY + TypeSig.STRUCT
+ TypeSig.DECIMAL_128 + TypeSig.ansiIntervals).nested(),
TypeSig.all),
(a, conf, p, r) => new UnaryAstExprMeta[Alias](a, conf, p, r) {
override def convertToGpu(child: Expression): GpuExpression =
GpuAlias(child, a.name)(a.exprId, a.qualifier, a.explicitMetadata)
})
})
).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap
super.getExprs ++ map
}
Expand All @@ -349,18 +303,6 @@ trait Spark33XShims extends Spark321PlusShims with Spark320PlusNonDBShims {
override def getExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] = {
val _gpuCommonTypes = TypeSig.commonCudfTypes + TypeSig.NULL
firestarman marked this conversation as resolved.
Show resolved Hide resolved
val map: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] = Seq(
GpuOverrides.exec[ShuffleExchangeExec](
"The backend for most data being exchanged between processes",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.ARRAY +
TypeSig.STRUCT + TypeSig.MAP + TypeSig.DAYTIME + TypeSig.ansiIntervals).nested()
.withPsNote(TypeEnum.STRUCT, "Round-robin partitioning is not supported for nested " +
s"structs if ${SQLConf.SORT_BEFORE_REPARTITION.key} is true")
.withPsNote(TypeEnum.ARRAY, "Round-robin partitioning is not supported if " +
s"${SQLConf.SORT_BEFORE_REPARTITION.key} is true")
.withPsNote(TypeEnum.MAP, "Round-robin partitioning is not supported if " +
s"${SQLConf.SORT_BEFORE_REPARTITION.key} is true"),
TypeSig.all),
(shuffle, conf, p, r) => new GpuShuffleMeta(shuffle, conf, p, r)),
GpuOverrides.exec[BatchScanExec](
"The backend for most file input",
ExecChecks(
Expand All @@ -383,15 +325,6 @@ trait Spark33XShims extends Spark321PlusShims with Spark320PlusNonDBShims {
override def convertToGpu(): GpuExec = GpuBatchScanExec(p.output,
childScans.head.convertToGpu(), p.runtimeFilters, p.keyGroupedPartitioning)
}),
GpuOverrides.exec[CoalesceExec](
"The backend for the dataframe coalesce method",
ExecChecks((_gpuCommonTypes + TypeSig.DECIMAL_128 + TypeSig.STRUCT + TypeSig.ARRAY +
TypeSig.MAP + TypeSig.ansiIntervals).nested(),
TypeSig.all),
(coalesce, conf, parent, r) => new SparkPlanMeta[CoalesceExec](coalesce, conf, parent, r) {
override def convertToGpu(): GpuExec =
GpuCoalesceExec(coalesce.numPartitions, childPlans.head.convertIfNeeded())
}),
GpuOverrides.exec[DataWritingCommandExec](
"Writing data",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.DECIMAL_128.withPsNote(
Expand Down
Loading