Skip to content

Commit

Permalink
feat(python): additional "schema_overrides" param for DataFrame init
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed Jan 16, 2023
1 parent 9df1c69 commit 81da75d
Show file tree
Hide file tree
Showing 15 changed files with 657 additions and 218 deletions.
93 changes: 76 additions & 17 deletions py-polars/polars/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,22 @@

from typing import TYPE_CHECKING, Any, Mapping, Sequence, overload

from polars.datatypes import N_INFER_DEFAULT, Schema
from polars.datatypes import N_INFER_DEFAULT, SchemaDefinition, SchemaDict
from polars.dependencies import numpy as np
from polars.dependencies import pandas as pd
from polars.dependencies import pyarrow as pa
from polars.internals import DataFrame, Series
from polars.utils import deprecated_alias

if TYPE_CHECKING:
from polars.internals.type_aliases import Orientation


def from_dict(
data: Mapping[str, Sequence[object] | Mapping[str, Sequence[object]] | Series],
columns: Sequence[str] | None = None,
columns: SchemaDefinition | None = None,
*,
schema_overrides: SchemaDict | None = None,
) -> DataFrame:
"""
Construct a DataFrame from a dictionary of sequences.
Expand All @@ -29,15 +32,17 @@ def from_dict(
columns : Sequence of str, default None
Column labels to use for resulting DataFrame. If specified, overrides any
labels already present in the data. Must match data dimensions.
schema_overrides : dict, default None
Support type specification or override of one or more columns; note that
any dtypes inferred from the columns param will be overridden.
Returns
-------
:class:`DataFrame`
Examples
--------
>>> data = {"a": [1, 2], "b": [3, 4]}
>>> df = pl.from_dict(data)
>>> df = pl.from_dict({"a": [1, 2], "b": [3, 4]})
>>> df
shape: (2, 2)
┌─────┬─────┐
Expand All @@ -50,14 +55,17 @@ def from_dict(
└─────┴─────┘
"""
return DataFrame._from_dict(data=data, columns=columns)
return DataFrame._from_dict(
data=data, schema=columns, schema_overrides=schema_overrides
)


@deprecated_alias(schema="schema_overrides")
def from_dicts(
dicts: Sequence[dict[str, Any]],
infer_schema_length: int | None = N_INFER_DEFAULT,
*,
schema: Schema | None = None,
schema_overrides: SchemaDict | None = None,
) -> DataFrame:
"""
Construct a DataFrame from a sequence of dictionaries. This operation clones data.
Expand All @@ -69,8 +77,8 @@ def from_dicts(
infer_schema_length
How many dictionaries/rows to scan to determine the data types
if set to `None` all rows are scanned. This will be slow.
schema
Schema that (partially) overwrites the inferred schema.
schema_overrides : dict, default None
Support override of inferred types for one or more columns.
Returns
-------
Expand All @@ -93,7 +101,7 @@ def from_dicts(
└─────┴─────┘
>>> # overwrite first column name and dtype
>>> pl.from_dicts(data, schema={"c": pl.Int32})
>>> pl.from_dicts(data, schema_overrides={"c": pl.Int32})
shape: (3, 2)
┌─────┬─────┐
│ c ┆ b │
Expand All @@ -107,7 +115,9 @@ def from_dicts(
>>> # let polars infer the dtypes
>>> # but inform about a 3rd column
>>> pl.from_dicts(data, schema={"a": pl.Unknown, "b": pl.Unknown, "c": pl.Int32})
>>> pl.from_dicts(
... data, schema_overrides={"a": pl.Unknown, "b": pl.Unknown, "c": pl.Int32}
... )
shape: (3, 3)
┌─────┬─────┬──────┐
│ a ┆ b ┆ c │
Expand All @@ -120,14 +130,17 @@ def from_dicts(
└─────┴─────┴──────┘
"""
return DataFrame._from_dicts(dicts, infer_schema_length, schema)
return DataFrame._from_dicts(
dicts, infer_schema_length, schema_overrides=schema_overrides
)


