Skip to content

Commit

Permalink
Support casting between day-time interval and string (#5155)
Browse files Browse the repository at this point in the history
Signed-off-by: Chong Gao <res_life@163.com>
  • Loading branch information
res-life authored Apr 18, 2022
1 parent 0b4465f commit 5251bac
Show file tree
Hide file tree
Showing 9 changed files with 520 additions and 62 deletions.
38 changes: 38 additions & 0 deletions integration_tests/src/main/python/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,41 @@ def fun(spark):
df = spark.createDataFrame(data, DoubleType())
return df.select(f.col('value').cast(TimestampType())).collect()
assert_gpu_and_cpu_error(fun, {"spark.sql.ansi.enabled": True}, "java.time.DateTimeException")

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
def test_cast_day_time_interval_to_string():
_assert_cast_to_string_equal(DayTimeIntervalGen(start_field='day', end_field='day', special_cases=[MIN_DAY_TIME_INTERVAL, MAX_DAY_TIME_INTERVAL, timedelta(seconds=0)]), {})
_assert_cast_to_string_equal(DayTimeIntervalGen(start_field='day', end_field='hour', special_cases=[MIN_DAY_TIME_INTERVAL, MAX_DAY_TIME_INTERVAL, timedelta(seconds=0)]), {})
_assert_cast_to_string_equal(DayTimeIntervalGen(start_field='day', end_field='minute', special_cases=[MIN_DAY_TIME_INTERVAL, MAX_DAY_TIME_INTERVAL, timedelta(seconds=0)]), {})
_assert_cast_to_string_equal(DayTimeIntervalGen(start_field='day', end_field='second', special_cases=[MIN_DAY_TIME_INTERVAL, MAX_DAY_TIME_INTERVAL, timedelta(seconds=0)]), {})
_assert_cast_to_string_equal(DayTimeIntervalGen(start_field='hour', end_field='hour', special_cases=[MIN_DAY_TIME_INTERVAL, MAX_DAY_TIME_INTERVAL, timedelta(seconds=0)]), {})
_assert_cast_to_string_equal(DayTimeIntervalGen(start_field='hour', end_field='minute', special_cases=[MIN_DAY_TIME_INTERVAL, MAX_DAY_TIME_INTERVAL, timedelta(seconds=0)]), {})
_assert_cast_to_string_equal(DayTimeIntervalGen(start_field='hour', end_field='second', special_cases=[MIN_DAY_TIME_INTERVAL, MAX_DAY_TIME_INTERVAL, timedelta(seconds=0)]), {})
_assert_cast_to_string_equal(DayTimeIntervalGen(start_field='minute', end_field='minute', special_cases=[MIN_DAY_TIME_INTERVAL, MAX_DAY_TIME_INTERVAL, timedelta(seconds=0)]), {})
_assert_cast_to_string_equal(DayTimeIntervalGen(start_field='minute', end_field='second', special_cases=[MIN_DAY_TIME_INTERVAL, MAX_DAY_TIME_INTERVAL, timedelta(seconds=0)]), {})
_assert_cast_to_string_equal(DayTimeIntervalGen(start_field='second', end_field='second', special_cases=[MIN_DAY_TIME_INTERVAL, MAX_DAY_TIME_INTERVAL, timedelta(seconds=0)]), {})

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
def test_cast_string_to_day_time_interval():
gen = DayTimeIntervalGen(start_field='day', end_field='second', special_cases=[MIN_DAY_TIME_INTERVAL, MAX_DAY_TIME_INTERVAL, timedelta(seconds=0)])
dtType = DayTimeIntervalType(0, 3) # 0 is day; 3 is second
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen).select(f.col('a').cast(StringType()).cast(dtType)))

gen = DayTimeIntervalGen(start_field='hour', end_field='second', special_cases=[MIN_DAY_TIME_INTERVAL, MAX_DAY_TIME_INTERVAL, timedelta(seconds=0)])
dtType = DayTimeIntervalType(1, 3) # 1 is hour; 3 is second
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, gen).select(f.col('a').cast(StringType()).cast(dtType)))

@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0')
@pytest.mark.parametrize('invalid_string', [
"INTERVAL 'xxx' DAY TO SECOND", # invalid format
"-999999999 04:00:54.775808000" # exceeds min value, min value is "-106751991 04:00:54.775808000"
])
def test_cast_string_to_day_time_interval_exception(invalid_string):
dtType = DayTimeIntervalType(0, 3)
def fun(spark):
data=[invalid_string]
df = spark.createDataFrame(data, StringType())
return df.select(f.col('value').cast(dtType)).collect()
assert_gpu_and_cpu_error(fun, {}, "java.lang.IllegalArgumentException")
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.nvidia.spark.rapids.shims

import ai.rapids.cudf.ColumnVector

import org.apache.spark.sql.types.DataType

/**
* Should not support in this Shim
*/
object GpuIntervalUtils {

def castStringToDayTimeIntervalWithThrow(cv: ColumnVector, t: DataType): ColumnVector = {
throw new IllegalStateException()
}

def toDayTimeIntervalString(micros: ColumnVector, dayTimeType: DataType): ColumnVector = {
throw new IllegalStateException()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,9 @@ object GpuTypeShims {
throw new RuntimeException(s"Not support type $dt.")

/**
* Whether the Shim supports day-time interval type
* Alias, Add, Subtract, Positive... operators do not support day-time interval type
* Whether the Shim supports day-time interval type for specific operator
* Alias, Add, Subtract, Positive... operators do not support day-time interval type on this Shim
* Note: Spark 3.2.x does support `DayTimeIntervalType`, this is for the GPU operators
*/
def isSupportedDayTimeType(dt: DataType): Boolean = false

Expand All @@ -109,6 +110,10 @@ object GpuTypeShims {
*/
def additionalCsvSupportedTypes: TypeSig = TypeSig.none

def typesDayTimeCanCastTo: TypeSig = TypeSig.none

def additionalTypesStringCanCastTo: TypeSig = TypeSig.none

/**
* Get additional Parquet supported types for this Shim
*/
Expand Down
Loading

0 comments on commit 5251bac

Please sign in to comment.