Skip to content

Commit

Permalink
python: clean up channel closing logic
Browse files Browse the repository at this point in the history
The logic for shutdown is now consolidated and simplified:

  - rename the .close_channel() method to .drop_channel() to make it
    clear that it's only about removing the reference from the router.
    Assert that the channel hasn't already been dropped.

  - the channel gets dropped from the routing table at exactly one
    point: when it sends its close message.  As such, all channels must
    send close messages.  Implement this in the default `do_close()`.

  - we no longer drop channels in response to incoming close messages.
    The Endpoint implementation is always responsible for calling one of
    the shutdown methods.

  - Peers no longer track the list of open channels: they request the
    router handle the shutdown for them, via a new `shutdown_endpoint()`
    method..  This results in more reliable sending of close messages,
    particularly in the case of frozen Peers where the channels list
    isn't filled in yet.

This makes Peer shutdown theoretically slower, since we have to iterate
the list of all open channels in the router, but the simplification is
well worth it.
  • Loading branch information
allisonkarlitskaya committed Feb 16, 2023
1 parent f987ffb commit f02af74
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 25 deletions.
2 changes: 1 addition & 1 deletion src/cockpit/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def do_done(self):
pass

def do_close(self):
pass
self.close()

def do_options(self, message):
raise ChannelError('not-supported', message='This channel does not implement "options"')
Expand Down
22 changes: 7 additions & 15 deletions src/cockpit/peer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import logging
import os

from typing import Dict, List, Optional, Sequence, Set
from typing import Dict, List, Optional, Sequence

from .router import Endpoint, Router, RoutingError, RoutingRule
from .protocol import CockpitProtocolClient
Expand Down Expand Up @@ -50,8 +50,6 @@ class Peer(CockpitProtocolClient, SubprocessProtocol, Endpoint):
init_host: str
state_listener: Optional[PeerStateListener]

channels: Set[str]

authorize_pending: Optional[str] = None # the cookie of the pending request

def __init__(self,
Expand All @@ -70,7 +68,6 @@ def __init__(self,
assert router.init_host is not None
self.init_host = init_host or router.init_host

self.channels = set()
self.authorize_pending = None

def spawn(self, argv: Sequence[str], env: Sequence[str], **kwargs) -> asyncio.Transport:
Expand Down Expand Up @@ -98,12 +95,15 @@ def do_init(self, message: Dict[str, object]) -> None:
else:
logger.warning('Peer %s connection got duplicate init message', self.name)

def shutdown(self, problem: str, **kwargs: object) -> None:
self.shutdown_endpoint(problem=problem, **kwargs)
if self.transport is not None:
self.transport.close()

def do_closed(self, transport_was: asyncio.Transport, exc: Optional[Exception]) -> None:
logger.debug('Peer %s connection lost %s', self.name, exc)

# We need to synthesize close messages for all open channels
while self.channels:
self.send_channel_control(self.channels.pop(), 'close', problem='disconnected')
self.shutdown('disconnected')

if self.state_listener is not None:
# If we don't otherwise has an exception set, but we have stderr output, we can use it.
Expand Down Expand Up @@ -149,21 +149,13 @@ def close(self) -> None:

# Forwarding data: from the peer to the router
def channel_control_received(self, channel: str, command: str, message: Dict[str, object]) -> None:
if command == 'close':
self.channels.discard(channel)

self.send_channel_control(**message)

def channel_data_received(self, channel: str, data: bytes) -> None:
self.send_channel_data(channel, data)

# Forwarding data: from the router to the peer
def do_channel_control(self, channel: str, command: str, message: Dict[str, object]) -> None:
if command == 'open':
self.channels.add(channel)
elif command == 'close':
self.channels.discard(channel)

self.write_control(**message)

def do_channel_data(self, channel: str, data: bytes) -> None:
Expand Down
27 changes: 19 additions & 8 deletions src/cockpit/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ def send_channel_message(self, channel: str, **kwargs) -> None:
def send_channel_control(self, channel, command, **kwargs) -> None:
self.router.write_control(channel=channel, command=command, **kwargs)
if command == 'close':
self.router.close_channel(channel)
self.router.drop_channel(channel)

def shutdown_endpoint(self, **kwargs) -> None:
self.router.shutdown_endpoint(self, **kwargs)


class RoutingError(Exception):
Expand Down Expand Up @@ -144,16 +147,28 @@ def check_rules(self, options: Dict[str, object]) -> Endpoint:
logger.debug(' No rules matched')
raise RoutingError('not-supported')

def close_channel(self, channel: str) -> None:
self.open_channels.pop(channel, None)
def drop_channel(self, channel: str) -> None:
assert channel in self.open_channels, (self.open_channels, channel)
try:
self.open_channels.pop(channel)
logger.debug('router dropped channel %s', channel)
except KeyError:
logger.error('trying to drop non-existent channel %s', channel)
if channel in self.groups:
del self.groups[channel]

def shutdown_endpoint(self, endpoint: Endpoint, **kwargs) -> None:
channels = set(key for key, value in self.open_channels.items() if value == endpoint)
logger.debug('shutdown_endpoint(%s, %s) will close %s', endpoint, kwargs, channels)
for channel in channels:
self.write_control(command='close', channel=channel, **kwargs)
self.drop_channel(channel)

def do_kill(self, host: Optional[str], group: Optional[str]) -> None:
if group:
to_close = set(ch for ch, gr in self.groups.items() if gr == group)
for channel in to_close:
self.close_channel(channel)
self.drop_channel(channel)

def channel_control_received(self, channel: str, command: str, message: Dict[str, object]) -> None:
# If this is an open message then we need to apply the routing rules to
Expand Down Expand Up @@ -185,10 +200,6 @@ def channel_control_received(self, channel: str, command: str, message: Dict[str
# At this point, we have the endpoint. Route the message.
endpoint.do_channel_control(channel, command, message)

# If that was a close message, we can remove the endpoint now.
if command == 'close':
self.close_channel(channel)

def channel_data_received(self, channel: str, data: bytes) -> None:
try:
endpoint = self.open_channels[channel]
Expand Down
1 change: 0 additions & 1 deletion test/pytest/test_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,6 @@ async def verify_root_bridge_running(self):
# try to open dbus on the root bridge
root_dbus = await self.transport.check_open('dbus-json3', bus='internal', superuser=True)
assert self.bridge.open_channels[root_dbus].name == 'pseudo'
assert self.bridge.open_channels[root_dbus].channels == {root_dbus}

# verify that the bridge thinks that it's the root bridge
await self.transport.assert_bus_props('/superuser', 'cockpit.Superuser',
Expand Down

0 comments on commit f02af74

Please sign in to comment.