Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support arithmetic operators on ANSI interval types #5020

Merged
merged 13 commits into from
Apr 8, 2022
81 changes: 80 additions & 1 deletion integration_tests/src/main/python/arithmetic_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@

import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_error, assert_gpu_fallback_collect
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_error, assert_gpu_fallback_collect, assert_gpu_and_cpu_are_equal_sql
from data_gen import *
from marks import ignore_order, incompat, approximate_float, allow_non_gpu
from pyspark.sql.types import *
from pyspark.sql.types import IntegralType
from spark_session import with_cpu_session, with_gpu_session, with_spark_session, is_before_spark_320, is_before_spark_330, is_databricks91_or_later
import pyspark.sql.functions as f
from datetime import timedelta

# No overflow gens here because we just focus on verifying the fallback to CPU when
# enabling ANSI mode. But overflows will fail the tests because CPU runs raise
Expand Down Expand Up @@ -867,3 +868,81 @@ def test_subtraction_overflow_with_ansi_enabled(data, tp, expr):
assert_gpu_and_cpu_are_equal_collect(
func=lambda spark: _get_overflow_df(spark, data, tp, expr),
conf=ansi_enabled_conf)


@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('ansi_enabled', ['false', 'true'])
def test_unary_minus_day_time_interval(ansi_enabled):
DAY_TIME_GEN_NO_OVER_FLOW = DayTimeIntervalGen(min_value=timedelta(days=-2000*365), max_value=timedelta(days=3000*365))
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, DAY_TIME_GEN_NO_OVER_FLOW).selectExpr('-a'),
conf={'spark.sql.ansi.enabled': ansi_enabled})

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('ansi_enabled', ['false', 'true'])
def test_unary_minus_ansi_overflow_day_time_interval(ansi_enabled):
assert_gpu_and_cpu_error(
df_fun=lambda spark: _get_overflow_df(spark, [timedelta(microseconds=LONG_MIN)], DayTimeIntervalType(), '-a').collect(),
conf={'spark.sql.ansi.enabled': ansi_enabled},
error_message='ArithmeticException')

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('ansi_enabled', ['false', 'true'])
def test_abs_ansi_no_overflow_day_time_interval(ansi_enabled):
DAY_TIME_GEN_NO_OVER_FLOW = DayTimeIntervalGen(min_value=timedelta(days=-2000*365), max_value=timedelta(days=3000*365))
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, DAY_TIME_GEN_NO_OVER_FLOW).selectExpr('abs(a)'),
conf={'spark.sql.ansi.enabled': ansi_enabled})

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('ansi_enabled', ['false', 'true'])
def test_abs_ansi_overflow_day_time_interval(ansi_enabled):
assert_gpu_and_cpu_error(
df_fun=lambda spark: _get_overflow_df(spark, [timedelta(microseconds=LONG_MIN)], DayTimeIntervalType(), 'abs(a)').collect(),
conf={'spark.sql.ansi.enabled': ansi_enabled},
error_message='ArithmeticException')

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('ansi_enabled', ['false', 'true'])
def test_addition_day_time_interval(ansi_enabled):
DAY_TIME_GEN_NO_OVER_FLOW = DayTimeIntervalGen(min_value=timedelta(days=-2000*365), max_value=timedelta(days=3000*365))
assert_gpu_and_cpu_are_equal_collect(
lambda spark: two_col_df(spark, DAY_TIME_GEN_NO_OVER_FLOW, DAY_TIME_GEN_NO_OVER_FLOW).select(
f.col('a') + f.col('b')),
conf={'spark.sql.ansi.enabled': ansi_enabled})

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('ansi_enabled', ['false', 'true'])
def test_add_overflow_with_ansi_enabled_day_time_interval(ansi_enabled):
assert_gpu_and_cpu_error(
df_fun=lambda spark: spark.createDataFrame(
SparkContext.getOrCreate().parallelize([(timedelta(microseconds=LONG_MAX), timedelta(microseconds=10)),]),
StructType([StructField('a', DayTimeIntervalType()), StructField('b', DayTimeIntervalType())])
).selectExpr('a + b').collect(),
conf={'spark.sql.ansi.enabled': ansi_enabled},
error_message='ArithmeticException')

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('ansi_enabled', ['false', 'true'])
def test_subtraction_day_time_interval(ansi_enabled):
DAY_TIME_GEN_NO_OVER_FLOW = DayTimeIntervalGen(min_value=timedelta(days=-2000*365), max_value=timedelta(days=3000*365))
assert_gpu_and_cpu_are_equal_collect(
lambda spark: two_col_df(spark, DAY_TIME_GEN_NO_OVER_FLOW, DAY_TIME_GEN_NO_OVER_FLOW).select(
f.col('a') - f.col('b')),
conf={'spark.sql.ansi.enabled': ansi_enabled})

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('ansi_enabled', ['false', 'true'])
def test_subtraction_overflow_with_ansi_enabled_day_time_interval(ansi_enabled):
assert_gpu_and_cpu_error(
df_fun=lambda spark: spark.createDataFrame(
SparkContext.getOrCreate().parallelize([(timedelta(microseconds=LONG_MIN), timedelta(microseconds=10)),]),
StructType([StructField('a', DayTimeIntervalType()), StructField('b', DayTimeIntervalType())])
).selectExpr('a - b').collect(),
conf={'spark.sql.ansi.enabled': ansi_enabled},
error_message='ArithmeticException')

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
def test_unary_positive_day_time_interval():
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, DayTimeIntervalGen()).selectExpr('+a'))
9 changes: 7 additions & 2 deletions integration_tests/src/main/python/ast_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021, NVIDIA CORPORATION.
# Copyright (c) 2021-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,7 +17,7 @@
from asserts import assert_cpu_and_gpu_are_equal_collect_with_capture
jlowe marked this conversation as resolved.
Show resolved Hide resolved
from data_gen import *
from marks import approximate_float
from spark_session import with_cpu_session
from spark_session import with_cpu_session, is_before_spark_330
import pyspark.sql.functions as f

