diff --git a/py-polars/polars/testing/parametric/strategies/data.py b/py-polars/polars/testing/parametric/strategies/data.py index 2bb3345c3ce3..71313436dc36 100644 --- a/py-polars/polars/testing/parametric/strategies/data.py +++ b/py-polars/polars/testing/parametric/strategies/data.py @@ -4,7 +4,7 @@ import decimal from datetime import timedelta -from typing import TYPE_CHECKING, Any, Literal, Sequence +from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence import hypothesis.strategies as st from hypothesis.errors import InvalidArgument @@ -34,6 +34,7 @@ Decimal, Duration, Enum, + Field, Float32, Float64, Int8, @@ -43,6 +44,7 @@ List, Null, String, + Struct, Time, UInt8, UInt16, @@ -58,10 +60,10 @@ if TYPE_CHECKING: from datetime import date, datetime, time - from hypothesis.strategies import SearchStrategy + from hypothesis.strategies import DrawFn, SearchStrategy from polars.datatypes import DataType, DataTypeClass - from polars.type_aliases import PolarsDataType, TimeUnit + from polars.type_aliases import PolarsDataType, SchemaDict, TimeUnit _DEFAULT_LIST_LEN_LIMIT = 3 _DEFAULT_N_CATEGORIES = 10 @@ -278,6 +280,28 @@ def lists( ) +@st.composite +def structs( # noqa: D417 + draw: DrawFn, /, fields: Sequence[Field] | SchemaDict, **kwargs: Any +) -> dict[str, Any]: + """ + Create a strategy for generating structs with the given fields. + + Parameters + ---------- + fields + The fields that make up the struct. Can be either a sequence of Field + objects or a mapping of column names to data types. + **kwargs + Additional arguments that are passed to nested data generation strategies. + """ + if isinstance(fields, Mapping): + fields = [Field(name, dtype) for name, dtype in fields.items()] + + strats = {f.name: data(f.dtype, **kwargs) for f in fields} + return {col: draw(strat) for col, strat in strats.items()} + + def nulls() -> SearchStrategy[None]: """Create a strategy for generating null values.""" return st.none() @@ -360,6 +384,9 @@ def data( allow_null=allow_null, **kwargs, ) + elif dtype == Struct: + fields = getattr(dtype, "fields", None) or [Field("f0", Null())] + strategy = structs(fields, **kwargs) else: msg = f"unsupported data type: {dtype}" raise InvalidArgument(msg) diff --git a/py-polars/polars/testing/parametric/strategies/dtype.py b/py-polars/polars/testing/parametric/strategies/dtype.py index d3a192e462a1..dac7049def65 100644 --- a/py-polars/polars/testing/parametric/strategies/dtype.py +++ b/py-polars/polars/testing/parametric/strategies/dtype.py @@ -72,7 +72,7 @@ # TODO: Enable nested types by default when various issues are solved. # List, # Array, - # Struct, + Struct, ] # Supported data type classes that do not contain other data types _FLAT_DTYPES = _SIMPLE_DTYPES + _COMPLEX_DTYPES diff --git a/py-polars/tests/unit/dataframe/test_null_count.py b/py-polars/tests/unit/dataframe/test_null_count.py index 11755bbdcb9b..a9b1141a2a67 100644 --- a/py-polars/tests/unit/dataframe/test_null_count.py +++ b/py-polars/tests/unit/dataframe/test_null_count.py @@ -11,7 +11,11 @@ min_size=1, min_cols=1, allow_null=True, - excluded_dtypes=[pl.String, pl.List], + excluded_dtypes=[ + pl.String, + pl.List, + pl.Struct, # See: https://github.com/pola-rs/polars/issues/3462 + ], ) ) @example(df=pl.DataFrame(schema=["x", "y", "z"])) diff --git a/py-polars/tests/unit/dataframe/test_to_dict.py b/py-polars/tests/unit/dataframe/test_to_dict.py index 30414f7c4a23..e95fc014caf5 100644 --- a/py-polars/tests/unit/dataframe/test_to_dict.py +++ b/py-polars/tests/unit/dataframe/test_to_dict.py @@ -10,7 +10,13 @@ from polars.testing.parametric import dataframes -@given(df=dataframes()) +@given( + df=dataframes( + excluded_dtypes=[ + pl.Categorical, # Bug: https://github.com/pola-rs/polars/issues/16196 + ] + ) +) def test_to_dict(df: pl.DataFrame) -> None: d = df.to_dict(as_series=False) result = pl.from_dict(d, schema=df.schema) diff --git a/py-polars/tests/unit/operations/test_clear.py b/py-polars/tests/unit/operations/test_clear.py index 7799e7e05ce2..ff2e94fecb94 100644 --- a/py-polars/tests/unit/operations/test_clear.py +++ b/py-polars/tests/unit/operations/test_clear.py @@ -8,14 +8,24 @@ from polars.testing.parametric import series -@given(s=series(), n=st.integers(min_value=0, max_value=10)) -def test_clear_series_parametric(s: pl.Series, n: int) -> None: +@given(s=series()) +def test_clear_series_parametric(s: pl.Series) -> None: result = s.clear() assert result.dtype == s.dtype assert result.name == s.name assert result.is_empty() + +@given( + s=series( + excluded_dtypes=[ + pl.Struct, # See: https://github.com/pola-rs/polars/issues/3462 + ] + ), + n=st.integers(min_value=0, max_value=10), +) +def test_clear_series_n_parametric(s: pl.Series, n: int) -> None: result = s.clear(n) assert result.dtype == s.dtype diff --git a/py-polars/tests/unit/operations/test_drop_nulls.py b/py-polars/tests/unit/operations/test_drop_nulls.py index 4250ecad154e..287a7ec2b7b0 100644 --- a/py-polars/tests/unit/operations/test_drop_nulls.py +++ b/py-polars/tests/unit/operations/test_drop_nulls.py @@ -7,7 +7,14 @@ from polars.testing.parametric import series -@given(s=series(allow_null=True)) +@given( + s=series( + allow_null=True, + excluded_dtypes=[ + pl.Struct, # See: https://github.com/pola-rs/polars/issues/3462 + ], + ) +) def test_drop_nulls_parametric(s: pl.Series) -> None: result = s.drop_nulls() assert result.len() == s.len() - s.null_count() diff --git a/py-polars/tests/unit/testing/parametric/strategies/test_data.py b/py-polars/tests/unit/testing/parametric/strategies/test_data.py index a7316b0a7a7b..d3c5282ac959 100644 --- a/py-polars/tests/unit/testing/parametric/strategies/test_data.py +++ b/py-polars/tests/unit/testing/parametric/strategies/test_data.py @@ -29,3 +29,9 @@ def test_data_enum(cat: str) -> None: @given(cat=data(pl.Enum(["hello", "world"]))) def test_data_enum_instantiated(cat: str) -> None: assert cat in ("hello", "world") + + +@given(struct=data(pl.Struct({"a": pl.Int8, "b": pl.String}))) +def test_data_struct(struct: dict[str, int | str]) -> None: + assert isinstance(struct["a"], int) + assert isinstance(struct["b"], str)