Skip to content

Commit

Permalink
fix(python): Correctly handle large timedelta objects in Series const…
Browse files Browse the repository at this point in the history
…ructor (pola-rs#16043)
  • Loading branch information
stinodego authored and Wouittone committed Jun 22, 2024
1 parent 4e2e47b commit ee571a2
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 12 deletions.
2 changes: 1 addition & 1 deletion py-polars/polars/datatypes/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _map_py_type_to_dtype(
if issubclass(python_dtype, date):
return Date
if python_dtype is timedelta:
return Duration("us")
return Duration
if python_dtype is time:
return Time
if python_dtype is list:
Expand Down
18 changes: 13 additions & 5 deletions py-polars/src/conversion/any_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ type InitFn = for<'py> fn(&Bound<'py, PyAny>, bool) -> PyResult<AnyValue<'py>>;
pub(crate) static LUT: crate::gil_once_cell::GILOnceCell<PlHashMap<TypeObjectPtr, InitFn>> =
crate::gil_once_cell::GILOnceCell::new();

/// Convert a Python object to an [`AnyValue`].
pub(crate) fn py_object_to_any_value<'py>(
ob: &Bound<'py, PyAny>,
strict: bool,
Expand Down Expand Up @@ -202,14 +203,21 @@ pub(crate) fn py_object_to_any_value<'py>(

fn get_timedelta(ob: &Bound<'_, PyAny>, _strict: bool) -> PyResult<AnyValue<'static>> {
Python::with_gil(|py| {
let td = UTILS
let f = UTILS
.bind(py)
.getattr(intern!(py, "timedelta_to_int"))
.unwrap()
.call1((ob, intern!(py, "us")))
.unwrap();
let v = td.extract::<i64>().unwrap();
Ok(AnyValue::Duration(v, TimeUnit::Microseconds))
let py_int = f.call1((ob, intern!(py, "us"))).unwrap();

let av = if let Ok(v) = py_int.extract::<i64>() {
AnyValue::Duration(v, TimeUnit::Microseconds)
} else {
// This should be faster than calling `timedelta_to_int` again with `"ms"` input.
let v_us = py_int.extract::<i128>().unwrap();
let v = (v_us / 1000) as i64;
AnyValue::Duration(v, TimeUnit::Milliseconds)
};
Ok(av)
})
}

Expand Down
12 changes: 12 additions & 0 deletions py-polars/tests/unit/constructors/test_series.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from datetime import timedelta
from typing import Any

import pytest
Expand Down Expand Up @@ -76,3 +77,14 @@ def test_preserve_decimal_precision() -> None:
dtype = pl.Decimal(None, 1)
s = pl.Series(dtype=dtype)
assert s.dtype == dtype


@pytest.mark.parametrize("dtype", [None, pl.Duration("ms")])
def test_large_timedelta(dtype: pl.DataType | None) -> None:
values = [timedelta.min, timedelta.max]
s = pl.Series(values, dtype=dtype)
assert s.dtype == pl.Duration("ms")

# Microsecond precision is lost
expected = [timedelta.min, timedelta.max - timedelta(microseconds=999)]
assert s.to_list() == expected
8 changes: 2 additions & 6 deletions py-polars/tests/unit/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,8 @@ def test_dtype_temporal_units() -> None:
assert pl.Duration("ns") != pl.Duration("us")

# check timeunit from pytype
for inferred_dtype, expected_dtype in (
(py_type_to_dtype(datetime), pl.Datetime),
(py_type_to_dtype(timedelta), pl.Duration),
):
assert inferred_dtype == expected_dtype
assert inferred_dtype.time_unit == "us" # type: ignore[union-attr]
assert py_type_to_dtype(datetime) == pl.Datetime("us")
assert py_type_to_dtype(timedelta) == pl.Duration

with pytest.raises(ValueError, match="invalid `time_unit`"):
pl.Datetime("?") # type: ignore[arg-type]
Expand Down

0 comments on commit ee571a2

Please sign in to comment.