Skip to content

Commit

Permalink
Add total seconds float support to the TimeDelta field
Browse files Browse the repository at this point in the history
Able to (de)serialise to float values as a representation of the total seconds.
  • Loading branch information
marcosatti committed Jun 3, 2022
1 parent 645cba2 commit 5b692a5
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 6 deletions.
1 change: 1 addition & 0 deletions AUTHORS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,4 @@ Contributors (chronological)
- Kevin Kirsche `@kkirsche <https://github.com/kkirsche>`_
- Isira Seneviratne `@Isira-Seneviratne <https://github.com/Isira-Seneviratne>`_
- Karthikeyan Singaravelan `@tirkarthi <https://github.com/tirkarthi>`_
- Marco Satti `@marcosatti <https://github.com/marcosatti>`_
31 changes: 25 additions & 6 deletions src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -1434,17 +1434,24 @@ def _make_object_from_format(value, data_format):

class TimeDelta(Field):
"""A field that (de)serializes a :class:`datetime.timedelta` object to an
integer and vice versa. The integer can represent the number of days,
seconds or microseconds.
integer or float and vice versa. The integer or float can represent the
number of days, seconds or microseconds (for integers) or the total seconds
(for floats).
:param precision: Influences how the integer is interpreted during
:param precision: Influences how the integer or float is interpreted during
(de)serialization. Must be 'days', 'seconds', 'microseconds',
'milliseconds', 'minutes', 'hours' or 'weeks'.
'milliseconds', 'minutes', 'hours', 'weeks' or 'total_seconds'.
The 'total_seconds' precision interprets values as floats.
Other precisions will be interpreted as integers.
When using the `total_seconds` precision mode, data precision loss
may occur.
:param kwargs: The same keyword arguments that :class:`Field` receives.
.. versionchanged:: 2.0.0
Always serializes to an integer value to avoid rounding errors.
Add `precision` parameter.
.. versionchanged:: 3.17.0
Allow (de)serialization to float as the value of total seconds.
"""

DAYS = "days"
Expand All @@ -1454,6 +1461,7 @@ class TimeDelta(Field):
MINUTES = "minutes"
HOURS = "hours"
WEEKS = "weeks"
TOTAL_SECONDS = "total_seconds"

#: Default error messages.
default_error_messages = {
Expand All @@ -1471,6 +1479,7 @@ def __init__(self, precision: str = SECONDS, **kwargs):
self.MINUTES,
self.HOURS,
self.WEEKS,
self.TOTAL_SECONDS,
)

if precision not in units:
Expand All @@ -1485,18 +1494,28 @@ def __init__(self, precision: str = SECONDS, **kwargs):
def _serialize(self, value, attr, obj, **kwargs):
if value is None:
return None
if self.precision == self.TOTAL_SECONDS:
return value.total_seconds()
base_unit = dt.timedelta(**{self.precision: 1})
delta = utils.timedelta_to_microseconds(value)
unit = utils.timedelta_to_microseconds(base_unit)
return delta // unit

def _deserialize(self, value, attr, data, **kwargs):
if self.precision == self.TOTAL_SECONDS:
deser_type = float
else:
deser_type = int

try:
value = int(value)
value = deser_type(value)
except (TypeError, ValueError) as error:
raise self.make_error("invalid") from error

kwargs = {self.precision: value}
if self.precision == self.TOTAL_SECONDS:
kwargs = {"seconds": value}
else:
kwargs = {self.precision: value}

try:
return dt.timedelta(**kwargs)
Expand Down
6 changes: 6 additions & 0 deletions tests/test_deserialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,12 @@ def test_timedelta_field_deserialization(self):
assert result.seconds == 123
assert result.microseconds == 456000

total_seconds_value = 322.223
field = fields.TimeDelta(fields.TimeDelta.TOTAL_SECONDS)
result = field.deserialize(total_seconds_value)
assert isinstance(result, dt.timedelta)
assert result.total_seconds() == total_seconds_value

