Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support reading ANSI day time interval type from CSV source #4927

Merged
merged 9 commits into from
Mar 18, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion integration_tests/src/main/python/csv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,29 @@ def test_basic_csv_read(std_input_path, name, schema, options, read_func, v1_ena
pytest.param(double_gen),
pytest.param(FloatGen(no_nans=False)),
pytest.param(float_gen),
TimestampGen()]
TimestampGen(),
# 365 days * 5000 is about 5000 years
DayTimeIntervalGen(start_field="day", end_field="day", max_days=365 * 5000, allow_negative=True),
firestarman marked this conversation as resolved.
Show resolved Hide resolved
DayTimeIntervalGen(start_field="day", end_field="hour", max_days=365 * 5000, allow_negative=True),
DayTimeIntervalGen(start_field="day", end_field="minute", max_days=365 * 5000, allow_negative=True),
DayTimeIntervalGen(start_field="day", end_field="second", max_days=365 * 5000, allow_negative=True),
DayTimeIntervalGen(start_field="hour", end_field="hour", max_days=365 * 5000, allow_negative=True),
DayTimeIntervalGen(start_field="hour", end_field="minute", max_days=365 * 5000, allow_negative=True),
DayTimeIntervalGen(start_field="hour", end_field="second", max_days=365 * 5000, allow_negative=True),
DayTimeIntervalGen(start_field="minute", end_field="minute", max_days=365 * 5000, allow_negative=True),
DayTimeIntervalGen(start_field="minute", end_field="second", max_days=365 * 5000, allow_negative=True),
DayTimeIntervalGen(start_field="second", end_field="second", max_days=365 * 5000, allow_negative=True),
DayTimeIntervalGen(start_field="day", end_field="day", max_days=365 * 5000),
DayTimeIntervalGen(start_field="day", end_field="hour", max_days=365 * 5000),
DayTimeIntervalGen(start_field="day", end_field="minute", max_days=365 * 5000),
DayTimeIntervalGen(start_field="day", end_field="second", max_days=365 * 5000),
DayTimeIntervalGen(start_field="hour", end_field="hour", max_days=365 * 5000),
DayTimeIntervalGen(start_field="hour", end_field="minute", max_days=365 * 5000),
DayTimeIntervalGen(start_field="hour", end_field="second", max_days=365 * 5000),
DayTimeIntervalGen(start_field="minute", end_field="minute", max_days=365 * 5000),
DayTimeIntervalGen(start_field="minute", end_field="second", max_days=365 * 5000),
DayTimeIntervalGen(start_field="second", end_field="second", max_days=365 * 5000),
]

@approximate_float
@pytest.mark.parametrize('data_gen', csv_supported_gens, ids=idfn)
Expand Down
107 changes: 94 additions & 13 deletions integration_tests/src/main/python/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,31 +613,112 @@ def make_null():
self._start(rand, make_null)

# DayTimeIntervalGen is for Spark 3.3.0+
# DayTimeIntervalType(startField, endField): Represents a day-time interval which is made up of a contiguous subset of the following fields:
# DayTimeIntervalType(startField, endField):
# Represents a day-time interval which is made up of a contiguous subset of the following fields:
# SECOND, seconds within minutes and possibly fractions of a second [0..59.999999],
# Note Spark now uses 99 as max second, see issue https://issues.apache.org/jira/browse/SPARK-38324
revans2 marked this conversation as resolved.
Show resolved Hide resolved
# If second is start field, it's max value is long.max / microseconds in one second
# MINUTE, minutes within hours [0..59],
# If minute is start field, it's max value is long.max / microseconds in one minute
# HOUR, hours within days [0..23],
# DAY, days in the range [0..106751991].
# If hour is start field, it's max value is long.max / microseconds in one hour
# DAY, days in the range [0..106751991]. 106751991 is long.max / microseconds in one day
# For more details: https://spark.apache.org/docs/latest/sql-ref-datatypes.html
# Note: 106751991/365 = 292471 years which is much bigger than 9999 year, seems something is wrong
# Note: 106751991/365 = 292471 years which is much bigger than 9999 year
revans2 marked this conversation as resolved.
Show resolved Hide resolved
class DayTimeIntervalGen(DataGen):
"""Generate DayTimeIntervalType values"""
def __init__(self, max_days = None, nullable=True, special_cases =[timedelta(seconds = 0)]):
def __init__(self, max_days=None, start_field="day", end_field="second", allow_negative=False, nullable=True,
firestarman marked this conversation as resolved.
Show resolved Hide resolved
special_cases=[timedelta(seconds=0)]):
firestarman marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(DayTimeIntervalType(), nullable=nullable, special_cases=special_cases)
if max_days is None:
self._max_days = 106751991
else:
assert 106751991 >= max_days > 0
self._max_days = max_days
self._allow_negative = allow_negative
self._start_field = start_field
self._end_field = end_field

