Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Implement support for Struct types in parametric tests #16197

Merged
merged 4 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions py-polars/polars/testing/parametric/strategies/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import decimal
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Literal, Sequence
from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence

import hypothesis.strategies as st
from hypothesis.errors import InvalidArgument
Expand Down Expand Up @@ -34,6 +34,7 @@
Decimal,
Duration,
Enum,
Field,
Float32,
Float64,
Int8,
Expand All @@ -43,6 +44,7 @@
List,
Null,
String,
Struct,
Time,
UInt8,
UInt16,
Expand All @@ -58,10 +60,10 @@
if TYPE_CHECKING:
from datetime import date, datetime, time

from hypothesis.strategies import SearchStrategy
from hypothesis.strategies import DrawFn, SearchStrategy

from polars.datatypes import DataType, DataTypeClass
from polars.type_aliases import PolarsDataType, TimeUnit
from polars.type_aliases import PolarsDataType, SchemaDict, TimeUnit

_DEFAULT_LIST_LEN_LIMIT = 3
_DEFAULT_N_CATEGORIES = 10
Expand Down Expand Up @@ -278,6 +280,28 @@ def lists(
)


@st.composite
def structs( # noqa: D417
draw: DrawFn, /, fields: Sequence[Field] | SchemaDict, **kwargs: Any
) -> dict[str, Any]:
"""
Create a strategy for generating structs with the given fields.
Parameters
----------
fields
The fields that make up the struct. Can be either a sequence of Field
objects or a mapping of column names to data types.
**kwargs
Additional arguments that are passed to nested data generation strategies.
"""
if isinstance(fields, Mapping):
fields = [Field(name, dtype) for name, dtype in fields.items()]

strats = {f.name: data(f.dtype, **kwargs) for f in fields}
return {col: draw(strat) for col, strat in strats.items()}


def nulls() -> SearchStrategy[None]:
"""Create a strategy for generating null values."""
return st.none()
Expand Down Expand Up @@ -360,6 +384,9 @@ def data(
allow_null=allow_null,
**kwargs,
)
elif dtype == Struct:
fields = getattr(dtype, "fields", None) or [Field("f0", Null())]
strategy = structs(fields, **kwargs)
else:
msg = f"unsupported data type: {dtype}"
raise InvalidArgument(msg)
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/testing/parametric/strategies/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
# TODO: Enable nested types by default when various issues are solved.
# List,
# Array,
# Struct,
Struct,
]
# Supported data type classes that do not contain other data types
_FLAT_DTYPES = _SIMPLE_DTYPES + _COMPLEX_DTYPES
Expand Down
6 changes: 5 additions & 1 deletion py-polars/tests/unit/dataframe/test_null_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
min_size=1,
min_cols=1,
allow_null=True,
excluded_dtypes=[pl.String, pl.List],
excluded_dtypes=[
pl.String,
pl.List,
pl.Struct, # See: https://github.com/pola-rs/polars/issues/3462
],
)
)
@example(df=pl.DataFrame(schema=["x", "y", "z"]))
Expand Down
8 changes: 7 additions & 1 deletion py-polars/tests/unit/dataframe/test_to_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@
from polars.testing.parametric import dataframes


@given(df=dataframes())
@given(
df=dataframes(
excluded_dtypes=[
pl.Categorical, # Bug: https://github.com/pola-rs/polars/issues/16196
]
)
)
def test_to_dict(df: pl.DataFrame) -> None:
d = df.to_dict(as_series=False)
result = pl.from_dict(d, schema=df.schema)
Expand Down
14 changes: 12 additions & 2 deletions py-polars/tests/unit/operations/test_clear.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,24 @@
from polars.testing.parametric import series


@given(s=series(), n=st.integers(min_value=0, max_value=10))
def test_clear_series_parametric(s: pl.Series, n: int) -> None:
@given(s=series())
def test_clear_series_parametric(s: pl.Series) -> None:
result = s.clear()

assert result.dtype == s.dtype
assert result.name == s.name
assert result.is_empty()


@given(
s=series(
excluded_dtypes=[
pl.Struct, # See: https://github.com/pola-rs/polars/issues/3462
]
),
n=st.integers(min_value=0, max_value=10),
)
def test_clear_series_n_parametric(s: pl.Series, n: int) -> None:
result = s.clear(n)

assert result.dtype == s.dtype
Expand Down
9 changes: 8 additions & 1 deletion py-polars/tests/unit/operations/test_drop_nulls.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@
from polars.testing.parametric import series


@given(s=series(allow_null=True))
@given(
s=series(
allow_null=True,
excluded_dtypes=[
pl.Struct, # See: https://github.com/pola-rs/polars/issues/3462
],
)
)
def test_drop_nulls_parametric(s: pl.Series) -> None:
result = s.drop_nulls()
assert result.len() == s.len() - s.null_count()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,9 @@ def test_data_enum(cat: str) -> None:
@given(cat=data(pl.Enum(["hello", "world"])))
def test_data_enum_instantiated(cat: str) -> None:
assert cat in ("hello", "world")


@given(struct=data(pl.Struct({"a": pl.Int8, "b": pl.String})))
def test_data_struct(struct: dict[str, int | str]) -> None:
assert isinstance(struct["a"], int)
assert isinstance(struct["b"], str)