Skip to content

Commit

Permalink
fix(python): treat literal values consistently in select context, i…
Browse files Browse the repository at this point in the history
…mprove related typing (#6628)
  • Loading branch information
alexander-beedie authored Feb 2, 2023
1 parent ffe8fbb commit 06c0b54
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 61 deletions.
20 changes: 15 additions & 5 deletions py-polars/polars/internals/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@
ParallelStrategy,
ParquetCompression,
PivotAgg,
PolarsExprType,
PythonLiteral,
RollingInterpolationMethod,
SizeUnit,
StartBy,
Expand Down Expand Up @@ -5544,12 +5546,13 @@ def select(
self: DF,
exprs: (
str
| pli.Expr
| PolarsExprType
| PythonLiteral
| pli.Series
| Iterable[str | pli.Expr | pli.Series | pli.WhenThen | pli.WhenThenThen]
| Iterable[str | PolarsExprType | PythonLiteral | pli.Series]
| None
) = None,
**named_exprs: Any,
**named_exprs: PolarsExprType | PythonLiteral | pli.Series | None,
) -> DF:
"""
Select columns from this DataFrame.
Expand Down Expand Up @@ -5658,8 +5661,15 @@ def select(

def with_columns(
self,
exprs: pli.Expr | pli.Series | Sequence[pli.Expr | pli.Series] | None = None,
**named_exprs: Any,
exprs: (
str
| PolarsExprType
| PythonLiteral
| pli.Series
| Iterable[str | PolarsExprType | PythonLiteral | pli.Series]
| None
) = None,
**named_exprs: PolarsExprType | PythonLiteral | pli.Series | None,
) -> DataFrame:
"""
Return a new DataFrame with the columns added (if new), or replaced.
Expand Down
44 changes: 12 additions & 32 deletions py-polars/polars/internals/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from polars.internals.expr.meta import ExprMetaNameSpace
from polars.internals.expr.string import ExprStringNameSpace
from polars.internals.expr.struct import ExprStructNameSpace
from polars.internals.type_aliases import PolarsExprType, PythonLiteral
from polars.utils import _timedelta_to_pl_duration, sphinx_accessor

try:
Expand All @@ -49,32 +50,23 @@

def selection_to_pyexpr_list(
exprs: (
str
| Expr
PolarsExprType
| PythonLiteral
| pli.Series
| Iterable[
str
| Expr
| pli.Series
| timedelta
| date
| datetime
| int
| float
| pli.WhenThen
| pli.WhenThenThen
]
| Iterable[PolarsExprType | PythonLiteral | pli.Series]
| None
),
structify: bool = False,
) -> list[PyExpr]:
if exprs is None:
exprs = []
elif isinstance(exprs, (str, Expr, pli.Series)):
elif isinstance(exprs, (str, Expr, pli.Series, pli.WhenThen, pli.WhenThenThen)):
exprs = [exprs]
elif not isinstance(exprs, Iterable):
exprs = [exprs]
return [
expr_to_lit_or_expr(e, str_to_lit=False, structify=structify)._pyexpr
for e in exprs
for e in exprs # type: ignore[union-attr]
]


Expand All @@ -87,20 +79,11 @@ def expr_output_name(expr: pli.Expr) -> str | None:

def expr_to_lit_or_expr(
expr: (
Expr
| bool
| int
| float
| str
PolarsExprType
| PythonLiteral
| pli.Series
| Iterable[PolarsExprType | PythonLiteral | pli.Series]
| None
| date
| datetime
| time
| timedelta
| pli.WhenThen
| pli.WhenThenThen
| Sequence[int | float | str | None]
),
str_to_lit: bool = True,
structify: bool = False,
Expand Down Expand Up @@ -2049,10 +2032,7 @@ def take(
indices = cast("np.ndarray[Any, Any]", indices)
indices_lit = pli.lit(pli.Series("", indices, dtype=UInt32))
else:
indices_lit = pli.expr_to_lit_or_expr(
indices, # type: ignore[arg-type]
str_to_lit=False,
)
indices_lit = pli.expr_to_lit_or_expr(indices, str_to_lit=False)
return pli.wrap_expr(self._pyexpr.take(indices_lit._pyexpr))

def shift(self, periods: int = 1) -> Expr:
Expand Down
16 changes: 12 additions & 4 deletions py-polars/polars/internals/lazy_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import sys
from datetime import date, datetime, time, timedelta
from typing import TYPE_CHECKING, Any, Callable, Sequence, overload
from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence, overload

from polars import internals as pli
from polars.datatypes import (
Expand All @@ -23,7 +23,7 @@
)
from polars.dependencies import _check_for_numpy
from polars.dependencies import numpy as np
from polars.internals.type_aliases import EpochTimeUnit
from polars.internals.type_aliases import EpochTimeUnit, PolarsExprType
from polars.utils import (
_datetime_to_pl_timestamp,
_time_to_pl_time,
Expand Down Expand Up @@ -72,6 +72,7 @@
if TYPE_CHECKING:
from polars.internals.type_aliases import (
IntoExpr,
PythonLiteral,
RollingInterpolationMethod,
TimeUnit,
)
Expand Down Expand Up @@ -2246,8 +2247,15 @@ def collect_all(


def select(
exprs: str | pli.Expr | Sequence[str | pli.Expr] | pli.Series,
**named_exprs: Any,
exprs: (
str
| PolarsExprType
| PythonLiteral
| pli.Series
| Iterable[str | PolarsExprType | PythonLiteral | pli.Series]
| None
) = None,
**named_exprs: PolarsExprType | PythonLiteral | pli.Series | None,
) -> pli.DataFrame:
"""
Run polars expressions without a context.
Expand Down
22 changes: 16 additions & 6 deletions py-polars/polars/internals/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from polars.internals import selection_to_pyexpr_list
from polars.internals.lazyframe.groupby import LazyGroupBy
from polars.internals.slice import LazyPolarsSlice
from polars.internals.type_aliases import PythonLiteral
from polars.utils import (
_in_notebook,
_prepare_row_count_args,
Expand Down Expand Up @@ -77,6 +78,7 @@
FillNullStrategy,
JoinStrategy,
ParallelStrategy,
PolarsExprType,
RollingInterpolationMethod,
StartBy,
UniqueKeepStrategy,
Expand Down Expand Up @@ -1557,12 +1559,13 @@ def select(
self: LDF,
exprs: (
str
| pli.Expr
| PolarsExprType
| PythonLiteral
| pli.Series
| Iterable[str | pli.Expr | pli.Series | pli.WhenThen | pli.WhenThenThen]
| Iterable[str | PolarsExprType | PythonLiteral | pli.Series]
| None
) = None,
**named_exprs: Any,
**named_exprs: PolarsExprType | PythonLiteral | pli.Series | None,
) -> LDF:
"""
Select columns from this DataFrame.
Expand Down Expand Up @@ -2460,11 +2463,18 @@ def join(

def with_columns(
self: LDF,
exprs: pli.Expr | pli.Series | Sequence[pli.Expr | pli.Series] | None = None,
**named_exprs: Any,
exprs: (
str
| PolarsExprType
| PythonLiteral
| pli.Series
| Iterable[str | PolarsExprType | PythonLiteral | pli.Series]
| None
) = None,
**named_exprs: PolarsExprType | PythonLiteral | pli.Series | None,
) -> LDF:
"""
Return a new LazyFrame with the columns added, if new, or replaced.
Return a new LazyFrame with the columns added (if new), or replaced.
Notes
-----
Expand Down
12 changes: 12 additions & 0 deletions py-polars/polars/internals/type_aliases.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import sys
from datetime import date, datetime, time, timedelta
from decimal import Decimal

from polars import internals as pli

Expand All @@ -14,6 +16,11 @@
else:
from typing_extensions import TypeAlias

from typing import Union

# Types that qualify as expressions (eg: for use in 'select', 'with_columns'...)
PolarsExprType: TypeAlias = "pli.Expr | pli.WhenThen | pli.WhenThenThen"

IntoExpr: TypeAlias = "int | float | str | pli.Expr | pli.Series"
ComparisonOperator: TypeAlias = Literal["eq", "neq", "gt", "lt", "gt_eq", "lt_eq"]

Expand Down Expand Up @@ -73,3 +80,8 @@
Orientation: TypeAlias = Literal["col", "row"]
TransferEncoding: TypeAlias = Literal["hex", "base64"]
SearchSortedSide: TypeAlias = Literal["any", "left", "right"]

# literal types that are allowed in expressions (auto-converted to pl.lit)
PythonLiteral: TypeAlias = Union[
str, int, float, bool, date, time, datetime, timedelta, bytes, Decimal
]
38 changes: 24 additions & 14 deletions py-polars/polars/internals/whenthen.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import typing
from typing import Any, Sequence
from typing import Any, Iterable

try:
from polars.polars import when as pywhen
Expand All @@ -11,6 +11,7 @@
_DOCUMENTING = True

from polars import internals as pli
from polars.internals.type_aliases import PolarsExprType, PythonLiteral


class WhenThenThen:
Expand All @@ -27,13 +28,11 @@ def when(self, predicate: pli.Expr | bool) -> WhenThenThen:
def then(
self,
expr: (
pli.Expr
| int
| float
| str
| None
PolarsExprType
| PythonLiteral
| pli.Series
| Sequence[int | float | str | None]
| Iterable[PolarsExprType | PythonLiteral | pli.Series]
| None
),
) -> WhenThenThen:
"""
Expand All @@ -51,7 +50,11 @@ def then(
def otherwise(
self,
expr: (
pli.Expr | int | float | str | None | Sequence[int | float | str | None]
PolarsExprType
| PythonLiteral
| pli.Series
| Iterable[PolarsExprType | PythonLiteral | pli.Series]
| None
),
) -> pli.Expr:
"""
Expand Down Expand Up @@ -83,7 +86,16 @@ def when(self, predicate: pli.Expr | bool) -> WhenThenThen:
predicate = pli.expr_to_lit_or_expr(predicate)
return WhenThenThen(self._pywhenthen.when(predicate._pyexpr))

def otherwise(self, expr: pli.Expr | int | float | str | None) -> pli.Expr:
def otherwise(
self,
expr: (
PolarsExprType
| PythonLiteral
| pli.Series
| Iterable[PolarsExprType | PythonLiteral | pli.Series]
| None
),
) -> pli.Expr:
"""
Values to return in case of the predicate being `False`.
Expand Down Expand Up @@ -111,13 +123,11 @@ def __init__(self, pywhen: pywhen):
def then(
self,
expr: (
pli.Expr
PolarsExprType
| PythonLiteral
| pli.Series
| int
| float
| str
| Iterable[PolarsExprType | PythonLiteral | pli.Series]
| None
| Sequence[None | int | float | str]
),
) -> WhenThen:
"""
Expand Down
16 changes: 16 additions & 0 deletions py-polars/tests/unit/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -2332,6 +2332,22 @@ def test_fill_null_limits() -> None:
}


def test_selection_misc() -> None:
df = pl.DataFrame({"x": "abc"}, schema={"x": pl.Utf8})

# literal values (as scalar/list)
for zero in (0, [0]):
assert df.select(zero)["literal"].to_list() == [0] # type: ignore[arg-type]
assert df.select(literal=0)["literal"].to_list() == [0]

# expect string values to be interpreted as cols
for x in ("x", ["x"], pl.col("x")):
assert df.select(x).rows() == [("abc",)] # type: ignore[arg-type]

# string col + lit
assert df.with_columns(["x", 0]).to_dicts() == [{"x": "abc", "literal": 0}]


def test_selection_regex_and_multicol() -> None:
test_df = pl.DataFrame(
{
Expand Down

0 comments on commit 06c0b54

Please sign in to comment.