From 598ca603a6844a12768fa72a8e456a5c3476f0cf Mon Sep 17 00:00:00 2001 From: Marshall Crumiller Date: Thu, 2 May 2024 10:11:43 -0400 Subject: [PATCH] Convert date/datetime in lit construction --- py-polars/polars/functions/lit.py | 16 ++++++- .../tests/unit/namespaces/test_datetime.py | 46 +++++++++++++++++++ 2 files changed, 60 insertions(+), 2 deletions(-) diff --git a/py-polars/polars/functions/lit.py b/py-polars/polars/functions/lit.py index b636d5f2544a6..b60c8cb9b9754 100644 --- a/py-polars/polars/functions/lit.py +++ b/py-polars/polars/functions/lit.py @@ -77,6 +77,9 @@ def lit( time_unit: TimeUnit if isinstance(value, datetime): + if dtype == Date: + dt_int = date_to_int(value.date()) + return lit(dt_int).cast(Date) if dtype is not None and (tu := getattr(dtype, "time_unit", "us")) is not None: time_unit = tu # type: ignore[assignment] else: @@ -113,8 +116,17 @@ def lit( return lit(time_int).cast(Time) elif isinstance(value, date): - date_int = date_to_int(value) - return lit(date_int).cast(Date) + if dtype == Datetime: + time_unit = getattr(dtype, "time_unit", "us") or "us" + dt_utc = datetime(value.year, value.month, value.day) + dt_int = datetime_to_int(dt_utc, time_unit) + expr = lit(dt_int).cast(Datetime(time_unit)) + if (time_zone := getattr(dtype, "time_zone", None)) is not None: + expr = expr.dt.replace_time_zone(str(time_zone)) + return expr + else: + date_int = date_to_int(value) + return lit(date_int).cast(Date) elif isinstance(value, pl.Series): value = value._s diff --git a/py-polars/tests/unit/namespaces/test_datetime.py b/py-polars/tests/unit/namespaces/test_datetime.py index 5307a2e7e129e..2e3e0f4d53b2e 100644 --- a/py-polars/tests/unit/namespaces/test_datetime.py +++ b/py-polars/tests/unit/namespaces/test_datetime.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections import OrderedDict from datetime import date, datetime, time, timedelta from typing import TYPE_CHECKING @@ -1310,3 +1311,48 @@ def test_agg_median_expr() -> None: ) assert_frame_equal(df.select(pl.all().median()), expected) + + +@pytest.mark.parametrize( + ("value_in", "dtype", "expected_value"), + [ + (date(2024, 1, 1), pl.Date, date(2024, 1, 1)), + (date(2024, 1, 1), pl.Datetime, datetime(2024, 1, 1)), + (date(2024, 1, 1), pl.Datetime("ms"), datetime(2024, 1, 1)), + (date(2024, 1, 1), pl.Datetime("us"), datetime(2024, 1, 1)), + (date(2024, 1, 1), pl.Datetime("ns"), datetime(2024, 1, 1)), + ( + date(2024, 1, 1), + pl.Datetime("ms", "EST"), + datetime(2024, 1, 1, tzinfo=ZoneInfo("EST")), + ), + (datetime(2024, 1, 1), pl.Date, date(2024, 1, 1)), + (datetime(2024, 1, 1), pl.Datetime, datetime(2024, 1, 1)), + (datetime(2024, 1, 1), pl.Datetime("ms"), datetime(2024, 1, 1)), + (datetime(2024, 1, 1), pl.Datetime("us"), datetime(2024, 1, 1)), + (datetime(2024, 1, 1), pl.Datetime("ns"), datetime(2024, 1, 1)), + ( + datetime(2024, 1, 1, tzinfo=ZoneInfo("EST")), + pl.Datetime("ms", "EST"), + datetime(2024, 1, 1, tzinfo=ZoneInfo("EST")), + ), + ], + ids=[ + "date to date", + "date to datetime", + "date to datetime('ms')", + "date to datetime('us')", + "date to datetime('ns')", + "date to datetime('ms', 'EST')", + "datetime to date", + "datetime to datetime", + "datetime to datetime('ms')", + "datetime to datetime('us')", + "datetime to datetime('ns')", + "datetime(tz=EST) to datetime('ms', 'EST')", + ], +) +def test_literal(value_in, dtype, expected_value) -> None: + out = pl.select(pl.lit(value_in, dtype)) + assert out.schema == OrderedDict({"literal": dtype}) + assert out.item() == expected_value