Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Convert events worker database to async/await. #8071

Merged
merged 15 commits into from
Aug 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8071.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
2 changes: 1 addition & 1 deletion synapse/event_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def check(
Args:
room_version_obj: the version of the room
event: the event being checked.
auth_events (dict: event-key -> event): the existing room state.
auth_events: the existing room state.

Raises:
AuthError if the checks fail
Expand Down
16 changes: 5 additions & 11 deletions synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1777,9 +1777,7 @@ async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase
"""Returns the state at the event. i.e. not including said event.
"""

event = await self.store.get_event(
event_id, allow_none=False, check_room_id=room_id
)
event = await self.store.get_event(event_id, check_room_id=room_id)

state_groups = await self.state_store.get_state_groups(room_id, [event_id])

Expand All @@ -1805,9 +1803,7 @@ async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
"""Returns the state at the event. i.e. not including said event.
"""
event = await self.store.get_event(
event_id, allow_none=False, check_room_id=room_id
)
event = await self.store.get_event(event_id, check_room_id=room_id)

state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])

Expand Down Expand Up @@ -2155,9 +2151,9 @@ async def _check_for_soft_fail(
auth_types = auth_types_for_event(event)
current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]

current_auth_events = await self.store.get_events(current_state_ids)
auth_events_map = await self.store.get_events(current_state_ids)
current_auth_events = {
(e.type, e.state_key): e for e in current_auth_events.values()
(e.type, e.state_key): e for e in auth_events_map.values()
}

try:
Expand All @@ -2173,9 +2169,7 @@ async def on_query_auth(
if not in_room:
raise AuthError(403, "Host not in room.")

event = await self.store.get_event(
event_id, allow_none=False, check_room_id=room_id
)
event = await self.store.get_event(event_id, check_room_id=room_id)

# Just go through and process each event in `remote_auth_chain`. We
# don't want to fall into the trap of `missing` being wrong.
Expand Down
6 changes: 3 additions & 3 deletions synapse/handlers/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,7 @@ async def persist_and_notify_client_event(
allow_none=True,
)

is_admin_redaction = (
is_admin_redaction = bool(
original_event and event.sender != original_event.sender
)

Expand Down Expand Up @@ -1080,8 +1080,8 @@ def is_inviter_member_event(e):
auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
auth_events_map = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events_map.values()}

room_version = await self.store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/room_member.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ async def _can_guest_join(

guest_access = await self.store.get_event(guest_access_id)

return (
return bool(
guest_access
and guest_access.content
and "guest_access" in guest_access.content
Expand Down
2 changes: 1 addition & 1 deletion synapse/spam_checker_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,5 @@ def get_state_events_in_room(self, room_id: str, types: tuple) -> defer.Deferred
state_ids = yield self._store.get_filtered_current_state_ids(
room_id=room_id, state_filter=StateFilter.from_types(types)
)
state = yield self._store.get_events(state_ids.values())
state = yield defer.ensureDeferred(self._store.get_events(state_ids.values()))
return state.values()
2 changes: 1 addition & 1 deletion synapse/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ def get_events(self, event_ids, allow_rejected=False):
allow_rejected (bool): If True return rejected events.

Returns:
Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
Awaitable[dict[str, FrozenEvent]]: Dict from event_id to event.
"""

return self.store.get_events(
Expand Down
30 changes: 14 additions & 16 deletions synapse/storage/databases/main/event_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@


class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
def get_auth_chain(self, event_ids, include_given=False):
async def get_auth_chain(self, event_ids, include_given=False):
"""Get auth events for given event_ids. The events *must* be state events.

Args:
Expand All @@ -40,9 +40,10 @@ def get_auth_chain(self, event_ids, include_given=False):
Returns:
list of events
"""
return self.get_auth_chain_ids(
event_ids = await self.get_auth_chain_ids(
event_ids, include_given=include_given
).addCallback(self.get_events_as_list)
)
return await self.get_events_as_list(event_ids)

def get_auth_chain_ids(
self,
Expand Down Expand Up @@ -459,7 +460,7 @@ def get_forward_extremeties_for_room_txn(txn):
"get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
)

def get_backfill_events(self, room_id, event_list, limit):
async def get_backfill_events(self, room_id, event_list, limit):
"""Get a list of Events for a given topic that occurred before (and
including) the events in event_list. Return a list of max size `limit`

Expand All @@ -469,17 +470,15 @@ def get_backfill_events(self, room_id, event_list, limit):
event_list (list)
limit (int)
"""
return (
self.db_pool.runInteraction(
"get_backfill_events",
self._get_backfill_events,
room_id,
event_list,
limit,
)
.addCallback(self.get_events_as_list)
.addCallback(lambda l: sorted(l, key=lambda e: -e.depth))
event_ids = await self.db_pool.runInteraction(
"get_backfill_events",
self._get_backfill_events,
room_id,
event_list,
limit,
)
events = await self.get_events_as_list(event_ids)
return sorted(events, key=lambda e: -e.depth)

def _get_backfill_events(self, txn, room_id, event_list, limit):
logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit)
Expand Down Expand Up @@ -540,8 +539,7 @@ async def get_missing_events(self, room_id, earliest_events, latest_events, limi
latest_events,
limit,
)
events = await self.get_events_as_list(ids)
return events
return await self.get_events_as_list(ids)

def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):

Expand Down
Loading