Skip to content

Commit

Permalink
Store handshake exceptions in connection objects.
Browse files Browse the repository at this point in the history
Storing them in request/response objects was probably legacy.
  • Loading branch information
aaugustin committed Apr 3, 2022
1 parent 0796c43 commit 5dc16c2
Show file tree
Hide file tree
Showing 10 changed files with 92 additions and 46 deletions.
14 changes: 8 additions & 6 deletions docs/howto/sansio.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,11 @@ the network, as described in `Send data`_ below.
The first event returned by :meth:`~connection.Connection.events_received` is
the WebSocket handshake response.

When the handshake fails, the reason is available in ``response.exception``::
When the handshake fails, the reason is available in
:attr:`~client.ClientConnection.handshake_exc`::

if response.exception is not None:
raise response.exception
if connection.handshake_exc is not None:
raise connection.handshake_exc

Else, the WebSocket connection is open.

Expand Down Expand Up @@ -96,10 +97,11 @@ the network, as described in `Send data`_ below.
Even when you call :meth:`~server.ServerConnection.accept`, the WebSocket
handshake may fail if the request is incorrect or unsupported.

When the handshake fails, the reason is available in ``request.exception``::
When the handshake fails, the reason is available in
:attr:`~server.ServerConnection.handshake_exc`::

if request.exception is not None:
raise request.exception
if connection.handshake_exc is not None:
raise connection.handshake_exc

Else, the WebSocket connection is open.

Expand Down
11 changes: 11 additions & 0 deletions docs/project/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@ They may change at any time.

*In development*

Backwards-incompatible changes
..............................

.. admonition:: The ``exception`` attribute of :class:`~http11.Request` and :class:`~http11.Response` is deprecated.
:class: note

Use the ``handshake_exc`` attribute of :class:`~server.ServerConnection` and
:class:`~client.ClientConnection` instead.

See :doc:`../howto/sansio` for details.

10.2
----

Expand Down
2 changes: 2 additions & 0 deletions docs/reference/client.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ Sans-I/O

.. autoproperty:: state

.. autoattribute:: handshake_exc

.. autoproperty:: close_code

.. autoproperty:: close_reason
Expand Down
2 changes: 2 additions & 0 deletions docs/reference/server.rst
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ Sans-I/O

.. autoproperty:: state

.. autoattribute:: handshake_exc

.. autoproperty:: close_code

