-
Notifications
You must be signed in to change notification settings - Fork 11
/
utils.py
146 lines (122 loc) · 4.63 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""
A set of higher-level coroutine aggregation utilities based on :class:`Supervisor`.
"""
from __future__ import annotations
import asyncio
from contextlib import aclosing
from contextvars import Context
from typing import (
Any,
AsyncGenerator,
Awaitable,
Iterable,
List,
Optional,
Sequence,
Tuple,
TypeVar,
)
from .supervisor import Supervisor
__all__ = (
"as_completed_safe",
"gather_safe",
"race",
)
T = TypeVar("T")
async def as_completed_safe(
coros: Iterable[Awaitable[T]],
*,
context: Optional[Context] = None,
) -> AsyncGenerator[Awaitable[T], None]:
"""
This is a safer version of :func:`asyncio.as_completed()` which uses
:class:`aiotools.Supervisor` as an underlying coroutine lifecycle keeper.
This requires Python 3.11 or higher to work properly with timeouts.
.. versionadded:: 1.6
.. versionchanged:: 2.0
It now uses :class:`aiotools.Supervisor` internally and handles
timeouts in a bettery way.
"""
q: asyncio.Queue[asyncio.Task[Any]] = asyncio.Queue()
remaining = 0
def result_callback(t: asyncio.Task[Any]) -> None:
q.put_nowait(t)
async with Supervisor() as supervisor:
for coro in coros:
t = supervisor.create_task(coro, context=context)
t.add_done_callback(result_callback)
remaining += 1
while remaining:
try:
t = await q.get()
remaining -= 1
try:
yield t
finally:
q.task_done()
except (GeneratorExit, BaseException):
# GeneratorExit: injected when aclose() is called.
# (i.e., the async-for body raises an exception)
# CancelledError: injected when a timeout occurs
# (i.e., the outer scope cancels the inner)
# BaseException: injected when the process is going to terminate
await supervisor.shutdown()
raise
async def gather_safe(
coros: Iterable[Awaitable[T]],
*,
context: Optional[Context] = None,
) -> List[T | Exception]:
"""
A safer version of :func:`asyncio.gather()`. It wraps the passed coroutines
with a :class:`Supervisor` to ensure the termination of them when returned.
Additionally, it supports manually setting the context of each subtask.
Note that if it is cancelled from an outer scope (e.g., timeout), there
is no way to retrieve partially completed or failed results.
If you need to process them anyway, you must store the results in a
separate place in the passed coroutines or use :func:`as_completed_safe()`
instead.
.. versionadded:: 2.0
"""
tasks = []
async with Supervisor() as supervisor:
for coro in coros:
t = supervisor.create_task(coro, context=context)
tasks.append(t)
# To ensure safety, the Python version must be 3.7 or higher.
return await asyncio.gather(*tasks, return_exceptions=True)
async def race(
coros: Iterable[Awaitable[T]],
*,
continue_on_error: bool = False,
context: Optional[Context] = None,
) -> Tuple[T, Sequence[Exception]]:
"""
Returns the first result and cancelling all remaining coroutines safely.
Passing an empty iterable of coroutines is not allowed.
If ``continue_on_error`` is set False (default), it will raise the first
exception immediately, cancelling all remaining coroutines. This behavior is
same to Javascript's ``Promise.race()``. The second item of the returned tuple
is always empty.
If ``continue_on_error`` is set True, it will keep running until it encounters
the first successful result. Then it returns the exceptions as a list in the
second item of the returned tuple. If all coroutines fail, it will raise an
:exc:`ExceptionGroup` to indicate the explicit failure of the entire operation.
You may use this function to implement a "happy eyeball" algorithm.
.. versionadded:: 2.0
"""
async with aclosing(as_completed_safe(coros, context=context)) as ag:
errors: list[Exception] = []
async for aresult in ag:
try:
result = await aresult
return result, errors
except Exception as e:
if continue_on_error:
errors.append(e)
continue
raise
else:
if errors:
raise ExceptionGroup("All coroutines have failed in race()", errors)
raise ValueError("No coroutines were given to race()")