Skip to content

Commit

Permalink
python: add task management to Channel class
Browse files Browse the repository at this point in the history
Add a new Channel.create_task() method to create a task bound to the
lifetime of the channel.

Implement logic in .close() to delay the actual delivery of the "close"
message until after all of the tasks have run to completion.

Adjust various Channel-based users of tasks to use the new API.  This is
mostly changing the DBus channel to remove similar logic and make sure
it exits cleanly.  Specifically: once we decide to close the channel, we
let any existing operations run to completion, but cancel all match
rules (and prevent the creation of new ones) in order to avoid creating
new work in response to incoming signals.

Rename a method in ProtocolChannel to avoid a conflict.  This is a
method, so it should have a verb name anyway.
  • Loading branch information
allisonkarlitskaya committed Feb 16, 2023
1 parent f02af74 commit dba8270
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 28 deletions.
63 changes: 58 additions & 5 deletions src/cockpit/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import asyncio
import logging

from typing import ClassVar, Dict, Generator, List, Optional, Sequence, Tuple, Type
from typing import ClassVar, Dict, Generator, List, Optional, Sequence, Set, Tuple, Type

from .router import Endpoint, Router, RoutingRule

Expand Down Expand Up @@ -90,6 +90,10 @@ class Channel(Endpoint):
_out_sequence: int = 0
_out_window: int = SEND_WINDOW

# Task management
_tasks: Set[asyncio.Task]
_close_args: Optional[Dict[str, object]] = None

# Must be filled in by the channel implementation
payload: ClassVar[str]
restrictions: ClassVar[Sequence[Tuple[str, object]]] = ()
Expand All @@ -102,6 +106,7 @@ def do_control(self, command, message):
# things that our subclass may be interested in handling. We drop the
# 'message' field for handlers that don't need it.
if command == 'open':
self._tasks = set()
self.channel = message['channel']
if message.get('flow-control'):
self._send_pings = True
Expand All @@ -121,6 +126,10 @@ def do_control(self, command, message):
self.do_options(message)

def do_channel_control(self, channel, command, message):
# Already closing? Ignore.
if self._close_args is not None:
return

# Catch errors and turn them into close messages
try:
self.do_control(command, message)
Expand Down Expand Up @@ -149,6 +158,10 @@ def do_ping(self, message):
self.send_pong(message)

def do_channel_data(self, channel, data):
# Already closing? Ignore.
if self._close_args is not None:
return

# Catch errors and turn them into close messages
try:
self.do_data(data)
Expand All @@ -167,8 +180,48 @@ def ready(self, **kwargs):
def done(self):
self.send_control(command='done')

# tasks and close management
def is_closing(self) -> bool:
return self._close_args is not None

def _close_now(self):
self.send_control('close', **self._close_args)

def _task_done(self, task):
# Strictly speaking, we should read the result and check for exceptions but:
# - exceptions bubbling out of the task are programming errors
# - the only thing we'd do with it anyway, is to show it
# - Python already does that with its "Task exception was never retrieved" messages
self._tasks.remove(task)
if self._close_args is not None and not self._tasks:
self._close_now()

def create_task(self, coroutine, name=None):
"""Create a task associated with the channel.
All tasks must exit before the channel can close. You may not create
new tasks after calling .close().
"""
assert self._close_args is None
task = asyncio.create_task(coroutine)
self._tasks.add(task)
task.add_done_callback(self._task_done)
return task

def close(self, **kwargs):
self.send_control('close', **kwargs)
"""Requests the channel to be closed.
After you call this method, you won't get anymore `.do_*()` calls.
This will wait for any running tasks to complete before sending the
close message.
"""
if self._close_args is not None:
# close already requested
return
self._close_args = kwargs
if not self._tasks:
self._close_now()

