Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Support Series init as struct from @dataclass and annotated NamedTuple #5057

Merged
merged 2 commits into from
Oct 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions py-polars/polars/datatypes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import ctypes
import re
import sys
from datetime import date, datetime, time, timedelta
from decimal import Decimal
Expand All @@ -9,6 +10,7 @@
TYPE_CHECKING,
Any,
Dict,
ForwardRef,
Mapping,
Optional,
Sequence,
Expand All @@ -35,6 +37,8 @@
_DOCUMENTING = True

UnionType: type
OptionType = type(Optional[type])

if sys.version_info >= (3, 10):
from types import UnionType
else:
Expand Down Expand Up @@ -409,6 +413,9 @@ def __hash__(self) -> int:
Decimal: Float64,
}

_PY_STR_TO_DTYPE: dict[str, PolarsDataType] = {
str(tp.__name__): dtype for tp, dtype in _PY_TYPE_TO_DTYPE.items()
}

_DTYPE_TO_PY_TYPE: dict[PolarsDataType, type] = {
Float64: float,
Expand Down Expand Up @@ -539,13 +546,23 @@ def is_polars_dtype(data_type: Any) -> bool:


def py_type_to_dtype(data_type: Any, raise_unmatched: bool = True) -> PolarsDataType:
"""Convert a Python dtype to a Polars dtype."""
# when the passed in is already a Polars datatype, return that
"""Convert a Python dtype (or type annotation) to a Polars dtype."""
if isinstance(data_type, ForwardRef):
annotation = data_type.__forward_arg__
data_type = (
_PY_STR_TO_DTYPE.get(
re.sub(r"(^None \|)|(\| None$)", "", annotation).strip(), data_type
)
if isinstance(annotation, str) # type: ignore[redundant-expr]
else annotation
)

if is_polars_dtype(data_type):
return data_type
elif isinstance(data_type, UnionType):
# not exhaustive; currently handles the common "type | None" case,
# but ideally would pick appropriate supertype when n_types > 1

elif isinstance(data_type, (OptionType, UnionType)):
# not exhaustive; handles the common "type | None" case, but
# should probably pick appropriate supertype when n_types > 1?
possible_types = [tp for tp in get_args(data_type) if tp is not NoneType]
if len(possible_types) == 1:
data_type = possible_types[0]
Expand Down
48 changes: 28 additions & 20 deletions py-polars/polars/internals/construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@ def dataclass_type_hints(obj: type) -> dict[str, Any]:
from polars.internals.type_aliases import Orientation


def is_namedtuple(value: Any, annotated: bool = False) -> bool:
"""Infer whether value is a NamedTuple."""
if all(hasattr(value, attr) for attr in ("_fields", "_field_defaults", "_replace")):
return len(value.__annotations__) == len(value._fields) if annotated else True
return False


################################
# Series constructor interface #
################################
Expand Down Expand Up @@ -214,17 +221,20 @@ def sequence_to_pyseries(

value = _get_first_non_none(values)
if value is not None:
# for temporal dtypes:
# * if the values are integer, we take the physical branch.
# * if the values are python types, take the temporal branch.
# * if the values are ISO-8601 strings, init then convert via strptime.
if dtype in py_temporal_types and isinstance(value, int):
dtype = py_type_to_dtype(dtype) # construct from integer
elif (
dtype in pl_temporal_types or type(dtype) in pl_temporal_types
) and not isinstance(value, int):
temporal_unit = getattr(dtype, "tu", None)
python_dtype = dtype_to_py_type(dtype) # type: ignore[arg-type]
if is_dataclass(value) or is_namedtuple(value, annotated=True):
return pli.DataFrame(values).to_struct(name)._s
else:
# for temporal dtypes:
# * if the values are integer, we take the physical branch.
# * if the values are python types, take the temporal branch.
# * if the values are ISO-8601 strings, init then convert via strptime.
if dtype in py_temporal_types and isinstance(value, int):
dtype = py_type_to_dtype(dtype) # construct from integer
elif (
dtype in pl_temporal_types or type(dtype) in pl_temporal_types
) and not isinstance(value, int):
temporal_unit = getattr(dtype, "tu", None)
python_dtype = dtype_to_py_type(dtype) # type: ignore[arg-type]

# physical branch
# flat data
Expand All @@ -235,7 +245,6 @@ def sequence_to_pyseries(
if dtype in (Date, Datetime, Duration, Time, Categorical):
pyseries = pyseries.cast(dtype, True)
return pyseries

else:
if python_dtype is None:
if value is None:
Expand All @@ -244,7 +253,6 @@ def sequence_to_pyseries(
else:
python_dtype = type(value)
if datetime == python_dtype:
# note: python-native datetimes have microsecond precision
temporal_unit = "us"

# temporal branch
Expand Down Expand Up @@ -307,7 +315,7 @@ def sequence_to_pyseries(
dtype = py_type_to_dtype(nested_dtype)
with suppress(BaseException):
return PySeries.new_list(name, values, dtype)
# pass we create an object if we get here
# pass; we create an object if we get here
else:
try:
to_arrow_type = (
Expand Down Expand Up @@ -369,7 +377,6 @@ def sequence_to_pyseries(
# - bools: "'int' object cannot be converted to 'PyBool'"
elif str_val == "'int' object cannot be converted to 'PyBool'":
constructor = py_type_to_constructor(int)

else:
raise error

Expand Down Expand Up @@ -587,13 +594,14 @@ def sequence_to_pydf(
return pydf

elif isinstance(data[0], Sequence) and not isinstance(data[0], str):
# infer orientation
if all(
hasattr(data[0], attr)
for attr in ("_fields", "_field_defaults", "_replace")
): # namedtuple
if is_namedtuple(data[0]):
if columns is None:
columns = data[0]._fields # type: ignore[attr-defined]
if len(data[0].__annotations__) == len(columns):
columns = [
(name, py_type_to_dtype(tp, raise_unmatched=False))
for name, tp in data[0].__annotations__.items()
]
elif orient is None:
orient = "row"

Expand Down
65 changes: 63 additions & 2 deletions py-polars/tests/unit/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,22 @@
import pytest

import polars as pl
from polars.datatypes import Date, Datetime, Float64, Int32, Int64, Time, UInt32, UInt64
from polars.testing import assert_series_equal, verify_series_and_expr_api
from polars.datatypes import (
Date,
Datetime,
Field,
Float64,
Int32,
Int64,
Time,
UInt32,
UInt64,
)
from polars.testing import (
assert_frame_equal,
assert_series_equal,
verify_series_and_expr_api,
)

if TYPE_CHECKING:
from polars.internals.type_aliases import TimeUnit
Expand Down Expand Up @@ -100,6 +114,53 @@ def test_init_inputs(monkeypatch: Any) -> None:
pl.DataFrame(np.array([1, 2, 3]), columns=["a"])


def test_init_dataclass_namedtuple() -> None:
from dataclasses import dataclass
from typing import NamedTuple

@dataclass
class TeaShipmentDC:
exporter: str
importer: str
product: str
tonnes: int | None

class TeaShipmentNT(NamedTuple):
exporter: str
importer: str
product: str
tonnes: None | int

for Tea in (TeaShipmentDC, TeaShipmentNT):
t0 = Tea(exporter="Sri Lanka", importer="USA", product="Ceylon", tonnes=10)
t1 = Tea(exporter="India", importer="UK", product="Darjeeling", tonnes=25)

s = pl.Series("t", [t0, t1])

assert isinstance(s, pl.Series)
assert s.dtype.fields == [ # type: ignore[attr-defined]
Field("exporter", pl.Utf8),
Field("importer", pl.Utf8),
Field("product", pl.Utf8),
Field("tonnes", pl.Int64),
]
assert s.to_list() == [
{
"exporter": "Sri Lanka",
"importer": "USA",
"product": "Ceylon",
"tonnes": 10,
},
{
"exporter": "India",
"importer": "UK",
"product": "Darjeeling",
"tonnes": 25,
},
]
assert_frame_equal(s.to_frame(), pl.DataFrame({"t": [t0, t1]}))


def test_concat() -> None:
s = pl.Series("a", [2, 1, 3])

Expand Down