Skip to content

Commit

Permalink
✨ Middleware stack
Browse files Browse the repository at this point in the history
  • Loading branch information
perdy committed Jan 19, 2023
1 parent 88b198e commit abfceb6
Show file tree
Hide file tree
Showing 11 changed files with 806 additions and 68 deletions.
102 changes: 74 additions & 28 deletions flama/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,22 @@
import typing

from starlette.applications import Starlette
from starlette.middleware.exceptions import ExceptionMiddleware
from starlette.datastructures import State

from flama.debug.middleware import ServerErrorMiddleware
from flama.exceptions import HTTPException
from flama.injection import Injector
from flama.lifespan import Lifespan
from flama.middleware import Middleware
from flama.middleware import Middleware, MiddlewareStack
from flama.models.modules import ModelsModule
from flama.modules import Modules
from flama.pagination import paginator
from flama.resources import ResourcesModule
from flama.responses import APIErrorResponse
from flama.routing import Router
from flama.schemas.modules import SchemaModule
from flama.sqlalchemy import SQLAlchemyModule

if typing.TYPE_CHECKING:
from flama.asgi import App
from flama.asgi import App, Receive, Scope, Send
from flama.components import Component, Components
from flama.http import Request, Response
from flama.modules import Module
from flama.routing import BaseRoute, Mount, WebSocketRoute

