Skip to content

Commit

Permalink
Simplify extra_headers implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
aaugustin committed Aug 12, 2021
1 parent 5de7b41 commit 26c1779
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 77 deletions.
9 changes: 1 addition & 8 deletions src/websockets/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import collections
from typing import Generator, List, Optional, Sequence

from .connection import CLIENT, CONNECTING, OPEN, Connection, State
Expand Down Expand Up @@ -105,13 +104,7 @@ def connect(self) -> Request: # noqa: F811
headers["Sec-WebSocket-Protocol"] = protocol_header

if self.extra_headers is not None:
extra_headers = self.extra_headers
if isinstance(extra_headers, Headers):
extra_headers = extra_headers.raw_items()
elif isinstance(extra_headers, collections.abc.Mapping):
extra_headers = extra_headers.items()
for name, value in extra_headers:
headers[name] = value
headers.update(self.extra_headers)

headers.setdefault("User-Agent", USER_AGENT)

Expand Down
10 changes: 2 additions & 8 deletions src/websockets/legacy/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from __future__ import annotations

import asyncio
import collections.abc
import functools
import logging
import urllib.parse
Expand Down Expand Up @@ -387,13 +386,8 @@ async def handshake(
protocol_header = build_subprotocol(available_subprotocols)
request_headers["Sec-WebSocket-Protocol"] = protocol_header

if extra_headers is not None:
if isinstance(extra_headers, Headers):
extra_headers = extra_headers.raw_items()
elif isinstance(extra_headers, collections.abc.Mapping):
extra_headers = extra_headers.items()
for name, value in extra_headers:
request_headers[name] = value
if self.extra_headers is not None:
request_headers.update(self.extra_headers)

request_headers.setdefault("User-Agent", USER_AGENT)

Expand Down
8 changes: 1 addition & 7 deletions src/websockets/legacy/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from __future__ import annotations

import asyncio
import collections.abc
import email.utils
import functools
import http
Expand Down Expand Up @@ -697,12 +696,7 @@ async def handshake(
if callable(extra_headers):
extra_headers = extra_headers(path, self.request_headers)
if extra_headers is not None:
if isinstance(extra_headers, Headers):
extra_headers = extra_headers.raw_items()
elif isinstance(extra_headers, collections.abc.Mapping):
extra_headers = extra_headers.items()
for name, value in extra_headers:
response_headers[name] = value
response_headers.update(extra_headers)

response_headers.setdefault("Date", email.utils.formatdate(usegmt=True))
response_headers.setdefault("Server", USER_AGENT)
Expand Down
8 changes: 1 addition & 7 deletions src/websockets/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import base64
import binascii
import collections
import email.utils
import http
from typing import Callable, Generator, List, Optional, Sequence, Tuple, Union, cast
Expand Down Expand Up @@ -142,12 +141,7 @@ def accept(self, request: Request) -> Response:
else:
extra_headers = self.extra_headers
if extra_headers is not None:
if isinstance(extra_headers, Headers):
extra_headers = extra_headers.raw_items()
elif isinstance(extra_headers, collections.abc.Mapping):
extra_headers = extra_headers.items()
for name, value in extra_headers:
headers[name] = value
headers.update(extra_headers)

headers.setdefault("Date", email.utils.formatdate(usegmt=True))
headers.setdefault("Server", USER_AGENT)
Expand Down
52 changes: 5 additions & 47 deletions tests/legacy/test_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,52 +608,24 @@ def test_protocol_headers(self):
self.assertEqual(server_req, repr(client_req))
self.assertEqual(server_resp, repr(client_resp))

@with_server()
@with_client("/headers", extra_headers=Headers({"X-Spam": "Eggs"}))
def test_protocol_custom_request_headers(self):
req_headers = self.loop.run_until_complete(self.client.recv())
self.loop.run_until_complete(self.client.recv())
self.assertIn("('X-Spam', 'Eggs')", req_headers)

@with_server()
@with_client("/headers", extra_headers={"X-Spam": "Eggs"})
def test_protocol_custom_request_headers_dict(self):
req_headers = self.loop.run_until_complete(self.client.recv())
self.loop.run_until_complete(self.client.recv())
self.assertIn("('X-Spam', 'Eggs')", req_headers)

@with_server()
@with_client("/headers", extra_headers=[("X-Spam", "Eggs")])
def test_protocol_custom_request_headers_list(self):
def test_protocol_custom_request_headers(self):
req_headers = self.loop.run_until_complete(self.client.recv())
self.loop.run_until_complete(self.client.recv())
self.assertIn("('X-Spam', 'Eggs')", req_headers)

@with_server()
@with_client("/headers", extra_headers=[("User-Agent", "Eggs")])
@with_client("/headers", extra_headers={"User-Agent": "Eggs"})
def test_protocol_custom_request_user_agent(self):
req_headers = self.loop.run_until_complete(self.client.recv())
self.loop.run_until_complete(self.client.recv())
self.assertEqual(req_headers.count("User-Agent"), 1)
self.assertIn("('User-Agent', 'Eggs')", req_headers)

@with_server(extra_headers=lambda p, r: Headers({"X-Spam": "Eggs"}))
@with_client("/headers")
def test_protocol_custom_response_headers_callable(self):
self.loop.run_until_complete(self.client.recv())
resp_headers = self.loop.run_until_complete(self.client.recv())
self.assertIn("('X-Spam', 'Eggs')", resp_headers)

@with_server(extra_headers=lambda p, r: {"X-Spam": "Eggs"})
@with_client("/headers")
def test_protocol_custom_response_headers_callable_dict(self):
self.loop.run_until_complete(self.client.recv())
resp_headers = self.loop.run_until_complete(self.client.recv())
self.assertIn("('X-Spam', 'Eggs')", resp_headers)

@with_server(extra_headers=lambda p, r: [("X-Spam", "Eggs")])
@with_client("/headers")
def test_protocol_custom_response_headers_callable_list(self):
def test_protocol_custom_response_headers_callable(self):
self.loop.run_until_complete(self.client.recv())
resp_headers = self.loop.run_until_complete(self.client.recv())
self.assertIn("('X-Spam', 'Eggs')", resp_headers)
Expand All @@ -664,28 +636,14 @@ def test_protocol_custom_response_headers_callable_none(self):
self.loop.run_until_complete(self.client.recv()) # doesn't crash
self.loop.run_until_complete(self.client.recv()) # nothing to check

@with_server(extra_headers=Headers({"X-Spam": "Eggs"}))
@with_client("/headers")
def test_protocol_custom_response_headers(self):
self.loop.run_until_complete(self.client.recv())
resp_headers = self.loop.run_until_complete(self.client.recv())
self.assertIn("('X-Spam', 'Eggs')", resp_headers)

@with_server(extra_headers={"X-Spam": "Eggs"})
@with_client("/headers")
def test_protocol_custom_response_headers_dict(self):
self.loop.run_until_complete(self.client.recv())
resp_headers = self.loop.run_until_complete(self.client.recv())
self.assertIn("('X-Spam', 'Eggs')", resp_headers)

@with_server(extra_headers=[("X-Spam", "Eggs")])
@with_client("/headers")
def test_protocol_custom_response_headers_list(self):
def test_protocol_custom_response_headers(self):
self.loop.run_until_complete(self.client.recv())
resp_headers = self.loop.run_until_complete(self.client.recv())
self.assertIn("('X-Spam', 'Eggs')", resp_headers)

@with_server(extra_headers=[("Server", "Eggs")])
@with_server(extra_headers={"Server": "Eggs"})
@with_client("/headers")
def test_protocol_custom_response_user_agent(self):
self.loop.run_until_complete(self.client.recv())
Expand Down

0 comments on commit 26c1779

Please sign in to comment.