Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add some type hints to datastore. (#12255)
Browse files Browse the repository at this point in the history
  • Loading branch information
dklimpel authored Mar 28, 2022
1 parent 4ba55a6 commit ac95167
Show file tree
Hide file tree
Showing 10 changed files with 61 additions and 42 deletions.
1 change: 1 addition & 0 deletions changelog.d/12255.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints for storage.
3 changes: 0 additions & 3 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,11 @@ exclude = (?x)
|synapse/_scripts/update_synapse_database.py

|synapse/storage/databases/__init__.py
|synapse/storage/databases/main/__init__.py
|synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/devices.py
|synapse/storage/databases/main/event_federation.py
|synapse/storage/databases/main/push_rule.py
|synapse/storage/databases/main/receipts.py
|synapse/storage/databases/main/roommember.py
|synapse/storage/databases/main/search.py
|synapse/storage/databases/main/state.py
|synapse/storage/schema/

Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/databases/main/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def __init__(
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
self.server_name: str = hs.hostname

async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
"""Get the metadata for a local piece of media
Expand Down
37 changes: 25 additions & 12 deletions synapse/storage/databases/main/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,9 @@
Optional,
Set,
Tuple,
cast,
)

from twisted.internet import defer

from synapse.api.constants import ReceiptTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ReceiptsStream
Expand All @@ -38,7 +37,11 @@
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.util.id_generators import (
AbstractStreamIdTracker,
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
Expand All @@ -58,6 +61,7 @@ def __init__(
hs: "HomeServer",
):
self._instance_name = hs.get_instance_name()
self._receipts_id_gen: AbstractStreamIdTracker

if isinstance(database.engine, PostgresEngine):
self._can_write_to_receipts = (
Expand Down Expand Up @@ -161,7 +165,7 @@ def f(txn: LoggingTransaction) -> List[Tuple[str, str, int, int]]:
" AND user_id = ?"
)
txn.execute(sql, (user_id,))
return txn.fetchall()
return cast(List[Tuple[str, str, int, int]], txn.fetchall())

rows = await self.db_pool.runInteraction(
"get_receipts_for_user_with_orderings", f
Expand Down Expand Up @@ -257,7 +261,7 @@ def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
if not rows:
return []

content = {}
content: JsonDict = {}
for row in rows:
content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
row["user_id"]
Expand Down Expand Up @@ -305,7 +309,7 @@ def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
"_get_linearized_receipts_for_rooms", f
)

results = {}
results: JsonDict = {}
for row in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
Expand Down Expand Up @@ -370,7 +374,7 @@ def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
"get_linearized_receipts_for_all_rooms", f
)

results = {}
results: JsonDict = {}
for row in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
Expand Down Expand Up @@ -399,7 +403,7 @@ async def get_users_sent_receipts_between(
"""

if last_id == current_id:
return defer.succeed([])
return []

def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]:
sql = """
Expand Down Expand Up @@ -453,7 +457,10 @@ def get_all_updated_receipts_txn(
"""
txn.execute(sql, (last_id, current_id, limit))

updates = [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn]
updates = cast(
List[Tuple[int, list]],
[(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn],
)

limited = False
upper_bound = current_id
Expand Down Expand Up @@ -496,7 +503,13 @@ def invalidate_caches_for_receipt(
self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
self.get_receipts_for_room.invalidate((room_id, receipt_type))

def process_replication_rows(self, stream_name, instance_name, token, rows):
def process_replication_rows(
self,
stream_name: str,
instance_name: str,
token: int,
rows: Iterable[Any],
) -> None:
if stream_name == ReceiptsStream.NAME:
self._receipts_id_gen.advance(instance_name, token)
for row in rows:
Expand Down Expand Up @@ -584,7 +597,7 @@ def insert_linearized_receipt_txn(
)

if receipt_type == ReceiptTypes.READ and stream_ordering is not None:
self._remove_old_push_actions_before_txn(
self._remove_old_push_actions_before_txn( # type: ignore[attr-defined]
txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
)

Expand Down Expand Up @@ -637,7 +650,7 @@ def graph_to_linear(txn: LoggingTransaction) -> str:
"insert_receipt_conv", graph_to_linear
)

async with self._receipts_id_gen.get_next() as stream_id:
async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
Expand Down
3 changes: 2 additions & 1 deletion synapse/storage/databases/main/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.config.homeserver import HomeServerConfig
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage.database import (
DatabasePool,
Expand Down Expand Up @@ -123,7 +124,7 @@ def __init__(
):
super().__init__(database, db_conn, hs)

self.config = hs.config
self.config: HomeServerConfig = hs.config

# Note: we don't check this sequence for consistency as we'd have to
# call `find_max_generated_user_id_localpart` each time, which is
Expand Down
3 changes: 2 additions & 1 deletion synapse/storage/databases/main/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from synapse.api.constants import EventContentFields, EventTypes, JoinRules
from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.config.homeserver import HomeServerConfig
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
Expand Down Expand Up @@ -98,7 +99,7 @@ def __init__(
):
super().__init__(database, db_conn, hs)

self.config = hs.config
self.config: HomeServerConfig = hs.config

async def store_room(
self,
Expand Down
26 changes: 13 additions & 13 deletions synapse/storage/databases/main/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import logging
import re
from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set
from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set

import attr

Expand Down Expand Up @@ -74,7 +74,7 @@ def store_search_entries_txn(
" VALUES (?,?,?,to_tsvector('english', ?),?,?)"
)

args = (
args1 = (
(
entry.event_id,
entry.room_id,
Expand All @@ -86,14 +86,14 @@ def store_search_entries_txn(
for entry in entries
)

txn.execute_batch(sql, args)
txn.execute_batch(sql, args1)

elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
"INSERT INTO event_search (event_id, room_id, key, value)"
" VALUES (?,?,?,?)"
)
args = (
args2 = (
(
entry.event_id,
entry.room_id,
Expand All @@ -102,7 +102,7 @@ def store_search_entries_txn(
)
for entry in entries
)
txn.execute_batch(sql, args)
txn.execute_batch(sql, args2)

else:
# This should be unreachable.
Expand Down Expand Up @@ -427,7 +427,7 @@ async def search_msgs(

search_query = _parse_query(self.database_engine, search_term)

args = []
args: List[Any] = []

# Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless.
Expand Down Expand Up @@ -496,7 +496,7 @@ async def search_msgs(

# We set redact_behaviour to BLOCK here to prevent redacted events being returned in
# search results (which is a data leak)
events = await self.get_events_as_list(
events = await self.get_events_as_list( # type: ignore[attr-defined]
[r["event_id"] for r in results],
redact_behaviour=EventRedactBehaviour.BLOCK,
)
Expand Down Expand Up @@ -530,7 +530,7 @@ async def search_rooms(
room_ids: Collection[str],
search_term: str,
keys: Iterable[str],
limit,
limit: int,
pagination_token: Optional[str] = None,
) -> JsonDict:
"""Performs a full text search over events with given keys.
Expand All @@ -549,7 +549,7 @@ async def search_rooms(

search_query = _parse_query(self.database_engine, search_term)

args = []
args: List[Any] = []

# Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless.
Expand All @@ -573,9 +573,9 @@ async def search_rooms(

if pagination_token:
try:
origin_server_ts, stream = pagination_token.split(",")
origin_server_ts = int(origin_server_ts)
stream = int(stream)
origin_server_ts_str, stream_str = pagination_token.split(",")
origin_server_ts = int(origin_server_ts_str)
stream = int(stream_str)
except Exception:
raise SynapseError(400, "Invalid pagination token")

Expand Down Expand Up @@ -654,7 +654,7 @@ async def search_rooms(

# We set redact_behaviour to BLOCK here to prevent redacted events being returned in
# search results (which is a data leak)
events = await self.get_events_as_list(
events = await self.get_events_as_list( # type: ignore[attr-defined]
[r["event_id"] for r in results],
redact_behaviour=EventRedactBehaviour.BLOCK,
)
Expand Down
24 changes: 15 additions & 9 deletions synapse/storage/databases/main/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
import collections.abc
import logging
from typing import TYPE_CHECKING, Iterable, Optional, Set
from typing import TYPE_CHECKING, Collection, Iterable, Optional, Set, Tuple

from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
Expand All @@ -29,7 +29,7 @@
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.state import StateFilter
from synapse.types import StateMap
from synapse.types import JsonDict, StateMap
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList

Expand Down Expand Up @@ -241,7 +241,9 @@ async def get_filtered_current_state_ids(
# We delegate to the cached version
return await self.get_current_state_ids(room_id)

def _get_filtered_current_state_ids_txn(txn):
def _get_filtered_current_state_ids_txn(
txn: LoggingTransaction,
) -> StateMap[str]:
results = {}
sql = """
SELECT type, state_key, event_id FROM current_state_events
Expand Down Expand Up @@ -281,11 +283,11 @@ async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]:

event_id = state.get((EventTypes.CanonicalAlias, ""))
if not event_id:
return
return None

event = await self.get_event(event_id, allow_none=True)
if not event:
return
return None

return event.content.get("canonical_alias")

Expand All @@ -304,7 +306,7 @@ async def _get_state_group_for_event(self, event_id: str) -> Optional[int]:
list_name="event_ids",
num_args=1,
)
async def _get_state_group_for_events(self, event_ids):
async def _get_state_group_for_events(self, event_ids: Collection[str]) -> JsonDict:
"""Returns mapping event_id -> state_group"""
rows = await self.db_pool.simple_select_many_batch(
table="event_to_state_groups",
Expand Down Expand Up @@ -355,7 +357,7 @@ def __init__(
):
super().__init__(database, db_conn, hs)

self.server_name = hs.hostname
self.server_name: str = hs.hostname

self.db_pool.updates.register_background_index_update(
self.CURRENT_STATE_INDEX_UPDATE_NAME,
Expand All @@ -375,15 +377,19 @@ def __init__(
self._background_remove_left_rooms,
)

async def _background_remove_left_rooms(self, progress, batch_size):
async def _background_remove_left_rooms(
self, progress: JsonDict, batch_size: int
) -> int:
"""Background update to delete rows from `current_state_events` and
`event_forward_extremities` tables of rooms that the server is no
longer joined to.
"""

last_room_id = progress.get("last_room_id", "")

def _background_remove_left_rooms_txn(txn):
def _background_remove_left_rooms_txn(
txn: LoggingTransaction,
) -> Tuple[bool, Set[str]]:
# get a batch of room ids to consider
sql = """
SELECT DISTINCT room_id FROM current_state_events
Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/databases/main/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(
):
super().__init__(database, db_conn, hs)

self.server_name = hs.hostname
self.server_name: str = hs.hostname
self.clock = self.hs.get_clock()
self.stats_enabled = hs.config.stats.stats_enabled

Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/databases/main/user_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(
) -> None:
super().__init__(database, db_conn, hs)

self.server_name = hs.hostname
self.server_name: str = hs.hostname

self.db_pool.updates.register_background_update_handler(
"populate_user_directory_createtables",
Expand Down

0 comments on commit ac95167

Please sign in to comment.