Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supports casting between ANSI interval types and integral types #5353

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions integration_tests/src/main/python/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from marks import allow_non_gpu, approximate_float
from pyspark.sql.types import *
from spark_init_internal import spark_version
import math

_decimal_gen_36_5 = DecimalGen(precision=36, scale=5)

Expand Down Expand Up @@ -343,3 +344,109 @@ def fun(spark):
df = spark.createDataFrame(data, StringType())
return df.select(f.col('value').cast(dtType)).collect()
assert_gpu_and_cpu_error(fun, {}, "java.lang.IllegalArgumentException")
revans2 marked this conversation as resolved.
Show resolved Hide resolved

@pytest.mark.skipif(is_before_spark_330(), reason='casting between interval and integral is not supported before Pyspark 3.3.0')
def test_cast_day_time_interval_to_integral_no_overflow():
second_dt_gen = DayTimeIntervalGen(start_field='second', end_field='second', min_value=timedelta(seconds=-128), max_value=timedelta(seconds=127), nullable=False)
gen = StructGen([('a', DayTimeIntervalGen(start_field='day', end_field='day', min_value=timedelta(seconds=-128 * 86400), max_value=timedelta(seconds=127 * 86400))),
('b', DayTimeIntervalGen(start_field='hour', end_field='hour', min_value=timedelta(seconds=-128 * 3600), max_value=timedelta(seconds=127 * 3600))),
('c', DayTimeIntervalGen(start_field='minute', end_field='minute', min_value=timedelta(seconds=-128 * 60), max_value=timedelta(seconds=127 * 60))),
('d', second_dt_gen),
('c_array', ArrayGen(second_dt_gen)),
('c_struct', StructGen([("a", second_dt_gen), ("b", second_dt_gen)])),
('c_map', MapGen(second_dt_gen, second_dt_gen))
], nullable=False)
assert_gpu_and_cpu_are_equal_collect(
lambda spark: gen_df(spark, gen).select(f.col('a').cast(ByteType()), f.col('a').cast(ShortType()), f.col('a').cast(IntegerType()), f.col('a').cast(LongType()),
f.col('b').cast(ByteType()), f.col('b').cast(ShortType()), f.col('b').cast(IntegerType()), f.col('b').cast(LongType()),
f.col('c').cast(ByteType()), f.col('c').cast(ShortType()), f.col('c').cast(IntegerType()), f.col('c').cast(LongType()),
f.col('d').cast(ByteType()), f.col('d').cast(ShortType()), f.col('d').cast(IntegerType()), f.col('d').cast(LongType()),
f.col('c_array').cast(ArrayType(ByteType())),
f.col('c_struct').cast(StructType([StructField('a', ShortType()), StructField('b', ShortType())])),
f.col('c_map').cast(MapType(IntegerType(), IntegerType()))
))

integral_gens_no_overflow = [
LongGen(min_val=math.ceil(LONG_MIN / 86400 / 1000000), max_val=math.floor(LONG_MAX / 86400 / 1000000), special_cases=[0, 1, -1]),
IntegerGen(min_val=math.ceil(INT_MIN / 86400 / 1000000), max_val=math.floor(INT_MAX / 86400 / 1000000), special_cases=[0, 1, -1]),
ShortGen(),
ByteGen(),
# StructGen([("a", ShortGen()), ("b", ByteGen())])
]
@pytest.mark.skipif(is_before_spark_330(), reason='casting between interval and integral is not supported before Pyspark 3.3.0')
def test_cast_integral_to_day_time_interval_no_overflow():
long_gen = IntegerGen(min_val=math.ceil(INT_MIN / 86400 / 1000000), max_val=math.floor(INT_MAX / 86400 / 1000000), special_cases=[0, 1, -1])
int_gen = LongGen(min_val=math.ceil(LONG_MIN / 86400 / 1000000), max_val=math.floor(LONG_MAX / 86400 / 1000000), special_cases=[0, 1, -1], nullable=False)
gen = StructGen([("a", long_gen),
("b", int_gen),
("c", ShortGen()),
("d", ByteGen()),
("c_struct", StructGen([("a", long_gen), ("b", int_gen)], nullable=False)),
('c_array', ArrayGen(int_gen)),
('c_map', MapGen(int_gen, long_gen))], nullable=False)
# day_time_field: 0 is day, 1 is hour, 2 is minute, 3 is second
day_type = DayTimeIntervalType(0, 0)
hour_type = DayTimeIntervalType(1, 1)
minute_type = DayTimeIntervalType(2, 2)
second_type = DayTimeIntervalType(3, 3)

assert_gpu_and_cpu_are_equal_collect(
lambda spark: gen_df(spark, gen).select(
f.col('a').cast(day_type), f.col('a').cast(hour_type), f.col('a').cast(minute_type), f.col('a').cast(second_type),
f.col('b').cast(day_type), f.col('b').cast(hour_type), f.col('b').cast(minute_type), f.col('b').cast(second_type),
f.col('c').cast(day_type), f.col('c').cast(hour_type), f.col('c').cast(minute_type), f.col('c').cast(second_type),
f.col('d').cast(day_type), f.col('d').cast(hour_type), f.col('d').cast(minute_type), f.col('d').cast(second_type),
f.col('c_struct').cast(StructType([StructField('a', day_type), StructField('b', hour_type)])),
f.col('c_array').cast(ArrayType(hour_type)),
f.col('c_map').cast(MapType(minute_type, second_type)),
))

