Skip to content

Commit

Permalink
✨ BackgroundTask using multiprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
perdy committed Jan 19, 2023
1 parent 0c7622a commit 9038b40
Show file tree
Hide file tree
Showing 9 changed files with 222 additions and 2 deletions.
22 changes: 22 additions & 0 deletions examples/background.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import asyncio

import uvicorn

from flama import BackgroundThreadTask, Flama
from flama.responses import JSONResponse

app = Flama()


async def sleep_task(value: int):
await asyncio.sleep(value)


@app.route("/")
async def test():
task = BackgroundThreadTask(sleep_task, 10)
return JSONResponse("hello", background=task)


if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
Empty file modified examples/data_schema.py
100644 → 100755
Empty file.
Empty file modified examples/hello_world.py
100644 → 100755
Empty file.
Empty file modified examples/pagination.py
100644 → 100755
Empty file.
Empty file modified examples/resource.py
100644 → 100755
Empty file.
1 change: 1 addition & 0 deletions flama/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from starlette.config import Config # noqa

from flama.applications import * # noqa
from flama.background import * # noqa
from flama.components import Component # noqa
from flama.endpoints import * # noqa
from flama.modules import Module # noqa
Expand Down
97 changes: 97 additions & 0 deletions flama/background.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import asyncio
import enum
import functools
import sys
import typing
from multiprocessing import Process

import starlette.background
from starlette.concurrency import run_in_threadpool

if sys.version_info >= (3, 10): # PORT: Remove when stop supporting 3.9 # pragma: no cover
from typing import ParamSpec
else: # pragma: no cover
from typing_extensions import ParamSpec

__all__ = ["BackgroundTask", "BackgroundTasks", "Concurrency", "BackgroundThreadTask", "BackgroundProcessTask"]

P = ParamSpec("P")


class Concurrency(enum.Enum):
thread = "thread"
process = "process"


class BackgroundTask(starlette.background.BackgroundTask):
def __init__(
self,
concurrency: typing.Union[Concurrency, str],
func: typing.Callable[P, typing.Any],
*args: P.args,
**kwargs: P.kwargs
) -> None:
self.func = self._create_task_function(func)
self.args = args
self.kwargs = kwargs
self.concurrency = Concurrency(concurrency)

def _create_task_function(self, func: typing.Callable[P, typing.Any]) -> typing.Callable[P, typing.Any]:
if asyncio.iscoroutinefunction(func):

@functools.wraps(func)
async def _inner(*args, **kwargs):
await func(*args, **kwargs)

else:

@functools.wraps(func)
async def _inner(*args, **kwargs):
await run_in_threadpool(func, *args, **kwargs)

return _inner

def _create_process_target(self, func: typing.Callable[P, typing.Any]):
@functools.wraps(func)
def process_target(*args: P.args, **kwargs: P.kwargs):
policy = asyncio.get_event_loop_policy()
loop = policy.new_event_loop()
policy.set_event_loop(loop)
loop.run_until_complete(func(*args, **kwargs))

return process_target

async def __call__(self):
if self.concurrency == Concurrency.process:
Process(target=self._create_process_target(self.func), args=self.args, kwargs=self.kwargs).start()
else:
await self.func(*self.args, **self.kwargs)


class BackgroundTasks(BackgroundTask):
def __init__(self, tasks: typing.Sequence[BackgroundTask] = None):
self.tasks = list(tasks) if tasks else []

def add_task(
self,
concurrency: typing.Union[Concurrency, str],
func: typing.Callable[P, typing.Any],
*args: P.args,
**kwargs: P.kwargs
) -> None:
task = BackgroundTask(concurrency, func, *args, **kwargs)
self.tasks.append(task)

async def __call__(self) -> None:
for task in self.tasks:
await task()


class BackgroundThreadTask(BackgroundTask):
def __init__(self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs):
super().__init__(Concurrency.thread, func, *args, **kwargs)


class BackgroundProcessTask(BackgroundTask):
def __init__(self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs):
super().__init__(Concurrency.process, func, *args, **kwargs)
4 changes: 2 additions & 2 deletions flama/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ def render(self, content: typing.Any) -> bytes:
class APIResponse(JSONResponse):
media_type = "application/json"

def __init__(self, schema: typing.Optional[Schema] = None, *args, **kwargs):
def __init__(self, content: typing.Any = None, schema: typing.Optional[Schema] = None, *args, **kwargs):
self.schema = schema
super().__init__(*args, **kwargs)
super().__init__(content=content, *args, **kwargs)

def render(self, content: typing.Any):
if content and is_schema_instance(content): # pragma: no cover (only apply to marshmallow)
Expand Down
100 changes: 100 additions & 0 deletions tests/test_background.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import time
from tempfile import NamedTemporaryFile

import anyio
import pytest

from flama import BackgroundProcessTask, BackgroundTasks, BackgroundThreadTask, Concurrency
from flama.responses import APIResponse


def sync_task(path: str, msg: str):
with open(path, "w") as f:
f.write(msg)


async def async_task(path: str, msg: str):
async with await anyio.open_file(path, "w") as f:
await f.write(msg)


class TestCaseBackgroundTask:
@pytest.fixture(params=["sync", "async"])
def task(self, request):
if request.param == "sync":

def _task(path: str, msg: str):
with open(path, "w") as f:
f.write(msg)

else:

async def _task(path: str, msg: str):
async with await anyio.open_file(path, "w") as f:
await f.write(msg)

return _task

@pytest.fixture
def tmp_file(self):
with NamedTemporaryFile() as tmp_file:
yield tmp_file

def test_background_process_task(self, app, client, task, tmp_file):
@app.route("/")
async def test(path: str, msg: str):
return APIResponse({"foo": "bar"}, background=BackgroundProcessTask(task, path, msg))

response = client.get("/", params={"path": tmp_file.name, "msg": "foo"})
assert response.status_code == 200
assert response.json() == {"foo": "bar"}

time.sleep(0.2)
with open(tmp_file.name) as f:
assert f.read() == "foo"

def test_background_thread_task(self, app, client, task, tmp_file):
@app.route("/")
async def test(path: str, msg: str):
return APIResponse({"foo": "bar"}, background=BackgroundThreadTask(task, path, msg))

response = client.get("/", params={"path": tmp_file.name, "msg": "foo"})
assert response.status_code == 200
assert response.json() == {"foo": "bar"}

time.sleep(0.2)
with open(tmp_file.name) as f:
assert f.read() == "foo"


class TestCaseBackgroundTasks:
@pytest.fixture
def tmp_file(self):
with NamedTemporaryFile() as tmp_file:
yield tmp_file

@pytest.fixture
def tmp_file_2(self):
with NamedTemporaryFile() as tmp_file:
yield tmp_file

def test_background_tasks(self, app, client, tmp_file, tmp_file_2):
@app.route("/")
async def test(path_1: str, msg_1: str, path_2: str, msg_2: str):
tasks = BackgroundTasks()
tasks.add_task(Concurrency.process, sync_task, path_1, msg_1)
tasks.add_task(Concurrency.thread, async_task, path_2, msg_2)
return APIResponse({"foo": "bar"}, background=tasks)

response = client.get(
"/", params={"path_1": tmp_file.name, "msg_1": "foo", "path_2": tmp_file_2.name, "msg_2": "bar"}
)
assert response.status_code == 200
assert response.json() == {"foo": "bar"}

time.sleep(0.2)
with open(tmp_file.name) as f:
assert f.read() == "foo"

with open(tmp_file_2.name) as f:
assert f.read() == "bar"

0 comments on commit 9038b40

Please sign in to comment.