diff --git a/src/websockets/client.py b/src/websockets/client.py index 8d826fea..df8e5342 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -336,6 +336,8 @@ def parse(self) -> Generator[None, None, None]: except InvalidHandshake as exc: response._exception = exc self.handshake_exc = exc + self.parser = self.discard() + next(self.parser) # start coroutine else: assert self.state is CONNECTING self.state = OPEN diff --git a/src/websockets/connection.py b/src/websockets/connection.py index 967bd8fa..db8b5369 100644 --- a/src/websockets/connection.py +++ b/src/websockets/connection.py @@ -498,7 +498,7 @@ def close_expected(self) -> bool: # applies, except on EOFError where we don't send a close frame # because we already received the TCP close, so we don't expect it. # We already got a TCP Close if and only if the state is CLOSED. - return self.state is CLOSING + return self.state is CLOSING or self.handshake_exc is not None # Private methods for receiving data. diff --git a/src/websockets/server.py b/src/websockets/server.py index 214417ad..5dad50b6 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -13,6 +13,7 @@ InvalidHeader, InvalidHeaderValue, InvalidOrigin, + InvalidStatus, InvalidUpgrade, NegotiationError, ) @@ -471,8 +472,14 @@ def reject( ("Server", USER_AGENT), ] ) + response = Response(status.value, status.phrase, headers, body) + # When reject() is called from accept(), handshake_exc is already set. + # If a user calls reject(), set handshake_exc to guarantee invariant: + # "handshake_exc is None if and only if opening handshake succeded." + if self.handshake_exc is None: + self.handshake_exc = InvalidStatus(response) self.logger.info("connection failed (%d %s)", status.value, status.phrase) - return Response(status.value, status.phrase, headers, body) + return response def send_response(self, response: Response) -> None: """ @@ -497,6 +504,8 @@ def send_response(self, response: Response) -> None: self.state = OPEN else: self.send_eof() + self.parser = self.discard() + next(self.parser) # start coroutine def parse(self) -> Generator[None, None, None]: if self.state is CONNECTING: diff --git a/tests/test_client.py b/tests/test_client.py index 12fd8726..a843d327 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -42,6 +42,7 @@ def test_send_connect(self): f"\r\n".encode() ], ) + self.assertFalse(client.close_expected()) def test_connect_request(self): with unittest.mock.patch("websockets.client.generate_key", return_value=KEY): @@ -135,6 +136,8 @@ def test_receive_accept(self): ) [response] = client.events_received() self.assertIsInstance(response, Response) + self.assertEqual(client.data_to_send(), []) + self.assertFalse(client.close_expected()) self.assertEqual(client.state, OPEN) def test_receive_reject(self): @@ -155,6 +158,8 @@ def test_receive_reject(self): ) [response] = client.events_received() self.assertIsInstance(response, Response) + self.assertEqual(client.data_to_send(), []) + self.assertTrue(client.close_expected()) self.assertEqual(client.state, CONNECTING) def test_accept_response(self): diff --git a/tests/test_server.py b/tests/test_server.py index 43bc03e1..e3e80223 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -38,6 +38,8 @@ def test_receive_connect(self): ) [request] = server.events_received() self.assertIsInstance(request, Request) + self.assertEqual(server.data_to_send(), []) + self.assertFalse(server.close_expected()) def test_connect_request(self): server = ServerConnection() @@ -104,6 +106,7 @@ def test_send_accept(self): f"\r\n".encode() ], ) + self.assertFalse(server.close_expected()) self.assertEqual(server.state, OPEN) def test_send_reject(self): @@ -126,6 +129,7 @@ def test_send_reject(self): b"", ], ) + self.assertTrue(server.close_expected()) self.assertEqual(server.state, CONNECTING) def test_accept_response(self):