Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Improve Schema and DataType interop with Python types #18308

Merged
merged 3 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions py-polars/polars/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __arrow_c_stream__(self, requested_schema: object | None = None) -> object:
Type[List[Any]],
Type[Tuple[Any, ...]],
Type[bytes],
Type[object],
Type["Decimal"],
Type[None],
]
Expand Down
22 changes: 11 additions & 11 deletions py-polars/polars/datatypes/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,16 +97,14 @@ def parse_py_type_into_dtype(input: PythonDataType | type[object]) -> PolarsData
# this is required as pass through. Don't remove
elif input == Unknown:
return Unknown

elif hasattr(input, "__origin__") and hasattr(input, "__args__"):
return _parse_generic_into_dtype(input)

else:
_raise_on_invalid_dtype(input)


def _parse_generic_into_dtype(input: Any) -> PolarsDataType:
"""Parse a generic type into a Polars data type."""
"""Parse a generic type (from typing annotation) into a Polars data type."""
base_type = input.__origin__
if base_type not in (tuple, list):
_raise_on_invalid_dtype(input)
Expand All @@ -124,19 +122,19 @@ def _parse_generic_into_dtype(input: Any) -> PolarsDataType:


PY_TYPE_STR_TO_DTYPE: SchemaDict = {
"int": Int64(),
"float": Float64(),
"Decimal": Decimal,
"NoneType": Null(),
"bool": Boolean(),
"str": String(),
"bytes": Binary(),
"date": Date(),
"time": Time(),
"datetime": Datetime("us"),
"float": Float64(),
"int": Int64(),
"list": List,
"object": Object(),
"NoneType": Null(),
"str": String(),
"time": Time(),
"timedelta": Duration,
"Decimal": Decimal,
"list": List,
"tuple": List,
}

Expand Down Expand Up @@ -177,5 +175,7 @@ def _parse_union_type_into_dtype(input: Any) -> PolarsDataType:

def _raise_on_invalid_dtype(input: Any) -> NoReturn:
"""Raise an informative error if the input could not be parsed."""
msg = f"cannot parse input of type {type(input).__name__!r} into Polars data type: {input!r}"
input_type = input if type(input) is type else f"of type {type(input).__name__!r}"
input_detail = "" if type(input) is type else f" (given: {input!r})"
msg = f"cannot parse input {input_type} into Polars data type{input_detail}"
raise TypeError(msg) from None
51 changes: 51 additions & 0 deletions py-polars/polars/datatypes/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,14 @@ def is_temporal(cls) -> bool: # noqa: D102
def is_nested(cls) -> bool: # noqa: D102
...

@classmethod
def from_python(cls, py_type: PythonDataType) -> PolarsDataType: # noqa: D102
...

@classmethod
def to_python(self) -> PythonDataType: # noqa: D102
...


class DataType(metaclass=DataTypeClass):
"""Base class for all Polars data types."""
Expand Down Expand Up @@ -180,6 +188,49 @@ def is_nested(cls) -> bool:
"""Check whether the data type is a nested type."""
return issubclass(cls, NestedType)

@classmethod
def from_python(cls, py_type: PythonDataType) -> PolarsDataType:
"""
Return the Polars data type corresponding to a given Python type.

Notes
-----
Not every Python type has a corresponding Polars data type; in general
you should declare Polars data types explicitly to exactly specify
the desired type and its properties (such as scale/unit).

Examples
--------
>>> pl.DataType.from_python(int)
Int64
>>> pl.DataType.from_python(float)
Float64
>>> from datetime import tzinfo
>>> pl.DataType.from_python(tzinfo) # doctest: +SKIP
TypeError: cannot parse input <class 'datetime.tzinfo'> into Polars data type
"""
from polars.datatypes._parse import parse_into_dtype

return parse_into_dtype(py_type)

@classinstmethod # type: ignore[arg-type]
def to_python(self) -> PythonDataType:
"""
Return the Python type corresponding to this Polars data type.

Examples
--------
>>> pl.Int16().to_python()
<class 'int'>
>>> pl.Float32().to_python()
<class 'float'>
>>> pl.Array(pl.Date(), 10).to_python()
<class 'list'>
"""
from polars.datatypes import dtype_to_py_type

return dtype_to_py_type(self)


class NumericType(DataType):
"""Base class for numeric data types."""
Expand Down
82 changes: 44 additions & 38 deletions py-polars/polars/datatypes/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Datetime,
Decimal,
Duration,
Enum,
Field,
Float32,
Float64,
Expand Down Expand Up @@ -134,88 +135,93 @@ class _DataTypeMappings:
@functools.lru_cache # noqa: B019
def DTYPE_TO_FFINAME(self) -> dict[PolarsDataType, str]:
return {
Int8: "i8",
Int16: "i16",
Int32: "i32",
Int64: "i64",
UInt8: "u8",
UInt16: "u16",
UInt32: "u32",
UInt64: "u64",
Float32: "f32",
Float64: "f64",
Decimal: "decimal",
Binary: "binary",
Boolean: "bool",
String: "str",
List: "list",
Categorical: "categorical",
Date: "date",
Datetime: "datetime",
Decimal: "decimal",
Duration: "duration",
Time: "time",
Float32: "f32",
Float64: "f64",
Int16: "i16",
Int32: "i32",
Int64: "i64",
Int8: "i8",
List: "list",
Object: "object",
Categorical: "categorical",
String: "str",
Struct: "struct",
Binary: "binary",
Time: "time",
UInt16: "u16",
UInt32: "u32",
UInt64: "u64",
UInt8: "u8",
}

