Skip to content

Commit

Permalink
Various cleanups in sync implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
aaugustin committed Feb 11, 2024
1 parent de768cf commit 50b6d20
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 49 deletions.
9 changes: 4 additions & 5 deletions src/websockets/sync/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

class ClientConnection(Connection):
"""
Threaded implementation of a WebSocket client connection.
:mod:`threading` implementation of a WebSocket client connection.
:class:`ClientConnection` provides :meth:`recv` and :meth:`send` methods for
receiving and sending messages.
Expand Down Expand Up @@ -157,7 +157,7 @@ def connect(
:func:`connect` may be used as a context manager::
async with websockets.sync.client.connect(...) as websocket:
with websockets.sync.client.connect(...) as websocket:
...
The connection is closed automatically when exiting the context.
Expand Down Expand Up @@ -273,19 +273,18 @@ def connect(
sock = ssl.wrap_socket(sock, server_hostname=server_hostname)
sock.settimeout(None)

# Initialize WebSocket connection
# Initialize WebSocket protocol

protocol = ClientProtocol(
wsuri,
origin=origin,
extensions=extensions,
subprotocols=subprotocols,
state=CONNECTING,
max_size=max_size,
logger=logger,
)

# Initialize WebSocket protocol
# Initialize WebSocket connection

connection = create_connection(
sock,
Expand Down
58 changes: 28 additions & 30 deletions src/websockets/sync/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,10 @@

__all__ = ["Connection"]

logger = logging.getLogger(__name__)


class Connection:
"""
Threaded implementation of a WebSocket connection.
:mod:`threading` implementation of a WebSocket connection.
:class:`Connection` provides APIs shared between WebSocket servers and
clients.
Expand Down Expand Up @@ -82,15 +80,15 @@ def __init__(
self.close_deadline: Optional[Deadline] = None

# Mapping of ping IDs to pong waiters, in chronological order.
self.pings: Dict[bytes, threading.Event] = {}
self.ping_waiters: Dict[bytes, threading.Event] = {}

# Receiving events from the socket.
self.recv_events_thread = threading.Thread(target=self.recv_events)
self.recv_events_thread.start()

# Exception raised in recv_events, to be chained to ConnectionClosed
# in the user thread in order to show why the TCP connection dropped.
self.recv_events_exc: Optional[BaseException] = None
self.recv_exc: Optional[BaseException] = None

# Public attributes

Expand Down Expand Up @@ -198,7 +196,7 @@ def recv(self, timeout: Optional[float] = None) -> Data:
try:
return self.recv_messages.get(timeout)
except EOFError:
raise self.protocol.close_exc from self.recv_events_exc
raise self.protocol.close_exc from self.recv_exc
except RuntimeError:
raise RuntimeError(
"cannot call recv while another thread "
Expand Down Expand Up @@ -229,9 +227,10 @@ def recv_streaming(self) -> Iterator[Data]:
"""
try:
yield from self.recv_messages.get_iter()
for frame in self.recv_messages.get_iter():
yield frame
except EOFError:
raise self.protocol.close_exc from self.recv_events_exc
raise self.protocol.close_exc from self.recv_exc
except RuntimeError:
raise RuntimeError(
"cannot call recv_streaming while another thread "
Expand Down Expand Up @@ -273,7 +272,7 @@ def send(self, message: Union[Data, Iterable[Data]]) -> None:
Raises:
ConnectionClosed: When the connection is closed.
RuntimeError: If a connection is busy sending a fragmented message.
RuntimeError: If the connection is sending a fragmented message.
TypeError: If ``message`` doesn't have a supported type.
"""
Expand Down Expand Up @@ -449,15 +448,15 @@ def ping(self, data: Optional[Data] = None) -> threading.Event:

with self.send_context():
# Protect against duplicates if a payload is explicitly set.
if data in self.pings:
if data in self.ping_waiters:
raise RuntimeError("already waiting for a pong with the same data")

# Generate a unique random payload otherwise.
while data is None or data in self.pings:
while data is None or data in self.ping_waiters:
data = struct.pack("!I", random.getrandbits(32))

pong_waiter = threading.Event()
self.pings[data] = pong_waiter
self.ping_waiters[data] = pong_waiter
self.protocol.send_ping(data)
return pong_waiter

Expand Down Expand Up @@ -504,22 +503,22 @@ def acknowledge_pings(self, data: bytes) -> None:
"""
with self.protocol_mutex:
# Ignore unsolicited pong.
if data not in self.pings:
if data not in self.ping_waiters:
return
# Sending a pong for only the most recent ping is legal.
# Acknowledge all previous pings too in that case.
ping_id = None
ping_ids = []
for ping_id, ping in self.pings.items():
for ping_id, ping in self.ping_waiters.items():
ping_ids.append(ping_id)
ping.set()
if ping_id == data:
break
else:
raise AssertionError("solicited pong not found in pings")
# Remove acknowledged pings from self.pings.
# Remove acknowledged pings from self.ping_waiters.
for ping_id in ping_ids:
del self.pings[ping_id]
del self.ping_waiters[ping_id]

def recv_events(self) -> None:
"""
Expand All @@ -541,18 +540,18 @@ def recv_events(self) -> None:
self.logger.debug("error while receiving data", exc_info=True)
# When the closing handshake is initiated by our side,
# recv() may block until send_context() closes the socket.
# In that case, send_context() already set recv_events_exc.
# Calling set_recv_events_exc() avoids overwriting it.
# In that case, send_context() already set recv_exc.
# Calling set_recv_exc() avoids overwriting it.
with self.protocol_mutex:
self.set_recv_events_exc(exc)
self.set_recv_exc(exc)
break

if data == b"":
break

# Acquire the connection lock.
with self.protocol_mutex:
# Feed incoming data to the connection.
# Feed incoming data to the protocol.
self.protocol.receive_data(data)

# This isn't expected to raise an exception.
Expand All @@ -568,7 +567,7 @@ def recv_events(self) -> None:
# set by send_context(), in case of a race condition
# i.e. send_context() closes the socket after recv()
# returns above but before send_data() calls send().
self.set_recv_events_exc(exc)
self.set_recv_exc(exc)
break

if self.protocol.close_expected():
Expand All @@ -595,7 +594,7 @@ def recv_events(self) -> None:
# Breaking out of the while True: ... loop means that we believe
# that the socket doesn't work anymore.
with self.protocol_mutex:
# Feed the end of the data stream to the connection.
# Feed the end of the data stream to the protocol.
self.protocol.receive_eof()

# This isn't expected to generate events.
Expand All @@ -609,7 +608,7 @@ def recv_events(self) -> None:
# This branch should never run. It's a safety net in case of bugs.
self.logger.error("unexpected internal error", exc_info=True)
with self.protocol_mutex:
self.set_recv_events_exc(exc)
self.set_recv_exc(exc)
# We don't know where we crashed. Force protocol state to CLOSED.
self.protocol.state = CLOSED
finally:
Expand Down Expand Up @@ -668,7 +667,6 @@ def send_context(
wait_for_close = True
# If the connection is expected to close soon, set the
# close deadline based on the close timeout.

# Since we tested earlier that protocol.state was OPEN
# (or CONNECTING) and we didn't release protocol_mutex,
# it is certain that self.close_deadline is still None.
Expand Down Expand Up @@ -710,11 +708,11 @@ def send_context(
# original_exc is never set when wait_for_close is True.
assert original_exc is None
original_exc = TimeoutError("timed out while closing connection")
# Set recv_events_exc before closing the socket in order to get
# Set recv_exc before closing the socket in order to get
# proper exception reporting.
raise_close_exc = True
with self.protocol_mutex:
self.set_recv_events_exc(original_exc)
self.set_recv_exc(original_exc)

# If an error occurred, close the socket to terminate the connection and
# raise an exception.
Expand Down Expand Up @@ -745,16 +743,16 @@ def send_data(self) -> None:
except OSError: # socket already closed
pass

def set_recv_events_exc(self, exc: Optional[BaseException]) -> None:
def set_recv_exc(self, exc: Optional[BaseException]) -> None:
"""
Set recv_events_exc, if not set yet.
Set recv_exc, if not set yet.
This method requires holding protocol_mutex.
"""
assert self.protocol_mutex.locked()
if self.recv_events_exc is None:
self.recv_events_exc = exc
if self.recv_exc is None:
self.recv_exc = exc

def close_socket(self) -> None:
"""
Expand Down
27 changes: 13 additions & 14 deletions src/websockets/sync/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

class ServerConnection(Connection):
"""
Threaded implementation of a WebSocket server connection.
:mod:`threading` implementation of a WebSocket server connection.
:class:`ServerConnection` provides :meth:`recv` and :meth:`send` methods for
receiving and sending messages.
Expand Down Expand Up @@ -188,6 +188,8 @@ class WebSocketServer:
handler: Handler for one connection. Receives the socket and address
returned by :meth:`~socket.socket.accept`.
logger: Logger for this server.
It defaults to ``logging.getLogger("websockets.server")``.
See the :doc:`logging guide <../../topics/logging>` for details.
"""

Expand Down Expand Up @@ -311,16 +313,16 @@ def serve(
Whenever a client connects, the server creates a :class:`ServerConnection`,
performs the opening handshake, and delegates to the ``handler``.
The handler receives a :class:`ServerConnection` instance, which you can use
to send and receive messages.
The handler receives the :class:`ServerConnection` instance, which you can
use to send and receive messages.
Once the handler completes, either normally or with an exception, the server
performs the closing handshake and closes the connection.
:class:`WebSocketServer` mirrors the API of
This function returns a :class:`WebSocketServer` whose API mirrors
:class:`~socketserver.BaseServer`. Treat it as a context manager to ensure
that it will be closed and call the :meth:`~WebSocketServer.serve_forever`
method to serve requests::
that it will be closed and call :meth:`~WebSocketServer.serve_forever` to
serve requests::
def handler(websocket):
...
Expand Down Expand Up @@ -454,15 +456,13 @@ def conn_handler(sock: socket.socket, addr: Any) -> None:
sock.do_handshake()
sock.settimeout(None)

# Create a closure so that select_subprotocol has access to self.

# Create a closure to give select_subprotocol access to connection.
protocol_select_subprotocol: Optional[
Callable[
[ServerProtocol, Sequence[Subprotocol]],
Optional[Subprotocol],
]
] = None

if select_subprotocol is not None:

def protocol_select_subprotocol(
Expand All @@ -475,19 +475,18 @@ def protocol_select_subprotocol(
assert protocol is connection.protocol
return select_subprotocol(connection, subprotocols)

# Initialize WebSocket connection
# Initialize WebSocket protocol

protocol = ServerProtocol(
origins=origins,
extensions=extensions,
subprotocols=subprotocols,
select_subprotocol=protocol_select_subprotocol,
state=CONNECTING,
max_size=max_size,
logger=logger,
)

# Initialize WebSocket protocol
# Initialize WebSocket connection

assert create_connection is not None # help mypy
connection = create_connection(
Expand Down Expand Up @@ -522,7 +521,7 @@ def protocol_select_subprotocol(


def unix_serve(
handler: Callable[[ServerConnection], Any],
handler: Callable[[ServerConnection], None],
path: Optional[str] = None,
**kwargs: Any,
) -> WebSocketServer:
Expand All @@ -541,4 +540,4 @@ def unix_serve(
path: File system path to the Unix socket.
"""
return serve(handler, path=path, unix=True, **kwargs)
return serve(handler, unix=True, path=path, **kwargs)

0 comments on commit 50b6d20

Please sign in to comment.