fields = ["day", "hour", "minute", "second"]
start_index = fields.index(start_field)
end_index = fields.index(end_field)
if start_index > end_index:
raise RuntimeError('Start field {}, end field {}, valid fields is {}, start field should <= end field'.format(start_field, end_field, fields))

super().__init__(DayTimeIntervalType(start_index, end_index), nullable=nullable, special_cases=special_cases)

def _gen_random(self, rand, start_field, end_field):
micros_per_second = 1000 * 1000
revans2 marked this conversation as resolved.
Show resolved Hide resolved
micros_per_minute = 60 * micros_per_second
micros_per_hour = 60 * micros_per_minute
micros_per_day = 24 * micros_per_hour

max_micros = self._max_days * micros_per_day

# set default value
days = 0
hours = 0
minutes = 0
seconds = 0
microseconds = 0

if (start_field, end_field) == ("day", "day"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to use the below structure, which should get a better perf.

if:
    ...
elif:
    ...
elif:
    ...
else:
    ...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do this totally differently? There is just so much copy/paste between each part of the code. We know that the python code really is only converting it to micro seconds. So can we just generate a random number for the microseconds, up to max_micros, and then truncate it to the proper time? Also can we convert the "string" start/end fields into something simpler to write code for?

# in __init__
DAY = 0
HOUR = 1
MIN = 2
SEC = 3
fields_to_look_at = ["day", "hour", "minute", "second"]
si = fields_to_look_at.index(start_field)
ei = fileds_to_look_at.index(end_field)
assert si <= ei

self.hasDays = si <= DAY and ei >= DAY
self.hasHours = si <= HOUR and ei >= HOUR
self.hasMin = si <= MIN and ei >= MIN
self.hasSec = si <= SEC and ei >= SEC

days = rand.randint(0, self._max_days)
if (start_field, end_field) == ("day", "hour"):
days = rand.randint(0, self._max_days)
hours_remaining = (max_micros - days * micros_per_day) / micros_per_hour
hours = rand.randint(0, min(23, hours_remaining))
if (start_field, end_field) == ("day", "minute"):
days = rand.randint(0, self._max_days)
hours_remaining = (max_micros - days * micros_per_day) / micros_per_hour
hours = rand.randint(0, min(23, hours_remaining))
minutes_remaining = (max_micros - days * micros_per_day - hours * micros_per_hour) / micros_per_minute
minutes = rand.randint(0, min(59, minutes_remaining))
if (start_field, end_field) == ("day", "second"):
days = rand.randint(0, self._max_days)
hours_remaining = (max_micros - days * micros_per_day) / micros_per_hour
hours = rand.randint(0, min(23, hours_remaining))
minutes_remaining = (max_micros - days * micros_per_day - hours * micros_per_hour) / micros_per_minute
minutes = rand.randint(0, min(59, minutes_remaining))
seconds_remaining = (max_micros - days * micros_per_day - hours * micros_per_hour - minutes * micros_per_minute ) / micros_per_second
seconds = rand.randint(0, min(99, seconds_remaining))
microseconds_remaining = max_micros - days * micros_per_day - hours * micros_per_hour - minutes * micros_per_minute - seconds * micros_per_second
microseconds = rand.randint(0, min(999999, microseconds_remaining))
if (start_field, end_field) == ("hour", "hour"):
hours = rand.randint(0, max_micros / micros_per_hour)
if (start_field, end_field) == ("hour", "minute"):
hours = rand.randint(0, max_micros / micros_per_hour)
minutes_remaining = (max_micros - days * micros_per_day - hours * micros_per_hour) / micros_per_minute
minutes = rand.randint(0, min(59, minutes_remaining))
if (start_field, end_field) == ("hour", "second"):
hours = rand.randint(0, max_micros / micros_per_hour)
minutes_remaining = (max_micros - days * micros_per_day - hours * micros_per_hour) / micros_per_minute
minutes = rand.randint(0, min(59, minutes_remaining))
seconds_remaining = (max_micros - days * micros_per_day - hours * micros_per_hour - minutes * micros_per_minute ) / micros_per_second
seconds = rand.randint(0, min(99, seconds_remaining))
microseconds_remaining = max_micros - days * micros_per_day - hours * micros_per_hour - minutes * micros_per_minute - seconds * micros_per_second
microseconds = rand.randint(0, min(999999, microseconds_remaining))
if (start_field, end_field) == ("minute", "minute"):
minutes = rand.randint(0, max_micros / micros_per_minute)
if (start_field, end_field) == ("minute", "second"):
minutes = rand.randint(0, max_micros / micros_per_minute)
seconds_remaining = (max_micros - days * micros_per_day - hours * micros_per_hour - minutes * micros_per_minute ) / micros_per_second
seconds = rand.randint(0, min(99, seconds_remaining))
microseconds_remaining = max_micros - days * micros_per_day - hours * micros_per_hour - minutes * micros_per_minute - seconds * micros_per_second
microseconds = rand.randint(0, min(999999, microseconds_remaining))
if (start_field, end_field) == ("second", "second"):
seconds = rand.randint(0, max_micros / micros_per_second)
microseconds_remaining = max_micros - days * micros_per_day - hours * micros_per_hour - minutes * micros_per_minute - seconds * micros_per_second
microseconds = rand.randint(0, min(999999, microseconds_remaining))

if self._allow_negative:
sign = 1 if (rand.randint(0, 1) == 0) else -1
else:
sign = 1
return timedelta(microseconds * sign, seconds * sign, minutes * sign, hours * sign, days * sign)

def start(self, rand):
self._start(rand,
lambda : timedelta(
microseconds = rand.randint(0, 999999),
seconds = rand.randint(0, 59),
minutes = rand.randint(0, 59),
hours = rand.randint(0, 23),
days = rand.randint(0, self._max_days),
)
)
self._start(rand, lambda: self._gen_random(rand, self._start_field, self._end_field))

def skip_if_not_utc():
if (not is_tz_utc()):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package com.nvidia.spark.rapids.shims

import ai.rapids.cudf.ColumnVector
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import ai.rapids.cudf.ColumnVector
import ai.rapids.cudf

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see nothing is changed. And the PR I mentioned is merged, so here should be a conflict now.

import ai.rapids.cudf.DType
import com.nvidia.spark.rapids.GpuRowToColumnConverter.TypeConverter

Expand Down Expand Up @@ -46,4 +47,9 @@ object GpuTypeShims {
* @return the cuDF type if the Shim supports
*/
def toRapidsOrNull(t: DataType): DType = null

def supportCsvRead(dt: DataType) : Boolean = false

def csvRead(cv: ColumnVector, dt: DataType): ColumnVector =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def csvRead(cv: ColumnVector, dt: DataType): ColumnVector =
def csvRead(cv: cudf.ColumnVector, dt: DataType): cudf.ColumnVector =

Otherwise, it will conflict with the PR #4926 who imports the org.apache.spark.sql.vectorized.ColumnVector

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same as above, #4926 is merged, then here should be a conflict now.

throw new RuntimeException(s"Not support type $dt.")
}
Loading