From 447223e2b7a3f06046ab679569655b5e73fa6a93 Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Wed, 27 Apr 2022 21:22:11 +0800 Subject: [PATCH 1/4] Supports casting between ANSI interval types and integral types Signed-off-by: Chong Gao --- .../src/main/python/cast_test.py | 67 ++++++++ .../spark/rapids/shims/GpuIntervalUtils.scala | 52 +++++- .../spark/rapids/shims/GpuTypeShims.scala | 8 + .../spark/rapids/shims/GpuIntervalUtils.scala | 154 +++++++++++++++++- .../spark/rapids/shims/GpuTypeShims.scala | 11 +- .../rapids/shims/intervalExpressions.scala | 8 +- .../com/nvidia/spark/rapids/GpuCast.scala | 33 ++++ .../com/nvidia/spark/rapids/TypeChecks.scala | 11 +- .../spark/rapids/IntervalCastSuite.scala | 102 ++++++++++++ 9 files changed, 430 insertions(+), 16 deletions(-) create mode 100644 tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalCastSuite.scala diff --git a/integration_tests/src/main/python/cast_test.py b/integration_tests/src/main/python/cast_test.py index 62a77ae2004..87d8f9f004c 100644 --- a/integration_tests/src/main/python/cast_test.py +++ b/integration_tests/src/main/python/cast_test.py @@ -20,6 +20,7 @@ from marks import allow_non_gpu, approximate_float from pyspark.sql.types import * from spark_init_internal import spark_version +import math _decimal_gen_36_5 = DecimalGen(precision=36, scale=5) @@ -343,3 +344,69 @@ def fun(spark): df = spark.createDataFrame(data, StringType()) return df.select(f.col('value').cast(dtType)).collect() assert_gpu_and_cpu_error(fun, {}, "java.lang.IllegalArgumentException") + +integral_types = [ByteType(), ShortType(), IntegerType(), LongType()] +@pytest.mark.skipif(is_before_spark_330(), reason='casting between interval and integral is not supported before Pyspark 3.3.0') +@pytest.mark.parametrize('integral_type', integral_types) +def test_cast_day_time_interval_to_integral_no_overflow(integral_type): + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, DayTimeIntervalGen(start_field='day', end_field='day', min_value=timedelta(seconds=-128 * 86400), max_value=timedelta(seconds=127 * 86400))) + .select(f.col('a').cast(integral_type))) + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, DayTimeIntervalGen(start_field='hour', end_field='hour', min_value=timedelta(seconds=-128 * 3600), max_value=timedelta(seconds=127 * 3600))) + .select(f.col('a').cast(integral_type))) + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, DayTimeIntervalGen(start_field='minute', end_field='minute', min_value=timedelta(seconds=-128 * 60), max_value=timedelta(seconds=127 * 60))) + .select(f.col('a').cast(integral_type))) + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, DayTimeIntervalGen(start_field='second', end_field='second', min_value=timedelta(seconds=-128), max_value=timedelta(seconds=127))) + .select(f.col('a').cast(integral_type))) + +integral_gens_no_overflow = [ + LongGen(min_val=math.ceil(LONG_MIN / 86400 / 1000000), max_val=math.floor(LONG_MAX / 86400 / 1000000), special_cases=[0, 1, -1]), + IntegerGen(min_val=math.ceil(INT_MIN / 86400 / 1000000), max_val=math.floor(INT_MAX / 86400 / 1000000), special_cases=[0, 1, -1]), + ShortGen(), + ByteGen() +] +day_time_fields = [0, 1, 2, 3] # 0 is day, 1 is hour, 2 is minute, 3 is second +@pytest.mark.skipif(is_before_spark_330(), reason='casting between interval and integral is not supported before Pyspark 3.3.0') +@pytest.mark.parametrize('integral_gen_no_overflow', integral_gens_no_overflow) +@pytest.mark.parametrize('day_time_field', day_time_fields) +def test_cast_integral_to_day_time_interval_no_overflow(integral_gen_no_overflow, day_time_field): + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, integral_gen_no_overflow).select(f.col('a').cast(DayTimeIntervalType(day_time_field, day_time_field)))) + +cast_day_time_to_inregral_overflow_pairs = [ + (INT_MIN - 1, IntegerType()), + (INT_MAX + 1, IntegerType()), + (SHORT_MIN - 1, ShortType()), + (SHORT_MAX + 1, ShortType()), + (BYTE_MIN - 1, ByteType()), + (BYTE_MAX + 1, ByteType()) +] +@pytest.mark.skipif(is_before_spark_330(), reason='casting between interval and integral is not supported before Pyspark 3.3.0') +@pytest.mark.parametrize('large_second, integral_type', cast_day_time_to_inregral_overflow_pairs) +def test_cast_day_time_interval_to_integral_overflow(large_second, integral_type): + def getDf(spark): + return spark.createDataFrame([timedelta(seconds=large_second)], DayTimeIntervalType(DayTimeIntervalType.SECOND, DayTimeIntervalType.SECOND)) + assert_gpu_and_cpu_error( + lambda spark: getDf(spark).select(f.col('value').cast(integral_type)).collect(), + conf={}, + error_message="overflow") + +day_time_interval_max_day = math.floor(LONG_MAX / (86400 * 1000000)) +large_days_overflow_pairs = [ + (-day_time_interval_max_day - 1, LongType()), + (+day_time_interval_max_day + 1, LongType()), + (-day_time_interval_max_day - 1, IntegerType()), + (+day_time_interval_max_day + 1, IntegerType()) +] +@pytest.mark.skipif(is_before_spark_330(), reason='casting between interval and integral is not supported before Pyspark 3.3.0') +@pytest.mark.parametrize('large_day,integral_type', large_days_overflow_pairs) +def test_cast_integral_to_day_time_interval_overflow(large_day, integral_type): + def getDf(spark): + return spark.createDataFrame([large_day], integral_type) + assert_gpu_and_cpu_error( + lambda spark: getDf(spark).select(f.col('value').cast(DayTimeIntervalType(DayTimeIntervalType.DAY, DayTimeIntervalType.DAY))).collect(), + conf={}, + error_message="overflow") diff --git a/sql-plugin/src/main/311until330-all/scala/com/nvidia/spark/rapids/shims/GpuIntervalUtils.scala b/sql-plugin/src/main/311until330-all/scala/com/nvidia/spark/rapids/shims/GpuIntervalUtils.scala index 45f143102df..f39fdea4c3e 100644 --- a/sql-plugin/src/main/311until330-all/scala/com/nvidia/spark/rapids/shims/GpuIntervalUtils.scala +++ b/sql-plugin/src/main/311until330-all/scala/com/nvidia/spark/rapids/shims/GpuIntervalUtils.scala @@ -25,10 +25,58 @@ import org.apache.spark.sql.types.DataType object GpuIntervalUtils { def castStringToDayTimeIntervalWithThrow(cv: ColumnVector, t: DataType): ColumnVector = { - throw new IllegalStateException() + throw new IllegalStateException("Not supported in this Shim") } def toDayTimeIntervalString(micros: ColumnVector, dayTimeType: DataType): ColumnVector = { - throw new IllegalStateException() + throw new IllegalStateException("Not supported in this Shim") + } + + def dayTimeIntervalToLong(dtCv: ColumnVector, dt: DataType): ColumnVector = { + throw new IllegalStateException("Not supported in this Shim") + } + + def dayTimeIntervalToInt(dtCv: ColumnVector, dt: DataType): ColumnVector = { + throw new IllegalStateException("Not supported in this Shim") + } + + def dayTimeIntervalToShort(dtCv: ColumnVector, dt: DataType): ColumnVector = { + throw new IllegalStateException("Not supported in this Shim") + } + + def dayTimeIntervalToByte(dtCv: ColumnVector, dt: DataType): ColumnVector = { + throw new IllegalStateException("Not supported in this Shim") + } + + def yearMonthIntervalToLong(ymCv: ColumnVector, ym: DataType): ColumnVector = { + throw new IllegalStateException("Not supported in this Shim") + } + + def yearMonthIntervalToInt(ymCv: ColumnVector, ym: DataType): ColumnVector = { + throw new IllegalStateException("Not supported in this Shim") + } + + def yearMonthIntervalToShort(ymCv: ColumnVector, ym: DataType): ColumnVector = { + throw new IllegalStateException("Not supported in this Shim") + } + + def yearMonthIntervalToByte(ymCv: ColumnVector, ym: DataType): ColumnVector = { + throw new IllegalStateException("Not supported in this Shim") + } + + def longToDayTimeInterval(longCv: ColumnVector, dt: DataType): ColumnVector = { + throw new IllegalStateException("Not supported in this Shim") + } + + def intToDayTimeInterval(intCv: ColumnVector, dt: DataType): ColumnVector = { + throw new IllegalStateException("Not supported in this Shim") + } + + def longToYearMonthInterval(longCv: ColumnVector, ym: DataType): ColumnVector = { + throw new IllegalStateException("Not supported in this Shim") + } + + def intToYearMonthInterval(intCv: ColumnVector, ym: DataType): ColumnVector = { + throw new IllegalStateException("Not supported in this Shim") } } diff --git a/sql-plugin/src/main/311until330-all/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala b/sql-plugin/src/main/311until330-all/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala index 1c069a08950..3760be0735c 100644 --- a/sql-plugin/src/main/311until330-all/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala +++ b/sql-plugin/src/main/311until330-all/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala @@ -112,6 +112,14 @@ object GpuTypeShims { def typesDayTimeCanCastTo: TypeSig = TypeSig.none + def typesYearMonthCanCastTo: TypeSig = TypeSig.none + + def typesDayTimeCanCastToOnSpark: TypeSig = TypeSig.DAYTIME + TypeSig.STRING + + def typesYearMonthCanCastToOnSpark: TypeSig = TypeSig.YEARMONTH + TypeSig.STRING + + def additionalTypesIntegralCanCastTo: TypeSig = TypeSig.none + def additionalTypesStringCanCastTo: TypeSig = TypeSig.none /** diff --git a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuIntervalUtils.scala b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuIntervalUtils.scala index 26d09b85a71..ddeaed634ba 100644 --- a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuIntervalUtils.scala +++ b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuIntervalUtils.scala @@ -21,11 +21,11 @@ import java.util.concurrent.TimeUnit.{DAYS, HOURS, MINUTES, SECONDS} import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf.{ColumnVector, ColumnView, DType, Scalar} -import com.nvidia.spark.rapids.Arm -import com.nvidia.spark.rapids.CloseableHolder +import com.nvidia.spark.rapids.{Arm, BoolUtils, CloseableHolder} -import org.apache.spark.sql.catalyst.util.DateTimeConstants.{MICROS_PER_DAY, MICROS_PER_HOUR, MICROS_PER_MINUTE, MICROS_PER_SECOND} -import org.apache.spark.sql.types.{DataType, DayTimeIntervalType => DT} +import org.apache.spark.sql.catalyst.util.DateTimeConstants.{MICROS_PER_DAY, MICROS_PER_HOUR, MICROS_PER_MINUTE, MICROS_PER_SECOND, MONTHS_PER_YEAR} +import org.apache.spark.sql.rapids.shims.IntervalUtils +import org.apache.spark.sql.types.{DataType, DayTimeIntervalType => DT, YearMonthIntervalType => YM} /** * Parse DayTimeIntervalType string column to long column of micro seconds @@ -741,4 +741,150 @@ object GpuIntervalUtils extends Arm { } } } + + def dayTimeIntervalToLong(dtCv: ColumnVector, dt: DataType): ColumnVector = { + dt.asInstanceOf[DT].endField match { + case DT.DAY => withResource(Scalar.fromLong(MICROS_PER_DAY)) { micros => + dtCv.div(micros) + } + case DT.HOUR => withResource(Scalar.fromLong(MICROS_PER_HOUR)) { micros => + dtCv.div(micros) + } + case DT.MINUTE => withResource(Scalar.fromLong(MICROS_PER_MINUTE)) { micros => + dtCv.div(micros) + } + case DT.SECOND => withResource(Scalar.fromLong(MICROS_PER_SECOND)) { micros => + dtCv.div(micros) + } + } + } + + def dayTimeIntervalToInt(dtCv: ColumnVector, dt: DataType): ColumnVector = { + withResource(dayTimeIntervalToLong(dtCv, dt)) { longCv => + castToTargetWithOverflowCheck(longCv, DType.INT32) + } + } + + def dayTimeIntervalToShort(dtCv: ColumnVector, dt: DataType): ColumnVector = { + withResource(dayTimeIntervalToLong(dtCv, dt)) { longCv => + castToTargetWithOverflowCheck(longCv, DType.INT16) + } + } + + def dayTimeIntervalToByte(dtCv: ColumnVector, dt: DataType): ColumnVector = { + withResource(dayTimeIntervalToLong(dtCv, dt)) { longCv => + castToTargetWithOverflowCheck(longCv, DType.INT8) + } + } + + def yearMonthIntervalToLong(ymCv: ColumnVector, ym: DataType): ColumnVector = { + ym.asInstanceOf[YM].endField match { + case YM.YEAR => withResource(Scalar.fromLong(MONTHS_PER_YEAR)) { monthsPerYear => + ymCv.div(monthsPerYear) + } + case YM.MONTH => ymCv.castTo(DType.INT64) + } + } + + def yearMonthIntervalToInt(ymCv: ColumnVector, ym: DataType): ColumnVector = { + ym.asInstanceOf[YM].endField match { + case YM.YEAR => withResource(Scalar.fromInt(MONTHS_PER_YEAR)) { monthsPerYear => + ymCv.div(monthsPerYear) + } + case YM.MONTH => ymCv.incRefCount() + } + } + + def yearMonthIntervalToShort(ymCv: ColumnVector, ym: DataType): ColumnVector = { + withResource(yearMonthIntervalToInt(ymCv, ym)) { i => + castToTargetWithOverflowCheck(i, DType.INT16) + } + } + + def yearMonthIntervalToByte(ymCv: ColumnVector, ym: DataType): ColumnVector = { + withResource(yearMonthIntervalToInt(ymCv, ym)) { i => + castToTargetWithOverflowCheck(i, DType.INT8) + } + } + + private def castToTargetWithOverflowCheck(cv: ColumnVector, dType: DType): ColumnVector = { + withResource(cv.castTo(dType)) { retTarget => + withResource(cv.notEqualTo(retTarget)) { notEqual => + if (BoolUtils.isAnyValidTrue(notEqual)) { + throw new ArithmeticException(s"overflow occurs when casting to $dType") + } else { + retTarget.incRefCount() + } + } + } + } + + /** + * Convert long cv to `day time interval` + */ + def longToDayTimeInterval(longCv: ColumnVector, dt: DataType): ColumnVector = { + val microsScalar = dt.asInstanceOf[DT].endField match { + case DT.DAY => Scalar.fromLong(MICROS_PER_DAY) + case DT.HOUR => Scalar.fromLong(MICROS_PER_HOUR) + case DT.MINUTE => Scalar.fromLong(MICROS_PER_MINUTE) + case DT.SECOND => Scalar.fromLong(MICROS_PER_SECOND) + } + withResource(microsScalar) { micros => + // leverage `Decimal 128` to check the overflow + IntervalUtils.multipleToLongWithOverflowCheck(longCv, micros) + } + } + + /** + * Convert (byte | short | int) cv to `day time interval` + */ + def intToDayTimeInterval(intCv: ColumnVector, dt: DataType): ColumnVector = { + dt.asInstanceOf[DT].endField match { + case DT.DAY => withResource(Scalar.fromLong(MICROS_PER_DAY)) { micros => + // leverage `Decimal 128` to check the overflow + IntervalUtils.multipleToLongWithOverflowCheck(intCv, micros) + } + case DT.HOUR => withResource(Scalar.fromLong(MICROS_PER_HOUR)) { micros => + // no need to check overflow + intCv.mul(micros) + } + case DT.MINUTE => withResource(Scalar.fromLong(MICROS_PER_MINUTE)) { micros => + // no need to check overflow + intCv.mul(micros) + } + case DT.SECOND => withResource(Scalar.fromLong(MICROS_PER_SECOND)) { micros => + // no need to check overflow + intCv.mul(micros) + } + } + } + + /** + * Convert long cv to `year month interval` + */ + def longToYearMonthInterval(longCv: ColumnVector, ym: DataType): ColumnVector = { + ym.asInstanceOf[YM].endField match { + case YM.YEAR => withResource(Scalar.fromLong(MONTHS_PER_YEAR)) { num12 => + // leverage `Decimal 128` to check the overflow + IntervalUtils.multipleToIntWithOverflowCheck(longCv, num12) + } + case YM.MONTH => IntervalUtils.castLongToIntWithOverflowCheck(longCv) + } + } + + /** + * Convert (byte | short | int) cv to `year month interval` + */ + def intToYearMonthInterval(intCv: ColumnVector, ym: DataType): ColumnVector = { + (ym.asInstanceOf[YM].endField, intCv.getType) match { + case (YM.YEAR, DType.INT32) => withResource(Scalar.fromInt(MONTHS_PER_YEAR)) { num12 => + // leverage `Decimal 128` to check the overflow + IntervalUtils.multipleToIntWithOverflowCheck(intCv, num12) + } + case (YM.YEAR, DType.INT16 | DType.INT8) => withResource(Scalar.fromInt(MONTHS_PER_YEAR)) { + num12 => intCv.mul(num12) + } + case (YM.MONTH, _) => intCv.castTo(DType.INT32) + } + } } diff --git a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala index ebfb604047d..424cca5e483 100644 --- a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala +++ b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala @@ -202,7 +202,16 @@ object GpuTypeShims { */ def additionalCsvSupportedTypes: TypeSig = TypeSig.DAYTIME - def typesDayTimeCanCastTo: TypeSig = TypeSig.DAYTIME + TypeSig.STRING + def typesDayTimeCanCastTo: TypeSig = TypeSig.DAYTIME + TypeSig.STRING + TypeSig.integral + + def typesYearMonthCanCastTo: TypeSig = TypeSig.integral + + def typesDayTimeCanCastToOnSpark: TypeSig = TypeSig.DAYTIME + TypeSig.STRING + TypeSig.integral + + def typesYearMonthCanCastToOnSpark: TypeSig = TypeSig.YEARMONTH + TypeSig.STRING + + TypeSig.integral + + def additionalTypesIntegralCanCastTo: TypeSig = TypeSig.YEARMONTH + TypeSig.DAYTIME def additionalTypesStringCanCastTo: TypeSig = TypeSig.DAYTIME diff --git a/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala b/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala index 2c923ca0791..3364a0b07e3 100644 --- a/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala +++ b/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala @@ -66,8 +66,8 @@ object IntervalUtils extends Arm { * Multiple with overflow check, then cast to long * Equivalent to Math.multiplyExact * - * @param left cv or scalar - * @param right cv or scalar, will not be scalar if left is scalar + * @param left cv(byte, short, int, long) or scalar + * @param right cv(byte, short, int, long) or scalar, will not be scalar if left is scalar * @return the long result of left * right */ def multipleToLongWithOverflowCheck(left: BinaryOperable, right: BinaryOperable): ColumnVector = { @@ -82,8 +82,8 @@ object IntervalUtils extends Arm { * Multiple with overflow check, then cast to int * Equivalent to Math.multiplyExact * - * @param left cv or scalar - * @param right cv or scalar, will not be scalar if left is scalar + * @param left cv(byte, short, int, long) or scalar + * @param right cv(byte, short, int, long) or scalar, will not be scalar if left is scalar * @return the int result of left * right */ def multipleToIntWithOverflowCheck(left: BinaryOperable, right: BinaryOperable): ColumnVector = { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala index 9fe85782702..5fc48361544 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala @@ -561,6 +561,39 @@ object GpuCast extends Arm { GpuIntervalUtils.castStringToDayTimeIntervalWithThrow( input.asInstanceOf[ColumnVector], dayTime) + // cast(`day time interval` as integral) + case (dt: DataType, _: LongType) if GpuTypeShims.isSupportedDayTimeType(dt) => + GpuIntervalUtils.dayTimeIntervalToLong(input.asInstanceOf[ColumnVector], dt) + case (dt: DataType, _: IntegerType) if GpuTypeShims.isSupportedDayTimeType(dt) => + GpuIntervalUtils.dayTimeIntervalToInt(input.asInstanceOf[ColumnVector], dt) + case (dt: DataType, _: ShortType) if GpuTypeShims.isSupportedDayTimeType(dt) => + GpuIntervalUtils.dayTimeIntervalToShort(input.asInstanceOf[ColumnVector], dt) + case (dt: DataType, _: ByteType) if GpuTypeShims.isSupportedDayTimeType(dt) => + GpuIntervalUtils.dayTimeIntervalToByte(input.asInstanceOf[ColumnVector], dt) + + // cast(integral as `day time interval`) + case (_: LongType, dt: DataType) if GpuTypeShims.isSupportedDayTimeType(dt) => + GpuIntervalUtils.longToDayTimeInterval(input.asInstanceOf[ColumnVector], dt) + case (_: IntegerType | ShortType | ByteType, dt: DataType) + if GpuTypeShims.isSupportedDayTimeType(dt) => + GpuIntervalUtils.intToDayTimeInterval(input.asInstanceOf[ColumnVector], dt) + + // cast(`year month interval` as integral) + case (ym: DataType, _: LongType) if GpuTypeShims.isSupportedYearMonthType(ym) => + GpuIntervalUtils.yearMonthIntervalToLong(input.asInstanceOf[ColumnVector], ym) + case (ym: DataType, _: IntegerType) if GpuTypeShims.isSupportedYearMonthType(ym) => + GpuIntervalUtils.yearMonthIntervalToInt(input.asInstanceOf[ColumnVector], ym) + case (ym: DataType, _: ShortType) if GpuTypeShims.isSupportedYearMonthType(ym) => + GpuIntervalUtils.yearMonthIntervalToShort(input.asInstanceOf[ColumnVector], ym) + case (ym: DataType, _: ByteType) if GpuTypeShims.isSupportedYearMonthType(ym) => + GpuIntervalUtils.yearMonthIntervalToByte(input.asInstanceOf[ColumnVector], ym) + + // cast(integral as `year month interval`) + case (_: LongType, ym: DataType) if GpuTypeShims.isSupportedYearMonthType(ym) => + GpuIntervalUtils.longToYearMonthInterval(input.asInstanceOf[ColumnVector], ym) + case (_: IntegerType | ShortType | ByteType, ym: DataType) + if GpuTypeShims.isSupportedYearMonthType(ym) => + GpuIntervalUtils.intToYearMonthInterval(input.asInstanceOf[ColumnVector], ym) case _ => input.castTo(GpuColumnVector.getNonNestedRapidsType(toDataType)) } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala index 89cf1b0ad39..87855209f10 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala @@ -1302,8 +1302,9 @@ class CastChecks extends ExprChecks { val sparkBooleanSig: TypeSig = cpuNumeric + BOOLEAN + TIMESTAMP + STRING val integralChecks: TypeSig = gpuNumeric + BOOLEAN + TIMESTAMP + STRING + - BINARY - val sparkIntegralSig: TypeSig = cpuNumeric + BOOLEAN + TIMESTAMP + STRING + BINARY + BINARY + GpuTypeShims.additionalTypesIntegralCanCastTo + val sparkIntegralSig: TypeSig = cpuNumeric + BOOLEAN + TIMESTAMP + STRING + BINARY + + BINARY + GpuTypeShims.additionalTypesIntegralCanCastTo val fpToStringPsNote: String = s"Conversion may produce different results and requires " + s"${RapidsConf.ENABLE_CAST_FLOAT_TO_STRING} to be true." @@ -1357,10 +1358,10 @@ class CastChecks extends ExprChecks { val sparkUdtSig: TypeSig = STRING + UDT val daytimeChecks: TypeSig = GpuTypeShims.typesDayTimeCanCastTo - val sparkDaytimeChecks: TypeSig = DAYTIME + STRING + val sparkDaytimeChecks: TypeSig = GpuTypeShims.typesDayTimeCanCastToOnSpark - val yearmonthChecks: TypeSig = none - val sparkYearmonthChecks: TypeSig = YEARMONTH + STRING + val yearmonthChecks: TypeSig = GpuTypeShims.typesYearMonthCanCastTo + val sparkYearmonthChecks: TypeSig = GpuTypeShims.typesYearMonthCanCastToOnSpark private[this] def getChecksAndSigs(from: DataType): (TypeSig, TypeSig) = from match { case NullType => (nullChecks, sparkNullSig) diff --git a/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalCastSuite.scala b/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalCastSuite.scala new file mode 100644 index 00000000000..3f7fde597e7 --- /dev/null +++ b/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalCastSuite.scala @@ -0,0 +1,102 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids + +import java.time.Period + +import org.apache.spark.SparkException +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ + +/** + * Can not put this suite to Pyspark test cases + * because of currently Pyspark not have year-month type. + * See: https://github.com/apache/spark/blob/branch-3.3/python/pyspark/sql/types.py + * Should move the year-month scala test cases to the integration test module, + * filed an issue to track: https://github.com/NVIDIA/spark-rapids/issues/5212 + */ +class IntervalCastSuite extends SparkQueryCompareTestSuite { + testSparkResultsAreEqual( + "test cast year-month to integral", + spark => { + val data = (-128 to 127).map(i => Row(Period.ofMonths(i))) + val schema = StructType(Seq(StructField("c_ym", YearMonthIntervalType()))) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + }) { + df => + df.selectExpr("cast(c_ym as long)", "cast(c_ym as integer)", "cast(c_ym as short)", + "cast(c_ym as byte)") + } + + testSparkResultsAreEqual( + "test cast integral to year-month", + spark => { + val data = (-128 to 127).map(i => Row(i.toLong, i, i.toShort, i.toByte)) + val schema = StructType(Seq(StructField("c_l", LongType), + StructField("c_i", IntegerType), + StructField("c_s", ShortType), + StructField("c_b", ByteType))) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + }) { + df => + df.selectExpr("cast(c_l as interval month)", "cast(c_i as interval year)", + "cast(c_s as interval year)", "cast(c_b as interval month)") + } + + val toIntegralOverflowPairs = Array( + (Byte.MinValue - 1, "byte"), + (Byte.MaxValue + 1, "byte"), + (Short.MinValue - 1, "short"), + (Short.MinValue - 1, "short")) + var testLoop = 1 + toIntegralOverflowPairs.foreach { case (months, toType) => + testBothCpuGpuExpectedException[SparkException]( + s"test cast year-month to integral, overflow $testLoop", + e => e.getMessage.contains("overflow"), + spark => { + val data = Seq(Row(Period.ofMonths(months))) + val schema = StructType(Seq(StructField("c_ym", YearMonthIntervalType()))) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + }) { + df => + df.selectExpr(s"cast(c_ym as $toType)") + } + testLoop += 1 + } + + val toYMOverflows = Array( + (Int.MinValue / 12 - 1, IntegerType, "year"), + (Int.MaxValue / 12 + 1, IntegerType, "year"), + (Int.MinValue - 1L, LongType, "month"), + (Int.MaxValue + 1L, LongType, "month"), + (Int.MinValue / 12 - 1L, LongType, "year"), + (Int.MaxValue / 12 + 1L, LongType, "year")) + testLoop = 1 + toYMOverflows.foreach { case (integral, integralType, toField) => + testBothCpuGpuExpectedException[SparkException]( + s"test cast integral to year-month, overflow $testLoop", + e => e.getMessage.contains("overflow"), + spark => { + val data = Seq(Row(integral)) + val schema = StructType(Seq(StructField("c_integral", integralType))) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + }) { + df => + df.selectExpr(s"cast(c_integral as interval $toField)") + } + testLoop += 1 + } +} From 98a9e29cd1bdc5634fcfa6b4bef6333692dca47e Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Thu, 28 Apr 2022 13:46:14 +0800 Subject: [PATCH 2/4] Fix spark330 build due to mapKeyNotExistError changed Signed-off-by: Chong Gao --- .../com/nvidia/spark/rapids/shims/RapidsErrorUtils.scala | 2 ++ .../com/nvidia/spark/rapids/shims/RapidsErrorUtils.scala | 2 ++ .../com/nvidia/spark/rapids/shims/RapidsErrorUtils.scala | 2 ++ .../com/nvidia/spark/rapids/shims/RapidsErrorUtils.scala | 4 +++- .../org/apache/spark/sql/rapids/collectionOperations.scala | 5 +++-- .../org/apache/spark/sql/rapids/complexTypeExtractors.scala | 2 +- 6 files changed, 13 insertions(+), 4 deletions(-) diff --git a/sql-plugin/src/main/311until330-nondb/scala/com/nvidia/spark/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/311until330-nondb/scala/com/nvidia/spark/rapids/shims/RapidsErrorUtils.scala index f4c81429cfe..100606ebd02 100644 --- a/sql-plugin/src/main/311until330-nondb/scala/com/nvidia/spark/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/311until330-nondb/scala/com/nvidia/spark/rapids/shims/RapidsErrorUtils.scala @@ -17,6 +17,7 @@ package com.nvidia.spark.rapids.shims import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.types.DataType object RapidsErrorUtils { def invalidArrayIndexError(index: Int, numElements: Int, @@ -27,6 +28,7 @@ object RapidsErrorUtils { def mapKeyNotExistError( key: String, + keyType: DataType, origin: Origin): NoSuchElementException = { // Follow the Spark string format before 3.3.0 new NoSuchElementException(s"Key $key does not exist.") diff --git a/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/RapidsErrorUtils.scala index f4c81429cfe..100606ebd02 100644 --- a/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/RapidsErrorUtils.scala @@ -17,6 +17,7 @@ package com.nvidia.spark.rapids.shims import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.types.DataType object RapidsErrorUtils { def invalidArrayIndexError(index: Int, numElements: Int, @@ -27,6 +28,7 @@ object RapidsErrorUtils { def mapKeyNotExistError( key: String, + keyType: DataType, origin: Origin): NoSuchElementException = { // Follow the Spark string format before 3.3.0 new NoSuchElementException(s"Key $key does not exist.") diff --git a/sql-plugin/src/main/321db/scala/com/nvidia/spark/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/321db/scala/com/nvidia/spark/rapids/shims/RapidsErrorUtils.scala index 78d2fb56b7e..759ad5ed79e 100644 --- a/sql-plugin/src/main/321db/scala/com/nvidia/spark/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/321db/scala/com/nvidia/spark/rapids/shims/RapidsErrorUtils.scala @@ -18,6 +18,7 @@ package com.nvidia.spark.rapids.shims import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.types.DataType object RapidsErrorUtils { def invalidArrayIndexError(index: Int, numElements: Int, @@ -31,6 +32,7 @@ object RapidsErrorUtils { def mapKeyNotExistError( key: String, + keyType: DataType, origin: Origin): NoSuchElementException = { // For now, the default argument is false. The caller sets the correct value accordingly. QueryExecutionErrors.mapKeyNotExistError(key) diff --git a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/RapidsErrorUtils.scala b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/RapidsErrorUtils.scala index e8f65ed8c27..0d48b58eb7d 100644 --- a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/RapidsErrorUtils.scala @@ -18,6 +18,7 @@ package com.nvidia.spark.rapids.shims import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.types.DataType object RapidsErrorUtils { def invalidArrayIndexError(index: Int, numElements: Int, @@ -31,8 +32,9 @@ object RapidsErrorUtils { def mapKeyNotExistError( key: String, + keyType: DataType, origin: Origin): NoSuchElementException = { - QueryExecutionErrors.mapKeyNotExistError(key, origin.context) + QueryExecutionErrors.mapKeyNotExistError(key, keyType, origin.context) } def sqlArrayIndexNotStartAtOneError(): ArrayIndexOutOfBoundsException = { diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala index 85a7cddcfaa..ddfaf51b549 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala @@ -223,7 +223,7 @@ case class GpuElementAt(left: Expression, right: Expression, failOnError: Boolea array.extractListElement(idx) } } - case _: MapType => + case MapType(keyType, _, _) => (map, keyS) => { val key = keyS.getBase if (failOnError) { @@ -232,7 +232,8 @@ case class GpuElementAt(left: Expression, right: Expression, failOnError: Boolea if (!exist.isValid || exist.getBoolean) { map.getMapValue(key) } else { - throw RapidsErrorUtils.mapKeyNotExistError(keyS.getValue.toString, origin) + throw RapidsErrorUtils.mapKeyNotExistError(keyS.getValue.toString, keyType, + origin) } } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala index d6e0d794936..5428161e4d0 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala @@ -212,7 +212,7 @@ case class GpuGetMapValue(child: Expression, key: Expression, failOnError: Boole withResource(lhs.getBase.getMapKeyExistence(rhs.getBase)) { keyExistenceColumn => withResource(keyExistenceColumn.all) { exist => if (exist.isValid && !exist.getBoolean) { - throw RapidsErrorUtils.mapKeyNotExistError(rhs.getValue.toString, origin) + throw RapidsErrorUtils.mapKeyNotExistError(rhs.getValue.toString, keyType, origin) } } } From cbb179e21455c3990cf6cd0ec16a56984accb320 Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Fri, 29 Apr 2022 10:39:32 +0800 Subject: [PATCH 3/4] Add test cases for nested types; Add the shim layer for hasSideEffects --- .../src/main/python/cast_test.py | 78 +++++-- .../spark/rapids/shims/GpuIntervalUtils.scala | 30 +-- .../spark/rapids/shims/GpuTypeShims.scala | 3 + .../spark/rapids/shims/GpuIntervalUtils.scala | 55 ++--- .../spark/rapids/shims/GpuTypeShims.scala | 8 + .../rapids/shims/intervalExpressions.scala | 6 +- .../com/nvidia/spark/rapids/GpuCast.scala | 43 ++-- .../nvidia/spark/rapids/GpuOverrides.scala | 6 +- .../com/nvidia/spark/rapids/TypeChecks.scala | 7 +- .../spark/rapids/IntervalCastSuite.scala | 190 +++++++++++++++++- 10 files changed, 339 insertions(+), 87 deletions(-) diff --git a/integration_tests/src/main/python/cast_test.py b/integration_tests/src/main/python/cast_test.py index 87d8f9f004c..ddf3d77a6a6 100644 --- a/integration_tests/src/main/python/cast_test.py +++ b/integration_tests/src/main/python/cast_test.py @@ -345,36 +345,61 @@ def fun(spark): return df.select(f.col('value').cast(dtType)).collect() assert_gpu_and_cpu_error(fun, {}, "java.lang.IllegalArgumentException") -integral_types = [ByteType(), ShortType(), IntegerType(), LongType()] @pytest.mark.skipif(is_before_spark_330(), reason='casting between interval and integral is not supported before Pyspark 3.3.0') -@pytest.mark.parametrize('integral_type', integral_types) -def test_cast_day_time_interval_to_integral_no_overflow(integral_type): +def test_cast_day_time_interval_to_integral_no_overflow(): + second_dt_gen = DayTimeIntervalGen(start_field='second', end_field='second', min_value=timedelta(seconds=-128), max_value=timedelta(seconds=127), nullable=False) + gen = StructGen([('a', DayTimeIntervalGen(start_field='day', end_field='day', min_value=timedelta(seconds=-128 * 86400), max_value=timedelta(seconds=127 * 86400))), + ('b', DayTimeIntervalGen(start_field='hour', end_field='hour', min_value=timedelta(seconds=-128 * 3600), max_value=timedelta(seconds=127 * 3600))), + ('c', DayTimeIntervalGen(start_field='minute', end_field='minute', min_value=timedelta(seconds=-128 * 60), max_value=timedelta(seconds=127 * 60))), + ('d', second_dt_gen), + ('c_array', ArrayGen(second_dt_gen)), + ('c_struct', StructGen([("a", second_dt_gen), ("b", second_dt_gen)])), + ('c_map', MapGen(second_dt_gen, second_dt_gen)) + ], nullable=False) assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, DayTimeIntervalGen(start_field='day', end_field='day', min_value=timedelta(seconds=-128 * 86400), max_value=timedelta(seconds=127 * 86400))) - .select(f.col('a').cast(integral_type))) - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, DayTimeIntervalGen(start_field='hour', end_field='hour', min_value=timedelta(seconds=-128 * 3600), max_value=timedelta(seconds=127 * 3600))) - .select(f.col('a').cast(integral_type))) - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, DayTimeIntervalGen(start_field='minute', end_field='minute', min_value=timedelta(seconds=-128 * 60), max_value=timedelta(seconds=127 * 60))) - .select(f.col('a').cast(integral_type))) - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, DayTimeIntervalGen(start_field='second', end_field='second', min_value=timedelta(seconds=-128), max_value=timedelta(seconds=127))) - .select(f.col('a').cast(integral_type))) + lambda spark: gen_df(spark, gen).select(f.col('a').cast(ByteType()), f.col('a').cast(ShortType()), f.col('a').cast(IntegerType()), f.col('a').cast(LongType()), + f.col('b').cast(ByteType()), f.col('b').cast(ShortType()), f.col('b').cast(IntegerType()), f.col('b').cast(LongType()), + f.col('c').cast(ByteType()), f.col('c').cast(ShortType()), f.col('c').cast(IntegerType()), f.col('c').cast(LongType()), + f.col('d').cast(ByteType()), f.col('d').cast(ShortType()), f.col('d').cast(IntegerType()), f.col('d').cast(LongType()), + f.col('c_array').cast(ArrayType(ByteType())), + f.col('c_struct').cast(StructType([StructField('a', ShortType()), StructField('b', ShortType())])), + f.col('c_map').cast(MapType(IntegerType(), IntegerType())) + )) integral_gens_no_overflow = [ LongGen(min_val=math.ceil(LONG_MIN / 86400 / 1000000), max_val=math.floor(LONG_MAX / 86400 / 1000000), special_cases=[0, 1, -1]), IntegerGen(min_val=math.ceil(INT_MIN / 86400 / 1000000), max_val=math.floor(INT_MAX / 86400 / 1000000), special_cases=[0, 1, -1]), ShortGen(), - ByteGen() + ByteGen(), + # StructGen([("a", ShortGen()), ("b", ByteGen())]) ] -day_time_fields = [0, 1, 2, 3] # 0 is day, 1 is hour, 2 is minute, 3 is second @pytest.mark.skipif(is_before_spark_330(), reason='casting between interval and integral is not supported before Pyspark 3.3.0') -@pytest.mark.parametrize('integral_gen_no_overflow', integral_gens_no_overflow) -@pytest.mark.parametrize('day_time_field', day_time_fields) -def test_cast_integral_to_day_time_interval_no_overflow(integral_gen_no_overflow, day_time_field): +def test_cast_integral_to_day_time_interval_no_overflow(): + long_gen = IntegerGen(min_val=math.ceil(INT_MIN / 86400 / 1000000), max_val=math.floor(INT_MAX / 86400 / 1000000), special_cases=[0, 1, -1]) + int_gen = LongGen(min_val=math.ceil(LONG_MIN / 86400 / 1000000), max_val=math.floor(LONG_MAX / 86400 / 1000000), special_cases=[0, 1, -1], nullable=False) + gen = StructGen([("a", long_gen), + ("b", int_gen), + ("c", ShortGen()), + ("d", ByteGen()), + ("c_struct", StructGen([("a", long_gen), ("b", int_gen)], nullable=False)), + ('c_array', ArrayGen(int_gen)), + ('c_map', MapGen(int_gen, long_gen))], nullable=False) + # day_time_field: 0 is day, 1 is hour, 2 is minute, 3 is second + day_type = DayTimeIntervalType(0, 0) + hour_type = DayTimeIntervalType(1, 1) + minute_type = DayTimeIntervalType(2, 2) + second_type = DayTimeIntervalType(3, 3) + assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, integral_gen_no_overflow).select(f.col('a').cast(DayTimeIntervalType(day_time_field, day_time_field)))) + lambda spark: gen_df(spark, gen).select( + f.col('a').cast(day_type), f.col('a').cast(hour_type), f.col('a').cast(minute_type), f.col('a').cast(second_type), + f.col('b').cast(day_type), f.col('b').cast(hour_type), f.col('b').cast(minute_type), f.col('b').cast(second_type), + f.col('c').cast(day_type), f.col('c').cast(hour_type), f.col('c').cast(minute_type), f.col('c').cast(second_type), + f.col('d').cast(day_type), f.col('d').cast(hour_type), f.col('d').cast(minute_type), f.col('d').cast(second_type), + f.col('c_struct').cast(StructType([StructField('a', day_type), StructField('b', hour_type)])), + f.col('c_array').cast(ArrayType(hour_type)), + f.col('c_map').cast(MapType(minute_type, second_type)), + )) cast_day_time_to_inregral_overflow_pairs = [ (INT_MIN - 1, IntegerType()), @@ -410,3 +435,16 @@ def getDf(spark): lambda spark: getDf(spark).select(f.col('value').cast(DayTimeIntervalType(DayTimeIntervalType.DAY, DayTimeIntervalType.DAY))).collect(), conf={}, error_message="overflow") + +def test_cast_integral_to_day_time_side_effect(): + def getDf(spark): + # INT_MAX > 106751991 (max value of interval day) + return spark.createDataFrame([(True, INT_MAX, LONG_MAX), (False, 0, 0)], "c_b boolean, c_i int, c_l long").repartition(1) + assert_gpu_and_cpu_are_equal_collect( + lambda spark: getDf(spark).selectExpr("if(c_b, interval 0 day, cast(c_i as interval day))", "if(c_b, interval 0 day, cast(c_l as interval second))")) + +def test_cast_day_time_to_integral_side_effect(): + def getDf(spark): + # 106751991 > Byte.MaxValue + return spark.createDataFrame([(True, MAX_DAY_TIME_INTERVAL), (False, (timedelta(microseconds=0)))], "c_b boolean, c_dt interval day to second").repartition(1) + assert_gpu_and_cpu_are_equal_collect(lambda spark: getDf(spark).selectExpr("if(c_b, 0, cast(c_dt as byte))", "if(c_b, 0, cast(c_dt as short))", "if(c_b, 0, cast(c_dt as int))")) diff --git a/sql-plugin/src/main/311until330-all/scala/com/nvidia/spark/rapids/shims/GpuIntervalUtils.scala b/sql-plugin/src/main/311until330-all/scala/com/nvidia/spark/rapids/shims/GpuIntervalUtils.scala index f39fdea4c3e..ebe99d57185 100644 --- a/sql-plugin/src/main/311until330-all/scala/com/nvidia/spark/rapids/shims/GpuIntervalUtils.scala +++ b/sql-plugin/src/main/311until330-all/scala/com/nvidia/spark/rapids/shims/GpuIntervalUtils.scala @@ -15,7 +15,7 @@ */ package com.nvidia.spark.rapids.shims -import ai.rapids.cudf.ColumnVector +import ai.rapids.cudf.{ColumnVector, ColumnView} import org.apache.spark.sql.types.DataType @@ -24,59 +24,59 @@ import org.apache.spark.sql.types.DataType */ object GpuIntervalUtils { - def castStringToDayTimeIntervalWithThrow(cv: ColumnVector, t: DataType): ColumnVector = { + def castStringToDayTimeIntervalWithThrow(cv: ColumnView, t: DataType): ColumnVector = { throw new IllegalStateException("Not supported in this Shim") } - def toDayTimeIntervalString(micros: ColumnVector, dayTimeType: DataType): ColumnVector = { + def toDayTimeIntervalString(micros: ColumnView, dayTimeType: DataType): ColumnVector = { throw new IllegalStateException("Not supported in this Shim") } - def dayTimeIntervalToLong(dtCv: ColumnVector, dt: DataType): ColumnVector = { + def dayTimeIntervalToLong(dtCv: ColumnView, dt: DataType): ColumnVector = { throw new IllegalStateException("Not supported in this Shim") } - def dayTimeIntervalToInt(dtCv: ColumnVector, dt: DataType): ColumnVector = { + def dayTimeIntervalToInt(dtCv: ColumnView, dt: DataType): ColumnVector = { throw new IllegalStateException("Not supported in this Shim") } - def dayTimeIntervalToShort(dtCv: ColumnVector, dt: DataType): ColumnVector = { + def dayTimeIntervalToShort(dtCv: ColumnView, dt: DataType): ColumnVector = { throw new IllegalStateException("Not supported in this Shim") } - def dayTimeIntervalToByte(dtCv: ColumnVector, dt: DataType): ColumnVector = { + def dayTimeIntervalToByte(dtCv: ColumnView, dt: DataType): ColumnVector = { throw new IllegalStateException("Not supported in this Shim") } - def yearMonthIntervalToLong(ymCv: ColumnVector, ym: DataType): ColumnVector = { + def yearMonthIntervalToLong(ymCv: ColumnView, ym: DataType): ColumnVector = { throw new IllegalStateException("Not supported in this Shim") } - def yearMonthIntervalToInt(ymCv: ColumnVector, ym: DataType): ColumnVector = { + def yearMonthIntervalToInt(ymCv: ColumnView, ym: DataType): ColumnVector = { throw new IllegalStateException("Not supported in this Shim") } - def yearMonthIntervalToShort(ymCv: ColumnVector, ym: DataType): ColumnVector = { + def yearMonthIntervalToShort(ymCv: ColumnView, ym: DataType): ColumnVector = { throw new IllegalStateException("Not supported in this Shim") } - def yearMonthIntervalToByte(ymCv: ColumnVector, ym: DataType): ColumnVector = { + def yearMonthIntervalToByte(ymCv: ColumnView, ym: DataType): ColumnVector = { throw new IllegalStateException("Not supported in this Shim") } - def longToDayTimeInterval(longCv: ColumnVector, dt: DataType): ColumnVector = { + def longToDayTimeInterval(longCv: ColumnView, dt: DataType): ColumnVector = { throw new IllegalStateException("Not supported in this Shim") } - def intToDayTimeInterval(intCv: ColumnVector, dt: DataType): ColumnVector = { + def intToDayTimeInterval(intCv: ColumnView, dt: DataType): ColumnVector = { throw new IllegalStateException("Not supported in this Shim") } - def longToYearMonthInterval(longCv: ColumnVector, ym: DataType): ColumnVector = { + def longToYearMonthInterval(longCv: ColumnView, ym: DataType): ColumnVector = { throw new IllegalStateException("Not supported in this Shim") } - def intToYearMonthInterval(intCv: ColumnVector, ym: DataType): ColumnVector = { + def intToYearMonthInterval(intCv: ColumnView, ym: DataType): ColumnVector = { throw new IllegalStateException("Not supported in this Shim") } } diff --git a/sql-plugin/src/main/311until330-all/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala b/sql-plugin/src/main/311until330-all/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala index 3760be0735c..fe5a90b1d4d 100644 --- a/sql-plugin/src/main/311until330-all/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala +++ b/sql-plugin/src/main/311until330-all/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala @@ -133,4 +133,7 @@ object GpuTypeShims { */ def additionalCommonOperatorSupportedTypes: TypeSig = TypeSig.none + def hasSideEffectsIfCastIntToYearMonth(ym: DataType): Boolean = false + + def hasSideEffectsIfCastIntToDayTime(dt: DataType): Boolean = false } diff --git a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuIntervalUtils.scala b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuIntervalUtils.scala index ddeaed634ba..73cc132d7a0 100644 --- a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuIntervalUtils.scala +++ b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuIntervalUtils.scala @@ -135,7 +135,7 @@ object GpuIntervalUtils extends Arm { private val secondPatternString = s"$sign$secondBoundPattern$microPattern" private val secondLiteralRegex = s"^$INTERVAL$blanks$sign'$secondPatternString'$blanks$SECOND$$" - def castStringToDayTimeIntervalWithThrow(cv: ColumnVector, t: DataType): ColumnVector = { + def castStringToDayTimeIntervalWithThrow(cv: ColumnView, t: DataType): ColumnVector = { castStringToDayTimeIntervalWithThrow(cv, t.asInstanceOf[DT]) } @@ -147,7 +147,7 @@ object GpuIntervalUtils extends Arm { * @return long column of micros * @throws IllegalArgumentException if have a row failed */ - def castStringToDayTimeIntervalWithThrow(cv: ColumnVector, t: DT): ColumnVector = { + def castStringToDayTimeIntervalWithThrow(cv: ColumnView, t: DT): ColumnVector = { withResource(castStringToDTInterval(cv, t)) { ret => if(ret.getNullCount > cv.getNullCount) { throw new IllegalArgumentException("Cast string to day time interval failed, " + @@ -166,7 +166,7 @@ object GpuIntervalUtils extends Arm { * @param t day-time interval type * @return long column of micros */ - def castStringToDTInterval(cv: ColumnVector, t: DT): ColumnVector = { + def castStringToDTInterval(cv: ColumnView, t: DT): ColumnVector = { (t.startField, t.endField) match { case (DT.DAY, DT.DAY) => withResource(cv.extractRe(dayLiteralRegex)) { groupsTable => { @@ -538,7 +538,7 @@ object GpuIntervalUtils extends Arm { } } - def toDayTimeIntervalString(micros: ColumnVector, dayTimeType: DataType): ColumnVector = { + def toDayTimeIntervalString(micros: ColumnView, dayTimeType: DataType): ColumnVector = { val t = dayTimeType.asInstanceOf[DT] toDayTimeIntervalString(micros, t.startField, t.endField) } @@ -554,7 +554,7 @@ object GpuIntervalUtils extends Arm { * @return ANSI day-time interval string, e.g.: interval '01 08:30:30.001' DAY TO SECOND */ def toDayTimeIntervalString( - micros: ColumnVector, + micros: ColumnView, startField: Byte, endField: Byte): ColumnVector = { @@ -563,11 +563,10 @@ object GpuIntervalUtils extends Arm { val to = DT.fieldToString(endField).toUpperCase val postfixStr = s"' ${if (startField == endField) from else s"$from TO $to"}" - val retCv = withResource(new CloseableHolder(micros.incRefCount())) { restHolder => - withResource(new ArrayBuffer[ColumnView]) { parts => + val retCv = withResource(new ArrayBuffer[ColumnView]) { parts => // prefix with sign part: INTERVAL ' or INTERVAL '- parts += withResource(Scalar.fromLong(0L)) { zero => - withResource(restHolder.get.lessThan(zero)) { less => + withResource(micros.lessThan(zero)) { less => withResource(Scalar.fromString(prefixStr + "-")) { negPrefix => withResource(Scalar.fromString(prefixStr)) { prefix => less.ifElse(negPrefix, prefix) @@ -577,7 +576,7 @@ object GpuIntervalUtils extends Arm { } // calculate abs, abs(Long.MinValue) will overflow, handle in the last as special case - restHolder.setAndCloseOld(restHolder.get.abs()) + withResource(new CloseableHolder(micros.abs())) { restHolder => startField match { case DT.DAY => @@ -742,7 +741,7 @@ object GpuIntervalUtils extends Arm { } } - def dayTimeIntervalToLong(dtCv: ColumnVector, dt: DataType): ColumnVector = { + def dayTimeIntervalToLong(dtCv: ColumnView, dt: DataType): ColumnVector = { dt.asInstanceOf[DT].endField match { case DT.DAY => withResource(Scalar.fromLong(MICROS_PER_DAY)) { micros => dtCv.div(micros) @@ -759,25 +758,25 @@ object GpuIntervalUtils extends Arm { } } - def dayTimeIntervalToInt(dtCv: ColumnVector, dt: DataType): ColumnVector = { + def dayTimeIntervalToInt(dtCv: ColumnView, dt: DataType): ColumnVector = { withResource(dayTimeIntervalToLong(dtCv, dt)) { longCv => castToTargetWithOverflowCheck(longCv, DType.INT32) } } - def dayTimeIntervalToShort(dtCv: ColumnVector, dt: DataType): ColumnVector = { + def dayTimeIntervalToShort(dtCv: ColumnView, dt: DataType): ColumnVector = { withResource(dayTimeIntervalToLong(dtCv, dt)) { longCv => castToTargetWithOverflowCheck(longCv, DType.INT16) } } - def dayTimeIntervalToByte(dtCv: ColumnVector, dt: DataType): ColumnVector = { + def dayTimeIntervalToByte(dtCv: ColumnView, dt: DataType): ColumnVector = { withResource(dayTimeIntervalToLong(dtCv, dt)) { longCv => castToTargetWithOverflowCheck(longCv, DType.INT8) } } - def yearMonthIntervalToLong(ymCv: ColumnVector, ym: DataType): ColumnVector = { + def yearMonthIntervalToLong(ymCv: ColumnView, ym: DataType): ColumnVector = { ym.asInstanceOf[YM].endField match { case YM.YEAR => withResource(Scalar.fromLong(MONTHS_PER_YEAR)) { monthsPerYear => ymCv.div(monthsPerYear) @@ -786,28 +785,28 @@ object GpuIntervalUtils extends Arm { } } - def yearMonthIntervalToInt(ymCv: ColumnVector, ym: DataType): ColumnVector = { + def yearMonthIntervalToInt(ymCv: ColumnView, ym: DataType): ColumnVector = { ym.asInstanceOf[YM].endField match { case YM.YEAR => withResource(Scalar.fromInt(MONTHS_PER_YEAR)) { monthsPerYear => ymCv.div(monthsPerYear) } - case YM.MONTH => ymCv.incRefCount() + case YM.MONTH => ymCv.copyToColumnVector() } } - def yearMonthIntervalToShort(ymCv: ColumnVector, ym: DataType): ColumnVector = { + def yearMonthIntervalToShort(ymCv: ColumnView, ym: DataType): ColumnVector = { withResource(yearMonthIntervalToInt(ymCv, ym)) { i => castToTargetWithOverflowCheck(i, DType.INT16) } } - def yearMonthIntervalToByte(ymCv: ColumnVector, ym: DataType): ColumnVector = { + def yearMonthIntervalToByte(ymCv: ColumnView, ym: DataType): ColumnVector = { withResource(yearMonthIntervalToInt(ymCv, ym)) { i => castToTargetWithOverflowCheck(i, DType.INT8) } } - private def castToTargetWithOverflowCheck(cv: ColumnVector, dType: DType): ColumnVector = { + private def castToTargetWithOverflowCheck(cv: ColumnView, dType: DType): ColumnVector = { withResource(cv.castTo(dType)) { retTarget => withResource(cv.notEqualTo(retTarget)) { notEqual => if (BoolUtils.isAnyValidTrue(notEqual)) { @@ -822,7 +821,7 @@ object GpuIntervalUtils extends Arm { /** * Convert long cv to `day time interval` */ - def longToDayTimeInterval(longCv: ColumnVector, dt: DataType): ColumnVector = { + def longToDayTimeInterval(longCv: ColumnView, dt: DataType): ColumnVector = { val microsScalar = dt.asInstanceOf[DT].endField match { case DT.DAY => Scalar.fromLong(MICROS_PER_DAY) case DT.HOUR => Scalar.fromLong(MICROS_PER_HOUR) @@ -838,11 +837,17 @@ object GpuIntervalUtils extends Arm { /** * Convert (byte | short | int) cv to `day time interval` */ - def intToDayTimeInterval(intCv: ColumnVector, dt: DataType): ColumnVector = { + def intToDayTimeInterval(intCv: ColumnView, dt: DataType): ColumnVector = { dt.asInstanceOf[DT].endField match { case DT.DAY => withResource(Scalar.fromLong(MICROS_PER_DAY)) { micros => - // leverage `Decimal 128` to check the overflow - IntervalUtils.multipleToLongWithOverflowCheck(intCv, micros) + if (intCv.getType.equals(DType.INT32)) { + // leverage `Decimal 128` to check the overflow + // Int.MaxValue * `micros` can cause overflow + IntervalUtils.multipleToLongWithOverflowCheck(intCv, micros) + } else { + // no need to check overflow for short byte types + intCv.mul(micros) + } } case DT.HOUR => withResource(Scalar.fromLong(MICROS_PER_HOUR)) { micros => // no need to check overflow @@ -862,7 +867,7 @@ object GpuIntervalUtils extends Arm { /** * Convert long cv to `year month interval` */ - def longToYearMonthInterval(longCv: ColumnVector, ym: DataType): ColumnVector = { + def longToYearMonthInterval(longCv: ColumnView, ym: DataType): ColumnVector = { ym.asInstanceOf[YM].endField match { case YM.YEAR => withResource(Scalar.fromLong(MONTHS_PER_YEAR)) { num12 => // leverage `Decimal 128` to check the overflow @@ -875,7 +880,7 @@ object GpuIntervalUtils extends Arm { /** * Convert (byte | short | int) cv to `year month interval` */ - def intToYearMonthInterval(intCv: ColumnVector, ym: DataType): ColumnVector = { + def intToYearMonthInterval(intCv: ColumnView, ym: DataType): ColumnVector = { (ym.asInstanceOf[YM].endField, intCv.getType) match { case (YM.YEAR, DType.INT32) => withResource(Scalar.fromInt(MONTHS_PER_YEAR)) { num12 => // leverage `Decimal 128` to check the overflow diff --git a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala index 424cca5e483..78eda4cc09d 100644 --- a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala +++ b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala @@ -225,4 +225,12 @@ object GpuTypeShims { * (filter, sample, project, alias, table scan ...... which GPU supports from 330) */ def additionalCommonOperatorSupportedTypes: TypeSig = TypeSig.ansiIntervals + + def hasSideEffectsIfCastIntToYearMonth(ym: DataType): Boolean = + // if cast(int as interval year), multiplication by 12 can cause overflow + ym.asInstanceOf[YearMonthIntervalType].endField == YearMonthIntervalType.YEAR + + def hasSideEffectsIfCastIntToDayTime(dt: DataType): Boolean = + // if cast(int as interval day), multiplication by (86400 * 1000000) can cause overflow + dt.asInstanceOf[DayTimeIntervalType].endField == DayTimeIntervalType.DAY } diff --git a/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala b/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala index 3364a0b07e3..ab601bbb7e5 100644 --- a/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala +++ b/sql-plugin/src/main/330+/scala/org/apache/spark/sql/rapids/shims/intervalExpressions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.rapids.shims import java.math.BigInteger -import ai.rapids.cudf.{BinaryOperable, ColumnVector, DType, RoundMode, Scalar} +import ai.rapids.cudf.{BinaryOperable, ColumnVector, ColumnView, DType, RoundMode, Scalar} import com.nvidia.spark.rapids.{Arm, BoolUtils, GpuBinaryExpression, GpuColumnVector, GpuScalar} import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes, NullIntolerant} @@ -30,7 +30,7 @@ object IntervalUtils extends Arm { * Convert long cv to int cv, throws exception if any value in `longCv` exceeds the int limits. * Check (int)(long_value) == long_value */ - def castLongToIntWithOverflowCheck(longCv: ColumnVector): ColumnVector = { + def castLongToIntWithOverflowCheck(longCv: ColumnView): ColumnVector = { withResource(longCv.castTo(DType.INT32)) { intResult => withResource(longCv.notEqualTo(intResult)) { notEquals => if (BoolUtils.isAnyValidTrue(notEquals)) { @@ -42,7 +42,7 @@ object IntervalUtils extends Arm { } } - def checkDecimal128CvInRange(decimal128Cv: ColumnVector, minValue: Long, maxValue: Long): Unit = { + def checkDecimal128CvInRange(decimal128Cv: ColumnView, minValue: Long, maxValue: Long): Unit = { // check min withResource(Scalar.fromLong(minValue)) { minScalar => withResource(decimal128Cv.lessThan(minScalar)) { lessThanMin => diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala index 5fc48361544..bad4fb0c9d3 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala @@ -555,45 +555,44 @@ object GpuCast extends Arm { castMapToString(input, from, ansiMode, legacyCastToString, stringToDateAnsiModeEnabled) case (dayTime: DataType, _: StringType) if GpuTypeShims.isSupportedDayTimeType(dayTime) => - GpuIntervalUtils.toDayTimeIntervalString(input.asInstanceOf[ColumnVector], dayTime) + GpuIntervalUtils.toDayTimeIntervalString(input, dayTime) case (_: StringType, dayTime: DataType) if GpuTypeShims.isSupportedDayTimeType(dayTime) => - GpuIntervalUtils.castStringToDayTimeIntervalWithThrow( - input.asInstanceOf[ColumnVector], dayTime) + GpuIntervalUtils.castStringToDayTimeIntervalWithThrow(input, dayTime) // cast(`day time interval` as integral) case (dt: DataType, _: LongType) if GpuTypeShims.isSupportedDayTimeType(dt) => - GpuIntervalUtils.dayTimeIntervalToLong(input.asInstanceOf[ColumnVector], dt) + GpuIntervalUtils.dayTimeIntervalToLong(input, dt) case (dt: DataType, _: IntegerType) if GpuTypeShims.isSupportedDayTimeType(dt) => - GpuIntervalUtils.dayTimeIntervalToInt(input.asInstanceOf[ColumnVector], dt) + GpuIntervalUtils.dayTimeIntervalToInt(input, dt) case (dt: DataType, _: ShortType) if GpuTypeShims.isSupportedDayTimeType(dt) => - GpuIntervalUtils.dayTimeIntervalToShort(input.asInstanceOf[ColumnVector], dt) + GpuIntervalUtils.dayTimeIntervalToShort(input, dt) case (dt: DataType, _: ByteType) if GpuTypeShims.isSupportedDayTimeType(dt) => - GpuIntervalUtils.dayTimeIntervalToByte(input.asInstanceOf[ColumnVector], dt) + GpuIntervalUtils.dayTimeIntervalToByte(input, dt) // cast(integral as `day time interval`) case (_: LongType, dt: DataType) if GpuTypeShims.isSupportedDayTimeType(dt) => - GpuIntervalUtils.longToDayTimeInterval(input.asInstanceOf[ColumnVector], dt) + GpuIntervalUtils.longToDayTimeInterval(input, dt) case (_: IntegerType | ShortType | ByteType, dt: DataType) if GpuTypeShims.isSupportedDayTimeType(dt) => - GpuIntervalUtils.intToDayTimeInterval(input.asInstanceOf[ColumnVector], dt) + GpuIntervalUtils.intToDayTimeInterval(input, dt) // cast(`year month interval` as integral) case (ym: DataType, _: LongType) if GpuTypeShims.isSupportedYearMonthType(ym) => - GpuIntervalUtils.yearMonthIntervalToLong(input.asInstanceOf[ColumnVector], ym) + GpuIntervalUtils.yearMonthIntervalToLong(input, ym) case (ym: DataType, _: IntegerType) if GpuTypeShims.isSupportedYearMonthType(ym) => - GpuIntervalUtils.yearMonthIntervalToInt(input.asInstanceOf[ColumnVector], ym) + GpuIntervalUtils.yearMonthIntervalToInt(input, ym) case (ym: DataType, _: ShortType) if GpuTypeShims.isSupportedYearMonthType(ym) => - GpuIntervalUtils.yearMonthIntervalToShort(input.asInstanceOf[ColumnVector], ym) + GpuIntervalUtils.yearMonthIntervalToShort(input, ym) case (ym: DataType, _: ByteType) if GpuTypeShims.isSupportedYearMonthType(ym) => - GpuIntervalUtils.yearMonthIntervalToByte(input.asInstanceOf[ColumnVector], ym) + GpuIntervalUtils.yearMonthIntervalToByte(input, ym) // cast(integral as `year month interval`) case (_: LongType, ym: DataType) if GpuTypeShims.isSupportedYearMonthType(ym) => - GpuIntervalUtils.longToYearMonthInterval(input.asInstanceOf[ColumnVector], ym) + GpuIntervalUtils.longToYearMonthInterval(input, ym) case (_: IntegerType | ShortType | ByteType, ym: DataType) if GpuTypeShims.isSupportedYearMonthType(ym) => - GpuIntervalUtils.intToYearMonthInterval(input.asInstanceOf[ColumnVector], ym) + GpuIntervalUtils.intToYearMonthInterval(input, ym) case _ => input.castTo(GpuColumnVector.getNonNestedRapidsType(toDataType)) } @@ -1580,6 +1579,20 @@ case class GpuCast( case (FloatType | DoubleType, ShortType) if ansiMode => true case (FloatType | DoubleType, IntegerType) if ansiMode => true case (FloatType | DoubleType, LongType) if ansiMode => true + case (_: LongType, dayTimeIntervalType: DataType) + if GpuTypeShims.isSupportedDayTimeType(dayTimeIntervalType) => true + case (_: IntegerType, dayTimeIntervalType: DataType) + if GpuTypeShims.isSupportedDayTimeType(dayTimeIntervalType) => + GpuTypeShims.hasSideEffectsIfCastIntToDayTime(dayTimeIntervalType) + case (dayTimeIntervalType: DataType, _: IntegerType | ShortType | ByteType) + if GpuTypeShims.isSupportedDayTimeType(dayTimeIntervalType) => true + case (_: LongType, yearMonthIntervalType: DataType) + if GpuTypeShims.isSupportedYearMonthType(yearMonthIntervalType) => true + case (_: IntegerType, yearMonthIntervalType: DataType) + if GpuTypeShims.isSupportedYearMonthType(yearMonthIntervalType) => + GpuTypeShims.hasSideEffectsIfCastIntToYearMonth(yearMonthIntervalType) + case (yearMonthIntervalType: DataType, _: ShortType | ByteType) + if GpuTypeShims.isSupportedYearMonthType(yearMonthIntervalType) => true case _ => false } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 04d772b4f95..138ea30fa87 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -2086,16 +2086,16 @@ object GpuOverrides extends Logging { "IF expression", ExprChecks.projectOnly( (_gpuCommonTypes + TypeSig.DECIMAL_128 + TypeSig.ARRAY + TypeSig.STRUCT + - TypeSig.MAP).nested(), + TypeSig.MAP + GpuTypeShims.additionalCommonOperatorSupportedTypes).nested(), TypeSig.all, Seq(ParamCheck("predicate", TypeSig.BOOLEAN, TypeSig.BOOLEAN), ParamCheck("trueValue", (_gpuCommonTypes + TypeSig.DECIMAL_128 + TypeSig.ARRAY + TypeSig.STRUCT + - TypeSig.MAP).nested(), + TypeSig.MAP + GpuTypeShims.additionalCommonOperatorSupportedTypes).nested(), TypeSig.all), ParamCheck("falseValue", (_gpuCommonTypes + TypeSig.DECIMAL_128 + TypeSig.ARRAY + TypeSig.STRUCT + - TypeSig.MAP).nested(), + TypeSig.MAP + GpuTypeShims.additionalCommonOperatorSupportedTypes).nested(), TypeSig.all))), (a, conf, p, r) => new ExprMeta[If](a, conf, p, r) { override def convertToGpu(): GpuExpression = { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala index 87855209f10..6f4a458dc7b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala @@ -1334,14 +1334,14 @@ class CastChecks extends ExprChecks { val arrayChecks: TypeSig = psNote(TypeEnum.STRING, "the array's child type must also support " + "being cast to string") + ARRAY.nested(commonCudfTypes + DECIMAL_128 + NULL + - ARRAY + BINARY + STRUCT + MAP) + + ARRAY + BINARY + STRUCT + MAP + GpuTypeShims.additionalCommonOperatorSupportedTypes) + psNote(TypeEnum.ARRAY, "The array's child type must also support being cast to the " + "desired child type(s)") val sparkArraySig: TypeSig = STRING + ARRAY.nested(all) val mapChecks: TypeSig = MAP.nested(commonCudfTypes + DECIMAL_128 + NULL + ARRAY + BINARY + - STRUCT + MAP) + + STRUCT + MAP + GpuTypeShims.additionalCommonOperatorSupportedTypes) + psNote(TypeEnum.MAP, "the map's key and value must also support being cast to the " + "desired child types") + psNote(TypeEnum.STRING, "the map's key and value must also support being cast to string") @@ -1349,7 +1349,8 @@ class CastChecks extends ExprChecks { val structChecks: TypeSig = psNote(TypeEnum.STRING, "the struct's children must also support " + "being cast to string") + - STRUCT.nested(commonCudfTypes + DECIMAL_128 + NULL + ARRAY + BINARY + STRUCT + MAP) + + STRUCT.nested(commonCudfTypes + DECIMAL_128 + NULL + ARRAY + BINARY + STRUCT + MAP + + GpuTypeShims.additionalCommonOperatorSupportedTypes) + psNote(TypeEnum.STRUCT, "the struct's children must also support being cast to the " + "desired child type(s)") val sparkStructSig: TypeSig = STRING + STRUCT.nested(all) diff --git a/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalCastSuite.scala b/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalCastSuite.scala index 3f7fde597e7..43eb1c189f6 100644 --- a/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalCastSuite.scala +++ b/tests/src/test/330/scala/com/nvidia/spark/rapids/IntervalCastSuite.scala @@ -18,8 +18,9 @@ package com.nvidia.spark.rapids import java.time.Period import org.apache.spark.SparkException +import org.apache.spark.sql.{functions => f} import org.apache.spark.sql.Row -import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, IntegerType, LongType, MapType, ShortType, StructField, StructType, YearMonthIntervalType => YM} /** * Can not put this suite to Pyspark test cases @@ -33,7 +34,7 @@ class IntervalCastSuite extends SparkQueryCompareTestSuite { "test cast year-month to integral", spark => { val data = (-128 to 127).map(i => Row(Period.ofMonths(i))) - val schema = StructType(Seq(StructField("c_ym", YearMonthIntervalType()))) + val schema = StructType(Seq(StructField("c_ym", YM()))) spark.createDataFrame(spark.sparkContext.parallelize(data), schema) }) { df => @@ -68,7 +69,7 @@ class IntervalCastSuite extends SparkQueryCompareTestSuite { e => e.getMessage.contains("overflow"), spark => { val data = Seq(Row(Period.ofMonths(months))) - val schema = StructType(Seq(StructField("c_ym", YearMonthIntervalType()))) + val schema = StructType(Seq(StructField("c_ym", YM()))) spark.createDataFrame(spark.sparkContext.parallelize(data), schema) }) { df => @@ -99,4 +100,187 @@ class IntervalCastSuite extends SparkQueryCompareTestSuite { } testLoop += 1 } + + testSparkResultsAreEqual( + "test cast struct(integral, integral) to struct(year-month, year-month)", + spark => { + val data = (-128 to 127).map { i => + Row( + Row(i.toLong, i.toLong), + Row(i, i), + Row(i.toShort, i.toShort), + Row(i.toByte, i.toByte)) + } + val schema = StructType(Seq( + StructField("c_l", StructType( + Seq(StructField("c1", LongType), StructField("c2", LongType)))), + StructField("c_i", StructType( + Seq(StructField("c1", IntegerType), StructField("c2", IntegerType)))), + StructField("c_s", StructType( + Seq(StructField("c1", ShortType), StructField("c2", ShortType)))), + StructField("c_b", StructType( + Seq(StructField("c1", ByteType), StructField("c2", ByteType)))))) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + }) { + df => + df.select( + f.col("c_l").cast(StructType( + Seq(StructField("c1", YM(YM.YEAR, YM.YEAR)), StructField("c2", YM(YM.MONTH, YM.MONTH))))), + f.col("c_i").cast(StructType( + Seq(StructField("c1", YM(YM.YEAR, YM.YEAR)), StructField("c2", YM(YM.MONTH, YM.MONTH))))), + f.col("c_s").cast(StructType( + Seq(StructField("c1", YM(YM.YEAR, YM.YEAR)), StructField("c2", YM(YM.MONTH, YM.MONTH))))), + f.col("c_b").cast(StructType( + Seq(StructField("c1", YM(YM.YEAR, YM.YEAR)), StructField("c2", YM(YM.MONTH, YM.MONTH)))))) + } + + testSparkResultsAreEqual( + "test cast array(integral) to array(year-month)", + spark => { + val data = (-128 to 127).map { i => + Row( + Seq(i.toLong, i.toLong), + Seq(i, i), + Seq(i.toShort, i.toShort), + Seq(i.toByte, i.toByte)) + } + val schema = StructType(Seq( + StructField("c_l", ArrayType(LongType)), + StructField("c_i", ArrayType(IntegerType)), + StructField("c_s", ArrayType(ShortType)), + StructField("c_b", ArrayType(ByteType)))) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + }) { + df => + df.select( + f.col("c_l").cast(ArrayType(YM(YM.YEAR, YM.YEAR))), + f.col("c_i").cast(ArrayType(YM(YM.MONTH, YM.MONTH))), + f.col("c_s").cast(ArrayType(YM(YM.YEAR, YM.YEAR))), + f.col("c_b").cast(ArrayType(YM(YM.MONTH, YM.MONTH)))) + } + + testSparkResultsAreEqual( + "test cast map(integral, integral) to map(year-month, year-month)", + spark => { + val data = (-128 to 127).map { i => + Row( + Map((i.toLong, i.toLong)), + Map((i, i)), + Map((i.toShort, i.toShort)), + Map((i.toByte, i.toByte))) + } + val schema = StructType(Seq( + StructField("c_l", MapType(LongType, LongType)), + StructField("c_i", MapType(IntegerType, IntegerType)), + StructField("c_s", MapType(ShortType, ShortType)), + StructField("c_b", MapType(ByteType, ByteType)))) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + }) { + df => + df.select( + f.col("c_l").cast(MapType(YM(YM.YEAR, YM.YEAR), YM(YM.YEAR, YM.YEAR))), + f.col("c_i").cast(MapType(YM(YM.MONTH, YM.MONTH), YM(YM.MONTH, YM.MONTH))), + f.col("c_s").cast(MapType(YM(YM.YEAR, YM.YEAR), YM(YM.YEAR, YM.YEAR))), + f.col("c_b").cast(MapType(YM(YM.MONTH, YM.MONTH), YM(YM.MONTH, YM.MONTH)))) + } + + testSparkResultsAreEqual( + "test cast struct(year-month, year-month) to struct(integral, integral)", + spark => { + val data = (-128 to 127).map { i => + Row(Row(Period.ofMonths(i), Period.ofMonths(i))) + } + val schema = StructType(Seq(StructField("c_ym", StructType(Seq( + StructField("c1", YM(YM.MONTH, YM.MONTH)), + StructField("c2", YM(YM.MONTH, YM.MONTH))))))) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + }) { + df => + df.select( + f.col("c_ym").cast(StructType( + Seq(StructField("c1", LongType), StructField("c2", LongType)))), + f.col("c_ym").cast(StructType( + Seq(StructField("c1", IntegerType), StructField("c2", IntegerType)))), + f.col("c_ym").cast(StructType( + Seq(StructField("c1", ShortType), StructField("c2", ShortType)))), + f.col("c_ym").cast(StructType( + Seq(StructField("c1", ByteType), StructField("c2", ByteType))))) + } + + testSparkResultsAreEqual( + "test cast array(year-month) to array(integral)", + spark => { + val data = (-128 to 127).map { i => + Row(Seq(Period.ofMonths(i), Period.ofMonths(i))) + } + val schema = StructType(Seq(StructField("c_ym", ArrayType(YM(YM.MONTH, YM.MONTH))))) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + }) { + df => + df.select( + f.col("c_ym").cast(ArrayType(LongType)), + f.col("c_ym").cast(ArrayType(IntegerType)), + f.col("c_ym").cast(ArrayType(ShortType)), + f.col("c_ym").cast(ArrayType(ByteType))) + } + + testSparkResultsAreEqual( + "test cast map(year-month, year-month) to map(integral, integral)", + spark => { + val data = (-128 to 127).map { i => + Row(Map((Period.ofMonths(i), Period.ofMonths(i)))) + } + val schema = StructType(Seq(StructField("c_ym", + MapType(YM(YM.MONTH, YM.MONTH), YM(YM.MONTH, YM.MONTH))))) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + }) { + df => + df.select( + f.col("c_ym").cast(MapType(LongType, LongType)), + f.col("c_ym").cast(MapType(IntegerType, IntegerType)), + f.col("c_ym").cast(MapType(ShortType, ShortType)), + f.col("c_ym").cast(MapType(ByteType, ByteType))) + } + + testSparkResultsAreEqual( + "test cast(ym as (byte or short)) side effect", + spark => { + val data = (-128 to 127).map { i => + val boolean = if (i % 2 == 0) true else false + val sideEffectValue = Period.ofMonths(Int.MaxValue) + val ymValue = if (boolean) sideEffectValue else Period.ofMonths(i) + Row(boolean, ymValue) + } + val schema = StructType(Seq(StructField("c_b", BooleanType), + StructField("c_ym", YM(YM.MONTH, YM.MONTH)))) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + }) { + df => + // when `c_b` is true, the `c_ym` is Period.ofMonths(Int.MaxValue) + // cast(Period.ofMonths(Int.MaxValue) as byte) will overflow + df.selectExpr("if(c_b, cast(0 as byte), cast(c_ym as byte))", + "if(c_b, cast(0 as short), cast(c_ym as short))") + } + + testSparkResultsAreEqual( + "test cast((long or int) as year-month) side effect", + spark => { + val data = (-128 to 127).map { i => + val boolean = if (i % 2 == 0) true else false + val sideEffectLongValue = Long.MaxValue + val sideEffectIntValue = Int.MaxValue + Row(boolean, + if (boolean) sideEffectLongValue else i.toLong, + if (boolean) sideEffectIntValue else i + ) + } + val schema = StructType(Seq(StructField("c_b", BooleanType), + StructField("c_l", LongType), + StructField("c_i", IntegerType))) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + }) { + df => + df.selectExpr("if(c_b, interval 0 month, cast(c_l as interval month))", + "if(c_b, interval 0 year, cast(c_i as interval year))") + } } From e4f88b40a43f59984da7bd890808987e7588a848 Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Tue, 10 May 2022 20:58:50 +0800 Subject: [PATCH 4/4] Fix --- integration_tests/src/main/python/cast_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/integration_tests/src/main/python/cast_test.py b/integration_tests/src/main/python/cast_test.py index ddf3d77a6a6..b838af92586 100644 --- a/integration_tests/src/main/python/cast_test.py +++ b/integration_tests/src/main/python/cast_test.py @@ -436,13 +436,15 @@ def getDf(spark): conf={}, error_message="overflow") +@pytest.mark.skipif(is_before_spark_330(), reason='casting between interval and integral is not supported before Pyspark 3.3.0') def test_cast_integral_to_day_time_side_effect(): def getDf(spark): # INT_MAX > 106751991 (max value of interval day) return spark.createDataFrame([(True, INT_MAX, LONG_MAX), (False, 0, 0)], "c_b boolean, c_i int, c_l long").repartition(1) assert_gpu_and_cpu_are_equal_collect( - lambda spark: getDf(spark).selectExpr("if(c_b, interval 0 day, cast(c_i as interval day))", "if(c_b, interval 0 day, cast(c_l as interval second))")) + lambda spark: getDf(spark).selectExpr("if(c_b, interval 0 day, cast(c_i as interval day))", "if(c_b, interval 0 second, cast(c_l as interval second))")) +@pytest.mark.skipif(is_before_spark_330(), reason='casting between interval and integral is not supported before Pyspark 3.3.0') def test_cast_day_time_to_integral_side_effect(): def getDf(spark): # 106751991 > Byte.MaxValue