diff --git a/flama/applications.py b/flama/applications.py index 049a1f15..243054df 100644 --- a/flama/applications.py +++ b/flama/applications.py @@ -2,26 +2,22 @@ import typing from starlette.applications import Starlette -from starlette.middleware.exceptions import ExceptionMiddleware +from starlette.datastructures import State -from flama.debug.middleware import ServerErrorMiddleware -from flama.exceptions import HTTPException from flama.injection import Injector from flama.lifespan import Lifespan -from flama.middleware import Middleware +from flama.middleware import Middleware, MiddlewareStack from flama.models.modules import ModelsModule from flama.modules import Modules from flama.pagination import paginator from flama.resources import ResourcesModule -from flama.responses import APIErrorResponse from flama.routing import Router from flama.schemas.modules import SchemaModule from flama.sqlalchemy import SQLAlchemyModule if typing.TYPE_CHECKING: - from flama.asgi import App + from flama.asgi import App, Receive, Scope, Send from flama.components import Component, Components - from flama.http import Request, Response from flama.modules import Module from flama.routing import BaseRoute, Mount, WebSocketRoute @@ -37,7 +33,7 @@ def __init__( routes: typing.Sequence[typing.Union["BaseRoute", "Mount"]] = None, components: typing.Optional[typing.List["Component"]] = None, modules: typing.Optional[typing.List[typing.Type["Module"]]] = None, - middleware: typing.Sequence["Middleware"] = None, + middleware: typing.Optional[typing.Sequence["Middleware"]] = None, debug: bool = False, on_startup: typing.Sequence[typing.Callable] = None, on_shutdown: typing.Sequence[typing.Callable] = None, @@ -52,7 +48,8 @@ def __init__( *args, **kwargs ) -> None: - super().__init__(debug, *args, **kwargs) + self._debug = debug + self.state = State() # Initialize router and middleware stack self.router: Router = Router( @@ -64,13 +61,8 @@ def __init__( lifespan=Lifespan(self, lifespan), ) self.app = self.router - self.exception_middleware = ExceptionMiddleware(self.router, debug=debug) - self.error_middleware = ServerErrorMiddleware(self.exception_middleware, debug=debug) - self.user_middleware = [] if middleware is None else list(middleware) - self.middleware_stack = self.build_middleware_stack() - # Add exception handler for API exceptions - self.add_exception_handler(HTTPException, self.api_http_exception_handler) + self.middleware = MiddlewareStack(app=self.app, middleware=middleware or [], debug=debug) # Initialize Modules self.modules = Modules( @@ -98,11 +90,26 @@ def __init__( self.paginator = paginator def __getattr__(self, item: str) -> "Module": + """Retrieve a module by its name. + + :param item: Module name. + :return: Module. + """ try: return self.modules.__getattr__(item) except KeyError: return None # type: ignore[return-value] + async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> None: + """Perform a request. + + :param scope: ASGI scope. + :param receive: ASGI receive event. + :param send: ASGI send event. + """ + scope["app"] = self + await self.middleware(scope, receive, send) + def add_route( # type: ignore[override] self, path: typing.Optional[str] = None, @@ -112,6 +119,15 @@ def add_route( # type: ignore[override] include_in_schema: bool = True, route: typing.Optional["BaseRoute"] = None, ) -> None: # pragma: no cover + """Register a new HTTP route or endpoint under given path. + + :param path: URL path. + :param endpoint: HTTP endpoint. + :param methods: List of valid HTTP methods (only applies for routes). + :param name: Endpoint or route name. + :param include_in_schema: True if this route or endpoint should be declared as part of the API schema. + :param route: HTTP route. + """ self.router.add_route( path, endpoint, methods=methods, name=name, include_in_schema=include_in_schema, route=route ) @@ -123,42 +139,72 @@ def add_websocket_route( # type: ignore[override] name: typing.Optional[str] = None, route: typing.Optional["WebSocketRoute"] = None, ) -> None: # pragma: no cover + """Register a new websocket route or endpoint under given path. + + :param path: URL path. + :param endpoint: Websocket endpoint. + :param name: Endpoint or route name. + :param route: Websocket route. + """ self.router.add_websocket_route(path, endpoint, name=name, route=route) @property def injector(self) -> Injector: + """Components dependency injector. + + :return: Injector instance. + """ return Injector(self.components) @property def components(self) -> "Components": + """Components register. + + :return: Components register. + """ return self.router.components def add_component(self, component: "Component"): + """Add a new component to the register. + + :param component: Component to include. + """ self.router.add_component(component) @property def routes(self) -> typing.List["BaseRoute"]: # type: ignore[override] + """List of registered routes. + + :return: Routes. + """ return self.router.routes def mount(self, path: str, app: "App", name: str = None) -> None: # type: ignore[override] + """Mount a new ASGI application under given path. + + :param path: URL path. + :param app: ASGI application. + :param name: Application name. + """ self.router.mount(path, app=app, name=name) - def build_middleware_stack(self) -> "App": # type: ignore[override] - debug = self.debug + def add_exception_handler( + self, exc_class_or_status_code: typing.Union[int, typing.Type[Exception]], handler: typing.Callable + ): + """Add a new exception handler for given status code or exception class. - middleware = ( - [Middleware(ServerErrorMiddleware, debug=debug)] - + self.user_middleware - + [Middleware(ExceptionMiddleware, handlers=self.exception_handlers, debug=debug)] - ) + :param exc_class_or_status_code: Status code or exception class. + :param handler: Exception handler. + """ + self.middleware.add_exception_handler(exc_class_or_status_code, handler) - app = self.router - for cls, options in reversed(middleware): - app = cls(app=app, **options) - return app + def add_middleware(self, middleware_class: typing.Type, **options: typing.Any): + """Add a new middleware to the stack. - def api_http_exception_handler(self, request: "Request", exc: HTTPException) -> "Response": - return APIErrorResponse(detail=exc.detail, status_code=exc.status_code, exception=exc) + :param middleware_class: Middleware class. + :param options: Keyword arguments used to initialise middleware. + """ + self.middleware.add_middleware(Middleware(middleware_class, **options)) get = functools.partialmethod(Starlette.route, methods=["GET"]) head = functools.partialmethod(Starlette.route, methods=["HEAD"]) diff --git a/flama/debug/middleware.py b/flama/debug/middleware.py index 0f131a4a..b30a5030 100644 --- a/flama/debug/middleware.py +++ b/flama/debug/middleware.py @@ -1,27 +1,39 @@ +import abc import dataclasses +import inspect +import typing from pathlib import Path -from flama.asgi import App, Message, Receive, Scope, Send +from flama import concurrency from flama.debug.types import ErrorContext +from flama.exceptions import HTTPException, WebSocketException from flama.http import PlainTextResponse, Request, Response -from flama.responses import HTMLTemplateResponse +from flama.responses import APIErrorResponse, HTMLTemplateResponse +from flama.websockets import WebSocket + +if typing.TYPE_CHECKING: + from flama.asgi import App, Message, Receive, Scope, Send + +__all__ = ["ServerErrorMiddleware", "ExceptionMiddleware"] TEMPLATES_PATH = Path(__file__).parents[1].resolve() / "templates" / "debug" +Handler = typing.NewType("Handler", typing.Callable[[Request, Exception], Response]) + -class ServerErrorMiddleware: - def __init__(self, app: App, debug: bool = False) -> None: +class BaseErrorMiddleware(abc.ABC): + def __init__(self, app: "App", debug: bool = False) -> None: self.app = app self.debug = debug - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> None: if scope["type"] != "http": await self.app(scope, receive, send) return response_started = False - async def _send(message: Message) -> None: + async def sender(message: "Message") -> None: nonlocal response_started, send if message["type"] == "http.response.start": @@ -29,25 +41,114 @@ async def _send(message: Message) -> None: await send(message) try: - await self.app(scope, receive, _send) + await self.app(scope, receive, sender) except Exception as exc: - request = Request(scope) - response = self.debug_response(request, exc) if self.debug else self.error_response(request, exc) + await self.process_exception(scope, receive, send, exc, response_started) + + @abc.abstractmethod + async def process_exception( + self, scope: "Scope", receive: "Receive", send: "Send", exc: Exception, response_started: bool + ) -> None: + ... + + +class ServerErrorMiddleware(BaseErrorMiddleware): + def _get_handler(self) -> Handler: + return self.debug_response if self.debug else self.error_response - if not response_started: - await response(scope, receive, send) + async def process_exception( + self, scope: "Scope", receive: "Receive", send: "Send", exc: Exception, response_started: bool + ) -> None: + handler = self._get_handler() + response = handler(Request(scope), exc) - # We always continue to raise the exception. - # This allows servers to log the error, or test clients to optionally raise the error within the test case. - raise exc + if not response_started: + await response(scope, receive, send) + + # We always continue to raise the exception. + # This allows servers to log the error, or test clients to optionally raise the error within the test case. + raise exc def debug_response(self, request: Request, exc: Exception) -> Response: accept = request.headers.get("accept", "") if "text/html" in accept: - context = dataclasses.asdict(ErrorContext.build(request, exc)) - return HTMLTemplateResponse("debug/error_500.html", context) + return HTMLTemplateResponse( + "debug/error_500.html", context=dataclasses.asdict(ErrorContext.build(request, exc)) + ) return PlainTextResponse("Internal Server Error", status_code=500) def error_response(self, request: Request, exc: Exception) -> Response: return PlainTextResponse("Internal Server Error", status_code=500) + + +class ExceptionMiddleware(BaseErrorMiddleware): + def __init__( + self, app: "App", handlers: typing.Optional[typing.Mapping[typing.Any, Handler]] = None, debug: bool = False + ): + super().__init__(app, debug) + handlers = handlers or {} + self._status_handlers: typing.Dict[int, typing.Callable] = { + status_code: handler for status_code, handler in handlers.items() if isinstance(status_code, int) + } + self._exception_handlers: typing.Dict[typing.Type[Exception], typing.Callable] = { + HTTPException: self.http_exception, + WebSocketException: self.websocket_exception, + **{e: handler for e, handler in handlers.items() if inspect.isclass(e) and issubclass(e, Exception)}, + } + + def add_exception_handler( + self, + handler: Handler, + status_code: typing.Optional[int] = None, + exc_class: typing.Optional[typing.Type[Exception]] = None, + ) -> None: + if status_code is None and exc_class is None: + raise ValueError("Status code or exception class must be defined") + + if status_code is not None: + self._status_handlers[status_code] = handler + + if exc_class is not None: + self._exception_handlers[exc_class] = handler + + def _get_handler(self, exc: Exception) -> Handler: + if isinstance(exc, HTTPException) and exc.status_code in self._status_handlers: + return self._status_handlers[exc.status_code] + else: + try: + return next( + self._exception_handlers[cls] for cls in type(exc).__mro__ if cls in self._exception_handlers + ) + except StopIteration: + raise exc + + async def process_exception( + self, scope: "Scope", receive: "Receive", send: "Send", exc: Exception, response_started: bool + ) -> None: + handler = self._get_handler(exc) + + if response_started: + raise RuntimeError("Caught handled exception, but response already started.") from exc + + if scope["type"] == "http": + request = Request(scope, receive=receive) + response = await concurrency.run(handler, request, exc) + await response(scope, receive, send) + elif scope["type"] == "websocket": + websocket = WebSocket(scope, receive=receive, send=send) + await concurrency.run(handler, websocket, exc) + + def http_exception(self, request: Request, exc: HTTPException) -> Response: + if exc.status_code in {204, 304}: + return Response(status_code=exc.status_code, headers=exc.headers) + + accept = request.headers.get("accept", "") + + if self.debug and exc.status_code == 404 and "text/html" in accept: + return PlainTextResponse(content=exc.detail, status_code=exc.status_code) + + return APIErrorResponse(detail=exc.detail, status_code=exc.status_code, exception=exc) + + async def websocket_exception(self, websocket: WebSocket, exc: WebSocketException) -> None: + await websocket.close(code=exc.code, reason=exc.reason) diff --git a/flama/middleware.py b/flama/middleware.py index f019f32b..f34dabb3 100644 --- a/flama/middleware.py +++ b/flama/middleware.py @@ -1,19 +1,24 @@ -# pragma: no cover +import typing from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.cors import CORSMiddleware -from starlette.middleware.exceptions import ExceptionMiddleware from starlette.middleware.gzip import GZipMiddleware from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware from starlette.middleware.trustedhost import TrustedHostMiddleware +from flama.debug.middleware import ExceptionMiddleware, ServerErrorMiddleware + try: from starlette.middleware.sessions import SessionMiddleware except Exception: SessionMiddleware = None # type: ignore +if typing.TYPE_CHECKING: + from flama.asgi import App, Receive, Scope, Send + from flama.http import Request, Response + __all__ = [ "AuthenticationMiddleware", "BaseHTTPMiddleware", @@ -22,6 +27,62 @@ "GZipMiddleware", "HTTPSRedirectMiddleware", "Middleware", + "MiddlewareStack", "SessionMiddleware", "TrustedHostMiddleware", ] + + +class MiddlewareStack: + def __init__(self, app: "App", middleware: typing.Sequence[Middleware], debug: bool): + self.app = app + self.middleware = list(middleware) + self.debug = debug + self._exception_handlers: typing.Dict[ + typing.Union[int, typing.Type[Exception]], typing.Callable[["Request", Exception], "Response"] + ] = {} + self._stack: typing.Optional["App"] = None + + @property + def stack(self) -> "App": + if self._stack is None: + app = self.app + for cls, options in reversed( + [ + Middleware(ServerErrorMiddleware, debug=self.debug), + *self.middleware, + Middleware(ExceptionMiddleware, handlers=self._exception_handlers, debug=self.debug), + ] + ): + app = cls(app=app, **options) + self._stack = app + + return self._stack + + @stack.deleter + def stack(self): + self._stack = None + + def add_exception_handler( + self, + key: typing.Union[int, typing.Type[Exception]], + handler: typing.Callable[["Request", Exception], "Response"], + ): + """Adds a new handler for an exception type or a HTTP status code. + + :param key: Exception type or HTTP status code. + :param handler: Exception handler. + """ + self._exception_handlers[key] = handler + del self.stack + + def add_middleware(self, middleware: Middleware): + """Adds a new middleware to the stack. + + :param middleware: Middleware. + """ + self.middleware.insert(0, middleware) + del self.stack + + async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> None: + await self.stack(scope, receive, send) diff --git a/flama/routing.py b/flama/routing.py index 9a5aa78e..c8ce399e 100644 --- a/flama/routing.py +++ b/flama/routing.py @@ -6,15 +6,17 @@ import starlette.routing from starlette.routing import Match -from flama import asgi, concurrency, http, websockets +from flama import concurrency, http, websockets from flama.components import Component, Components -from flama.responses import APIResponse, Response -from flama.schemas import adapter +from flama.exceptions import HTTPException +from flama.responses import APIResponse, PlainTextResponse, Response from flama.schemas.routing import RouteParametersMixin from flama.schemas.validation import get_output_schema from flama.types import HTTPMethod +from flama.websockets import WebSocketClose if typing.TYPE_CHECKING: + from flama import asgi from flama.applications import Flama from flama.lifespan import Lifespan @@ -31,9 +33,7 @@ async def prepare_http_request(app: "Flama", handler: typing.Callable, state: ty response = await concurrency.run(injected_func) # Wrap response data with a proper response class - if adapter.is_schema(response): - response = APIResponse(content=response, schema=response.__class__) - elif isinstance(response, (dict, list)): + if isinstance(response, (dict, list)): response = APIResponse(content=response, schema=get_output_schema(handler)) elif isinstance(response, str): response = APIResponse(content=response) @@ -109,18 +109,18 @@ def __init__( # Replace function with another wrapper that uses the injector if inspect.isfunction(endpoint) or inspect.ismethod(endpoint): - self.app = self.endpoint_wrapper(endpoint) + self.app = self.endpoint_wrapper(endpoint) # type: ignore[assignment] if self.methods is None: self.methods = {m for m in HTTPMethod.__members__.keys() if hasattr(self.endpoint, m.lower())} - def endpoint_wrapper(self, endpoint: typing.Callable) -> asgi.App: + def endpoint_wrapper(self, endpoint: typing.Callable) -> "asgi.App": """ Wraps a http function into ASGI application. """ @wraps(endpoint) - async def _app(scope: asgi.Scope, receive: asgi.Receive, send: asgi.Send) -> None: + async def _app(scope: "asgi.Scope", receive: "asgi.Receive", send: "asgi.Send") -> None: app = scope["app"] route, route_scope = app.router.get_route_from_scope(scope) state = { @@ -148,15 +148,15 @@ def __init__(self, path: str, endpoint: typing.Callable, main_app: "Flama" = Non # Replace function with another wrapper that uses the injector if inspect.isfunction(endpoint): - self.app = self.endpoint_wrapper(endpoint) + self.app = self.endpoint_wrapper(endpoint) # type: ignore[assignment] - def endpoint_wrapper(self, endpoint: typing.Callable) -> asgi.App: + def endpoint_wrapper(self, endpoint: typing.Callable) -> "asgi.App": """ Wraps websocket function into ASGI application. """ @wraps(endpoint) - async def _app(scope: asgi.Scope, receive: asgi.Receive, send: asgi.Send) -> None: + async def _app(scope: "asgi.Scope", receive: "asgi.Receive", send: "asgi.Send") -> None: app = scope["app"] route, route_scope = app.router.get_route_from_scope(scope) @@ -189,13 +189,13 @@ def __init__( self, path: str, main_app: "Flama" = None, - app: asgi.App = None, + app: "asgi.App" = None, routes: typing.Sequence[BaseRoute] = None, components: typing.Sequence[Component] = None, name: str = None, ): if app is None: - app = Router(routes=routes, components=components) + app = Router(routes=routes, components=components) # type: ignore[assignment] super().__init__(path, app, routes, name) @@ -260,7 +260,7 @@ def components(self) -> Components: def add_component(self, component: Component): self._components.append(component) - def mount(self, path: str, app: asgi.App, name: str = None) -> None: + def mount(self, path: str, app: "asgi.App", name: str = None) -> None: try: main_app = self.main_app except AttributeError: @@ -335,9 +335,23 @@ def decorator(func: typing.Callable) -> typing.Callable: return decorator + async def not_found(self, scope: "asgi.Scope", receive: "asgi.Receive", send: "asgi.Send") -> None: + if scope["type"] == "websocket": + websocket_close = WebSocketClose() + await websocket_close(scope, receive, send) + return + + # If we're running inside a starlette application then raise an exception, so that the configurable exception + # handler can deal with returning the response. For plain ASGI apps, just return the response. + if "app" in scope: + raise HTTPException(status_code=404) + + response = PlainTextResponse("Not Found", status_code=404) + await response(scope, receive, send) + def get_route_from_scope( self, scope, mounted=False - ) -> typing.Tuple[typing.Union[BaseRoute, asgi.App], typing.Optional[typing.Dict]]: + ) -> typing.Tuple[typing.Union[BaseRoute, "asgi.App"], typing.Optional[typing.Dict]]: partial = None for route in self.routes: diff --git a/flama/websockets.py b/flama/websockets.py index d18973c7..b0b30559 100644 --- a/flama/websockets.py +++ b/flama/websockets.py @@ -1,10 +1,19 @@ import typing -from starlette.websockets import WebSocket +from starlette.websockets import WebSocket, WebSocketClose, WebSocketDisconnect, WebSocketState from flama.asgi import Message -__all__ = ["WebSocket", "Message", "Code", "Encoding", "Data"] +__all__ = [ + "WebSocket", + "WebSocketClose", + "WebSocketState", + "WebSocketDisconnect", + "Message", + "Code", + "Encoding", + "Data", +] Code = typing.NewType("Code", int) diff --git a/tests/conftest.py b/tests/conftest.py index 4a209139..22baf8b7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ import asyncio from contextlib import ExitStack from time import sleep +from unittest.mock import AsyncMock import marshmallow import pytest @@ -8,7 +9,7 @@ import typesystem from faker import Faker -from flama import Flama +from flama import Flama, asgi from flama.sqlalchemy import metadata from flama.testclient import TestClient @@ -106,6 +107,31 @@ def client(app): yield client +@pytest.fixture(scope="function") +def asgi_scope(): + return asgi.Scope( + { + "type": "http", + "method": "GET", + "scheme": "https", + "path": "/", + "root_path": "/", + "query_string": b"", + "headers": [], + } + ) + + +@pytest.fixture(scope="function") +def asgi_receive(): + return AsyncMock(spec=asgi.Receive) + + +@pytest.fixture(scope="function") +def asgi_send(): + return AsyncMock(spec=asgi.Send) + + def assert_recursive_contains(first, second): if isinstance(first, dict) and isinstance(second, dict): assert first.keys() <= second.keys() diff --git a/tests/debug/__init__.py b/tests/debug/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/debug/test_middleware.py b/tests/debug/test_middleware.py new file mode 100644 index 00000000..b271ef5c --- /dev/null +++ b/tests/debug/test_middleware.py @@ -0,0 +1,296 @@ +from unittest.mock import AsyncMock, MagicMock, call, patch + +import pytest + +from flama.asgi import Receive, Scope, Send +from flama.debug.middleware import BaseErrorMiddleware, ExceptionMiddleware, ServerErrorMiddleware +from flama.debug.types import ErrorContext +from flama.exceptions import HTTPException, WebSocketException +from flama.http import Request +from flama.responses import APIErrorResponse, HTMLTemplateResponse, PlainTextResponse, Response +from flama.websockets import WebSocket + + +class TestCaseBaseErrorMiddleware: + @pytest.fixture + def middleware_cls(self): + class FooMiddleware(BaseErrorMiddleware): + async def process_exception( + self, scope: Scope, receive: Receive, send: Send, exc: Exception, response_started: bool + ) -> None: + ... + + return FooMiddleware + + def test_init(self, middleware_cls): + app = AsyncMock() + middleware = middleware_cls(app=app, debug=True) + assert middleware.app == app + assert middleware.debug + + async def test_call_http(self, middleware_cls, asgi_scope, asgi_receive, asgi_send): + exc = ValueError() + app = AsyncMock(side_effect=exc) + middleware = middleware_cls(app=app, debug=True) + with patch.object(middleware, "process_exception", new_callable=AsyncMock): + await middleware(asgi_scope, asgi_receive, asgi_send) + assert middleware.app.call_count == 1 + assert middleware.process_exception.call_args_list == [ + call(asgi_scope, asgi_receive, asgi_send, exc, False) + ] + + async def test_call_websocket(self, middleware_cls, asgi_scope, asgi_receive, asgi_send): + asgi_scope["type"] = "websocket" + app = AsyncMock() + middleware = middleware_cls(app=app, debug=True) + with patch.object(middleware, "process_exception", new_callable=AsyncMock): + await middleware(asgi_scope, asgi_receive, asgi_send) + assert middleware.app.call_args_list == [call(asgi_scope, asgi_receive, asgi_send)] + assert middleware.process_exception.call_args_list == [] + + +class TestCaseServerErrorMiddleware: + @pytest.fixture + def middleware(self): + return ServerErrorMiddleware(app=AsyncMock()) + + @pytest.mark.parametrize( + ["debug"], + ( + pytest.param(True, id="debug"), + pytest.param(False, id="no_debug"), + ), + ) + def test_get_handler(self, middleware, debug): + middleware.debug = debug + handler = middleware.debug_response if debug else middleware.error_response + + result_handler = middleware._get_handler() + + assert result_handler == handler + + @pytest.mark.parametrize( + ["debug", "response_started"], + ( + pytest.param(True, False, id="debug_not_started"), + pytest.param(True, True, id="debug_started"), + pytest.param(False, False, id="error_not_started"), + pytest.param(False, True, id="error_started"), + ), + ) + async def test_process_exception(self, middleware, asgi_scope, asgi_receive, asgi_send, debug, response_started): + middleware.debug = debug + response_method = "debug_response" if debug else "error_response" + exc = ValueError("Foo") + with patch.object( + ServerErrorMiddleware, response_method, new=MagicMock(return_value=AsyncMock()) + ) as response, pytest.raises(ValueError, match="Foo"): + await middleware.process_exception(asgi_scope, asgi_receive, asgi_send, exc, response_started) + + if debug: + assert ServerErrorMiddleware.debug_response.call_args_list == [call(Request(asgi_scope), exc)] + else: + assert ServerErrorMiddleware.error_response.call_args_list == [call(Request(asgi_scope), exc)] + + if response_started: + assert response.call_args_list == [call(asgi_scope, asgi_receive, asgi_send)] + else: + assert response.call_args_list == [] + + def test_debug_response_html(self, middleware, asgi_scope): + asgi_scope["headers"].append((b"accept", b"text/html")) + request = Request(asgi_scope) + exc = ValueError() + error_context_mock, context_mock = MagicMock(), MagicMock() + with patch( + "flama.debug.middleware.dataclasses.asdict", return_value=context_mock + ) as dataclasses_dict, patch.object(ErrorContext, "build", return_value=error_context_mock), patch.object( + HTMLTemplateResponse, "__init__", return_value=None + ): + response = middleware.debug_response(request, exc) + assert ErrorContext.build.call_args_list == [call(request, exc)] + assert dataclasses_dict.call_args_list == [call(error_context_mock)] + assert isinstance(response, HTMLTemplateResponse) + assert HTMLTemplateResponse.__init__.call_args_list == [call("debug/error_500.html", context=context_mock)] + + def test_debug_response_text(self, middleware, asgi_scope): + request = Request(asgi_scope) + exc = ValueError() + with patch.object(PlainTextResponse, "__init__", return_value=None): + response = middleware.debug_response(request, exc) + assert isinstance(response, PlainTextResponse) + assert PlainTextResponse.__init__.call_args_list == [call("Internal Server Error", status_code=500)] + + def test_error_response(self, middleware, asgi_scope): + request = Request(asgi_scope) + exc = ValueError() + with patch.object(PlainTextResponse, "__init__", return_value=None): + response = middleware.error_response(request, exc) + assert isinstance(response, PlainTextResponse) + assert PlainTextResponse.__init__.call_args_list == [call("Internal Server Error", status_code=500)] + + +class TestCaseExceptionMiddleware: + @pytest.fixture + def middleware(self): + return ExceptionMiddleware(app=AsyncMock()) + + @pytest.fixture + def handler(self): + def _handler(): + ... + + return _handler + + def test_init(self, handler): + app = AsyncMock() + debug = True + + middleware = ExceptionMiddleware(app=app, handlers={400: handler, ValueError: handler}, debug=debug) + + assert middleware.app == app + assert middleware.debug == debug + assert middleware._status_handlers == {400: handler} + assert middleware._exception_handlers == { + ValueError: handler, + HTTPException: middleware.http_exception, + WebSocketException: middleware.websocket_exception, + } + + @pytest.mark.parametrize( + ["status_code", "exc_class", "exception"], + ( + pytest.param(400, None, None, id="status_code"), + pytest.param(None, ValueError, None, id="exc_class"), + pytest.param(400, ValueError, None, id="status_code_and_exc_class"), + pytest.param(None, None, ValueError("Status code or exception class must be defined"), id="no_key"), + ), + indirect=["exception"], + ) + def test_add_exception_handler(self, middleware, handler, status_code, exc_class, exception): + status_code_handlers = {} + if status_code is not None: + status_code_handlers[status_code] = handler + exc_class_handlers = { + HTTPException: middleware.http_exception, + WebSocketException: middleware.websocket_exception, + } + if exc_class is not None: + exc_class_handlers[exc_class] = handler + + with exception: + middleware.add_exception_handler(handler, status_code=status_code, exc_class=exc_class) + + assert middleware._status_handlers == status_code_handlers + assert middleware._exception_handlers == exc_class_handlers + + @pytest.mark.parametrize( + ["status_code", "exc_class", "key", "exception"], + ( + pytest.param(400, None, HTTPException(400), None, id="status_code"), + pytest.param(None, ValueError, ValueError("Foo"), None, id="exc_class"), + pytest.param(400, ValueError, ValueError("Foo"), None, id="status_code_and_exc_class"), + pytest.param(None, Exception, ValueError("Foo"), None, id="child_exc_class"), + pytest.param(400, None, HTTPException(401), HTTPException(401), id="handler_not_found"), + ), + indirect=["exception"], + ) + def test_get_handler(self, middleware, handler, status_code, exc_class, key, exception): + # Force clean all handlers + middleware._status_handlers = {} + middleware._exception_handlers = {} + middleware.add_exception_handler(handler=handler, status_code=status_code, exc_class=exc_class) + + with exception: + result_handler = middleware._get_handler(key) + if not exception: + assert result_handler == handler + + @pytest.mark.parametrize( + ["request_type", "response_started", "exception"], + ( + pytest.param("http", False, None, id="http"), + pytest.param("websocket", False, None, id="websocket"), + pytest.param( + None, + True, + RuntimeError("Caught handled exception, but response already started."), + id="response_started_error", + ), + ), + indirect=["exception"], + ) + async def test_process_exception( + self, middleware, asgi_scope, asgi_receive, asgi_send, request_type, response_started, exception + ): + expected_exc = ValueError() + asgi_scope["type"] = request_type + handler_mock = MagicMock() + response_mock = AsyncMock() + with exception, patch.object(middleware, "_get_handler", return_value=handler_mock), patch( + "flama.debug.middleware.concurrency.run", new=AsyncMock(return_value=response_mock) + ) as run_mock: + await middleware.process_exception(asgi_scope, asgi_receive, asgi_send, expected_exc, response_started) + + if request_type == "http": + assert run_mock.call_count == 1 + handler, request, exc = run_mock.call_args_list[0][0] + assert handler == handler_mock + assert isinstance(request, Request) + assert request.scope == asgi_scope + assert exc == expected_exc + assert response_mock.call_args_list == [call(asgi_scope, asgi_receive, asgi_send)] + + elif request_type == "websocket": + assert run_mock.call_count == 1 + handler, websocket, exc = run_mock.call_args_list[0][0] + assert handler == handler_mock + assert isinstance(websocket, WebSocket) + assert websocket.scope == asgi_scope + assert exc == expected_exc + + @pytest.mark.parametrize( + ["debug", "accept", "exc", "response_class", "response_params"], + ( + pytest.param(False, None, HTTPException(204), Response, {"status_code": 204, "headers": None}, id="204"), + pytest.param(False, None, HTTPException(304), Response, {"status_code": 304, "headers": None}, id="304"), + pytest.param( + True, + b"text/html", + HTTPException(404, "Foo"), + PlainTextResponse, + {"content": "Foo", "status_code": 404}, + id="debug_404", + ), + pytest.param( + False, + None, + HTTPException(400, "Foo"), + APIErrorResponse, + {"detail": "Foo", "status_code": 400}, + id="other", + ), + ), + ) + def test_http_exception(self, middleware, asgi_scope, debug, accept, exc, response_class, response_params): + middleware.debug = debug + + if accept: + asgi_scope["headers"].append((b"accept", accept)) + + if response_class == APIErrorResponse: + response_params["exception"] = exc + + request = Request(asgi_scope) + with patch.object(response_class, "__init__", return_value=None): + middleware.http_exception(request, exc) + + assert response_class.__init__.call_args_list == [call(**response_params)] + + async def test_websocket_exception(self, middleware): + websocket = AsyncMock() + exc = WebSocketException(1011, "Foo reason") + + await middleware.websocket_exception(websocket, exc) + + assert websocket.close.call_args_list == [call(code=exc.code, reason=exc.reason)] diff --git a/tests/test_applications.py b/tests/test_applications.py index 62d13340..61095bd7 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -1,11 +1,13 @@ -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock, call, patch import pytest +from starlette.middleware import Middleware from flama import Component, Flama, Module, Mount, Route, Router from flama.applications import DEFAULT_MODULES from flama.components import Components from flama.injection import Injector +from flama.middleware import MiddlewareStack class TestCaseFlama: @@ -89,3 +91,36 @@ def test_declarative_recursion(self, component_mock, module_mock): # Check modules are isolated for each app assert mount_app.modules == [*DEFAULT_MODULES, module_mock] assert root_app.modules == DEFAULT_MODULES + + @pytest.mark.parametrize( + ["key", "handler"], + (pytest.param(400, MagicMock(), id="status_code"), pytest.param(ValueError, MagicMock(), id="exception_class")), + ) + def test_add_exception_handler(self, app, key, handler): + expected_call = [call(key, handler)] + + with patch.object(app, "middleware", spec=MiddlewareStack): + app.add_exception_handler(key, handler) + assert app.middleware.add_exception_handler.call_args_list == expected_call + + def test_add_middleware(self, app): + class FooMiddleware: + def __call__(self, *args, **kwargs): + ... + + options = {"foo": "bar"} + + with patch.object(app, "middleware", spec=MiddlewareStack): + app.add_middleware(FooMiddleware, **options) + assert len(app.middleware.add_middleware.call_args_list) == 1 + middleware = app.middleware.add_middleware.call_args[0][0] + assert isinstance(middleware, Middleware) + assert middleware.cls == FooMiddleware + assert middleware.options == options + + async def test_call(self, app, asgi_scope, asgi_receive, asgi_send): + with patch.object(app, "middleware", new=AsyncMock(spec=MiddlewareStack)): + await app(asgi_scope, asgi_receive, asgi_send) + assert "app" in asgi_scope + assert asgi_scope["app"] == app + assert app.middleware.call_args_list == [call(asgi_scope, asgi_receive, asgi_send)] diff --git a/tests/test_middleware.py b/tests/test_middleware.py new file mode 100644 index 00000000..81732b78 --- /dev/null +++ b/tests/test_middleware.py @@ -0,0 +1,71 @@ +from unittest.mock import AsyncMock, PropertyMock, call, patch + +import pytest + +from flama import http +from flama.debug.middleware import ExceptionMiddleware, ServerErrorMiddleware +from flama.middleware import Middleware, MiddlewareStack + + +class TestCaseMiddlewareStack: + @pytest.fixture + def middleware(self): + class FooMiddleware: + def __init__(self, *args, **kwargs): + ... + + def __call__(self, *args, **kwargs): + return None + + return Middleware(FooMiddleware) + + @pytest.fixture + def app(self): + return AsyncMock() + + @pytest.fixture + def stack(self, app, middleware): + return MiddlewareStack(app=app, middleware=[], debug=True) + + def test_init(self, app, middleware): + stack = MiddlewareStack(app=app, middleware=[middleware], debug=True) + + assert stack.app == app + assert stack.middleware == [middleware] + assert stack.debug + assert stack._exception_handlers == {} + + def test_stack(self, stack): + assert stack._stack is None + assert stack.stack + assert isinstance(stack._stack, ServerErrorMiddleware) + assert isinstance(stack._stack.app, ExceptionMiddleware) + + def test_add_exception_handler(self, stack): + def handler(request: http.Request, exc: Exception) -> http.Response: + ... + + assert stack._stack is None + assert stack.stack + assert stack._stack is not None + + stack.add_exception_handler(400, handler) + stack.add_exception_handler(ValueError, handler) + + assert stack._stack is None + assert stack._exception_handlers == {400: handler, ValueError: handler} + + def test_add_middleware(self, stack, middleware): + assert stack._stack is None + assert stack.stack + assert stack._stack is not None + + stack.add_middleware(middleware) + + assert stack._stack is None + assert stack.middleware == [middleware] + + async def test_call(self, stack, asgi_scope, asgi_receive, asgi_send): + with patch.object(MiddlewareStack, "stack", new=PropertyMock(return_value=AsyncMock())): + await stack(asgi_scope, asgi_receive, asgi_send) + assert stack.stack.call_args_list == [call(asgi_scope, asgi_receive, asgi_send)] diff --git a/tests/test_routing.py b/tests/test_routing.py index 4276c48f..eec78b84 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -1,11 +1,51 @@ -from unittest.mock import MagicMock +import json +from unittest.mock import AsyncMock, MagicMock, PropertyMock, call, patch import pytest from flama.applications import Flama from flama.components import Component, Components from flama.endpoints import HTTPEndpoint, WebSocketEndpoint -from flama.routing import Mount, Route, Router, WebSocketRoute +from flama.exceptions import HTTPException +from flama.responses import APIResponse +from flama.routing import Mount, Route, Router, WebSocketRoute, prepare_http_request + + +class TestCasePrepareHTTPRequest: + @pytest.fixture + def handler(self): + def _handler(): + ... + + return _handler + + @pytest.mark.parametrize( + ["content"], + ( + pytest.param({"name": "Canna", "custom_id": 6}, id="dict"), + pytest.param("foo", id="str"), + pytest.param(None, id="none"), + ), + ) + async def test_prepare_http_request(self, app, puppy_schema, handler, content): + with patch.object(type(app), "injector", new_callable=PropertyMock) as injector_mock, patch( + "flama.routing.concurrency" + ) as concurrency_mock, patch("flama.routing.get_output_schema", return_value=puppy_schema): + injector_mock().inject = AsyncMock() + concurrency_mock.run = AsyncMock(return_value=content) + response = await prepare_http_request(app, handler, {}) + + assert isinstance(response, APIResponse) + if response.body: + assert json.loads(response.body) == content + + async def test_prepare_http_exception(self, app, handler): + with patch.object(type(app), "injector", new_callable=PropertyMock) as injector_mock, patch( + "flama.routing.concurrency" + ) as concurrency_mock, pytest.raises(ValueError): + injector_mock().inject = AsyncMock() + concurrency_mock.run = AsyncMock(side_effect=ValueError) + await prepare_http_request(app, handler, {}) class TestCaseRouter: @@ -83,6 +123,10 @@ async def get(self): assert router.routes[0].path == "/" assert router.routes[0].endpoint == FooEndpoint + def test_add_route_wrong(self, router): + with pytest.raises(ValueError, match="Either 'path' and 'endpoint' or 'route' variables are needed"): + router.add_route() + def test_add_websocket_route(self, router): async def foo(): return "foo" @@ -115,6 +159,10 @@ async def on_receive(self, websocket): assert router.routes[0].path == "/" assert router.routes[0].endpoint == FooEndpoint + def test_add_websocket_route_wrong(self, router): + with pytest.raises(ValueError, match="Either 'path' and 'endpoint' or 'route' variables are needed"): + router.add_websocket_route() + def test_mount_app(self, app, app_mock): app.mount("/app/", app=app_mock) @@ -223,6 +271,37 @@ def test_mount_declarative(self, component_mock): with pytest.raises(AttributeError): app.routes[1].routes[0].main_app + async def test_not_found_websocket(self, router, asgi_scope, asgi_receive, asgi_send): + asgi_scope["type"] = "websocket" + + websocket_close_instance_mock = AsyncMock() + websocket_close_mock = MagicMock(return_value=websocket_close_instance_mock) + with patch("flama.routing.WebSocketClose", new=websocket_close_mock): + await router.not_found(asgi_scope, asgi_receive, asgi_send) + assert websocket_close_mock.call_args_list == [call()] + assert websocket_close_instance_mock.call_args_list == [call(asgi_scope, asgi_receive, asgi_send)] + + async def test_not_found_flama_app(self, router, asgi_scope, asgi_receive, asgi_send): + asgi_scope["app"] = MagicMock() + + with pytest.raises(HTTPException) as exc_info: + await router.not_found(asgi_scope, asgi_receive, asgi_send) + + assert exc_info.type is HTTPException + assert exc_info.value.args == [400] + + async def test_not_found_no_app(self, router, asgi_scope, asgi_receive, asgi_send): + if "app" in asgi_scope: + del asgi_scope["app"] + + response_instance_mock = AsyncMock() + response_mock = MagicMock(return_value=response_instance_mock) + with patch("flama.routing.PlainTextResponse", new=response_mock): + await router.not_found(asgi_scope, asgi_receive, asgi_send) + + assert response_mock.call_args_list == [call("Not Found", status_code=404)] + assert response_instance_mock.call_args_list == [call(asgi_scope, asgi_receive, asgi_send)] + def test_get_route_from_scope_route(self, app, scope): @app.route("/foo/") async def foo():