diff --git a/py-polars/polars/convert.py b/py-polars/polars/convert.py index bb5366e62c33..08263d842116 100644 --- a/py-polars/polars/convert.py +++ b/py-polars/polars/convert.py @@ -2,11 +2,12 @@ from typing import TYPE_CHECKING, Any, Mapping, Sequence, overload -from polars.datatypes import N_INFER_DEFAULT, Schema +from polars.datatypes import N_INFER_DEFAULT, SchemaDefinition, SchemaDict from polars.dependencies import numpy as np from polars.dependencies import pandas as pd from polars.dependencies import pyarrow as pa from polars.internals import DataFrame, Series +from polars.utils import deprecated_alias if TYPE_CHECKING: from polars.internals.type_aliases import Orientation @@ -14,7 +15,9 @@ def from_dict( data: Mapping[str, Sequence[object] | Mapping[str, Sequence[object]] | Series], - columns: Sequence[str] | None = None, + columns: SchemaDefinition | None = None, + *, + schema_overrides: SchemaDict | None = None, ) -> DataFrame: """ Construct a DataFrame from a dictionary of sequences. @@ -29,6 +32,9 @@ def from_dict( columns : Sequence of str, default None Column labels to use for resulting DataFrame. If specified, overrides any labels already present in the data. Must match data dimensions. + schema_overrides : dict, default None + Support type specification or override of one or more columns; note that + any dtypes inferred from the columns param will be overridden. Returns ------- @@ -36,8 +42,7 @@ def from_dict( Examples -------- - >>> data = {"a": [1, 2], "b": [3, 4]} - >>> df = pl.from_dict(data) + >>> df = pl.from_dict({"a": [1, 2], "b": [3, 4]}) >>> df shape: (2, 2) ┌─────┬─────┐ @@ -50,14 +55,17 @@ def from_dict( └─────┴─────┘ """ - return DataFrame._from_dict(data=data, columns=columns) + return DataFrame._from_dict( + data=data, schema=columns, schema_overrides=schema_overrides + ) +@deprecated_alias(schema="schema_overrides") def from_dicts( dicts: Sequence[dict[str, Any]], infer_schema_length: int | None = N_INFER_DEFAULT, *, - schema: Schema | None = None, + schema_overrides: SchemaDict | None = None, ) -> DataFrame: """ Construct a DataFrame from a sequence of dictionaries. This operation clones data. @@ -69,8 +77,8 @@ def from_dicts( infer_schema_length How many dictionaries/rows to scan to determine the data types if set to `None` all rows are scanned. This will be slow. - schema - Schema that (partially) overwrites the inferred schema. + schema_overrides : dict, default None + Support override of inferred types for one or more columns. Returns ------- @@ -93,7 +101,7 @@ def from_dicts( └─────┴─────┘ >>> # overwrite first column name and dtype - >>> pl.from_dicts(data, schema={"c": pl.Int32}) + >>> pl.from_dicts(data, schema_overrides={"c": pl.Int32}) shape: (3, 2) ┌─────┬─────┐ │ c ┆ b │ @@ -107,7 +115,9 @@ def from_dicts( >>> # let polars infer the dtypes >>> # but inform about a 3rd column - >>> pl.from_dicts(data, schema={"a": pl.Unknown, "b": pl.Unknown, "c": pl.Int32}) + >>> pl.from_dicts( + ... data, schema_overrides={"a": pl.Unknown, "b": pl.Unknown, "c": pl.Int32} + ... ) shape: (3, 3) ┌─────┬─────┬──────┐ │ a ┆ b ┆ c │ @@ -120,7 +130,9 @@ def from_dicts( └─────┴─────┴──────┘ """ - return DataFrame._from_dicts(dicts, infer_schema_length, schema) + return DataFrame._from_dicts( + dicts, infer_schema_length, schema_overrides=schema_overrides + ) def from_records( @@ -128,6 +140,7 @@ def from_records( columns: Sequence[str] | None = None, orient: Orientation | None = None, infer_schema_length: int | None = N_INFER_DEFAULT, + schema_overrides: SchemaDict | None = None, ) -> DataFrame: """ Construct a DataFrame from a sequence of sequences. This operation clones data. @@ -148,6 +161,9 @@ def from_records( infer_schema_length How many dictionaries/rows to scan to determine the data types if set to `None` all rows are scanned. This will be slow. + schema_overrides : dict, default None + Support type specification or override of one or more columns; note that + any dtypes inferred from the columns param will be overridden. Returns ------- @@ -171,7 +187,11 @@ def from_records( """ return DataFrame._from_records( - data, columns=columns, orient=orient, infer_schema_length=infer_schema_length + data, + columns=columns, + schema_overrides=schema_overrides, + orient=orient, + infer_schema_length=infer_schema_length, ) @@ -179,6 +199,7 @@ def from_numpy( data: np.ndarray[Any, Any], columns: Sequence[str] | None = None, orient: Orientation | None = None, + schema_overrides: SchemaDict | None = None, ) -> DataFrame: """ Construct a DataFrame from a numpy ndarray. This operation clones data. @@ -196,6 +217,9 @@ def from_numpy( Whether to interpret two-dimensional data as columns or as rows. If None, the orientation is inferred by matching the columns and data dimensions. If this does not yield conclusive results, column orientation is used. + schema_overrides : dict, default None + Support type specification or override of one or more columns; note that + any dtypes inferred from the columns param will be overridden. Returns ------- @@ -219,11 +243,16 @@ def from_numpy( └─────┴─────┘ """ - return DataFrame._from_numpy(data, columns=columns, orient=orient) + return DataFrame._from_numpy( + data, columns=columns, orient=orient, schema_overrides=schema_overrides + ) def from_arrow( - a: pa.Table | pa.Array | pa.ChunkedArray, rechunk: bool = True + a: pa.Table | pa.Array | pa.ChunkedArray, + rechunk: bool = True, + schema: Sequence[str] | None = None, + schema_overrides: SchemaDict | None = None, ) -> DataFrame | Series: """ Create a DataFrame or Series from an Arrow Table or Array. @@ -234,9 +263,27 @@ def from_arrow( Parameters ---------- a : :class:`pyarrow.Table` or :class:`pyarrow.Array` - Data represented as Arrow Table or Array. + Data representing an Arrow Table or Array. rechunk : bool, default True Make sure that all data is in contiguous memory. + schema : Sequence of str, dict, default None + Column labels to use for resulting DataFrame. Must match data dimensions. + If not specified, existing Array table columns are used, with missing names + named as `column_0`, `column_1`, etc. + schema : Sequence of str, (str,DataType) pairs, or a {str:DataType,} dict + The resulting DataFrame schema may be declared in several ways: + + * As a dict of {name:type} pairs; if the type is None, it will be auto-inferred. + * As a list of column names; in this case types are all automatically inferred. + * As a list of (name,type) pairs; this is equivalent to the dictionary form. + + If you supply a list of column names that does not match the names in the + underlying data, the names supplied here will overwrite them. The number + of names given in the schema should match the underlying data dimensions. + + schema_overrides : dict, default None + Support type specification or override of one or more columns; note that + any dtypes inferred from the schema param will be overridden. Returns ------- @@ -277,7 +324,9 @@ def from_arrow( """ if isinstance(a, pa.Table): - return DataFrame._from_arrow(a, rechunk=rechunk) + return DataFrame._from_arrow( + a, rechunk=rechunk, columns=schema, schema_overrides=schema_overrides + ) elif isinstance(a, (pa.Array, pa.ChunkedArray)): return Series._from_arrow("", a, rechunk) else: @@ -289,6 +338,7 @@ def from_pandas( df: pd.DataFrame, rechunk: bool = True, nan_to_none: bool = True, + schema_overrides: SchemaDict | None = None, ) -> DataFrame: ... @@ -298,6 +348,7 @@ def from_pandas( df: pd.Series | pd.DatetimeIndex, rechunk: bool = True, nan_to_none: bool = True, + schema_overrides: SchemaDict | None = None, ) -> Series: ... @@ -306,6 +357,7 @@ def from_pandas( df: pd.DataFrame | pd.Series | pd.DatetimeIndex, rechunk: bool = True, nan_to_none: bool = True, + schema_overrides: SchemaDict | None = None, ) -> DataFrame | Series: """ Construct a Polars DataFrame or Series from a pandas DataFrame or Series. @@ -322,6 +374,8 @@ def from_pandas( Make sure that all data is in contiguous memory. nan_to_none : bool, default True If data contains `NaN` values PyArrow will convert the ``NaN`` to ``None`` + schema_overrides : dict, default None + Support override of inferred types for one or more columns. Returns ------- @@ -363,6 +417,11 @@ def from_pandas( if isinstance(df, (pd.Series, pd.DatetimeIndex)): return Series._from_pandas("", df, nan_to_none=nan_to_none) elif isinstance(df, pd.DataFrame): - return DataFrame._from_pandas(df, rechunk=rechunk, nan_to_none=nan_to_none) + return DataFrame._from_pandas( + df, + rechunk=rechunk, + nan_to_none=nan_to_none, + schema_overrides=schema_overrides, + ) else: raise ValueError(f"Expected pandas DataFrame or Series, got {type(df)}.") diff --git a/py-polars/polars/datatypes.py b/py-polars/polars/datatypes.py index 31004f54093a..678366ea82a0 100644 --- a/py-polars/polars/datatypes.py +++ b/py-polars/polars/datatypes.py @@ -85,12 +85,12 @@ def get_args(tp: Any) -> Any: Type[Decimal], ] -ColumnsType: TypeAlias = Union[ +SchemaDefinition: TypeAlias = Union[ Sequence[str], Mapping[str, Union[PolarsDataType, PythonDataType]], Sequence[Union[str, Tuple[str, Union[PolarsDataType, PythonDataType, None]]]], ] -Schema: TypeAlias = Mapping[str, PolarsDataType] +SchemaDict: TypeAlias = Mapping[str, PolarsDataType] DTYPE_TEMPORAL_UNITS: frozenset[TimeUnit] = frozenset(["ns", "us", "ms"]) @@ -375,7 +375,7 @@ def __repr__(self) -> str: class Struct(NestedType): - def __init__(self, fields: Sequence[Field] | Mapping[str, PolarsDataType]): + def __init__(self, fields: Sequence[Field] | SchemaDict): """ Struct composite type. @@ -415,7 +415,7 @@ def __repr__(self) -> str: class_name = self.__class__.__name__ return f"{class_name}({self.fields})" - def to_schema(self) -> dict[str, PolarsDataType] | None: + def to_schema(self) -> SchemaDict | None: """Return Struct dtype as a schema dict.""" return dict(self) @@ -510,7 +510,7 @@ def PY_TYPE_TO_DTYPE(self) -> dict[PythonDataType | type[object], PolarsDataType @property @cache - def PY_STR_TO_DTYPE(self) -> dict[str, PolarsDataType]: + def PY_STR_TO_DTYPE(self) -> SchemaDict: return {str(tp.__name__): dtype for tp, dtype in self.PY_TYPE_TO_DTYPE.items()} @property @@ -539,7 +539,7 @@ def DTYPE_TO_PY_TYPE(self) -> dict[PolarsDataType, PythonDataType]: @property @cache - def NUMPY_CHAR_CODE_TO_DTYPE(self) -> dict[str, PolarsDataType]: + def NUMPY_CHAR_CODE_TO_DTYPE(self) -> SchemaDict: # Note: Windows behaves differently from other platforms as C long # is only 32-bit on Windows, while it is 64-bit on other platforms. # See: https://numpy.org/doc/stable/reference/arrays.scalars.html diff --git a/py-polars/polars/internals/batched.py b/py-polars/polars/internals/batched.py index 88c07658a4ca..002671a46a6a 100644 --- a/py-polars/polars/internals/batched.py +++ b/py-polars/polars/internals/batched.py @@ -1,10 +1,15 @@ from __future__ import annotations from pathlib import Path -from typing import Mapping, Sequence +from typing import Sequence import polars.internals as pli -from polars.datatypes import N_INFER_DEFAULT, PolarsDataType, py_type_to_dtype +from polars.datatypes import ( + N_INFER_DEFAULT, + PolarsDataType, + SchemaDict, + py_type_to_dtype, +) from polars.internals.type_aliases import CsvEncoding from polars.utils import ( _prepare_row_count_args, @@ -31,7 +36,7 @@ def __init__( comment_char: str | None = None, quote_char: str | None = r'"', skip_rows: int = 0, - dtypes: None | (Mapping[str, PolarsDataType] | Sequence[PolarsDataType]) = None, + dtypes: None | (SchemaDict | Sequence[PolarsDataType]) = None, null_values: str | list[str] | dict[str, str] | None = None, missing_utf8_is_empty_string: bool = False, ignore_errors: bool = False, diff --git a/py-polars/polars/internals/construction.py b/py-polars/polars/internals/construction.py index f7d5309a5f77..acf8397abe9b 100644 --- a/py-polars/polars/internals/construction.py +++ b/py-polars/polars/internals/construction.py @@ -11,6 +11,7 @@ Generator, Iterable, Mapping, + MutableMapping, Sequence, get_type_hints, ) @@ -19,13 +20,14 @@ from polars.datatypes import ( N_INFER_DEFAULT, Categorical, - ColumnsType, Date, Datetime, Duration, Float32, List, PolarsDataType, + SchemaDefinition, + SchemaDict, Struct, Time, Unknown, @@ -85,8 +87,8 @@ def is_namedtuple(value: Any, annotated: bool = False) -> bool: def include_unknowns( - schema: dict[str, PolarsDataType], cols: Sequence[str] -) -> dict[str, PolarsDataType]: + schema: SchemaDict, cols: Sequence[str] +) -> MutableMapping[str, PolarsDataType]: """Complete partial schema dict by including Unknown type.""" return {col: schema.get(col, Unknown) for col in cols} @@ -356,11 +358,10 @@ def sequence_to_pyseries( strict=strict, ) - # logs will show a panic if we infer wrong dtype - # and its hard to error from rust side - # to reduce the likelihood of this happening - # we infer the dtype of first 100 elements - # if all() fails, we will hit the PySeries.new_object + # logs will show a panic if we infer wrong dtype and it's hard to error + # from the rust side. to reduce the likelihood of this happening we + # infer the dtype of first 100 elements; if all() fails, we will hit + # the PySeries.new_object if not _PYARROW_AVAILABLE: # check lists for consistent inner types if isinstance(value, list): @@ -513,7 +514,7 @@ def pandas_to_pyseries( def _handle_columns_arg( - data: list[PySeries], columns: Sequence[str] | None = None + data: list[PySeries], columns: Sequence[str] | None = None, from_dict: bool = False ) -> list[PySeries]: """Rename data according to columns argument.""" if not columns: @@ -522,6 +523,11 @@ def _handle_columns_arg( if not data: return [pli.Series(c, None)._s for c in columns] elif len(data) == len(columns): + if from_dict: + series_map = {s.name(): s for s in data} + if all((col in series_map) for col in columns): + return [series_map[col] for col in columns] + for i, c in enumerate(columns): data[i].rename(c) return data @@ -531,13 +537,15 @@ def _handle_columns_arg( def _post_apply_columns( pydf: PyDataFrame, - columns: ColumnsType | None, - categoricals: set[str] | None = None, + columns: SchemaDefinition | None, structs: dict[str, Struct] | None = None, + schema_overrides: SchemaDict | None = None, ) -> PyDataFrame: """Apply 'columns' param _after_ PyDataFrame creation (if no alternative).""" pydf_columns, pydf_dtypes = pydf.columns(), pydf.dtypes() - columns, dtypes = _unpack_columns(columns or pydf_columns) + columns, dtypes = _unpack_columns( + (columns or pydf_columns), schema_overrides=schema_overrides + ) column_subset: list[str] = [] if columns != pydf_columns: if len(columns) < len(pydf_columns) and columns == pydf_columns[: len(columns)]: @@ -547,7 +555,7 @@ def _post_apply_columns( column_casts = [] for i, col in enumerate(columns): - if categoricals and col in categoricals: + if dtypes.get(col) == Categorical != pydf_dtypes[i]: column_casts.append(pli.col(col).cast(Categorical)._pyexpr) elif structs and col in structs and structs[col] != pydf_dtypes[i]: column_casts.append(pli.col(col).cast(structs[col])._pyexpr) @@ -566,14 +574,16 @@ def _post_apply_columns( def _unpack_columns( - columns: ColumnsType | None, - lookup_names: Iterable[str] | None = None, + columns: SchemaDefinition | None, + schema_overrides: SchemaDict | None = None, n_expected: int | None = None, -) -> tuple[list[str], dict[str, PolarsDataType]]: + lookup_names: Iterable[str] | None = None, +) -> tuple[list[str], SchemaDict]: """ Unpack column names and create dtype lookup. - Works for any (name, dtype) pairs or schema dict input. + Works for any (name, dtype) pairs or schema dict input, + overriding any inferred dtypes with explicit dtypes if supplied. """ if isinstance(columns, dict): columns = list(columns.items()) @@ -586,26 +596,29 @@ def _unpack_columns( lookup = { col: name for col, name in zip_longest(column_names, lookup_names or []) if name } - dtypes = { + column_dtypes = { lookup.get(col[0], col[0]): col[1] for col in (columns or []) if not isinstance(col, str) and col[1] } + if schema_overrides: + column_dtypes.update(schema_overrides) + return ( column_names or None, # type: ignore[return-value] - dtypes, + column_dtypes, ) def _expand_dict_scalars( data: Mapping[str, Sequence[object] | Mapping[str, Sequence[object]] | pli.Series], - dtypes: dict[str, PolarsDataType] | None = None, + schema_overrides: SchemaDict | None = None, order: Sequence[str] | None = None, ) -> dict[str, pli.Series]: """Expand scalar values in dict data (propagate literal as array).""" updated_data = {} if data: - dtypes = dtypes or {} + dtypes = schema_overrides or {} array_len = max((arrlen(val) or 0) for val in data.values()) if array_len > 0: for name, val in data.items(): @@ -647,10 +660,13 @@ def _expand_dict_scalars( def dict_to_pydf( data: Mapping[str, Sequence[object] | Mapping[str, Sequence[object]] | pli.Series], - columns: ColumnsType | None = None, + columns: SchemaDefinition | None = None, + schema_overrides: SchemaDict | None = None, ) -> PyDataFrame: """Construct a PyDataFrame from a dictionary of sequences.""" - if columns is not None: + if not columns: + columns = list(data) + if columns: # the columns arg may also set the dtype/column order of the series if isinstance(columns, dict) and data: if not all((col in columns) for col in data): @@ -659,15 +675,19 @@ def dict_to_pydf( ) data = {col: data[col] for col in columns} - columns, dtypes = _unpack_columns(columns, lookup_names=data.keys()) - if not data and dtypes: + columns, schema_overrides = _unpack_columns( + columns, lookup_names=data.keys(), schema_overrides=schema_overrides + ) + if not data and schema_overrides: data_series = [ - pli.Series(name, [], dtypes.get(name))._s for name in columns + pli.Series(name, [], schema_overrides.get(name))._s for name in columns ] else: - data_series = [s._s for s in _expand_dict_scalars(data, dtypes).values()] + data_series = [ + s._s for s in _expand_dict_scalars(data, schema_overrides).values() + ] - data_series = _handle_columns_arg(data_series, columns=columns) + data_series = _handle_columns_arg(data_series, columns=columns, from_dict=True) return PyDataFrame(data_series) if _NUMPY_AVAILABLE: @@ -704,34 +724,45 @@ def dict_to_pydf( def sequence_to_pydf( data: Sequence[Any], - columns: ColumnsType | None = None, + columns: SchemaDefinition | None = None, + schema_overrides: SchemaDict | None = None, orient: Orientation | None = None, infer_schema_length: int | None = N_INFER_DEFAULT, ) -> PyDataFrame: """Construct a PyDataFrame from a sequence.""" data_series: list[PySeries] - if len(data) == 0: - return dict_to_pydf({}, columns=columns) + return dict_to_pydf({}, columns=columns, schema_overrides=schema_overrides) + if isinstance(data[0], Generator): data = [list(row) for row in data] if isinstance(data[0], pli.Series): series_names = [s.name for s in data] - columns, dtypes = _unpack_columns(columns or series_names, n_expected=len(data)) + columns, schema_overrides = _unpack_columns( + columns or series_names, + schema_overrides=schema_overrides, + n_expected=len(data), + ) data_series = [] for i, s in enumerate(data): if not s.name: # TODO: Replace by `if s.name is None` once allowed s.rename(columns[i], in_place=True) - new_dtype = dtypes.get(columns[i]) + new_dtype = schema_overrides.get(columns[i]) if new_dtype and new_dtype != s.dtype: s = s.cast(new_dtype) data_series.append(s._s) elif isinstance(data[0], dict): - column_names, dtypes = _unpack_columns(columns) - schema_overrides = include_unknowns(dtypes, column_names) if dtypes else None - pydf = PyDataFrame.read_dicts(data, infer_schema_length, schema_overrides) + column_names, schema_overrides = _unpack_columns( + columns, schema_overrides=schema_overrides + ) + dtypes = ( + include_unknowns(schema_overrides, column_names) + if schema_overrides + else None + ) + pydf = PyDataFrame.read_dicts(data, infer_schema_length, dtypes) if column_names: pydf = _post_apply_columns(pydf, column_names) return pydf @@ -752,29 +783,38 @@ def sequence_to_pydf( orient = "col" if len(columns) == len(data) else "row" if orient == "row": - column_names, dtypes = _unpack_columns(columns) - schema_override = include_unknowns(dtypes, column_names) if dtypes else {} + column_names, schema_overrides = _unpack_columns( + columns, schema_overrides=schema_overrides, n_expected=len(data[0]) + ) + schema_override = ( + include_unknowns(schema_overrides, column_names) + if schema_overrides + else {} + ) if column_names and data and len(data[0]) != len(column_names): raise ShapeError("The row data does not match the number of columns") - categoricals = { - col for col, tp in schema_override.items() if tp == Categorical - } - for col in categoricals: - schema_override[col] = Utf8 + + for col, tp in schema_override.items(): + if tp == Categorical: + schema_override[col] = Utf8 pydf = PyDataFrame.read_rows( data, infer_schema_length, schema_override or None, ) - if column_names: - pydf = _post_apply_columns(pydf, column_names, categoricals) + if column_names or schema_overrides: + pydf = _post_apply_columns( + pydf, column_names, schema_overrides=schema_overrides + ) return pydf elif orient == "col" or orient is None: - columns, dtypes = _unpack_columns(columns, n_expected=len(data)) + columns, schema_overrides = _unpack_columns( + columns, schema_overrides=schema_overrides, n_expected=len(data) + ) data_series = [ - pli.Series(columns[i], data[i], dtypes.get(columns[i]))._s + pli.Series(columns[i], data[i], schema_overrides.get(columns[i]))._s for i in range(len(data)) ] else: @@ -784,17 +824,23 @@ def sequence_to_pydf( elif is_dataclass(data[0]): if columns: - columns, dtypes = _unpack_columns(columns) - schema_override = {col: dtypes.get(col, Unknown) for col in columns} + columns, schema_overrides = _unpack_columns( + columns, schema_overrides=schema_overrides + ) + schema_override = { + col: schema_overrides.get(col, Unknown) for col in columns + } else: columns = None schema_override = { col: (py_type_to_dtype(tp, raise_unmatched=False) or Unknown) for col, tp in dataclass_type_hints(data[0].__class__).items() } - categoricals = {col for col, tp in schema_override.items() if tp == Categorical} - for col in categoricals: - schema_override[col] = Utf8 + schema_override.update(schema_overrides or {}) + + for col, tp in schema_override.items(): + if tp == Categorical: + schema_override[col] = Utf8 pydf = PyDataFrame.read_rows( [astuple(dc) for dc in data], infer_schema_length, schema_override or None @@ -803,20 +849,24 @@ def sequence_to_pydf( structs = { col: tp for col, tp in schema_override.items() if isinstance(tp, Struct) } - pydf = _post_apply_columns(pydf, columns, categoricals, structs) + pydf = _post_apply_columns( + pydf, columns, structs, schema_overrides=schema_overrides + ) return pydf elif _check_for_pandas(data[0]) and isinstance( data[0], (pd.Series, pd.DatetimeIndex) ): - dtypes = {} if columns is not None: - columns, dtypes = _unpack_columns(columns, n_expected=1) + columns, schema_overrides = _unpack_columns( + columns, schema_overrides=schema_overrides, n_expected=1 + ) + schema_overrides = schema_overrides or {} data_series = [] for i, s in enumerate(data): name = columns[i] if columns else s.name - dtype = dtypes.get(name, None) + dtype = schema_overrides.get(name, None) pyseries = pandas_to_pyseries(name=name, values=s) if dtype is not None and dtype != pyseries.dtype(): pyseries = pyseries.cast(dtype, strict=True) @@ -824,8 +874,12 @@ def sequence_to_pydf( columns = None else: - columns, dtypes = _unpack_columns(columns, n_expected=1) - data_series = [pli.Series(columns[0], data, dtypes.get(columns[0]))._s] + columns, schema_overrides = _unpack_columns( + columns, schema_overrides=schema_overrides, n_expected=1 + ) + data_series = [ + pli.Series(columns[0], data, schema_overrides.get(columns[0]))._s + ] data_series = _handle_columns_arg(data_series, columns=columns) return PyDataFrame(data_series) @@ -833,7 +887,8 @@ def sequence_to_pydf( def numpy_to_pydf( data: np.ndarray[Any, Any], - columns: ColumnsType | None = None, + columns: SchemaDefinition | None = None, + schema_overrides: SchemaDict | None = None, orient: Orientation | None = None, ) -> PyDataFrame: """Construct a PyDataFrame from a numpy ndarray.""" @@ -870,7 +925,6 @@ def numpy_to_pydf( raise ValueError( f"orient must be one of {{'col', 'row', None}}, got {orient} instead." ) - else: raise ValueError( "Cannot create DataFrame from numpy array with more than two dimensions." @@ -879,24 +933,28 @@ def numpy_to_pydf( if columns is not None and len(columns) != n_columns: raise ValueError("Dimensions of columns arg must match data dimensions.") - columns, dtypes = _unpack_columns(columns, n_expected=n_columns) + columns, schema_overrides = _unpack_columns( + columns, schema_overrides=schema_overrides, n_expected=n_columns + ) # Convert data to series if shape == (0,): data_series = [] elif len(shape) == 1: - data_series = [pli.Series(columns[0], data, dtypes.get(columns[0]))._s] + data_series = [ + pli.Series(columns[0], data, schema_overrides.get(columns[0]))._s + ] else: if orient == "row": data_series = [ - pli.Series(columns[i], data[:, i], dtypes.get(columns[i]))._s + pli.Series(columns[i], data[:, i], schema_overrides.get(columns[i]))._s for i in range(n_columns) ] else: data_series = [ - pli.Series(columns[i], data[i], dtypes.get(columns[i]))._s + pli.Series(columns[i], data[i], schema_overrides.get(columns[i]))._s for i in range(n_columns) ] @@ -905,18 +963,21 @@ def numpy_to_pydf( def arrow_to_pydf( - data: pa.Table, columns: ColumnsType | None = None, rechunk: bool = True + data: pa.Table, + columns: SchemaDefinition | None = None, + schema_overrides: SchemaDict | None = None, + rechunk: bool = True, ) -> PyDataFrame: """Construct a PyDataFrame from an Arrow Table.""" - original_columns, dtypes = columns, None - if columns is not None: - columns, dtypes = _unpack_columns(columns) - try: + original_columns = columns + columns, schema_overrides = _unpack_columns( + (columns or data.column_names), schema_overrides=schema_overrides + ) + try: + if columns and columns != data.column_names: data = data.rename_columns(columns) - except pa.lib.ArrowInvalid as e: - raise ValueError( - "Dimensions of columns arg must match data dimensions." - ) from e + except pa.lib.ArrowInvalid as e: + raise ValueError("Dimensions of columns arg must match data dimensions.") from e data_dict = {} # dictionaries cannot be built in different batches (categorical does not allow @@ -978,18 +1039,35 @@ def arrow_to_pydf( df = df[names] pydf = df._df - if columns is not None and dtypes and original_columns: - pydf = _post_apply_columns(pydf, original_columns) + if columns != original_columns and (schema_overrides or original_columns): + pydf = _post_apply_columns( + pydf, original_columns, schema_overrides=schema_overrides + ) + elif schema_overrides: + for col, dtype in zip(pydf.columns(), pydf.dtypes()): + override_dtype = schema_overrides.get(col) + if override_dtype is not None and dtype != override_dtype: + pydf = _post_apply_columns( + pydf, original_columns, schema_overrides=schema_overrides + ) + break + return pydf -def series_to_pydf(data: pli.Series, columns: ColumnsType | None = None) -> PyDataFrame: +def series_to_pydf( + data: pli.Series, + columns: SchemaDefinition | None = None, + schema_overrides: SchemaDict | None = None, +) -> PyDataFrame: """Construct a PyDataFrame from a Polars Series.""" data_series = [data._s] series_name = [s.name() for s in data_series] - columns, dtypes = _unpack_columns(columns or series_name, n_expected=1) - if dtypes: - new_dtype = list(dtypes.values())[0] + columns, schema_overrides = _unpack_columns( + columns or series_name, schema_overrides=schema_overrides, n_expected=1 + ) + if schema_overrides: + new_dtype = list(schema_overrides.values())[0] if new_dtype != data.dtype: data_series[0] = data_series[0].cast(new_dtype, True) @@ -999,25 +1077,32 @@ def series_to_pydf(data: pli.Series, columns: ColumnsType | None = None) -> PyDa def iterable_to_pydf( data: Iterable[Any], - columns: ColumnsType | None = None, + columns: SchemaDefinition | None = None, + schema_overrides: SchemaDict | None = None, orient: Orientation | None = None, chunk_size: int | None = None, infer_schema_length: int | None = N_INFER_DEFAULT, ) -> PyDataFrame: """Construct a PyDataFrame from an iterable/generator.""" original_columns = columns - dtypes: dict[str, PolarsDataType] = {} dtypes_by_idx: dict[int, PolarsDataType] = {} if columns is not None: - columns, dtypes = _unpack_columns(columns) + columns, schema_overrides = _unpack_columns( + columns, schema_overrides=schema_overrides + ) + elif schema_overrides: + _columns, schema_overrides = _unpack_columns( + columns, schema_overrides=schema_overrides + ) if not isinstance(data, Generator): data = iter(data) if orient == "col": - if columns is not None and dtypes: + if columns and schema_overrides: dtypes_by_idx = { - idx: dtypes.get(col, Unknown) for idx, col in enumerate(columns) + idx: schema_overrides.get(col, Unknown) + for idx, col in enumerate(columns) } return pli.DataFrame( @@ -1029,7 +1114,9 @@ def iterable_to_pydf( } )._df - def to_frame_chunk(values: list[Any], columns: ColumnsType | None) -> pli.DataFrame: + def to_frame_chunk( + values: list[Any], columns: SchemaDefinition | None + ) -> pli.DataFrame: return pli.DataFrame( data=values, columns=columns, @@ -1071,7 +1158,8 @@ def to_frame_chunk(values: list[Any], columns: ColumnsType | None) -> pli.DataFr def pandas_to_pydf( data: pd.DataFrame, - columns: ColumnsType | None = None, + columns: SchemaDefinition | None = None, + schema_overrides: SchemaDict | None = None, rechunk: bool = True, nan_to_none: bool = True, ) -> PyDataFrame: @@ -1084,7 +1172,9 @@ def pandas_to_pydf( for col in data.columns } arrow_table = pa.table(arrow_dict) - return arrow_to_pydf(arrow_table, columns=columns, rechunk=rechunk) + return arrow_to_pydf( + arrow_table, columns=columns, schema_overrides=schema_overrides, rechunk=rechunk + ) def coerce_arrow(array: pa.Array, rechunk: bool = True) -> pa.Array: diff --git a/py-polars/polars/internals/dataframe/frame.py b/py-polars/polars/internals/dataframe/frame.py index f64a2de77692..134d7f4e4e2c 100644 --- a/py-polars/polars/internals/dataframe/frame.py +++ b/py-polars/polars/internals/dataframe/frame.py @@ -32,13 +32,13 @@ from polars.datatypes import ( N_INFER_DEFAULT, Boolean, - ColumnsType, Int8, Int16, Int32, Int64, PolarsDataType, - Schema, + SchemaDefinition, + SchemaDict, UInt8, UInt16, UInt32, @@ -146,6 +146,9 @@ class DataFrame: columns : Sequence of str, (str,DataType) pairs, or {str:DataType,} dict Column labels (with optional type) to use for resulting DataFrame. If specified, overrides any labels already present in the data. Must match data dimensions. + schema_overrides : dict, default None + Support type specification or override of one or more columns; note that + any dtypes inferred from the columns param will be overridden. orient : {'col', 'row'}, default None Whether to interpret two-dimensional data as columns or as rows. If None, the orientation is inferred by matching the columns and data dimensions. If @@ -285,39 +288,54 @@ def __init__( | pli.Series | None ) = None, - columns: ColumnsType | None = None, + columns: SchemaDefinition | None = None, + schema_overrides: SchemaDict | None = None, orient: Orientation | None = None, infer_schema_length: int | None = N_INFER_DEFAULT, ): if data is None: - self._df = dict_to_pydf({}, columns=columns) + self._df = dict_to_pydf( + {}, columns=columns, schema_overrides=schema_overrides + ) elif isinstance(data, dict): - self._df = dict_to_pydf(data, columns=columns) + self._df = dict_to_pydf( + data, columns=columns, schema_overrides=schema_overrides + ) elif isinstance(data, (list, tuple, Sequence)): self._df = sequence_to_pydf( data, columns=columns, + schema_overrides=schema_overrides, orient=orient, infer_schema_length=infer_schema_length, ) elif isinstance(data, pli.Series): - self._df = series_to_pydf(data, columns=columns) + self._df = series_to_pydf( + data, columns=columns, schema_overrides=schema_overrides + ) elif _check_for_numpy(data) and isinstance(data, np.ndarray): - self._df = numpy_to_pydf(data, columns=columns, orient=orient) + self._df = numpy_to_pydf( + data, columns=columns, schema_overrides=schema_overrides, orient=orient + ) elif _check_for_pyarrow(data) and isinstance(data, pa.Table): - self._df = arrow_to_pydf(data, columns=columns) + self._df = arrow_to_pydf( + data, columns=columns, schema_overrides=schema_overrides + ) elif _check_for_pandas(data) and isinstance(data, pd.DataFrame): - self._df = pandas_to_pydf(data, columns=columns) + self._df = pandas_to_pydf( + data, columns=columns, schema_overrides=schema_overrides + ) elif not isinstance(data, Sized) and isinstance(data, (Generator, Iterable)): self._df = iterable_to_pydf( data, columns=columns, + schema_overrides=schema_overrides, orient=orient, infer_schema_length=infer_schema_length, ) @@ -338,9 +356,9 @@ def _from_dicts( cls: type[DF], data: Sequence[dict[str, Any]], infer_schema_length: int | None = N_INFER_DEFAULT, - schema: Schema | None = None, + schema_overrides: SchemaDict | None = None, ) -> DF: - pydf = PyDataFrame.read_dicts(data, infer_schema_length, schema) + pydf = PyDataFrame.read_dicts(data, infer_schema_length, schema_overrides) return cls._from_pydf(pydf) @classmethod @@ -349,7 +367,8 @@ def _from_dict( data: Mapping[ str, Sequence[object] | Mapping[str, Sequence[object]] | pli.Series ], - columns: Sequence[str] | None = None, + schema: SchemaDefinition | None = None, + schema_overrides: SchemaDict | None = None, ) -> DF: """ Construct a DataFrame from a dictionary of sequences. @@ -357,24 +376,30 @@ def _from_dict( Parameters ---------- data : dict of sequences - Two-dimensional data represented as a dictionary. dict must contain - Sequences. - columns : Sequence of str, default None - Column labels to use for resulting DataFrame. If specified, overrides any - labels already present in the data. Must match data dimensions. + Two-dimensional data represented as a dictionary. dict must contain + Sequences. + schema : Sequence of str, (str,DataType) pairs, or a {str:DataType,} dict + Column labels to use for resulting DataFrame. If specified, overrides any + labels already present in the data. Must match data dimensions. + schema_overrides : dict, default None + Support type specification or override of one or more columns; note that + any dtypes inferred from the columns param will be overridden. Returns ------- DataFrame """ - return cls._from_pydf(dict_to_pydf(data, columns=columns)) + return cls._from_pydf( + dict_to_pydf(data, columns=schema, schema_overrides=schema_overrides) + ) @classmethod def _from_records( cls: type[DF], data: Sequence[Sequence[Any]], columns: Sequence[str] | None = None, + schema_overrides: SchemaDict | None = None, orient: Orientation | None = None, infer_schema_length: int | None = N_INFER_DEFAULT, ) -> DF: @@ -388,6 +413,9 @@ def _from_records( columns : Sequence of str, default None Column labels to use for resulting DataFrame. Must match data dimensions. If not specified, columns will be named `column_0`, `column_1`, etc. + schema_overrides : dict, default None + Support type specification or override of one or more columns; note that + any dtypes inferred from the columns param will be overridden. orient : {'col', 'row'}, default None Whether to interpret two-dimensional data as columns or as rows. If None, the orientation is inferred by matching the columns and data dimensions. If @@ -404,6 +432,7 @@ def _from_records( sequence_to_pydf( data, columns=columns, + schema_overrides=schema_overrides, orient=orient, infer_schema_length=infer_schema_length, ) @@ -414,6 +443,7 @@ def _from_numpy( cls: type[DF], data: np.ndarray[Any, Any], columns: Sequence[str] | None = None, + schema_overrides: SchemaDict | None = None, orient: Orientation | None = None, ) -> DF: """ @@ -426,6 +456,9 @@ def _from_numpy( columns : Sequence of str, default None Column labels to use for resulting DataFrame. Must match data dimensions. If not specified, columns will be named `column_0`, `column_1`, etc. + schema_overrides : dict, default None + Support type specification or override of one or more columns; note that + any dtypes inferred from the columns param will be overridden. orient : {'col', 'row'}, default None Whether to interpret two-dimensional data as columns or as rows. If None, the orientation is inferred by matching the columns and data dimensions. If @@ -436,13 +469,18 @@ def _from_numpy( DataFrame """ - return cls._from_pydf(numpy_to_pydf(data, columns=columns, orient=orient)) + return cls._from_pydf( + numpy_to_pydf( + data, columns=columns, schema_overrides=schema_overrides, orient=orient + ) + ) @classmethod def _from_arrow( cls: type[DF], data: pa.Table, columns: Sequence[str] | None = None, + schema_overrides: SchemaDict | None = None, rechunk: bool = True, ) -> DF: """ @@ -453,12 +491,15 @@ def _from_arrow( Parameters ---------- - data : numpy ndarray or Sequence of sequences - Two-dimensional data represented as Arrow table. + data : arrow table, array, or sequence of sequences + Data representing an Arrow Table or Array. columns : Sequence of str, default None Column labels to use for resulting DataFrame. Must match data dimensions. If not specified, existing Array table columns are used, with missing names named as `column_0`, `column_1`, etc. + schema_overrides : dict, default None + Support type specification or override of one or more columns; note that + any dtypes inferred from the columns param will be overridden. rechunk : bool, default True Make sure that all data is in contiguous memory. @@ -467,13 +508,21 @@ def _from_arrow( DataFrame """ - return cls._from_pydf(arrow_to_pydf(data, columns=columns, rechunk=rechunk)) + return cls._from_pydf( + arrow_to_pydf( + data, + columns=columns, + schema_overrides=schema_overrides, + rechunk=rechunk, + ) + ) @classmethod def _from_pandas( cls: type[DF], data: pd.DataFrame, columns: Sequence[str] | None = None, + schema_overrides: SchemaDict | None = None, rechunk: bool = True, nan_to_none: bool = True, ) -> DF: @@ -487,6 +536,9 @@ def _from_pandas( columns : Sequence of str, default None Column labels to use for resulting DataFrame. If specified, overrides any labels already present in the data. Must match data dimensions. + schema_overrides : dict, default None + Support type specification or override of one or more columns; note that + any dtypes inferred from the columns param will be overridden. rechunk : bool, default True Make sure that all data is in contiguous memory. nan_to_none : bool, default True @@ -505,13 +557,18 @@ def _from_pandas( if pd_series.dtype == np.dtype("O"): series.append(pli.Series(name, [], dtype=Utf8)) else: - col = pli.Series(name, pd_series) + dtype = (schema_overrides or {}).get(name) + col = pli.Series(name, pd_series, dtype=dtype) series.append(pli.Series(name, col)) return cls(series) return cls._from_pydf( pandas_to_pydf( - data, columns=columns, rechunk=rechunk, nan_to_none=nan_to_none + data, + columns=columns, + schema_overrides=schema_overrides, + rechunk=rechunk, + nan_to_none=nan_to_none, ) ) @@ -525,7 +582,7 @@ def _read_csv( comment_char: str | None = None, quote_char: str | None = r'"', skip_rows: int = 0, - dtypes: None | (Mapping[str, PolarsDataType] | Sequence[PolarsDataType]) = None, + dtypes: None | (SchemaDict | Sequence[PolarsDataType]) = None, null_values: str | list[str] | dict[str, str] | None = None, missing_utf8_is_empty_string: bool = False, ignore_errors: bool = False, @@ -993,7 +1050,7 @@ def dtypes(self) -> list[PolarsDataType]: return self._df.dtypes() @property - def schema(self) -> dict[str, PolarsDataType]: + def schema(self) -> SchemaDict: """ Get a dict[column name, DataType]. diff --git a/py-polars/polars/internals/lazy_functions.py b/py-polars/polars/internals/lazy_functions.py index 02bd9a9f5607..b762b4cb0946 100644 --- a/py-polars/polars/internals/lazy_functions.py +++ b/py-polars/polars/internals/lazy_functions.py @@ -2,7 +2,7 @@ import sys from datetime import date, datetime, time, timedelta -from typing import TYPE_CHECKING, Any, Callable, Mapping, Sequence, overload +from typing import TYPE_CHECKING, Any, Callable, Sequence, overload from polars import internals as pli from polars.datatypes import ( @@ -14,6 +14,7 @@ Duration, Int64, PolarsDataType, + SchemaDict, Struct, Time, UInt32, @@ -2278,7 +2279,7 @@ def select( def struct( exprs: Sequence[pli.Expr | str | pli.Series] | pli.Expr | pli.Series, eager: Literal[True], - schema: Mapping[str, PolarsDataType] | None = None, + schema: SchemaDict | None = None, ) -> pli.Series: ... @@ -2287,7 +2288,7 @@ def struct( def struct( exprs: Sequence[pli.Expr | str | pli.Series] | pli.Expr | pli.Series, eager: Literal[False], - schema: Mapping[str, PolarsDataType] | None = None, + schema: SchemaDict | None = None, ) -> pli.Expr: ... @@ -2296,7 +2297,7 @@ def struct( def struct( exprs: Sequence[pli.Expr | str | pli.Series] | pli.Expr | pli.Series, eager: bool = False, - schema: Mapping[str, PolarsDataType] | None = None, + schema: SchemaDict | None = None, ) -> pli.Expr | pli.Series: ... @@ -2304,7 +2305,7 @@ def struct( def struct( exprs: Sequence[pli.Expr | str | pli.Series] | pli.Expr | pli.Series, eager: bool = False, - schema: Mapping[str, PolarsDataType] | None = None, + schema: SchemaDict | None = None, ) -> pli.Expr | pli.Series: """ Collect several columns into a Series of dtype Struct. diff --git a/py-polars/polars/internals/lazyframe/frame.py b/py-polars/polars/internals/lazyframe/frame.py index 364bb75e06ea..64bd140afbe0 100644 --- a/py-polars/polars/internals/lazyframe/frame.py +++ b/py-polars/polars/internals/lazyframe/frame.py @@ -35,7 +35,7 @@ Int32, Int64, PolarsDataType, - Schema, + SchemaDict, Time, UInt8, UInt16, @@ -122,7 +122,7 @@ def _scan_csv( comment_char: str | None = None, quote_char: str | None = r'"', skip_rows: int = 0, - dtypes: dict[str, PolarsDataType] | None = None, + dtypes: SchemaDict | None = None, null_values: str | list[str] | dict[str, str] | None = None, missing_utf8_is_empty_string: bool = False, ignore_errors: bool = False, @@ -417,7 +417,7 @@ def dtypes(self) -> list[type[DataType]]: return self._ldf.dtypes() @property - def schema(self) -> Schema: + def schema(self) -> SchemaDict: """ Get a dict[column name, DataType]. @@ -3723,7 +3723,7 @@ def map( projection_pushdown: bool = True, slice_pushdown: bool = True, no_optimizations: bool = False, - schema: None | Schema = None, + schema: None | SchemaDict = None, validate_output_schema: bool = True, ) -> LDF: """ diff --git a/py-polars/polars/internals/lazyframe/groupby.py b/py-polars/polars/internals/lazyframe/groupby.py index c4321ec73acf..34b0bba3521d 100644 --- a/py-polars/polars/internals/lazyframe/groupby.py +++ b/py-polars/polars/internals/lazyframe/groupby.py @@ -3,7 +3,7 @@ from typing import Callable, Generic, Sequence, TypeVar import polars.internals as pli -from polars.datatypes import Schema +from polars.datatypes import SchemaDict from polars.internals import selection_to_pyexpr_list from polars.utils import is_expr_sequence @@ -154,7 +154,7 @@ def tail(self, n: int = 5) -> LDF: return self._lazyframe_class._from_pyldf(self.lgb.tail(n)) def apply( - self, f: Callable[[pli.DataFrame], pli.DataFrame], schema: Schema | None + self, f: Callable[[pli.DataFrame], pli.DataFrame], schema: SchemaDict | None ) -> LDF: """ Apply a custom/user-defined function (UDF) over the groups as a new DataFrame. diff --git a/py-polars/polars/io.py b/py-polars/polars/io.py index 85ba4eb0dfa8..0e154017e3e3 100644 --- a/py-polars/polars/io.py +++ b/py-polars/polars/io.py @@ -25,7 +25,7 @@ import polars.internals as pli from polars.convert import from_arrow -from polars.datatypes import N_INFER_DEFAULT, DataType, PolarsDataType, Utf8 +from polars.datatypes import N_INFER_DEFAULT, DataType, SchemaDict, Utf8 from polars.dependencies import _DELTALAKE_AVAILABLE, _PYARROW_AVAILABLE, deltalake from polars.dependencies import pyarrow as pa from polars.internals import DataFrame, LazyFrame, _scan_ds @@ -282,7 +282,7 @@ def read_csv( [f"column_{int(column[1:]) + 1}" for column in tbl.column_names] ) - df = cast(DataFrame, from_arrow(tbl, rechunk)) + df = cast(DataFrame, from_arrow(tbl, rechunk=rechunk)) if new_columns: return pli._update_columns(df, new_columns) return df @@ -421,7 +421,7 @@ def scan_csv( comment_char: str | None = None, quote_char: str | None = r'"', skip_rows: int = 0, - dtypes: dict[str, PolarsDataType] | None = None, + dtypes: SchemaDict | None = None, null_values: str | list[str] | dict[str, str] | None = None, missing_utf8_is_empty_string: bool = False, ignore_errors: bool = False, diff --git a/py-polars/polars/testing/__init__.py b/py-polars/polars/testing/__init__.py index 752a6a35e863..13cf07939044 100644 --- a/py-polars/polars/testing/__init__.py +++ b/py-polars/polars/testing/__init__.py @@ -1,11 +1,15 @@ from polars.testing.asserts import ( assert_frame_equal, assert_frame_equal_local_categoricals, + assert_frame_not_equal, assert_series_equal, + assert_series_not_equal, ) __all__ = [ "assert_series_equal", + "assert_series_not_equal", "assert_frame_equal", + "assert_frame_not_equal", "assert_frame_equal_local_categoricals", ] diff --git a/py-polars/polars/testing/asserts.py b/py-polars/polars/testing/asserts.py index 6fc8ac3c49ad..ead43b8be2ae 100644 --- a/py-polars/polars/testing/asserts.py +++ b/py-polars/polars/testing/asserts.py @@ -29,7 +29,7 @@ def assert_frame_equal( check_row_order: bool = True, ) -> None: """ - Raise detailed AssertionError if `left` does not equal `right`. + Raise detailed AssertionError if `left` does NOT equal `right`. Parameters ---------- @@ -49,10 +49,10 @@ def assert_frame_equal( nans_compare_equal if your assert/test requires float NaN != NaN, set this to False. check_column_order - if False, allows the assert/test to succeed if the required columns are present, + if False, frames will compare equal if the required columns are present, irrespective of the order in which they appear. check_row_order - if False, allows the assert/test to succeed if the required rows are present, + if False, frames will compare equal if the required rows are present, irrespective of the order in which they appear; as this requires sorting, you cannot set on frames that contain unsortable columns. @@ -114,6 +114,71 @@ def assert_frame_equal( ) +def assert_frame_not_equal( + left: pli.DataFrame | pli.LazyFrame, + right: pli.DataFrame | pli.LazyFrame, + check_dtype: bool = True, + check_exact: bool = False, + rtol: float = 1.0e-5, + atol: float = 1.0e-8, + nans_compare_equal: bool = True, + check_column_order: bool = True, + check_row_order: bool = True, +) -> None: + """ + Raise AssertionError if `left` DOES equal `right`. + + Parameters + ---------- + left + the dataframe to compare. + right + the dataframe to compare with. + check_dtype + if True, data types need to match exactly. + check_exact + if False, test if values are within tolerance of each other + (see `rtol` & `atol`). + rtol + relative tolerance for inexact checking. Fraction of values in `right`. + atol + absolute tolerance for inexact checking. + nans_compare_equal + if your assert/test requires float NaN != NaN, set this to False. + check_column_order + if False, frames will compare equal if the required columns are present, + irrespective of the order in which they appear. + check_row_order + if False, frames will compare equal if the required rows are present, + irrespective of the order in which they appear; as this requires + sorting, you cannot set on frames that contain unsortable columns. + + Examples + -------- + >>> from polars.testing import assert_frame_not_equal + >>> df1 = pl.DataFrame({"a": [1, 2, 3]}) + >>> df2 = pl.DataFrame({"a": [2, 3, 4]}) + >>> assert_frame_not_equal(df1, df2) + + """ + try: + assert_frame_equal( + left=left, + right=right, + check_dtype=check_dtype, + check_exact=check_exact, + rtol=rtol, + atol=atol, + nans_compare_equal=nans_compare_equal, + check_column_order=check_column_order, + check_row_order=check_row_order, + ) + except AssertionError: + return + + raise AssertionError("Expected the two frames to compare unequal") + + def assert_series_equal( left: pli.Series, right: pli.Series, @@ -125,7 +190,7 @@ def assert_series_equal( nans_compare_equal: bool = True, ) -> None: """ - Raise detailed AssertionError if `left` does not equal `right`. + Raise detailed AssertionError if `left` does NOT equal `right`. Parameters ---------- @@ -175,6 +240,64 @@ def assert_series_equal( ) +def assert_series_not_equal( + left: pli.Series, + right: pli.Series, + check_dtype: bool = True, + check_names: bool = True, + check_exact: bool = False, + rtol: float = 1.0e-5, + atol: float = 1.0e-8, + nans_compare_equal: bool = True, +) -> None: + """ + Raise AssertionError if `left` DOES equal `right`. + + Parameters + ---------- + left + the series to compare. + right + the series to compare with. + check_dtype + if True, data types need to match exactly. + check_names + if True, names need to match. + check_exact + if False, test if values are within tolerance of each other + (see `rtol` & `atol`). + rtol + relative tolerance for inexact checking. Fraction of values in `right`. + atol + absolute tolerance for inexact checking. + nans_compare_equal + if your assert/test requires float NaN != NaN, set this to False. + + Examples + -------- + >>> from polars.testing import assert_series_not_equal + >>> s1 = pl.Series([1, 2, 3]) + >>> s2 = pl.Series([2, 3, 4]) + >>> assert_series_not_equal(s1, s2) + + """ + try: + assert_series_equal( + left=left, + right=right, + check_dtype=check_dtype, + check_names=check_names, + check_exact=check_exact, + rtol=rtol, + atol=atol, + nans_compare_equal=nans_compare_equal, + ) + except AssertionError: + return + + raise AssertionError("Expected the two series to compare unequal") + + def _assert_series_inner( left: pli.Series, right: pli.Series, diff --git a/py-polars/tests/unit/test_constructors.py b/py-polars/tests/unit/test_constructors.py index 336f2c6ddcb2..53a3585a55da 100644 --- a/py-polars/tests/unit/test_constructors.py +++ b/py-polars/tests/unit/test_constructors.py @@ -547,7 +547,9 @@ def test_from_dicts_schema() -> None: # let polars infer the dtypes # but inform about a 3rd column - df = pl.from_dicts(data, schema={"a": pl.Unknown, "b": pl.Unknown, "c": pl.Int32}) + df = pl.from_dicts( + data, schema_overrides={"a": pl.Unknown, "b": pl.Unknown, "c": pl.Int32} + ) assert df.dtypes == [pl.Int64, pl.Int64, pl.Int32] assert df.to_dict(False) == { "a": [1, 2, 3], diff --git a/py-polars/tests/unit/test_df.py b/py-polars/tests/unit/test_df.py index 664acc0eafae..17fad4140331 100644 --- a/py-polars/tests/unit/test_df.py +++ b/py-polars/tests/unit/test_df.py @@ -16,7 +16,11 @@ from polars.datatypes import DTYPE_TEMPORAL_UNITS from polars.dependencies import zoneinfo from polars.internals.construction import iterable_to_pydf -from polars.testing import assert_frame_equal, assert_series_equal +from polars.testing import ( + assert_frame_equal, + assert_frame_not_equal, + assert_series_equal, +) from polars.testing.parametric import columns if TYPE_CHECKING: @@ -187,6 +191,7 @@ def test_from_arrow() -> None: "b": pa.array([1, 2], pa.timestamp("ms")), "c": pa.array([1, 2], pa.timestamp("us")), "d": pa.array([1, 2], pa.timestamp("ns")), + "e": pa.array([1, 2], pa.int32()), "decimal1": pa.array([1, 2], pa.decimal128(2, 1)), } ) @@ -195,6 +200,7 @@ def test_from_arrow() -> None: "b": pl.Datetime("ms"), "c": pl.Datetime("us"), "d": pl.Datetime("ns"), + "e": pl.Int32, "decimal1": pl.Float64, } expected_data = [ @@ -203,6 +209,7 @@ def test_from_arrow() -> None: datetime(1970, 1, 1, 0, 0, 0, 1000), datetime(1970, 1, 1, 0, 0, 0, 1), datetime(1970, 1, 1, 0, 0), + 1, 1.0, ), ( @@ -210,6 +217,7 @@ def test_from_arrow() -> None: datetime(1970, 1, 1, 0, 0, 0, 2000), datetime(1970, 1, 1, 0, 0, 0, 2), datetime(1970, 1, 1, 0, 0), + 2, 2.0, ), ] @@ -223,26 +231,39 @@ def test_from_arrow() -> None: assert df.schema == expected_schema assert df.rows() == [] + # try a single column dtype override + for t in (tbl, empty_tbl): + df = pl.DataFrame(t, schema_overrides={"e": pl.Int8}) + override_schema = expected_schema.copy() + override_schema["e"] = pl.Int8 + assert df.schema == override_schema + assert df.rows() == expected_data[: (len(df))] + -def test_from_dict_with_dict_columns() -> None: - # expect schema order to take precedence +def test_from_dict_with_column_order() -> None: + # expect schema/columns order to take precedence schema = {"a": pl.UInt8, "b": pl.UInt32} - df = pl.DataFrame({"b": [3, 4], "a": [1, 2]}, columns=schema) - # ┌─────┬─────┐ - # │ a ┆ b │ - # │ --- ┆ --- │ - # │ u8 ┆ u32 │ - # ╞═════╪═════╡ - # │ 1 ┆ 3 │ - # │ 2 ┆ 4 │ - # └─────┴─────┘ - assert df.columns == ["a", "b"] - assert df.rows() == [(1, 3), (2, 4)] - - # expected error - mismatched_schema = {"x": pl.UInt8, "b": pl.UInt32} - with pytest.raises(ValueError): - pl.DataFrame({"b": [3, 4], "a": [1, 2]}, columns=mismatched_schema) + data = {"b": [3, 4], "a": [1, 2]} + for df in ( + pl.DataFrame(data, columns=schema), + pl.DataFrame(data, columns=["a", "b"], schema_overrides=schema), + ): + # ┌─────┬─────┐ + # │ a ┆ b │ + # │ --- ┆ --- │ + # │ u8 ┆ u32 │ + # ╞═════╪═════╡ + # │ 1 ┆ 3 │ + # │ 2 ┆ 4 │ + # └─────┴─────┘ + assert df.columns == ["a", "b"] + assert df.schema == {"a": pl.UInt8, "b": pl.UInt32} + assert df.rows() == [(1, 3), (2, 4)] + + # expect an error + mismatched_schema = {"x": pl.UInt8, "b": pl.UInt32} + with pytest.raises(ValueError): + pl.DataFrame({"b": [3, 4], "a": [1, 2]}, columns=mismatched_schema) def test_from_dict_with_scalars() -> None: @@ -288,6 +309,7 @@ def test_from_dict_with_scalars() -> None: "key": pl.Int8, }, ) + assert df4.columns == ["value", "other", "misc", "key"] assert df4.to_dict(False) == { "value": ["x", "y", "z"], "other": [7.0, 8.0, 9.0], @@ -302,16 +324,22 @@ def test_from_dict_with_scalars() -> None: } # mixed with struct cols - df5 = pl.from_dict( - {"x": {"b": [1, 3], "c": [2, 4]}, "y": [5, 6], "z": "x"}, - columns=["x", ("y", pl.Int8), "z"], # type: ignore[list-item] - ) - assert df5.rows() == [({"b": 1, "c": 2}, 5, "x"), ({"b": 3, "c": 4}, 6, "x")] - assert df5.schema == { - "x": pl.Struct([pl.Field("b", pl.Int64), pl.Field("c", pl.Int64)]), - "y": pl.Int8, - "z": pl.Utf8, - } + for df5 in ( + pl.from_dict( + {"x": {"b": [1, 3], "c": [2, 4]}, "y": [5, 6], "z": "x"}, + schema_overrides={"y": pl.Int8}, + ), + pl.from_dict( + {"x": {"b": [1, 3], "c": [2, 4]}, "y": [5, 6], "z": "x"}, + columns=["x", ("y", pl.Int8), "z"], + ), + ): + assert df5.rows() == [({"b": 1, "c": 2}, 5, "x"), ({"b": 3, "c": 4}, 6, "x")] + assert df5.schema == { + "x": pl.Struct([pl.Field("b", pl.Int64), pl.Field("c", pl.Int64)]), + "y": pl.Int8, + "z": pl.Utf8, + } def test_dataclasses_and_namedtuple() -> None: @@ -350,7 +378,19 @@ class TradeNT(NamedTuple): } assert df.rows() == raw_data - # in conjunction with 'columns' override (rename/downcast) + # partial dtypes override + df = DF( # type: ignore[operator] + data=trades, + schema_overrides={"timestamp": pl.Datetime("ms"), "size": pl.Int32}, + ) + assert df.schema == { + "timestamp": pl.Datetime("ms"), + "ticker": pl.Utf8, + "price": pl.Float64, + "size": pl.Int32, + } + + # in conjunction with full 'columns' override (rename/downcast) df = pl.DataFrame( data=trades, columns=[ @@ -1197,11 +1237,18 @@ def test_string_cache_eager_lazy() -> None: # also check row-wise categorical insert. # (column-wise is preferred, but this shouldn't fail) - df3 = pl.DataFrame( - data=[["reg1"], ["reg2"], ["reg3"], ["reg4"], ["reg5"]], - columns=[("region_ids", pl.Categorical)], - ) - assert_frame_equal(df1, df3) + for params in ( + {"columns": [("region_ids", pl.Categorical)]}, + { + "columns": ["region_ids"], + "schema_overrides": {"region_ids": pl.Categorical}, + }, + ): + df3 = pl.DataFrame( # type: ignore[arg-type] + data=[["reg1"], ["reg2"], ["reg3"], ["reg4"], ["reg5"]], + **params, + ) + assert_frame_equal(df1, df3) def test_assign() -> None: @@ -1248,7 +1295,8 @@ def test_literal_series() -> None: [datetime(2022, 8, 16), datetime(2022, 8, 17), datetime(2022, 8, 18)], dtype=" None: .collect() ) expected_schema = { - "a": pl.Float32, + "a": pl.Float64, "b": pl.Int8, "c": pl.Utf8, "d": pl.Datetime("ns"), @@ -1415,16 +1463,17 @@ def test_from_rows() -> None: ) df = pl.from_records( [[1, datetime.fromtimestamp(100)], [2, datetime.fromtimestamp(2398754908)]], + schema_overrides={"column_0": pl.UInt32}, orient="row", ) - assert df.dtypes == [pl.Int64, pl.Datetime] + assert df.dtypes == [pl.UInt32, pl.Datetime] def test_repeat_by() -> None: df = pl.DataFrame({"name": ["foo", "bar"], "n": [2, 3]}) - out = df.select(pl.col("n").repeat_by("n")) s = out["n"] + assert s[0].to_list() == [2, 2] assert s[1].to_list() == [3, 3, 3] @@ -2168,7 +2217,6 @@ def test_to_dict(as_series: bool, inner_dtype: Any) -> None: "optional": [28, 300, None, 2, -30], } ) - s = df.to_dict(as_series=as_series) assert isinstance(s, dict) for v in s.values(): @@ -2177,9 +2225,11 @@ def test_to_dict(as_series: bool, inner_dtype: Any) -> None: def test_df_broadcast() -> None: - df = pl.DataFrame({"a": [1, 2, 3]}) - out = df.with_column(pl.Series([[1, 2]])) + df = pl.DataFrame({"a": [1, 2, 3]}, schema_overrides={"a": pl.UInt8}) + out = df.with_column(pl.Series("s", [[1, 2]])) assert out.shape == (3, 2) + assert out.schema == {"a": pl.UInt8, "s": pl.List(pl.Int64)} + assert out.rows() == [(1, [1, 2]), (2, [1, 2]), (3, [1, 2])] def test_product() -> None: @@ -2189,11 +2239,16 @@ def test_product() -> None: "flt": [-1.0, 12.0, 9.0], "bool_0": [True, False, True], "bool_1": [True, True, True], - } + }, + schema_overrides={ + "int": pl.UInt16, + "flt": pl.Float32, + }, ) out = df.product() expected = pl.DataFrame({"int": [6], "flt": [-108.0], "bool_0": [0], "bool_1": [1]}) - assert out.frame_equal(expected) + assert_frame_not_equal(out, expected, check_dtype=True) + assert_frame_equal(out, expected, check_dtype=False) def test_first_last_expression(fruits_cars: pl.DataFrame) -> None: @@ -2644,20 +2699,31 @@ def test_init_datetimes_with_timezone() -> None: dtm = datetime(2022, 10, 12, 12, 30, tzinfo=zoneinfo.ZoneInfo("UTC")) for tu in DTYPE_TEMPORAL_UNITS | frozenset([None]): - df = pl.DataFrame( - data={"d1": [dtm], "d2": [dtm]}, - columns=[ - ("d1", pl.Datetime(tu, tz_us)), - ("d2", pl.Datetime(tu, tz_europe)), - ], - ) - assert (df["d1"].to_physical() == df["d2"].to_physical()).all() - assert df.rows() == [ - ( - datetime(2022, 10, 12, 8, 30, tzinfo=zoneinfo.ZoneInfo(tz_us)), - datetime(2022, 10, 12, 14, 30, tzinfo=zoneinfo.ZoneInfo(tz_europe)), + for type_overrides in ( + { + "columns": [ + ("d1", pl.Datetime(tu, tz_us)), + ("d2", pl.Datetime(tu, tz_europe)), + ] + }, + { + "schema_overrides": { + "d1": pl.Datetime(tu, tz_us), + "d2": pl.Datetime(tu, tz_europe), + } + }, + ): + df = pl.DataFrame( # type: ignore[arg-type] + data={"d1": [dtm], "d2": [dtm]}, + **type_overrides, ) - ] + assert (df["d1"].to_physical() == df["d2"].to_physical()).all() + assert df.rows() == [ + ( + datetime(2022, 10, 12, 8, 30, tzinfo=zoneinfo.ZoneInfo(tz_us)), + datetime(2022, 10, 12, 14, 30, tzinfo=zoneinfo.ZoneInfo(tz_europe)), + ) + ] def test_init_physical_with_timezone() -> None: diff --git a/py-polars/tests/unit/test_interop.py b/py-polars/tests/unit/test_interop.py index 39e6910eb724..86f3739e90c0 100644 --- a/py-polars/tests/unit/test_interop.py +++ b/py-polars/tests/unit/test_interop.py @@ -119,6 +119,12 @@ def test_from_pandas() -> None: (False, False, 3, 3.0, 3.0, 3.0, "ham", "ham", "ham"), ] + # partial dtype overrides from pandas + overrides = {"int": pl.Int8, "int_nulls": pl.Int32, "floats": pl.Float32} + out = pl.from_pandas(df, schema_overrides=overrides) + for col, dtype in overrides.items(): + assert out.schema[col] == dtype + def test_from_pandas_nan_to_none() -> None: df = pd.DataFrame( @@ -264,7 +270,7 @@ def test_from_dicts() -> None: def test_from_dict_no_inference() -> None: schema = {"a": pl.Utf8} data = [{"a": "aa"}] - pl.from_dicts(data, schema=schema, infer_schema_length=0) + pl.from_dicts(data, schema_overrides=schema, infer_schema_length=0) def test_from_dicts_schema_override() -> None: @@ -333,9 +339,15 @@ def test_from_records() -> None: def test_from_numpy() -> None: data = np.array([[1, 2, 3], [4, 5, 6]]) - df = pl.from_numpy(data, columns=["a", "b"], orient="col") + df = pl.from_numpy( + data, + columns=["a", "b"], + orient="col", + schema_overrides={"a": pl.UInt32, "b": pl.UInt32}, + ) assert df.shape == (3, 2) assert df.rows() == [(1, 4), (2, 5), (3, 6)] + assert df.schema == {"a": pl.UInt32, "b": pl.UInt32} def test_from_arrow() -> None: @@ -348,6 +360,12 @@ def test_from_arrow() -> None: with pytest.raises(ValueError): _ = pl.from_arrow([1, 2]) + df = pl.from_arrow( + data, schema=["a", "b"], schema_overrides={"a": pl.UInt32, "b": pl.UInt64} + ) + assert df.rows() == [(1, 4), (2, 5), (3, 6)] # type: ignore[union-attr] + assert df.schema == {"a": pl.UInt32, "b": pl.UInt64} # type: ignore[union-attr] + def test_from_pandas_dataframe() -> None: pd_df = pd.DataFrame([[1, 2, 3], [4, 5, 6]], columns=["a", "b", "c"]) diff --git a/py-polars/tests/unit/test_testing.py b/py-polars/tests/unit/test_testing.py index 1437dd31d0cc..9dcacb99adf8 100644 --- a/py-polars/tests/unit/test_testing.py +++ b/py-polars/tests/unit/test_testing.py @@ -4,12 +4,18 @@ import polars as pl from polars.exceptions import InvalidAssert -from polars.testing import assert_frame_equal, assert_series_equal +from polars.testing import ( + assert_frame_equal, + assert_series_equal, + assert_series_not_equal, +) def test_compare_series_value_mismatch() -> None: srs1 = pl.Series([1, 2, 3]) srs2 = pl.Series([2, 3, 4]) + + assert_series_not_equal(srs1, srs2) with pytest.raises(AssertionError, match="Series are different\n\nValue mismatch"): assert_series_equal(srs1, srs2) @@ -17,7 +23,10 @@ def test_compare_series_value_mismatch() -> None: def test_compare_series_empty_equal() -> None: srs1 = pl.Series([]) srs2 = pl.Series(()) + assert_series_equal(srs1, srs2) + with pytest.raises(AssertionError): + assert_series_not_equal(srs1, srs2) def test_compare_series_nans_assert_equal() -> None: @@ -34,8 +43,11 @@ def test_compare_series_nans_assert_equal() -> None: with pytest.raises(AssertionError): assert_series_equal(srs1, srs1, nans_compare_equal=False) + assert_series_not_equal(srs1, srs1, nans_compare_equal=False) + with pytest.raises(AssertionError): assert_series_equal(srs1, srs1, nans_compare_equal=False, check_exact=True) + assert_series_not_equal(srs1, srs1, nans_compare_equal=False, check_exact=True) for check_exact, nans_equal in ( (False, False), @@ -64,6 +76,7 @@ def test_compare_series_nans_assert_equal() -> None: assert_series_equal(srs4, srs6, check_dtype=False) with pytest.raises(AssertionError): assert_series_equal(srs5, srs6, check_dtype=False) + assert_series_not_equal(srs5, srs6, check_dtype=True) def test_compare_series_nulls() -> None: @@ -73,6 +86,7 @@ def test_compare_series_nulls() -> None: srs1 = pl.Series([1, 2, 3]) srs2 = pl.Series([1, None, None]) + with pytest.raises(AssertionError, match="Value mismatch"): assert_series_equal(srs1, srs2) with pytest.raises(AssertionError, match="Exact value mismatch"):