From eb6786a1b1e43b92cb2e008b3c962e29bf74e5df Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Sat, 1 Oct 2022 00:53:51 +0400 Subject: [PATCH 1/2] feat[python]: support Series init as struct from dataclass or annotated namedtuple --- py-polars/polars/datatypes.py | 26 +++++++-- py-polars/polars/internals/construction.py | 48 +++++++++------- py-polars/tests/unit/test_series.py | 65 +++++++++++++++++++++- 3 files changed, 112 insertions(+), 27 deletions(-) diff --git a/py-polars/polars/datatypes.py b/py-polars/polars/datatypes.py index da14bd57bcfe..c1e90d90084b 100644 --- a/py-polars/polars/datatypes.py +++ b/py-polars/polars/datatypes.py @@ -9,6 +9,7 @@ TYPE_CHECKING, Any, Dict, + ForwardRef, Mapping, Optional, Sequence, @@ -35,6 +36,8 @@ _DOCUMENTING = True UnionType: type +OptionType = type(Optional[type]) + if sys.version_info >= (3, 10): from types import UnionType else: @@ -409,6 +412,9 @@ def __hash__(self) -> int: Decimal: Float64, } +_PY_STR_TO_DTYPE: dict[str, PolarsDataType] = { + str(tp.__name__): dtype for tp, dtype in _PY_TYPE_TO_DTYPE.items() +} _DTYPE_TO_PY_TYPE: dict[PolarsDataType, type] = { Float64: float, @@ -539,13 +545,23 @@ def is_polars_dtype(data_type: Any) -> bool: def py_type_to_dtype(data_type: Any, raise_unmatched: bool = True) -> PolarsDataType: - """Convert a Python dtype to a Polars dtype.""" - # when the passed in is already a Polars datatype, return that + """Convert a Python dtype (or type annotation) to a Polars dtype.""" + if isinstance(data_type, ForwardRef): + dtype = data_type.__forward_arg__ + data_type = ( + _PY_STR_TO_DTYPE.get( + dtype.removeprefix("None | ").removesuffix(" | None").strip(), data_type + ) + if isinstance(dtype, str) # type: ignore[redundant-expr] + else data_type + ) + if is_polars_dtype(data_type): return data_type - elif isinstance(data_type, UnionType): - # not exhaustive; currently handles the common "type | None" case, - # but ideally would pick appropriate supertype when n_types > 1 + + elif isinstance(data_type, (OptionType, UnionType)): + # not exhaustive; handles the common "type | None" case, but + # should probably pick appropriate supertype when n_types > 1? possible_types = [tp for tp in get_args(data_type) if tp is not NoneType] if len(possible_types) == 1: data_type = possible_types[0] diff --git a/py-polars/polars/internals/construction.py b/py-polars/polars/internals/construction.py index 8ae7f067c7e2..1f44d7881e39 100644 --- a/py-polars/polars/internals/construction.py +++ b/py-polars/polars/internals/construction.py @@ -70,6 +70,13 @@ def dataclass_type_hints(obj: type) -> dict[str, Any]: from polars.internals.type_aliases import Orientation +def is_namedtuple(value: Any, annotated: bool = False) -> bool: + """Infer whether value is a NamedTuple.""" + if all(hasattr(value, attr) for attr in ("_fields", "_field_defaults", "_replace")): + return len(value.__annotations__) == len(value._fields) if annotated else True + return False + + ################################ # Series constructor interface # ################################ @@ -214,17 +221,20 @@ def sequence_to_pyseries( value = _get_first_non_none(values) if value is not None: - # for temporal dtypes: - # * if the values are integer, we take the physical branch. - # * if the values are python types, take the temporal branch. - # * if the values are ISO-8601 strings, init then convert via strptime. - if dtype in py_temporal_types and isinstance(value, int): - dtype = py_type_to_dtype(dtype) # construct from integer - elif ( - dtype in pl_temporal_types or type(dtype) in pl_temporal_types - ) and not isinstance(value, int): - temporal_unit = getattr(dtype, "tu", None) - python_dtype = dtype_to_py_type(dtype) # type: ignore[arg-type] + if is_dataclass(value) or is_namedtuple(value, annotated=True): + return pli.DataFrame(values).to_struct(name)._s + else: + # for temporal dtypes: + # * if the values are integer, we take the physical branch. + # * if the values are python types, take the temporal branch. + # * if the values are ISO-8601 strings, init then convert via strptime. + if dtype in py_temporal_types and isinstance(value, int): + dtype = py_type_to_dtype(dtype) # construct from integer + elif ( + dtype in pl_temporal_types or type(dtype) in pl_temporal_types + ) and not isinstance(value, int): + temporal_unit = getattr(dtype, "tu", None) + python_dtype = dtype_to_py_type(dtype) # type: ignore[arg-type] # physical branch # flat data @@ -235,7 +245,6 @@ def sequence_to_pyseries( if dtype in (Date, Datetime, Duration, Time, Categorical): pyseries = pyseries.cast(dtype, True) return pyseries - else: if python_dtype is None: if value is None: @@ -244,7 +253,6 @@ def sequence_to_pyseries( else: python_dtype = type(value) if datetime == python_dtype: - # note: python-native datetimes have microsecond precision temporal_unit = "us" # temporal branch @@ -307,7 +315,7 @@ def sequence_to_pyseries( dtype = py_type_to_dtype(nested_dtype) with suppress(BaseException): return PySeries.new_list(name, values, dtype) - # pass we create an object if we get here + # pass; we create an object if we get here else: try: to_arrow_type = ( @@ -369,7 +377,6 @@ def sequence_to_pyseries( # - bools: "'int' object cannot be converted to 'PyBool'" elif str_val == "'int' object cannot be converted to 'PyBool'": constructor = py_type_to_constructor(int) - else: raise error @@ -587,13 +594,14 @@ def sequence_to_pydf( return pydf elif isinstance(data[0], Sequence) and not isinstance(data[0], str): - # infer orientation - if all( - hasattr(data[0], attr) - for attr in ("_fields", "_field_defaults", "_replace") - ): # namedtuple + if is_namedtuple(data[0]): if columns is None: columns = data[0]._fields # type: ignore[attr-defined] + if len(data[0].__annotations__) == len(columns): + columns = [ + (name, py_type_to_dtype(tp, raise_unmatched=False)) + for name, tp in data[0].__annotations__.items() + ] elif orient is None: orient = "row" diff --git a/py-polars/tests/unit/test_series.py b/py-polars/tests/unit/test_series.py index 38e2ac36569c..adf04335ba5e 100644 --- a/py-polars/tests/unit/test_series.py +++ b/py-polars/tests/unit/test_series.py @@ -10,8 +10,22 @@ import pytest import polars as pl -from polars.datatypes import Date, Datetime, Float64, Int32, Int64, Time, UInt32, UInt64 -from polars.testing import assert_series_equal, verify_series_and_expr_api +from polars.datatypes import ( + Date, + Datetime, + Field, + Float64, + Int32, + Int64, + Time, + UInt32, + UInt64, +) +from polars.testing import ( + assert_frame_equal, + assert_series_equal, + verify_series_and_expr_api, +) if TYPE_CHECKING: from polars.internals.type_aliases import TimeUnit @@ -100,6 +114,53 @@ def test_init_inputs(monkeypatch: Any) -> None: pl.DataFrame(np.array([1, 2, 3]), columns=["a"]) +def test_init_dataclass_namedtuple() -> None: + from dataclasses import dataclass + from typing import NamedTuple + + @dataclass + class TeaShipmentDC: + exporter: str + importer: str + product: str + weight: float | None + + class TeaShipmentNT(NamedTuple): + exporter: str + importer: str + product: str + weight: None | float + + for Tea in (TeaShipmentDC, TeaShipmentNT): + t0 = Tea(exporter="Sri Lanka", importer="USA", product="Ceylon", weight=100) + t1 = Tea(exporter="India", importer="UK", product="Darjeeling", weight=250) + + s = pl.Series("t", [t0, t1]) + + assert isinstance(s, pl.Series) + assert s.dtype.fields == [ # type: ignore[attr-defined] + Field("exporter", pl.Utf8), + Field("importer", pl.Utf8), + Field("product", pl.Utf8), + Field("weight", pl.Float64), + ] + assert s.to_list() == [ + { + "exporter": "Sri Lanka", + "importer": "USA", + "product": "Ceylon", + "weight": 100.0, + }, + { + "exporter": "India", + "importer": "UK", + "product": "Darjeeling", + "weight": 250.0, + }, + ] + assert_frame_equal(s.to_frame(), pl.DataFrame({"t": [t0, t1]})) + + def test_concat() -> None: s = pl.Series("a", [2, 1, 3]) From 20f2a2eb858af07867afb9c7f84749507435628c Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Sat, 1 Oct 2022 10:50:58 +0400 Subject: [PATCH 2/2] py 3.7 compatibility --- py-polars/polars/datatypes.py | 9 +++++---- py-polars/tests/unit/test_series.py | 14 +++++++------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/py-polars/polars/datatypes.py b/py-polars/polars/datatypes.py index c1e90d90084b..f90c03cfa4af 100644 --- a/py-polars/polars/datatypes.py +++ b/py-polars/polars/datatypes.py @@ -1,6 +1,7 @@ from __future__ import annotations import ctypes +import re import sys from datetime import date, datetime, time, timedelta from decimal import Decimal @@ -547,13 +548,13 @@ def is_polars_dtype(data_type: Any) -> bool: def py_type_to_dtype(data_type: Any, raise_unmatched: bool = True) -> PolarsDataType: """Convert a Python dtype (or type annotation) to a Polars dtype.""" if isinstance(data_type, ForwardRef): - dtype = data_type.__forward_arg__ + annotation = data_type.__forward_arg__ data_type = ( _PY_STR_TO_DTYPE.get( - dtype.removeprefix("None | ").removesuffix(" | None").strip(), data_type + re.sub(r"(^None \|)|(\| None$)", "", annotation).strip(), data_type ) - if isinstance(dtype, str) # type: ignore[redundant-expr] - else data_type + if isinstance(annotation, str) # type: ignore[redundant-expr] + else annotation ) if is_polars_dtype(data_type): diff --git a/py-polars/tests/unit/test_series.py b/py-polars/tests/unit/test_series.py index adf04335ba5e..5ea0adb6eb2f 100644 --- a/py-polars/tests/unit/test_series.py +++ b/py-polars/tests/unit/test_series.py @@ -123,17 +123,17 @@ class TeaShipmentDC: exporter: str importer: str product: str - weight: float | None + tonnes: int | None class TeaShipmentNT(NamedTuple): exporter: str importer: str product: str - weight: None | float + tonnes: None | int for Tea in (TeaShipmentDC, TeaShipmentNT): - t0 = Tea(exporter="Sri Lanka", importer="USA", product="Ceylon", weight=100) - t1 = Tea(exporter="India", importer="UK", product="Darjeeling", weight=250) + t0 = Tea(exporter="Sri Lanka", importer="USA", product="Ceylon", tonnes=10) + t1 = Tea(exporter="India", importer="UK", product="Darjeeling", tonnes=25) s = pl.Series("t", [t0, t1]) @@ -142,20 +142,20 @@ class TeaShipmentNT(NamedTuple): Field("exporter", pl.Utf8), Field("importer", pl.Utf8), Field("product", pl.Utf8), - Field("weight", pl.Float64), + Field("tonnes", pl.Int64), ] assert s.to_list() == [ { "exporter": "Sri Lanka", "importer": "USA", "product": "Ceylon", - "weight": 100.0, + "tonnes": 10, }, { "exporter": "India", "importer": "UK", "product": "Darjeeling", - "weight": 250.0, + "tonnes": 25, }, ] assert_frame_equal(s.to_frame(), pl.DataFrame({"t": [t0, t1]}))