Skip to content

Commit

Permalink
feat: implementation of Attrs Factory (litestar-org#313)
Browse files Browse the repository at this point in the history
* feat: implement attrs factory

* docs: add documentation for attrs factory

* docs: added example usage of AttrsFactory
  • Loading branch information
guacs committed Jul 28, 2023
1 parent 0bfa3b4 commit 96c61ae
Show file tree
Hide file tree
Showing 9 changed files with 342 additions and 2 deletions.
29 changes: 29 additions & 0 deletions docs/examples/declaring_factories/test_example_7.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from datetime import date, datetime
from typing import Any, Dict, List, Union
from uuid import UUID

import attrs

from polyfactory.factories.attrs_factory import AttrsFactory


@attrs.define
class Person:
id: UUID
name: str
hobbies: List[str]
age: Union[float, int]
# an aliased variable
birthday: Union[datetime, date] = attrs.field(alias="date_of_birth")
# a "private" variable
_assets: List[Dict[str, Dict[str, Any]]]


class PersonFactory(AttrsFactory[Person]):
__model__ = Person


def test_person_factory() -> None:
person = PersonFactory.build()

assert isinstance(person, Person)
5 changes: 5 additions & 0 deletions docs/reference/factories/attrs_factory.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
attrs_factory
================

.. automodule:: polyfactory.factories.attrs_factory
:members:
1 change: 1 addition & 0 deletions docs/reference/factories/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ factories
msgspec_factory
odmantic_odm_factory
beanie_odm_factory
attrs_factory
9 changes: 9 additions & 0 deletions docs/usage/declaring_factories.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ Or for pydantic models:
:caption: Declaring a factory for a pydantic dataclass
:language: python

Or for attrs models:

.. literalinclude:: /examples/declaring_factories/test_example_7.py
:caption: Declaring a factory for a attrs model
:language: python

.. note::
Validators are not currently supported - neither the built in validators that come
with `attrs` nor custom validators.

Imperative Factory Creation
---------------------------
Expand Down
3 changes: 3 additions & 0 deletions docs/usage/library_factories.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ These include:
:class:`MsgspecFactory <polyfactory.factories.msgspec_factory.MsgspecFactory>`
a base factory for `msgspec <https://jcristharif.com/msgspec/>`_ Structs

:class:`AttrsFactory <polyfactory.factories.attrs_factory.AttrsFactory>`
a base factory for `attrs <https://www.attrs.org/en/stable/index.html>`_ models.

.. note::
All factories exported from ``polyfactory.factories`` do not require any additional dependencies. The other factories,
such as :class:`ModelFactory <polyfactory.factories.pydantic_factory.ModelFactory>`, require an additional but optional
Expand Down
5 changes: 3 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

65 changes: 65 additions & 0 deletions polyfactory/factories/attrs_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from __future__ import annotations

from inspect import isclass
from typing import TYPE_CHECKING, TypeVar

from polyfactory.exceptions import MissingDependencyException
from polyfactory.factories.base import BaseFactory
from polyfactory.field_meta import FieldMeta, Null

if TYPE_CHECKING:
from typing import Any, TypeGuard


try:
import attrs
from attr._make import Factory
except ImportError as ex:
raise MissingDependencyException("attrs is not installed") from ex

T = TypeVar("T", bound=attrs.AttrsInstance)


class AttrsFactory(BaseFactory[T]):
"""Base factory for attrs classes."""

__is_base_factory__ = True

@classmethod
def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]:
return isclass(value) and hasattr(value, "__attrs_attrs__")

@classmethod
def get_model_fields(cls) -> list[FieldMeta]:
field_metas: list[FieldMeta] = []
fields = attrs.fields(cls.__model__)
none_type = type(None)

for field in fields:
annotation = none_type if field.type is None else field.type

default = field.default
if isinstance(default, Factory):
# The default value is not currently being used when generating
# the field values. When that is implemented, this would need
# to be handled differently since the `default.factory` could
# take a `self` argument.
default_value = default.factory
elif default is None:
default_value = Null
else:
default_value = default

field_metas.append(
FieldMeta.from_type(
annotation=annotation,
name=field.alias,
default=default_value,
random=cls.__random__,
randomize_collection_length=cls.__randomize_collection_length__,
min_collection_length=cls.__min_collection_length__,
max_collection_length=cls.__max_collection_length__,
)
)

return field_metas
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ msgspec = { version = "*", optional = true }
odmantic = { version = "*", optional = true }
pydantic = { version = "*", optional = true, extras = ["email"] }
typing-extensions = "*"
attrs = {version = "*", optional = true}

[tool.poetry.group.dev.dependencies]
hypothesis = "*"
Expand Down Expand Up @@ -98,11 +99,13 @@ pydantic = ["pydantic"]
msgspec = ["msgspec"]
odmantic = ["odmantic", "pydantic"]
beanie = ["beanie", "pydantic"]
attrs = ["attrs"]
full = [
"pydantic",
"odmantic",
"msgspec",
"beanie",
"attrs",
]

[build-system]
Expand Down
224 changes: 224 additions & 0 deletions tests/test_attrs_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
import datetime as dt
from decimal import Decimal
from enum import Enum
from typing import Any, Dict, FrozenSet, Generic, List, Set, Tuple, TypeVar
from uuid import UUID

