From 0af3b6436e58ddbb117b67fd28c668b41f4a78b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Antonio=20Perdiguero=20L=C3=B3pez?= Date: Sat, 9 Sep 2023 12:59:21 +0200 Subject: [PATCH] :sparkles: Discovering nested schemas (#114) --- flama/schemas/_libs/marshmallow/adapter.py | 42 +- flama/schemas/_libs/marshmallow/fields.py | 7 +- flama/schemas/_libs/pydantic/adapter.py | 29 + flama/schemas/_libs/typesystem/adapter.py | 46 +- flama/schemas/_libs/typesystem/fields.py | 6 +- flama/schemas/adapter.py | 18 + flama/schemas/data_structures.py | 34 +- flama/schemas/generator.py | 103 ++-- pyproject.toml | 2 +- tests/schemas/test_generator.py | 686 ++++++++++++++++++--- tests/test_pagination.py | 21 +- 11 files changed, 832 insertions(+), 162 deletions(-) diff --git a/flama/schemas/_libs/marshmallow/adapter.py b/flama/schemas/_libs/marshmallow/adapter.py index b4dd9efa..d1ee7a44 100644 --- a/flama/schemas/_libs/marshmallow/adapter.py +++ b/flama/schemas/_libs/marshmallow/adapter.py @@ -7,7 +7,7 @@ from apispec.ext.marshmallow import MarshmallowPlugin, resolve_schema_cls from flama.injection import Parameter -from flama.schemas._libs.marshmallow.fields import MAPPING +from flama.schemas._libs.marshmallow.fields import MAPPING, MAPPING_TYPES from flama.schemas.adapter import Adapter from flama.schemas.exceptions import SchemaGenerationError, SchemaValidationError from flama.types import JSONSchema @@ -35,7 +35,7 @@ def build_field( required: bool = True, default: t.Any = None, multiple: bool = False, - **kwargs + **kwargs, ) -> Field: field_args = { "required": required, @@ -80,6 +80,10 @@ def dump(self, schema: t.Union[t.Type[Schema], Schema], value: t.Dict[str, t.Any except Exception as exc: raise SchemaValidationError(errors=str(exc)) + def name(self, schema: t.Union[Schema, t.Type[Schema]]) -> str: + s = self.unique_schema(schema) + return s.__qualname__ if s.__module__ == "builtins" else f"{s.__module__}.{s.__qualname__}" + def to_json_schema(self, schema: t.Union[t.Type[Schema], t.Type[Field], Schema, Field]) -> JSONSchema: json_schema: t.Dict[str, t.Any] try: @@ -115,18 +119,32 @@ def unique_schema(self, schema: t.Union[Schema, t.Type[Schema]]) -> t.Type[Schem return schema - def is_schema( - self, obj: t.Any - ) -> t.TypeGuard[ # type: ignore # PORT: Remove this comment when stop supporting 3.9 - t.Union[Schema, t.Type[Schema]] - ]: + def _get_field_type(self, field: Field) -> t.Union[Schema, t.Type]: + if isinstance(field, marshmallow.fields.Nested): + return field.schema + + if isinstance(field, marshmallow.fields.List): + return self._get_field_type(field.inner) # type: ignore + + if isinstance(field, marshmallow.fields.Dict): + return self._get_field_type(field.value_field) # type: ignore + + try: + return MAPPING_TYPES[field.__class__] + except KeyError: + return None + + def schema_fields( + self, schema: t.Union[Schema, t.Type[Schema]] + ) -> t.Dict[str, t.Tuple[t.Union[t.Type, Schema], Field]]: + return { + name: (self._get_field_type(field), field) for name, field in self._schema_instance(schema).fields.items() + } + + def is_schema(self, obj: t.Any) -> t.TypeGuard[t.Union[Schema, t.Type[Schema]]]: # type: ignore return isinstance(obj, Schema) or (inspect.isclass(obj) and issubclass(obj, Schema)) - def is_field( - self, obj: t.Any - ) -> t.TypeGuard[ # type: ignore # PORT: Remove this comment when stop supporting 3.9 - t.Union[Field, t.Type[Field]] - ]: + def is_field(self, obj: t.Any) -> t.TypeGuard[t.Union[Field, t.Type[Field]]]: # type: ignore return isinstance(obj, Field) or (inspect.isclass(obj) and issubclass(obj, Field)) def _schema_instance(self, schema: t.Union[t.Type[Schema], Schema]) -> Schema: diff --git a/flama/schemas/_libs/marshmallow/fields.py b/flama/schemas/_libs/marshmallow/fields.py index abda4972..362159fb 100644 --- a/flama/schemas/_libs/marshmallow/fields.py +++ b/flama/schemas/_libs/marshmallow/fields.py @@ -1,12 +1,11 @@ # ruff: noqa import datetime -import typing +import typing as t import uuid -import marshmallow.fields from marshmallow.fields import * -MAPPING: typing.Dict[typing.Optional[typing.Type], typing.Type[marshmallow.fields.Field]] = { +MAPPING: t.Dict[t.Union[t.Type, None], t.Type[Field]] = { None: Field, int: Integer, float: Float, @@ -19,3 +18,5 @@ datetime.datetime: DateTime, datetime.time: Time, } + +MAPPING_TYPES = {v: k for k, v in MAPPING.items()} diff --git a/flama/schemas/_libs/pydantic/adapter.py b/flama/schemas/_libs/pydantic/adapter.py index 5bfa7534..613f7ce4 100644 --- a/flama/schemas/_libs/pydantic/adapter.py +++ b/flama/schemas/_libs/pydantic/adapter.py @@ -91,6 +91,10 @@ def dump(self, schema: t.Union[Schema, t.Type[Schema]], value: t.Dict[str, t.Any return self.validate(schema_cls, value) + def name(self, schema: t.Union[Schema, t.Type[Schema]]) -> str: + s = self.unique_schema(schema) + return s.__qualname__ if s.__module__ == "builtins" else f"{s.__module__}.{s.__qualname__}" + def to_json_schema(self, schema: t.Union[Schema, t.Type[Schema], Field]) -> JSONSchema: try: if self.is_schema(schema): @@ -116,6 +120,31 @@ def to_json_schema(self, schema: t.Union[Schema, t.Type[Schema], Field]) -> JSON def unique_schema(self, schema: t.Union[Schema, t.Type[Schema]]) -> t.Type[Schema]: return schema.__class__ if isinstance(schema, Schema) else schema + def _get_field_type( + self, field: Field + ) -> t.Union[t.Union[Schema, t.Type], t.List[t.Union[Schema, t.Type]], t.Dict[str, t.Union[Schema, t.Type]]]: + if not self.is_field(field): + return field + + if t.get_origin(field.annotation) == list: + return self._get_field_type(t.get_args(field.annotation)[0]) + + if t.get_origin(field.annotation) == dict: + return self._get_field_type(t.get_args(field.annotation)[1]) + + return field.annotation + + def schema_fields( + self, schema: t.Union[Schema, t.Type[Schema]] + ) -> t.Dict[ + str, + t.Tuple[ + t.Union[t.Union[Schema, t.Type], t.List[t.Union[Schema, t.Type]], t.Dict[str, t.Union[Schema, t.Type]]], + Field, + ], + ]: + return {name: (self._get_field_type(field), field) for name, field in schema.model_fields.items()} + def is_schema( self, obj: t.Any ) -> t.TypeGuard[t.Type[Schema]]: # type: ignore # PORT: Remove this comment when stop supporting 3.9 diff --git a/flama/schemas/_libs/typesystem/adapter.py b/flama/schemas/_libs/typesystem/adapter.py index a31afa94..ff201bb7 100644 --- a/flama/schemas/_libs/typesystem/adapter.py +++ b/flama/schemas/_libs/typesystem/adapter.py @@ -5,7 +5,7 @@ import typesystem from flama.injection import Parameter -from flama.schemas._libs.typesystem.fields import MAPPING +from flama.schemas._libs.typesystem.fields import MAPPING, MAPPING_TYPES from flama.schemas.adapter import Adapter from flama.schemas.exceptions import SchemaGenerationError, SchemaValidationError from flama.types import JSONSchema @@ -30,7 +30,7 @@ def build_field( required: bool = True, default: t.Any = None, multiple: bool = False, - **kwargs + **kwargs, ) -> Field: if required is False and default is not Parameter.empty: kwargs["default"] = default @@ -44,7 +44,7 @@ def build_field( if self.is_schema(type_) else MAPPING[type_]() ), - **kwargs + **kwargs, ) return MAPPING[type_](**kwargs) @@ -82,6 +82,13 @@ def _dump(self, value: t.Any) -> t.Any: return value + @t.no_type_check + def name(self, schema: Schema) -> str: + if not schema.title: + raise ValueError(f"Schema '{schema}' needs to define title attribute") + + return schema.title if schema.__module__ == "builtins" else f"{schema.__module__}.{schema.title}" + @t.no_type_check def to_json_schema(self, schema: t.Union[Schema, Field]) -> JSONSchema: try: @@ -100,6 +107,39 @@ def to_json_schema(self, schema: t.Union[Schema, Field]) -> JSONSchema: def unique_schema(self, schema: Schema) -> Schema: return schema + def _get_field_type( + self, field: Field + ) -> t.Union[t.Union[Schema, t.Type], t.List[t.Union[Schema, t.Type]], t.Dict[str, t.Union[Schema, t.Type]]]: + if isinstance(field, typesystem.Reference): + return field.target + + if isinstance(field, typesystem.Array): + return ( + [self._get_field_type(x) for x in field.items] + if isinstance(field.items, (list, tuple, set)) + else self._get_field_type(field.items) + ) + + if isinstance(field, typesystem.Object): + return {k: self._get_field_type(v) for k, v in field.properties.items()} + + try: + return MAPPING_TYPES[field.__class__] + except KeyError: + return None + + @t.no_type_check + def schema_fields( + self, schema: Schema + ) -> t.Dict[ + str, + t.Tuple[ + t.Union[t.Union[Schema, t.Type], t.List[t.Union[Schema, t.Type]], t.Dict[str, t.Union[Schema, t.Type]]], + Field, + ], + ]: + return {name: (self._get_field_type(field), field) for name, field in schema.fields.items()} + @t.no_type_check def is_schema( self, obj: t.Any diff --git a/flama/schemas/_libs/typesystem/fields.py b/flama/schemas/_libs/typesystem/fields.py index 20f9769e..274cc177 100644 --- a/flama/schemas/_libs/typesystem/fields.py +++ b/flama/schemas/_libs/typesystem/fields.py @@ -1,12 +1,12 @@ # ruff: noqa import datetime -import typing +import typing as t import uuid from typesystem.fields import * from typesystem.schemas import Reference -MAPPING: typing.Dict[typing.Any, typing.Type[Field]] = { +MAPPING: t.Dict[t.Union[t.Type, None], t.Type[Field]] = { None: Field, int: Integer, float: Float, @@ -19,3 +19,5 @@ datetime.datetime: DateTime, datetime.time: Time, } + +MAPPING_TYPES = {v: k for k, v in MAPPING.items()} diff --git a/flama/schemas/adapter.py b/flama/schemas/adapter.py index 1c5da1fb..9eee87e6 100644 --- a/flama/schemas/adapter.py +++ b/flama/schemas/adapter.py @@ -74,6 +74,10 @@ def load(self, schema: t.Union[_T_Schema, t.Type[_T_Schema]], value: t.Dict[str, def dump(self, schema: t.Union[_T_Schema, t.Type[_T_Schema]], value: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: ... + @abc.abstractmethod + def name(self, schema: t.Union[_T_Schema, t.Type[_T_Schema]]) -> str: + ... + @abc.abstractmethod def to_json_schema(self, schema: t.Union[_T_Schema, t.Type[_T_Schema], _T_Field]) -> JSONSchema: ... @@ -82,6 +86,20 @@ def to_json_schema(self, schema: t.Union[_T_Schema, t.Type[_T_Schema], _T_Field] def unique_schema(self, schema: t.Union[_T_Schema, t.Type[_T_Schema]]) -> t.Union[_T_Schema, t.Type[_T_Schema]]: ... + @abc.abstractmethod + def schema_fields( + self, schema: t.Union[_T_Schema, t.Type[_T_Schema]] + ) -> t.Dict[ + str, + t.Tuple[ + t.Union[ + t.Union[_T_Schema, t.Type], t.List[t.Union[_T_Schema, t.Type]], t.Dict[str, t.Union[_T_Schema, t.Type]] + ], + _T_Field, + ], + ]: + ... + @abc.abstractmethod def is_schema( self, obj: t.Any diff --git a/flama/schemas/data_structures.py b/flama/schemas/data_structures.py index 601f6e76..1533cf2b 100644 --- a/flama/schemas/data_structures.py +++ b/flama/schemas/data_structures.py @@ -3,7 +3,6 @@ import sys import typing as t -import flama.types from flama import schemas, types from flama.injection.resolver import Parameter as InjectionParameter @@ -15,6 +14,9 @@ __all__ = ["Field", "Schema", "Parameter", "Parameters"] +UNKNOWN = t.TypeVar("UNKNOWN") + + class ParameterLocation(enum.Enum): query = enum.auto() path = enum.auto() @@ -81,7 +83,7 @@ def is_http_valid_type(cls, type_: t.Type) -> bool: ) @property - def json_schema(self) -> flama.types.JSONSchema: + def json_schema(self) -> types.JSONSchema: return schemas.adapter.to_json_schema(self.field) @@ -121,13 +123,39 @@ def is_schema(cls, obj: t.Any) -> bool: return schemas.adapter.is_schema(obj) @property - def json_schema(self) -> t.Dict[str, t.Any]: + def name(self) -> str: + return schemas.adapter.name(self.schema) + + @property + def json_schema(self) -> types.JSONSchema: return schemas.adapter.to_json_schema(self.schema) @property def unique_schema(self) -> t.Any: return schemas.adapter.unique_schema(self.schema) + @property + def fields(self) -> t.Dict[str, t.Tuple[t.Any, t.Any]]: + return schemas.adapter.schema_fields(self.unique_schema) + + def nested_schemas(self, schema: t.Any = UNKNOWN) -> t.List[t.Any]: + if schema == UNKNOWN: + return self.nested_schemas(self) + + if schemas.adapter.is_schema(schema): + return [schema] + + if isinstance(schema, (list, tuple, set)): + return [x for field in schema for x in self.nested_schemas(field)] + + if isinstance(schema, dict): + return [x for field in schema.values() for x in self.nested_schemas(field)] + + if isinstance(schema, Schema): + return [x for field_type, _ in schema.fields.values() for x in self.nested_schemas(field_type)] + + return [] + @t.overload def validate(self, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: ... diff --git a/flama/schemas/generator.py b/flama/schemas/generator.py index 3d88b940..95e4afdd 100644 --- a/flama/schemas/generator.py +++ b/flama/schemas/generator.py @@ -2,13 +2,12 @@ import inspect import itertools import logging -import typing import typing as t from collections import defaultdict import yaml -from flama import routing, schemas +from flama import routing, schemas, types from flama.schemas import Schema, openapi from flama.schemas.data_structures import Parameter from flama.url import RegexPath @@ -39,33 +38,31 @@ def ref(self) -> str: return f"#/components/schemas/{self.name}" @property - def json_schema(self) -> t.Dict[str, t.Any]: + def json_schema(self) -> types.JSONSchema: return Schema(self.schema).json_schema -class SchemaRegistry(typing.Dict[int, SchemaInfo]): - def __init__(self, schemas: t.Optional[typing.Dict[str, schemas.Schema]] = None): +class SchemaRegistry(t.Dict[int, SchemaInfo]): + def __init__(self, schemas: t.Optional[t.Dict[str, schemas.Schema]] = None): super().__init__() for name, schema in (schemas or {}).items(): self.register(schema, name) - def __contains__(self, item: schemas.Schema) -> bool: - return super().__contains__(id(schemas.adapter.unique_schema(item))) + def __contains__(self, item: t.Any) -> bool: + return super().__contains__(id(schemas.Schema(item).unique_schema)) - def __getitem__(self, item: schemas.Schema) -> SchemaInfo: + def __getitem__(self, item: t.Any) -> SchemaInfo: """ Lookup method that allows using Schema classes or instances. :param item: Schema to look for. :return: Registered schema. """ - return super().__getitem__(id(schemas.adapter.unique_schema(item))) + return super().__getitem__(id(schemas.Schema(item).unique_schema)) @t.no_type_check - def _get_schema_references_from_schema( - self, schema: typing.Union[openapi.Schema, openapi.Reference] - ) -> typing.List[str]: + def _get_schema_references_from_schema(self, schema: t.Union[openapi.Schema, openapi.Reference]) -> t.List[str]: if isinstance(schema, openapi.Reference): return [schema.ref] @@ -100,10 +97,10 @@ def _get_schema_references_from_schema( return result - def _get_schema_references_from_path(self, path: openapi.Path) -> typing.List[str]: + def _get_schema_references_from_path(self, path: openapi.Path) -> t.List[str]: return [y for x in path.operations.values() for y in self._get_schema_references_from_operation(x)] - def _get_schema_references_from_operation(self, operation: openapi.Operation) -> typing.List[str]: + def _get_schema_references_from_operation(self, operation: openapi.Operation) -> t.List[str]: return [ *self._get_schema_references_from_operation_parameters(operation.parameters), *self._get_schema_references_from_operation_request_body(operation.requestBody), @@ -111,7 +108,7 @@ def _get_schema_references_from_operation(self, operation: openapi.Operation) -> *self._get_schema_references_from_operation_responses(operation.responses), ] - def _get_schema_references_from_operation_responses(self, responses: openapi.Responses) -> typing.List[str]: + def _get_schema_references_from_operation_responses(self, responses: openapi.Responses) -> t.List[str]: refs = [] for response in [x for x in responses.values() if x.content]: @@ -125,8 +122,8 @@ def _get_schema_references_from_operation_responses(self, responses: openapi.Res return refs def _get_schema_references_from_operation_callbacks( - self, callbacks: typing.Optional[typing.Dict[str, typing.Union[openapi.Callback, openapi.Reference]]] - ) -> typing.List[str]: + self, callbacks: t.Optional[t.Dict[str, t.Union[openapi.Callback, openapi.Reference]]] + ) -> t.List[str]: refs = [] if callbacks: @@ -140,8 +137,8 @@ def _get_schema_references_from_operation_callbacks( return refs def _get_schema_references_from_operation_request_body( - self, request_body: typing.Optional[typing.Union[openapi.RequestBody, openapi.Reference]] - ) -> typing.List[str]: + self, request_body: t.Optional[t.Union[openapi.RequestBody, openapi.Reference]] + ) -> t.List[str]: refs = [] if request_body: @@ -154,8 +151,8 @@ def _get_schema_references_from_operation_request_body( return refs def _get_schema_references_from_operation_parameters( - self, parameters: typing.Optional[typing.List[typing.Union[openapi.Parameter, openapi.Reference]]] - ) -> typing.List[str]: + self, parameters: t.Optional[t.List[t.Union[openapi.Parameter, openapi.Reference]]] + ) -> t.List[str]: refs = [] if parameters: @@ -167,7 +164,7 @@ def _get_schema_references_from_operation_parameters( return refs - def used(self, spec: openapi.OpenAPISpec) -> typing.Dict[int, SchemaInfo]: + def used(self, spec: openapi.OpenAPISpec) -> t.Dict[int, SchemaInfo]: """ Generate a dict containing used schemas. @@ -185,6 +182,14 @@ def used(self, spec: openapi.OpenAPISpec) -> typing.Dict[int, SchemaInfo]: } used_schemas.update({k: v for k, v in self.items() if v.name in refs_from_schemas}) + for child_schema in [y for x in used_schemas.values() for y in schemas.Schema(x.schema).nested_schemas()]: + schema = schemas.Schema(child_schema) + instance = schema.unique_schema + if instance not in used_schemas: + used_schemas[id(instance)] = ( + self[instance] if instance in self else SchemaInfo(name=schema.name, schema=instance) + ) + return used_schemas def register(self, schema: schemas.Schema, name: t.Optional[str] = None) -> int: @@ -198,28 +203,22 @@ def register(self, schema: schemas.Schema, name: t.Optional[str] = None) -> int: if schema in self: raise ValueError("Schema is already registered.") - schema_instance = schemas.adapter.unique_schema(schema) - if name is None: - if not inspect.isclass(schema_instance): - raise ValueError("Cannot infer schema name.") + s = schemas.Schema(schema) - try: - name = ( - schema_instance.__qualname__ - if schema_instance.__module__ == "builtins" - else f"{schema_instance.__module__}.{schema_instance.__qualname__}" - ) - except AttributeError: - raise ValueError("Cannot infer schema name.") + try: + schema_name = name or s.name + except ValueError as e: + raise ValueError("Cannot infer schema name.") from e + schema_instance = s.unique_schema schema_id = id(schema_instance) - self[schema_id] = SchemaInfo(name=name, schema=schema_instance) + self[schema_id] = SchemaInfo(name=schema_name, schema=schema_instance) return schema_id def get_openapi_ref( self, element: schemas.Schema, multiple: t.Optional[bool] = None - ) -> typing.Union[openapi.Schema, openapi.Reference]: + ) -> t.Union[openapi.Schema, openapi.Reference]: """ Builds the reference for a single schema or the array schema containing the reference. @@ -249,7 +248,7 @@ def __init__( contact_email: t.Optional[str] = None, license_name: t.Optional[str] = None, license_url: t.Optional[str] = None, - schemas: t.Optional[typing.Dict] = None, + schemas: t.Optional[t.Dict] = None, ): contact = ( openapi.Contact(name=contact_name, url=contact_url, email=contact_email) @@ -272,8 +271,8 @@ def __init__( self.schemas = SchemaRegistry(schemas=schemas) def get_endpoints( # type: ignore[override] - self, routes: typing.List[routing.BaseRoute], base_path: str = "" - ) -> typing.Dict[str, typing.List[EndpointInfo]]: + self, routes: t.List[routing.BaseRoute], base_path: str = "" + ) -> t.Dict[str, t.List[EndpointInfo]]: """ Given the routes, yields the following information: @@ -288,7 +287,7 @@ def get_endpoints( # type: ignore[override] :param base_path: The base endpoints path. :return: Data structure that contains metadata from every route. """ - endpoints_info: typing.Dict[str, typing.List[EndpointInfo]] = defaultdict(list) + endpoints_info: t.Dict[str, t.List[EndpointInfo]] = defaultdict(list) for route in routes: path = RegexPath(base_path + route.path.path).template @@ -333,8 +332,8 @@ def get_endpoints( # type: ignore[override] return endpoints_info def _build_endpoint_parameters( - self, endpoint: EndpointInfo, metadata: typing.Dict[str, typing.Any] - ) -> typing.Optional[typing.List[openapi.Parameter]]: + self, endpoint: EndpointInfo, metadata: t.Dict[str, t.Any] + ) -> t.Optional[t.List[openapi.Parameter]]: if not endpoint.query_parameters and not endpoint.path_parameters: return None @@ -363,8 +362,8 @@ def _build_endpoint_parameters( ] def _build_endpoint_body( - self, endpoint: EndpointInfo, metadata: typing.Dict[str, typing.Any] - ) -> typing.Optional[openapi.RequestBody]: + self, endpoint: EndpointInfo, metadata: t.Dict[str, t.Any] + ) -> t.Optional[openapi.RequestBody]: if not endpoint.body_parameter: return None @@ -384,8 +383,8 @@ def _build_endpoint_body( ) def _build_endpoint_response( - self, endpoint: EndpointInfo, metadata: typing.Dict[str, typing.Any] - ) -> typing.Tuple[typing.Optional[openapi.Response], str]: + self, endpoint: EndpointInfo, metadata: t.Dict[str, t.Any] + ) -> t.Tuple[t.Optional[openapi.Response], str]: try: response_code, main_response = list(metadata.get("responses", {}).items())[0] except IndexError: @@ -414,7 +413,7 @@ def _build_endpoint_response( str(response_code), ) - def _build_endpoint_default_response(self, metadata: typing.Dict[str, typing.Any]) -> openapi.Response: + def _build_endpoint_default_response(self, metadata: t.Dict[str, t.Any]) -> openapi.Response: return openapi.Response( description=metadata.get("responses", {}).get("default", {}).get("description", "Unexpected error."), content={ @@ -424,9 +423,7 @@ def _build_endpoint_default_response(self, metadata: typing.Dict[str, typing.Any }, ) - def _build_endpoint_responses( - self, endpoint: EndpointInfo, metadata: typing.Dict[str, typing.Any] - ) -> openapi.Responses: + def _build_endpoint_responses(self, endpoint: EndpointInfo, metadata: t.Dict[str, t.Any]) -> openapi.Responses: responses = metadata.get("responses", {}) try: main_response_code = next(iter(responses.keys())) @@ -480,7 +477,7 @@ def _build_endpoint_responses( } ) - def _parse_docstring(self, func: typing.Callable) -> t.Dict[t.Any, t.Any]: + def _parse_docstring(self, func: t.Callable) -> t.Dict[t.Any, t.Any]: """Given a function, parse the docstring as YAML and return a dictionary of info. :param func: Function to analyze docstring. @@ -523,7 +520,7 @@ def get_operation_schema(self, endpoint: EndpointInfo) -> openapi.Operation: }, ) - def get_api_schema(self, routes: typing.List[routing.BaseRoute]) -> typing.Dict[str, typing.Any]: + def get_api_schema(self, routes: t.List[routing.BaseRoute]) -> t.Dict[str, t.Any]: endpoints_info = self.get_endpoints(routes) for path, endpoints in endpoints_info.items(): @@ -533,6 +530,6 @@ def get_api_schema(self, routes: typing.List[routing.BaseRoute]) -> typing.Dict[ for schema in self.schemas.used(self.spec).values(): self.spec.add_schema(schema.name, openapi.Schema(schema.json_schema)) - api_schema: typing.Dict[str, typing.Any] = self.spec.asdict() + api_schema: t.Dict[str, t.Any] = self.spec.asdict() return api_schema diff --git a/pyproject.toml b/pyproject.toml index fba3adf5..4ea89afe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -175,7 +175,7 @@ reportPrivateImportUsage = false [tool.pytest.ini_options] minversion = 3 -addopts = "--dist=loadfile --junitxml=./test-results/pytest/results.xml --no-cov-on-fail --cov-report=html --cov-report=xml --cov-report=term-missing --cov-config=pyproject.toml --cov=. --pdbcls=IPython.terminal.debugger:TerminalPdb" +addopts = "--dist=loadfile --junitxml=./test-results/pytest/results.xml --no-cov-on-fail --cov-report=html --cov-report=xml --cov-report=term-missing --cov-config=pyproject.toml --cov=. --pdbcls=IPython.terminal.debugger:TerminalPdb -nauto" norecursedirs = [ "*settings*", "*urls*", diff --git a/tests/schemas/test_generator.py b/tests/schemas/test_generator.py index a4d6532a..726a2672 100644 --- a/tests/schemas/test_generator.py +++ b/tests/schemas/test_generator.py @@ -1,9 +1,12 @@ +import contextlib import typing as t +from collections import namedtuple import marshmallow import pydantic import pytest import typesystem +import typesystem.fields from flama import types from flama.endpoints import HTTPEndpoint @@ -19,23 +22,106 @@ def registry(self): return SchemaRegistry() @pytest.fixture(scope="function") - def foo_schema(self, registry): - from flama import schemas - - if schemas.lib == pydantic: + def foo_schema(self, app): + if app.schema.schema_library.lib == pydantic: schema = pydantic.create_model("Foo", name=(str, ...)) - elif schemas.lib == typesystem: + name = "pydantic.main.Foo" + elif app.schema.schema_library.lib == typesystem: schema = typesystem.Schema(title="Foo", fields={"name": typesystem.fields.String()}) - elif schemas.lib == marshmallow: + name = "typesystem.schemas.Foo" + elif app.schema.schema_library.lib == marshmallow: schema = type("Foo", (marshmallow.Schema,), {"name": marshmallow.fields.String()}) + name = "abc.Foo" else: - raise ValueError("Wrong schema lib") - registry.register(schema, "Foo") - return schema + raise ValueError(f"Wrong schema lib: {app.schema.schema_library.lib}") + return namedtuple("FooSchema", ("schema", "name"))(schema=schema, name=name) @pytest.fixture(scope="function") def foo_array_schema(self, foo_schema): - return t.List[foo_schema] + return t.List[foo_schema.schema] + + @pytest.fixture(scope="function") + def bar_schema(self, app, foo_schema): + child_schema = foo_schema.schema + if app.schema.schema_library.lib == pydantic: + schema = pydantic.create_model("Bar", foo=(child_schema, ...)) + name = "pydantic.main.Bar" + elif app.schema.schema_library.lib == typesystem: + schema = typesystem.Schema( + title="Bar", + fields={ + "foo": typesystem.Reference(to="Foo", definitions=typesystem.Definitions({"Foo": child_schema})) + }, + ) + name = "typesystem.schemas.Bar" + elif app.schema.schema_library.lib == marshmallow: + schema = type("Bar", (marshmallow.Schema,), {"foo": marshmallow.fields.Nested(child_schema())}) + name = "abc.Bar" + else: + raise ValueError(f"Wrong schema lib: {app.schema.schema_library.lib}") + return namedtuple("BarSchema", ("schema", "name"))(schema=schema, name=name) + + @pytest.fixture(scope="function") + def bar_list_schema(self, app, foo_schema): + child_schema = foo_schema.schema + if app.schema.schema_library.lib == pydantic: + schema = pydantic.create_model("BarList", foo=(t.List[child_schema], ...)) + name = "pydantic.main.BarList" + elif app.schema.schema_library.lib == typesystem: + schema = typesystem.Schema( + title="BarList", + fields={ + "foo": typesystem.Array( + typesystem.Reference(to="Foo", definitions=typesystem.Definitions({"Foo": child_schema})) + ) + }, + ) + name = "typesystem.schemas.BarList" + elif app.schema.schema_library.lib == marshmallow: + schema = type( + "BarList", + (marshmallow.Schema,), + {"foo": marshmallow.fields.List(marshmallow.fields.Nested(child_schema()))}, + ) + name = "abc.BarList" + else: + raise ValueError(f"Wrong schema lib: {app.schema.schema_library.lib}") + return namedtuple("BarListSchema", ("schema", "name"))(schema=schema, name=name) + + @pytest.fixture(scope="function") + def bar_dict_schema(self, app, foo_schema): + child_schema = foo_schema.schema + if app.schema.schema_library.lib == pydantic: + schema = pydantic.create_model("BarDict", foo=(t.Dict[str, child_schema], ...)) + name = "pydantic.main.BarDict" + elif app.schema.schema_library.lib == typesystem: + schema = typesystem.Schema( + title="BarDict", + fields={ + "foo": typesystem.Object( + properties={ + "x": typesystem.Reference( + to="Foo", definitions=typesystem.Definitions({"Foo": child_schema}) + ) + } + ) + }, + ) + name = "typesystem.schemas.BarDict" + elif app.schema.schema_library.lib == marshmallow: + schema = type( + "BarDict", + (marshmallow.Schema,), + {"foo": marshmallow.fields.Dict(values=marshmallow.fields.Nested(child_schema()))}, + ) + name = "abc.BarDict" + else: + raise ValueError(f"Wrong schema lib: {app.schema.schema_library.lib}") + return namedtuple("BarDictSchema", ("schema", "name"))(schema=schema, name=name) + + @pytest.fixture(scope="function") + def schemas(self, foo_schema, bar_schema, bar_list_schema, bar_dict_schema): + return {"Foo": foo_schema, "Bar": bar_schema, "BarList": bar_list_schema, "BarDict": bar_dict_schema} @pytest.fixture(scope="function") def spec(self): @@ -45,7 +131,7 @@ def test_empty_init(self): assert SchemaRegistry() == {} @pytest.mark.parametrize( - "operation,output", + ["operation", "register_schemas", "output"], [ pytest.param( openapi.Operation( @@ -62,9 +148,67 @@ def test_empty_init(self): } ) ), - True, + ["Foo"], + ["Foo"], id="response_reference", ), + pytest.param( + openapi.Operation( + responses=openapi.Responses( + { + "200": openapi.Response( + description="Bar", + content={ + "application/json": openapi.MediaType( + schema=openapi.Reference(ref="#!/components/schemas/Bar") + ) + }, + ) + } + ) + ), + ["Bar"], + ["Foo", "Bar"], + id="response_reference_nested", + ), + pytest.param( + openapi.Operation( + responses=openapi.Responses( + { + "200": openapi.Response( + description="Bar", + content={ + "application/json": openapi.MediaType( + schema=openapi.Reference(ref="#!/components/schemas/BarList") + ) + }, + ) + } + ) + ), + ["BarList"], + ["Foo", "BarList"], + id="response_reference_nested_list", + ), + pytest.param( + openapi.Operation( + responses=openapi.Responses( + { + "200": openapi.Response( + description="Bar", + content={ + "application/json": openapi.MediaType( + schema=openapi.Reference(ref="#!/components/schemas/BarDict") + ) + }, + ) + } + ) + ), + ["BarDict"], + ["Foo", "BarDict"], + id="response_reference_nested_dict", + ), pytest.param( openapi.Operation( responses=openapi.Responses( @@ -85,9 +229,82 @@ def test_empty_init(self): } ) ), - True, + ["Foo"], + ["Foo"], id="response_schema", ), + pytest.param( + openapi.Operation( + responses=openapi.Responses( + { + "200": openapi.Response( + description="Bar", + content={ + "application/json": openapi.MediaType( + schema=openapi.Schema( + { + "type": "object", + "properties": {"bar": {"$ref": "#!/components/schemas/Bar"}}, + } + ), + ) + }, + ) + } + ) + ), + ["Bar"], + ["Foo", "Bar"], + id="response_schema_nested", + ), + pytest.param( + openapi.Operation( + responses=openapi.Responses( + { + "200": openapi.Response( + description="BarList", + content={ + "application/json": openapi.MediaType( + schema=openapi.Schema( + { + "type": "object", + "properties": {"bar": {"$ref": "#!/components/schemas/BarList"}}, + } + ), + ) + }, + ) + } + ) + ), + ["BarList"], + ["Foo", "BarList"], + id="response_schema_nested_list", + ), + pytest.param( + openapi.Operation( + responses=openapi.Responses( + { + "200": openapi.Response( + description="BarDict", + content={ + "application/json": openapi.MediaType( + schema=openapi.Schema( + { + "type": "object", + "properties": {"bar": {"$ref": "#!/components/schemas/BarDict"}}, + } + ), + ) + }, + ) + } + ) + ), + ["BarDict"], + ["Foo", "BarDict"], + id="response_schema_nested_dict", + ), pytest.param( openapi.Operation( responses=openapi.Responses( @@ -113,7 +330,8 @@ def test_empty_init(self): } ) ), - True, + ["Foo"], + ["Foo"], id="response_array", ), pytest.param( @@ -133,16 +351,42 @@ def test_empty_init(self): } ) ), - False, + [], + [], id="response_wrong", ), pytest.param( openapi.Operation( requestBody=openapi.Reference(ref="#!/components/schemas/Foo"), responses=openapi.Responses({}) ), - True, + ["Foo"], + ["Foo"], id="body_reference", ), + pytest.param( + openapi.Operation( + requestBody=openapi.Reference(ref="#!/components/schemas/Bar"), responses=openapi.Responses({}) + ), + ["Bar"], + ["Foo", "Bar"], + id="body_reference_nested", + ), + pytest.param( + openapi.Operation( + requestBody=openapi.Reference(ref="#!/components/schemas/BarList"), responses=openapi.Responses({}) + ), + ["BarList"], + ["Foo", "BarList"], + id="body_reference_nested_list", + ), + pytest.param( + openapi.Operation( + requestBody=openapi.Reference(ref="#!/components/schemas/BarDict"), responses=openapi.Responses({}) + ), + ["BarDict"], + ["Foo", "BarDict"], + id="body_reference_nested_dict", + ), pytest.param( openapi.Operation( requestBody=openapi.RequestBody( @@ -157,9 +401,64 @@ def test_empty_init(self): ), responses=openapi.Responses({}), ), - True, + ["Foo"], + ["Foo"], id="body_schema", ), + pytest.param( + openapi.Operation( + requestBody=openapi.RequestBody( + description="Bar", + content={ + "application/json": openapi.MediaType( + schema=openapi.Schema( + {"type": "object", "properties": {"foo": {"$ref": "#!/components/schemas/Bar"}}} + ), + ) + }, + ), + responses=openapi.Responses({}), + ), + ["Bar"], + ["Foo", "Bar"], + id="body_schema_nested", + ), + pytest.param( + openapi.Operation( + requestBody=openapi.RequestBody( + description="BarList", + content={ + "application/json": openapi.MediaType( + schema=openapi.Schema( + {"type": "object", "properties": {"foo": {"$ref": "#!/components/schemas/BarList"}}} + ), + ) + }, + ), + responses=openapi.Responses({}), + ), + ["BarList"], + ["Foo", "BarList"], + id="body_schema_nested_list", + ), + pytest.param( + openapi.Operation( + requestBody=openapi.RequestBody( + description="BarDict", + content={ + "application/json": openapi.MediaType( + schema=openapi.Schema( + {"type": "object", "properties": {"foo": {"$ref": "#!/components/schemas/BarDict"}}} + ), + ) + }, + ), + responses=openapi.Responses({}), + ), + ["BarDict"], + ["Foo", "BarDict"], + id="body_schema_nested_dict", + ), pytest.param( openapi.Operation( requestBody=openapi.RequestBody( @@ -179,7 +478,8 @@ def test_empty_init(self): ), responses=openapi.Responses({}), ), - True, + ["Foo"], + ["Foo"], id="body_array", ), pytest.param( @@ -194,16 +494,42 @@ def test_empty_init(self): ), responses=openapi.Responses({}), ), - False, + [], + [], id="body_wrong", ), pytest.param( openapi.Operation( parameters=[openapi.Reference(ref="#!/components/schemas/Foo")], responses=openapi.Responses({}) ), - True, + ["Foo"], + ["Foo"], id="parameter_reference", ), + pytest.param( + openapi.Operation( + parameters=[openapi.Reference(ref="#!/components/schemas/Bar")], responses=openapi.Responses({}) + ), + ["Bar"], + ["Foo", "Bar"], + id="parameter_reference_nested", + ), + pytest.param( + openapi.Operation( + parameters=[openapi.Reference(ref="#!/components/schemas/BarList")], responses=openapi.Responses({}) + ), + ["BarList"], + ["Foo", "BarList"], + id="parameter_reference_nested_list", + ), + pytest.param( + openapi.Operation( + parameters=[openapi.Reference(ref="#!/components/schemas/BarDict")], responses=openapi.Responses({}) + ), + ["BarDict"], + ["Foo", "BarDict"], + id="parameter_reference_nested_dict", + ), pytest.param( openapi.Operation( parameters=[ @@ -217,9 +543,61 @@ def test_empty_init(self): ], responses=openapi.Responses({}), ), - True, + ["Foo"], + ["Foo"], id="parameter_schema", ), + pytest.param( + openapi.Operation( + parameters=[ + openapi.Parameter( + in_="query", + name="bar", + schema=openapi.Schema( + {"type": "object", "properties": {"bar": {"$ref": "#!/components/schemas/Bar"}}} + ), + ) + ], + responses=openapi.Responses({}), + ), + ["Bar"], + ["Foo", "Bar"], + id="parameter_schema_nested", + ), + pytest.param( + openapi.Operation( + parameters=[ + openapi.Parameter( + in_="query", + name="bar", + schema=openapi.Schema( + {"type": "object", "properties": {"bar": {"$ref": "#!/components/schemas/BarList"}}} + ), + ) + ], + responses=openapi.Responses({}), + ), + ["BarList"], + ["Foo", "BarList"], + id="parameter_schema_nested_list", + ), + pytest.param( + openapi.Operation( + parameters=[ + openapi.Parameter( + in_="query", + name="bar", + schema=openapi.Schema( + {"type": "object", "properties": {"bar": {"$ref": "#!/components/schemas/BarDict"}}} + ), + ) + ], + responses=openapi.Responses({}), + ), + ["BarDict"], + ["Foo", "BarDict"], + id="parameter_schema_nested_dict", + ), pytest.param( openapi.Operation( parameters=[ @@ -238,7 +616,8 @@ def test_empty_init(self): ], responses=openapi.Responses({}), ), - True, + ["Foo"], + ["Foo"], id="parameter_array", ), pytest.param( @@ -252,7 +631,8 @@ def test_empty_init(self): ], responses=openapi.Responses({}), ), - False, + [], + [], id="parameter_wrong", ), pytest.param( @@ -260,9 +640,37 @@ def test_empty_init(self): callbacks={"200": openapi.Reference(ref="#!/components/schemas/Foo")}, responses=openapi.Responses({}), ), - True, + ["Foo"], + ["Foo"], id="callback_reference", ), + pytest.param( + openapi.Operation( + callbacks={"200": openapi.Reference(ref="#!/components/schemas/Bar")}, + responses=openapi.Responses({}), + ), + ["Bar"], + ["Foo", "Bar"], + id="callback_reference_nested", + ), + pytest.param( + openapi.Operation( + callbacks={"200": openapi.Reference(ref="#!/components/schemas/BarList")}, + responses=openapi.Responses({}), + ), + ["BarList"], + ["Foo", "BarList"], + id="callback_reference_nested_list", + ), + pytest.param( + openapi.Operation( + callbacks={"200": openapi.Reference(ref="#!/components/schemas/BarDict")}, + responses=openapi.Responses({}), + ), + ["BarDict"], + ["Foo", "BarDict"], + id="callback_reference_nested_dict", + ), pytest.param( openapi.Operation( callbacks={ @@ -296,9 +704,121 @@ def test_empty_init(self): }, responses=openapi.Responses({}), ), - True, + ["Foo"], + ["Foo"], id="callback_schema", ), + pytest.param( + openapi.Operation( + callbacks={ + "foo": openapi.Callback( + { + "/callback": openapi.Path( + get=openapi.Operation( + responses=openapi.Responses( + { + "200": openapi.Response( + description="Bar", + content={ + "application/json": openapi.MediaType( + schema=openapi.Schema( + { + "type": "object", + "properties": { + "bar": {"$ref": "#!/components/schemas/Bar"} + }, + } + ) + ) + }, + ) + } + ) + ) + ) + } + ) + }, + responses=openapi.Responses({}), + ), + ["Bar"], + ["Foo", "Bar"], + id="callback_schema_nested", + ), + pytest.param( + openapi.Operation( + callbacks={ + "foo": openapi.Callback( + { + "/callback": openapi.Path( + get=openapi.Operation( + responses=openapi.Responses( + { + "200": openapi.Response( + description="BarList", + content={ + "application/json": openapi.MediaType( + schema=openapi.Schema( + { + "type": "object", + "properties": { + "bar": {"$ref": "#!/components/schemas/BarList"} + }, + } + ) + ) + }, + ) + } + ) + ) + ) + } + ) + }, + responses=openapi.Responses({}), + ), + ["BarList"], + ["Foo", "BarList"], + id="callback_schema_nested_list", + ), + pytest.param( + openapi.Operation( + callbacks={ + "foo": openapi.Callback( + { + "/callback": openapi.Path( + get=openapi.Operation( + responses=openapi.Responses( + { + "200": openapi.Response( + description="BarDict", + content={ + "application/json": openapi.MediaType( + schema=openapi.Schema( + { + "type": "object", + "properties": { + "bar": {"$ref": "#!/components/schemas/BarDict"} + }, + } + ) + ) + }, + ) + } + ) + ) + ) + } + ) + }, + responses=openapi.Responses({}), + ), + ["BarDict"], + ["Foo", "BarDict"], + id="callback_schema_nested_dict", + ), pytest.param( openapi.Operation( callbacks={ @@ -330,49 +850,54 @@ def test_empty_init(self): }, responses=openapi.Responses({}), ), - False, + [], + [], id="callback_wrong", ), ], ) - def test_used(self, registry, foo_schema, spec, operation, output): - expected_output = {id(foo_schema): registry[foo_schema]} if output else {} + def test_used(self, registry, schemas, spec, operation, register_schemas, output): + for schema in register_schemas: + registry.register(schemas[schema].schema, name=schema) + + expected_output = {id(schemas[schema].schema) for schema in output} + spec.add_path("/", openapi.Path(get=operation)) - assert registry.used(spec) == expected_output + + assert set(registry.used(spec).keys()) == expected_output @pytest.mark.parametrize( - ["schema", "name", "expected_name", "exception"], + ["schema", "explicit_name", "output"], [ - pytest.param(typesystem.Schema(title="Foo", fields={}), "Foo", "Foo", None, id="typesystem_explicit_name"), - pytest.param( - typesystem.Schema(title="Foo", fields={}), - None, - "abc.Foo", - ValueError("Cannot infer schema name."), - id="typesystem_cannot_infer_name", - ), - pytest.param(type("Foo", (marshmallow.Schema,), {}), None, "abc.Foo", None, id="marshmallow_infer_name"), - pytest.param( - pydantic.create_model("Foo", name=(str, ...)), None, "pydantic.main.Foo", None, id="pydantic_infer_name" - ), + pytest.param("Foo", "Foo", {"Foo": "Foo"}, id="explicit_name"), + pytest.param("Foo", None, {"Foo": None}, id="infer_name"), + pytest.param("Bar", "Bar", {"Bar": "Bar"}, id="nested_schemas"), ], - indirect=["exception"], ) - def test_register(self, registry, schema, name, expected_name, exception): + def test_register(self, registry, schemas, schema, explicit_name, output): + schema, name = schemas[schema] + expected_name = name if not explicit_name else explicit_name + exception = ( + contextlib.ExitStack() if expected_name else pytest.raises(ValueError, match="Cannot infer schema name.") + ) with exception: - registry.register(schema, name=name) - assert registry[schema].name == expected_name + registry.register(schema, name=explicit_name) + for s, n in output.items(): + assert schemas[s].schema in registry + assert registry[schemas[s].schema].name == (n or schemas[s].name) def test_register_already_registered(self, registry, foo_schema): + schema = foo_schema.schema + registry.register(schema, name="Foo") with pytest.raises(ValueError, match="Schema is already registered."): - registry.register(foo_schema, name="Foo") + registry.register(schema, name="Foo") @pytest.mark.parametrize( ["multiple", "result"], ( - pytest.param(False, openapi.Reference(ref="#/components/schemas/Foo"), id="single"), + pytest.param(False, openapi.Reference(ref="#/components/schemas/Foo"), id="schema"), pytest.param( - True, openapi.Schema({"type": "array", "items": {"$ref": "#/components/schemas/Foo"}}), id="multiple" + True, openapi.Schema({"type": "array", "items": {"$ref": "#/components/schemas/Foo"}}), id="array" ), pytest.param( None, @@ -384,79 +909,86 @@ def test_register_already_registered(self, registry, foo_schema): ] } ), - id="multiple", + id="array_or_schema", ), ), ) def test_get_openapi_ref(self, multiple, result, registry, foo_schema): - assert registry.get_openapi_ref(foo_schema, multiple=multiple) == result + schema = foo_schema.schema + registry.register(schema, name="Foo") + assert registry.get_openapi_ref(schema, multiple=multiple) == result class TestCaseSchemaGenerator: @pytest.fixture(scope="function") def owner_schema(self, app): - from flama import schemas - - if schemas.lib == pydantic: + if app.schema.schema_library.lib == pydantic: schema = pydantic.create_model("Owner", name=(str, ...)) - elif schemas.lib == typesystem: + name = "pydantic.main.Owner" + elif app.schema.schema_library.lib == typesystem: schema = typesystem.Schema(title="Owner", fields={"name": typesystem.fields.String()}) - elif schemas.lib == marshmallow: + name = "typesystem.schemas.Owner" + elif app.schema.schema_library.lib == marshmallow: schema = type("Owner", (marshmallow.Schema,), {"name": marshmallow.fields.String()}) + name = "abc.Owner" else: raise ValueError("Wrong schema lib") app.schema.schemas["Owner"] = schema - return schema + return namedtuple("OwnerSchema", ("schema", "name"))(schema, name) @pytest.fixture(scope="function") def puppy_schema(self, app, owner_schema): - from flama import schemas - - if schemas.lib == pydantic: - schema = pydantic.create_model("Puppy", name=(str, ...), owner=(owner_schema, ...)) - elif schemas.lib == typesystem: + if app.schema.schema_library.lib == pydantic: + schema = pydantic.create_model("Puppy", name=(str, ...), owner=(owner_schema.schema, ...)) + name = "pydantic.main.Puppy" + elif app.schema.schema_library.lib == typesystem: schema = typesystem.Schema( title="Puppy", fields={ "name": typesystem.fields.String(), - "owner": typesystem.Reference(to="Owner", definitions=app.schema.schemas), + "owner": typesystem.Reference( + to="Owner", definitions=typesystem.Definitions({"Owner": owner_schema.schema}) + ), }, ) - elif schemas.lib == marshmallow: + name = "typesystem.schemas.Puppy" + elif app.schema.schema_library.lib == marshmallow: schema = type( "Puppy", (marshmallow.Schema,), { "name": marshmallow.fields.String(), - "owner": marshmallow.fields.Nested(owner_schema), + "owner": marshmallow.fields.Nested(owner_schema.schema), }, ) + name = "abc.Puppy" else: raise ValueError("Wrong schema lib") app.schema.schemas["Puppy"] = schema - return schema + return namedtuple("PuppySchema", ("schema", "name"))(schema, name) @pytest.fixture(scope="function") def body_param_schema(self, app): - from flama import schemas - - if schemas.lib == pydantic: + if app.schema.schema_library.lib == pydantic: schema = pydantic.create_model("BodyParam", name=(str, ...)) - elif schemas.lib == typesystem: - schema = typesystem.Schema(fields={"name": typesystem.fields.String()}) - elif schemas.lib == marshmallow: + name = "pydantic.main.BodyParam" + elif app.schema.schema_library.lib == typesystem: + schema = typesystem.Schema(title="BodyParam", fields={"name": typesystem.fields.String()}) + name = "typesystem.schemas.BodyParam" + elif app.schema.schema_library.lib == marshmallow: schema = type("BodyParam", (marshmallow.Schema,), {"name": marshmallow.fields.String()}) + name = "abc.BodyParam" else: raise ValueError("Wrong schema lib") app.schema.schemas["BodyParam"] = schema - return schema + return namedtuple("BodyParamSchema", ("schema", "name"))(schema, name) @pytest.fixture(scope="function", autouse=True) def add_endpoints(self, app, puppy_schema, body_param_schema): @app.route("/endpoint/", methods=["GET"]) class PuppyEndpoint(HTTPEndpoint): - async def get(self) -> puppy_schema: + async def get(self) -> types.Schema[puppy_schema.schema]: """ description: Endpoint. responses: @@ -466,7 +998,7 @@ async def get(self) -> puppy_schema: return {"name": "Canna"} @app.route("/custom-component/", methods=["GET"]) - async def get() -> puppy_schema: + async def get() -> types.Schema[puppy_schema.schema]: """ description: Custom component. responses: @@ -476,7 +1008,7 @@ async def get() -> puppy_schema: return {"name": "Canna"} @app.route("/many-components/", methods=["GET"]) - async def many_components() -> types.Schema[puppy_schema]: + async def many_components() -> types.Schema[puppy_schema.schema]: """ description: Many custom components. responses: @@ -506,7 +1038,7 @@ async def path_param(param: int): return {"name": param} @app.route("/body-param/", methods=["POST"]) - async def body_param(param: types.Schema[body_param_schema]): + async def body_param(param: types.Schema[body_param_schema.schema]): """ description: Body param. responses: diff --git a/tests/test_pagination.py b/tests/test_pagination.py index a3e6bc60..2ffe780e 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -10,24 +10,29 @@ from flama.pagination import paginator +@pytest.fixture(scope="function") +def app(app): + paginator.schemas = {} + return app + + @pytest.fixture(scope="function") def output_schema(app): - from flama import schemas - if schemas.lib == pydantic: + if app.schema.schema_library.lib == pydantic: schema = pydantic.create_model("OutputSchema", value=(t.Optional[int], ...)) - elif schemas.lib == typesystem: + elif app.schema.schema_library.lib == typesystem: schema = typesystem.Schema(title="OutputSchema", fields={"value": typesystem.fields.Integer(allow_null=True)}) - elif schemas.lib == marshmallow: + elif app.schema.schema_library.lib == marshmallow: schema = type("OutputSchema", (marshmallow.Schema,), {"value": marshmallow.fields.Integer(allow_none=True)}) else: - raise ValueError("Wrong schema lib") + raise ValueError(f"Wrong schema lib: {app.schema.schema_library.lib}") app.schema.schemas["OutputSchema"] = schema return schema -class TestPageNumberResponse: +class TestCasePageNumberPagination: @pytest.fixture(scope="function", autouse=True) def add_endpoints(self, app, output_schema): @app.route("/page-number/", methods=["GET"]) @@ -158,7 +163,7 @@ async def test_params(self, client, params, status_code, expected): assert response.json() == expected -class TestLimitOffsetResponse: +class TestCaseLimitOffsetPagination: @pytest.fixture(scope="function", autouse=True) def add_endpoints(self, app, output_schema): @app.route("/limit-offset/", methods=["GET"]) @@ -176,7 +181,7 @@ def test_registered_schemas(self, app): "flama.APIError", } - def test_invalid_view(self, app, output_schema): + def test_invalid_view(self, output_schema): with pytest.raises(TypeError, match=r"Paginated views must define \*\*kwargs param"): @paginator.limit_offset(schema_name="Foo")