cast_day_time_to_inregral_overflow_pairs = [
(INT_MIN - 1, IntegerType()),
(INT_MAX + 1, IntegerType()),
(SHORT_MIN - 1, ShortType()),
(SHORT_MAX + 1, ShortType()),
(BYTE_MIN - 1, ByteType()),
(BYTE_MAX + 1, ByteType())
]
@pytest.mark.skipif(is_before_spark_330(), reason='casting between interval and integral is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('large_second, integral_type', cast_day_time_to_inregral_overflow_pairs)
def test_cast_day_time_interval_to_integral_overflow(large_second, integral_type):
def getDf(spark):
return spark.createDataFrame([timedelta(seconds=large_second)], DayTimeIntervalType(DayTimeIntervalType.SECOND, DayTimeIntervalType.SECOND))
assert_gpu_and_cpu_error(
lambda spark: getDf(spark).select(f.col('value').cast(integral_type)).collect(),
conf={},
error_message="overflow")

day_time_interval_max_day = math.floor(LONG_MAX / (86400 * 1000000))
large_days_overflow_pairs = [
(-day_time_interval_max_day - 1, LongType()),
(+day_time_interval_max_day + 1, LongType()),
(-day_time_interval_max_day - 1, IntegerType()),
(+day_time_interval_max_day + 1, IntegerType())
]
@pytest.mark.skipif(is_before_spark_330(), reason='casting between interval and integral is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('large_day,integral_type', large_days_overflow_pairs)
def test_cast_integral_to_day_time_interval_overflow(large_day, integral_type):
def getDf(spark):
return spark.createDataFrame([large_day], integral_type)
assert_gpu_and_cpu_error(
lambda spark: getDf(spark).select(f.col('value').cast(DayTimeIntervalType(DayTimeIntervalType.DAY, DayTimeIntervalType.DAY))).collect(),
conf={},
error_message="overflow")

@pytest.mark.skipif(is_before_spark_330(), reason='casting between interval and integral is not supported before Pyspark 3.3.0')
def test_cast_integral_to_day_time_side_effect():
def getDf(spark):
# INT_MAX > 106751991 (max value of interval day)
return spark.createDataFrame([(True, INT_MAX, LONG_MAX), (False, 0, 0)], "c_b boolean, c_i int, c_l long").repartition(1)
assert_gpu_and_cpu_are_equal_collect(
lambda spark: getDf(spark).selectExpr("if(c_b, interval 0 day, cast(c_i as interval day))", "if(c_b, interval 0 second, cast(c_l as interval second))"))

@pytest.mark.skipif(is_before_spark_330(), reason='casting between interval and integral is not supported before Pyspark 3.3.0')
def test_cast_day_time_to_integral_side_effect():
def getDf(spark):
# 106751991 > Byte.MaxValue
return spark.createDataFrame([(True, MAX_DAY_TIME_INTERVAL), (False, (timedelta(microseconds=0)))], "c_b boolean, c_dt interval day to second").repartition(1)
assert_gpu_and_cpu_are_equal_collect(lambda spark: getDf(spark).selectExpr("if(c_b, 0, cast(c_dt as byte))", "if(c_b, 0, cast(c_dt as short))", "if(c_b, 0, cast(c_dt as int))"))
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
package com.nvidia.spark.rapids.shims

import ai.rapids.cudf.ColumnVector
import ai.rapids.cudf.{ColumnVector, ColumnView}

import org.apache.spark.sql.types.DataType

Expand All @@ -24,11 +24,59 @@ import org.apache.spark.sql.types.DataType
*/
object GpuIntervalUtils {

def castStringToDayTimeIntervalWithThrow(cv: ColumnVector, t: DataType): ColumnVector = {
throw new IllegalStateException()
def castStringToDayTimeIntervalWithThrow(cv: ColumnView, t: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def toDayTimeIntervalString(micros: ColumnVector, dayTimeType: DataType): ColumnVector = {
throw new IllegalStateException()
def toDayTimeIntervalString(micros: ColumnView, dayTimeType: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def dayTimeIntervalToLong(dtCv: ColumnView, dt: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def dayTimeIntervalToInt(dtCv: ColumnView, dt: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def dayTimeIntervalToShort(dtCv: ColumnView, dt: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def dayTimeIntervalToByte(dtCv: ColumnView, dt: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def yearMonthIntervalToLong(ymCv: ColumnView, ym: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def yearMonthIntervalToInt(ymCv: ColumnView, ym: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def yearMonthIntervalToShort(ymCv: ColumnView, ym: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def yearMonthIntervalToByte(ymCv: ColumnView, ym: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def longToDayTimeInterval(longCv: ColumnView, dt: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def intToDayTimeInterval(intCv: ColumnView, dt: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def longToYearMonthInterval(longCv: ColumnView, ym: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def intToYearMonthInterval(intCv: ColumnView, ym: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,14 @@ object GpuTypeShims {

def typesDayTimeCanCastTo: TypeSig = TypeSig.none

def typesYearMonthCanCastTo: TypeSig = TypeSig.none

def typesDayTimeCanCastToOnSpark: TypeSig = TypeSig.DAYTIME + TypeSig.STRING

def typesYearMonthCanCastToOnSpark: TypeSig = TypeSig.YEARMONTH + TypeSig.STRING

def additionalTypesIntegralCanCastTo: TypeSig = TypeSig.none

def additionalTypesStringCanCastTo: TypeSig = TypeSig.none

/**
Expand All @@ -125,4 +133,7 @@ object GpuTypeShims {
*/
def additionalCommonOperatorSupportedTypes: TypeSig = TypeSig.none

def hasSideEffectsIfCastIntToYearMonth(ym: DataType): Boolean = false

def hasSideEffectsIfCastIntToDayTime(dt: DataType): Boolean = false
}
Loading