Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Convert tags and metrics databases to async/await #8062

Merged
merged 3 commits into from
Aug 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8062.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
20 changes: 6 additions & 14 deletions synapse/storage/databases/main/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
import typing
from collections import Counter

from twisted.internet import defer

from synapse.metrics import BucketCollector
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore
Expand Down Expand Up @@ -69,8 +67,7 @@ def fetch(txn):
res = await self.db_pool.runInteraction("read_forward_extremities", fetch)
self._current_forward_extremities_amount = Counter([x[0] for x in res])

@defer.inlineCallbacks
def count_daily_messages(self):
async def count_daily_messages(self):
"""
Returns an estimate of the number of messages sent in the last day.

Expand All @@ -88,11 +85,9 @@ def _count_messages(txn):
(count,) = txn.fetchone()
return count

ret = yield self.db_pool.runInteraction("count_messages", _count_messages)
return ret
return await self.db_pool.runInteraction("count_messages", _count_messages)

@defer.inlineCallbacks
def count_daily_sent_messages(self):
async def count_daily_sent_messages(self):
def _count_messages(txn):
# This is good enough as if you have silly characters in your own
# hostname then thats your own fault.
Expand All @@ -109,13 +104,11 @@ def _count_messages(txn):
(count,) = txn.fetchone()
return count

ret = yield self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"count_daily_sent_messages", _count_messages
)
return ret

@defer.inlineCallbacks
def count_daily_active_rooms(self):
async def count_daily_active_rooms(self):
def _count(txn):
sql = """
SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
Expand All @@ -126,5 +119,4 @@ def _count(txn):
(count,) = txn.fetchone()
return count

ret = yield self.db_pool.runInteraction("count_daily_active_rooms", _count)
return ret
return await self.db_pool.runInteraction("count_daily_active_rooms", _count)
103 changes: 53 additions & 50 deletions synapse/storage/databases/main/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,45 +15,40 @@
# limitations under the License.

import logging
from typing import List, Tuple
from typing import Dict, List, Tuple

from canonicaljson import json

from twisted.internet import defer

from synapse.storage._base import db_to_json
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached

logger = logging.getLogger(__name__)


class TagsWorkerStore(AccountDataWorkerStore):
@cached()
def get_tags_for_user(self, user_id):
async def get_tags_for_user(self, user_id: str) -> Dict[str, Dict[str, JsonDict]]:
"""Get all the tags for a user.


Args:
user_id(str): The user to get the tags for.
user_id: The user to get the tags for.
Returns:
A deferred dict mapping from room_id strings to dicts mapping from
tag strings to tag content.
A mapping from room_id strings to dicts mapping from tag strings to
tag content.
"""

