Skip to content

Commit

Permalink
Convert date/datetime in lit construction
Browse files Browse the repository at this point in the history
  • Loading branch information
mcrumiller committed May 2, 2024
1 parent f03e7e0 commit 598ca60
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 2 deletions.
16 changes: 14 additions & 2 deletions py-polars/polars/functions/lit.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def lit(
time_unit: TimeUnit

if isinstance(value, datetime):
if dtype == Date:
dt_int = date_to_int(value.date())
return lit(dt_int).cast(Date)
if dtype is not None and (tu := getattr(dtype, "time_unit", "us")) is not None:
time_unit = tu # type: ignore[assignment]
else:
Expand Down Expand Up @@ -113,8 +116,17 @@ def lit(
return lit(time_int).cast(Time)

elif isinstance(value, date):
date_int = date_to_int(value)
return lit(date_int).cast(Date)
if dtype == Datetime:
time_unit = getattr(dtype, "time_unit", "us") or "us"
dt_utc = datetime(value.year, value.month, value.day)
dt_int = datetime_to_int(dt_utc, time_unit)
expr = lit(dt_int).cast(Datetime(time_unit))
if (time_zone := getattr(dtype, "time_zone", None)) is not None:
expr = expr.dt.replace_time_zone(str(time_zone))
return expr
else:
date_int = date_to_int(value)
return lit(date_int).cast(Date)

elif isinstance(value, pl.Series):
value = value._s
Expand Down
46 changes: 46 additions & 0 deletions py-polars/tests/unit/namespaces/test_datetime.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from collections import OrderedDict
from datetime import date, datetime, time, timedelta
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -1310,3 +1311,48 @@ def test_agg_median_expr() -> None:
)

assert_frame_equal(df.select(pl.all().median()), expected)


@pytest.mark.parametrize(
("value_in", "dtype", "expected_value"),
[
(date(2024, 1, 1), pl.Date, date(2024, 1, 1)),
(date(2024, 1, 1), pl.Datetime, datetime(2024, 1, 1)),
(date(2024, 1, 1), pl.Datetime("ms"), datetime(2024, 1, 1)),
(date(2024, 1, 1), pl.Datetime("us"), datetime(2024, 1, 1)),
(date(2024, 1, 1), pl.Datetime("ns"), datetime(2024, 1, 1)),
(
date(2024, 1, 1),
pl.Datetime("ms", "EST"),
datetime(2024, 1, 1, tzinfo=ZoneInfo("EST")),
),
(datetime(2024, 1, 1), pl.Date, date(2024, 1, 1)),
(datetime(2024, 1, 1), pl.Datetime, datetime(2024, 1, 1)),
(datetime(2024, 1, 1), pl.Datetime("ms"), datetime(2024, 1, 1)),
(datetime(2024, 1, 1), pl.Datetime("us"), datetime(2024, 1, 1)),
(datetime(2024, 1, 1), pl.Datetime("ns"), datetime(2024, 1, 1)),
(
datetime(2024, 1, 1, tzinfo=ZoneInfo("EST")),
pl.Datetime("ms", "EST"),
datetime(2024, 1, 1, tzinfo=ZoneInfo("EST")),
),
],
ids=[
"date to date",
"date to datetime",
"date to datetime('ms')",
"date to datetime('us')",
"date to datetime('ns')",
"date to datetime('ms', 'EST')",
"datetime to date",
"datetime to datetime",
"datetime to datetime('ms')",
"datetime to datetime('us')",
"datetime to datetime('ns')",
"datetime(tz=EST) to datetime('ms', 'EST')",
],
)
def test_literal(value_in, dtype, expected_value) -> None:
out = pl.select(pl.lit(value_in, dtype))
assert out.schema == OrderedDict({"literal": dtype})
assert out.item() == expected_value

0 comments on commit 598ca60

Please sign in to comment.