import attrs
import pytest
from attrs import asdict, define

from polyfactory.factories.attrs_factory import AttrsFactory

pytestmark = [pytest.mark.attrs]


def test_is_supported_type() -> None:
@define
class Foo:
...

assert AttrsFactory.is_supported_type(Foo) is True


def test_is_supported_type_without_struct() -> None:
class Foo:
...

assert AttrsFactory.is_supported_type(Foo) is False


def test_with_basic_types_annotated() -> None:
class SampleEnum(Enum):
FOO = "foo"
BAR = "bar"

@define
class Foo:
bool_field: bool
int_field: int
float_field: float
str_field: str
bytse_field: bytes
bytearray_field: bytearray
tuple_field: Tuple[int, str]
tuple_with_variadic_args: Tuple[int, ...]
list_field: List[int]
dict_field: Dict[str, int]
datetime_field: dt.datetime
date_field: dt.date
time_field: dt.time
uuid_field: UUID
decimal_field: Decimal
enum_type: SampleEnum
any_type: Any

class FooFactory(AttrsFactory[Foo]):
__model__ = Foo

foo = FooFactory.build()
foo_dict = asdict(foo)

assert foo == Foo(**foo_dict)


def test_with_basic_types_attrs_field() -> None:
@define
class Foo:
bool_field = attrs.field(type=bool) # pyright: ignore
int_field = attrs.field(type=int) # pyright: ignore
float_field = attrs.field(type=float) # pyright: ignore
str_field = attrs.field(type=str) # pyright: ignore
bytes_field = attrs.field(type=bytes) # pyright: ignore
bytearray_field = attrs.field(type=bytearray) # pyright: ignore
tuple_field = attrs.field(type=Tuple[int, str]) # type: ignore
tuple_with_variadic_args = attrs.field(type=Tuple[int, ...]) # type: ignore
list_field = attrs.field(type=List[int]) # pyright: ignore
dict_field = attrs.field(type=Dict[int, str]) # pyright: ignore
datetime_field = attrs.field(type=dt.datetime) # pyright: ignore
date_field = attrs.field(type=dt.date) # pyright: ignore
time_field = attrs.field(type=dt.time) # pyright: ignore
uuid_field = attrs.field(type=UUID) # pyright: ignore
decimal_field = attrs.field(type=Decimal) # pyright: ignore

class FooFactory(AttrsFactory[Foo]):
__model__ = Foo

foo = FooFactory.build()
foo_dict = asdict(foo)

assert foo == Foo(**foo_dict)


def test_with_nested_attr_model() -> None:
@define
class Foo:
int_field: int

@define
class Bar:
int_field: int
foo_field: Foo

class BarFactory(AttrsFactory[Bar]):
__model__ = Bar

bar = BarFactory.build()
bar_dict = asdict(bar, recurse=False)

assert bar == Bar(**bar_dict)


def test_with_private_fields() -> None:
@define
class Foo:
_private: int

class FooFactory(AttrsFactory[Foo]):
__model__ = Foo

foo = FooFactory.build()

assert foo == Foo(foo._private)


def test_with_aliased_fields() -> None:
@define
class Foo:
aliased: int = attrs.field(alias="foo")

class FooFactory(AttrsFactory[Foo]):
__model__ = Foo

foo = FooFactory.build()

assert foo == Foo(foo.aliased)


@pytest.mark.parametrize("type_", (Set, FrozenSet, List))
def test_variable_length(type_: Any) -> None:
@define
class Foo:
items: type_[int]

class FooFactory(AttrsFactory[Foo]):
__model__ = Foo

__randomize_collection_length__ = True
number_of_args = 3

__min_collection_length__ = number_of_args
__max_collection_length__ = number_of_args

foo = FooFactory.build()
assert len(foo.items) == 3


def test_variable_length__dict() -> None:
@define
class Foo:
items: Dict[int, float]

class FooFactory(AttrsFactory[Foo]):
__model__ = Foo

__randomize_collection_length__ = True
number_of_args = 3

__min_collection_length__ = number_of_args
__max_collection_length__ = number_of_args

foo = FooFactory.build()
assert len(foo.items) == 3


def test_variable_length__tuple() -> None:
@define
class Foo:
items: Tuple[int, ...]

class FooFactory(AttrsFactory[Foo]):
__model__ = Foo

__randomize_collection_length__ = True
number_of_args = 3

__min_collection_length__ = number_of_args
__max_collection_length__ = number_of_args

foo = FooFactory.build()
assert len(foo.items) == 3


def test_with_generics() -> None:
T = TypeVar("T")

@define
class Foo(Generic[T]):
x: T

class FooFactory(AttrsFactory[Foo[str]]):
__model__ = Foo

foo = FooFactory.build()
foo_dict = asdict(foo)

assert foo == Foo(**foo_dict)


def test_with_inheritance() -> None:
@define
class Parent:
int_field: int

@define
class Child:
str_field: str

class ChildFactory(AttrsFactory[Child]):
__model__ = Child

child = ChildFactory.build()
child_dict = asdict(child)

assert child == Child(**child_dict)

0 comments on commit 96c61ae

Please sign in to comment.