Skip to content

Commit

Permalink
🐛 Include nested applications in Lifespan (#130)
Browse files Browse the repository at this point in the history
  • Loading branch information
perdy committed Oct 18, 2023
1 parent 12a1b46 commit d76385a
Show file tree
Hide file tree
Showing 10 changed files with 601 additions and 191 deletions.
10 changes: 8 additions & 2 deletions flama/applications.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import functools
import typing as t

from flama import asgi, http, injection, types, url, validation, websockets
from flama import asgi, exceptions, http, injection, types, url, validation, websockets
from flama.ddd.components import WorkerComponent
from flama.events import Events
from flama.middleware import MiddlewareStack
Expand Down Expand Up @@ -57,7 +57,7 @@ def __init__(
:param schema_library: Schema library to use.
"""
self._debug = debug
self._status = types.AppStatus.NOT_INITIALIZED
self._status = types.AppStatus.NOT_STARTED
self._shutdown = False

# Create Dependency Injector
Expand Down Expand Up @@ -131,6 +131,12 @@ async def __call__(self, scope: types.Scope, receive: types.Receive, send: types
:param receive: ASGI receive event.
:param send: ASGI send event.
"""
if scope["type"] != "lifespan" and self._status in (types.AppStatus.NOT_STARTED, types.AppStatus.STARTING):
raise exceptions.ApplicationError("Application is not ready to process requests yet.")

if scope["type"] != "lifespan" and self._status in (types.AppStatus.SHUT_DOWN, types.AppStatus.SHUTTING_DOWN):
raise exceptions.ApplicationError("Application is already shut down.")

scope["app"] = self
await self.middleware(scope, receive, send)

Expand Down
42 changes: 16 additions & 26 deletions flama/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import contextlib
import functools
import importlib.metadata
import logging
Expand Down Expand Up @@ -46,36 +45,23 @@ async def _send(self, message: types.Message) -> None:
self._shutdown_complete.set()

async def _app_task(self) -> None:
with contextlib.suppress(asyncio.CancelledError):
scope = types.Scope({"type": "lifespan"})
scope = types.Scope({"type": "lifespan"})

try:
await self.app(scope, self._receive, self._send)
except BaseException as exc:
self._exception = exc
self._startup_complete.set()
self._shutdown_complete.set()

raise

def _run_app(self) -> None:
self._task = asyncio.get_event_loop().create_task(self._app_task())

async def _stop_app(self) -> None:
assert self._task is not None

if not self._task.done():
self._task.cancel()
try:
await self.app(scope, self._receive, self._send)
except BaseException as exc:
self._exception = exc
self._startup_complete.set()
self._shutdown_complete.set()

await self._task
raise

async def __aenter__(self) -> "LifespanContextManager":
self._run_app()
asyncio.create_task(self._app_task())

try:
await self._startup()
except BaseException:
await self._stop_app()
raise

return self
Expand All @@ -86,8 +72,12 @@ async def __aexit__(
exc_value: t.Optional[BaseException] = None,
traceback: t.Optional[TracebackType] = None,
):
await self._shutdown()
await self._stop_app()
asyncio.create_task(self._app_task())

try:
await self._shutdown()
except BaseException:
raise


class _BaseClient:
Expand Down Expand Up @@ -193,7 +183,7 @@ async def __aexit__(
await self.lifespan.__aexit__(exc_type, exc_value, traceback)
await super().__aexit__(exc_type, exc_value, traceback)

async def model_request(self, model: str, method: str, url: str, **kwargs) -> t.Awaitable[httpx.Response]:
def model_request(self, model: str, method: str, url: str, **kwargs) -> t.Awaitable[httpx.Response]:
assert self.models, "No models found for request."
return self.request(method, f"{self.models[model].rstrip('/')}{url}", **kwargs)

Expand Down
5 changes: 4 additions & 1 deletion flama/debug/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ async def sender(message: types.Message) -> None:
try:
await concurrency.run(self.app, scope, receive, sender)
except Exception as exc:
await self.process_exception(scope, receive, send, exc, response_started)
if scope["type"] in ("http", "websocket"):
await self.process_exception(scope, receive, send, exc, response_started)
else:
raise

@abc.abstractmethod
async def process_exception(
Expand Down
39 changes: 34 additions & 5 deletions flama/lifespan.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
import typing as t

Expand All @@ -13,6 +14,12 @@

class Lifespan(types.AppClass):
def __init__(self, lifespan: t.Optional[t.Callable[[t.Optional["Flama"]], t.AsyncContextManager]] = None):
"""A class that handles the lifespan of an application.
It is responsible for calling the startup and shutdown events and the user defined lifespan.
:param lifespan: A user defined lifespan. It must be a callable that returns an async context manager.
"""
self.lifespan = lifespan

async def __call__(self, scope: types.Scope, receive: types.Receive, send: types.Send) -> None:
Expand All @@ -22,36 +29,49 @@ async def __call__(self, scope: types.Scope, receive: types.Receive, send: types
:param receive: ASGI receive.
:param send: ASGI send.
"""
app = scope["app"]
while True:
async with asyncio.Lock():
app = scope["app"]
message = await receive()
if message["type"] == "lifespan.startup":
if app._status not in (types.AppStatus.NOT_STARTED, types.AppStatus.SHUT_DOWN):
msg = f"Trying to start application from '{app._status}' state"
await send(types.Message({"type": "lifespan.startup.failed", "message": msg}))
raise exceptions.ApplicationError(msg)

try:
logger.info("Application starting")
app._status = types.AppStatus.STARTING
await self._startup(app)
await send(types.Message({"type": "lifespan.startup.complete"}))
await self._child_propagation(app, scope, message, send)
app._status = types.AppStatus.READY
await send(types.Message({"type": "lifespan.startup.complete"}))
logger.info("Application ready")
except BaseException as e:
logger.exception("Application start failed")
app._status = types.AppStatus.FAILED
await send(types.Message({"type": "lifespan.startup.failed", "message": str(e)}))
raise exceptions.ApplicationError("Lifespan startup failed") from e
elif message["type"] == "lifespan.shutdown":
if app._status != types.AppStatus.READY:
msg = f"Trying to shutdown application from '{app._status}' state"
await send(types.Message({"type": "lifespan.shutdown.failed", "message": msg}))
raise exceptions.ApplicationError(msg)

try:
logger.info("Application shutting down")
app._status = types.AppStatus.SHUTTING_DOWN
await self._child_propagation(app, scope, message, send)
await self._shutdown(app)
await send(types.Message({"type": "lifespan.shutdown.complete"}))
app._status = types.AppStatus.SHUT_DOWN
await send(types.Message({"type": "lifespan.shutdown.complete"}))
logger.info("Application shut down")
return
except BaseException as e:
await send(types.Message({"type": "lifespan.shutdown.failed", "message": str(e)}))
app._status = types.AppStatus.FAILED
logger.exception("Application shutdown failed")
raise exceptions.ApplicationError("Lifespan shutdown failed") from e
else:
logger.warning("Unknown lifespan message received: %s", str(message))

async def _startup(self, app: "Flama") -> None:
await concurrency.run_task_group(*(f() for f in app.events.startup))
Expand All @@ -62,3 +82,12 @@ async def _shutdown(self, app: "Flama") -> None:
if self.lifespan:
await self.lifespan(app).__aexit__(None, None, None)
await concurrency.run_task_group(*(f() for f in app.events.shutdown))

async def _child_propagation(
self, app: "Flama", scope: types.Scope, message: types.Message, send: types.Send
) -> None:
async def child_receive() -> types.Message:
return message

if app.routes:
await concurrency.run_task_group(*(route(scope, child_receive, send) for route in app.routes))
78 changes: 46 additions & 32 deletions flama/routing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import abc
import enum
import functools
import inspect
Expand Down Expand Up @@ -190,7 +191,7 @@ def __eq__(self, other) -> bool:
return isinstance(other, EndpointWrapper) and self.handler == other.handler


class BaseRoute(RouteParametersMixin):
class BaseRoute(abc.ABC, RouteParametersMixin):
def __init__(
self,
path: t.Union[str, url.RegexPath],
Expand All @@ -216,8 +217,9 @@ def __init__(
self.tags = tags or {}
super().__init__()

@abc.abstractmethod
async def __call__(self, scope: types.Scope, receive: types.Receive, send: types.Send) -> None:
await self.handle(types.Scope({**scope, **self.route_scope(scope)}), receive, send)
...

def __eq__(self, other: t.Any) -> bool:
return (
Expand Down Expand Up @@ -344,6 +346,10 @@ def __init__(

self.app: EndpointWrapper

async def __call__(self, scope: types.Scope, receive: types.Receive, send: types.Send) -> None:
if scope["type"] == "http":
await self.handle(types.Scope({**scope, **self.route_scope(scope)}), receive, send)

def __eq__(self, other: t.Any) -> bool:
return super().__eq__(other) and isinstance(other, Route) and self.methods == other.methods

Expand Down Expand Up @@ -427,6 +433,10 @@ def __init__(

self.app: EndpointWrapper

async def __call__(self, scope: types.Scope, receive: types.Receive, send: types.Send) -> None:
if scope["type"] == "websocket":
await self.handle(types.Scope({**scope, **self.route_scope(scope)}), receive, send)

def __eq__(self, other: t.Any) -> bool:
return super().__eq__(other) and isinstance(other, WebSocketRoute)

Expand Down Expand Up @@ -489,6 +499,12 @@ def __init__(

super().__init__(url.RegexPath(path.rstrip("/") + "{path:path}"), app, name=name, tags=tags)

async def __call__(self, scope: types.Scope, receive: types.Receive, send: types.Send) -> None:
if scope["type"] in ("http", "websocket") or (
scope["type"] == "lifespan" and types.is_flama_instance(self.app)
):
await self.handle(types.Scope({**scope, **self.route_scope(scope)}), receive, send)

def __eq__(self, other: t.Any) -> bool:
return super().__eq__(other) and isinstance(other, Mount)

Expand All @@ -499,12 +515,10 @@ def build(self, app: t.Optional["Flama"] = None) -> None:
:param app: Flama app.
"""
from flama import Flama

if app and isinstance(self.app, Flama):
if app and types.is_flama_instance(self.app):
self.app.router.components = Components(self.app.router.components + app.components)

if root := (self.app if isinstance(self.app, Flama) else app):
if root := (self.app if types.is_flama_instance(self.app) else app):
for route in self.routes:
route.build(root)

Expand All @@ -531,30 +545,33 @@ async def handle(self, scope: types.Scope, receive: types.Receive, send: types.S
def route_scope(self, scope: types.Scope) -> types.Scope:
"""Build route scope from given scope.
It generates an updated scope parameters for the route:
* app: The app of this mount point. If it's mounting a Flama app, it will replace the app with this one
* path_params: The matched path parameters of this mount point
* endpoint: The endpoint of this mount point
* root_path: The root path of this mount point (if it's mounting a Flama app, it will be empty)
* path: The remaining path to be matched
:param scope: ASGI scope.
:return: Route scope.
"""
from flama import Flama
result = {"app": self.app if types.is_flama_instance(self.app) else scope["app"]}

if "path" in scope:
path = scope["path"]
matched_params = self.path.values(path)
remaining_path = matched_params.pop("path")
matched_path = path[: -len(remaining_path)]
result.update(
{
"path_params": {**dict(scope.get("path_params", {})), **matched_params},
"endpoint": self.endpoint,
"root_path": "" if types.is_flama_instance(self.app) else scope.get("root_path", "") + matched_path,
"path": remaining_path,
}
)

path = scope["path"]
matched_params = self.path.values(path)
remaining_path = matched_params.pop("path")
matched_path = path[: -len(remaining_path)]
if isinstance(self.app, Flama):
app = self.app
root_path = ""
else:
app = scope["app"]
root_path = scope.get("root_path", "") + matched_path
return types.Scope(
{
"app": app,
"path_params": {**dict(scope.get("path_params", {})), **matched_params},
"endpoint": self.endpoint,
"root_path": root_path,
"path": remaining_path,
}
)
return types.Scope(result)

def resolve_url(self, name: str, **params: t.Any) -> url.URL:
"""Builds URL path for given name and params.
Expand Down Expand Up @@ -620,16 +637,13 @@ async def __call__(self, scope: types.Scope, receive: types.Receive, send: types
logger.debug("Request: %s", str(scope))
assert scope["type"] in ("http", "websocket", "lifespan")

if "app" in scope and scope["app"]._status != types.AppStatus.READY and scope["type"] != "lifespan":
raise exceptions.ApplicationError("Application is not ready to process requests yet.")

if "router" not in scope:
scope["router"] = self

if scope["type"] == "lifespan":
await self.lifespan(scope, receive, send)
return

if "router" not in scope:
scope["router"] = self

route, route_scope = self.resolve_route(scope)
await route(route_scope, receive, send)

Expand Down
29 changes: 27 additions & 2 deletions flama/types/applications.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,37 @@
import enum
import sys
import typing as t

__all__ = ["AppStatus"]
if t.TYPE_CHECKING:
from flama import Flama


if sys.version_info < (3, 10): # PORT: Remove when stop supporting 3.9 # pragma: no cover
from typing_extensions import TypeGuard

t.TypeGuard = TypeGuard # type: ignore


__all__ = ["AppStatus", "is_flama_instance"]


class AppStatus(enum.Enum):
NOT_INITIALIZED = enum.auto()
NOT_STARTED = enum.auto()
STARTING = enum.auto()
READY = enum.auto()
SHUTTING_DOWN = enum.auto()
SHUT_DOWN = enum.auto()
FAILED = enum.auto()


def is_flama_instance(
obj: t.Any,
) -> t.TypeGuard["Flama"]: # type: ignore # PORT: Remove this comment when stop supporting 3.9
"""Checks if an object is an instance of Flama.
:param obj: The object to check.
:return: True if the object is an instance of Flama, False otherwise.
"""
from flama import Flama

return isinstance(obj, Flama)
Loading

0 comments on commit d76385a

Please sign in to comment.