diff --git a/py-polars/polars/internals/dataframe/frame.py b/py-polars/polars/internals/dataframe/frame.py index e9c8a841103e..8c0e9dd2d062 100644 --- a/py-polars/polars/internals/dataframe/frame.py +++ b/py-polars/polars/internals/dataframe/frame.py @@ -8,7 +8,7 @@ import typing from collections import namedtuple from collections.abc import Sized -from datetime import timedelta +from datetime import date, datetime, time, timedelta from io import BytesIO, IOBase, StringIO from pathlib import Path from typing import ( @@ -5453,7 +5453,22 @@ def select( def with_columns( self, exprs: pli.Expr | pli.Series | Sequence[pli.Expr | pli.Series] | None = None, - **named_exprs: pli.Expr | pli.Series, + **named_exprs: ( + pli.Expr + | bool + | int + | float + | str + | pli.Series + | None + | date + | datetime + | time + | timedelta + | pli.WhenThen + | pli.WhenThenThen + | Sequence[(int | float | str | None)] + ), ) -> DataFrame: """ Add or overwrite multiple columns in a DataFrame. diff --git a/py-polars/polars/internals/lazyframe/frame.py b/py-polars/polars/internals/lazyframe/frame.py index fd1a10f6c884..10b018af2d6a 100644 --- a/py-polars/polars/internals/lazyframe/frame.py +++ b/py-polars/polars/internals/lazyframe/frame.py @@ -2395,7 +2395,22 @@ def join( def with_columns( self: LDF, exprs: pli.Expr | pli.Series | Sequence[pli.Expr | pli.Series] | None = None, - **named_exprs: pli.Expr | pli.Series | str, + **named_exprs: ( + pli.Expr + | bool + | int + | float + | str + | pli.Series + | None + | date + | datetime + | time + | timedelta + | pli.WhenThen + | pli.WhenThenThen + | Sequence[(int | float | str | None)] + ), ) -> LDF: """ Add or overwrite multiple columns in a DataFrame. @@ -2471,7 +2486,7 @@ def with_columns( else ([exprs] if isinstance(exprs, pli.Expr) else list(exprs)) ) exprs.extend( - (pli.lit(expr).alias(name) if isinstance(expr, str) else expr.alias(name)) + pli.expr_to_lit_or_expr(expr).alias(name) for name, expr in named_exprs.items() ) pyexprs = [] diff --git a/py-polars/tests/unit/test_df.py b/py-polars/tests/unit/test_df.py index f76b51bd497d..d48e36bac4bf 100644 --- a/py-polars/tests/unit/test_df.py +++ b/py-polars/tests/unit/test_df.py @@ -2401,6 +2401,8 @@ def test_selection_regex_and_multicol() -> None: def test_with_columns() -> None: + import datetime + df = pl.DataFrame( { "a": [1, 2, 3, 4], @@ -2419,12 +2421,28 @@ def test_with_columns() -> None: "d": [0.5, 8.0, 30.0, 52.0], "e": [False, False, True, False], "f": [3, 2, 1, 0], + "g": True, + "h": pl.Series(values=[1, 1, 1, 1], dtype=pl.Int32), + "i": 3.2, + "j": "d", + "k": pl.Series(values=[None, None, None, None], dtype=pl.Boolean), + "l": datetime.datetime(2001, 1, 1, 0, 0), } ) # as exprs list dx = df.with_columns( - [(pl.col("a") * pl.col("b")).alias("d"), ~pl.col("c").alias("e"), srs_named] + [ + (pl.col("a") * pl.col("b")).alias("d"), + ~pl.col("c").alias("e"), + srs_named, + pl.lit(True).alias("g"), + pl.lit(1).alias("h"), + pl.lit(3.2).alias("i"), + pl.lit("d").alias("j"), + pl.lit(None).alias("k"), + pl.lit(datetime.datetime(2001, 1, 1, 0, 0)).alias("l"), + ] ) assert_frame_equal(dx, expected) @@ -2435,6 +2453,12 @@ def test_with_columns() -> None: d=pl.col("a") * pl.col("b"), e=~pl.col("c"), f=srs_unnamed, + g=True, + h=1, + i=3.2, + j="d", + k=None, + l=datetime.datetime(2001, 1, 1, 0, 0), ) assert_frame_equal(dx, expected) @@ -2443,6 +2467,12 @@ def test_with_columns() -> None: [(pl.col("a") * pl.col("b")).alias("d")], e=~pl.col("c"), f=srs_unnamed, + g=True, + h=1, + i=3.2, + j="d", + k=None, + l=datetime.datetime(2001, 1, 1, 0, 0), ) assert_frame_equal(dx, expected)