def from_records(
data: Sequence[Sequence[Any]],
columns: Sequence[str] | None = None,
orient: Orientation | None = None,
infer_schema_length: int | None = N_INFER_DEFAULT,
schema_overrides: SchemaDict | None = None,
) -> DataFrame:
"""
Construct a DataFrame from a sequence of sequences. This operation clones data.
Expand All @@ -148,6 +161,9 @@ def from_records(
infer_schema_length
How many dictionaries/rows to scan to determine the data types
if set to `None` all rows are scanned. This will be slow.
schema_overrides : dict, default None
Support type specification or override of one or more columns; note that
any dtypes inferred from the columns param will be overridden.
Returns
-------
Expand All @@ -171,14 +187,19 @@ def from_records(
"""
return DataFrame._from_records(
data, columns=columns, orient=orient, infer_schema_length=infer_schema_length
data,
columns=columns,
schema_overrides=schema_overrides,
orient=orient,
infer_schema_length=infer_schema_length,
)


def from_numpy(
data: np.ndarray[Any, Any],
columns: Sequence[str] | None = None,
orient: Orientation | None = None,
schema_overrides: SchemaDict | None = None,
) -> DataFrame:
"""
Construct a DataFrame from a numpy ndarray. This operation clones data.
Expand All @@ -196,6 +217,9 @@ def from_numpy(
Whether to interpret two-dimensional data as columns or as rows. If None,
the orientation is inferred by matching the columns and data dimensions. If
this does not yield conclusive results, column orientation is used.
schema_overrides : dict, default None
Support type specification or override of one or more columns; note that
any dtypes inferred from the columns param will be overridden.
Returns
-------
Expand All @@ -219,11 +243,16 @@ def from_numpy(
└─────┴─────┘
"""
return DataFrame._from_numpy(data, columns=columns, orient=orient)
return DataFrame._from_numpy(
data, columns=columns, orient=orient, schema_overrides=schema_overrides
)


def from_arrow(
a: pa.Table | pa.Array | pa.ChunkedArray, rechunk: bool = True
a: pa.Table | pa.Array | pa.ChunkedArray,
rechunk: bool = True,
schema: Sequence[str] | None = None,
schema_overrides: SchemaDict | None = None,
) -> DataFrame | Series:
"""
Create a DataFrame or Series from an Arrow Table or Array.
Expand All @@ -234,9 +263,27 @@ def from_arrow(
Parameters
----------
a : :class:`pyarrow.Table` or :class:`pyarrow.Array`
Data represented as Arrow Table or Array.
Data representing an Arrow Table or Array.
rechunk : bool, default True
Make sure that all data is in contiguous memory.
schema : Sequence of str, dict, default None
Column labels to use for resulting DataFrame. Must match data dimensions.
If not specified, existing Array table columns are used, with missing names
named as `column_0`, `column_1`, etc.
schema : Sequence of str, (str,DataType) pairs, or a {str:DataType,} dict
The resulting DataFrame schema may be declared in several ways:
* As a dict of {name:type} pairs; if the type is None, it will be auto-inferred.
* As a list of column names; in this case types are all automatically inferred.
* As a list of (name,type) pairs; this is equivalent to the dictionary form.
If you supply a list of column names that does not match the names in the
underlying data, the names supplied here will overwrite them. The number
of names given in the schema should match the underlying data dimensions.
schema_overrides : dict, default None
Support type specification or override of one or more columns; note that
any dtypes inferred from the schema param will be overridden.
Returns
-------
Expand Down Expand Up @@ -277,7 +324,9 @@ def from_arrow(
"""
if isinstance(a, pa.Table):
return DataFrame._from_arrow(a, rechunk=rechunk)
return DataFrame._from_arrow(
a, rechunk=rechunk, columns=schema, schema_overrides=schema_overrides
)
elif isinstance(a, (pa.Array, pa.ChunkedArray)):
return Series._from_arrow("", a, rechunk)
else:
Expand All @@ -289,6 +338,7 @@ def from_pandas(
df: pd.DataFrame,
rechunk: bool = True,
nan_to_none: bool = True,
schema_overrides: SchemaDict | None = None,
) -> DataFrame:
...

Expand All @@ -298,6 +348,7 @@ def from_pandas(
df: pd.Series | pd.DatetimeIndex,
rechunk: bool = True,
nan_to_none: bool = True,
schema_overrides: SchemaDict | None = None,
) -> Series:
...

