Skip to content

Commit

Permalink
Support arithmetic operators on ANSI interval types (#5020)
Browse files Browse the repository at this point in the history
* Support arithmetic operators on ANSI interval types

Signed-off-by: Chong Gao <res_life@163.com>

* Support arithmetic operators on year month interval types

Signed-off-by: Chong Gao <res_life@163.com>

* Refactor

* Fix

* Refactor

Signed-off-by: Chong Gao <res_life@163.com>

* Refactor

* Refactor

* Fix

* Refactor

Co-authored-by: Chong Gao <res_life@163.com>
  • Loading branch information
res-life and Chong Gao authored Apr 8, 2022
1 parent 963a008 commit 09da145
Show file tree
Hide file tree
Showing 12 changed files with 565 additions and 108 deletions.
81 changes: 80 additions & 1 deletion integration_tests/src/main/python/arithmetic_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@

import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_error, assert_gpu_fallback_collect
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_error, assert_gpu_fallback_collect, assert_gpu_and_cpu_are_equal_sql
from data_gen import *
from marks import ignore_order, incompat, approximate_float, allow_non_gpu
from pyspark.sql.types import *
from pyspark.sql.types import IntegralType
from spark_session import with_cpu_session, with_gpu_session, with_spark_session, is_before_spark_320, is_before_spark_330, is_databricks91_or_later
import pyspark.sql.functions as f
from datetime import timedelta

# No overflow gens here because we just focus on verifying the fallback to CPU when
# enabling ANSI mode. But overflows will fail the tests because CPU runs raise
Expand Down Expand Up @@ -867,3 +868,81 @@ def test_subtraction_overflow_with_ansi_enabled(data, tp, expr):
assert_gpu_and_cpu_are_equal_collect(
func=lambda spark: _get_overflow_df(spark, data, tp, expr),
conf=ansi_enabled_conf)


@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('ansi_enabled', ['false', 'true'])
def test_unary_minus_day_time_interval(ansi_enabled):
DAY_TIME_GEN_NO_OVER_FLOW = DayTimeIntervalGen(min_value=timedelta(days=-2000*365), max_value=timedelta(days=3000*365))
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, DAY_TIME_GEN_NO_OVER_FLOW).selectExpr('-a'),
conf={'spark.sql.ansi.enabled': ansi_enabled})

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('ansi_enabled', ['false', 'true'])
def test_unary_minus_ansi_overflow_day_time_interval(ansi_enabled):
assert_gpu_and_cpu_error(
df_fun=lambda spark: _get_overflow_df(spark, [timedelta(microseconds=LONG_MIN)], DayTimeIntervalType(), '-a').collect(),
conf={'spark.sql.ansi.enabled': ansi_enabled},
error_message='ArithmeticException')

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('ansi_enabled', ['false', 'true'])
def test_abs_ansi_no_overflow_day_time_interval(ansi_enabled):
DAY_TIME_GEN_NO_OVER_FLOW = DayTimeIntervalGen(min_value=timedelta(days=-2000*365), max_value=timedelta(days=3000*365))
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, DAY_TIME_GEN_NO_OVER_FLOW).selectExpr('abs(a)'),
conf={'spark.sql.ansi.enabled': ansi_enabled})

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('ansi_enabled', ['false', 'true'])
def test_abs_ansi_overflow_day_time_interval(ansi_enabled):
assert_gpu_and_cpu_error(
df_fun=lambda spark: _get_overflow_df(spark, [timedelta(microseconds=LONG_MIN)], DayTimeIntervalType(), 'abs(a)').collect(),
conf={'spark.sql.ansi.enabled': ansi_enabled},
error_message='ArithmeticException')

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('ansi_enabled', ['false', 'true'])
def test_addition_day_time_interval(ansi_enabled):
DAY_TIME_GEN_NO_OVER_FLOW = DayTimeIntervalGen(min_value=timedelta(days=-2000*365), max_value=timedelta(days=3000*365))
assert_gpu_and_cpu_are_equal_collect(
lambda spark: two_col_df(spark, DAY_TIME_GEN_NO_OVER_FLOW, DAY_TIME_GEN_NO_OVER_FLOW).select(
f.col('a') + f.col('b')),
conf={'spark.sql.ansi.enabled': ansi_enabled})

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('ansi_enabled', ['false', 'true'])
def test_add_overflow_with_ansi_enabled_day_time_interval(ansi_enabled):
assert_gpu_and_cpu_error(
df_fun=lambda spark: spark.createDataFrame(
SparkContext.getOrCreate().parallelize([(timedelta(microseconds=LONG_MAX), timedelta(microseconds=10)),]),
StructType([StructField('a', DayTimeIntervalType()), StructField('b', DayTimeIntervalType())])
).selectExpr('a + b').collect(),
conf={'spark.sql.ansi.enabled': ansi_enabled},
error_message='ArithmeticException')

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('ansi_enabled', ['false', 'true'])
def test_subtraction_day_time_interval(ansi_enabled):
DAY_TIME_GEN_NO_OVER_FLOW = DayTimeIntervalGen(min_value=timedelta(days=-2000*365), max_value=timedelta(days=3000*365))
assert_gpu_and_cpu_are_equal_collect(
lambda spark: two_col_df(spark, DAY_TIME_GEN_NO_OVER_FLOW, DAY_TIME_GEN_NO_OVER_FLOW).select(
f.col('a') - f.col('b')),
conf={'spark.sql.ansi.enabled': ansi_enabled})

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('ansi_enabled', ['false', 'true'])
def test_subtraction_overflow_with_ansi_enabled_day_time_interval(ansi_enabled):
assert_gpu_and_cpu_error(
df_fun=lambda spark: spark.createDataFrame(
SparkContext.getOrCreate().parallelize([(timedelta(microseconds=LONG_MIN), timedelta(microseconds=10)),]),
StructType([StructField('a', DayTimeIntervalType()), StructField('b', DayTimeIntervalType())])
).selectExpr('a - b').collect(),
conf={'spark.sql.ansi.enabled': ansi_enabled},
error_message='ArithmeticException')

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
def test_unary_positive_day_time_interval():
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, DayTimeIntervalGen()).selectExpr('+a'))
9 changes: 7 additions & 2 deletions integration_tests/src/main/python/ast_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021, NVIDIA CORPORATION.
# Copyright (c) 2021-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,7 +17,7 @@
from asserts import assert_cpu_and_gpu_are_equal_collect_with_capture
from data_gen import *
from marks import approximate_float
from spark_session import with_cpu_session
from spark_session import with_cpu_session, is_before_spark_330
import pyspark.sql.functions as f

