Skip to content

Commit

Permalink
Supports casting between ANSI interval types and integral types
Browse files Browse the repository at this point in the history
Signed-off-by: Chong Gao <res_life@163.com>
  • Loading branch information
Chong Gao committed Apr 27, 2022
1 parent 785b4ac commit f570c04
Show file tree
Hide file tree
Showing 8 changed files with 323 additions and 22 deletions.
37 changes: 37 additions & 0 deletions integration_tests/src/main/python/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from marks import allow_non_gpu, approximate_float
from pyspark.sql.types import *
from spark_init_internal import spark_version
import math

_decimal_gen_36_5 = DecimalGen(precision=36, scale=5)

Expand Down Expand Up @@ -343,3 +344,39 @@ def fun(spark):
df = spark.createDataFrame(data, StringType())
return df.select(f.col('value').cast(dtType)).collect()
assert_gpu_and_cpu_error(fun, {}, "java.lang.IllegalArgumentException")

integral_types = [ByteType(), ShortType(), IntegerType(), LongType()]
@pytest.mark.skipif(is_before_spark_330(), reason='casting between interval and integral is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('integral_type', integral_types)
def test_cast_day_time_interval_to_integral_no_overflow(integral_type):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, DayTimeIntervalGen(start_field='day', end_field='day', min_value=timedelta(seconds=-128 * 86400), max_value=timedelta(seconds=127 * 86400)))
.select(f.col('a').cast(integral_type)))
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, DayTimeIntervalGen(start_field='hour', end_field='hour', min_value=timedelta(seconds=-128 * 3600), max_value=timedelta(seconds=127 * 3600)))
.select(f.col('a').cast(integral_type)))
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, DayTimeIntervalGen(start_field='minute', end_field='minute', min_value=timedelta(seconds=-128 * 60), max_value=timedelta(seconds=127 * 60)))
.select(f.col('a').cast(integral_type)))
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, DayTimeIntervalGen(start_field='second', end_field='second', min_value=timedelta(seconds=-128), max_value=timedelta(seconds=127)))
.select(f.col('a').cast(integral_type)))


