Skip to content

Commit

Permalink
✨ Use PartialSchema for all-optional schema version
Browse files Browse the repository at this point in the history
  • Loading branch information
perdy committed May 3, 2024
1 parent ec4ae36 commit a24a583
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 52 deletions.
18 changes: 12 additions & 6 deletions flama/schemas/_libs/marshmallow/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,16 @@ def build_schema(
name: t.Optional[str] = None,
schema: t.Optional[t.Union[Schema, t.Type[Schema]]] = None,
fields: t.Optional[t.Dict[str, Field]] = None,
partial: bool = False,
) -> t.Type[Schema]:
return Schema.from_dict(
fields={**(self.unique_schema(schema)().fields if schema else {}), **(fields or {})},
name=name or self.DEFAULT_SCHEMA_NAME,
)
fields_ = {**(self.unique_schema(schema)().fields if schema else {}), **(fields or {})}

if partial:
for field in fields_:
fields_[field].required = False
fields_[field].allow_none = True

return Schema.from_dict(fields=fields_, name=name or self.DEFAULT_SCHEMA_NAME) # type: ignore

def validate(
self, schema: t.Union[t.Type[Schema], Schema], values: t.Dict[str, t.Any], *, partial: bool = False
Expand All @@ -90,9 +95,10 @@ def dump(self, schema: t.Union[t.Type[Schema], Schema], value: t.Dict[str, t.Any

return dump_value

def name(self, schema: t.Union[Schema, t.Type[Schema]]) -> str:
def name(self, schema: t.Union[Schema, t.Type[Schema]], *, prefix: t.Optional[str] = None) -> str:
s = self.unique_schema(schema)
return s.__qualname__ if s.__module__ == "builtins" else f"{s.__module__}.{s.__qualname__}"
schema_name = f"{prefix or ''}{s.__qualname__}"
return schema_name if s.__module__ == "builtins" else f"{s.__module__}.{schema_name}"

def to_json_schema(self, schema: t.Union[t.Type[Schema], t.Type[Field], Schema, Field]) -> JSONSchema:
json_schema: t.Dict[str, t.Any]
Expand Down
38 changes: 22 additions & 16 deletions flama/schemas/_libs/pydantic/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,22 @@ def build_schema(
name: t.Optional[str] = None,
schema: t.Optional[t.Union[Schema, t.Type[Schema]]] = None,
fields: t.Optional[t.Dict[str, t.Type[Field]]] = None,
partial: bool = False,
) -> t.Type[Schema]:
return pydantic.create_model(
name or self.DEFAULT_SCHEMA_NAME,
fields_ = {
**{
**(
{
name: (field_info.annotation, field_info)
for name, field_info in self.unique_schema(schema).model_fields.items()
}
if self.is_schema(schema)
else {}
),
**({name: (field.annotation, field) for name, field in fields.items()} if fields else {}),
}, # type: ignore
)
name: (field.annotation, field)
for name, field in (self.unique_schema(schema).model_fields.items() if self.is_schema(schema) else {})
},
**{name: (field.annotation, field) for name, field in (fields.items() if fields else {})},
}

if partial:
for name, (annotation, field) in fields_.items():
field.default = None
fields_[name] = (t.Optional[annotation], field)

return pydantic.create_model(name or self.DEFAULT_SCHEMA_NAME, **fields_)

def validate(
self, schema: t.Union[Schema, t.Type[Schema]], values: t.Dict[str, t.Any], *, partial: bool = False
Expand All @@ -93,23 +94,28 @@ def dump(self, schema: t.Union[Schema, t.Type[Schema]], value: t.Dict[str, t.Any

return self.validate(schema_cls, value)

def name(self, schema: t.Union[Schema, t.Type[Schema]]) -> str:
def name(self, schema: t.Union[Schema, t.Type[Schema]], *, prefix: t.Optional[str] = None) -> str:
s = self.unique_schema(schema)
return s.__qualname__ if s.__module__ == "builtins" else f"{s.__module__}.{s.__qualname__}"
schema_name = f"{prefix or ''}{s.__qualname__}"
return schema_name if s.__module__ == "builtins" else f"{s.__module__}.{schema_name}"

def to_json_schema(self, schema: t.Union[t.Type[Schema], t.Type[Field]]) -> JSONSchema:
try:
if self.is_schema(schema):
json_schema = model_json_schema(schema, ref_template="#/components/schemas/{model}")
if "$defs" in json_schema:
del json_schema["$defs"]
for property in json_schema["properties"].values():
if "anyOf" in property: # Simplify type from anyOf to a list of types
property["type"] = [x["type"] for x in property["anyOf"]]
del property["anyOf"]
elif self.is_field(schema):
json_schema = model_json_schema(
self.build_schema(fields={"x": schema}), ref_template="#/components/schemas/{model}"
)["properties"]["x"]
if not schema.title: # Pydantic is introducing a default title, so we drop it
del json_schema["title"]
if "anyOf" in json_schema: # Just simplifying type definition from anyOf to a list of types
if "anyOf" in json_schema: # Simplify type from anyOf to a list of types
json_schema["type"] = [x["type"] for x in json_schema["anyOf"]]
del json_schema["anyOf"]
else:
Expand Down
18 changes: 12 additions & 6 deletions flama/schemas/_libs/typesystem/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,16 @@ def build_schema( # type: ignore[return-value]
name: t.Optional[str] = None,
schema: t.Optional[t.Union[Schema, t.Type[Schema]]] = None,
fields: t.Optional[t.Dict[str, Field]] = None,
partial: bool = False,
) -> Schema:
return Schema(
title=name or self.DEFAULT_SCHEMA_NAME,
fields={**(self.unique_schema(schema).fields if self.is_schema(schema) else {}), **(fields or {})},
)
fields_ = {**(self.unique_schema(schema).fields if self.is_schema(schema) else {}), **(fields or {})}

if partial:
for field in fields_:
fields_[field].default = None
fields_[field].allow_null = True

return Schema(title=name or self.DEFAULT_SCHEMA_NAME, fields=fields_)

def validate(self, schema: Schema, values: t.Dict[str, t.Any], *, partial: bool = False) -> t.Any:
try:
Expand All @@ -83,11 +88,12 @@ def _dump(self, value: t.Any) -> t.Any:

return value

def name(self, schema: Schema) -> str:
def name(self, schema: Schema, *, prefix: t.Optional[str] = None) -> str:
if not schema.title:
raise ValueError(f"Schema '{schema}' needs to define title attribute")

return schema.title if schema.__module__ == "builtins" else f"{schema.__module__}.{schema.title}"
schema_name = f"{prefix or ''}{schema.title}"
return schema_name if schema.__module__ == "builtins" else f"{schema.__module__}.{schema_name}"

def to_json_schema(self, schema: t.Union[Schema, Field]) -> JSONSchema:
try:
Expand Down
16 changes: 16 additions & 0 deletions flama/schemas/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ def build_schema(self, *, name: t.Optional[str] = None, fields: t.Dict[str, t.An
def build_schema(self, *, name: t.Optional[str] = None, schema: t.Any) -> t.Any:
...

@t.overload
@abc.abstractmethod
def build_schema(self, *, name: t.Optional[str] = None, schema: t.Any, partial: bool) -> t.Any:
...

@t.overload
@abc.abstractmethod
def build_schema(
Expand All @@ -51,6 +56,7 @@ def build_schema(
name: t.Optional[str] = None,
schema: t.Optional[t.Any] = None,
fields: t.Optional[t.Dict[str, t.Any]] = None,
partial: bool = False,
) -> t.Any:
...

Expand All @@ -66,10 +72,20 @@ def load(self, schema: t.Any, value: t.Dict[str, t.Any]) -> _T_Schema:
def dump(self, schema: t.Any, value: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
...

@t.overload
@abc.abstractmethod
def name(self, schema: t.Any) -> str:
...

@t.overload
@abc.abstractmethod
def name(self, schema: t.Any, *, prefix: str) -> str:
...

@abc.abstractmethod
def name(self, schema: t.Any, *, prefix: t.Optional[str] = None) -> str:
...

@abc.abstractmethod
def to_json_schema(self, schema: t.Any) -> JSONSchema:
...
Expand Down
10 changes: 9 additions & 1 deletion flama/schemas/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,15 @@ class Schema:
@classmethod
def from_type(cls, type_: t.Optional[t.Type]) -> "Schema":
if types.Schema.is_schema(type_):
schema = type_.schema
schema = (
type_.schema
if not type_.partial
else schemas.adapter.build_schema(
name=schemas.adapter.name(type_.schema, prefix="Partial").rsplit(".", 1)[1],
schema=type_.schema,
partial=True,
)
)
elif t.get_origin(type_) in (list, tuple, set):
return cls.from_type(t.get_args(type_)[0])
else:
Expand Down
2 changes: 1 addition & 1 deletion flama/schemas/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class EndpointInfo:
@dataclasses.dataclass(frozen=True)
class SchemaInfo:
name: str
schema: t.Any
schema: types.Schema

@property
def ref(self) -> str:
Expand Down
62 changes: 40 additions & 22 deletions tests/schemas/test_data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,33 +103,44 @@ def test_json_schema(self):

class TestCaseSchema:
@pytest.fixture(scope="function")
def schema_type(self, request, foo_schema):
def schema_type(self, request, foo_schema, bar_schema, bar_list_schema, bar_dict_schema):
if request.param is None:
return None
elif request.param == "schema":
elif request.param == "bare_schema":
return foo_schema.schema
elif request.param == "schema_wrapped":
elif request.param == "schema":
return types.Schema[foo_schema.schema]
elif request.param == "list":
return t.List[foo_schema.schema] if inspect.isclass(foo_schema.schema) else foo_schema.schema
elif request.param == "list_of_schema":
return t.List[types.Schema[foo_schema.schema]] if inspect.isclass(foo_schema.schema) else foo_schema.schema
elif request.param == "schema_partial":
return types.PartialSchema[foo_schema.schema]
elif request.param == "schema_nested":
return types.Schema[bar_schema.schema]
elif request.param == "schema_nested_list":
return types.Schema[bar_list_schema.schema]
elif request.param == "schema_nested_dict":
return types.Schema[bar_dict_schema.schema]

else:
raise ValueError("Wrong schema type")

@pytest.mark.parametrize(
["schema_type", "exception"],
(
pytest.param("bare_schema", None, id="bare_schema"),
pytest.param("schema", None, id="schema"),
pytest.param("schema_wrapped", None, id="schema_wrapped"),
pytest.param("list", None, id="list"),
pytest.param("schema_partial", None, id="schema_partial"),
pytest.param("list_of_schema", None, id="list_of_schema"),
pytest.param("schema_nested", None, id="schema_nested"),
pytest.param("schema_nested_list", None, id="schema_nested_list"),
pytest.param("schema_nested_dict", None, id="schema_nested_dict"),
pytest.param(None, ValueError("Wrong schema type"), id="wrong"),
),
indirect=["schema_type", "exception"],
)
def test_from_type(self, foo_schema, schema_type, exception):
def test_from_type(self, schema_type, exception):
with exception:
schema = Schema.from_type(schema_type)

assert schema.schema == foo_schema.schema
Schema.from_type(schema_type)

def test_build(self, foo_schema):
n = Mock()
Expand Down Expand Up @@ -157,44 +168,51 @@ def test_name(self):
assert schemas_mock.adapter.name.call_args_list == [call(mock)]

@pytest.mark.parametrize(
["schema", "json_schema", "key_to_replace"],
["schema_type", "json_schema", "key_to_replace"],
(
pytest.param(
"Foo",
"schema",
{"properties": {"name": {"type": "string"}}, "type": "object"},
None,
id="no_nested",
id="plain",
),
pytest.param(
"schema_partial",
{"properties": {"name": {"type": ["string", "null"]}}, "type": "object"},
None,
id="partial",
),
pytest.param(
"Bar",
"schema_nested",
{"properties": {"foo": {"$ref": "#/components/schemas/Foo"}}, "type": "object"},
"properties.foo",
id="attribute_nested",
id="nested",
),
pytest.param(
"BarList",
"schema_nested_list",
{
"properties": {"foo": {"items": {"$ref": "#/components/schemas/Foo"}, "type": "array"}},
"type": "object",
},
"properties.foo.items",
id="list_nested",
id="nested_list",
),
pytest.param(
"BarDict",
"schema_nested_dict",
{
"properties": {
"foo": {"additionalProperties": {"$ref": "#/components/schemas/Foo"}, "type": "object"}
},
"type": "object",
},
"properties.foo.additionalProperties",
id="dict_nested",
id="nested_dict",
),
),
indirect=["schema_type"],
)
def test_json_schema(self, schemas, schema, json_schema, key_to_replace):
result = Schema(schemas[schema].schema).json_schema({id(schemas["Foo"].schema): schemas["Foo"].name})
def test_json_schema(self, schemas, schema_type, json_schema, key_to_replace):
result = Schema.from_type(schema_type).json_schema({id(schemas["Foo"].schema): schemas["Foo"].name})

expected_result = deepcopy(json_schema)

Expand Down

0 comments on commit a24a583

Please sign in to comment.