Expand All @@ -306,6 +357,7 @@ def from_pandas(
df: pd.DataFrame | pd.Series | pd.DatetimeIndex,
rechunk: bool = True,
nan_to_none: bool = True,
schema_overrides: SchemaDict | None = None,
) -> DataFrame | Series:
"""
Construct a Polars DataFrame or Series from a pandas DataFrame or Series.
Expand All @@ -322,6 +374,8 @@ def from_pandas(
Make sure that all data is in contiguous memory.
nan_to_none : bool, default True
If data contains `NaN` values PyArrow will convert the ``NaN`` to ``None``
schema_overrides : dict, default None
Support override of inferred types for one or more columns.
Returns
-------
Expand Down Expand Up @@ -363,6 +417,11 @@ def from_pandas(
if isinstance(df, (pd.Series, pd.DatetimeIndex)):
return Series._from_pandas("", df, nan_to_none=nan_to_none)
elif isinstance(df, pd.DataFrame):
return DataFrame._from_pandas(df, rechunk=rechunk, nan_to_none=nan_to_none)
return DataFrame._from_pandas(
df,
rechunk=rechunk,
nan_to_none=nan_to_none,
schema_overrides=schema_overrides,
)
else:
raise ValueError(f"Expected pandas DataFrame or Series, got {type(df)}.")
12 changes: 6 additions & 6 deletions py-polars/polars/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,12 @@ def get_args(tp: Any) -> Any:
Type[Decimal],
]

ColumnsType: TypeAlias = Union[
SchemaDefinition: TypeAlias = Union[
Sequence[str],
Mapping[str, Union[PolarsDataType, PythonDataType]],
Sequence[Union[str, Tuple[str, Union[PolarsDataType, PythonDataType, None]]]],
]
Schema: TypeAlias = Mapping[str, PolarsDataType]
SchemaDict: TypeAlias = Mapping[str, PolarsDataType]

DTYPE_TEMPORAL_UNITS: frozenset[TimeUnit] = frozenset(["ns", "us", "ms"])

Expand Down Expand Up @@ -375,7 +375,7 @@ def __repr__(self) -> str:


class Struct(NestedType):
def __init__(self, fields: Sequence[Field] | Mapping[str, PolarsDataType]):
def __init__(self, fields: Sequence[Field] | SchemaDict):
"""
Struct composite type.
Expand Down Expand Up @@ -415,7 +415,7 @@ def __repr__(self) -> str:
class_name = self.__class__.__name__
return f"{class_name}({self.fields})"

def to_schema(self) -> dict[str, PolarsDataType] | None:
def to_schema(self) -> SchemaDict | None:
"""Return Struct dtype as a schema dict."""
return dict(self)

Expand Down Expand Up @@ -510,7 +510,7 @@ def PY_TYPE_TO_DTYPE(self) -> dict[PythonDataType | type[object], PolarsDataType

@property
@cache
def PY_STR_TO_DTYPE(self) -> dict[str, PolarsDataType]:
def PY_STR_TO_DTYPE(self) -> SchemaDict:
return {str(tp.__name__): dtype for tp, dtype in self.PY_TYPE_TO_DTYPE.items()}

@property
Expand Down Expand Up @@ -539,7 +539,7 @@ def DTYPE_TO_PY_TYPE(self) -> dict[PolarsDataType, PythonDataType]:

@property
@cache
def NUMPY_CHAR_CODE_TO_DTYPE(self) -> dict[str, PolarsDataType]:
def NUMPY_CHAR_CODE_TO_DTYPE(self) -> SchemaDict:
# Note: Windows behaves differently from other platforms as C long
# is only 32-bit on Windows, while it is 64-bit on other platforms.
# See: https://numpy.org/doc/stable/reference/arrays.scalars.html
Expand Down
11 changes: 8 additions & 3 deletions py-polars/polars/internals/batched.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from __future__ import annotations

from pathlib import Path
from typing import Mapping, Sequence
from typing import Sequence

import polars.internals as pli
from polars.datatypes import N_INFER_DEFAULT, PolarsDataType, py_type_to_dtype
from polars.datatypes import (
N_INFER_DEFAULT,
PolarsDataType,
SchemaDict,
py_type_to_dtype,
)
from polars.internals.type_aliases import CsvEncoding
from polars.utils import (
_prepare_row_count_args,
Expand All @@ -31,7 +36,7 @@ def __init__(
comment_char: str | None = None,
quote_char: str | None = r'"',
skip_rows: int = 0,
dtypes: None | (Mapping[str, PolarsDataType] | Sequence[PolarsDataType]) = None,
dtypes: None | (SchemaDict | Sequence[PolarsDataType]) = None,
null_values: str | list[str] | dict[str, str] | None = None,
missing_utf8_is_empty_string: bool = False,
ignore_errors: bool = False,
Expand Down
Loading

0 comments on commit 81da75d

Please sign in to comment.