From 267679428afeb60b27d96712d89d561c3fd0ec6b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 11 Feb 2022 12:56:32 +0000 Subject: [PATCH] Make `get_auth_chain_ids` return a Set It has a set internally, and a set is often useful where it gets used, so let's avoid converting to an intermediate list. --- synapse/federation/federation_server.py | 2 +- synapse/storage/databases/main/event_federation.py | 12 ++++++------ tests/storage/test_event_federation.py | 8 ++++---- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index af9cb98f67b1..ea51012baa4f 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -571,7 +571,7 @@ async def _on_state_ids_request_compute( ) -> JsonDict: state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id) auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids) - return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} + return {"pdu_ids": state_ids, "auth_chain_ids": list(auth_chain_ids)} async def _on_context_state_request_compute( self, room_id: str, event_id: Optional[str] diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 22f64741277a..277e6422ebd6 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -121,7 +121,7 @@ async def get_auth_chain_ids( room_id: str, event_ids: Collection[str], include_given: bool = False, - ) -> List[str]: + ) -> Set[str]: """Get auth events for given event_ids. The events *must* be state events. Args: @@ -130,7 +130,7 @@ async def get_auth_chain_ids( include_given: include the given events in result Returns: - list of event_ids + set of event_ids """ # Check if we have indexed the room so we can use the chain cover @@ -159,7 +159,7 @@ async def get_auth_chain_ids( def _get_auth_chain_ids_using_cover_index_txn( self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool - ) -> List[str]: + ) -> Set[str]: """Calculates the auth chain IDs using the chain index.""" # First we look up the chain ID/sequence numbers for the given events. @@ -272,11 +272,11 @@ def _get_auth_chain_ids_using_cover_index_txn( txn.execute(sql, (chain_id, max_no)) results.update(r for r, in txn) - return list(results) + return results def _get_auth_chain_ids_txn( self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool - ) -> List[str]: + ) -> Set[str]: """Calculates the auth chain IDs. This is used when we don't have a cover index for the room. @@ -331,7 +331,7 @@ def _get_auth_chain_ids_txn( front = new_front results.update(front) - return list(results) + return results async def get_auth_chain_difference( self, room_id: str, state_sets: List[Set[str]] diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 2bc89512f8ae..667ca90a4d3e 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -260,16 +260,16 @@ def test_auth_chain_ids(self, use_chain_cover_index: bool): self.assertCountEqual(auth_chain_ids, ["h", "i", "j", "k"]) auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["h"])) - self.assertEqual(auth_chain_ids, ["k"]) + self.assertEqual(auth_chain_ids, {"k"}) auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["i"])) - self.assertEqual(auth_chain_ids, ["j"]) + self.assertEqual(auth_chain_ids, {"j"}) # j and k have no parents. auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["j"])) - self.assertEqual(auth_chain_ids, []) + self.assertEqual(auth_chain_ids, set()) auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["k"])) - self.assertEqual(auth_chain_ids, []) + self.assertEqual(auth_chain_ids, set()) # More complex input sequences. auth_chain_ids = self.get_success(