# Each descriptor contains a list of data generators and a corresponding boolean
Expand Down Expand Up @@ -105,6 +105,11 @@ def test_bitwise_not(data_descr):
def test_unary_positive(data_descr):
assert_unary_ast(data_descr, lambda df: df.selectExpr('+a'))

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
def test_unary_positive_for_daytime_interval():
data_descr = (DayTimeIntervalGen(), True)
assert_unary_ast(data_descr, lambda df: df.selectExpr('+a'))

@pytest.mark.parametrize('data_descr', ast_arithmetic_descrs, ids=idfn)
def test_unary_minus(data_descr):
assert_unary_ast(data_descr, lambda df: df.selectExpr('-a'))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,7 @@ object TypeSigUtil extends TypeSigUtilBase {
/** Get numeric and interval TypeSig */
override def getNumericAndInterval(): TypeSig =
TypeSig.cpuNumeric + TypeSig.CALENDAR

/** Get Ansi year-month and day-time TypeSig, begins from 320+ */
override def getAnsiInterval: TypeSig = TypeSig.none
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package com.nvidia.spark.rapids.shims
import ai.rapids.cudf
import ai.rapids.cudf.DType
import com.nvidia.spark.rapids.GpuRowToColumnConverter.TypeConverter
import com.nvidia.spark.rapids.TypeSig

import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.vectorized.ColumnVector
Expand Down Expand Up @@ -80,4 +81,22 @@ object GpuTypeShims {

def csvRead(cv: cudf.ColumnVector, dt: DataType): cudf.ColumnVector =
throw new RuntimeException(s"Not support type $dt.")

/**
* Whether the Shim supports day-time interval type
* Alias, Add, Subtract, Positive... operators do not support day-time interval type
*/
def isSupportedDayTimeType(dt: DataType): Boolean = false

/**
* Whether the Shim supports year-month interval type
* Alias, Add, Subtract, Positive... operators do not support year-month interval type
*/
def isSupportedYearMonthType(dt: DataType): Boolean = false

/**
* Get additional supported types for this Shim
*/
def additionalArithmeticSupportedTypes: TypeSig = TypeSig.none

}
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,7 @@ object TypeSigUtil extends TypeSigUtilBase {
/** Get numeric and interval TypeSig */
override def getNumericAndInterval(): TypeSig =
TypeSig.cpuNumeric + TypeSig.CALENDAR + TypeSig.DAYTIME + TypeSig.YEARMONTH

/** Get Ansi year-month and day-time TypeSig */
override def getAnsiInterval: TypeSig = TypeSig.DAYTIME + TypeSig.YEARMONTH
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ package com.nvidia.spark.rapids.shims

import ai.rapids.cudf
import ai.rapids.cudf.{DType, Scalar}
import com.nvidia.spark.rapids.ColumnarCopyHelper
import com.nvidia.spark.rapids.{ColumnarCopyHelper, TypeSig}
import com.nvidia.spark.rapids.GpuRowToColumnConverter.{IntConverter, LongConverter, NotNullIntConverter, NotNullLongConverter, TypeConverter}

import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, YearMonthIntervalType}
Expand Down Expand Up @@ -155,7 +155,6 @@ object GpuTypeShims {
}
}


def supportCsvRead(dt: DataType) : Boolean = {
dt match {
case DayTimeIntervalType(_, _) => true
Expand All @@ -170,4 +169,21 @@ object GpuTypeShims {
}
}

/**
* Whether the Shim supports day-time interval type
* Alias, Add, Subtract, Positive... operators support day-time interval type
*/
def isSupportedDayTimeType(dt: DataType): Boolean = dt.isInstanceOf[DayTimeIntervalType]

/**
* Whether the Shim supports year-month interval type
* Alias, Add, Subtract, Positive... operators support year-month interval type
*/
def isSupportedYearMonthType(dt: DataType): Boolean = dt.isInstanceOf[YearMonthIntervalType]

/**
* Get additional supported types for this Shim
*/
def additionalArithmeticSupportedTypes: TypeSig = TypeSig.ansiIntervals

}
Loading