Skip to content

Commit

Permalink
Supports casting between ANSI interval types and integral types (#5353)
Browse files Browse the repository at this point in the history
* Supports casting between ANSI interval types and integral types
* Add test cases for nested types; Add the shim layer for hasSideEffects

Signed-off-by: Chong Gao <res_life@163.com>
Co-authored-by: Chong Gao <res_life@163.com>
  • Loading branch information
res-life and Chong Gao authored May 18, 2022
1 parent 5f449cd commit 78750ef
Show file tree
Hide file tree
Showing 10 changed files with 709 additions and 41 deletions.
107 changes: 107 additions & 0 deletions integration_tests/src/main/python/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from marks import allow_non_gpu, approximate_float
from pyspark.sql.types import *
from spark_init_internal import spark_version
import math

_decimal_gen_36_5 = DecimalGen(precision=36, scale=5)

Expand Down Expand Up @@ -408,3 +409,109 @@ def fun(spark):
df = spark.createDataFrame(data, StringType())
return df.select(f.col('value').cast(dtType)).collect()
assert_gpu_and_cpu_error(fun, {}, "java.lang.IllegalArgumentException")

@pytest.mark.skipif(is_before_spark_330(), reason='casting between interval and integral is not supported before Pyspark 3.3.0')
def test_cast_day_time_interval_to_integral_no_overflow():
second_dt_gen = DayTimeIntervalGen(start_field='second', end_field='second', min_value=timedelta(seconds=-128), max_value=timedelta(seconds=127), nullable=False)
gen = StructGen([('a', DayTimeIntervalGen(start_field='day', end_field='day', min_value=timedelta(seconds=-128 * 86400), max_value=timedelta(seconds=127 * 86400))),
('b', DayTimeIntervalGen(start_field='hour', end_field='hour', min_value=timedelta(seconds=-128 * 3600), max_value=timedelta(seconds=127 * 3600))),
('c', DayTimeIntervalGen(start_field='minute', end_field='minute', min_value=timedelta(seconds=-128 * 60), max_value=timedelta(seconds=127 * 60))),
('d', second_dt_gen),
('c_array', ArrayGen(second_dt_gen)),
('c_struct', StructGen([("a", second_dt_gen), ("b", second_dt_gen)])),
('c_map', MapGen(second_dt_gen, second_dt_gen))
], nullable=False)
assert_gpu_and_cpu_are_equal_collect(
lambda spark: gen_df(spark, gen).select(f.col('a').cast(ByteType()), f.col('a').cast(ShortType()), f.col('a').cast(IntegerType()), f.col('a').cast(LongType()),
f.col('b').cast(ByteType()), f.col('b').cast(ShortType()), f.col('b').cast(IntegerType()), f.col('b').cast(LongType()),
f.col('c').cast(ByteType()), f.col('c').cast(ShortType()), f.col('c').cast(IntegerType()), f.col('c').cast(LongType()),
f.col('d').cast(ByteType()), f.col('d').cast(ShortType()), f.col('d').cast(IntegerType()), f.col('d').cast(LongType()),
f.col('c_array').cast(ArrayType(ByteType())),
f.col('c_struct').cast(StructType([StructField('a', ShortType()), StructField('b', ShortType())])),
f.col('c_map').cast(MapType(IntegerType(), IntegerType()))
))

integral_gens_no_overflow = [
LongGen(min_val=math.ceil(LONG_MIN / 86400 / 1000000), max_val=math.floor(LONG_MAX / 86400 / 1000000), special_cases=[0, 1, -1]),
IntegerGen(min_val=math.ceil(INT_MIN / 86400 / 1000000), max_val=math.floor(INT_MAX / 86400 / 1000000), special_cases=[0, 1, -1]),
ShortGen(),
ByteGen(),
# StructGen([("a", ShortGen()), ("b", ByteGen())])
]
@pytest.mark.skipif(is_before_spark_330(), reason='casting between interval and integral is not supported before Pyspark 3.3.0')
def test_cast_integral_to_day_time_interval_no_overflow():
long_gen = IntegerGen(min_val=math.ceil(INT_MIN / 86400 / 1000000), max_val=math.floor(INT_MAX / 86400 / 1000000), special_cases=[0, 1, -1])
int_gen = LongGen(min_val=math.ceil(LONG_MIN / 86400 / 1000000), max_val=math.floor(LONG_MAX / 86400 / 1000000), special_cases=[0, 1, -1], nullable=False)
gen = StructGen([("a", long_gen),
("b", int_gen),
("c", ShortGen()),
("d", ByteGen()),
("c_struct", StructGen([("a", long_gen), ("b", int_gen)], nullable=False)),
('c_array', ArrayGen(int_gen)),
('c_map', MapGen(int_gen, long_gen))], nullable=False)
# day_time_field: 0 is day, 1 is hour, 2 is minute, 3 is second
day_type = DayTimeIntervalType(0, 0)
hour_type = DayTimeIntervalType(1, 1)
minute_type = DayTimeIntervalType(2, 2)
second_type = DayTimeIntervalType(3, 3)

assert_gpu_and_cpu_are_equal_collect(
lambda spark: gen_df(spark, gen).select(
f.col('a').cast(day_type), f.col('a').cast(hour_type), f.col('a').cast(minute_type), f.col('a').cast(second_type),
f.col('b').cast(day_type), f.col('b').cast(hour_type), f.col('b').cast(minute_type), f.col('b').cast(second_type),
f.col('c').cast(day_type), f.col('c').cast(hour_type), f.col('c').cast(minute_type), f.col('c').cast(second_type),
f.col('d').cast(day_type), f.col('d').cast(hour_type), f.col('d').cast(minute_type), f.col('d').cast(second_type),
f.col('c_struct').cast(StructType([StructField('a', day_type), StructField('b', hour_type)])),
f.col('c_array').cast(ArrayType(hour_type)),
f.col('c_map').cast(MapType(minute_type, second_type)),
))

cast_day_time_to_inregral_overflow_pairs = [
(INT_MIN - 1, IntegerType()),
(INT_MAX + 1, IntegerType()),
(SHORT_MIN - 1, ShortType()),
(SHORT_MAX + 1, ShortType()),
(BYTE_MIN - 1, ByteType()),
(BYTE_MAX + 1, ByteType())
]
@pytest.mark.skipif(is_before_spark_330(), reason='casting between interval and integral is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('large_second, integral_type', cast_day_time_to_inregral_overflow_pairs)
def test_cast_day_time_interval_to_integral_overflow(large_second, integral_type):
def getDf(spark):
return spark.createDataFrame([timedelta(seconds=large_second)], DayTimeIntervalType(DayTimeIntervalType.SECOND, DayTimeIntervalType.SECOND))
assert_gpu_and_cpu_error(
lambda spark: getDf(spark).select(f.col('value').cast(integral_type)).collect(),
conf={},
error_message="overflow")

day_time_interval_max_day = math.floor(LONG_MAX / (86400 * 1000000))
large_days_overflow_pairs = [
(-day_time_interval_max_day - 1, LongType()),
(+day_time_interval_max_day + 1, LongType()),
(-day_time_interval_max_day - 1, IntegerType()),
(+day_time_interval_max_day + 1, IntegerType())
]
@pytest.mark.skipif(is_before_spark_330(), reason='casting between interval and integral is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('large_day,integral_type', large_days_overflow_pairs)
def test_cast_integral_to_day_time_interval_overflow(large_day, integral_type):
def getDf(spark):
return spark.createDataFrame([large_day], integral_type)
assert_gpu_and_cpu_error(
lambda spark: getDf(spark).select(f.col('value').cast(DayTimeIntervalType(DayTimeIntervalType.DAY, DayTimeIntervalType.DAY))).collect(),
conf={},
error_message="overflow")

@pytest.mark.skipif(is_before_spark_330(), reason='casting between interval and integral is not supported before Pyspark 3.3.0')
def test_cast_integral_to_day_time_side_effect():
def getDf(spark):
# INT_MAX > 106751991 (max value of interval day)
return spark.createDataFrame([(True, INT_MAX, LONG_MAX), (False, 0, 0)], "c_b boolean, c_i int, c_l long").repartition(1)
assert_gpu_and_cpu_are_equal_collect(
lambda spark: getDf(spark).selectExpr("if(c_b, interval 0 day, cast(c_i as interval day))", "if(c_b, interval 0 second, cast(c_l as interval second))"))

@pytest.mark.skipif(is_before_spark_330(), reason='casting between interval and integral is not supported before Pyspark 3.3.0')
def test_cast_day_time_to_integral_side_effect():
def getDf(spark):
# 106751991 > Byte.MaxValue
return spark.createDataFrame([(True, MAX_DAY_TIME_INTERVAL), (False, (timedelta(microseconds=0)))], "c_b boolean, c_dt interval day to second").repartition(1)
assert_gpu_and_cpu_are_equal_collect(lambda spark: getDf(spark).selectExpr("if(c_b, 0, cast(c_dt as byte))", "if(c_b, 0, cast(c_dt as short))", "if(c_b, 0, cast(c_dt as int))"))
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
package com.nvidia.spark.rapids.shims

import ai.rapids.cudf.ColumnVector
import ai.rapids.cudf.{ColumnVector, ColumnView}

import org.apache.spark.sql.types.DataType

Expand All @@ -24,11 +24,59 @@ import org.apache.spark.sql.types.DataType
*/
object GpuIntervalUtils {

def castStringToDayTimeIntervalWithThrow(cv: ColumnVector, t: DataType): ColumnVector = {
throw new IllegalStateException()
def castStringToDayTimeIntervalWithThrow(cv: ColumnView, t: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def toDayTimeIntervalString(micros: ColumnVector, dayTimeType: DataType): ColumnVector = {
throw new IllegalStateException()
def toDayTimeIntervalString(micros: ColumnView, dayTimeType: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def dayTimeIntervalToLong(dtCv: ColumnView, dt: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def dayTimeIntervalToInt(dtCv: ColumnView, dt: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def dayTimeIntervalToShort(dtCv: ColumnView, dt: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def dayTimeIntervalToByte(dtCv: ColumnView, dt: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def yearMonthIntervalToLong(ymCv: ColumnView, ym: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def yearMonthIntervalToInt(ymCv: ColumnView, ym: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def yearMonthIntervalToShort(ymCv: ColumnView, ym: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def yearMonthIntervalToByte(ymCv: ColumnView, ym: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def longToDayTimeInterval(longCv: ColumnView, dt: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def intToDayTimeInterval(intCv: ColumnView, dt: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def longToYearMonthInterval(longCv: ColumnView, ym: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def intToYearMonthInterval(intCv: ColumnView, ym: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,14 @@ object GpuTypeShims {

def typesDayTimeCanCastTo: TypeSig = TypeSig.none

def typesYearMonthCanCastTo: TypeSig = TypeSig.none

def typesDayTimeCanCastToOnSpark: TypeSig = TypeSig.DAYTIME + TypeSig.STRING

def typesYearMonthCanCastToOnSpark: TypeSig = TypeSig.YEARMONTH + TypeSig.STRING

def additionalTypesIntegralCanCastTo: TypeSig = TypeSig.none

def additionalTypesStringCanCastTo: TypeSig = TypeSig.none

/**
Expand All @@ -125,4 +133,7 @@ object GpuTypeShims {
*/
def additionalCommonOperatorSupportedTypes: TypeSig = TypeSig.none

def hasSideEffectsIfCastIntToYearMonth(ym: DataType): Boolean = false

def hasSideEffectsIfCastIntToDayTime(dt: DataType): Boolean = false
}
Loading

0 comments on commit 78750ef

Please sign in to comment.