Skip to content

Commit

Permalink
Refactor tests for redirects.
Browse files Browse the repository at this point in the history
  • Loading branch information
aaugustin committed Aug 11, 2021
1 parent 7861345 commit b343fc6
Showing 1 changed file with 29 additions and 87 deletions.
116 changes: 29 additions & 87 deletions tests/legacy/test_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,20 @@ async def default_handler(ws, path):
await ws.send((await ws.recv()))


async def redirect_request(path, headers, test, status):
if path == "/redirect":
location = get_server_uri(test.server, test.secure, "/")
elif path == "/infinite":
location = get_server_uri(test.server, test.secure, "/infinite")
elif path == "/force_insecure":
location = get_server_uri(test.server, False, "/")
elif path == "/missing_location":
return status, {}, b""
else:
return None
return status, {"Location": location}, b""


@contextlib.contextmanager
def temp_test_server(test, **kwargs):
test.start_server(**kwargs)
Expand All @@ -84,15 +98,9 @@ def temp_test_server(test, **kwargs):
test.stop_server()


@contextlib.contextmanager
def temp_test_redirecting_server(
test, status, include_location=True, force_insecure=False, **kwargs
):
test.start_redirecting_server(status, include_location, force_insecure, **kwargs)
try:
yield
finally:
test.stop_redirecting_server()
def temp_test_redirecting_server(test, status=http.HTTPStatus.FOUND, **kwargs):
process_request = functools.partial(redirect_request, test=test, status=status)
return temp_test_server(test, process_request=process_request, **kwargs)


@contextlib.contextmanager
Expand Down Expand Up @@ -201,11 +209,6 @@ class ClientServerTestsMixin:
def setUp(self):
super().setUp()
self.server = None
self.redirecting_server = None

@property
def server_context(self):
return None

def start_server(self, deprecation_warnings=None, **kwargs):
handler = kwargs.pop("handler", default_handler)
Expand All @@ -226,42 +229,6 @@ def start_server(self, deprecation_warnings=None, **kwargs):
expected_warnings += ["There is no current event loop"]
self.assertDeprecationWarnings(recorded_warnings, expected_warnings)

def start_redirecting_server(
self,
status,
include_location=True,
force_insecure=False,
deprecation_warnings=None,
**kwargs,
):
async def process_request(path, headers):
server_uri = get_server_uri(self.server, self.secure, path)
if force_insecure:
server_uri = server_uri.replace("wss:", "ws:")
headers = {"Location": server_uri} if include_location else []
return status, headers, b""

with warnings.catch_warnings(record=True) as recorded_warnings:
start_server = serve(
default_handler,
"localhost",
0,
compression=None,
ping_interval=None,
process_request=process_request,
ssl=self.server_context,
**kwargs,
)
self.redirecting_server = self.loop.run_until_complete(start_server)

expected_warnings = [] if deprecation_warnings is None else deprecation_warnings
if (
sys.version_info[:2] >= (3, 10)
and "remove loop argument" not in expected_warnings
): # pragma: no cover
expected_warnings += ["There is no current event loop"]
self.assertDeprecationWarnings(recorded_warnings, expected_warnings)

def start_client(
self, resource_name="/", user_info=None, deprecation_warnings=None, **kwargs
):
Expand All @@ -274,8 +241,7 @@ def start_client(
try:
server_uri = kwargs.pop("uri")
except KeyError:
server = self.redirecting_server if self.redirecting_server else self.server
server_uri = get_server_uri(server, secure, resource_name, user_info)
server_uri = get_server_uri(self.server, secure, resource_name, user_info)

with warnings.catch_warnings(record=True) as recorded_warnings:
start_client = connect(server_uri, **kwargs)
Expand Down Expand Up @@ -306,17 +272,6 @@ def stop_server(self):
except asyncio.TimeoutError: # pragma: no cover
self.fail("Server failed to stop")

def stop_redirecting_server(self):
self.redirecting_server.close()
try:
self.loop.run_until_complete(
asyncio.wait_for(self.redirecting_server.wait_closed(), timeout=1)
)
except asyncio.TimeoutError: # pragma: no cover
self.fail("Redirecting server failed to stop")
finally:
self.redirecting_server = None

@contextlib.contextmanager
def temp_server(self, **kwargs):
with temp_test_server(self, **kwargs):
Expand Down Expand Up @@ -388,7 +343,6 @@ def test_basic(self):
reply = self.loop.run_until_complete(self.client.recv())
self.assertEqual(reply, "Hello!")

@with_server()
def test_redirect(self):
redirect_statuses = [
http.HTTPStatus.MOVED_PERMANENTLY,
Expand All @@ -399,40 +353,31 @@ def test_redirect(self):
]
for status in redirect_statuses:
with temp_test_redirecting_server(self, status):
with temp_test_client(self):
with self.temp_client("/redirect"):
self.loop.run_until_complete(self.client.send("Hello!"))
reply = self.loop.run_until_complete(self.client.recv())
self.assertEqual(reply, "Hello!")

def test_infinite_redirect(self):
with temp_test_redirecting_server(
self,
http.HTTPStatus.FOUND,
):
self.server = self.redirecting_server
with temp_test_redirecting_server(self):
with self.assertRaises(InvalidHandshake):
with temp_test_client(self):
with self.temp_client("/infinite"):
self.fail("Did not raise") # pragma: no cover

@with_server()
def test_redirect_missing_location(self):
with temp_test_redirecting_server(
self,
http.HTTPStatus.FOUND,
include_location=False,
loop=self.loop,
deprecation_warnings=["remove loop argument"],
):
with temp_test_redirecting_server(self):
with self.assertRaises(InvalidHeader):
with temp_test_client(self):
with self.temp_client("/missing_location"):
self.fail("Did not raise") # pragma: no cover

def test_loop_backwards_compatibility(self):
with self.temp_server(
loop=self.loop, deprecation_warnings=["remove loop argument"]
loop=self.loop,
deprecation_warnings=["remove loop argument"],
):
with self.temp_client(
loop=self.loop, deprecation_warnings=["remove loop argument"]
loop=self.loop,
deprecation_warnings=["remove loop argument"],
):
self.loop.run_until_complete(self.client.send("Hello!"))
reply = self.loop.run_until_complete(self.client.recv())
Expand Down Expand Up @@ -1274,13 +1219,10 @@ def test_ws_uri_is_rejected(self):
uri=get_server_uri(self.server, secure=False), ssl=self.client_context
)

@with_server()
def test_redirect_insecure(self):
with temp_test_redirecting_server(
self, http.HTTPStatus.FOUND, force_insecure=True
):
with temp_test_redirecting_server(self):
with self.assertRaises(InvalidHandshake):
with temp_test_client(self):
with self.temp_client("/force_insecure"):
self.fail("Did not raise") # pragma: no cover


Expand Down

0 comments on commit b343fc6

Please sign in to comment.