Skip to content

Commit

Permalink
🐛 Allow nullable nested schemas for Pydantic
Browse files Browse the repository at this point in the history
  • Loading branch information
perdy authored and migduroli committed Sep 3, 2024
1 parent c4d0915 commit 5716a6f
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 28 deletions.
6 changes: 6 additions & 0 deletions flama/schemas/_libs/marshmallow/adapter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
import itertools
import sys
import typing as t

Expand Down Expand Up @@ -122,6 +123,11 @@ def to_json_schema(self, schema: t.Union[t.Type[Schema], t.Type[Field], Schema,
json_schema = converter.schema2jsonschema(schema)
else:
raise SchemaGenerationError

for property in itertools.chain(json_schema.get("properties", {}).values(), [json_schema]):
if isinstance(property.get("type"), list):
property["anyOf"] = [{"type": x} for x in property["type"]]
del property["type"]
except Exception as e:
raise SchemaGenerationError from e

Expand Down
7 changes: 0 additions & 7 deletions flama/schemas/_libs/pydantic/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,19 +105,12 @@ def to_json_schema(self, schema: t.Union[t.Type[Schema], t.Type[Field]]) -> JSON
json_schema = model_json_schema(schema, ref_template="#/components/schemas/{model}")
if "$defs" in json_schema:
del json_schema["$defs"]
for property in json_schema["properties"].values():
if "anyOf" in property: # Simplify type from anyOf to a list of types
property["type"] = [x["type"] for x in property["anyOf"]]
del property["anyOf"]
elif self.is_field(schema):
json_schema = model_json_schema(
self.build_schema(fields={"x": schema}), ref_template="#/components/schemas/{model}"
)["properties"]["x"]
if not schema.title: # Pydantic is introducing a default title, so we drop it
del json_schema["title"]
if "anyOf" in json_schema: # Simplify type from anyOf to a list of types
json_schema["type"] = [x["type"] for x in json_schema["anyOf"]]
del json_schema["anyOf"]
else:
raise TypeError("Not a valid schema class or field")

Expand Down
6 changes: 6 additions & 0 deletions flama/schemas/_libs/typesystem/adapter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
import itertools
import sys
import typing as t
import warnings
Expand Down Expand Up @@ -102,6 +103,11 @@ def to_json_schema(self, schema: t.Union[Schema, Field]) -> JSONSchema:
if not isinstance(json_schema, dict):
raise SchemaGenerationError

for property in itertools.chain(json_schema.get("properties", {}).values(), [json_schema]):
if isinstance(property.get("type"), list):
property["anyOf"] = [{"type": x} for x in property["type"]]
del property["type"]

json_schema.pop("components", None)
except Exception as e:
raise SchemaGenerationError from e
Expand Down
10 changes: 9 additions & 1 deletion flama/schemas/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,12 @@ def is_http_valid_type(cls, type_: t.Type) -> bool:

return (
(type_ in types.PARAMETERS_TYPES)
or (origin is t.Union and len(args) == 2 and args[0] in types.PARAMETERS_TYPES and args[1] is NoneType)
or (
origin in (t.Union, type(int | str))
and len(args) == 2
and args[0] in types.PARAMETERS_TYPES
and args[1] is NoneType
)
or (origin is list and args[0] in types.PARAMETERS_TYPES)
)

Expand Down Expand Up @@ -186,6 +191,9 @@ def nested_schemas(self, schema: t.Any = UNKNOWN) -> t.List[t.Any]:
if schemas.adapter.is_schema(schema):
return [schemas.adapter.unique_schema(schema)]

if t.get_origin(schema) in (t.Union, type(int | str)):
return [x for field in t.get_args(schema) for x in self.nested_schemas(field)]

if isinstance(schema, (list, tuple, set)):
return [x for field in schema for x in self.nested_schemas(field)]

Expand Down
30 changes: 30 additions & 0 deletions tests/schemas/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,36 @@ def bar_schema(app, foo_schema):
return namedtuple("BarSchema", ("schema", "name"))(schema=schema, name=name)


@pytest.fixture(scope="function")
def bar_optional_schema(app, foo_schema):
child_schema = foo_schema.schema
if app.schema.schema_library.lib == pydantic:
schema = pydantic.create_model(
"BarOptional", foo=(t.Union[child_schema, None], None), __module__="pydantic.main"
)
name = "pydantic.main.BarOptional"
elif app.schema.schema_library.lib == typesystem:
schema = typesystem.Schema(
title="BarOptional",
fields={
"foo": typesystem.Reference(
to="Foo", definitions=typesystem.Definitions({"Foo": child_schema}), allow_null=True, default=None
)
},
)
name = "typesystem.schemas.BarOptional"
elif app.schema.schema_library.lib == marshmallow:
schema = type(
"BarOptional",
(marshmallow.Schema,),
{"foo": marshmallow.fields.Nested(child_schema(), required=False, default=None, allow_none=True)},
)
name = "abc.BarOptional"
else:
raise ValueError(f"Wrong schema lib: {app.schema.schema_library.lib}")
return namedtuple("BarOptionalSchema", ("schema", "name"))(schema=schema, name=name)


@pytest.fixture(scope="function")
def bar_list_schema(app, foo_schema):
child_schema = foo_schema.schema
Expand Down
24 changes: 21 additions & 3 deletions tests/schemas/test_data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from copy import deepcopy
from unittest.mock import Mock, call, patch

import marshmallow
import pytest
import typesystem

from flama import types
from flama.injection import Parameter as InjectionParameter
Expand Down Expand Up @@ -103,7 +105,7 @@ def test_json_schema(self):

class TestCaseSchema:
@pytest.fixture(scope="function")
def schema_type(self, request, foo_schema, bar_schema, bar_list_schema, bar_dict_schema):
def schema_type(self, app, request, foo_schema, bar_schema, bar_optional_schema, bar_list_schema, bar_dict_schema):
if request.param is None:
return None
elif request.param == "bare_schema":
Expand All @@ -116,6 +118,10 @@ def schema_type(self, request, foo_schema, bar_schema, bar_list_schema, bar_dict
return types.PartialSchema[foo_schema.schema]
elif request.param == "schema_nested":
return types.Schema[bar_schema.schema]
elif request.param == "schema_nested_optional":
if app.schema.schema_library.lib in (typesystem, marshmallow):
pytest.skip("Library does not support optional nested schemas")
return types.Schema[bar_optional_schema.schema]
elif request.param == "schema_nested_list":
return types.Schema[bar_list_schema.schema]
elif request.param == "schema_nested_dict":
Expand All @@ -132,6 +138,7 @@ def schema_type(self, request, foo_schema, bar_schema, bar_list_schema, bar_dict
pytest.param("schema_partial", None, id="schema_partial"),
pytest.param("list_of_schema", None, id="list_of_schema"),
pytest.param("schema_nested", None, id="schema_nested"),
pytest.param("schema_nested_optional", None, id="schema_nested_optional"),
pytest.param("schema_nested_list", None, id="schema_nested_list"),
pytest.param("schema_nested_dict", None, id="schema_nested_dict"),
pytest.param(None, ValueError("Wrong schema type"), id="wrong"),
Expand Down Expand Up @@ -178,7 +185,7 @@ def test_name(self):
),
pytest.param(
"schema_partial",
{"properties": {"name": {"type": ["string", "null"]}}, "type": "object"},
{"properties": {"name": {"anyOf": [{"type": "string"}, {"type": "null"}]}}, "type": "object"},
None,
id="partial",
),
Expand All @@ -188,6 +195,15 @@ def test_name(self):
"properties.foo",
id="nested",
),
pytest.param(
"schema_nested_optional",
{
"properties": {"foo": {"anyOf": [{"$ref": "#/components/schemas/Foo"}, {"type": "null"}]}},
"type": "object",
},
"properties.foo.anyOf.0",
id="nested_optional",
),
pytest.param(
"schema_nested_list",
{
Expand Down Expand Up @@ -217,7 +233,9 @@ def test_json_schema(self, schemas, schema_type, json_schema, key_to_replace):
expected_result = deepcopy(json_schema)

if key_to_replace:
subdict = functools.reduce(lambda x, k: x[k], key_to_replace.split("."), expected_result)
subdict = functools.reduce(
lambda x, k: x[int(k) if k.isnumeric() else k], key_to_replace.split("."), expected_result
)
subdict["$ref"] = subdict["$ref"].replace("Foo", schemas["Foo"].name)

assert_recursive_contains(expected_result, result)
Expand Down
2 changes: 1 addition & 1 deletion tests/schemas/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,7 +1085,7 @@ def test_components_schemas(self, app, schemas):
"name": "param2",
"in": "query",
"required": False,
"schema": {"type": ["string", "null"]},
"schema": {"anyOf": [{"type": "string"}, {"type": "null"}]},
},
{
"name": "param3",
Expand Down
61 changes: 45 additions & 16 deletions tests/test_pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from flama import types
from flama.pagination import paginator
from tests.asserts import assert_recursive_contains


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -71,19 +72,33 @@ def test_pagination_schema_parameters(self, app):
schema = app.schema.schema["paths"]["/page-number/"]["get"]
parameters = schema.get("parameters", [])

for parameter in parameters:
parameter["schema"] = {k: v for k, v in parameter["schema"].items() if k in ("type", "default")}

assert parameters == [
assert_recursive_contains(
{
"name": "count",
"in": "query",
"required": False,
"schema": {"type": ["boolean", "null"], "default": False},
"schema": {"anyOf": [{"type": "boolean"}, {"type": "null"}], "default": False},
},
parameters[0],
)
assert_recursive_contains(
{
"name": "page",
"in": "query",
"required": False,
"schema": {"anyOf": [{"type": "integer"}, {"type": "null"}]},
},
parameters[1],
)
assert_recursive_contains(
{
"name": "page_size",
"in": "query",
"required": False,
"schema": {"anyOf": [{"type": "integer"}, {"type": "null"}]},
},
{"name": "page", "in": "query", "required": False, "schema": {"type": ["integer", "null"]}},
{"name": "page_size", "in": "query", "required": False, "schema": {"type": ["integer", "null"]}},
]
parameters[2],
)

def test_pagination_schema_return(self, app, output_schema):
prefix, name = output_schema.name.rsplit(".", 1)
Expand Down Expand Up @@ -195,19 +210,33 @@ def test_pagination_schema_parameters(self, app):
schema = app.schema.schema["paths"]["/limit-offset/"]["get"]
parameters = schema.get("parameters", [])

for parameter in parameters:
parameter["schema"] = {k: v for k, v in parameter["schema"].items() if k in ("type", "default")}

assert parameters == [
assert_recursive_contains(
{
"name": "count",
"in": "query",
"required": False,
"schema": {"type": ["boolean", "null"], "default": False},
"schema": {"anyOf": [{"type": "boolean"}, {"type": "null"}], "default": False},
},
parameters[0],
)
assert_recursive_contains(
{
"name": "limit",
"in": "query",
"required": False,
"schema": {"anyOf": [{"type": "integer"}, {"type": "null"}]},
},
parameters[1],
)
assert_recursive_contains(
{
"name": "offset",
"in": "query",
"required": False,
"schema": {"anyOf": [{"type": "integer"}, {"type": "null"}]},
},
{"name": "limit", "in": "query", "required": False, "schema": {"type": ["integer", "null"]}},
{"name": "offset", "in": "query", "required": False, "schema": {"type": ["integer", "null"]}},
]
parameters[2],
)

def test_pagination_schema_return(self, app, output_schema):
prefix, name = output_schema.name.rsplit(".", 1)
Expand Down

0 comments on commit 5716a6f

Please sign in to comment.