Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Make get_auth_chain_ids return a Set
Browse files Browse the repository at this point in the history
It has a set internally, and a set is often useful where it gets used, so let's
avoid converting to an intermediate list.
  • Loading branch information
richvdh committed Feb 11, 2022
1 parent a121507 commit 2676794
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion synapse/federation/federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ async def _on_state_ids_request_compute(
) -> JsonDict:
state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids)
return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
return {"pdu_ids": state_ids, "auth_chain_ids": list(auth_chain_ids)}

async def _on_context_state_request_compute(
self, room_id: str, event_id: Optional[str]
Expand Down
12 changes: 6 additions & 6 deletions synapse/storage/databases/main/event_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ async def get_auth_chain_ids(
room_id: str,
event_ids: Collection[str],
include_given: bool = False,
) -> List[str]:
) -> Set[str]:
"""Get auth events for given event_ids. The events *must* be state events.
Args:
Expand All @@ -130,7 +130,7 @@ async def get_auth_chain_ids(
include_given: include the given events in result
Returns:
list of event_ids
set of event_ids
"""

# Check if we have indexed the room so we can use the chain cover
Expand Down Expand Up @@ -159,7 +159,7 @@ async def get_auth_chain_ids(

def _get_auth_chain_ids_using_cover_index_txn(
self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool
) -> List[str]:
) -> Set[str]:
"""Calculates the auth chain IDs using the chain index."""

# First we look up the chain ID/sequence numbers for the given events.
Expand Down Expand Up @@ -272,11 +272,11 @@ def _get_auth_chain_ids_using_cover_index_txn(
txn.execute(sql, (chain_id, max_no))
results.update(r for r, in txn)

return list(results)
return results

def _get_auth_chain_ids_txn(
self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
) -> List[str]:
) -> Set[str]:
"""Calculates the auth chain IDs.
This is used when we don't have a cover index for the room.
Expand Down Expand Up @@ -331,7 +331,7 @@ def _get_auth_chain_ids_txn(
front = new_front
results.update(front)

return list(results)
return results

async def get_auth_chain_difference(
self, room_id: str, state_sets: List[Set[str]]
Expand Down
8 changes: 4 additions & 4 deletions tests/storage/test_event_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,16 +260,16 @@ def test_auth_chain_ids(self, use_chain_cover_index: bool):
self.assertCountEqual(auth_chain_ids, ["h", "i", "j", "k"])

auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["h"]))
self.assertEqual(auth_chain_ids, ["k"])
self.assertEqual(auth_chain_ids, {"k"})

auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["i"]))
self.assertEqual(auth_chain_ids, ["j"])
self.assertEqual(auth_chain_ids, {"j"})

# j and k have no parents.
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["j"]))
self.assertEqual(auth_chain_ids, [])
self.assertEqual(auth_chain_ids, set())
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["k"]))
self.assertEqual(auth_chain_ids, [])
self.assertEqual(auth_chain_ids, set())

# More complex input sequences.
auth_chain_ids = self.get_success(
Expand Down

0 comments on commit 2676794

Please sign in to comment.