Skip to content

Commit

Permalink
Implement auto registration hook (#125)
Browse files Browse the repository at this point in the history
* Implement auto registration hook

* * Remove TypeDict
* Use const value instead of getting class variable.

* Fix Black errors

* fix pylint

* Lint + fix lint errors

* add typedict and dataclass test cases

Co-authored-by: Avihai <ayosef@paypal.com>
  • Loading branch information
avihai-yosef and avihaiyosef committed Nov 30, 2022
1 parent a0151e0 commit 82464cc
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 2 deletions.
40 changes: 38 additions & 2 deletions pydantic_factories/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,18 @@ class ModelFactory(Generic[T]):
__async_persistence__: Optional[Union[Type[AsyncPersistenceProtocol[T]], AsyncPersistenceProtocol[T]]] = None
__allow_none_optionals__: bool = True
__random_seed__: Optional[int] = None
__auto_register__: bool = False

# Private Fields
_registered_model_factory_map: Dict[FactoryTypes, "ModelFactory[T]"] = {}

# Private Methods

def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None:
super().__init_subclass__(*args, **kwargs)
if cls.__auto_register__:
cls._register_model_factory()

@classmethod
def _get_model(cls) -> Type[T]:
"""Returns the factory's model."""
Expand Down Expand Up @@ -325,6 +334,32 @@ def _handle_factory_field(cls, field_value: Any) -> Any:

return field_value

@classmethod
def _get_or_create_factory(
cls,
model: Type[FactoryTypes],
) -> "ModelFactory":
"""get from registered factories or generate dynamically a
'ModelFactory' for a given pydantic model subclass.
Args:
model: A pydantic model subclass.
Returns:
A 'ModelFactory' subclass.
"""
factory = cls._get_registered_model_factory(model)
if factory:
return factory
return cls.create_factory(model)

@classmethod
def _register_model_factory(cls) -> None:
cls._registered_model_factory_map[cls._get_model()] = cast("ModelFactory", cls)

@classmethod
def _get_registered_model_factory(cls, model: FactoryTypes) -> Optional["ModelFactory"]:
return cls._registered_model_factory_map.get(model)

# Public Methods

@classmethod
Expand Down Expand Up @@ -593,13 +628,14 @@ def get_field_value(

if is_pydantic_model(outer_type) or is_dataclass(outer_type) or is_typeddict(outer_type):

return cls.create_factory(model=outer_type).build(
return cls._get_or_create_factory(model=outer_type).build(
**(field_parameters if isinstance(field_parameters, dict) else {})
)

if isinstance(field_parameters, list) and is_pydantic_model(model_field.type_):
return [
cls.create_factory(model=model_field.type_).build(**build_kwargs) for build_kwargs in field_parameters
cls._get_or_create_factory(model=model_field.type_).build(**build_kwargs)
for build_kwargs in field_parameters
]

if cls.is_constrained_field(outer_type):
Expand Down
99 changes: 99 additions & 0 deletions tests/test_factory_auto_registration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from dataclasses import dataclass as vanilla_dataclass
from typing import List

from pydantic import BaseModel
from typing_extensions import TypedDict

from pydantic_factories import ModelFactory


class A(BaseModel):
a_text: str


class B(BaseModel):
b_text: str
a: A


class C(BaseModel):
b: B
b_list: List[B]


def test_auto_register_model_factory() -> None:
class AFactory(ModelFactory):
a_text = "const value"
__model__ = A

class BFactory(ModelFactory):
b_text = "const value"
__model__ = B
__auto_register__ = True

class CFactory(ModelFactory):
__model__ = C

c = CFactory.build()

assert c.b.b_text == BFactory.b_text
assert c.b_list[0].b_text == BFactory.b_text
assert c.b.a.a_text != AFactory.a_text


def test_auto_register_model_factory_using_create_factory() -> None:
const_value = "const value"
ModelFactory.create_factory(model=A, a_text=const_value)
ModelFactory.create_factory(model=B, b_text=const_value, __auto_register__=True)
CFactory = ModelFactory.create_factory(model=C)

c = CFactory.build()

assert c.b.b_text == const_value
assert c.b_list[0].b_text == const_value
assert c.b.a.a_text != const_value


def test_dataclass_model_factory_auto_registration() -> None:
@vanilla_dataclass
class DataClass:
text: str

class UpperModel(BaseModel):
nested_field: DataClass
nested_list_field: List[DataClass]

class UpperModelFactory(ModelFactory):
__model__ = UpperModel

class DataClassFactory(ModelFactory):
text = "const value"
__model__ = DataClass
__auto_register__ = True

upper = UpperModelFactory.build()

assert upper.nested_field.text == DataClassFactory.text
assert upper.nested_list_field[0].text == DataClassFactory.text


def test_typeddict_model_factory_auto_registration() -> None:
class TypedDictModel(TypedDict):
text: str

class UpperSchema(BaseModel):
nested_field: TypedDictModel
nested_list_field: List[TypedDictModel]

class UpperModelFactory(ModelFactory):
__model__ = UpperSchema

class TypedDictFactory(ModelFactory):
text = "const value"
__model__ = TypedDictModel
__auto_register__ = True

upper = UpperModelFactory.build()

assert upper.nested_field["text"] == TypedDictFactory.text
assert upper.nested_list_field[0]["text"] == TypedDictFactory.text

0 comments on commit 82464cc

Please sign in to comment.