Skip to content

Commit

Permalink
Fix initializing Headers from Headers.
Browse files Browse the repository at this point in the history
It didn't work when a header had multiple values.
  • Loading branch information
aaugustin committed Aug 12, 2021
1 parent dc42ecb commit 5de7b41
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 27 deletions.
16 changes: 13 additions & 3 deletions src/websockets/datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ class Headers(MutableMapping[str, str]):

__slots__ = ["_dict", "_list"]

def __init__(self, *args: Any, **kwargs: str) -> None:
# Like dict, Headers accepts an optional "mapping or iterable" argument.
def __init__(self, *args: HeadersLike, **kwargs: str) -> None:
self._dict: Dict[str, List[str]] = {}
self._list: List[Tuple[str, str]] = []
# MutableMapping.update calls __setitem__ for each (name, value) pair.
self.update(*args, **kwargs)

def __str__(self) -> str:
Expand All @@ -86,7 +86,7 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self._list!r})"

def copy(self) -> "Headers":
def copy(self) -> Headers:
copy = self.__class__()
copy._dict = self._dict.copy()
copy._list = self._list.copy()
Expand Down Expand Up @@ -139,6 +139,16 @@ def clear(self) -> None:
self._dict = {}
self._list = []

def update(self, *args: HeadersLike, **kwargs: str) -> None:
"""
Update from a Headers instance and/or keyword arguments.
"""
args = tuple(
arg.raw_items() if isinstance(arg, Headers) else arg for arg in args
)
super().update(*args, **kwargs)

# Methods for handling multiple values

def get_all(self, key: str) -> List[str]:
Expand Down
149 changes: 125 additions & 24 deletions tests/test_datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,68 @@
from websockets.datastructures import *


class MultipleValuesErrorTests(unittest.TestCase):
def test_multiple_values_error_str(self):
self.assertEqual(str(MultipleValuesError("Connection")), "'Connection'")
self.assertEqual(str(MultipleValuesError()), "")


class HeadersTests(unittest.TestCase):
def setUp(self):
self.headers = Headers([("Connection", "Upgrade"), ("Server", "websockets")])

def test_init(self):
self.assertEqual(
Headers(),
Headers(),
)

def test_init_from_kwargs(self):
self.assertEqual(
Headers(connection="Upgrade", server="websockets"),
self.headers,
)

def test_init_from_headers(self):
self.assertEqual(
Headers(self.headers),
self.headers,
)

def test_init_from_headers_and_kwargs(self):
self.assertEqual(
Headers(Headers(connection="Upgrade"), server="websockets"),
self.headers,
)

def test_init_from_mapping(self):
self.assertEqual(
Headers({"Connection": "Upgrade", "Server": "websockets"}),
self.headers,
)

def test_init_from_mapping_and_kwargs(self):
self.assertEqual(
Headers({"Connection": "Upgrade"}, server="websockets"),
self.headers,
)

def test_init_from_iterable(self):
self.assertEqual(
Headers([("Connection", "Upgrade"), ("Server", "websockets")]),
self.headers,
)

def test_init_from_iterable_and_kwargs(self):
self.assertEqual(
Headers([("Connection", "Upgrade")], server="websockets"),
self.headers,
)

def test_init_multiple_positional_arguments(self):
with self.assertRaises(TypeError):
Headers(Headers(connection="Upgrade"), Headers(server="websockets"))

def test_str(self):
self.assertEqual(
str(self.headers), "Connection: Upgrade\r\nServer: websockets\r\n\r\n"
Expand All @@ -27,10 +85,6 @@ def test_serialize(self):
b"Connection: Upgrade\r\nServer: websockets\r\n\r\n",
)

def test_multiple_values_error_str(self):
self.assertEqual(str(MultipleValuesError("Connection")), "'Connection'")
self.assertEqual(str(MultipleValuesError()), "")

def test_contains(self):
self.assertIn("Server", self.headers)

Expand Down Expand Up @@ -59,11 +113,6 @@ def test_getitem_key_error(self):
with self.assertRaises(KeyError):
self.headers["Upgrade"]

def test_getitem_multiple_values_error(self):
self.headers["Server"] = "2"
with self.assertRaises(MultipleValuesError):
self.headers["Server"]

def test_setitem(self):
self.headers["Upgrade"] = "websocket"
self.assertEqual(self.headers["Upgrade"], "websocket")
Expand All @@ -72,11 +121,6 @@ def test_setitem_case_insensitive(self):
self.headers["upgrade"] = "websocket"
self.assertEqual(self.headers["Upgrade"], "websocket")

def test_setitem_multiple_values(self):
self.headers["Connection"] = "close"
with self.assertRaises(MultipleValuesError):
self.headers["Connection"]

def test_delitem(self):
del self.headers["Connection"]
with self.assertRaises(KeyError):
Expand All @@ -87,12 +131,6 @@ def test_delitem_case_insensitive(self):
with self.assertRaises(KeyError):
self.headers["Connection"]

def test_delitem_multiple_values(self):
self.headers["Connection"] = "close"
del self.headers["Connection"]
with self.assertRaises(KeyError):
self.headers["Connection"]

def test_eq(self):
other_headers = Headers([("Connection", "Upgrade"), ("Server", "websockets")])
self.assertEqual(self.headers, other_headers)
Expand Down Expand Up @@ -124,12 +162,75 @@ def test_get_all_case_insensitive(self):
def test_get_all_no_values(self):
self.assertEqual(self.headers.get_all("Upgrade"), [])

def test_get_all_multiple_values(self):
self.headers["Connection"] = "close"
self.assertEqual(self.headers.get_all("Connection"), ["Upgrade", "close"])

def test_raw_items(self):
self.assertEqual(
list(self.headers.raw_items()),
[("Connection", "Upgrade"), ("Server", "websockets")],
)


class MultiValueHeadersTests(unittest.TestCase):
def setUp(self):
self.headers = Headers([("Server", "Python"), ("Server", "websockets")])

def test_init_from_headers(self):
self.assertEqual(
Headers(self.headers),
self.headers,
)

def test_init_from_headers_and_kwargs(self):
self.assertEqual(
Headers(Headers(server="Python"), server="websockets"),
self.headers,
)

def test_str(self):
self.assertEqual(
str(self.headers), "Server: Python\r\nServer: websockets\r\n\r\n"
)

def test_repr(self):
self.assertEqual(
repr(self.headers),
"Headers([('Server', 'Python'), ('Server', 'websockets')])",
)

def test_copy(self):
self.assertEqual(repr(self.headers.copy()), repr(self.headers))

def test_serialize(self):
self.assertEqual(
self.headers.serialize(),
b"Server: Python\r\nServer: websockets\r\n\r\n",
)

def test_iter(self):
self.assertEqual(set(iter(self.headers)), {"server"})

def test_len(self):
self.assertEqual(len(self.headers), 1)

def test_getitem_multiple_values_error(self):
with self.assertRaises(MultipleValuesError):
self.headers["Server"]

def test_setitem(self):
self.headers["Server"] = "redux"
self.assertEqual(
self.headers.get_all("Server"), ["Python", "websockets", "redux"]
)

def test_delitem(self):
del self.headers["Server"]
with self.assertRaises(KeyError):
self.headers["Server"]

def test_get_all(self):
self.assertEqual(self.headers.get_all("Server"), ["Python", "websockets"])

def test_raw_items(self):
self.assertEqual(
list(self.headers.raw_items()),
[("Server", "Python"), ("Server", "websockets")],
)

0 comments on commit 5de7b41

Please sign in to comment.