diff --git a/examples/background.py b/examples/background.py new file mode 100755 index 00000000..c1381641 --- /dev/null +++ b/examples/background.py @@ -0,0 +1,22 @@ +import asyncio + +import uvicorn + +from flama import BackgroundThreadTask, Flama +from flama.responses import JSONResponse + +app = Flama() + + +async def sleep_task(value: int): + await asyncio.sleep(value) + + +@app.route("/") +async def test(): + task = BackgroundThreadTask(sleep_task, 10) + return JSONResponse("hello", background=task) + + +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/examples/data_schema.py b/examples/data_schema.py old mode 100644 new mode 100755 diff --git a/examples/hello_world.py b/examples/hello_world.py old mode 100644 new mode 100755 diff --git a/examples/pagination.py b/examples/pagination.py old mode 100644 new mode 100755 diff --git a/examples/resource.py b/examples/resource.py old mode 100644 new mode 100755 diff --git a/flama/__init__.py b/flama/__init__.py index 67e85c34..3ef820b8 100644 --- a/flama/__init__.py +++ b/flama/__init__.py @@ -1,6 +1,7 @@ from starlette.config import Config # noqa from flama.applications import * # noqa +from flama.background import * # noqa from flama.components import Component # noqa from flama.endpoints import * # noqa from flama.modules import Module # noqa diff --git a/flama/background.py b/flama/background.py new file mode 100644 index 00000000..88d8a635 --- /dev/null +++ b/flama/background.py @@ -0,0 +1,97 @@ +import asyncio +import enum +import functools +import sys +import typing +from multiprocessing import Process + +import starlette.background +from starlette.concurrency import run_in_threadpool + +if sys.version_info >= (3, 10): # PORT: Remove when stop supporting 3.9 # pragma: no cover + from typing import ParamSpec +else: # pragma: no cover + from typing_extensions import ParamSpec + +__all__ = ["BackgroundTask", "BackgroundTasks", "Concurrency", "BackgroundThreadTask", "BackgroundProcessTask"] + +P = ParamSpec("P") + + +class Concurrency(enum.Enum): + thread = "thread" + process = "process" + + +class BackgroundTask(starlette.background.BackgroundTask): + def __init__( + self, + concurrency: typing.Union[Concurrency, str], + func: typing.Callable[P, typing.Any], + *args: P.args, + **kwargs: P.kwargs + ) -> None: + self.func = self._create_task_function(func) + self.args = args + self.kwargs = kwargs + self.concurrency = Concurrency(concurrency) + + def _create_task_function(self, func: typing.Callable[P, typing.Any]) -> typing.Callable[P, typing.Any]: + if asyncio.iscoroutinefunction(func): + + @functools.wraps(func) + async def _inner(*args, **kwargs): + await func(*args, **kwargs) + + else: + + @functools.wraps(func) + async def _inner(*args, **kwargs): + await run_in_threadpool(func, *args, **kwargs) + + return _inner + + def _create_process_target(self, func: typing.Callable[P, typing.Any]): + @functools.wraps(func) + def process_target(*args: P.args, **kwargs: P.kwargs): + policy = asyncio.get_event_loop_policy() + loop = policy.new_event_loop() + policy.set_event_loop(loop) + loop.run_until_complete(func(*args, **kwargs)) + + return process_target + + async def __call__(self): + if self.concurrency == Concurrency.process: + Process(target=self._create_process_target(self.func), args=self.args, kwargs=self.kwargs).start() + else: + await self.func(*self.args, **self.kwargs) + + +class BackgroundTasks(BackgroundTask): + def __init__(self, tasks: typing.Sequence[BackgroundTask] = None): + self.tasks = list(tasks) if tasks else [] + + def add_task( + self, + concurrency: typing.Union[Concurrency, str], + func: typing.Callable[P, typing.Any], + *args: P.args, + **kwargs: P.kwargs + ) -> None: + task = BackgroundTask(concurrency, func, *args, **kwargs) + self.tasks.append(task) + + async def __call__(self) -> None: + for task in self.tasks: + await task() + + +class BackgroundThreadTask(BackgroundTask): + def __init__(self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs): + super().__init__(Concurrency.thread, func, *args, **kwargs) + + +class BackgroundProcessTask(BackgroundTask): + def __init__(self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs): + super().__init__(Concurrency.process, func, *args, **kwargs) diff --git a/flama/responses.py b/flama/responses.py index d6232698..4c962eae 100644 --- a/flama/responses.py +++ b/flama/responses.py @@ -67,9 +67,9 @@ def render(self, content: typing.Any) -> bytes: class APIResponse(JSONResponse): media_type = "application/json" - def __init__(self, schema: typing.Optional[Schema] = None, *args, **kwargs): + def __init__(self, content: typing.Any = None, schema: typing.Optional[Schema] = None, *args, **kwargs): self.schema = schema - super().__init__(*args, **kwargs) + super().__init__(content=content, *args, **kwargs) def render(self, content: typing.Any): if content and is_schema_instance(content): # pragma: no cover (only apply to marshmallow) diff --git a/tests/test_background.py b/tests/test_background.py new file mode 100644 index 00000000..42399c2e --- /dev/null +++ b/tests/test_background.py @@ -0,0 +1,100 @@ +import time +from tempfile import NamedTemporaryFile + +import anyio +import pytest + +from flama import BackgroundProcessTask, BackgroundTasks, BackgroundThreadTask, Concurrency +from flama.responses import APIResponse + + +def sync_task(path: str, msg: str): + with open(path, "w") as f: + f.write(msg) + + +async def async_task(path: str, msg: str): + async with await anyio.open_file(path, "w") as f: + await f.write(msg) + + +class TestCaseBackgroundTask: + @pytest.fixture(params=["sync", "async"]) + def task(self, request): + if request.param == "sync": + + def _task(path: str, msg: str): + with open(path, "w") as f: + f.write(msg) + + else: + + async def _task(path: str, msg: str): + async with await anyio.open_file(path, "w") as f: + await f.write(msg) + + return _task + + @pytest.fixture + def tmp_file(self): + with NamedTemporaryFile() as tmp_file: + yield tmp_file + + def test_background_process_task(self, app, client, task, tmp_file): + @app.route("/") + async def test(path: str, msg: str): + return APIResponse({"foo": "bar"}, background=BackgroundProcessTask(task, path, msg)) + + response = client.get("/", params={"path": tmp_file.name, "msg": "foo"}) + assert response.status_code == 200 + assert response.json() == {"foo": "bar"} + + time.sleep(0.2) + with open(tmp_file.name) as f: + assert f.read() == "foo" + + def test_background_thread_task(self, app, client, task, tmp_file): + @app.route("/") + async def test(path: str, msg: str): + return APIResponse({"foo": "bar"}, background=BackgroundThreadTask(task, path, msg)) + + response = client.get("/", params={"path": tmp_file.name, "msg": "foo"}) + assert response.status_code == 200 + assert response.json() == {"foo": "bar"} + + time.sleep(0.2) + with open(tmp_file.name) as f: + assert f.read() == "foo" + + +class TestCaseBackgroundTasks: + @pytest.fixture + def tmp_file(self): + with NamedTemporaryFile() as tmp_file: + yield tmp_file + + @pytest.fixture + def tmp_file_2(self): + with NamedTemporaryFile() as tmp_file: + yield tmp_file + + def test_background_tasks(self, app, client, tmp_file, tmp_file_2): + @app.route("/") + async def test(path_1: str, msg_1: str, path_2: str, msg_2: str): + tasks = BackgroundTasks() + tasks.add_task(Concurrency.process, sync_task, path_1, msg_1) + tasks.add_task(Concurrency.thread, async_task, path_2, msg_2) + return APIResponse({"foo": "bar"}, background=tasks) + + response = client.get( + "/", params={"path_1": tmp_file.name, "msg_1": "foo", "path_2": tmp_file_2.name, "msg_2": "bar"} + ) + assert response.status_code == 200 + assert response.json() == {"foo": "bar"} + + time.sleep(0.2) + with open(tmp_file.name) as f: + assert f.read() == "foo" + + with open(tmp_file_2.name) as f: + assert f.read() == "bar"