Skip to content

Commit

Permalink
Enable direct lit
Browse files Browse the repository at this point in the history
  • Loading branch information
mcrumiller committed Aug 15, 2024
1 parent c2c162f commit 5480c5a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 11 deletions.
16 changes: 6 additions & 10 deletions py-polars/polars/functions/lit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

import polars._reexport as pl
from polars._utils.convert import (
date_to_int,
datetime_to_int,
time_to_int,
timedelta_to_int,
)
Expand Down Expand Up @@ -79,8 +77,7 @@ def lit(

if isinstance(value, datetime):
if dtype == Date:
dt_int = date_to_int(value.date())
return lit(dt_int).cast(Date)
return wrap_expr(plr.lit(value.date(), allow_object=False))

# parse time unit
if dtype is not None and (tu := getattr(dtype, "time_unit", "us")) is not None:
Expand Down Expand Up @@ -109,8 +106,7 @@ def lit(
raise TypeError(msg)

dt_utc = value.replace(tzinfo=timezone.utc)
dt_int = datetime_to_int(dt_utc, time_unit)
expr = lit(dt_int).cast(Datetime(time_unit))
expr = wrap_expr(plr.lit(dt_utc, allow_object=False)).cast(Datetime(time_unit))
if tz is not None:
expr = expr.dt.replace_time_zone(
tz, ambiguous="earliest" if value.fold == 0 else "latest"
Expand All @@ -134,14 +130,14 @@ def lit(
if dtype == Datetime:
time_unit = getattr(dtype, "time_unit", "us") or "us"
dt_utc = datetime(value.year, value.month, value.day)
dt_int = datetime_to_int(dt_utc, time_unit)
expr = lit(dt_int).cast(Datetime(time_unit))
expr = wrap_expr(plr.lit(dt_utc, allow_object=False)).cast(
Datetime(time_unit)
)
if (time_zone := getattr(dtype, "time_zone", None)) is not None:
expr = expr.dt.replace_time_zone(str(time_zone))
return expr
else:
date_int = date_to_int(value)
return lit(date_int).cast(Date)
return wrap_expr(plr.lit(value, allow_object=False))

elif isinstance(value, pl.Series):
value = value._s
Expand Down
5 changes: 4 additions & 1 deletion py-polars/src/functions/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,10 @@ pub fn lit(value: &Bound<'_, PyAny>, allow_object: bool) -> PyResult<PyExpr> {
Ok(dsl::lit(Null {}).into())
} else if let Ok(value) = value.downcast::<PyBytes>() {
Ok(dsl::lit(value.as_bytes()).into())
} else if value.get_type().qualname().unwrap() == "Decimal" {
} else if matches!(
value.get_type().qualname().unwrap().as_str(),
"date" | "datetime" | "Decimal"
) {
let av = py_object_to_any_value(value, true)?;
Ok(Expr::Literal(LiteralValue::try_from(av).unwrap()).into())
} else if allow_object {
Expand Down

0 comments on commit 5480c5a

Please sign in to comment.