Skip to content

Commit

Permalink
Expect TCP close after a failed opening handshake.
Browse files Browse the repository at this point in the history
  • Loading branch information
aaugustin committed Apr 3, 2022
1 parent 5dc16c2 commit 4034d8d
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/websockets/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,8 @@ def parse(self) -> Generator[None, None, None]:
except InvalidHandshake as exc:
response._exception = exc
self.handshake_exc = exc
self.parser = self.discard()
next(self.parser) # start coroutine
else:
assert self.state is CONNECTING
self.state = OPEN
Expand Down
2 changes: 1 addition & 1 deletion src/websockets/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def close_expected(self) -> bool:
# applies, except on EOFError where we don't send a close frame
# because we already received the TCP close, so we don't expect it.
# We already got a TCP Close if and only if the state is CLOSED.
return self.state is CLOSING
return self.state is CLOSING or self.handshake_exc is not None

# Private methods for receiving data.

Expand Down
11 changes: 10 additions & 1 deletion src/websockets/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
InvalidHeader,
InvalidHeaderValue,
InvalidOrigin,
InvalidStatus,
InvalidUpgrade,
NegotiationError,
)
Expand Down Expand Up @@ -471,8 +472,14 @@ def reject(
("Server", USER_AGENT),
]
)
response = Response(status.value, status.phrase, headers, body)
# When reject() is called from accept(), handshake_exc is already set.
# If a user calls reject(), set handshake_exc to guarantee invariant:
# "handshake_exc is None if and only if opening handshake succeded."
if self.handshake_exc is None:
self.handshake_exc = InvalidStatus(response)
self.logger.info("connection failed (%d %s)", status.value, status.phrase)
return Response(status.value, status.phrase, headers, body)
return response

def send_response(self, response: Response) -> None:
"""
Expand All @@ -497,6 +504,8 @@ def send_response(self, response: Response) -> None:
self.state = OPEN
else:
self.send_eof()
self.parser = self.discard()
next(self.parser) # start coroutine

def parse(self) -> Generator[None, None, None]:
if self.state is CONNECTING:
Expand Down
5 changes: 5 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def test_send_connect(self):
f"\r\n".encode()
],
)
self.assertFalse(client.close_expected())

def test_connect_request(self):
with unittest.mock.patch("websockets.client.generate_key", return_value=KEY):
Expand Down Expand Up @@ -135,6 +136,8 @@ def test_receive_accept(self):
)
[response] = client.events_received()
self.assertIsInstance(response, Response)
self.assertEqual(client.data_to_send(), [])
self.assertFalse(client.close_expected())
self.assertEqual(client.state, OPEN)

def test_receive_reject(self):
Expand All @@ -155,6 +158,8 @@ def test_receive_reject(self):
)
[response] = client.events_received()
self.assertIsInstance(response, Response)
self.assertEqual(client.data_to_send(), [])
self.assertTrue(client.close_expected())
self.assertEqual(client.state, CONNECTING)

def test_accept_response(self):
Expand Down
4 changes: 4 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def test_receive_connect(self):
)
[request] = server.events_received()
self.assertIsInstance(request, Request)
self.assertEqual(server.data_to_send(), [])
self.assertFalse(server.close_expected())

def test_connect_request(self):
server = ServerConnection()
Expand Down Expand Up @@ -104,6 +106,7 @@ def test_send_accept(self):
f"\r\n".encode()
],
)
self.assertFalse(server.close_expected())
self.assertEqual(server.state, OPEN)

def test_send_reject(self):
Expand All @@ -126,6 +129,7 @@ def test_send_reject(self):
b"",
],
)
self.assertTrue(server.close_expected())
self.assertEqual(server.state, CONNECTING)

def test_accept_response(self):
Expand Down

0 comments on commit 4034d8d

Please sign in to comment.