Skip to content

Commit

Permalink
refactor(python): Add TypeGuard to is_polars_dtype util (#16065)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored May 5, 2024
1 parent aa2e77b commit 73cbdb2
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
8 changes: 7 additions & 1 deletion py-polars/polars/datatypes/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@

from polars.type_aliases import PolarsDataType, PythonDataType, SchemaDict, TimeUnit

if sys.version_info >= (3, 10):
from typing import TypeGuard
else:
from typing_extensions import TypeGuard

PY_STR_TO_DTYPE: SchemaDict = {
"float": Float64,
Expand Down Expand Up @@ -139,7 +143,9 @@ def _map_py_type_to_dtype(
raise TypeError(msg)


def is_polars_dtype(dtype: Any, *, include_unknown: bool = False) -> bool:
def is_polars_dtype(
dtype: Any, *, include_unknown: bool = False
) -> TypeGuard[PolarsDataType]:
"""Indicate whether the given input is a Polars dtype, or dtype specialization."""
try:
if dtype == Unknown:
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/functions/col.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _create_col(
return wrap_expr(plr.cols(names_str))
elif is_polars_dtype(name):
dtypes = [name]
dtypes.extend(more_names)
dtypes.extend(more_names) # type: ignore[arg-type]
return wrap_expr(plr.dtype_cols(dtypes))
else:
msg = (
Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def _combine_as_selector(
if names:
selected.append(by_name(*names))
if dtypes:
selected.append(by_dtype(*dtypes)) # type: ignore[arg-type]
selected.append(by_dtype(*dtypes))
if regexes:
selected.append(
matches(
Expand Down Expand Up @@ -579,7 +579,7 @@ def by_dtype(
all_dtypes: list[PolarsDataType] = []
for tp in dtypes:
if is_polars_dtype(tp):
all_dtypes.append(tp) # type: ignore[arg-type]
all_dtypes.append(tp)
elif isinstance(tp, Collection):
for t in tp:
if not is_polars_dtype(t):
Expand Down

0 comments on commit 73cbdb2

Please sign in to comment.