deferred = self.db_pool.simple_select_list(
rows = await self.db_pool.simple_select_list(
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
)

@deferred.addCallback
def tags_by_room(rows):
tags_by_room = {}
for row in rows:
room_tags = tags_by_room.setdefault(row["room_id"], {})
room_tags[row["tag"]] = db_to_json(row["content"])
return tags_by_room

return deferred
tags_by_room = {}
for row in rows:
room_tags = tags_by_room.setdefault(row["room_id"], {})
room_tags[row["tag"]] = db_to_json(row["content"])
return tags_by_room

async def get_all_updated_tags(
self, instance_name: str, last_id: int, current_id: int, limit: int
Expand Down Expand Up @@ -127,17 +122,19 @@ def get_tag_content(txn, tag_ids):

return results, upto_token, limited

@defer.inlineCallbacks
def get_updated_tags(self, user_id, stream_id):
async def get_updated_tags(
self, user_id: str, stream_id: int
) -> Dict[str, List[str]]:
"""Get all the tags for the rooms where the tags have changed since the
given version

Args:
user_id(str): The user to get the tags for.
stream_id(int): The earliest update to get for the user.

Returns:
A deferred dict mapping from room_id strings to lists of tag
strings for all the rooms that changed since the stream_id token.
A mapping from room_id strings to lists of tag strings for all the
rooms that changed since the stream_id token.
"""

def get_updated_tags_txn(txn):
Expand All @@ -155,47 +152,53 @@ def get_updated_tags_txn(txn):
if not changed:
return {}

room_ids = yield self.db_pool.runInteraction(
room_ids = await self.db_pool.runInteraction(
"get_updated_tags", get_updated_tags_txn
)

results = {}
if room_ids:
tags_by_room = yield self.get_tags_for_user(user_id)
tags_by_room = await self.get_tags_for_user(user_id)
for room_id in room_ids:
results[room_id] = tags_by_room.get(room_id, {})

return results

def get_tags_for_room(self, user_id, room_id):
async def get_tags_for_room(
self, user_id: str, room_id: str
) -> Dict[str, JsonDict]:
"""Get all the tags for the given room

Args:
user_id(str): The user to get tags for
room_id(str): The room to get tags for
user_id: The user to get tags for
room_id: The room to get tags for

Returns:
A deferred list of string tags.
A mapping of tags to tag content.
"""
return self.db_pool.simple_select_list(
rows = await self.db_pool.simple_select_list(
table="room_tags",
keyvalues={"user_id": user_id, "room_id": room_id},
retcols=("tag", "content"),
desc="get_tags_for_room",
).addCallback(
lambda rows: {row["tag"]: db_to_json(row["content"]) for row in rows}
)
return {row["tag"]: db_to_json(row["content"]) for row in rows}


class TagsStore(TagsWorkerStore):
@defer.inlineCallbacks
def add_tag_to_room(self, user_id, room_id, tag, content):
async def add_tag_to_room(
self, user_id: str, room_id: str, tag: str, content: JsonDict
) -> int:
"""Add a tag to a room for a user.

Args:
user_id(str): The user to add a tag for.
room_id(str): The room to add a tag for.
tag(str): The tag name to add.
content(dict): A json object to associate with the tag.
user_id: The user to add a tag for.
room_id: The room to add a tag for.
tag: The tag name to add.
content: A json object to associate with the tag.

Returns:
A deferred that completes once the tag has been added.
The next account data ID.
"""
content_json = json.dumps(content)

Expand All @@ -209,18 +212,17 @@ def add_tag_txn(txn, next_id):
self._update_revision_txn(txn, user_id, room_id, next_id)

with self._account_data_id_gen.get_next() as next_id:
yield self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)

self.get_tags_for_user.invalidate((user_id,))

result = self._account_data_id_gen.get_current_token()
return result
return self._account_data_id_gen.get_current_token()

@defer.inlineCallbacks
def remove_tag_from_room(self, user_id, room_id, tag):
async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> int:
"""Remove a tag from a room for a user.

Returns:
A deferred that completes once the tag has been removed
The next account data ID.
"""

def remove_tag_txn(txn, next_id):
Expand All @@ -232,21 +234,22 @@ def remove_tag_txn(txn, next_id):
self._update_revision_txn(txn, user_id, room_id, next_id)

with self._account_data_id_gen.get_next() as next_id:
yield self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)

self.get_tags_for_user.invalidate((user_id,))

result = self._account_data_id_gen.get_current_token()
return result
return self._account_data_id_gen.get_current_token()

def _update_revision_txn(self, txn, user_id, room_id, next_id):
def _update_revision_txn(
self, txn, user_id: str, room_id: str, next_id: int
) -> None:
"""Update the latest revision of the tags for the given user and room.

Args:
txn: The database cursor
user_id(str): The ID of the user.
room_id(str): The ID of the room.
next_id(int): The the revision to advance to.
user_id: The ID of the user.
room_id: The ID of the room.
next_id: The the revision to advance to.
"""

txn.call_after(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)

from tests import unittest
from tests.test_utils import make_awaitable
from tests.unittest import override_config
from tests.utils import default_config

Expand Down Expand Up @@ -79,7 +80,9 @@ def prepare(self, reactor, clock, hs):
return_value=defer.succeed("!something:localhost")
)
self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None))
self._rlsn._store.get_tags_for_room = Mock(return_value=defer.succeed({}))
self._rlsn._store.get_tags_for_room = Mock(
side_effect=lambda user_id, room_id: make_awaitable({})
)

@override_config({"hs_disabled": True})
def test_maybe_send_server_notice_disabled_hs(self):
Expand Down