Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Convert groups and visibility code to async / await. #7951

Merged
merged 3 commits into from
Jul 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/7951.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert groups and visibility code to async / await.
25 changes: 11 additions & 14 deletions synapse/groups/attestations.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@

from signedjson.sign import sign_json

from twisted.internet import defer

from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import get_domain_from_id
Expand Down Expand Up @@ -72,8 +70,9 @@ def __init__(self, hs):
self.server_name = hs.hostname
self.signing_key = hs.signing_key

@defer.inlineCallbacks
def verify_attestation(self, attestation, group_id, user_id, server_name=None):
async def verify_attestation(
self, attestation, group_id, user_id, server_name=None
):
"""Verifies that the given attestation matches the given parameters.

An optional server_name can be supplied to explicitly set which server's
Expand Down Expand Up @@ -102,7 +101,7 @@ def verify_attestation(self, attestation, group_id, user_id, server_name=None):
if valid_until_ms < now:
raise SynapseError(400, "Attestation expired")

yield self.keyring.verify_json_for_server(
await self.keyring.verify_json_for_server(
server_name, attestation, now, "Group attestation"
)

Expand Down Expand Up @@ -142,20 +141,19 @@ def __init__(self, hs):
self._start_renew_attestations, 30 * 60 * 1000
)

@defer.inlineCallbacks
def on_renew_attestation(self, group_id, user_id, content):
async def on_renew_attestation(self, group_id, user_id, content):
"""When a remote updates an attestation
"""
attestation = content["attestation"]

if not self.is_mine_id(group_id) and not self.is_mine_id(user_id):
raise SynapseError(400, "Neither user not group are on this server")

yield self.attestations.verify_attestation(
await self.attestations.verify_attestation(
attestation, user_id=user_id, group_id=group_id
)

yield self.store.update_remote_attestion(group_id, user_id, attestation)
await self.store.update_remote_attestion(group_id, user_id, attestation)

return {}

Expand All @@ -172,8 +170,7 @@ async def _renew_attestations(self):
now + UPDATE_ATTESTATION_TIME_MS
)

@defer.inlineCallbacks
def _renew_attestation(group_user: Tuple[str, str]):
async def _renew_attestation(group_user: Tuple[str, str]):
group_id, user_id = group_user
try:
if not self.is_mine_id(group_id):
Expand All @@ -186,16 +183,16 @@ def _renew_attestation(group_user: Tuple[str, str]):
user_id,
group_id,
)
yield self.store.remove_attestation_renewal(group_id, user_id)
await self.store.remove_attestation_renewal(group_id, user_id)
return

attestation = self.attestations.create_attestation(group_id, user_id)

yield self.transport_client.renew_group_attestation(
await self.transport_client.renew_group_attestation(
destination, group_id, user_id, content={"attestation": attestation}
)

yield self.store.update_attestation_renewal(
await self.store.update_attestation_renewal(
group_id, user_id, attestation
)
except (RequestSendFailed, HttpResponseException) as e:
Expand Down
30 changes: 13 additions & 17 deletions synapse/visibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
import logging
import operator

from twisted.internet import defer

from synapse.api.constants import EventTypes, Membership
from synapse.events.utils import prune_event
from synapse.storage import Storage
Expand All @@ -39,8 +37,7 @@
)


@defer.inlineCallbacks
def filter_events_for_client(
async def filter_events_for_client(
storage: Storage,
user_id,
events,
Expand All @@ -67,19 +64,19 @@ def filter_events_for_client(
also be called to check whether a user can see the state at a given point.

Returns:
Deferred[list[synapse.events.EventBase]]
list[synapse.events.EventBase]
"""
# Filter out events that have been soft failed so that we don't relay them
# to clients.
events = [e for e in events if not e.internal_metadata.is_soft_failed()]