@property
@functools.lru_cache # noqa: B019
def DTYPE_TO_PY_TYPE(self) -> dict[PolarsDataType, PythonDataType]:
return {
Float64: float,
Array: list,
Binary: bytes,
Boolean: bool,
Date: date,
Datetime: datetime,
Decimal: PyDecimal,
Duration: timedelta,
Float32: float,
Int64: int,
Int32: int,
Float64: float,
Int16: int,
Int32: int,
Int64: int,
Int8: int,
List: list,
Null: None.__class__,
Object: object,
String: str,
UInt8: int,
Struct: dict,
Time: time,
UInt16: int,
UInt32: int,
UInt64: int,
Decimal: PyDecimal,
Boolean: bool,
Duration: timedelta,
Datetime: datetime,
Date: date,
Time: time,
Binary: bytes,
List: list,
Array: list,
Null: None.__class__,
UInt8: int,
# the below mappings are appropriate as we restrict cat/enum to strings
Enum: str,
Categorical: str,
}

@property
@functools.lru_cache # noqa: B019
def NUMPY_KIND_AND_ITEMSIZE_TO_DTYPE(self) -> dict[tuple[str, int], PolarsDataType]:
return {
# (np.dtype().kind, np.dtype().itemsize)
("M", 8): Datetime,
("b", 1): Boolean,
("f", 4): Float32,
("f", 8): Float64,
("i", 1): Int8,
("i", 2): Int16,
("i", 4): Int32,
("i", 8): Int64,
("m", 8): Duration,
("u", 1): UInt8,
("u", 2): UInt16,
("u", 4): UInt32,
("u", 8): UInt64,
("f", 4): Float32,
("f", 8): Float64,
("m", 8): Duration,
("M", 8): Datetime,
}

@property
@functools.lru_cache # noqa: B019
def PY_TYPE_TO_ARROW_TYPE(self) -> dict[PythonDataType, pa.lib.DataType]:
return {
bool: pa.bool_(),
date: pa.date32(),
datetime: pa.timestamp("us"),
float: pa.float64(),
int: pa.int64(),
str: pa.large_utf8(),
bool: pa.bool_(),
date: pa.date32(),
time: pa.time64("us"),
datetime: pa.timestamp("us"),
timedelta: pa.duration("us"),
None.__class__: pa.null(),
}
Expand Down Expand Up @@ -338,7 +344,7 @@ def maybe_cast(el: Any, dtype: PolarsDataType) -> Any:
py_type = dtype_to_py_type(dtype)
if not isinstance(el, py_type):
try:
el = py_type(el) # type: ignore[call-arg, misc]
el = py_type(el) # type: ignore[call-arg]
except Exception:
msg = f"cannot convert Python type {type(el).__name__!r} to {dtype!r}"
raise TypeError(msg) from None
Expand Down
33 changes: 29 additions & 4 deletions py-polars/polars/schema.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from __future__ import annotations

from collections import OrderedDict
from typing import TYPE_CHECKING, Iterable, Mapping
from collections.abc import Mapping
from typing import TYPE_CHECKING, Iterable

from polars.datatypes._parse import parse_into_dtype

if TYPE_CHECKING:
from polars._typing import PythonDataType
from polars.datatypes import DataType

BaseSchema = OrderedDict[str, DataType]
Expand Down Expand Up @@ -49,10 +53,19 @@ class Schema(BaseSchema):

def __init__(
self,
schema: Mapping[str, DataType] | Iterable[tuple[str, DataType]] | None = None,
schema: (
Mapping[str, DataType | PythonDataType]
| Iterable[tuple[str, DataType | PythonDataType]]
| None
) = None,
):
schema = schema or {}
super().__init__(schema)
input = (
schema.items() if schema and isinstance(schema, Mapping) else (schema or {})
)
super().__init__({name: parse_into_dtype(tp) for name, tp in input}) # type: ignore[misc]

def __setitem__(self, name: str, dtype: DataType | PythonDataType) -> None:
super().__setitem__(name, parse_into_dtype(dtype)) # type: ignore[assignment]

def names(self) -> list[str]:
"""Get the column names of the schema."""
Expand All @@ -65,3 +78,15 @@ def dtypes(self) -> list[DataType]:
def len(self) -> int:
"""Get the number of columns in the schema."""
return len(self)

def to_python(self) -> dict[str, type]:
"""
Return Schema as a dictionary of column names and their Python types.

Examples
--------
>>> s = pl.Schema({"x": pl.Int8(), "y": pl.String(), "z": pl.Duration("ms")})
>>> s.to_python()
{'x': <class 'int'>, 'y': <class 'str'>, 'z': <class 'datetime.timedelta'>}
"""
return {name: tp.to_python() for name, tp in self.items()}
10 changes: 8 additions & 2 deletions py-polars/tests/unit/test_errors.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import io
from datetime import date, datetime, time
from datetime import date, datetime, time, tzinfo
from decimal import Decimal
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -326,10 +326,16 @@ def test_datetime_time_add_err() -> None:
def test_invalid_dtype() -> None:
with pytest.raises(
TypeError,
match="cannot parse input of type 'str' into Polars data type: 'mayonnaise'",
match=r"cannot parse input of type 'str' into Polars data type \(given: 'mayonnaise'\)",
):
pl.Series([1, 2], dtype="mayonnaise") # type: ignore[arg-type]

with pytest.raises(
TypeError,
match="cannot parse input <class 'datetime.tzinfo'> into Polars data type",
):
pl.Series([None], dtype=tzinfo) # type: ignore[arg-type]


def test_arr_eval_named_cols() -> None:
df = pl.DataFrame({"A": ["a", "b"], "B": [["a", "b"], ["c", "d"]]})
Expand Down
Loading