From a24a5836541de6c78183fc6c5f45d0dec9b27f86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Antonio=20Perdiguero=20L=C3=B3pez?= Date: Fri, 3 May 2024 23:37:28 +0200 Subject: [PATCH] :sparkles: Use PartialSchema for all-optional schema version --- flama/schemas/_libs/marshmallow/adapter.py | 18 ++++--- flama/schemas/_libs/pydantic/adapter.py | 38 +++++++------ flama/schemas/_libs/typesystem/adapter.py | 18 ++++--- flama/schemas/adapter.py | 16 ++++++ flama/schemas/data_structures.py | 10 +++- flama/schemas/generator.py | 2 +- tests/schemas/test_data_structures.py | 62 ++++++++++++++-------- 7 files changed, 112 insertions(+), 52 deletions(-) diff --git a/flama/schemas/_libs/marshmallow/adapter.py b/flama/schemas/_libs/marshmallow/adapter.py index 953c8b7b..1a907877 100644 --- a/flama/schemas/_libs/marshmallow/adapter.py +++ b/flama/schemas/_libs/marshmallow/adapter.py @@ -60,11 +60,16 @@ def build_schema( name: t.Optional[str] = None, schema: t.Optional[t.Union[Schema, t.Type[Schema]]] = None, fields: t.Optional[t.Dict[str, Field]] = None, + partial: bool = False, ) -> t.Type[Schema]: - return Schema.from_dict( - fields={**(self.unique_schema(schema)().fields if schema else {}), **(fields or {})}, - name=name or self.DEFAULT_SCHEMA_NAME, - ) + fields_ = {**(self.unique_schema(schema)().fields if schema else {}), **(fields or {})} + + if partial: + for field in fields_: + fields_[field].required = False + fields_[field].allow_none = True + + return Schema.from_dict(fields=fields_, name=name or self.DEFAULT_SCHEMA_NAME) # type: ignore def validate( self, schema: t.Union[t.Type[Schema], Schema], values: t.Dict[str, t.Any], *, partial: bool = False @@ -90,9 +95,10 @@ def dump(self, schema: t.Union[t.Type[Schema], Schema], value: t.Dict[str, t.Any return dump_value - def name(self, schema: t.Union[Schema, t.Type[Schema]]) -> str: + def name(self, schema: t.Union[Schema, t.Type[Schema]], *, prefix: t.Optional[str] = None) -> str: s = self.unique_schema(schema) - return s.__qualname__ if s.__module__ == "builtins" else f"{s.__module__}.{s.__qualname__}" + schema_name = f"{prefix or ''}{s.__qualname__}" + return schema_name if s.__module__ == "builtins" else f"{s.__module__}.{schema_name}" def to_json_schema(self, schema: t.Union[t.Type[Schema], t.Type[Field], Schema, Field]) -> JSONSchema: json_schema: t.Dict[str, t.Any] diff --git a/flama/schemas/_libs/pydantic/adapter.py b/flama/schemas/_libs/pydantic/adapter.py index a3a1874b..63bdebd7 100644 --- a/flama/schemas/_libs/pydantic/adapter.py +++ b/flama/schemas/_libs/pydantic/adapter.py @@ -57,21 +57,22 @@ def build_schema( name: t.Optional[str] = None, schema: t.Optional[t.Union[Schema, t.Type[Schema]]] = None, fields: t.Optional[t.Dict[str, t.Type[Field]]] = None, + partial: bool = False, ) -> t.Type[Schema]: - return pydantic.create_model( - name or self.DEFAULT_SCHEMA_NAME, + fields_ = { **{ - **( - { - name: (field_info.annotation, field_info) - for name, field_info in self.unique_schema(schema).model_fields.items() - } - if self.is_schema(schema) - else {} - ), - **({name: (field.annotation, field) for name, field in fields.items()} if fields else {}), - }, # type: ignore - ) + name: (field.annotation, field) + for name, field in (self.unique_schema(schema).model_fields.items() if self.is_schema(schema) else {}) + }, + **{name: (field.annotation, field) for name, field in (fields.items() if fields else {})}, + } + + if partial: + for name, (annotation, field) in fields_.items(): + field.default = None + fields_[name] = (t.Optional[annotation], field) + + return pydantic.create_model(name or self.DEFAULT_SCHEMA_NAME, **fields_) def validate( self, schema: t.Union[Schema, t.Type[Schema]], values: t.Dict[str, t.Any], *, partial: bool = False @@ -93,9 +94,10 @@ def dump(self, schema: t.Union[Schema, t.Type[Schema]], value: t.Dict[str, t.Any return self.validate(schema_cls, value) - def name(self, schema: t.Union[Schema, t.Type[Schema]]) -> str: + def name(self, schema: t.Union[Schema, t.Type[Schema]], *, prefix: t.Optional[str] = None) -> str: s = self.unique_schema(schema) - return s.__qualname__ if s.__module__ == "builtins" else f"{s.__module__}.{s.__qualname__}" + schema_name = f"{prefix or ''}{s.__qualname__}" + return schema_name if s.__module__ == "builtins" else f"{s.__module__}.{schema_name}" def to_json_schema(self, schema: t.Union[t.Type[Schema], t.Type[Field]]) -> JSONSchema: try: @@ -103,13 +105,17 @@ def to_json_schema(self, schema: t.Union[t.Type[Schema], t.Type[Field]]) -> JSON json_schema = model_json_schema(schema, ref_template="#/components/schemas/{model}") if "$defs" in json_schema: del json_schema["$defs"] + for property in json_schema["properties"].values(): + if "anyOf" in property: # Simplify type from anyOf to a list of types + property["type"] = [x["type"] for x in property["anyOf"]] + del property["anyOf"] elif self.is_field(schema): json_schema = model_json_schema( self.build_schema(fields={"x": schema}), ref_template="#/components/schemas/{model}" )["properties"]["x"] if not schema.title: # Pydantic is introducing a default title, so we drop it del json_schema["title"] - if "anyOf" in json_schema: # Just simplifying type definition from anyOf to a list of types + if "anyOf" in json_schema: # Simplify type from anyOf to a list of types json_schema["type"] = [x["type"] for x in json_schema["anyOf"]] del json_schema["anyOf"] else: diff --git a/flama/schemas/_libs/typesystem/adapter.py b/flama/schemas/_libs/typesystem/adapter.py index f24720f5..85233c01 100644 --- a/flama/schemas/_libs/typesystem/adapter.py +++ b/flama/schemas/_libs/typesystem/adapter.py @@ -56,11 +56,16 @@ def build_schema( # type: ignore[return-value] name: t.Optional[str] = None, schema: t.Optional[t.Union[Schema, t.Type[Schema]]] = None, fields: t.Optional[t.Dict[str, Field]] = None, + partial: bool = False, ) -> Schema: - return Schema( - title=name or self.DEFAULT_SCHEMA_NAME, - fields={**(self.unique_schema(schema).fields if self.is_schema(schema) else {}), **(fields or {})}, - ) + fields_ = {**(self.unique_schema(schema).fields if self.is_schema(schema) else {}), **(fields or {})} + + if partial: + for field in fields_: + fields_[field].default = None + fields_[field].allow_null = True + + return Schema(title=name or self.DEFAULT_SCHEMA_NAME, fields=fields_) def validate(self, schema: Schema, values: t.Dict[str, t.Any], *, partial: bool = False) -> t.Any: try: @@ -83,11 +88,12 @@ def _dump(self, value: t.Any) -> t.Any: return value - def name(self, schema: Schema) -> str: + def name(self, schema: Schema, *, prefix: t.Optional[str] = None) -> str: if not schema.title: raise ValueError(f"Schema '{schema}' needs to define title attribute") - return schema.title if schema.__module__ == "builtins" else f"{schema.__module__}.{schema.title}" + schema_name = f"{prefix or ''}{schema.title}" + return schema_name if schema.__module__ == "builtins" else f"{schema.__module__}.{schema_name}" def to_json_schema(self, schema: t.Union[Schema, Field]) -> JSONSchema: try: diff --git a/flama/schemas/adapter.py b/flama/schemas/adapter.py index 52166c91..80be8444 100644 --- a/flama/schemas/adapter.py +++ b/flama/schemas/adapter.py @@ -37,6 +37,11 @@ def build_schema(self, *, name: t.Optional[str] = None, fields: t.Dict[str, t.An def build_schema(self, *, name: t.Optional[str] = None, schema: t.Any) -> t.Any: ... + @t.overload + @abc.abstractmethod + def build_schema(self, *, name: t.Optional[str] = None, schema: t.Any, partial: bool) -> t.Any: + ... + @t.overload @abc.abstractmethod def build_schema( @@ -51,6 +56,7 @@ def build_schema( name: t.Optional[str] = None, schema: t.Optional[t.Any] = None, fields: t.Optional[t.Dict[str, t.Any]] = None, + partial: bool = False, ) -> t.Any: ... @@ -66,10 +72,20 @@ def load(self, schema: t.Any, value: t.Dict[str, t.Any]) -> _T_Schema: def dump(self, schema: t.Any, value: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: ... + @t.overload @abc.abstractmethod def name(self, schema: t.Any) -> str: ... + @t.overload + @abc.abstractmethod + def name(self, schema: t.Any, *, prefix: str) -> str: + ... + + @abc.abstractmethod + def name(self, schema: t.Any, *, prefix: t.Optional[str] = None) -> str: + ... + @abc.abstractmethod def to_json_schema(self, schema: t.Any) -> JSONSchema: ... diff --git a/flama/schemas/data_structures.py b/flama/schemas/data_structures.py index ac3d1020..30bf69e5 100644 --- a/flama/schemas/data_structures.py +++ b/flama/schemas/data_structures.py @@ -103,7 +103,15 @@ class Schema: @classmethod def from_type(cls, type_: t.Optional[t.Type]) -> "Schema": if types.Schema.is_schema(type_): - schema = type_.schema + schema = ( + type_.schema + if not type_.partial + else schemas.adapter.build_schema( + name=schemas.adapter.name(type_.schema, prefix="Partial").rsplit(".", 1)[1], + schema=type_.schema, + partial=True, + ) + ) elif t.get_origin(type_) in (list, tuple, set): return cls.from_type(t.get_args(type_)[0]) else: diff --git a/flama/schemas/generator.py b/flama/schemas/generator.py index 8cafd901..7f69b626 100644 --- a/flama/schemas/generator.py +++ b/flama/schemas/generator.py @@ -31,7 +31,7 @@ class EndpointInfo: @dataclasses.dataclass(frozen=True) class SchemaInfo: name: str - schema: t.Any + schema: types.Schema @property def ref(self) -> str: diff --git a/tests/schemas/test_data_structures.py b/tests/schemas/test_data_structures.py index 50718887..89e7b933 100644 --- a/tests/schemas/test_data_structures.py +++ b/tests/schemas/test_data_structures.py @@ -103,33 +103,44 @@ def test_json_schema(self): class TestCaseSchema: @pytest.fixture(scope="function") - def schema_type(self, request, foo_schema): + def schema_type(self, request, foo_schema, bar_schema, bar_list_schema, bar_dict_schema): if request.param is None: return None - elif request.param == "schema": + elif request.param == "bare_schema": return foo_schema.schema - elif request.param == "schema_wrapped": + elif request.param == "schema": return types.Schema[foo_schema.schema] - elif request.param == "list": - return t.List[foo_schema.schema] if inspect.isclass(foo_schema.schema) else foo_schema.schema + elif request.param == "list_of_schema": + return t.List[types.Schema[foo_schema.schema]] if inspect.isclass(foo_schema.schema) else foo_schema.schema + elif request.param == "schema_partial": + return types.PartialSchema[foo_schema.schema] + elif request.param == "schema_nested": + return types.Schema[bar_schema.schema] + elif request.param == "schema_nested_list": + return types.Schema[bar_list_schema.schema] + elif request.param == "schema_nested_dict": + return types.Schema[bar_dict_schema.schema] + else: raise ValueError("Wrong schema type") @pytest.mark.parametrize( ["schema_type", "exception"], ( + pytest.param("bare_schema", None, id="bare_schema"), pytest.param("schema", None, id="schema"), - pytest.param("schema_wrapped", None, id="schema_wrapped"), - pytest.param("list", None, id="list"), + pytest.param("schema_partial", None, id="schema_partial"), + pytest.param("list_of_schema", None, id="list_of_schema"), + pytest.param("schema_nested", None, id="schema_nested"), + pytest.param("schema_nested_list", None, id="schema_nested_list"), + pytest.param("schema_nested_dict", None, id="schema_nested_dict"), pytest.param(None, ValueError("Wrong schema type"), id="wrong"), ), indirect=["schema_type", "exception"], ) - def test_from_type(self, foo_schema, schema_type, exception): + def test_from_type(self, schema_type, exception): with exception: - schema = Schema.from_type(schema_type) - - assert schema.schema == foo_schema.schema + Schema.from_type(schema_type) def test_build(self, foo_schema): n = Mock() @@ -157,31 +168,37 @@ def test_name(self): assert schemas_mock.adapter.name.call_args_list == [call(mock)] @pytest.mark.parametrize( - ["schema", "json_schema", "key_to_replace"], + ["schema_type", "json_schema", "key_to_replace"], ( pytest.param( - "Foo", + "schema", {"properties": {"name": {"type": "string"}}, "type": "object"}, None, - id="no_nested", + id="plain", + ), + pytest.param( + "schema_partial", + {"properties": {"name": {"type": ["string", "null"]}}, "type": "object"}, + None, + id="partial", ), pytest.param( - "Bar", + "schema_nested", {"properties": {"foo": {"$ref": "#/components/schemas/Foo"}}, "type": "object"}, "properties.foo", - id="attribute_nested", + id="nested", ), pytest.param( - "BarList", + "schema_nested_list", { "properties": {"foo": {"items": {"$ref": "#/components/schemas/Foo"}, "type": "array"}}, "type": "object", }, "properties.foo.items", - id="list_nested", + id="nested_list", ), pytest.param( - "BarDict", + "schema_nested_dict", { "properties": { "foo": {"additionalProperties": {"$ref": "#/components/schemas/Foo"}, "type": "object"} @@ -189,12 +206,13 @@ def test_name(self): "type": "object", }, "properties.foo.additionalProperties", - id="dict_nested", + id="nested_dict", ), ), + indirect=["schema_type"], ) - def test_json_schema(self, schemas, schema, json_schema, key_to_replace): - result = Schema(schemas[schema].schema).json_schema({id(schemas["Foo"].schema): schemas["Foo"].name}) + def test_json_schema(self, schemas, schema_type, json_schema, key_to_replace): + result = Schema.from_type(schema_type).json_schema({id(schemas["Foo"].schema): schemas["Foo"].name}) expected_result = deepcopy(json_schema)