From 004084e88ced44cb744ec7fa04b72fe9735a64d2 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Sun, 5 May 2024 19:29:34 +0200 Subject: [PATCH] Add TypeGuard to is_polars_dtype --- py-polars/polars/datatypes/convert.py | 8 +++++++- py-polars/polars/functions/col.py | 2 +- py-polars/polars/selectors.py | 4 ++-- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/py-polars/polars/datatypes/convert.py b/py-polars/polars/datatypes/convert.py index 8adffd125850..cdfc76875858 100644 --- a/py-polars/polars/datatypes/convert.py +++ b/py-polars/polars/datatypes/convert.py @@ -68,6 +68,10 @@ from polars.type_aliases import PolarsDataType, PythonDataType, SchemaDict, TimeUnit + if sys.version_info >= (3, 10): + from typing import TypeGuard + else: + from typing_extensions import TypeGuard PY_STR_TO_DTYPE: SchemaDict = { "float": Float64, @@ -139,7 +143,9 @@ def _map_py_type_to_dtype( raise TypeError(msg) -def is_polars_dtype(dtype: Any, *, include_unknown: bool = False) -> bool: +def is_polars_dtype( + dtype: Any, *, include_unknown: bool = False +) -> TypeGuard[PolarsDataType]: """Indicate whether the given input is a Polars dtype, or dtype specialization.""" try: if dtype == Unknown: diff --git a/py-polars/polars/functions/col.py b/py-polars/polars/functions/col.py index 90e0fd843ec9..6837020af7f8 100644 --- a/py-polars/polars/functions/col.py +++ b/py-polars/polars/functions/col.py @@ -29,7 +29,7 @@ def _create_col( return wrap_expr(plr.cols(names_str)) elif is_polars_dtype(name): dtypes = [name] - dtypes.extend(more_names) + dtypes.extend(more_names) # type: ignore[arg-type] return wrap_expr(plr.dtype_cols(dtypes)) else: msg = ( diff --git a/py-polars/polars/selectors.py b/py-polars/polars/selectors.py index b732c53bab42..0f4daad424c7 100644 --- a/py-polars/polars/selectors.py +++ b/py-polars/polars/selectors.py @@ -223,7 +223,7 @@ def _combine_as_selector( if names: selected.append(by_name(*names)) if dtypes: - selected.append(by_dtype(*dtypes)) # type: ignore[arg-type] + selected.append(by_dtype(*dtypes)) if regexes: selected.append( matches( @@ -579,7 +579,7 @@ def by_dtype( all_dtypes: list[PolarsDataType] = [] for tp in dtypes: if is_polars_dtype(tp): - all_dtypes.append(tp) # type: ignore[arg-type] + all_dtypes.append(tp) elif isinstance(tp, Collection): for t in tp: if not is_polars_dtype(t):