Skip to content

Commit

Permalink
Fix inheritance support for msgspec (litestar-org#295)
Browse files Browse the repository at this point in the history
* fix inherited models support for msgspec

* use `get_type_hints`
  • Loading branch information
abdulhaq-e committed Jul 10, 2023
1 parent 79e8145 commit 8cfaeb4
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 2 deletions.
4 changes: 3 additions & 1 deletion polyfactory/factories/msgspec_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
cast,
)

from typing_extensions import get_type_hints

from polyfactory.exceptions import MissingDependencyException
from polyfactory.factories.base import BaseFactory
from polyfactory.field_meta import FieldMeta, Null
Expand Down Expand Up @@ -58,7 +60,7 @@ def get_model_fields(cls) -> list[FieldMeta]:
fields_meta: list[FieldMeta] = []

for field in type_info.fields:
annotation = cls.__model__.__annotations__[field.name]
annotation = get_type_hints(cls.__model__, include_extras=True)[field.name]
if field.default is not msgspec.NODEFAULT:
default_value = field.default
elif field.default_factory is not msgspec.NODEFAULT:
Expand Down
39 changes: 38 additions & 1 deletion tests/test_msgspec_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
from decimal import Decimal
from enum import Enum
from typing import Any, Dict, FrozenSet, List, NewType, Set, Tuple, Type, Union
from typing import Any, Dict, FrozenSet, Generic, List, NewType, Set, Tuple, Type, TypeVar, Union
from uuid import UUID

import msgspec
Expand Down Expand Up @@ -227,3 +227,40 @@ class FooFactory(MsgspecFactory[Foo]):

foo = FooFactory.build()
assert len(foo.items) == 3


def test_inheritence() -> None:
class Foo(Struct):
int_field: int

class Bar(Foo):
pass

class BarFactory(MsgspecFactory[Bar]):
__model__ = Bar

bar = BarFactory.build()
bar_dict = structs.asdict(bar)

validated_bar = msgspec.convert(bar_dict, type=Bar)
assert validated_bar == bar


def test_inheritence_with_generics() -> None:
T = TypeVar("T")

class Foo(Struct, Generic[T]):
int_field: int
generic_field: T

class Bar(Foo[str]):
pass

class BarFactory(MsgspecFactory[Bar]):
__model__ = Bar

bar = BarFactory.build()
bar_dict = structs.asdict(bar)

validated_bar = msgspec.convert(bar_dict, type=Bar)
assert validated_bar == bar

0 comments on commit 8cfaeb4

Please sign in to comment.