Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
  • Loading branch information
thirtiseven committed Dec 13, 2023
1 parent 69223e9 commit f3cbbb4
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 37 deletions.
40 changes: 6 additions & 34 deletions integration_tests/src/main/python/date_time_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,48 +130,20 @@ def test_datediff(data_gen):
'datediff(a, date(null))',
'datediff(a, \'2016-03-02\')'))

@pytest.mark.skipif(not is_supported_time_zone(), reason="not all time zones are supported now, refer to https://github.com/NVIDIA/spark-rapids/issues/6839, please update after all time zones are supported")
@allow_non_gpu(*non_utc_allow)
def test_hour():
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, timestamp_gen).selectExpr('hour(a)'),
conf = {'spark.rapids.sql.nonUTC.enabled': True})
lambda spark : unary_op_df(spark, timestamp_gen).selectExpr('hour(a)'))

@allow_non_gpu('ProjectExec')
@pytest.mark.skipif(is_supported_time_zone(), reason="not all time zones are supported now, refer to https://github.com/NVIDIA/spark-rapids/issues/6839, please update after all time zones are supported")
def test_hour_fallback():
assert_gpu_fallback_collect(
lambda spark : unary_op_df(spark, timestamp_gen).selectExpr('hour(a)'),
'ProjectExec',
conf = {'spark.rapids.sql.nonUTC.enabled': True})

@pytest.mark.skipif(not is_supported_time_zone(), reason="not all time zones are supported now, refer to https://github.com/NVIDIA/spark-rapids/issues/6839, please update after all time zones are supported")
@allow_non_gpu(*non_utc_allow)
def test_minute():
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, timestamp_gen).selectExpr('minute(a)'),
conf = {'spark.rapids.sql.nonUTC.enabled': True})

@allow_non_gpu('ProjectExec')
@pytest.mark.skipif(is_supported_time_zone(), reason="not all time zones are supported now, refer to https://github.com/NVIDIA/spark-rapids/issues/6839, please update after all time zones are supported")
def test_minute_fallback():
assert_gpu_fallback_collect(
lambda spark : unary_op_df(spark, timestamp_gen).selectExpr('minute(a)'),
'ProjectExec',
conf = {'spark.rapids.sql.nonUTC.enabled': True})
lambda spark : unary_op_df(spark, timestamp_gen).selectExpr('minute(a)'))

@pytest.mark.skipif(not is_supported_time_zone(), reason="not all time zones are supported now, refer to https://github.com/NVIDIA/spark-rapids/issues/6839, please update after all time zones are supported")
@allow_non_gpu(*non_utc_allow)
def test_second():
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, timestamp_gen).selectExpr('second(a)'),
conf = {'spark.rapids.sql.nonUTC.enabled': True})

@allow_non_gpu('ProjectExec')
@pytest.mark.skipif(is_supported_time_zone(), reason="not all time zones are supported now, refer to https://github.com/NVIDIA/spark-rapids/issues/6839, please update after all time zones are supported")
def test_second_fallback():
assert_gpu_fallback_collect(
lambda spark : unary_op_df(spark, timestamp_gen).selectExpr('second(a)'),
'ProjectExec',
conf = {'spark.rapids.sql.nonUTC.enabled': True})

lambda spark : unary_op_df(spark, timestamp_gen).selectExpr('second(a)'))

def test_quarter():
assert_gpu_and_cpu_are_equal_collect(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ case class GpuMinute(child: Expression, timeZoneId: Option[String] = None)
input.getBase.minute()
} else {
// Non-UTC time zone
withResource(GpuTimeZoneDB.fromUtcTimestampToTimestamp(input.getBase, zoneId.normalized())) {
withResource(GpuTimeZoneDB.fromUtcTimestampToTimestamp(input.getBase, zoneId)) {
shifted => shifted.minute()
}
}
Expand All @@ -112,7 +112,7 @@ case class GpuSecond(child: Expression, timeZoneId: Option[String] = None)
input.getBase.second()
} else {
// Non-UTC time zone
withResource(GpuTimeZoneDB.fromUtcTimestampToTimestamp(input.getBase, zoneId.normalized())) {
withResource(GpuTimeZoneDB.fromUtcTimestampToTimestamp(input.getBase, zoneId)) {
shifted => shifted.second()
}
}
Expand All @@ -129,7 +129,7 @@ case class GpuHour(child: Expression, timeZoneId: Option[String] = None)
input.getBase.hour()
} else {
// Non-UTC time zone
withResource(GpuTimeZoneDB.fromUtcTimestampToTimestamp(input.getBase, zoneId.normalized())) {
withResource(GpuTimeZoneDB.fromUtcTimestampToTimestamp(input.getBase, zoneId)) {
shifted => shifted.hour()
}
}
Expand Down

0 comments on commit f3cbbb4

Please sign in to comment.