diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 2b0cac77ac50..1be366702b51 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -24,9 +24,11 @@ from synapse.api.errors import NotFoundError, SynapseError from synapse.appservice import ApplicationService from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler +from synapse.rest import admin +from synapse.rest.client import devices, login, register from synapse.server import HomeServer from synapse.storage.databases.main.appservice import _make_exclusive_regex -from synapse.types import create_requester, JsonDict +from synapse.types import JsonDict, create_requester from synapse.util import Clock from tests import unittest @@ -401,11 +403,19 @@ def test_on_federation_query_user_devices_appservice(self) -> None: class DehydrationTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + register.register_servlets, + devices.register_servlets, + ] + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver("server", federation_http_client=None) handler = hs.get_device_handler() assert isinstance(handler, DeviceHandler) self.handler = handler + self.message_handler = hs.get_device_message_handler() self.registration = hs.get_registration_handler() self.auth = hs.get_auth() self.store = hs.get_datastores().main @@ -501,10 +511,11 @@ def test_dehydrate_v2_and_fetch_events(self) -> None: ) ) - retrieved_device_id, device_data = self.get_success( + device_info = self.get_success( self.handler.get_dehydrated_device(user_id=user_id) ) - + assert device_info is not None + retrieved_device_id, device_data = device_info self.assertEqual(retrieved_device_id, stored_dehydrated_device_id) self.assertEqual(device_data, {"device_data": {"foo": "bar"}}) @@ -566,21 +577,20 @@ def test_dehydrate_v2_and_fetch_events(self) -> None: self.assertTrue(len(res["next_batch"]) > 1) self.assertEqual(len(res["events"]), 0) - # Fetching messages without since should return nothing, since the messages got deleted - res = self.get_success( + # Fetching messages again should fail, since the messages and dehydrated device + # were deleted + self.get_failure( self.message_handler.get_events_for_dehydrated_device( requester=requester, device_id=stored_dehydrated_device_id, since_token=None, limit=10, - ) + ), + SynapseError, ) - self.assertTrue(len(res["next_batch"]) > 1) - self.assertEqual(len(res["events"]), 0) - # We don't delete the device when fetch messages for now. - # # make sure that the device ID that we were initially assigned no longer exists - # self.get_failure( - # self.handler.get_device(user_id, device_id), - # NotFoundError, - # ) + # make sure that the dehydrated device ID is deleted after fetching messages + res2 = self.get_success( + self.handler.get_dehydrated_device(requester.user.to_string()), + ) + self.assertEqual(res2, None) diff --git a/tests/rest/client/test_devices.py b/tests/rest/client/test_devices.py index d80eea17d3af..25e425b0fe6e 100644 --- a/tests/rest/client/test_devices.py +++ b/tests/rest/client/test_devices.py @@ -13,12 +13,14 @@ # limitations under the License. from http import HTTPStatus +from twisted.internet.defer import ensureDeferred from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import NotFoundError from synapse.rest import admin, devices, room, sync -from synapse.rest.client import account, login, register +from synapse.rest.client import account, keys, login, register from synapse.server import HomeServer +from synapse.types import JsonDict, create_requester from synapse.util import Clock from tests import unittest @@ -208,8 +210,13 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase): login.register_servlets, register.register_servlets, devices.register_servlets, + keys.register_servlets, ] + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.registration = hs.get_registration_handler() + self.message_handler = hs.get_device_message_handler() + def test_PUT(self) -> None: """Sanity-check that we can PUT a dehydrated device. @@ -234,3 +241,128 @@ def test_PUT(self) -> None: self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) device_id = channel.json_body.get("device_id") self.assertIsInstance(device_id, str) + + @unittest.override_config( + {"experimental_features": {"msc2697_enabled": False, "msc3814_enabled": True}} + ) + def test_dehydrate_msc3814(self) -> None: + user = self.register_user("mikey", "pass") + token = self.login(user, "pass", device_id="device1") + content: JsonDict = { + "device_data": { + "algorithm": "m.dehydration.v1.olm", + }, + "initial_device_display_name": "foo bar", + } + channel = self.make_request( + "PUT", + "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device", + content=content, + access_token=token, + shorthand=False, + ) + self.assertEqual(channel.code, 200) + device_id = channel.json_body.get("device_id") + assert device_id is not None + self.assertIsInstance(device_id, str) + + # test that you can upload keys for this device + content = { + "device_keys": { + "algorithms": ["m.olm.v1.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"], + "device_id": f"{device_id}", + "keys": { + "curve25519:JLAFKJWSCS": "3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI", + "ed25519:JLAFKJWSCS": "lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI", + }, + "signatures": { + "@alice:example.com": { + "ed25519:JLAFKJWSCS": "dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA" + } + }, + "user_id": f"{user}", + }, + } + channel = self.make_request( + "POST", + f"/_matrix/client/r0/keys/upload/{device_id}", + content=content, + access_token=token, + ) + self.assertEqual(channel.code, 200) + + # test that we can now GET the dehydrated device info + channel = self.make_request( + "GET", + "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device", + access_token=token, + shorthand=False, + ) + self.assertEqual(channel.code, 200) + returned_device_id = channel.json_body.get("device_id") + self.assertEqual(returned_device_id, device_id) + device_data = channel.json_body.get("device_data") + expected_device_data = { + "algorithm": "m.dehydration.v1.olm", + } + self.assertEqual(device_data, expected_device_data) + + # create another device for the user + ( + new_device_id, + _, + _, + _, + ) = self.get_success( + self.registration.register_device( + user_id=user, + device_id=None, + initial_display_name="new device", + ) + ) + requester = create_requester(user, device_id=new_device_id) + + # Send a message to the dehydrated device + ensureDeferred( + self.message_handler.send_device_message( + requester=requester, + message_type="test.message", + messages={user: {device_id: {"body": "test_message"}}}, + ) + ) + self.pump() + + # make sure we can fetch the message with our dehydrated device id + channel = self.make_request( + "POST", + f"_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device/{device_id}/events", + content={}, + access_token=token, + shorthand=False, + ) + self.assertEqual(channel.code, 200) + expected_content = {"body": "test_message"} + self.assertEqual(channel.json_body["events"][0]["content"], expected_content) + next_batch_token = channel.json_body.get("next_batch") + + # fetch messages again and make sure that the message was deleted and we are returned an + # empty array + content = {"next_batch": next_batch_token} + channel = self.make_request( + "POST", + f"_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device/{device_id}/events", + content=content, + access_token=token, + shorthand=False, + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["events"], []) + + # make sure that the dehydrated device id is deleted after we received the messages + channel = self.make_request( + "GET", + "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device", + access_token=token, + shorthand=False, + ) + self.assertEqual(channel.code, 404)