types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id))
event_id_to_state = yield storage.state.get_state_for_events(
event_id_to_state = await storage.state.get_state_for_events(
frozenset(e.event_id for e in events),
state_filter=StateFilter.from_types(types),
)

ignore_dict_content = yield storage.main.get_global_account_data_by_type_for_user(
ignore_dict_content = await storage.main.get_global_account_data_by_type_for_user(
"m.ignored_user_list", user_id
)

Expand All @@ -90,7 +87,7 @@ def filter_events_for_client(
else []
)

erased_senders = yield storage.main.are_users_erased((e.sender for e in events))
erased_senders = await storage.main.are_users_erased((e.sender for e in events))

if filter_send_to_client:
room_ids = {e.room_id for e in events}
Expand All @@ -99,7 +96,7 @@ def filter_events_for_client(
for room_id in room_ids:
retention_policies[
room_id
] = yield storage.main.get_retention_policy_for_room(room_id)
] = await storage.main.get_retention_policy_for_room(room_id)

def allowed(event):
"""
Expand Down Expand Up @@ -254,8 +251,7 @@ def allowed(event):
return list(filtered_events)


@defer.inlineCallbacks
def filter_events_for_server(
async def filter_events_for_server(
storage: Storage,
server_name,
events,
Expand All @@ -277,7 +273,7 @@ def filter_events_for_server(
backfill or not.

Returns
Deferred[list[FrozenEvent]]
list[FrozenEvent]
"""

def is_sender_erased(event, erased_senders):
Expand Down Expand Up @@ -321,7 +317,7 @@ def check_event_is_visible(event, state):
# Lets check to see if all the events have a history visibility
# of "shared" or "world_readable". If that's the case then we don't
# need to check membership (as we know the server is in the room).
event_to_state_ids = yield storage.state.get_state_ids_for_events(
event_to_state_ids = await storage.state.get_state_ids_for_events(
frozenset(e.event_id for e in events),
state_filter=StateFilter.from_types(
types=((EventTypes.RoomHistoryVisibility, ""),)
Expand All @@ -339,14 +335,14 @@ def check_event_is_visible(event, state):
if not visibility_ids:
all_open = True
else:
event_map = yield storage.main.get_events(visibility_ids)
event_map = await storage.main.get_events(visibility_ids)
all_open = all(
e.content.get("history_visibility") in (None, "shared", "world_readable")
for e in event_map.values()
)

if not check_history_visibility_only:
erased_senders = yield storage.main.are_users_erased((e.sender for e in events))
erased_senders = await storage.main.are_users_erased((e.sender for e in events))
else:
# We don't want to check whether users are erased, which is equivalent
# to no users having been erased.
Expand Down Expand Up @@ -375,7 +371,7 @@ def check_event_is_visible(event, state):

# first, for each event we're wanting to return, get the event_ids
# of the history vis and membership state at those events.
event_to_state_ids = yield storage.state.get_state_ids_for_events(
event_to_state_ids = await storage.state.get_state_ids_for_events(
frozenset(e.event_id for e in events),
state_filter=StateFilter.from_types(
types=((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None))
Expand Down Expand Up @@ -405,7 +401,7 @@ def include(typ, state_key):
return False
return state_key[idx + 1 :] == server_name

event_map = yield storage.main.get_events(
event_map = await storage.main.get_events(
[e_id for e_id, key in event_id_to_state_key.items() if include(key[0], key[1])]
)

Expand Down
12 changes: 6 additions & 6 deletions tests/test_visibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def test_filtering(self):
evt = yield self.inject_room_member(user, extra_content={"a": "b"})
events_to_filter.append(evt)

filtered = yield filter_events_for_server(
self.storage, "test_server", events_to_filter
filtered = yield defer.ensureDeferred(
filter_events_for_server(self.storage, "test_server", events_to_filter)
)

# the result should be 5 redacted events, and 5 unredacted events.
Expand Down Expand Up @@ -102,8 +102,8 @@ def test_erased_user(self):
yield self.hs.get_datastore().mark_user_erased("@erased:local_hs")

# ... and the filtering happens.
filtered = yield filter_events_for_server(
self.storage, "test_server", events_to_filter
filtered = yield defer.ensureDeferred(
filter_events_for_server(self.storage, "test_server", events_to_filter)
)

for i in range(0, len(events_to_filter)):
Expand Down Expand Up @@ -265,8 +265,8 @@ def test_large_room(self):
storage.main = test_store
storage.state = test_store

filtered = yield filter_events_for_server(
test_store, "test_server", events_to_filter
filtered = yield defer.ensureDeferred(
filter_events_for_server(test_store, "test_server", events_to_filter)
)
logger.info("Filtering took %f seconds", time.time() - start)

Expand Down