Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support reading ANSI day time interval type from CSV source #4927

Merged
merged 9 commits into from
Mar 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions integration_tests/src/main/python/csv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,3 +485,29 @@ def do_csv_scan(spark):
do_csv_scan,
exist_classes= "FileSourceScanExec",
non_exist_classes= "GpuBatchScanExec")

@pytest.mark.skipif(is_before_spark_330(), reason='Reading day-time interval type is supported from Spark3.3.0')
@pytest.mark.parametrize('v1_enabled_list', ["", "csv"])
def test_round_trip_for_interval(spark_tmp_path, v1_enabled_list):
csv_interval_gens = [
DayTimeIntervalGen(start_field="day", end_field="day"),
DayTimeIntervalGen(start_field="day", end_field="hour"),
DayTimeIntervalGen(start_field="day", end_field="minute"),
DayTimeIntervalGen(start_field="day", end_field="second"),
DayTimeIntervalGen(start_field="hour", end_field="hour"),
DayTimeIntervalGen(start_field="hour", end_field="minute"),
DayTimeIntervalGen(start_field="hour", end_field="second"),
DayTimeIntervalGen(start_field="minute", end_field="minute"),
DayTimeIntervalGen(start_field="minute", end_field="second"),
DayTimeIntervalGen(start_field="second", end_field="second"),
]

gen = StructGen([('_c' + str(i), csv_interval_gens[i]) for i in range(0, len(csv_interval_gens))], nullable=False)
data_path = spark_tmp_path + '/CSV_DATA'
schema = gen.data_type
updated_conf = copy_and_update(_enable_all_types_conf, {'spark.sql.sources.useV1SourceList': v1_enabled_list})
with_cpu_session(
lambda spark: gen_df(spark, gen).write.csv(data_path))
assert_gpu_and_cpu_are_equal_collect(
lambda spark: spark.read.schema(schema).csv(data_path),
conf=updated_conf)
48 changes: 30 additions & 18 deletions integration_tests/src/main/python/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,31 +613,43 @@ def make_null():
self._start(rand, make_null)

# DayTimeIntervalGen is for Spark 3.3.0+
# DayTimeIntervalType(startField, endField): Represents a day-time interval which is made up of a contiguous subset of the following fields:
# DayTimeIntervalType(startField, endField):
# Represents a day-time interval which is made up of a contiguous subset of the following fields:
# SECOND, seconds within minutes and possibly fractions of a second [0..59.999999],
# Note Spark now uses 99 as max second, see issue https://issues.apache.org/jira/browse/SPARK-38324
revans2 marked this conversation as resolved.
Show resolved Hide resolved
# If second is start field, its max value is long.max / microseconds in one second
# MINUTE, minutes within hours [0..59],
# If minute is start field, its max value is long.max / microseconds in one minute
# HOUR, hours within days [0..23],
# DAY, days in the range [0..106751991].
# If hour is start field, its max value is long.max / microseconds in one hour
# DAY, days in the range [0..106751991]. 106751991 is long.max / microseconds in one day
# For more details: https://spark.apache.org/docs/latest/sql-ref-datatypes.html
# Note: 106751991/365 = 292471 years which is much bigger than 9999 year, seems something is wrong
MIN_DAY_TIME_INTERVAL = timedelta(microseconds=-pow(2, 63))
MAX_DAY_TIME_INTERVAL = timedelta(microseconds=(pow(2, 63) - 1))
class DayTimeIntervalGen(DataGen):
"""Generate DayTimeIntervalType values"""
def __init__(self, max_days = None, nullable=True, special_cases =[timedelta(seconds = 0)]):
super().__init__(DayTimeIntervalType(), nullable=nullable, special_cases=special_cases)
if max_days is None:
self._max_days = 106751991
else:
self._max_days = max_days
def __init__(self, min_value=MIN_DAY_TIME_INTERVAL, max_value=MAX_DAY_TIME_INTERVAL, start_field="day", end_field="second",
nullable=True, special_cases=[timedelta(seconds=0)]):
# Note the nano seconds are truncated for min_value and max_value
self._min_micros = (math.floor(min_value.total_seconds()) * 1000000) + min_value.microseconds
self._max_micros = (math.floor(max_value.total_seconds()) * 1000000) + max_value.microseconds
fields = ["day", "hour", "minute", "second"]
start_index = fields.index(start_field)
end_index = fields.index(end_field)
if start_index > end_index:
raise RuntimeError('Start field {}, end field {}, valid fields is {}, start field index should <= end '
'field index'.format(start_field, end_field, fields))
super().__init__(DayTimeIntervalType(start_index, end_index), nullable=nullable, special_cases=special_cases)

def _gen_random(self, rand):
micros = rand.randint(self._min_micros, self._max_micros)
# issue: Interval types are not truncated to the expected endField when creating a DataFrame via Duration
# https://issues.apache.org/jira/browse/SPARK-38577
# If above issue is fixed, should update this DayTimeIntervalGen.
return timedelta(microseconds=micros)

def start(self, rand):
self._start(rand,
lambda : timedelta(
microseconds = rand.randint(0, 999999),
seconds = rand.randint(0, 59),
minutes = rand.randint(0, 59),
hours = rand.randint(0, 23),
days = rand.randint(0, self._max_days),
)
)
self._start(rand, lambda: self._gen_random(rand))

def skip_if_not_utc():
if (not is_tz_utc()):
Expand Down
4 changes: 2 additions & 2 deletions integration_tests/src/main/python/date_time_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def test_timeadd(data_gen):
def test_timeadd_daytime_column():
gen_list = [
# timestamp column max year is 1000
('t', TimestampGen(end = datetime(1000, 1, 1, tzinfo=timezone.utc))),
('t', TimestampGen(end=datetime(1000, 1, 1, tzinfo=timezone.utc))),
# max days is 8000 year, so added result will not be out of range
('d', DayTimeIntervalGen(max_days = 8000 * 365))]
('d', DayTimeIntervalGen(min_value=timedelta(days=0), max_value=timedelta(days=8000 * 365)))]
assert_gpu_and_cpu_are_equal_collect(
lambda spark: gen_df(spark, gen_list).selectExpr("t + d", "t + INTERVAL '1 02:03:04' DAY TO SECOND"))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package com.nvidia.spark.rapids.shims

import ai.rapids.cudf
import ai.rapids.cudf.DType
import com.nvidia.spark.rapids.GpuRowToColumnConverter.TypeConverter

Expand Down Expand Up @@ -74,4 +75,9 @@ object GpuTypeShims {
def toScalarForType(t: DataType, v: Any) = {
throw new RuntimeException(s"Can not convert $v to scalar for type $t.")
}

def supportCsvRead(dt: DataType) : Boolean = false

def csvRead(cv: cudf.ColumnVector, dt: DataType): cudf.ColumnVector =
throw new RuntimeException(s"Not support type $dt.")
}
Loading