@pytest.mark.parametrize("in_value", ["", "badvalue", [], 9999999999])
def test_invalid_timedelta_field_deserialization(self, in_value):
field = fields.TimeDelta(fields.TimeDelta.DAYS)
Expand Down
25 changes: 25 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,34 +658,44 @@ def test_timedelta_field(self, user):
assert field.serialize("d1", user) == 86401000001
field = fields.TimeDelta(fields.TimeDelta.HOURS)
assert field.serialize("d1", user) == 24
field = fields.TimeDelta(fields.TimeDelta.TOTAL_SECONDS)
assert field.serialize("d1", user) == user.d1.total_seconds()

field = fields.TimeDelta(fields.TimeDelta.DAYS)
assert field.serialize("d2", user) == 1
field = fields.TimeDelta(fields.TimeDelta.SECONDS)
assert field.serialize("d2", user) == 86401
field = fields.TimeDelta(fields.TimeDelta.MICROSECONDS)
assert field.serialize("d2", user) == 86401000001
field = fields.TimeDelta(fields.TimeDelta.TOTAL_SECONDS)
assert field.serialize("d2", user) == user.d2.total_seconds()

field = fields.TimeDelta(fields.TimeDelta.DAYS)
assert field.serialize("d3", user) == 1
field = fields.TimeDelta(fields.TimeDelta.SECONDS)
assert field.serialize("d3", user) == 86401
field = fields.TimeDelta(fields.TimeDelta.MICROSECONDS)
assert field.serialize("d3", user) == 86401000001
field = fields.TimeDelta(fields.TimeDelta.TOTAL_SECONDS)
assert field.serialize("d3", user) == user.d3.total_seconds()

field = fields.TimeDelta(fields.TimeDelta.DAYS)
assert field.serialize("d4", user) == 0
field = fields.TimeDelta(fields.TimeDelta.SECONDS)
assert field.serialize("d4", user) == 0
field = fields.TimeDelta(fields.TimeDelta.MICROSECONDS)
assert field.serialize("d4", user) == 0
field = fields.TimeDelta(fields.TimeDelta.TOTAL_SECONDS)
assert field.serialize("d4", user) == user.d4.total_seconds()

field = fields.TimeDelta(fields.TimeDelta.DAYS)
assert field.serialize("d5", user) == -1
field = fields.TimeDelta(fields.TimeDelta.SECONDS)
assert field.serialize("d5", user) == -86400
field = fields.TimeDelta(fields.TimeDelta.MICROSECONDS)
assert field.serialize("d5", user) == -86400000000
field = fields.TimeDelta(fields.TimeDelta.TOTAL_SECONDS)
assert field.serialize("d5", user) == user.d5.total_seconds()

field = fields.TimeDelta(fields.TimeDelta.WEEKS)
assert field.serialize("d6", user) == 1
Expand All @@ -708,6 +718,8 @@ def test_timedelta_field(self, user):
assert field.serialize("d6", user) == d6_seconds * 1000 + 1
field = fields.TimeDelta(fields.TimeDelta.MICROSECONDS)
assert field.serialize("d6", user) == d6_seconds * 10**6 + 1000 + 1
field = fields.TimeDelta(fields.TimeDelta.TOTAL_SECONDS)
assert field.serialize("d6", user) == user.d6.total_seconds()

user.d7 = None
assert field.serialize("d7", user) is None
Expand All @@ -721,6 +733,19 @@ def test_timedelta_field(self, user):
field = fields.TimeDelta(fields.TimeDelta.SECONDS)
assert field.serialize("d9", user) == 1

user.d10 = dt.timedelta(
weeks=1,
days=6,
hours=2,
minutes=5,
seconds=51,
milliseconds=10,
microseconds=742,
)
field = fields.TimeDelta(fields.TimeDelta.TOTAL_SECONDS)
# Test for reasonable approximate equality. No guarantees are made about accuracy.
assert abs(field.serialize("d10", user) - 1130751.010742) < 0.1

def test_datetime_list_field(self):
obj = DateTimeList([dt.datetime.utcnow(), dt.datetime.now()])
field = fields.List(fields.DateTime)
Expand Down

0 comments on commit 5b692a5

Please sign in to comment.