From 5716a6fab0e897236cd27d7341f997defbc0e46a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Antonio=20Perdiguero=20L=C3=B3pez?= Date: Sun, 12 May 2024 09:43:42 +0200 Subject: [PATCH] :bug: Allow nullable nested schemas for Pydantic --- flama/schemas/_libs/marshmallow/adapter.py | 6 +++ flama/schemas/_libs/pydantic/adapter.py | 7 --- flama/schemas/_libs/typesystem/adapter.py | 6 +++ flama/schemas/data_structures.py | 10 +++- tests/schemas/conftest.py | 30 +++++++++++ tests/schemas/test_data_structures.py | 24 +++++++-- tests/schemas/test_generator.py | 2 +- tests/test_pagination.py | 61 ++++++++++++++++------ 8 files changed, 118 insertions(+), 28 deletions(-) diff --git a/flama/schemas/_libs/marshmallow/adapter.py b/flama/schemas/_libs/marshmallow/adapter.py index 1a907877..0d00c057 100644 --- a/flama/schemas/_libs/marshmallow/adapter.py +++ b/flama/schemas/_libs/marshmallow/adapter.py @@ -1,4 +1,5 @@ import inspect +import itertools import sys import typing as t @@ -122,6 +123,11 @@ def to_json_schema(self, schema: t.Union[t.Type[Schema], t.Type[Field], Schema, json_schema = converter.schema2jsonschema(schema) else: raise SchemaGenerationError + + for property in itertools.chain(json_schema.get("properties", {}).values(), [json_schema]): + if isinstance(property.get("type"), list): + property["anyOf"] = [{"type": x} for x in property["type"]] + del property["type"] except Exception as e: raise SchemaGenerationError from e diff --git a/flama/schemas/_libs/pydantic/adapter.py b/flama/schemas/_libs/pydantic/adapter.py index 63bdebd7..7f7bc6af 100644 --- a/flama/schemas/_libs/pydantic/adapter.py +++ b/flama/schemas/_libs/pydantic/adapter.py @@ -105,19 +105,12 @@ def to_json_schema(self, schema: t.Union[t.Type[Schema], t.Type[Field]]) -> JSON json_schema = model_json_schema(schema, ref_template="#/components/schemas/{model}") if "$defs" in json_schema: del json_schema["$defs"] - for property in json_schema["properties"].values(): - if "anyOf" in property: # Simplify type from anyOf to a list of types - property["type"] = [x["type"] for x in property["anyOf"]] - del property["anyOf"] elif self.is_field(schema): json_schema = model_json_schema( self.build_schema(fields={"x": schema}), ref_template="#/components/schemas/{model}" )["properties"]["x"] if not schema.title: # Pydantic is introducing a default title, so we drop it del json_schema["title"] - if "anyOf" in json_schema: # Simplify type from anyOf to a list of types - json_schema["type"] = [x["type"] for x in json_schema["anyOf"]] - del json_schema["anyOf"] else: raise TypeError("Not a valid schema class or field") diff --git a/flama/schemas/_libs/typesystem/adapter.py b/flama/schemas/_libs/typesystem/adapter.py index 85233c01..4912cd94 100644 --- a/flama/schemas/_libs/typesystem/adapter.py +++ b/flama/schemas/_libs/typesystem/adapter.py @@ -1,4 +1,5 @@ import inspect +import itertools import sys import typing as t import warnings @@ -102,6 +103,11 @@ def to_json_schema(self, schema: t.Union[Schema, Field]) -> JSONSchema: if not isinstance(json_schema, dict): raise SchemaGenerationError + for property in itertools.chain(json_schema.get("properties", {}).values(), [json_schema]): + if isinstance(property.get("type"), list): + property["anyOf"] = [{"type": x} for x in property["type"]] + del property["type"] + json_schema.pop("components", None) except Exception as e: raise SchemaGenerationError from e diff --git a/flama/schemas/data_structures.py b/flama/schemas/data_structures.py index 30bf69e5..8dfbc34b 100644 --- a/flama/schemas/data_structures.py +++ b/flama/schemas/data_structures.py @@ -87,7 +87,12 @@ def is_http_valid_type(cls, type_: t.Type) -> bool: return ( (type_ in types.PARAMETERS_TYPES) - or (origin is t.Union and len(args) == 2 and args[0] in types.PARAMETERS_TYPES and args[1] is NoneType) + or ( + origin in (t.Union, type(int | str)) + and len(args) == 2 + and args[0] in types.PARAMETERS_TYPES + and args[1] is NoneType + ) or (origin is list and args[0] in types.PARAMETERS_TYPES) ) @@ -186,6 +191,9 @@ def nested_schemas(self, schema: t.Any = UNKNOWN) -> t.List[t.Any]: if schemas.adapter.is_schema(schema): return [schemas.adapter.unique_schema(schema)] + if t.get_origin(schema) in (t.Union, type(int | str)): + return [x for field in t.get_args(schema) for x in self.nested_schemas(field)] + if isinstance(schema, (list, tuple, set)): return [x for field in schema for x in self.nested_schemas(field)] diff --git a/tests/schemas/conftest.py b/tests/schemas/conftest.py index 9d9e10c6..40be108a 100644 --- a/tests/schemas/conftest.py +++ b/tests/schemas/conftest.py @@ -44,6 +44,36 @@ def bar_schema(app, foo_schema): return namedtuple("BarSchema", ("schema", "name"))(schema=schema, name=name) +@pytest.fixture(scope="function") +def bar_optional_schema(app, foo_schema): + child_schema = foo_schema.schema + if app.schema.schema_library.lib == pydantic: + schema = pydantic.create_model( + "BarOptional", foo=(t.Union[child_schema, None], None), __module__="pydantic.main" + ) + name = "pydantic.main.BarOptional" + elif app.schema.schema_library.lib == typesystem: + schema = typesystem.Schema( + title="BarOptional", + fields={ + "foo": typesystem.Reference( + to="Foo", definitions=typesystem.Definitions({"Foo": child_schema}), allow_null=True, default=None + ) + }, + ) + name = "typesystem.schemas.BarOptional" + elif app.schema.schema_library.lib == marshmallow: + schema = type( + "BarOptional", + (marshmallow.Schema,), + {"foo": marshmallow.fields.Nested(child_schema(), required=False, default=None, allow_none=True)}, + ) + name = "abc.BarOptional" + else: + raise ValueError(f"Wrong schema lib: {app.schema.schema_library.lib}") + return namedtuple("BarOptionalSchema", ("schema", "name"))(schema=schema, name=name) + + @pytest.fixture(scope="function") def bar_list_schema(app, foo_schema): child_schema = foo_schema.schema diff --git a/tests/schemas/test_data_structures.py b/tests/schemas/test_data_structures.py index 89e7b933..0c8594ce 100644 --- a/tests/schemas/test_data_structures.py +++ b/tests/schemas/test_data_structures.py @@ -6,7 +6,9 @@ from copy import deepcopy from unittest.mock import Mock, call, patch +import marshmallow import pytest +import typesystem from flama import types from flama.injection import Parameter as InjectionParameter @@ -103,7 +105,7 @@ def test_json_schema(self): class TestCaseSchema: @pytest.fixture(scope="function") - def schema_type(self, request, foo_schema, bar_schema, bar_list_schema, bar_dict_schema): + def schema_type(self, app, request, foo_schema, bar_schema, bar_optional_schema, bar_list_schema, bar_dict_schema): if request.param is None: return None elif request.param == "bare_schema": @@ -116,6 +118,10 @@ def schema_type(self, request, foo_schema, bar_schema, bar_list_schema, bar_dict return types.PartialSchema[foo_schema.schema] elif request.param == "schema_nested": return types.Schema[bar_schema.schema] + elif request.param == "schema_nested_optional": + if app.schema.schema_library.lib in (typesystem, marshmallow): + pytest.skip("Library does not support optional nested schemas") + return types.Schema[bar_optional_schema.schema] elif request.param == "schema_nested_list": return types.Schema[bar_list_schema.schema] elif request.param == "schema_nested_dict": @@ -132,6 +138,7 @@ def schema_type(self, request, foo_schema, bar_schema, bar_list_schema, bar_dict pytest.param("schema_partial", None, id="schema_partial"), pytest.param("list_of_schema", None, id="list_of_schema"), pytest.param("schema_nested", None, id="schema_nested"), + pytest.param("schema_nested_optional", None, id="schema_nested_optional"), pytest.param("schema_nested_list", None, id="schema_nested_list"), pytest.param("schema_nested_dict", None, id="schema_nested_dict"), pytest.param(None, ValueError("Wrong schema type"), id="wrong"), @@ -178,7 +185,7 @@ def test_name(self): ), pytest.param( "schema_partial", - {"properties": {"name": {"type": ["string", "null"]}}, "type": "object"}, + {"properties": {"name": {"anyOf": [{"type": "string"}, {"type": "null"}]}}, "type": "object"}, None, id="partial", ), @@ -188,6 +195,15 @@ def test_name(self): "properties.foo", id="nested", ), + pytest.param( + "schema_nested_optional", + { + "properties": {"foo": {"anyOf": [{"$ref": "#/components/schemas/Foo"}, {"type": "null"}]}}, + "type": "object", + }, + "properties.foo.anyOf.0", + id="nested_optional", + ), pytest.param( "schema_nested_list", { @@ -217,7 +233,9 @@ def test_json_schema(self, schemas, schema_type, json_schema, key_to_replace): expected_result = deepcopy(json_schema) if key_to_replace: - subdict = functools.reduce(lambda x, k: x[k], key_to_replace.split("."), expected_result) + subdict = functools.reduce( + lambda x, k: x[int(k) if k.isnumeric() else k], key_to_replace.split("."), expected_result + ) subdict["$ref"] = subdict["$ref"].replace("Foo", schemas["Foo"].name) assert_recursive_contains(expected_result, result) diff --git a/tests/schemas/test_generator.py b/tests/schemas/test_generator.py index 588dc491..2ed6dd1c 100644 --- a/tests/schemas/test_generator.py +++ b/tests/schemas/test_generator.py @@ -1085,7 +1085,7 @@ def test_components_schemas(self, app, schemas): "name": "param2", "in": "query", "required": False, - "schema": {"type": ["string", "null"]}, + "schema": {"anyOf": [{"type": "string"}, {"type": "null"}]}, }, { "name": "param3", diff --git a/tests/test_pagination.py b/tests/test_pagination.py index b9f4a923..e35bdd2c 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -10,6 +10,7 @@ from flama import types from flama.pagination import paginator +from tests.asserts import assert_recursive_contains @pytest.fixture(scope="function") @@ -71,19 +72,33 @@ def test_pagination_schema_parameters(self, app): schema = app.schema.schema["paths"]["/page-number/"]["get"] parameters = schema.get("parameters", []) - for parameter in parameters: - parameter["schema"] = {k: v for k, v in parameter["schema"].items() if k in ("type", "default")} - - assert parameters == [ + assert_recursive_contains( { "name": "count", "in": "query", "required": False, - "schema": {"type": ["boolean", "null"], "default": False}, + "schema": {"anyOf": [{"type": "boolean"}, {"type": "null"}], "default": False}, + }, + parameters[0], + ) + assert_recursive_contains( + { + "name": "page", + "in": "query", + "required": False, + "schema": {"anyOf": [{"type": "integer"}, {"type": "null"}]}, + }, + parameters[1], + ) + assert_recursive_contains( + { + "name": "page_size", + "in": "query", + "required": False, + "schema": {"anyOf": [{"type": "integer"}, {"type": "null"}]}, }, - {"name": "page", "in": "query", "required": False, "schema": {"type": ["integer", "null"]}}, - {"name": "page_size", "in": "query", "required": False, "schema": {"type": ["integer", "null"]}}, - ] + parameters[2], + ) def test_pagination_schema_return(self, app, output_schema): prefix, name = output_schema.name.rsplit(".", 1) @@ -195,19 +210,33 @@ def test_pagination_schema_parameters(self, app): schema = app.schema.schema["paths"]["/limit-offset/"]["get"] parameters = schema.get("parameters", []) - for parameter in parameters: - parameter["schema"] = {k: v for k, v in parameter["schema"].items() if k in ("type", "default")} - - assert parameters == [ + assert_recursive_contains( { "name": "count", "in": "query", "required": False, - "schema": {"type": ["boolean", "null"], "default": False}, + "schema": {"anyOf": [{"type": "boolean"}, {"type": "null"}], "default": False}, + }, + parameters[0], + ) + assert_recursive_contains( + { + "name": "limit", + "in": "query", + "required": False, + "schema": {"anyOf": [{"type": "integer"}, {"type": "null"}]}, + }, + parameters[1], + ) + assert_recursive_contains( + { + "name": "offset", + "in": "query", + "required": False, + "schema": {"anyOf": [{"type": "integer"}, {"type": "null"}]}, }, - {"name": "limit", "in": "query", "required": False, "schema": {"type": ["integer", "null"]}}, - {"name": "offset", "in": "query", "required": False, "schema": {"type": ["integer", "null"]}}, - ] + parameters[2], + ) def test_pagination_schema_return(self, app, output_schema): prefix, name = output_schema.name.rsplit(".", 1)