# Each descriptor contains a list of data generators and a corresponding boolean
Expand Down Expand Up @@ -105,6 +105,11 @@ def test_bitwise_not(data_descr):
def test_unary_positive(data_descr):
assert_unary_ast(data_descr, lambda df: df.selectExpr('+a'))

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
def test_unary_positive_for_daytime_interval():
data_descr = (DayTimeIntervalGen(), True)
assert_unary_ast(data_descr, lambda df: df.selectExpr('+a'))

@pytest.mark.parametrize('data_descr', ast_arithmetic_descrs, ids=idfn)
def test_unary_minus(data_descr):
assert_unary_ast(data_descr, lambda df: df.selectExpr('-a'))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,7 @@ object TypeSigUtil extends TypeSigUtilBase {
/** Get numeric and interval TypeSig */
override def getNumericAndInterval(): TypeSig =
TypeSig.cpuNumeric + TypeSig.CALENDAR

/** Get Ansi year-month and day-time TypeSig, begins from 320+ */
override def getAnsiInterval: TypeSig = TypeSig.none
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package com.nvidia.spark.rapids.shims
import ai.rapids.cudf
import ai.rapids.cudf.DType
import com.nvidia.spark.rapids.GpuRowToColumnConverter.TypeConverter
import com.nvidia.spark.rapids.TypeSig

import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.vectorized.ColumnVector
Expand Down Expand Up @@ -80,4 +81,22 @@ object GpuTypeShims {

def csvRead(cv: cudf.ColumnVector, dt: DataType): cudf.ColumnVector =
throw new RuntimeException(s"Not support type $dt.")

/**
* Whether the Shim supports day-time interval type
* Alias, Add, Subtract, Positive... operators do not support day-time interval type
*/
def isSupportedDayTimeType(dt: DataType): Boolean = false

/**
* Whether the Shim supports year-month interval type
* Alias, Add, Subtract, Positive... operators do not support year-month interval type
*/
def isSupportedYearMonthType(dt: DataType): Boolean = false

/**
* Get additional supported types for this Shim
*/
def additionalArithmeticSupportedTypes: TypeSig = TypeSig.none

}
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,7 @@ object TypeSigUtil extends TypeSigUtilBase {
/** Get numeric and interval TypeSig */
override def getNumericAndInterval(): TypeSig =
TypeSig.cpuNumeric + TypeSig.CALENDAR + TypeSig.DAYTIME + TypeSig.YEARMONTH

/** Get Ansi year-month and day-time TypeSig */
override def getAnsiInterval: TypeSig = TypeSig.DAYTIME + TypeSig.YEARMONTH
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ package com.nvidia.spark.rapids.shims

import ai.rapids.cudf
import ai.rapids.cudf.{DType, Scalar}
import com.nvidia.spark.rapids.ColumnarCopyHelper
import com.nvidia.spark.rapids.{ColumnarCopyHelper, TypeSig}
import com.nvidia.spark.rapids.GpuRowToColumnConverter.{IntConverter, LongConverter, NotNullIntConverter, NotNullLongConverter, TypeConverter}

import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, YearMonthIntervalType}
Expand Down Expand Up @@ -155,7 +155,6 @@ object GpuTypeShims {
}
}


def supportCsvRead(dt: DataType) : Boolean = {
dt match {
case DayTimeIntervalType(_, _) => true
Expand All @@ -170,4 +169,21 @@ object GpuTypeShims {
}
}

/**
* Whether the Shim supports day-time interval type
* Alias, Add, Subtract, Positive... operators support day-time interval type
*/
def isSupportedDayTimeType(dt: DataType): Boolean = dt.isInstanceOf[DayTimeIntervalType]

/**
* Whether the Shim supports year-month interval type
* Alias, Add, Subtract, Positive... operators support year-month interval type
*/
def isSupportedYearMonthType(dt: DataType): Boolean = dt.isInstanceOf[YearMonthIntervalType]

/**
* Get additional supported types for this Shim
*/
def additionalArithmeticSupportedTypes: TypeSig = TypeSig.ansiIntervals

}
Loading

0 comments on commit 09da145

Please sign in to comment.