Skip to content

Commit

Permalink
Refactor python async trampoline handling
Browse files Browse the repository at this point in the history
  • Loading branch information
dbrattli committed Oct 16, 2023
1 parent 12830a2 commit e3b1f21
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 78 deletions.
25 changes: 16 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,22 @@ pythonVersion = "3.10"
typeCheckingMode = "strict"

[tool.isort]
profile = "black"
atomic = true
lines_after_imports = 2
lines_between_types = 1
multi_line_output = 3 # corresponds to -m flag
include_trailing_comma = true # corresponds to -tc flag
line_length = 88
known_third_party = ["cognite","pytest"]
py_version=310

[tool.ruff]
# Keep in sync with .pre-commit-config.yaml
line-length = 120
ignore = []
target-version = "py310"
select = ["E", "W", "F", "I", "T", "RUF", "TID", "UP"]
exclude = ["tests", "build", "temp", "src/fable_library", "src/fable_library_rust", "src/fable_library_php"]
include =["*.py"]

[tool.ruff.pydocstyle]
convention = "google"

[tool.ruff.isort]
lines-after-imports = 2


[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
65 changes: 53 additions & 12 deletions src/fable-library-py/fable_library/async_.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio

from asyncio import Future, ensure_future
from concurrent.futures import ThreadPoolExecutor
from threading import Timer
Expand All @@ -16,9 +15,9 @@
)

from .async_builder import (
Continuations,
Async,
CancellationToken,
Continuations,
IAsyncContext,
OperationCanceledError,
Trampoline,
Expand All @@ -30,11 +29,12 @@
)

# F# generated code (from Choice.fs)
from .choice import Choice_makeChoice1Of2 # type: ignore
from .choice import Choice_makeChoice2Of2 # type: ignore
from .choice import (
Choice_makeChoice1Of2, # type: ignore
Choice_makeChoice2Of2, # type: ignore
)
from .task import TaskCompletionSource


_T = TypeVar("_T")


Expand All @@ -47,6 +47,7 @@ def cont(ctx: IAsyncContext[Any]):

default_cancellation_token = CancellationToken()


# see AsyncBuilder.Delay
def delay(generator: Callable[[], Async[_T]]):
def cont(ctx: IAsyncContext[_T]):
Expand Down Expand Up @@ -81,7 +82,6 @@ def is_cancellation_requested(token: CancellationToken) -> bool:
def sleep(millisecondsDueTime: int) -> Async[None]:
def cont(ctx: IAsyncContext[None]):
def cancel():
timer.cancel()
ctx.on_cancel(OperationCanceledError())

token_id = ctx.cancel_token.add_listener(cancel)
Expand All @@ -90,8 +90,8 @@ def timeout():
ctx.cancel_token.remove_listener(token_id)
ctx.on_success(None)

timer = Timer(millisecondsDueTime / 1000.0, timeout)
timer.start()
due_time = millisecondsDueTime / 1000.0
ctx.trampoline.run_later(timeout, due_time)

return protected_cont(cont)

Expand All @@ -106,8 +106,10 @@ def binder(_: Optional[Any] = None) -> Async[None]:
def parallel(computations: Iterable[Async[_T]]) -> Async[List[_T]]:
def delayed() -> Async[List[_T]]:
tasks: Iterable[Future[_T]] = map(start_as_task, computations) # type: ignore
all: Future[List[_T]] = asyncio.gather(*tasks)

try:
all: Future[List[_T]] = asyncio.gather(*tasks)
except Exception as ex:
raise ex
return await_task(all)

return delay(delayed)
Expand Down Expand Up @@ -189,7 +191,7 @@ def callback(conts: Continuations[_T]) -> None:
continuation = conts

task.add_done_callback(done)
return from_continuations(callback) # type: ignore
return from_continuations(callback)


