Skip to content

Commit

Permalink
✨ Generic schemas module for abstracting schema lib
Browse files Browse the repository at this point in the history
  • Loading branch information
perdy committed Jan 19, 2023
1 parent fd96e1e commit bf68305
Show file tree
Hide file tree
Showing 15 changed files with 150 additions and 71 deletions.
2 changes: 1 addition & 1 deletion flama/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from flama.injection import Injector
from flama.responses import APIErrorResponse
from flama.routing import Router
from flama.schemas import SchemaMixin
from flama.schemas.generator import SchemaMixin

if typing.TYPE_CHECKING:
from flama.resources import BaseResource
Expand Down
22 changes: 5 additions & 17 deletions flama/pagination/limit_offset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
import functools
import typing

import marshmallow

from flama import schemas
from flama.responses import APIResponse
from flama.validation import get_output_schema

Expand All @@ -12,18 +11,7 @@
except Exception: # pragma: no cover
forge = None # type: ignore

__all__ = ["LimitOffsetSchema", "LimitOffsetResponse", "limit_offset"]


class LimitOffsetMeta(marshmallow.Schema):
limit = marshmallow.fields.Integer(title="limit", description="Number of retrieved items")
offset = marshmallow.fields.Integer(title="offset", description="Collection offset")
count = marshmallow.fields.Integer(title="count", description="Total number of items", allow_none=True)


class LimitOffsetSchema(marshmallow.Schema):
meta = marshmallow.fields.Nested(LimitOffsetMeta)
data = marshmallow.fields.List(marshmallow.fields.Dict())
__all__ = ["LimitOffsetResponse", "limit_offset"]


class LimitOffsetResponse(APIResponse):
Expand All @@ -40,7 +28,7 @@ class LimitOffsetResponse(APIResponse):

