Skip to content

Commit

Permalink
Fix list of literal generation (#239) (#240)
Browse files Browse the repository at this point in the history
  • Loading branch information
Simske committed Jun 25, 2023
1 parent 26073c4 commit 01ba280
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
24 changes: 19 additions & 5 deletions polyfactory/value_generators/complex_types.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
from __future__ import annotations

from typing import TYPE_CHECKING, AbstractSet, Any, Collection, MutableMapping, MutableSequence, Set, TypeVar

from typing_extensions import is_typeddict
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
Collection,
MutableMapping,
MutableSequence,
Set,
TypeVar,
)

from typing_extensions import get_args, is_typeddict

from polyfactory.utils.helpers import unwrap_annotation
from polyfactory.utils.predicates import get_type_origin, is_any, is_union
from polyfactory.utils.predicates import get_type_origin, is_any, is_literal, is_union
from polyfactory.value_generators.primitives import create_random_string

if TYPE_CHECKING:
Expand Down Expand Up @@ -58,7 +67,12 @@ def handle_complex_type(field_meta: FieldMeta, factory: type[BaseFactory]) -> An
:returns: A built result.
"""
if origin := get_type_origin(unwrap_annotation(field_meta.annotation, random=factory.__random__)):
unwrapped_annotation = unwrap_annotation(field_meta.annotation, random=factory.__random__)

if is_literal(annotation=unwrapped_annotation) and (literal_args := get_args(unwrapped_annotation)):
return factory.__random__.choice(literal_args)

if origin := get_type_origin(unwrapped_annotation):
if issubclass(origin, Collection):
return handle_collection_type(field_meta, origin, factory)
return factory.get_mock_value(origin)
Expand Down
1 change: 1 addition & 0 deletions tests/test_complex_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class MyModel(BaseModel):
nested_dict: Dict[str, Dict[Union[int, str], Dict[Any, List[Dict[str, str]]]]]
dict_str_any: Dict[str, Any]
nested_list: List[List[List[Dict[str, List[Any]]]]]
sequence_literal: Sequence[Literal[1, 2, 3]]
sequence_dict: Sequence[Dict]
iterable_float: Iterable[float]
tuple_ellipsis: Tuple[int, ...]
Expand Down

0 comments on commit 01ba280

Please sign in to comment.