-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
✨ BackgroundTask using multiprocessing
- Loading branch information
Showing
9 changed files
with
222 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import asyncio | ||
|
||
import uvicorn | ||
|
||
from flama import BackgroundThreadTask, Flama | ||
from flama.responses import JSONResponse | ||
|
||
app = Flama() | ||
|
||
|
||
async def sleep_task(value: int): | ||
await asyncio.sleep(value) | ||
|
||
|
||
@app.route("/") | ||
async def test(): | ||
task = BackgroundThreadTask(sleep_task, 10) | ||
return JSONResponse("hello", background=task) | ||
|
||
|
||
if __name__ == "__main__": | ||
uvicorn.run(app, host="0.0.0.0", port=8000) |
Empty file.
Empty file.
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import asyncio | ||
import enum | ||
import functools | ||
import sys | ||
import typing | ||
from multiprocessing import Process | ||
|
||
import starlette.background | ||
from starlette.concurrency import run_in_threadpool | ||
|
||
if sys.version_info >= (3, 10): # PORT: Remove when stop supporting 3.9 # pragma: no cover | ||
from typing import ParamSpec | ||
else: # pragma: no cover | ||
from typing_extensions import ParamSpec | ||
|
||
__all__ = ["BackgroundTask", "BackgroundTasks", "Concurrency", "BackgroundThreadTask", "BackgroundProcessTask"] | ||
|
||
P = ParamSpec("P") | ||
|
||
|
||
class Concurrency(enum.Enum): | ||
thread = "thread" | ||
process = "process" | ||
|
||
|
||
class BackgroundTask(starlette.background.BackgroundTask): | ||
def __init__( | ||
self, | ||
concurrency: typing.Union[Concurrency, str], | ||
func: typing.Callable[P, typing.Any], | ||
*args: P.args, | ||
**kwargs: P.kwargs | ||
) -> None: | ||
self.func = self._create_task_function(func) | ||
self.args = args | ||
self.kwargs = kwargs | ||
self.concurrency = Concurrency(concurrency) | ||
|
||
def _create_task_function(self, func: typing.Callable[P, typing.Any]) -> typing.Callable[P, typing.Any]: | ||
if asyncio.iscoroutinefunction(func): | ||
|
||
@functools.wraps(func) | ||
async def _inner(*args, **kwargs): | ||
await func(*args, **kwargs) | ||
|
||
else: | ||
|
||
@functools.wraps(func) | ||
async def _inner(*args, **kwargs): | ||
await run_in_threadpool(func, *args, **kwargs) | ||
|
||
return _inner | ||
|
||
def _create_process_target(self, func: typing.Callable[P, typing.Any]): | ||
@functools.wraps(func) | ||
def process_target(*args: P.args, **kwargs: P.kwargs): | ||
policy = asyncio.get_event_loop_policy() | ||
loop = policy.new_event_loop() | ||
policy.set_event_loop(loop) | ||
loop.run_until_complete(func(*args, **kwargs)) | ||
|
||
return process_target | ||
|
||
async def __call__(self): | ||
if self.concurrency == Concurrency.process: | ||
Process(target=self._create_process_target(self.func), args=self.args, kwargs=self.kwargs).start() | ||
else: | ||
await self.func(*self.args, **self.kwargs) | ||
|
||
|
||
class BackgroundTasks(BackgroundTask): | ||
def __init__(self, tasks: typing.Sequence[BackgroundTask] = None): | ||
self.tasks = list(tasks) if tasks else [] | ||
|
||
def add_task( | ||
self, | ||
concurrency: typing.Union[Concurrency, str], | ||
func: typing.Callable[P, typing.Any], | ||
*args: P.args, | ||
**kwargs: P.kwargs | ||
) -> None: | ||
task = BackgroundTask(concurrency, func, *args, **kwargs) | ||
self.tasks.append(task) | ||
|
||
async def __call__(self) -> None: | ||
for task in self.tasks: | ||
await task() | ||
|
||
|
||
class BackgroundThreadTask(BackgroundTask): | ||
def __init__(self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs): | ||
super().__init__(Concurrency.thread, func, *args, **kwargs) | ||
|
||
|
||
class BackgroundProcessTask(BackgroundTask): | ||
def __init__(self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs): | ||
super().__init__(Concurrency.process, func, *args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import time | ||
from tempfile import NamedTemporaryFile | ||
|
||
import anyio | ||
import pytest | ||
|
||
from flama import BackgroundProcessTask, BackgroundTasks, BackgroundThreadTask, Concurrency | ||
from flama.responses import APIResponse | ||
|
||
|
||
def sync_task(path: str, msg: str): | ||
with open(path, "w") as f: | ||
f.write(msg) | ||
|
||
|
||
async def async_task(path: str, msg: str): | ||
async with await anyio.open_file(path, "w") as f: | ||
await f.write(msg) | ||
|
||
|
||
class TestCaseBackgroundTask: | ||
@pytest.fixture(params=["sync", "async"]) | ||
def task(self, request): | ||
if request.param == "sync": | ||
|
||
def _task(path: str, msg: str): | ||
with open(path, "w") as f: | ||
f.write(msg) | ||
|
||
else: | ||
|
||
async def _task(path: str, msg: str): | ||
async with await anyio.open_file(path, "w") as f: | ||
await f.write(msg) | ||
|
||
return _task | ||
|
||
@pytest.fixture | ||
def tmp_file(self): | ||
with NamedTemporaryFile() as tmp_file: | ||
yield tmp_file | ||
|
||
def test_background_process_task(self, app, client, task, tmp_file): | ||
@app.route("/") | ||
async def test(path: str, msg: str): | ||
return APIResponse({"foo": "bar"}, background=BackgroundProcessTask(task, path, msg)) | ||
|
||
response = client.get("/", params={"path": tmp_file.name, "msg": "foo"}) | ||
assert response.status_code == 200 | ||
assert response.json() == {"foo": "bar"} | ||
|
||
time.sleep(0.2) | ||
with open(tmp_file.name) as f: | ||
assert f.read() == "foo" | ||
|
||
def test_background_thread_task(self, app, client, task, tmp_file): | ||
@app.route("/") | ||
async def test(path: str, msg: str): | ||
return APIResponse({"foo": "bar"}, background=BackgroundThreadTask(task, path, msg)) | ||
|
||
response = client.get("/", params={"path": tmp_file.name, "msg": "foo"}) | ||
assert response.status_code == 200 | ||
assert response.json() == {"foo": "bar"} | ||
|
||
time.sleep(0.2) | ||
with open(tmp_file.name) as f: | ||
assert f.read() == "foo" | ||
|
||
|
||
class TestCaseBackgroundTasks: | ||
@pytest.fixture | ||
def tmp_file(self): | ||
with NamedTemporaryFile() as tmp_file: | ||
yield tmp_file | ||
|
||
@pytest.fixture | ||
def tmp_file_2(self): | ||
with NamedTemporaryFile() as tmp_file: | ||
yield tmp_file | ||
|
||
def test_background_tasks(self, app, client, tmp_file, tmp_file_2): | ||
@app.route("/") | ||
async def test(path_1: str, msg_1: str, path_2: str, msg_2: str): | ||
tasks = BackgroundTasks() | ||
tasks.add_task(Concurrency.process, sync_task, path_1, msg_1) | ||
tasks.add_task(Concurrency.thread, async_task, path_2, msg_2) | ||
return APIResponse({"foo": "bar"}, background=tasks) | ||
|
||
response = client.get( | ||
"/", params={"path_1": tmp_file.name, "msg_1": "foo", "path_2": tmp_file_2.name, "msg_2": "bar"} | ||
) | ||
assert response.status_code == 200 | ||
assert response.json() == {"foo": "bar"} | ||
|
||
time.sleep(0.2) | ||
with open(tmp_file.name) as f: | ||
assert f.read() == "foo" | ||
|
||
with open(tmp_file_2.name) as f: | ||
assert f.read() == "bar" |