def __init__(
self,
schema: marshmallow.Schema,
schema: schemas.Schema,
offset: typing.Optional[typing.Union[int, str]] = None,
limit: typing.Optional[typing.Union[int, str]] = None,
count: typing.Optional[bool] = True,
Expand Down Expand Up @@ -79,11 +67,11 @@ def limit_offset(func):
assert forge is not None, "`python-forge` must be installed to use Paginator."

resource_schema = get_output_schema(func)
data_schema = marshmallow.fields.Nested(resource_schema, many=True) if resource_schema else marshmallow.fields.Raw()
data_schema = schemas.fields.Nested(resource_schema, many=True) if resource_schema else schemas.fields.Raw()

schema = type(
"LimitOffsetPaginated" + resource_schema.__class__.__name__, # Add a prefix to avoid collision
(LimitOffsetSchema,),
(schemas.core.LimitOffsetSchema,),
{"data": data_schema}, # Replace generic with resource schema
)()

Expand Down
22 changes: 5 additions & 17 deletions flama/pagination/page_number.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
import functools
import typing

import marshmallow

from flama import schemas
from flama.responses import APIResponse
from flama.validation import get_output_schema

Expand All @@ -12,18 +11,7 @@
except Exception: # pragma: no cover
forge = None # type: ignore

__all__ = ["PageNumberSchema", "PageNumberResponse", "page_number"]


class PageNumberMeta(marshmallow.Schema):
page = marshmallow.fields.Integer(title="page", description="Current page number")
page_size = marshmallow.fields.Integer(title="page_size", description="Page size")
count = marshmallow.fields.Integer(title="count", description="Total number of items", allow_none=True)


class PageNumberSchema(marshmallow.Schema):
meta = marshmallow.fields.Nested(PageNumberMeta)
data = marshmallow.fields.List(marshmallow.fields.Dict())
__all__ = ["PageNumberResponse", "page_number"]


class PageNumberResponse(APIResponse):
Expand All @@ -42,7 +30,7 @@ class PageNumberResponse(APIResponse):

def __init__(
self,
schema: marshmallow.Schema,
schema: schemas.Schema,
page: typing.Optional[typing.Union[int, str]] = None,
page_size: typing.Optional[typing.Union[int, str]] = None,
count: typing.Optional[bool] = True,
Expand Down Expand Up @@ -86,11 +74,11 @@ def page_number(func):
assert forge is not None, "`python-forge` must be installed to use Paginator."

resource_schema = get_output_schema(func)
data_schema = marshmallow.fields.Nested(resource_schema, many=True) if resource_schema else marshmallow.fields.Raw()
data_schema = schemas.fields.Nested(resource_schema, many=True) if resource_schema else schemas.fields.Raw()

schema = type(
"PageNumberPaginated" + resource_schema.__class__.__name__, # Add a prefix to avoid collision
(PageNumberSchema,),
(schemas.core.PageNumberSchema,),
{"data": data_schema}, # Replace generic with resource schema
)()

Expand Down
10 changes: 3 additions & 7 deletions flama/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import marshmallow

from flama import pagination
from flama import pagination, schemas
from flama.exceptions import HTTPException
from flama.responses import APIResponse
from flama.types import Model, PrimaryKey, ResourceMeta, ResourceMethodMeta
Expand Down Expand Up @@ -38,10 +38,6 @@
}


class DropCollection(marshmallow.Schema):
deleted = marshmallow.fields.Integer(title="deleted", description="Number of deleted elements", required=True)


def resource_method(path: str, methods: typing.List[str] = None, name: str = None, **kwargs) -> typing.Callable:
def wrapper(func: typing.Callable) -> typing.Callable:
func._meta = ResourceMethodMeta(
Expand Down Expand Up @@ -415,14 +411,14 @@ def _add_drop(
) -> typing.Dict[str, typing.Any]:
@resource_method("/", methods=["DELETE"], name=f"{name}-drop")
@database.transaction()
async def drop(self) -> DropCollection:
async def drop(self) -> schemas.core.DropCollection:
query = sqlalchemy.select([sqlalchemy.func.count(self.model.c[model.primary_key.name])])
result = next((i for i in (await self.database.fetch_one(query)).values()))

query = self.model.delete()
await self.database.execute(query)

return APIResponse(schema=DropCollection(), content={"deleted": result}, status_code=204)
return APIResponse(schema=schemas.core.DropCollection(), content={"deleted": result}, status_code=204)

drop.__doc__ = f"""
tags:
Expand Down
12 changes: 3 additions & 9 deletions flama/responses.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import typing

import marshmallow
from starlette.responses import (
FileResponse,
HTMLResponse,
Expand All @@ -11,6 +10,7 @@
StreamingResponse,
)

from flama import schemas
from flama.exceptions import HTTPException, SerializationError

__all__ = [
Expand All @@ -27,16 +27,10 @@
]


class APIError(marshmallow.Schema):
status_code = marshmallow.fields.Integer(title="status_code", description="HTTP status code", required=True)
detail = marshmallow.fields.Raw(title="detail", description="Error detail", required=True)
error = marshmallow.fields.String(title="type", description="Exception or error type")


class APIResponse(JSONResponse):
media_type = "application/json"

def __init__(self, schema: typing.Optional[marshmallow.Schema] = None, *args, **kwargs):
def __init__(self, schema: typing.Optional[schemas.Schema] = None, *args, **kwargs):
self.schema = schema
super().__init__(*args, **kwargs)

Expand Down Expand Up @@ -64,7 +58,7 @@ def __init__(
"status_code": status_code,
}

super().__init__(schema=APIError(), content=content, status_code=status_code, *args, **kwargs)
super().__init__(schema=schemas.core.APIError(), content=content, status_code=status_code, *args, **kwargs)

self.detail = detail
self.exception = exception
Expand Down
25 changes: 25 additions & 0 deletions flama/schemas/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import importlib.util
import sys

_SCHEMA_LIBS = ("marshmallow",)
_INSTALLED = [x for x in _SCHEMA_LIBS if x in sys.modules or importlib.util.find_spec(x) is not None]
_LIB = None


for lib in _INSTALLED:
try:
_LIB = importlib.import_module(f"flama.schemas.{lib}")
break
except ModuleNotFoundError:
pass


# Check that at least one of the schema libs is installed
assert _LIB is not None, f"Any of the schema libraries ({', '.join(_SCHEMA_LIBS)}) must be installed."

lib = _LIB.lib
fields = _LIB.fields
Schema = _LIB.Schema
core = _LIB.core

__all__ = ["Schema", "fields", "lib", "core"]
18 changes: 11 additions & 7 deletions flama/schemas.py → flama/schemas/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from string import Template

import marshmallow
from starlette import routing, schemas
from starlette import routing
from starlette import schemas as starlette_schemas
from starlette.responses import HTMLResponse

from flama.responses import APIError
from flama import schemas
from flama.templates import PATH as TEMPLATES_PATH
from flama.types import EndpointInfo
from flama.utils import dict_safe_add

Expand All @@ -33,7 +35,7 @@ def ignore_aliases(self, data):
__all__ = ["OpenAPIResponse", "SchemaGenerator", "SchemaMixin"]


class OpenAPIResponse(schemas.OpenAPIResponse):
class OpenAPIResponse(starlette_schemas.OpenAPIResponse):
def render(self, content: typing.Any) -> bytes:
assert yaml is not None, "`pyyaml` must be installed to use OpenAPIResponse."
assert apispec is not None, "`apispec` must be installed to use OpenAPIResponse."
Expand Down Expand Up @@ -65,7 +67,7 @@ def __getitem__(self, item):
return schema


class SchemaGenerator(schemas.BaseSchemaGenerator):
class SchemaGenerator(starlette_schemas.BaseSchemaGenerator):
def __init__(self, title: str, version: str, description: str, openapi_version="3.0.0"):
assert apispec is not None, "`apispec` must be installed to use SchemaGenerator."

Expand Down Expand Up @@ -169,7 +171,9 @@ def _add_endpoint_response(self, endpoint: EndpointInfo, schema: typing.Dict):
)

def _add_endpoint_default_response(self, schema: typing.Dict):
dict_safe_add(schema, self.schemas[APIError], "responses", "default", "content", "application/json", "schema")
dict_safe_add(
schema, self.schemas[schemas.core.APIError], "responses", "default", "content", "application/json", "schema"
)

# Default description
schema["responses"]["default"]["description"] = schema["responses"]["default"].get(
Expand Down Expand Up @@ -257,7 +261,7 @@ def schema():

def add_docs_route(self):
def swagger_ui() -> HTMLResponse:
with open(os.path.join(os.path.dirname(__file__), "templates/swagger_ui.html")) as f:
with open(os.path.join(TEMPLATES_PATH, "swagger_ui.html")) as f:
content = Template(f.read()).substitute(title=self.title, schema_url=self.schema_url)

return HTMLResponse(content)
Expand All @@ -266,7 +270,7 @@ def swagger_ui() -> HTMLResponse:

def add_redoc_route(self):
def redoc() -> HTMLResponse:
with open(os.path.join(os.path.dirname(__file__), "templates/redoc.html")) as f:
with open(os.path.join(TEMPLATES_PATH, "redoc.html")) as f:
content = Template(f.read()).substitute(title=self.title, schema_url=self.schema_url)

return HTMLResponse(content)
Expand Down
8 changes: 8 additions & 0 deletions flama/schemas/marshmallow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import marshmallow
from marshmallow import Schema, fields

from flama.schemas.marshmallow import core

lib = marshmallow

__all__ = ["Schema", "fields", "lib", "core"]
33 changes: 33 additions & 0 deletions flama/schemas/marshmallow/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import marshmallow


class APIError(marshmallow.Schema):
status_code = marshmallow.fields.Integer(title="status_code", description="HTTP status code", required=True)
detail = marshmallow.fields.Raw(title="detail", description="Error detail", required=True)
error = marshmallow.fields.String(title="type", description="Exception or error type")


class DropCollection(marshmallow.Schema):
deleted = marshmallow.fields.Integer(title="deleted", description="Number of deleted elements", required=True)


class LimitOffsetMeta(marshmallow.Schema):
limit = marshmallow.fields.Integer(title="limit", description="Number of retrieved items")
offset = marshmallow.fields.Integer(title="offset", description="Collection offset")
count = marshmallow.fields.Integer(title="count", description="Total number of items", allow_none=True)


class LimitOffsetSchema(marshmallow.Schema):
meta = marshmallow.fields.Nested(LimitOffsetMeta)
data = marshmallow.fields.List(marshmallow.fields.Dict())


class PageNumberMeta(marshmallow.Schema):
page = marshmallow.fields.Integer(title="page", description="Current page number")
page_size = marshmallow.fields.Integer(title="page_size", description="Page size")
count = marshmallow.fields.Integer(title="count", description="Total number of items", allow_none=True)


class PageNumberSchema(marshmallow.Schema):
meta = marshmallow.fields.Nested(PageNumberMeta)
data = marshmallow.fields.List(marshmallow.fields.Dict())
3 changes: 3 additions & 0 deletions flama/templates/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import os.path

PATH = os.path.dirname(os.path.abspath(__file__))
8 changes: 3 additions & 5 deletions flama/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
import inspect
from functools import wraps

import marshmallow

from flama import exceptions
from flama import exceptions, schemas

__all__ = ["get_output_schema", "output_validation"]

Expand All @@ -17,9 +15,9 @@ def get_output_schema(func):
:returns: Output schema.
"""
return_annotation = inspect.signature(func).return_annotation
if inspect.isclass(return_annotation) and issubclass(return_annotation, marshmallow.Schema):
if inspect.isclass(return_annotation) and issubclass(return_annotation, schemas.Schema):
return return_annotation()
elif isinstance(return_annotation, marshmallow.Schema):
elif isinstance(return_annotation, schemas.Schema):
return return_annotation

return None
Expand Down
Empty file modified make
100644 → 100755
Empty file.
Loading

0 comments on commit bf68305

Please sign in to comment.