Skip to content

Commit

Permalink
fix: Method dt.truncate was sometimes returning incorrect results for…
Browse files Browse the repository at this point in the history
… pre-1970 datetimes (#17582)
  • Loading branch information
MarcoGorelli authored Jul 12, 2024
1 parent 64b45a8 commit a6236c1
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 5 deletions.
2 changes: 1 addition & 1 deletion crates/polars-time/src/truncate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ impl PolarsTruncate for DatetimeChunked {
return Ok(self
.apply_values(|t| {
let remainder = t % every;
t - remainder + every * (remainder < 0) as i64
t - (remainder + every * (remainder < 0) as i64)
})
.into_datetime(self.time_unit(), time_zone.clone()));
} else {
Expand Down
3 changes: 1 addition & 2 deletions py-polars/polars/expr/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,8 @@ def truncate(self, every: str | dt.timedelta | Expr) -> Expr:
│ 2001-01-01 01:00:00 ┆ 2001-01-01 01:00:00 │
└─────────────────────┴─────────────────────┘
"""
if not isinstance(every, pl.Expr):
if isinstance(every, dt.timedelta):
every = parse_as_duration_string(every)

every = parse_into_expression(every, str_as_lit=True)
return wrap_expr(self._pyexpr.dt_truncate(every))

Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/series/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -1642,7 +1642,7 @@ def offset_by(self, by: str | Expr) -> Series:
]
"""

def truncate(self, every: str | dt.timedelta | Expr) -> Series:
def truncate(self, every: str | dt.timedelta | IntoExprColumn) -> Series:
"""
Divide the date/ datetime range into buckets.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

from datetime import date, datetime
from datetime import date, datetime, timedelta
from typing import TYPE_CHECKING

import hypothesis.strategies as st
import pytest
from hypothesis import given

import polars as pl
from polars._utils.convert import parse_as_duration_string
from polars.testing import assert_series_equal

if TYPE_CHECKING:
Expand Down Expand Up @@ -92,3 +93,29 @@ def test_truncate_datetime_w_expression(time_unit: TimeUnit) -> None:
result = df.select(pl.col("a").dt.truncate(pl.col("b")))["a"]
assert result[0] == datetime(2020, 1, 1)
assert result[1] == datetime(2020, 1, 3)


def test_pre_epoch_truncate_17581() -> None:
s = pl.Series([datetime(1980, 1, 1), datetime(1969, 1, 1, 1)])
result = s.dt.truncate("1d")
expected = pl.Series([datetime(1980, 1, 1), datetime(1969, 1, 1)])
assert_series_equal(result, expected)


@given(
datetimes=st.lists(
st.datetimes(min_value=datetime(1960, 1, 1), max_value=datetime(1980, 1, 1)),
min_size=1,
max_size=3,
),
every=st.timedeltas(
min_value=timedelta(microseconds=1), max_value=timedelta(days=1)
).map(parse_as_duration_string),
)
def test_fast_path_vs_slow_path(datetimes: list[datetime], every: str) -> None:
s = pl.Series(datetimes)
# Might use fastpath:
result = s.dt.truncate(every)
# Definitely uses slowpath:
expected = s.dt.truncate(pl.Series([every] * len(datetimes)))
assert_series_equal(result, expected)

0 comments on commit a6236c1

Please sign in to comment.