diff --git a/flama/http.py b/flama/http.py index 305fd30f..c3415fd5 100644 --- a/flama/http.py +++ b/flama/http.py @@ -113,7 +113,7 @@ def __init__(self, content: t.Any = None, schema: t.Optional["flama.schemas.type def render(self, content: t.Any): if self.schema is not None: try: - content = schemas.Schema(self.schema, multiple=isinstance(content, list)).dump(content) + content = schemas.Schema(self.schema).dump(content) except schemas.SchemaValidationError as e: raise SerializationError(status_code=500, detail=e.errors) diff --git a/flama/pagination/limit_offset.py b/flama/pagination/limit_offset.py index 4d7cdcbf..4ad3b8f5 100644 --- a/flama/pagination/limit_offset.py +++ b/flama/pagination/limit_offset.py @@ -73,8 +73,7 @@ def _inner(func: t.Callable): schema = schemas.Schema.build( paginated_schema_name, schema=schemas.schemas.LimitOffset, - fields=[schemas.Field("data", t.List[resource_schema])], # type: ignore[valid-type] - multiple=False, + fields=[schemas.Field("data", resource_schema, multiple=True)], ).unique_schema forge_revision_list = ( diff --git a/flama/pagination/page_number.py b/flama/pagination/page_number.py index 300adb54..4cba8b34 100644 --- a/flama/pagination/page_number.py +++ b/flama/pagination/page_number.py @@ -80,8 +80,7 @@ def _inner(func: t.Callable): schema = schemas.Schema.build( paginated_schema_name, schema=schemas.schemas.PageNumber, - fields=[schemas.Field("data", t.List[resource_schema])], # type: ignore[valid-type] - multiple=False, + fields=[schemas.Field("data", resource_schema, multiple=True)], ).unique_schema forge_revision_list = ( diff --git a/flama/schemas/data_structures.py b/flama/schemas/data_structures.py index a6d834f3..9b7879c1 100644 --- a/flama/schemas/data_structures.py +++ b/flama/schemas/data_structures.py @@ -20,23 +20,28 @@ class Field: type: t.Type nullable: bool = dataclasses.field(init=False) field: t.Any = dataclasses.field(hash=False, init=False, compare=False) - multiple: bool = dataclasses.field(hash=False, init=False, compare=False) + multiple: t.Optional[bool] = dataclasses.field(hash=False, compare=False, default=None) required: bool = True default: t.Any = InjectionParameter.empty def __post_init__(self) -> None: object.__setattr__(self, "nullable", type(None) in t.get_args(self.type) or self.default is None) - object.__setattr__(self, "multiple", t.get_origin(self.type) is list) + + field_type = t.get_args(self.type)[0] if t.get_origin(self.type) in (list, t.Union) else self.type + + if not Schema.is_schema(field_type) and self.multiple is None: + object.__setattr__(self, "multiple", t.get_origin(self.type) is list) + object.__setattr__( self, "field", schemas.adapter.build_field( self.name, - t.get_args(self.type)[0] if t.get_origin(self.type) in (list, t.Union) else self.type, + field_type, nullable=self.nullable, required=self.required, default=self.default, - multiple=self.multiple, + multiple=bool(self.multiple), ), ) @@ -69,28 +74,21 @@ def is_http_valid_type(cls, type_: t.Type) -> bool: @property def json_schema(self) -> schemas.types.JSONSchema: - schema = schemas.adapter.to_json_schema(self.field) - - if self.multiple: - schema = {"items": {"$ref": schema}, "type": "array"} - - return schema + return schemas.adapter.to_json_schema(self.field) @dataclasses.dataclass(frozen=True) class Schema: schema: t.Any = dataclasses.field(hash=False, compare=False) - multiple: bool = dataclasses.field(hash=False, compare=False, default=False) @classmethod - def from_type(cls, type: t.Optional[t.Type]) -> "Schema": - multiple = t.get_origin(type) is list - schema = t.get_args(type)[0] if multiple else type + def from_type(cls, type_: t.Optional[t.Type]) -> "Schema": + schema = t.get_args(type_)[0] if t.get_origin(type_) is list else type_ if not schemas.adapter.is_schema(schema): raise ValueError("Wrong schema type") - return cls(schema=schema, multiple=multiple) + return cls(schema=schema) @classmethod def build( @@ -98,13 +96,11 @@ def build( name: t.Optional[str] = None, schema: t.Any = None, fields: t.Optional[t.List[Field]] = None, - multiple: bool = False, ) -> "Schema": return cls( schema=schemas.adapter.build_schema( name=name, schema=schema, fields={f.name: f.field for f in (fields or [])} ), - multiple=multiple, ) @classmethod @@ -113,12 +109,7 @@ def is_schema(cls, obj: t.Any) -> bool: @property def json_schema(self) -> t.Dict[str, t.Any]: - schema = schemas.adapter.to_json_schema(self.schema) - - if self.multiple: - schema = {"items": {"$ref": schema}, "type": "array"} - - return schema + return schemas.adapter.to_json_schema(self.schema) @property def unique_schema(self) -> t.Any: @@ -133,7 +124,7 @@ def validate(self, values: t.List[t.Dict[str, t.Any]]) -> t.List[t.Dict[str, t.A ... def validate(self, values): - if self.multiple and isinstance(values, (list, tuple)): + if isinstance(values, (list, tuple)): return [schemas.adapter.validate(self.schema, value) for value in values] return schemas.adapter.validate(self.schema, values) @@ -147,7 +138,7 @@ def load(self, values: t.List[t.Dict[str, t.Any]]) -> t.List[t.Any]: ... def load(self, values): - if self.multiple: + if isinstance(values, (list, tuple)): return [schemas.adapter.load(self.schema, value) for value in values] return schemas.adapter.load(self.schema, values) @@ -161,7 +152,7 @@ def dump(self, values: t.List[t.Dict[str, t.Any]]) -> t.List[t.Dict[str, t.Any]] ... def dump(self, values): - if self.multiple and isinstance(values, (list, tuple)): + if isinstance(values, (list, tuple)): return [schemas.adapter.dump(self.schema, value) for value in values] return schemas.adapter.dump(self.schema, values) @@ -186,7 +177,7 @@ def __post_init__(self) -> None: field = None except ValueError: if self.type in (None, InjectionParameter.empty): - schema = Schema(schema=None, multiple=False) + schema = Schema(schema=None) field = None else: schema = None diff --git a/flama/schemas/generator.py b/flama/schemas/generator.py index 1133b9d0..732e6d21 100644 --- a/flama/schemas/generator.py +++ b/flama/schemas/generator.py @@ -70,20 +70,25 @@ def _get_schema_references_from_schema( return [schema.ref] result = [] - for name, prop in schema.get("properties", {}).items(): - if "$ref" in prop: - result.append(prop["$ref"]) - if prop.get("type", "") == "array" and prop.get("items", {}).get("$ref"): - result.append(prop["items"]["$ref"]) + if "$ref" in schema: + result.append(schema["$ref"]) - result += [ - ref["$ref"] - for composer in ("allOf", "anyOf", "oneOff") - if composer in prop - for ref in prop[composer] - if "$ref" in ref - ] + if schema.get("type", "") == "array" and schema.get("items", {}).get("$ref"): + result.append(schema["items"]["$ref"]) + + result += [ + ref + for composer in ("allOf", "anyOf", "oneOf") + for composer_schema in schema.get(composer, []) + for ref in self._get_schema_references_from_schema(composer_schema) + ] + + result += [ + ref + for prop in schema.get("properties", {}).values() + for ref in self._get_schema_references_from_schema(prop) + ] return result @@ -201,7 +206,7 @@ def register(self, schema: schemas.types.Schema, name: t.Optional[str] = None) - return schema_id def get_openapi_ref( - self, element: schemas.types.Schema, multiple: bool = False + self, element: schemas.types.Schema, multiple: t.Optional[bool] = None ) -> typing.Union[openapi.Schema, openapi.Reference]: """ Builds the reference for a single schema or the array schema containing the reference. @@ -210,12 +215,14 @@ def get_openapi_ref( :param multiple: True for building a schema containing an array of references instead of a single reference. :return: Reference or array schema. """ - reference = openapi.Reference(ref=self[element].ref) - - if multiple: - return openapi.Schema({"items": dataclasses.asdict(reference), "type": "array"}) + reference = self[element].ref - return reference + if multiple is True: + return openapi.Schema({"items": {"$ref": reference}, "type": "array"}) + elif multiple is None: + return openapi.Schema({"oneOf": [{"$ref": reference}, {"items": {"$ref": reference}, "type": "array"}]}) + else: + return openapi.Reference(ref=reference) class SchemaGenerator(starlette_schemas.BaseSchemaGenerator): @@ -355,9 +362,7 @@ def _build_endpoint_body( return openapi.RequestBody( content={ "application/json": openapi.MediaType( - schema=self.schemas.get_openapi_ref( - endpoint.body_parameter.schema.schema, multiple=endpoint.body_parameter.schema.multiple - ) + schema=self.schemas.get_openapi_ref(endpoint.body_parameter.schema.schema, multiple=False), ) }, **{ @@ -383,9 +388,7 @@ def _build_endpoint_response( content = { "application/json": openapi.MediaType( - schema=self.schemas.get_openapi_ref( - endpoint.response_parameter.schema.schema, multiple=endpoint.response_parameter.schema.multiple - ) + schema=self.schemas.get_openapi_ref(endpoint.response_parameter.schema.schema) ) } else: @@ -403,7 +406,9 @@ def _build_endpoint_default_response(self, metadata: typing.Dict[str, typing.Any return openapi.Response( description=metadata.get("responses", {}).get("default", {}).get("description", "Unexpected error."), content={ - "application/json": openapi.MediaType(schema=self.schemas.get_openapi_ref(schemas.schemas.APIError)) + "application/json": openapi.MediaType( + schema=self.schemas.get_openapi_ref(schemas.schemas.APIError, multiple=False) + ) }, ) diff --git a/poetry.lock b/poetry.lock index e218414d..c89dfbb9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1880,7 +1880,7 @@ typesystem = ["typesystem"] [metadata] lock-version = "1.1" python-versions = ">=3.7,<3.12" -content-hash = "363944d3bb558ea113e1260fa74bfec09993ee17f0d581662fa4b063fe86c25d" +content-hash = "2f189b9f7b605605355d9ebb419dde524cb02ea43d06ab076a70ca23c5ea0b11" [metadata.files] absl-py = [ diff --git a/tests/schemas/test_generator.py b/tests/schemas/test_generator.py index 69e55c80..cc7b18ba 100644 --- a/tests/schemas/test_generator.py +++ b/tests/schemas/test_generator.py @@ -367,11 +367,23 @@ def test_register_already_registered(self, registry, foo_schema): ( pytest.param(False, openapi.Reference(ref="#/components/schemas/Foo"), id="single"), pytest.param( - True, openapi.Schema({"type": "array", "items": {"ref": "#/components/schemas/Foo"}}), id="multiple" + True, openapi.Schema({"type": "array", "items": {"$ref": "#/components/schemas/Foo"}}), id="multiple" + ), + pytest.param( + None, + openapi.Schema( + { + "oneOf": [ + {"$ref": "#/components/schemas/Foo"}, + {"type": "array", "items": {"$ref": "#/components/schemas/Foo"}}, + ] + } + ), + id="multiple", ), ), ) - def test_get_openapi_ref_single(self, multiple, result, registry, foo_schema): + def test_get_openapi_ref(self, multiple, result, registry, foo_schema): assert registry.get_openapi_ref(foo_schema, multiple=multiple) == result @@ -419,10 +431,6 @@ def puppy_schema(self, app, owner_schema): app.schema.schemas["Puppy"] = schema return schema - @pytest.fixture(scope="function") - def puppy_array_schema(self, app, puppy_schema): - return t.List[puppy_schema] - @pytest.fixture(scope="function") def body_param_schema(self, app): from flama import schemas @@ -440,7 +448,7 @@ def body_param_schema(self, app): return schema @pytest.fixture(scope="function", autouse=True) - def add_endpoints(self, app, puppy_schema, puppy_array_schema, body_param_schema): + def add_endpoints(self, app, puppy_schema, body_param_schema): @app.route("/endpoint/", methods=["GET"]) class PuppyEndpoint(HTTPEndpoint): async def get(self) -> puppy_schema: @@ -463,7 +471,7 @@ async def get() -> puppy_schema: return {"name": "Canna"} @app.route("/many-components/", methods=["GET"]) - async def many_components() -> puppy_array_schema: + async def many_components() -> puppy_schema: """ description: Many custom components. responses: @@ -601,7 +609,16 @@ def test_components_schemas(self, app): "responses": { "200": { "description": "Component.", - "content": {"application/json": {"schema": {"$ref": "#/components/schemas/Puppy"}}}, + "content": { + "application/json": { + "schema": { + "oneOf": [ + {"$ref": "#/components/schemas/Puppy"}, + {"items": {"$ref": "#/components/schemas/Puppy"}, "type": "array"}, + ] + } + } + }, } }, }, @@ -618,8 +635,10 @@ def test_components_schemas(self, app): "content": { "application/json": { "schema": { - "items": {"$ref": "#/components/schemas/Puppy"}, - "type": "array", + "oneOf": [ + {"$ref": "#/components/schemas/Puppy"}, + {"items": {"$ref": "#/components/schemas/Puppy"}, "type": "array"}, + ] } } }, @@ -636,7 +655,16 @@ def test_components_schemas(self, app): "responses": { "200": { "description": "Component.", - "content": {"application/json": {"schema": {"$ref": "#/components/schemas/Puppy"}}}, + "content": { + "application/json": { + "schema": { + "oneOf": [ + {"$ref": "#/components/schemas/Puppy"}, + {"items": {"$ref": "#/components/schemas/Puppy"}, "type": "array"}, + ] + } + } + }, } }, }, @@ -650,7 +678,16 @@ def test_components_schemas(self, app): "responses": { "200": { "description": "Component.", - "content": {"application/json": {"schema": {"$ref": "#/components/schemas/Puppy"}}}, + "content": { + "application/json": { + "schema": { + "oneOf": [ + {"$ref": "#/components/schemas/Puppy"}, + {"items": {"$ref": "#/components/schemas/Puppy"}, "type": "array"}, + ] + } + } + }, } }, }, @@ -664,7 +701,16 @@ def test_components_schemas(self, app): "responses": { "200": { "description": "Component.", - "content": {"application/json": {"schema": {"$ref": "#/components/schemas/Puppy"}}}, + "content": { + "application/json": { + "schema": { + "oneOf": [ + {"$ref": "#/components/schemas/Puppy"}, + {"items": {"$ref": "#/components/schemas/Puppy"}, "type": "array"}, + ] + } + } + }, } }, }, diff --git a/tests/schemas/test_validation.py b/tests/schemas/test_validation.py index d76f9c90..178a2813 100644 --- a/tests/schemas/test_validation.py +++ b/tests/schemas/test_validation.py @@ -54,12 +54,8 @@ def rating_validator(cls, x): app.schema.schemas["Product"] = schema return schema - @pytest.fixture(scope="function") - def product_array_schema(self, product_schema): - return t.List[product_schema] - @pytest.fixture(scope="function", autouse=True) - def add_endpoints(self, app, product_schema, product_array_schema): + def add_endpoints(self, app, product_schema): @app.route("/product", methods=["GET"]) @output_validation() def validate_product() -> product_schema: @@ -71,7 +67,7 @@ def validate_product() -> product_schema: @app.route("/many-products", methods=["GET"]) @output_validation() - def validate_many_products() -> product_array_schema: + def validate_many_products() -> product_schema: return [ { "name": "foo", diff --git a/tests/test_pagination.py b/tests/test_pagination.py index a09f4b8b..4eaee884 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -79,7 +79,17 @@ def test_pagination_schema_return(self, app): assert response_schema == { "description": "Description not provided.", "content": { - "application/json": {"schema": {"$ref": "#/components/schemas/PageNumberPaginatedOutputSchema"}}, + "application/json": { + "schema": { + "oneOf": [ + {"$ref": "#/components/schemas/PageNumberPaginatedOutputSchema"}, + { + "items": {"$ref": "#/components/schemas/PageNumberPaginatedOutputSchema"}, + "type": "array", + }, + ] + } + } }, } @@ -195,7 +205,17 @@ def test_pagination_schema_return(self, app): assert response_schema == { "description": "Description not provided.", "content": { - "application/json": {"schema": {"$ref": "#/components/schemas/LimitOffsetPaginatedOutputSchema"}} + "application/json": { + "schema": { + "oneOf": [ + {"$ref": "#/components/schemas/LimitOffsetPaginatedOutputSchema"}, + { + "items": {"$ref": "#/components/schemas/LimitOffsetPaginatedOutputSchema"}, + "type": "array", + }, + ] + } + } }, } diff --git a/tests/validation/test_schemas.py b/tests/validation/test_schemas.py index 01fadfce..b4e1b795 100644 --- a/tests/validation/test_schemas.py +++ b/tests/validation/test_schemas.py @@ -46,10 +46,6 @@ def product_schema(self, app): app.schema.register_schema("Product", schema) return schema - @pytest.fixture(scope="function") - def product_array_schema(self, app, product_schema): - return t.List[product_schema] - @pytest.fixture(scope="function") def reviewed_product_schema(self, app, product_schema): from flama import schemas @@ -135,9 +131,7 @@ def place_schema(self, app, location_schema): return schema @pytest.fixture(scope="function", autouse=True) - def add_endpoints( - self, app, product_schema, product_array_schema, reviewed_product_schema, location_schema, place_schema - ): + def add_endpoints(self, app, product_schema, reviewed_product_schema, location_schema, place_schema): @app.route("/product", methods=["POST"]) def product_identity(product: product_schema) -> product_schema: return product @@ -151,7 +145,7 @@ def place_identity(place: place_schema) -> place_schema: return place @app.route("/many-products", methods=["GET"]) - def many_products() -> product_array_schema: + def many_products() -> product_schema: return [ { "name": "foo",