Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Convert directory, e2e_room_keys, end_to_end_keys, monthly_active_use…
Browse files Browse the repository at this point in the history
…rs database to async (#8042)
  • Loading branch information
clokep authored Aug 7, 2020
1 parent f3fe696 commit 7f83795
Show file tree
Hide file tree
Showing 10 changed files with 141 additions and 120 deletions.
1 change: 1 addition & 0 deletions changelog.d/8042.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
12 changes: 7 additions & 5 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ def get_device_updates_by_remote(self, destination, from_stream_id, limit):
master_key_by_user = {}
self_signing_key_by_user = {}
for user in users:
cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master")
cross_signing_key = yield defer.ensureDeferred(
self.get_e2e_cross_signing_key(user, "master")
)
if cross_signing_key:
key_id, verify_key = get_verify_key_from_cross_signing_key(
cross_signing_key
Expand All @@ -149,8 +151,8 @@ def get_device_updates_by_remote(self, destination, from_stream_id, limit):
"device_id": verify_key.version,
}

cross_signing_key = yield self.get_e2e_cross_signing_key(
user, "self_signing"
cross_signing_key = yield defer.ensureDeferred(
self.get_e2e_cross_signing_key(user, "self_signing")
)
if cross_signing_key:
key_id, verify_key = get_verify_key_from_cross_signing_key(
Expand Down Expand Up @@ -246,7 +248,7 @@ def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_m
destination (str): The host the device updates are intended for
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping
user_id/device_id to update stream_id and the relevent json-encoded
user_id/device_id to update stream_id and the relevant json-encoded
opentracing context
Returns:
Expand Down Expand Up @@ -599,7 +601,7 @@ async def get_all_device_list_changes_for_remotes(
between the requested tokens due to the limit.
The token returned can be used in a subsequent call to this
function to get further updatees.
function to get further updates.
The updates are a list of 2-tuples of stream ID and the row data
"""
Expand Down
51 changes: 25 additions & 26 deletions synapse/storage/databases/main/directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,29 @@
# limitations under the License.

from collections import namedtuple
from typing import Optional

from twisted.internet import defer
from typing import Iterable, Optional

from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore
from synapse.types import RoomAlias
from synapse.util.caches.descriptors import cached

RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers"))


class DirectoryWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_association_from_room_alias(self, room_alias):
""" Get's the room_id and server list for a given room_alias
async def get_association_from_room_alias(
self, room_alias: RoomAlias
) -> Optional[RoomAliasMapping]:
"""Gets the room_id and server list for a given room_alias
Args:
room_alias (RoomAlias)
room_alias: The alias to translate to an ID.
Returns:
Deferred: results in namedtuple with keys "room_id" and
"servers" or None if no association can be found
The room alias mapping or None if no association can be found.
"""
room_id = yield self.db_pool.simple_select_one_onecol(
room_id = await self.db_pool.simple_select_one_onecol(
"room_aliases",
{"room_alias": room_alias.to_string()},
"room_id",
Expand All @@ -48,7 +47,7 @@ def get_association_from_room_alias(self, room_alias):
if not room_id:
return None

servers = yield self.db_pool.simple_select_onecol(
servers = await self.db_pool.simple_select_onecol(
"room_alias_servers",
{"room_alias": room_alias.to_string()},
"server",
Expand Down Expand Up @@ -79,18 +78,20 @@ def get_aliases_for_room(self, room_id):


class DirectoryStore(DirectoryWorkerStore):
@defer.inlineCallbacks
def create_room_alias_association(self, room_alias, room_id, servers, creator=None):
async def create_room_alias_association(
self,
room_alias: RoomAlias,
room_id: str,
servers: Iterable[str],
creator: Optional[str] = None,
) -> None:
""" Creates an association between a room alias and room_id/servers
Args:
room_alias (RoomAlias)
room_id (str)
servers (list)
creator (str): Optional user_id of creator.
Returns:
Deferred
room_alias: The alias to create.
room_id: The target of the alias.
servers: A list of servers through which it may be possible to join the room
creator: Optional user_id of creator.
"""

def alias_txn(txn):
Expand Down Expand Up @@ -118,24 +119,22 @@ def alias_txn(txn):
)

try:
ret = yield self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"create_room_alias_association", alias_txn
)
except self.database_engine.module.IntegrityError:
raise SynapseError(
409, "Room alias %s already exists" % room_alias.to_string()
)
return ret

@defer.inlineCallbacks
def delete_room_alias(self, room_alias):
room_id = yield self.db_pool.runInteraction(
async def delete_room_alias(self, room_alias: RoomAlias) -> str:
room_id = await self.db_pool.runInteraction(
"delete_room_alias", self._delete_room_alias_txn, room_alias
)

return room_id

def _delete_room_alias_txn(self, txn, room_alias):
def _delete_room_alias_txn(self, txn, room_alias: RoomAlias) -> str:
txn.execute(
"SELECT room_id FROM room_aliases WHERE room_alias = ?",
(room_alias.to_string(),),
Expand Down
30 changes: 14 additions & 16 deletions synapse/storage/databases/main/e2e_room_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from twisted.internet import defer

from synapse.api.errors import StoreError
from synapse.logging.opentracing import log_kv, trace
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.util import json_encoder


class EndToEndRoomKeyStore(SQLBaseStore):
@defer.inlineCallbacks
def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key):
async def update_e2e_room_key(
self, user_id, version, room_id, session_id, room_key
):
"""Replaces the encrypted E2E room key for a given session in a given backup
Args:
Expand All @@ -37,7 +36,7 @@ def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key):
StoreError
"""

yield self.db_pool.simple_update_one(
await self.db_pool.simple_update_one(
table="e2e_room_keys",
keyvalues={
"user_id": user_id,
Expand All @@ -54,8 +53,7 @@ def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key):
desc="update_e2e_room_key",
)

@defer.inlineCallbacks
def add_e2e_room_keys(self, user_id, version, room_keys):
async def add_e2e_room_keys(self, user_id, version, room_keys):
"""Bulk add room keys to a given backup.
Args:
Expand Down Expand Up @@ -88,13 +86,12 @@ def add_e2e_room_keys(self, user_id, version, room_keys):
}
)

yield self.db_pool.simple_insert_many(
await self.db_pool.simple_insert_many(
table="e2e_room_keys", values=values, desc="add_e2e_room_keys"
)

@trace
@defer.inlineCallbacks
def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
async def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
room, or a given session.
Expand All @@ -109,7 +106,7 @@ def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
the backup (or for the specified room)
Returns:
A deferred list of dicts giving the session_data and message metadata for
A list of dicts giving the session_data and message metadata for
these room keys.
"""

Expand All @@ -124,7 +121,7 @@ def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
if session_id:
keyvalues["session_id"] = session_id

rows = yield self.db_pool.simple_select_list(
rows = await self.db_pool.simple_select_list(
table="e2e_room_keys",
keyvalues=keyvalues,
retcols=(
Expand Down Expand Up @@ -242,8 +239,9 @@ def count_e2e_room_keys(self, user_id, version):
)

@trace
@defer.inlineCallbacks
def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
async def delete_e2e_room_keys(
self, user_id, version, room_id=None, session_id=None
):
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given
room or a given session.
Expand All @@ -258,7 +256,7 @@ def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
the backup (or for the specified room)
Returns:
A deferred of the deletion transaction
The deletion transaction
"""

keyvalues = {"user_id": user_id, "version": int(version)}
Expand All @@ -267,7 +265,7 @@ def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
if session_id:
keyvalues["session_id"] = session_id

yield self.db_pool.simple_delete(
await self.db_pool.simple_delete(
table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys"
)

Expand Down
Loading

0 comments on commit 7f83795

Please sign in to comment.