integral_gens_no_overflow = [
LongGen(min_val=math.ceil(LONG_MIN / 86400 / 1000000), max_val=math.floor(LONG_MAX / 86400 / 1000000), special_cases=[0, 1, -1]),
IntegerGen(min_val=math.ceil(INT_MIN / 86400 / 1000000), max_val=math.floor(INT_MAX / 86400 / 1000000), special_cases=[0, 1, -1]),
ShortGen(),
ByteGen()
]
@pytest.mark.skipif(is_before_spark_330(), reason='casting between interval and integral is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('integral_gen_no_overflow', integral_gens_no_overflow)
def test_cast_integral_to_day_time_interval_no_overflow(integral_gen_no_overflow):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, integral_gen_no_overflow).select(f.col('a').cast(DayTimeIntervalType(0, 0))))
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, integral_gen_no_overflow).select(f.col('a').cast(DayTimeIntervalType(1, 1))))
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, integral_gen_no_overflow).select(f.col('a').cast(DayTimeIntervalType(2, 2))))
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, integral_gen_no_overflow).select(f.col('a').cast(DayTimeIntervalType(3, 3))))
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,58 @@ import org.apache.spark.sql.types.DataType
object GpuIntervalUtils {

def castStringToDayTimeIntervalWithThrow(cv: ColumnVector, t: DataType): ColumnVector = {
throw new IllegalStateException()
throw new IllegalStateException("Not supported in this Shim")
}

def toDayTimeIntervalString(micros: ColumnVector, dayTimeType: DataType): ColumnVector = {
throw new IllegalStateException()
throw new IllegalStateException("Not supported in this Shim")
}

def dayTimeIntervalToLong(dtCv: ColumnVector, dt: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def dayTimeIntervalToInt(dtCv: ColumnVector, dt: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def dayTimeIntervalToShort(dtCv: ColumnVector, dt: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def dayTimeIntervalToByte(dtCv: ColumnVector, dt: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def yearMonthIntervalToLong(ymCv: ColumnVector, ym: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def yearMonthIntervalToInt(ymCv: ColumnVector, ym: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def yearMonthIntervalToShort(ymCv: ColumnVector, ym: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def yearMonthIntervalToByte(ymCv: ColumnVector, ym: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def longToDayTimeInterval(longCv: ColumnVector, dt: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def intToDayTimeInterval(intCv: ColumnVector, dt: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def longToYearMonthInterval(longCv: ColumnVector, ym: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}

def intToYearMonthInterval(intCv: ColumnVector, ym: DataType): ColumnVector = {
throw new IllegalStateException("Not supported in this Shim")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,14 @@ object GpuTypeShims {

def typesDayTimeCanCastTo: TypeSig = TypeSig.none

def typesYearMonthCanCastTo: TypeSig = TypeSig.none

def typesDayTimeCanCastToOnSpark: TypeSig = TypeSig.DAYTIME + TypeSig.STRING

def typesYearMonthCanCastToOnSpark: TypeSig = TypeSig.YEARMONTH + TypeSig.STRING

def additionalTypesIntegralCanCastTo: TypeSig = TypeSig.none

def additionalTypesStringCanCastTo: TypeSig = TypeSig.none

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ import java.util.concurrent.TimeUnit.{DAYS, HOURS, MINUTES, SECONDS}
import scala.collection.mutable.ArrayBuffer

import ai.rapids.cudf.{ColumnVector, ColumnView, DType, Scalar}
import com.nvidia.spark.rapids.Arm
import com.nvidia.spark.rapids.CloseableHolder
import com.nvidia.spark.rapids.{Arm, BoolUtils, CloseableHolder}

import org.apache.spark.sql.catalyst.util.DateTimeConstants.{MICROS_PER_DAY, MICROS_PER_HOUR, MICROS_PER_MINUTE, MICROS_PER_SECOND}
import org.apache.spark.sql.types.{DataType, DayTimeIntervalType => DT}
import org.apache.spark.sql.catalyst.util.DateTimeConstants.{MICROS_PER_DAY, MICROS_PER_HOUR, MICROS_PER_MINUTE, MICROS_PER_SECOND, MONTHS_PER_YEAR}
import org.apache.spark.sql.rapids.shims.IntervalUtils
import org.apache.spark.sql.types.{DataType, DayTimeIntervalType => DT, YearMonthIntervalType => YM}

/**
* Parse DayTimeIntervalType string column to long column of micro seconds
Expand Down Expand Up @@ -741,4 +741,151 @@ object GpuIntervalUtils extends Arm {
}
}
}

def dayTimeIntervalToLong(dtCv: ColumnVector, dt: DataType): ColumnVector = {
dt.asInstanceOf[DT].endField match {
case DT.DAY => withResource(Scalar.fromLong(MICROS_PER_DAY)) { micros =>
dtCv.div(micros)
}
case DT.HOUR => withResource(Scalar.fromLong(MICROS_PER_HOUR)) { micros =>
dtCv.div(micros)
}
case DT.MINUTE => withResource(Scalar.fromLong(MICROS_PER_MINUTE)) { micros =>
dtCv.div(micros)
}
case DT.SECOND => withResource(Scalar.fromLong(MICROS_PER_SECOND)) { micros =>
dtCv.div(micros)
}
}
}

def dayTimeIntervalToInt(dtCv: ColumnVector, dt: DataType): ColumnVector = {
withResource(dayTimeIntervalToLong(dtCv, dt)) { longCv =>
castToTargetWithOverflowCheck(longCv, DType.INT32)
}
}

def dayTimeIntervalToShort(dtCv: ColumnVector, dt: DataType): ColumnVector = {
withResource(dayTimeIntervalToLong(dtCv, dt)) { longCv =>
castToTargetWithOverflowCheck(longCv, DType.INT16)
}
}

def dayTimeIntervalToByte(dtCv: ColumnVector, dt: DataType): ColumnVector = {
withResource(dayTimeIntervalToLong(dtCv, dt)) { longCv =>
castToTargetWithOverflowCheck(longCv, DType.INT8)
}
}

def yearMonthIntervalToLong(ymCv: ColumnVector, ym: DataType): ColumnVector = {
ym.asInstanceOf[YM].endField match {
case YM.YEAR => withResource(Scalar.fromLong(MONTHS_PER_YEAR)) { monthsPerYear =>
ymCv.div(monthsPerYear)
}
case YM.MONTH => ymCv.castTo(DType.INT64)
}
}

def yearMonthIntervalToInt(ymCv: ColumnVector, ym: DataType): ColumnVector = {
ym.asInstanceOf[YM].endField match {
case YM.YEAR => withResource(Scalar.fromInt(MONTHS_PER_YEAR)) { monthsPerYear =>
ymCv.div(monthsPerYear)
}
case YM.MONTH => ymCv.incRefCount()
}
}

def yearMonthIntervalToShort(ymCv: ColumnVector, ym: DataType): ColumnVector = {
withResource(yearMonthIntervalToInt(ymCv, ym)) { i =>
castToTargetWithOverflowCheck(i, DType.INT16)
}
}

def yearMonthIntervalToByte(ymCv: ColumnVector, ym: DataType): ColumnVector = {
withResource(yearMonthIntervalToInt(ymCv, ym)) { i =>
castToTargetWithOverflowCheck(i, DType.INT8)
}
}

private def castToTargetWithOverflowCheck(cv: ColumnVector, dType: DType): ColumnVector = {
withResource(cv.castTo(dType)) { retTarget =>
withResource(cv.notEqualTo(retTarget)) { notEqual =>
if (BoolUtils.isAnyValidTrue(notEqual)) {
// TODO change to SparkArithmeticException
throw new ArithmeticException(s"Cast to $dType overflowed")
} else {
retTarget.incRefCount()
}
}
}
}

/**
* Convert long cv to `day time interval`
*/
def longToDayTimeInterval(longCv: ColumnVector, dt: DataType): ColumnVector = {
val microsScalar = dt.asInstanceOf[DT].endField match {
case DT.DAY => Scalar.fromLong(MICROS_PER_DAY)
case DT.HOUR => Scalar.fromLong(MICROS_PER_HOUR)
case DT.MINUTE => Scalar.fromLong(MICROS_PER_MINUTE)
case DT.SECOND => Scalar.fromLong(MICROS_PER_SECOND)
}
withResource(microsScalar) { micros =>
// leverage `Decimal 128` to check the overflow
IntervalUtils.multipleToLongWithOverflowCheck(longCv, micros)
}
}

/**
* Convert (byte | short | int) cv to `day time interval`
*/
def intToDayTimeInterval(intCv: ColumnVector, dt: DataType): ColumnVector = {
dt.asInstanceOf[DT].endField match {
case DT.DAY => withResource(Scalar.fromLong(MICROS_PER_DAY)) { micros =>
// leverage `Decimal 128` to check the overflow
IntervalUtils.multipleToLongWithOverflowCheck(intCv, micros)
}
case DT.HOUR => withResource(Scalar.fromLong(MICROS_PER_HOUR)) { micros =>
// no need to check overflow
intCv.mul(micros)
}
case DT.MINUTE => withResource(Scalar.fromLong(MICROS_PER_MINUTE)) { micros =>
// no need to check overflow
intCv.mul(micros)
}
case DT.SECOND => withResource(Scalar.fromLong(MICROS_PER_SECOND)) { micros =>
// no need to check overflow
intCv.mul(micros)
}
}
}

/**
* Convert long cv to `year month interval`
*/
def longToYearMonthInterval(longCv: ColumnVector, ym: DataType): ColumnVector = {
ym.asInstanceOf[DT].endField match {
case YM.YEAR => withResource(Scalar.fromLong(MONTHS_PER_YEAR)) { num12 =>
// leverage `Decimal 128` to check the overflow
IntervalUtils.multipleToIntWithOverflowCheck(longCv, num12)
}
case YM.MONTH => IntervalUtils.castLongToIntWithOverflowCheck(longCv)
}
}

/**
* Convert (byte | short | int) cv to `year month interval`
*/
def intToYearMonthInterval(intCv: ColumnVector, ym: DataType): ColumnVector = {
(ym.asInstanceOf[DT].endField, intCv.getType) match {
case (YM.YEAR, DType.INT32) => withResource(Scalar.fromInt(MONTHS_PER_YEAR)) { num12 =>
// leverage `Decimal 128` to check the overflow
IntervalUtils.multipleToIntWithOverflowCheck(intCv, num12)
}
case (YM.YEAR, DType.INT16 | DType.INT8) => withResource(Scalar.fromInt(MONTHS_PER_YEAR)) {
num12 => intCv.mul(num12)
}
case (YM.MONTH, _) => intCv.castTo(DType.INT32)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ package com.nvidia.spark.rapids.shims

import ai.rapids.cudf
import ai.rapids.cudf.{DType, Scalar}
import com.nvidia.spark.rapids.{ColumnarCopyHelper, TypeSig}
import com.nvidia.spark.rapids.{ColumnarCopyHelper, TypeEnum, TypeSig}
import com.nvidia.spark.rapids.GpuRowToColumnConverter.{IntConverter, LongConverter, NotNullIntConverter, NotNullLongConverter, TypeConverter}

import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, YearMonthIntervalType}
Expand Down Expand Up @@ -202,7 +202,33 @@ object GpuTypeShims {
*/
def additionalCsvSupportedTypes: TypeSig = TypeSig.DAYTIME

def typesDayTimeCanCastTo: TypeSig = TypeSig.DAYTIME + TypeSig.STRING
def typesDayTimeCanCastTo: TypeSig = TypeSig.DAYTIME + TypeSig.STRING + TypeSig.integral

def typesYearMonthCanCastTo: TypeSig = TypeSig.integral

def typesDayTimeCanCastToOnSpark: TypeSig = TypeSig.DAYTIME + TypeSig.STRING + TypeSig.integral

def typesYearMonthCanCastToOnSpark: TypeSig = TypeSig.YEARMONTH + TypeSig.STRING +
TypeSig.integral

val yearMonthFieldsEqualsTypeSig = new TypeSig(
TypeEnum.ValueSet(TypeEnum.YEARMONTH),
extraCheck = t => t match {
// if `data type` is `YearMonthIntervalType`, check if startField == endField
case YearMonthIntervalType(startField, endField) => startField == endField
case _ => true
})

val dayTimeFieldsEqualsTypeSig = new TypeSig(
TypeEnum.ValueSet(TypeEnum.DAYTIME),
extraCheck = t => t match {
// if `data type` is `DayTimeIntervalType`, check if startField == endField
case DayTimeIntervalType(startField, endField) => startField == endField
case _ => true
})

def additionalTypesIntegralCanCastTo: TypeSig = yearMonthFieldsEqualsTypeSig +
dayTimeFieldsEqualsTypeSig

def additionalTypesStringCanCastTo: TypeSig = TypeSig.DAYTIME

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ object IntervalUtils extends Arm {
* Multiple with overflow check, then cast to long
* Equivalent to Math.multiplyExact
*
* @param left cv or scalar
* @param right cv or scalar, will not be scalar if left is scalar
* @param left cv(byte, short, int, long) or scalar
* @param right cv(byte, short, int, long) or scalar, will not be scalar if left is scalar
* @return the long result of left * right
*/
def multipleToLongWithOverflowCheck(left: BinaryOperable, right: BinaryOperable): ColumnVector = {
Expand All @@ -82,8 +82,8 @@ object IntervalUtils extends Arm {
* Multiple with overflow check, then cast to int
* Equivalent to Math.multiplyExact
*
* @param left cv or scalar
* @param right cv or scalar, will not be scalar if left is scalar
* @param left cv(byte, short, int, long) or scalar
* @param right cv(byte, short, int, long) or scalar, will not be scalar if left is scalar
* @return the int result of left * right
*/
def multipleToIntWithOverflowCheck(left: BinaryOperable, right: BinaryOperable): ColumnVector = {
Expand Down
33 changes: 33 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,39 @@ object GpuCast extends Arm {
GpuIntervalUtils.castStringToDayTimeIntervalWithThrow(
input.asInstanceOf[ColumnVector], dayTime)

// cast(`day time interval` as integral)
case (dt: DataType, _: LongType) if GpuTypeShims.isSupportedDayTimeType(dt) =>
GpuIntervalUtils.dayTimeIntervalToLong(input.asInstanceOf[ColumnVector], dt)
case (dt: DataType, _: IntegerType) if GpuTypeShims.isSupportedDayTimeType(dt) =>
GpuIntervalUtils.dayTimeIntervalToInt(input.asInstanceOf[ColumnVector], dt)
case (dt: DataType, _: ShortType) if GpuTypeShims.isSupportedDayTimeType(dt) =>
GpuIntervalUtils.dayTimeIntervalToShort(input.asInstanceOf[ColumnVector], dt)
case (dt: DataType, _: ByteType) if GpuTypeShims.isSupportedDayTimeType(dt) =>
GpuIntervalUtils.dayTimeIntervalToByte(input.asInstanceOf[ColumnVector], dt)

// cast(integral as `day time interval`)
case (_: LongType, dt: DataType) if GpuTypeShims.isSupportedDayTimeType(dt) =>
GpuIntervalUtils.longToDayTimeInterval(input.asInstanceOf[ColumnVector], dt)
case (_: IntegerType | ShortType | ByteType, dt: DataType)
if GpuTypeShims.isSupportedDayTimeType(dt) =>
GpuIntervalUtils.intToDayTimeInterval(input.asInstanceOf[ColumnVector], dt)

// cast(`year month interval` as integral)
case (ym: DataType, _: LongType) if GpuTypeShims.isSupportedYearMonthType(ym) =>
GpuIntervalUtils.yearMonthIntervalToLong(input.asInstanceOf[ColumnVector], ym)
case (ym: DataType, _: IntegerType) if GpuTypeShims.isSupportedYearMonthType(ym) =>
GpuIntervalUtils.yearMonthIntervalToInt(input.asInstanceOf[ColumnVector], ym)
case (ym: DataType, _: ShortType) if GpuTypeShims.isSupportedYearMonthType(ym) =>
GpuIntervalUtils.yearMonthIntervalToShort(input.asInstanceOf[ColumnVector], ym)
case (ym: DataType, _: ByteType) if GpuTypeShims.isSupportedYearMonthType(ym) =>
GpuIntervalUtils.yearMonthIntervalToByte(input.asInstanceOf[ColumnVector], ym)

// cast(integral as `year month interval`)
case (_: LongType, ym: DataType) if GpuTypeShims.isSupportedYearMonthType(ym) =>
GpuIntervalUtils.longToYearMonthInterval(input.asInstanceOf[ColumnVector], ym)
case (_: IntegerType | ShortType | ByteType, ym: DataType)
if GpuTypeShims.isSupportedYearMonthType(ym) =>
GpuIntervalUtils.intToYearMonthInterval(input.asInstanceOf[ColumnVector], ym)
case _ =>
input.castTo(GpuColumnVector.getNonNestedRapidsType(toDataType))
}
Expand Down
Loading

0 comments on commit f570c04

Please sign in to comment.