def start_with_continuations(
Expand Down Expand Up @@ -251,6 +253,32 @@ def cancel(_: OperationCanceledError) -> None:
return tcs.get_task()


def start_child(computation: Async[_T], ms: Optional[int] = None) -> Async[Async[_T]]:
if ms:
computation_with_timeout = protected_bind(
parallel(computation, throw_after(ms)), lambda xs: protected_return(xs[0])
)
return start_child(computation_with_timeout)

task = start_as_task(computation)

def cont(ctx: IAsyncContext[Async[_T]]) -> None:
def on_success(_: Async[_T]) -> None:
ctx.on_success(await_task(task))

on_error = ctx.on_error
on_cancel = ctx.on_cancel
trampoline = ctx.trampoline
cancel_token = ctx.cancel_token

ctx_ = IAsyncContext.create(
on_success, on_error, on_cancel, trampoline, cancel_token
)
computation(ctx_)

return protected_cont(cont)


def start_immediate(
computation: Async[Any],
cancellation_token: Optional[CancellationToken] = None,
Expand All @@ -260,7 +288,20 @@ def start_immediate(
Runs an asynchronous computation, starting immediately on the
current operating system thread
"""
return start_with_continuations(computation, cancellation_token=cancellation_token)
try:
asyncio.get_event_loop()
except RuntimeError:

async def runner() -> None:
return start_with_continuations(
computation, cancellation_token=cancellation_token
)

return asyncio.run(runner())
else:
return start_with_continuations(
computation, cancellation_token=cancellation_token
)


_executor: Optional[ThreadPoolExecutor] = None
Expand Down
111 changes: 59 additions & 52 deletions src/fable-library-py/fable_library/async_builder.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from __future__ import annotations

import asyncio
from abc import abstractmethod
from collections import deque
from threading import Lock, RLock, Timer
from collections.abc import Callable, Iterable
from dataclasses import dataclass, field
from threading import Lock, RLock
from typing import (
Any,
Callable,
Dict,
Generic,
Iterable,
Optional,
Literal,
Protocol,
Tuple,
TypeVar,
overload,
)
Expand All @@ -25,19 +23,19 @@


class OperationCanceledError(Exception):
def __init__(self, msg: Optional[str] = None) -> None:
def __init__(self, msg: str | None = None) -> None:
super().__init__(msg or "The operation was canceled")


Continuations = Tuple[
Continuations = tuple[
Callable[[_T], None],
Callable[[Exception], None],
Callable[[OperationCanceledError], None],
]


class _Listener(Protocol):
def __call__(self, __state: Optional[Any] = None) -> None:
def __call__(self, __state: Any | None = None) -> None:
...


Expand All @@ -46,7 +44,7 @@ class CancellationToken:

def __init__(self, cancelled: bool = False):
self.cancelled = cancelled
self.listeners: Dict[int, Callable[[], None]] = {}
self.listeners: dict[int, Callable[[], None]] = {}
self.idx = 0
self.lock = RLock()

Expand Down Expand Up @@ -79,7 +77,7 @@ def remove_listener(self, id: int) -> None:
with self.lock:
del self.listeners[id]

def register(self, f: _Listener, state: Optional[Any] = None) -> None:
def register(self, f: _Listener, state: Any | None = None) -> None:
if state:
id = self.add_listener(lambda: f(state))
else:
Expand Down Expand Up @@ -108,7 +106,7 @@ def on_cancel(self, error: OperationCanceledError) -> None:

@property
@abstractmethod
def trampoline(self) -> "Trampoline":
def trampoline(self) -> Trampoline:
...

@trampoline.setter
Expand All @@ -128,12 +126,12 @@ def cancel_token(self, val: CancellationToken):

@staticmethod
def create(
on_success: Optional[Callable[[_T], None]],
on_error: Optional[Callable[[Exception], None]],
on_cancel: Optional[Callable[[OperationCanceledError], None]],
trampoline: Optional[Trampoline],
cancel_token: Optional[CancellationToken],
) -> AnonymousAsyncContext[_T]:
on_success: Callable[[_T], None] | None,
on_error: Callable[[Exception], None] | None,
on_cancel: Callable[[OperationCanceledError], None] | None,
trampoline: Trampoline | None,
cancel_token: CancellationToken | None,
) -> IAsyncContext[_T]:
return AnonymousAsyncContext(
on_success, on_error, on_cancel, trampoline, cancel_token
)
Expand All @@ -152,11 +150,11 @@ class AnonymousAsyncContext(IAsyncContext[_T]):

def __init__(
self,
on_success: Optional[Callable[[_T], None]] = None,
on_error: Optional[Callable[[Exception], None]] = None,
on_cancel: Optional[Callable[[OperationCanceledError], None]] = None,
trampoline: Optional[Trampoline] = None,
cancel_token: Optional[CancellationToken] = None,
on_success: Callable[[_T], None] | None = None,
on_error: Callable[[Exception], None] | None = None,
on_cancel: Callable[[OperationCanceledError], None] | None = None,
trampoline: Trampoline | None = None,
cancel_token: CancellationToken | None = None,
):
self._on_success: Callable[[_T], None] = on_success or empty_continuation
self._on_error: Callable[[Exception], None] = on_error or empty_continuation
Expand Down Expand Up @@ -197,45 +195,45 @@ def cancel_token(self, val: CancellationToken):
self._cancel_token = val


@dataclass(order=True)
class ScheduledItem:
due_time: float
action: Callable[[], None] = field(compare=False)
cancel_token: CancellationToken | None = field(compare=False)


class Trampoline:
__slots__ = "queue", "lock", "running", "call_count"
__slots__ = "lock", "running", "call_count"

MaxTrampolineCallCount = 150 # Max recursion depth: 1000
MaxTrampolineCallCount = 75 # Max recursion depth: 1000

def __init__(self):
self.call_count: int = 0
self.lock = Lock()
self.queue: deque[Callable[[], None]] = deque()
self.running: bool = False
self.running = False

def increment_and_check(self):
with self.lock:
self.call_count = self.call_count + 1
return self.call_count > Trampoline.MaxTrampolineCallCount

def run_later(
self,
action: Callable[[], None],
due_time: float = 0.0,
):
loop = asyncio.get_event_loop()
loop.call_later(due_time, action)

def run(self, action: Callable[[], None]):
loop = asyncio.get_event_loop()

if self.increment_and_check():
with self.lock:
self.queue.append(action)

if not self.running:
self.running = True
timer = Timer(0.0, self._run)
timer.start()
self.call_count = 0
loop.call_soon(action)
else:
action()

def _run(self) -> None:
while len(self.queue):
with self.lock:
self.call_count = 0
action = self.queue.popleft()

action()

self.running = False


def protected_cont(f: Async[_T]) -> Async[_T]:
def _protected_cont(ctx: IAsyncContext[_T]):
Expand Down Expand Up @@ -275,7 +273,9 @@ def on_success(x: _T) -> None:


def protected_return(value: _T) -> Async[_T]:
f: Callable[[IAsyncContext[_T]], None] = lambda ctx: ctx.on_success(value)
def f(ctx: IAsyncContext[_T]) -> None:
return ctx.on_success(value)

return protected_cont(f)


Expand All @@ -288,7 +288,9 @@ def Bind(
return protected_bind(computation, binder)

def Combine(self, computation1: Async[Any], computation2: Async[_T]) -> Async[_T]:
binder: Callable[[_T], Async[_T]] = lambda _: computation2
def binder(_: _T) -> Async[_T]:
return computation2

return self.Bind(computation1, binder)

def Delay(self, generator: Callable[[], Async[_T]]) -> Async[_T]:
Expand Down Expand Up @@ -373,11 +375,15 @@ def on_error(err: Exception) -> None:
return protected_cont(fn)

def Using(self, resource: _D, binder: Callable[[_D], Async[_U]]) -> Async[_U]:
compensation: Callable[[], None] = lambda: resource.Dispose()
def compensation() -> None:
return resource.Dispose()

return self.TryFinally(binder(resource), compensation)

@overload
def While(self, guard: Callable[[], bool], computation: Async[None]) -> Async[None]:
def While(
self, guard: Callable[[], bool], computation: Async[Literal[None]]
) -> Async[None]:
...

@overload
Expand All @@ -386,9 +392,10 @@ def While(self, guard: Callable[[], bool], computation: Async[_T]) -> Async[_T]:

def While(self, guard: Callable[[], bool], computation: Async[Any]) -> Async[Any]:
if guard():
binder: Callable[[Any], Async[Any]] = lambda _: self.While(
guard, computation
)

def binder(_: Any) -> Async[Any]:
return self.While(guard, computation)

return self.Bind(computation, binder)
else:
return self.Return()
Expand Down
Loading

0 comments on commit e3b1f21

Please sign in to comment.