Expand All @@ -37,7 +33,7 @@ def __init__(
routes: typing.Sequence[typing.Union["BaseRoute", "Mount"]] = None,
components: typing.Optional[typing.List["Component"]] = None,
modules: typing.Optional[typing.List[typing.Type["Module"]]] = None,
middleware: typing.Sequence["Middleware"] = None,
middleware: typing.Optional[typing.Sequence["Middleware"]] = None,
debug: bool = False,
on_startup: typing.Sequence[typing.Callable] = None,
on_shutdown: typing.Sequence[typing.Callable] = None,
Expand All @@ -52,7 +48,8 @@ def __init__(
*args,
**kwargs
) -> None:
super().__init__(debug, *args, **kwargs)
self._debug = debug
self.state = State()

# Initialize router and middleware stack
self.router: Router = Router(
Expand All @@ -64,13 +61,8 @@ def __init__(
lifespan=Lifespan(self, lifespan),
)
self.app = self.router
self.exception_middleware = ExceptionMiddleware(self.router, debug=debug)
self.error_middleware = ServerErrorMiddleware(self.exception_middleware, debug=debug)
self.user_middleware = [] if middleware is None else list(middleware)
self.middleware_stack = self.build_middleware_stack()

# Add exception handler for API exceptions
self.add_exception_handler(HTTPException, self.api_http_exception_handler)
self.middleware = MiddlewareStack(app=self.app, middleware=middleware or [], debug=debug)

# Initialize Modules
self.modules = Modules(
Expand Down Expand Up @@ -98,11 +90,26 @@ def __init__(
self.paginator = paginator

def __getattr__(self, item: str) -> "Module":
"""Retrieve a module by its name.
:param item: Module name.
:return: Module.
"""
try:
return self.modules.__getattr__(item)
except KeyError:
return None # type: ignore[return-value]

async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> None:
"""Perform a request.
:param scope: ASGI scope.
:param receive: ASGI receive event.
:param send: ASGI send event.
"""
scope["app"] = self
await self.middleware(scope, receive, send)

def add_route( # type: ignore[override]
self,
path: typing.Optional[str] = None,
Expand All @@ -112,6 +119,15 @@ def add_route( # type: ignore[override]
include_in_schema: bool = True,
route: typing.Optional["BaseRoute"] = None,
) -> None: # pragma: no cover
"""Register a new HTTP route or endpoint under given path.
:param path: URL path.
:param endpoint: HTTP endpoint.
:param methods: List of valid HTTP methods (only applies for routes).
:param name: Endpoint or route name.
:param include_in_schema: True if this route or endpoint should be declared as part of the API schema.
:param route: HTTP route.
"""
self.router.add_route(
path, endpoint, methods=methods, name=name, include_in_schema=include_in_schema, route=route
)
Expand All @@ -123,42 +139,72 @@ def add_websocket_route( # type: ignore[override]
name: typing.Optional[str] = None,
route: typing.Optional["WebSocketRoute"] = None,
) -> None: # pragma: no cover
"""Register a new websocket route or endpoint under given path.
:param path: URL path.
:param endpoint: Websocket endpoint.
:param name: Endpoint or route name.
:param route: Websocket route.
"""
self.router.add_websocket_route(path, endpoint, name=name, route=route)

@property
def injector(self) -> Injector:
"""Components dependency injector.
:return: Injector instance.
"""
return Injector(self.components)

@property
def components(self) -> "Components":
"""Components register.
:return: Components register.
"""
return self.router.components

def add_component(self, component: "Component"):
"""Add a new component to the register.
:param component: Component to include.
"""
self.router.add_component(component)

@property
def routes(self) -> typing.List["BaseRoute"]: # type: ignore[override]
"""List of registered routes.
:return: Routes.
"""
return self.router.routes

def mount(self, path: str, app: "App", name: str = None) -> None: # type: ignore[override]
"""Mount a new ASGI application under given path.
:param path: URL path.
:param app: ASGI application.
:param name: Application name.
"""
self.router.mount(path, app=app, name=name)

def build_middleware_stack(self) -> "App": # type: ignore[override]
debug = self.debug
def add_exception_handler(
self, exc_class_or_status_code: typing.Union[int, typing.Type[Exception]], handler: typing.Callable
):
"""Add a new exception handler for given status code or exception class.
middleware = (
[Middleware(ServerErrorMiddleware, debug=debug)]
+ self.user_middleware
+ [Middleware(ExceptionMiddleware, handlers=self.exception_handlers, debug=debug)]
)
:param exc_class_or_status_code: Status code or exception class.
:param handler: Exception handler.
"""
self.middleware.add_exception_handler(exc_class_or_status_code, handler)

app = self.router
for cls, options in reversed(middleware):
app = cls(app=app, **options)
return app
def add_middleware(self, middleware_class: typing.Type, **options: typing.Any):
"""Add a new middleware to the stack.
def api_http_exception_handler(self, request: "Request", exc: HTTPException) -> "Response":
return APIErrorResponse(detail=exc.detail, status_code=exc.status_code, exception=exc)
:param middleware_class: Middleware class.
:param options: Keyword arguments used to initialise middleware.
"""
self.middleware.add_middleware(Middleware(middleware_class, **options))

get = functools.partialmethod(Starlette.route, methods=["GET"])
head = functools.partialmethod(Starlette.route, methods=["HEAD"])
Expand Down
133 changes: 117 additions & 16 deletions flama/debug/middleware.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,154 @@
import abc
import dataclasses
import inspect
import typing
from pathlib import Path

from flama.asgi import App, Message, Receive, Scope, Send
from flama import concurrency
from flama.debug.types import ErrorContext
from flama.exceptions import HTTPException, WebSocketException
from flama.http import PlainTextResponse, Request, Response
from flama.responses import HTMLTemplateResponse
from flama.responses import APIErrorResponse, HTMLTemplateResponse
from flama.websockets import WebSocket

if typing.TYPE_CHECKING:
from flama.asgi import App, Message, Receive, Scope, Send

__all__ = ["ServerErrorMiddleware", "ExceptionMiddleware"]

TEMPLATES_PATH = Path(__file__).parents[1].resolve() / "templates" / "debug"

Handler = typing.NewType("Handler", typing.Callable[[Request, Exception], Response])


class ServerErrorMiddleware:
def __init__(self, app: App, debug: bool = False) -> None:
class BaseErrorMiddleware(abc.ABC):
def __init__(self, app: "App", debug: bool = False) -> None:
self.app = app
self.debug = debug

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return

response_started = False

async def _send(message: Message) -> None:
async def sender(message: "Message") -> None:
nonlocal response_started, send

if message["type"] == "http.response.start":
response_started = True
await send(message)

try:
await self.app(scope, receive, _send)
await self.app(scope, receive, sender)
except Exception as exc:
request = Request(scope)
response = self.debug_response(request, exc) if self.debug else self.error_response(request, exc)
await self.process_exception(scope, receive, send, exc, response_started)

@abc.abstractmethod
async def process_exception(
self, scope: "Scope", receive: "Receive", send: "Send", exc: Exception, response_started: bool
) -> None:
...


class ServerErrorMiddleware(BaseErrorMiddleware):
def _get_handler(self) -> Handler:
return self.debug_response if self.debug else self.error_response

if not response_started:
await response(scope, receive, send)
async def process_exception(
self, scope: "Scope", receive: "Receive", send: "Send", exc: Exception, response_started: bool
) -> None:
handler = self._get_handler()
response = handler(Request(scope), exc)

# We always continue to raise the exception.
# This allows servers to log the error, or test clients to optionally raise the error within the test case.
raise exc
if not response_started:
await response(scope, receive, send)

# We always continue to raise the exception.
# This allows servers to log the error, or test clients to optionally raise the error within the test case.
raise exc

def debug_response(self, request: Request, exc: Exception) -> Response:
accept = request.headers.get("accept", "")

if "text/html" in accept:
context = dataclasses.asdict(ErrorContext.build(request, exc))
return HTMLTemplateResponse("debug/error_500.html", context)
return HTMLTemplateResponse(
"debug/error_500.html", context=dataclasses.asdict(ErrorContext.build(request, exc))
)
return PlainTextResponse("Internal Server Error", status_code=500)

def error_response(self, request: Request, exc: Exception) -> Response:
return PlainTextResponse("Internal Server Error", status_code=500)


class ExceptionMiddleware(BaseErrorMiddleware):
def __init__(
self, app: "App", handlers: typing.Optional[typing.Mapping[typing.Any, Handler]] = None, debug: bool = False
):
super().__init__(app, debug)
handlers = handlers or {}
self._status_handlers: typing.Dict[int, typing.Callable] = {
status_code: handler for status_code, handler in handlers.items() if isinstance(status_code, int)
}
self._exception_handlers: typing.Dict[typing.Type[Exception], typing.Callable] = {
HTTPException: self.http_exception,
WebSocketException: self.websocket_exception,
**{e: handler for e, handler in handlers.items() if inspect.isclass(e) and issubclass(e, Exception)},
}

def add_exception_handler(
self,
handler: Handler,
status_code: typing.Optional[int] = None,
exc_class: typing.Optional[typing.Type[Exception]] = None,
) -> None:
if status_code is None and exc_class is None:
raise ValueError("Status code or exception class must be defined")

if status_code is not None:
self._status_handlers[status_code] = handler

if exc_class is not None:
self._exception_handlers[exc_class] = handler

def _get_handler(self, exc: Exception) -> Handler:
if isinstance(exc, HTTPException) and exc.status_code in self._status_handlers:
return self._status_handlers[exc.status_code]
else:
try:
return next(
self._exception_handlers[cls] for cls in type(exc).__mro__ if cls in self._exception_handlers
)
except StopIteration:
raise exc

async def process_exception(
self, scope: "Scope", receive: "Receive", send: "Send", exc: Exception, response_started: bool
) -> None:
handler = self._get_handler(exc)

if response_started:
raise RuntimeError("Caught handled exception, but response already started.") from exc

if scope["type"] == "http":
request = Request(scope, receive=receive)
response = await concurrency.run(handler, request, exc)
await response(scope, receive, send)
elif scope["type"] == "websocket":
websocket = WebSocket(scope, receive=receive, send=send)
await concurrency.run(handler, websocket, exc)

def http_exception(self, request: Request, exc: HTTPException) -> Response:
if exc.status_code in {204, 304}:
return Response(status_code=exc.status_code, headers=exc.headers)

accept = request.headers.get("accept", "")

if self.debug and exc.status_code == 404 and "text/html" in accept:
return PlainTextResponse(content=exc.detail, status_code=exc.status_code)

return APIErrorResponse(detail=exc.detail, status_code=exc.status_code, exception=exc)

async def websocket_exception(self, websocket: WebSocket, exc: WebSocketException) -> None:
await websocket.close(code=exc.code, reason=exc.reason)
Loading

0 comments on commit abfceb6

Please sign in to comment.