Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Clean up the test code for client disconnections #12929

Merged
merged 7 commits into from
Jun 7, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12929.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Clean up the test code for client disconnection.
10 changes: 4 additions & 6 deletions tests/federation/transport/server/test__base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from synapse.util.ratelimitutils import FederationRateLimiter

from tests import unittest
from tests.http.server._base import EndpointCancellationTestHelperMixin
from tests.http.server._base import test_disconnect


class CancellableFederationServlet(BaseFederationServlet):
Expand Down Expand Up @@ -54,9 +54,7 @@ async def on_POST(
return HTTPStatus.OK, {"result": True}


class BaseFederationServletCancellationTests(
unittest.FederatingHomeserverTestCase, EndpointCancellationTestHelperMixin
):
class BaseFederationServletCancellationTests(unittest.FederatingHomeserverTestCase):
"""Tests for `BaseFederationServlet` cancellation."""

skip = "`BaseFederationServlet` does not support cancellation yet."
Expand Down Expand Up @@ -86,7 +84,7 @@ def test_cancellable_disconnect(self) -> None:
# request won't be processed.
self.pump()

self._test_disconnect(
test_disconnect(
self.reactor,
channel,
expect_cancellation=True,
Expand All @@ -106,7 +104,7 @@ def test_uncancellable_disconnect(self) -> None:
# request won't be processed.
self.pump()

self._test_disconnect(
test_disconnect(
self.reactor,
channel,
expect_cancellation=False,
Expand Down
138 changes: 72 additions & 66 deletions tests/http/server/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,76 +25,82 @@
)
from synapse.types import JsonDict

from tests import unittest
from tests.server import FakeChannel, ThreadedMemoryReactorClock


class EndpointCancellationTestHelperMixin(unittest.TestCase):
"""Provides helper methods for testing cancellation of endpoints."""

def _test_disconnect(
self,
reactor: ThreadedMemoryReactorClock,
channel: FakeChannel,
expect_cancellation: bool,
expected_body: Union[bytes, JsonDict],
expected_code: Optional[int] = None,
) -> None:
"""Disconnects an in-flight request and checks the response.

Args:
reactor: The twisted reactor running the request handler.
channel: The `FakeChannel` for the request.
expect_cancellation: `True` if request processing is expected to be
cancelled, `False` if the request should run to completion.
expected_body: The expected response for the request.
expected_code: The expected status code for the request. Defaults to `200`
or `499` depending on `expect_cancellation`.
"""
# Determine the expected status code.
if expected_code is None:
if expect_cancellation:
expected_code = HTTP_STATUS_REQUEST_CANCELLED
else:
expected_code = HTTPStatus.OK

request = channel.request
self.assertFalse(
channel.is_finished(),
def test_disconnect(
reactor: ThreadedMemoryReactorClock,
channel: FakeChannel,
expect_cancellation: bool,
expected_body: Union[bytes, JsonDict],
expected_code: Optional[int] = None,
) -> None:
"""Disconnects an in-flight request and checks the response.

Args:
reactor: The twisted reactor running the request handler.
channel: The `FakeChannel` for the request.
expect_cancellation: `True` if request processing is expected to be cancelled,
`False` if the request should run to completion.
expected_body: The expected response for the request.
expected_code: The expected status code for the request. Defaults to `200` or
`499` depending on `expect_cancellation`.
"""
# Determine the expected status code.
if expected_code is None:
if expect_cancellation:
expected_code = HTTP_STATUS_REQUEST_CANCELLED
else:
expected_code = HTTPStatus.OK

request = channel.request
if channel.is_finished():
raise AssertionError(
"Request finished before we could disconnect - "
"was `await_result=False` passed to `make_request`?",
"ensure `await_result=False` is passed to `make_request`.",
)

# We're about to disconnect the request. This also disconnects the channel, so
# we have to rely on mocks to extract the response.
respond_method: Callable[..., Any]
if isinstance(expected_body, bytes):
respond_method = respond_with_html_bytes
# We're about to disconnect the request. This also disconnects the channel, so we
# have to rely on mocks to extract the response.
respond_method: Callable[..., Any]
if isinstance(expected_body, bytes):
respond_method = respond_with_html_bytes
else:
respond_method = respond_with_json

with mock.patch(
f"synapse.http.server.{respond_method.__name__}", wraps=respond_method
) as respond_mock:
# Disconnect the request.
request.connectionLost(reason=ConnectionDone())

if expect_cancellation:
# An immediate cancellation is expected.
respond_mock.assert_called_once()
else:
respond_method = respond_with_json

with mock.patch(
f"synapse.http.server.{respond_method.__name__}", wraps=respond_method
) as respond_mock:
# Disconnect the request.
request.connectionLost(reason=ConnectionDone())

if expect_cancellation:
# An immediate cancellation is expected.
respond_mock.assert_called_once()
args, _kwargs = respond_mock.call_args
code, body = args[1], args[2]
self.assertEqual(code, expected_code)
self.assertEqual(request.code, expected_code)
self.assertEqual(body, expected_body)
else:
respond_mock.assert_not_called()

# The handler is expected to run to completion.
reactor.pump([1.0])
respond_mock.assert_called_once()
args, _kwargs = respond_mock.call_args
code, body = args[1], args[2]
self.assertEqual(code, expected_code)
self.assertEqual(request.code, expected_code)
self.assertEqual(body, expected_body)
respond_mock.assert_not_called()

# The handler is expected to run to completion.
reactor.advance(1.0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that we can't use channel.await_result() here because we've already disconnected the channel above.

respond_mock.assert_called_once()

args, _kwargs = respond_mock.call_args
code, body = args[1], args[2]

if code != expected_code:
raise AssertionError(
f"{code} != {expected_code} : "
"Request did not finish with the expected status code."
)

if request.code != expected_code:
raise AssertionError(
f"{request.code} != {expected_code} : "
"Request did not finish with the expected status code."
)

if body != expected_body:
raise AssertionError(
f"{body!r} != {expected_body!r} : "
"Request did not finish with the expected status code."
)
10 changes: 4 additions & 6 deletions tests/http/test_servlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from synapse.types import JsonDict

from tests import unittest
from tests.http.server._base import EndpointCancellationTestHelperMixin
from tests.http.server._base import test_disconnect


def make_request(content):
Expand Down Expand Up @@ -108,9 +108,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
return HTTPStatus.OK, {"result": True}


class TestRestServletCancellation(
unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin
):
class TestRestServletCancellation(unittest.HomeserverTestCase):
"""Tests for `RestServlet` cancellation."""

servlets = [
Expand All @@ -120,7 +118,7 @@ class TestRestServletCancellation(
def test_cancellable_disconnect(self) -> None:
"""Test that handlers with the `@cancellable` flag can be cancelled."""
channel = self.make_request("GET", "/sleep", await_result=False)
self._test_disconnect(
test_disconnect(
self.reactor,
channel,
expect_cancellation=True,
Expand All @@ -130,7 +128,7 @@ def test_cancellable_disconnect(self) -> None:
def test_uncancellable_disconnect(self) -> None:
"""Test that handlers without the `@cancellable` flag cannot be cancelled."""
channel = self.make_request("POST", "/sleep", await_result=False)
self._test_disconnect(
test_disconnect(
self.reactor,
channel,
expect_cancellation=False,
Expand Down
10 changes: 4 additions & 6 deletions tests/replication/http/test__base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from synapse.types import JsonDict

from tests import unittest
from tests.http.server._base import EndpointCancellationTestHelperMixin
from tests.http.server._base import test_disconnect


class CancellableReplicationEndpoint(ReplicationEndpoint):
Expand Down Expand Up @@ -69,9 +69,7 @@ async def _handle_request( # type: ignore[override]
return HTTPStatus.OK, {"result": True}


class ReplicationEndpointCancellationTestCase(
unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin
):
class ReplicationEndpointCancellationTestCase(unittest.HomeserverTestCase):
"""Tests for `ReplicationEndpoint` cancellation."""

def create_test_resource(self):
Expand All @@ -87,7 +85,7 @@ def test_cancellable_disconnect(self) -> None:
"""Test that handlers with the `@cancellable` flag can be cancelled."""
path = f"{REPLICATION_PREFIX}/{CancellableReplicationEndpoint.NAME}/"
channel = self.make_request("POST", path, await_result=False)
self._test_disconnect(
test_disconnect(
self.reactor,
channel,
expect_cancellation=True,
Expand All @@ -98,7 +96,7 @@ def test_uncancellable_disconnect(self) -> None:
"""Test that handlers without the `@cancellable` flag cannot be cancelled."""
path = f"{REPLICATION_PREFIX}/{UncancellableReplicationEndpoint.NAME}/"
channel = self.make_request("POST", path, await_result=False)
self._test_disconnect(
test_disconnect(
self.reactor,
channel,
expect_cancellation=False,
Expand Down
14 changes: 7 additions & 7 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from synapse.util import Clock

from tests import unittest
from tests.http.server._base import EndpointCancellationTestHelperMixin
from tests.http.server._base import test_disconnect
from tests.server import (
FakeSite,
ThreadedMemoryReactorClock,
Expand Down Expand Up @@ -407,7 +407,7 @@ async def _async_render_POST(self, request: SynapseRequest) -> Tuple[int, bytes]
return HTTPStatus.OK, b"ok"


class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMixin):
class DirectServeJsonResourceCancellationTests(unittest.TestCase):
"""Tests for `DirectServeJsonResource` cancellation."""

def setUp(self):
Expand All @@ -421,7 +421,7 @@ def test_cancellable_disconnect(self) -> None:
channel = make_request(
self.reactor, self.site, "GET", "/sleep", await_result=False
)
self._test_disconnect(
test_disconnect(
self.reactor,
channel,
expect_cancellation=True,
Expand All @@ -433,15 +433,15 @@ def test_uncancellable_disconnect(self) -> None:
channel = make_request(
self.reactor, self.site, "POST", "/sleep", await_result=False
)
self._test_disconnect(
test_disconnect(
self.reactor,
channel,
expect_cancellation=False,
expected_body={"result": True},
)


class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMixin):
class DirectServeHtmlResourceCancellationTests(unittest.TestCase):
"""Tests for `DirectServeHtmlResource` cancellation."""

def setUp(self):
Expand All @@ -455,7 +455,7 @@ def test_cancellable_disconnect(self) -> None:
channel = make_request(
self.reactor, self.site, "GET", "/sleep", await_result=False
)
self._test_disconnect(
test_disconnect(
self.reactor,
channel,
expect_cancellation=True,
Expand All @@ -467,6 +467,6 @@ def test_uncancellable_disconnect(self) -> None:
channel = make_request(
self.reactor, self.site, "POST", "/sleep", await_result=False
)
self._test_disconnect(
test_disconnect(
self.reactor, channel, expect_cancellation=False, expected_body=b"ok"
)