def send_data(self, data: bytes) -> bool:
"""Send data and handle book-keeping for flow control.
Expand Down Expand Up @@ -268,11 +321,11 @@ def connection_made(self, transport: asyncio.BaseTransport):
assert isinstance(transport, asyncio.Transport)
self._transport = transport

def _close_args(self) -> Dict[str, object]:
def _get_close_args(self) -> Dict[str, object]:
return {}

def connection_lost(self, exc: Optional[Exception]) -> None:
self.close(**self._close_args())
self.close(**self._get_close_args())

def do_data(self, data: bytes) -> None:
assert self._transport is not None
Expand Down Expand Up @@ -389,7 +442,7 @@ def do_resume_send(self) -> None:

def do_open(self, options):
self.receive_queue = asyncio.Queue()
asyncio.create_task(self.run_wrapper(options), name=f'{self.__class__.__name__}.run_wrapper({options})')
self.create_task(self.run_wrapper(options), name=f'{self.__class__.__name__}.run_wrapper({options})')

def do_done(self):
self.receive_queue.put_nowait(b'')
Expand Down
27 changes: 11 additions & 16 deletions src/cockpit/channels/dbus.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ def notify_update(notify, path, interface_name, props):
class DBusChannel(Channel):
payload = 'dbus-json3'

tasks = None
matches = None
name = None
bus = None
Expand Down Expand Up @@ -216,7 +215,6 @@ def do_open(self, options):
self.cache = InterfaceCache()
self.name = options.get('name')
self.matches = []
self.tasks = set()

bus = options.get('bus')
address = options.get('address')
Expand Down Expand Up @@ -260,9 +258,7 @@ async def get_ready():
self.ready(unique_name=self.owner)
else:
self.close(problem="not-found")
task = asyncio.create_task(get_ready())
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)
self.create_task(get_ready())
else:
self.ready()

Expand All @@ -285,13 +281,12 @@ def filter_owner(message):
else:
func = handler
r_string = ','.join(f"{key}='{value}'" for key, value in r.items())
self.matches.append(self.bus.add_match(r_string, func))
if not self.is_closing():
self.matches.append(self.bus.add_match(r_string, func))

def add_async_signal_handler(self, handler, **kwargs):
def sync_handler(message):
task = asyncio.create_task(handler(message))
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)
self.create_task(handler(message))
self.add_signal_handler(sync_handler, **kwargs)

async def do_call(self, message):
Expand Down Expand Up @@ -483,19 +478,19 @@ def do_data(self, data):
logger.debug('receive dbus request %s %s', self.name, message)

if 'call' in message:
task = asyncio.create_task(self.do_call(message))
self.create_task(self.do_call(message))
elif 'add-match' in message:
task = asyncio.create_task(self.do_add_match(message))
self.create_task(self.do_add_match(message))
elif 'watch' in message:
task = asyncio.create_task(self.do_watch(message))
self.create_task(self.do_watch(message))
elif 'meta' in message:
task = asyncio.create_task(self.do_meta(message))
self.create_task(self.do_meta(message))
else:
logger.debug('ignored dbus request %s', message)
return

self.tasks.add(task)
task.add_done_callback(self.tasks.discard)

def do_close(self):
for slot in self.matches:
slot.cancel()
self.matches = None # error out
self.close()
2 changes: 1 addition & 1 deletion src/cockpit/channels/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class SubprocessStreamChannel(ProtocolChannel, SubprocessProtocol):
def process_exited(self) -> None:
self.close_on_eof()

def _close_args(self) -> Dict[str, object]:
def _get_close_args(self) -> Dict[str, object]:
assert isinstance(self._transport, SubprocessTransport)
args: Dict[str, object] = {'exit-status': self._transport.get_returncode()}
stderr = self._transport.get_stderr()
Expand Down
13 changes: 7 additions & 6 deletions test/pytest/test_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def send_done(self, channel, **kwargs):
def send_close(self, channel, **kwargs):
self.send_json('', command='close', channel=channel, **kwargs)

async def check_close(self, channel, **kwargs):
self.send_close(channel, **kwargs)
await self.assert_msg('', command='close', channel=channel)

def send_ping(self, **kwargs):
self.send_json('', command='ping', **kwargs)

Expand Down Expand Up @@ -313,8 +317,7 @@ async def verify_root_bridge_running(self):
{'Bridges': self.superuser_bridges, 'Current': 'root'}, bus=root_dbus)

# close up
self.transport.send_close(channel=root_dbus)
await self.transport.assert_msg('', command='close', channel=root_dbus)
await self.transport.check_close(channel=root_dbus)

async def test_superuser_dbus(self):
await self.start()
Expand Down Expand Up @@ -529,8 +532,7 @@ async def test_fslist1_no_watch(self):
# empty
ch = self.transport.send_open('fslist1', path=str(dir_path), watch=False)
await self.transport.assert_msg('', command='done', channel=ch)
self.transport.send_close(channel=ch)
await self.transport.assert_msg('', command='close', channel=ch)
await self.transport.check_close(channel=ch)

# create a file and a directory in some_dir
Path(dir_path, 'somefile').touch()
Expand All @@ -546,8 +548,7 @@ async def test_fslist1_no_watch(self):
assert msg2 == {'event': 'present', 'path': 'somefile', 'type': 'file'}

await self.transport.assert_msg('', command='done', channel=ch)
self.transport.send_close(channel=ch)
await self.transport.assert_msg('', command='close', channel=ch)
await self.transport.check_close(channel=ch)

async def test_fslist1_notexist(self):
await self.start()
Expand Down

0 comments on commit dba8270

Please sign in to comment.