Skip to content

Commit

Permalink
🐛 Remove multiple from Schema
Browse files Browse the repository at this point in the history
  • Loading branch information
perdy committed Jan 19, 2023
1 parent 62a885d commit 0e1356f
Show file tree
Hide file tree
Showing 10 changed files with 138 additions and 88 deletions.
2 changes: 1 addition & 1 deletion flama/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(self, content: t.Any = None, schema: t.Optional["flama.schemas.type
def render(self, content: t.Any):
if self.schema is not None:
try:
content = schemas.Schema(self.schema, multiple=isinstance(content, list)).dump(content)
content = schemas.Schema(self.schema).dump(content)
except schemas.SchemaValidationError as e:
raise SerializationError(status_code=500, detail=e.errors)

Expand Down
3 changes: 1 addition & 2 deletions flama/pagination/limit_offset.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ def _inner(func: t.Callable):
schema = schemas.Schema.build(
paginated_schema_name,
schema=schemas.schemas.LimitOffset,
fields=[schemas.Field("data", t.List[resource_schema])], # type: ignore[valid-type]
multiple=False,
fields=[schemas.Field("data", resource_schema, multiple=True)],
).unique_schema

forge_revision_list = (
Expand Down
3 changes: 1 addition & 2 deletions flama/pagination/page_number.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ def _inner(func: t.Callable):
schema = schemas.Schema.build(
paginated_schema_name,
schema=schemas.schemas.PageNumber,
fields=[schemas.Field("data", t.List[resource_schema])], # type: ignore[valid-type]
multiple=False,
fields=[schemas.Field("data", resource_schema, multiple=True)],
).unique_schema

forge_revision_list = (
Expand Down
45 changes: 18 additions & 27 deletions flama/schemas/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,28 @@ class Field:
type: t.Type
nullable: bool = dataclasses.field(init=False)
field: t.Any = dataclasses.field(hash=False, init=False, compare=False)
multiple: bool = dataclasses.field(hash=False, init=False, compare=False)
multiple: t.Optional[bool] = dataclasses.field(hash=False, compare=False, default=None)
required: bool = True
default: t.Any = InjectionParameter.empty

def __post_init__(self) -> None:
object.__setattr__(self, "nullable", type(None) in t.get_args(self.type) or self.default is None)
object.__setattr__(self, "multiple", t.get_origin(self.type) is list)

field_type = t.get_args(self.type)[0] if t.get_origin(self.type) in (list, t.Union) else self.type

if not Schema.is_schema(field_type) and self.multiple is None:
object.__setattr__(self, "multiple", t.get_origin(self.type) is list)

object.__setattr__(
self,
"field",
schemas.adapter.build_field(
self.name,
t.get_args(self.type)[0] if t.get_origin(self.type) in (list, t.Union) else self.type,
field_type,
nullable=self.nullable,
required=self.required,
default=self.default,
multiple=self.multiple,
multiple=bool(self.multiple),
),
)

Expand Down Expand Up @@ -69,42 +74,33 @@ def is_http_valid_type(cls, type_: t.Type) -> bool:

@property
def json_schema(self) -> schemas.types.JSONSchema:
schema = schemas.adapter.to_json_schema(self.field)

if self.multiple:
schema = {"items": {"$ref": schema}, "type": "array"}

return schema
return schemas.adapter.to_json_schema(self.field)


@dataclasses.dataclass(frozen=True)
class Schema:
schema: t.Any = dataclasses.field(hash=False, compare=False)
multiple: bool = dataclasses.field(hash=False, compare=False, default=False)

@classmethod
def from_type(cls, type: t.Optional[t.Type]) -> "Schema":
multiple = t.get_origin(type) is list
schema = t.get_args(type)[0] if multiple else type
def from_type(cls, type_: t.Optional[t.Type]) -> "Schema":
schema = t.get_args(type_)[0] if t.get_origin(type_) is list else type_

if not schemas.adapter.is_schema(schema):
raise ValueError("Wrong schema type")

return cls(schema=schema, multiple=multiple)
return cls(schema=schema)

@classmethod
def build(
cls,
name: t.Optional[str] = None,
schema: t.Any = None,
fields: t.Optional[t.List[Field]] = None,
multiple: bool = False,
) -> "Schema":
return cls(
schema=schemas.adapter.build_schema(
name=name, schema=schema, fields={f.name: f.field for f in (fields or [])}
),
multiple=multiple,
)

@classmethod
Expand All @@ -113,12 +109,7 @@ def is_schema(cls, obj: t.Any) -> bool:

@property
def json_schema(self) -> t.Dict[str, t.Any]:
schema = schemas.adapter.to_json_schema(self.schema)

if self.multiple:
schema = {"items": {"$ref": schema}, "type": "array"}

return schema
return schemas.adapter.to_json_schema(self.schema)

@property
def unique_schema(self) -> t.Any:
Expand All @@ -133,7 +124,7 @@ def validate(self, values: t.List[t.Dict[str, t.Any]]) -> t.List[t.Dict[str, t.A
...

def validate(self, values):
if self.multiple and isinstance(values, (list, tuple)):
if isinstance(values, (list, tuple)):
return [schemas.adapter.validate(self.schema, value) for value in values]

return schemas.adapter.validate(self.schema, values)
Expand All @@ -147,7 +138,7 @@ def load(self, values: t.List[t.Dict[str, t.Any]]) -> t.List[t.Any]:
...

def load(self, values):
if self.multiple:
if isinstance(values, (list, tuple)):
return [schemas.adapter.load(self.schema, value) for value in values]

return schemas.adapter.load(self.schema, values)
Expand All @@ -161,7 +152,7 @@ def dump(self, values: t.List[t.Dict[str, t.Any]]) -> t.List[t.Dict[str, t.Any]]
...

def dump(self, values):
if self.multiple and isinstance(values, (list, tuple)):
if isinstance(values, (list, tuple)):
return [schemas.adapter.dump(self.schema, value) for value in values]

return schemas.adapter.dump(self.schema, values)
Expand All @@ -186,7 +177,7 @@ def __post_init__(self) -> None:
field = None
except ValueError:
if self.type in (None, InjectionParameter.empty):
schema = Schema(schema=None, multiple=False)
schema = Schema(schema=None)
field = None
else:
schema = None
Expand Down
55 changes: 30 additions & 25 deletions flama/schemas/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,20 +70,25 @@ def _get_schema_references_from_schema(
return [schema.ref]

result = []
for name, prop in schema.get("properties", {}).items():
if "$ref" in prop:
result.append(prop["$ref"])

if prop.get("type", "") == "array" and prop.get("items", {}).get("$ref"):
result.append(prop["items"]["$ref"])
if "$ref" in schema:
result.append(schema["$ref"])

result += [
ref["$ref"]
for composer in ("allOf", "anyOf", "oneOff")
if composer in prop
for ref in prop[composer]
if "$ref" in ref
]
if schema.get("type", "") == "array" and schema.get("items", {}).get("$ref"):
result.append(schema["items"]["$ref"])

result += [
ref
for composer in ("allOf", "anyOf", "oneOf")
for composer_schema in schema.get(composer, [])
for ref in self._get_schema_references_from_schema(composer_schema)
]

result += [
ref
for prop in schema.get("properties", {}).values()
for ref in self._get_schema_references_from_schema(prop)
]

return result

Expand Down Expand Up @@ -201,7 +206,7 @@ def register(self, schema: schemas.types.Schema, name: t.Optional[str] = None) -
return schema_id

def get_openapi_ref(
self, element: schemas.types.Schema, multiple: bool = False
self, element: schemas.types.Schema, multiple: t.Optional[bool] = None
) -> typing.Union[openapi.Schema, openapi.Reference]:
"""
Builds the reference for a single schema or the array schema containing the reference.
Expand All @@ -210,12 +215,14 @@ def get_openapi_ref(
:param multiple: True for building a schema containing an array of references instead of a single reference.
:return: Reference or array schema.
"""
reference = openapi.Reference(ref=self[element].ref)

if multiple:
return openapi.Schema({"items": dataclasses.asdict(reference), "type": "array"})
reference = self[element].ref

return reference
if multiple is True:
return openapi.Schema({"items": {"$ref": reference}, "type": "array"})
elif multiple is None:
return openapi.Schema({"oneOf": [{"$ref": reference}, {"items": {"$ref": reference}, "type": "array"}]})
else:
return openapi.Reference(ref=reference)


class SchemaGenerator(starlette_schemas.BaseSchemaGenerator):
Expand Down Expand Up @@ -355,9 +362,7 @@ def _build_endpoint_body(
return openapi.RequestBody(
content={
"application/json": openapi.MediaType(
schema=self.schemas.get_openapi_ref(
endpoint.body_parameter.schema.schema, multiple=endpoint.body_parameter.schema.multiple
)
schema=self.schemas.get_openapi_ref(endpoint.body_parameter.schema.schema, multiple=False),
)
},
**{
Expand All @@ -383,9 +388,7 @@ def _build_endpoint_response(

content = {
"application/json": openapi.MediaType(
schema=self.schemas.get_openapi_ref(
endpoint.response_parameter.schema.schema, multiple=endpoint.response_parameter.schema.multiple
)
schema=self.schemas.get_openapi_ref(endpoint.response_parameter.schema.schema)
)
}
else:
Expand All @@ -403,7 +406,9 @@ def _build_endpoint_default_response(self, metadata: typing.Dict[str, typing.Any
return openapi.Response(
description=metadata.get("responses", {}).get("default", {}).get("description", "Unexpected error."),
content={
"application/json": openapi.MediaType(schema=self.schemas.get_openapi_ref(schemas.schemas.APIError))
"application/json": openapi.MediaType(
schema=self.schemas.get_openapi_ref(schemas.schemas.APIError, multiple=False)
)
},
)

Expand Down
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 0e1356f

Please sign in to comment.