.. autoproperty:: close_reason
Expand Down
3 changes: 2 additions & 1 deletion src/websockets/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,8 @@ def parse(self) -> Generator[None, None, None]:
try:
self.process_response(response)
except InvalidHandshake as exc:
response.exception = exc
response._exception = exc
self.handshake_exc = exc
else:
assert self.state is CONNECTING
self.state = OPEN
Expand Down
9 changes: 9 additions & 0 deletions src/websockets/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,15 @@ def __init__(
self.close_sent: Optional[Close] = None
self.close_rcvd_then_sent: Optional[bool] = None

# Track if an exception happened during the handshake.
self.handshake_exc: Optional[Exception] = None
"""
Exception to raise if the opening handshake failed.
:obj:`None` if the opening handshake succeeded.
"""

# Track if send_eof() was called.
self.eof_sent = False

Expand Down
27 changes: 21 additions & 6 deletions src/websockets/http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import dataclasses
import re
import warnings
from typing import Callable, Generator, Optional

from . import datastructures, exceptions
Expand Down Expand Up @@ -57,15 +58,22 @@ class Request:
Attributes:
path: Request path, including optional query.
headers: Request headers.
exception: If processing the response triggers an exception,
the exception is stored in this attribute.
"""

path: str
headers: datastructures.Headers
# body isn't useful is the context of this library.

exception: Optional[Exception] = None
_exception: Optional[Exception] = None

@property
def exception(self) -> Optional[Exception]: # pragma: no cover
warnings.warn(
"Request.exception is deprecated; "
"use ServerConnection.handshake_exc instead",
DeprecationWarning,
)
return self._exception

@classmethod
def parse(
Expand Down Expand Up @@ -152,8 +160,6 @@ class Response:
reason_phrase: Response reason.
headers: Response headers.
body: Response body, if any.
exception: if processing the response triggers an exception,
the exception is stored in this attribute.
"""

Expand All @@ -162,7 +168,16 @@ class Response:
headers: datastructures.Headers
body: Optional[bytes] = None

exception: Optional[Exception] = None
_exception: Optional[Exception] = None

@property
def exception(self) -> Optional[Exception]: # pragma: no cover
warnings.warn(
"Response.exception is deprecated; "
"use ClientConnection.handshake_exc instead",
DeprecationWarning,
)
return self._exception

@classmethod
def parse(
Expand Down
12 changes: 8 additions & 4 deletions src/websockets/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,17 @@ def accept(self, request: Request) -> Response:
protocol_header,
) = self.process_request(request)
except InvalidOrigin as exc:
request.exception = exc
request._exception = exc
self.handshake_exc = exc
if self.debug:
self.logger.debug("! invalid origin", exc_info=True)
return self.reject(
http.HTTPStatus.FORBIDDEN,
f"Failed to open a WebSocket connection: {exc}.\n",
)
except InvalidUpgrade as exc:
request.exception = exc
request._exception = exc
self.handshake_exc = exc
if self.debug:
self.logger.debug("! invalid upgrade", exc_info=True)
response = self.reject(
Expand All @@ -133,15 +135,17 @@ def accept(self, request: Request) -> Response:
response.headers["Upgrade"] = "websocket"
return response
except InvalidHandshake as exc:
request.exception = exc
request._exception = exc
self.handshake_exc = exc
if self.debug:
self.logger.debug("! invalid handshake", exc_info=True)
return self.reject(
http.HTTPStatus.BAD_REQUEST,
f"Failed to open a WebSocket connection: {exc}.\n",
)
except Exception as exc:
request.exception = exc
request._exception = exc
self.handshake_exc = exc
self.logger.error("opening handshake failed", exc_info=True)
return self.reject(
http.HTTPStatus.INTERNAL_SERVER_ERROR,
Expand Down
26 changes: 13 additions & 13 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def test_missing_connection(self):

self.assertEqual(client.state, CONNECTING)
with self.assertRaises(InvalidHeader) as raised:
raise response.exception
raise client.handshake_exc
self.assertEqual(str(raised.exception), "missing Connection header")

def test_invalid_connection(self):
Expand All @@ -268,7 +268,7 @@ def test_invalid_connection(self):

self.assertEqual(client.state, CONNECTING)
with self.assertRaises(InvalidHeader) as raised:
raise response.exception
raise client.handshake_exc
self.assertEqual(str(raised.exception), "invalid Connection header: close")

def test_missing_upgrade(self):
Expand All @@ -280,7 +280,7 @@ def test_missing_upgrade(self):

self.assertEqual(client.state, CONNECTING)
with self.assertRaises(InvalidHeader) as raised:
raise response.exception
raise client.handshake_exc
self.assertEqual(str(raised.exception), "missing Upgrade header")

def test_invalid_upgrade(self):
Expand All @@ -293,7 +293,7 @@ def test_invalid_upgrade(self):

self.assertEqual(client.state, CONNECTING)
with self.assertRaises(InvalidHeader) as raised:
raise response.exception
raise client.handshake_exc
self.assertEqual(str(raised.exception), "invalid Upgrade header: h2c")

def test_missing_accept(self):
Expand All @@ -305,7 +305,7 @@ def test_missing_accept(self):

self.assertEqual(client.state, CONNECTING)
with self.assertRaises(InvalidHeader) as raised:
raise response.exception
raise client.handshake_exc
self.assertEqual(str(raised.exception), "missing Sec-WebSocket-Accept header")

def test_multiple_accept(self):
Expand All @@ -317,7 +317,7 @@ def test_multiple_accept(self):

self.assertEqual(client.state, CONNECTING)
with self.assertRaises(InvalidHeader) as raised:
raise response.exception
raise client.handshake_exc
self.assertEqual(
str(raised.exception),
"invalid Sec-WebSocket-Accept header: "
Expand All @@ -334,7 +334,7 @@ def test_invalid_accept(self):

self.assertEqual(client.state, CONNECTING)
with self.assertRaises(InvalidHeader) as raised:
raise response.exception
raise client.handshake_exc
self.assertEqual(
str(raised.exception), f"invalid Sec-WebSocket-Accept header: {ACCEPT}"
)
Expand Down Expand Up @@ -383,7 +383,7 @@ def test_unexpected_extension(self):

self.assertEqual(client.state, CONNECTING)
with self.assertRaises(InvalidHandshake) as raised:
raise response.exception
raise client.handshake_exc
self.assertEqual(str(raised.exception), "no extensions supported")

def test_unsupported_extension(self):
Expand All @@ -398,7 +398,7 @@ def test_unsupported_extension(self):

self.assertEqual(client.state, CONNECTING)
with self.assertRaises(InvalidHandshake) as raised:
raise response.exception
raise client.handshake_exc
self.assertEqual(
str(raised.exception),
"Unsupported extension: name = x-op, params = [('op', None)]",
Expand Down Expand Up @@ -429,7 +429,7 @@ def test_unsupported_extension_parameters(self):

self.assertEqual(client.state, CONNECTING)
with self.assertRaises(InvalidHandshake) as raised:
raise response.exception
raise client.handshake_exc
self.assertEqual(
str(raised.exception),
"Unsupported extension: name = x-op, params = [('op', 'that')]",
Expand Down Expand Up @@ -520,7 +520,7 @@ def test_unexpected_subprotocol(self):

self.assertEqual(client.state, CONNECTING)
with self.assertRaises(InvalidHandshake) as raised:
raise response.exception
raise client.handshake_exc
self.assertEqual(str(raised.exception), "no subprotocols supported")

def test_multiple_subprotocols(self):
Expand All @@ -536,7 +536,7 @@ def test_multiple_subprotocols(self):

self.assertEqual(client.state, CONNECTING)
with self.assertRaises(InvalidHandshake) as raised:
raise response.exception
raise client.handshake_exc
self.assertEqual(
str(raised.exception), "multiple subprotocols: superchat, chat"
)
Expand Down Expand Up @@ -566,7 +566,7 @@ def test_unsupported_subprotocol(self):

self.assertEqual(client.state, CONNECTING)
with self.assertRaises(InvalidHandshake) as raised:
raise response.exception
raise client.handshake_exc
self.assertEqual(str(raised.exception), "unsupported subprotocol: otherchat")


Expand Down
Loading

0 comments on commit 5dc16c2

Please sign in to comment.