Skip to content

Commit

Permalink
Avoid leaking sockets when connect() is canceled.
Browse files Browse the repository at this point in the history
  • Loading branch information
aaugustin committed Feb 20, 2022
1 parent 88d2e2f commit 8516801
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 7 deletions.
5 changes: 5 additions & 0 deletions docs/project/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ Improvements

* Made compression negotiation more lax for compatibility with Firefox.

Bug fixes
.........

* Avoided leaking open sockets when :func:`~client.connect` is canceled.

10.1
----

Expand Down
3 changes: 2 additions & 1 deletion src/websockets/legacy/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,8 @@ async def __await_impl__(self) -> WebSocketClientProtocol:
protocol.fail_connection()
await protocol.wait_closed()
self.handle_redirect(exc.uri)
except Exception:
# Avoid leaking a connected socket when the handshake fails.
except (Exception, asyncio.CancelledError):
protocol.fail_connection()
await protocol.wait_closed()
raise
Expand Down
37 changes: 31 additions & 6 deletions tests/legacy/test_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,15 +147,11 @@ def with_client(*args, **kwargs):
return with_manager(temp_test_client, *args, **kwargs)


def get_server_uri(server, secure=False, resource_name="/", user_info=None):
def get_server_address(server):
"""
Return a WebSocket URI for connecting to the given server.
Return an address on which the given server listens.
"""
proto = "wss" if secure else "ws"

user_info = ":".join(user_info) + "@" if user_info else ""

# Pick a random socket in order to test both IPv4 and IPv6 on systems
# where both are available. Randomizing tests is usually a bad idea. If
# needed, either use the first socket, or test separately IPv4 and IPv6.
Expand All @@ -169,6 +165,17 @@ def get_server_uri(server, secure=False, resource_name="/", user_info=None):
else: # pragma: no cover
raise ValueError("expected an IPv6, IPv4, or Unix socket")

return host, port


def get_server_uri(server, secure=False, resource_name="/", user_info=None):
"""
Return a WebSocket URI for connecting to the given server.
"""
proto = "wss" if secure else "ws"
user_info = ":".join(user_info) + "@" if user_info else ""
host, port = get_server_address(server)
return f"{proto}://{user_info}{host}:{port}{resource_name}"


Expand Down Expand Up @@ -1067,6 +1074,21 @@ def test_server_error_in_handshake(self, _process_request):
with self.assertRaises(InvalidHandshake):
self.start_client()

@with_server(create_protocol=SlowOpeningHandshakeProtocol)
def test_client_connect_canceled_during_handshake(self):
sock = socket.create_connection(get_server_address(self.server))
sock.send(b"") # socket is connected

async def cancelled_client():
start_client = connect(get_server_uri(self.server), sock=sock)
await asyncio.wait_for(start_client, 5 * MS)

with self.assertRaises(asyncio.TimeoutError):
self.loop.run_until_complete(cancelled_client())

with self.assertRaises(OSError):
sock.send(b"") # socket is closed

@with_server()
@unittest.mock.patch("websockets.legacy.server.WebSocketServerProtocol.send")
def test_server_handler_crashes(self, send):
Expand Down Expand Up @@ -1199,6 +1221,9 @@ class SecureClientServerTests(
CommonClientServerTests, SecureClientServerTestsMixin, AsyncioTestCase
):

# The implementation of this test makes it hard to run it over TLS.
test_client_connect_canceled_during_handshake = None

# TLS over Unix sockets doesn't make sense.
test_unix_socket = None

Expand Down

0 comments on commit 8516801

Please sign in to comment.