From ebc1916ee2b03cde00897aa14169160dcf171435 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 12 Aug 2020 11:35:53 -0400 Subject: [PATCH 01/12] Convert events worker database to async/await. --- changelog.d/8071.misc | 1 + synapse/spam_checker_api/__init__.py | 2 +- synapse/state/__init__.py | 2 +- .../databases/main/event_federation.py | 30 ++++--- .../storage/databases/main/events_worker.py | 81 +++++++++---------- synapse/storage/databases/main/stream.py | 53 +++++------- .../test_resource_limits_server_notices.py | 6 +- tests/storage/test_appservice.py | 3 +- tests/storage/test_purge.py | 49 ++++------- 9 files changed, 96 insertions(+), 131 deletions(-) create mode 100644 changelog.d/8071.misc diff --git a/changelog.d/8071.misc b/changelog.d/8071.misc new file mode 100644 index 000000000000..dfe4c03171d6 --- /dev/null +++ b/changelog.d/8071.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py index 9b78924d96d8..4d9b13ac045d 100644 --- a/synapse/spam_checker_api/__init__.py +++ b/synapse/spam_checker_api/__init__.py @@ -51,5 +51,5 @@ def get_state_events_in_room(self, room_id: str, types: tuple) -> defer.Deferred state_ids = yield self._store.get_filtered_current_state_ids( room_id=room_id, state_filter=StateFilter.from_types(types) ) - state = yield self._store.get_events(state_ids.values()) + state = yield defer.ensureDeferred(self._store.get_events(state_ids.values())) return state.values() diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index a1d388466734..dba8d91eef24 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -641,7 +641,7 @@ def get_events(self, event_ids, allow_rejected=False): allow_rejected (bool): If True return rejected events. Returns: - Deferred[dict[str, FrozenEvent]]: Dict from event_id to event. + Awaitable[dict[str, FrozenEvent]]: Dict from event_id to event. """ return self.store.get_events( diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 484875f98992..02bd20af010f 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -30,7 +30,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore): - def get_auth_chain(self, event_ids, include_given=False): + async def get_auth_chain(self, event_ids, include_given=False): """Get auth events for given event_ids. The events *must* be state events. Args: @@ -40,9 +40,10 @@ def get_auth_chain(self, event_ids, include_given=False): Returns: list of events """ - return self.get_auth_chain_ids( + event_ids = await self.get_auth_chain_ids( event_ids, include_given=include_given - ).addCallback(self.get_events_as_list) + ) + return await self.get_events_as_list(event_ids) def get_auth_chain_ids( self, @@ -472,7 +473,7 @@ def get_forward_extremeties_for_room_txn(txn): "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn ) - def get_backfill_events(self, room_id, event_list, limit): + async def get_backfill_events(self, room_id, event_list, limit): """Get a list of Events for a given topic that occurred before (and including) the events in event_list. Return a list of max size `limit` @@ -482,17 +483,15 @@ def get_backfill_events(self, room_id, event_list, limit): event_list (list) limit (int) """ - return ( - self.db_pool.runInteraction( - "get_backfill_events", - self._get_backfill_events, - room_id, - event_list, - limit, - ) - .addCallback(self.get_events_as_list) - .addCallback(lambda l: sorted(l, key=lambda e: -e.depth)) + event_ids = await self.db_pool.runInteraction( + "get_backfill_events", + self._get_backfill_events, + room_id, + event_list, + limit, ) + events = await self.get_events_as_list(event_ids) + return sorted(events, key=lambda e: -e.depth) def _get_backfill_events(self, txn, room_id, event_list, limit): logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit) @@ -553,8 +552,7 @@ async def get_missing_events(self, room_id, earliest_events, latest_events, limi latest_events, limit, ) - events = await self.get_events_as_list(ids) - return events + return await self.get_events_as_list(ids) def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit): diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 755b7a2a85d4..a962fa09bad3 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -19,7 +19,7 @@ import logging import threading from collections import namedtuple -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple from constantly import NamedConstant, Names @@ -32,7 +32,7 @@ EventFormatVersions, RoomVersions, ) -from synapse.events import make_event_from_dict +from synapse.events import EventBase, make_event_from_dict from synapse.events.utils import prune_event from synapse.logging.context import PreserveLoggingContext, current_context from synapse.metrics.background_process_metrics import run_as_background_process @@ -43,7 +43,7 @@ from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import get_domain_from_id -from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks +from synapse.util.caches.descriptors import Cache, cached from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -173,8 +173,7 @@ def _get_approximate_received_ts_txn(txn): "get_approximate_received_ts", _get_approximate_received_ts_txn ) - @defer.inlineCallbacks - def get_event( + async def get_event( self, event_id: str, redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, @@ -182,7 +181,7 @@ def get_event( allow_rejected: bool = False, allow_none: bool = False, check_room_id: Optional[str] = None, - ): + ) -> Optional[EventBase]: """Get an event from the database by event_id. Args: @@ -207,12 +206,12 @@ def get_event( If there is a mismatch, behave as per allow_none. Returns: - Deferred[EventBase|None] + The event, or None if the event was not found. """ if not isinstance(event_id, str): raise TypeError("Invalid event event_id %r" % (event_id,)) - events = yield self.get_events_as_list( + events = await self.get_events_as_list( [event_id], redact_behaviour=redact_behaviour, get_prev_content=get_prev_content, @@ -230,14 +229,13 @@ def get_event( return event - @defer.inlineCallbacks - def get_events( + async def get_events( self, event_ids: List[str], redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, get_prev_content: bool = False, allow_rejected: bool = False, - ): + ) -> Dict[str, EventBase]: """Get events from the database Args: @@ -256,9 +254,9 @@ def get_events( omits rejeted events from the response. Returns: - Deferred : Dict from event_id to event. + A mapping from event_id to event. """ - events = yield self.get_events_as_list( + events = await self.get_events_as_list( event_ids, redact_behaviour=redact_behaviour, get_prev_content=get_prev_content, @@ -267,8 +265,7 @@ def get_events( return {e.event_id: e for e in events} - @defer.inlineCallbacks - def get_events_as_list( + async def get_events_as_list( self, event_ids: List[str], redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, @@ -295,7 +292,7 @@ def get_events_as_list( omits rejected events from the response. Returns: - Deferred[list[EventBase]]: List of events fetched from the database. The + list[EventBase]: List of events fetched from the database. The events are in the same order as `event_ids` arg. Note that the returned list may be smaller than the list of event @@ -306,7 +303,7 @@ def get_events_as_list( return [] # there may be duplicates so we cast the list to a set - event_entry_map = yield self._get_events_from_cache_or_db( + event_entry_map = await self._get_events_from_cache_or_db( set(event_ids), allow_rejected=allow_rejected ) @@ -341,7 +338,7 @@ def get_events_as_list( continue redacted_event_id = entry.event.redacts - event_map = yield self._get_events_from_cache_or_db([redacted_event_id]) + event_map = await self._get_events_from_cache_or_db([redacted_event_id]) original_event_entry = event_map.get(redacted_event_id) if not original_event_entry: # we don't have the redacted event (or it was rejected). @@ -407,7 +404,7 @@ def get_events_as_list( if get_prev_content: if "replaces_state" in event.unsigned: - prev = yield self.get_event( + prev = await self.get_event( event.unsigned["replaces_state"], get_prev_content=False, allow_none=True, @@ -419,8 +416,7 @@ def get_events_as_list( return events - @defer.inlineCallbacks - def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): + async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): """Fetch a bunch of events from the cache or the database. If events are pulled from the database, they will be cached for future lookups. @@ -435,7 +431,7 @@ def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): rejected events are omitted from the response. Returns: - Deferred[Dict[str, _EventCacheEntry]]: + Dict[str, _EventCacheEntry]: map from event id to result """ event_entry_map = self._get_events_from_cache( @@ -453,7 +449,7 @@ def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): # the events have been redacted, and if so pulling the redaction event out # of the database to check it. # - missing_events = yield self._get_events_from_db( + missing_events = await self._get_events_from_db( missing_events_ids, allow_rejected=allow_rejected ) @@ -561,8 +557,7 @@ def fire(evs, exc): with PreserveLoggingContext(): self.hs.get_reactor().callFromThread(fire, event_list, e) - @defer.inlineCallbacks - def _get_events_from_db(self, event_ids, allow_rejected=False): + async def _get_events_from_db(self, event_ids, allow_rejected=False): """Fetch a bunch of events from the database. Returned events will be added to the cache for future lookups. @@ -576,7 +571,7 @@ def _get_events_from_db(self, event_ids, allow_rejected=False): rejected events are omitted from the response. Returns: - Deferred[Dict[str, _EventCacheEntry]]: + Dict[str, _EventCacheEntry]: map from event id to result. May return extra events which weren't asked for. """ @@ -584,7 +579,7 @@ def _get_events_from_db(self, event_ids, allow_rejected=False): events_to_fetch = event_ids while events_to_fetch: - row_map = yield self._enqueue_events(events_to_fetch) + row_map = await self._enqueue_events(events_to_fetch) # we need to recursively fetch any redactions of those events redaction_ids = set() @@ -686,8 +681,7 @@ def _get_events_from_db(self, event_ids, allow_rejected=False): return result_map - @defer.inlineCallbacks - def _enqueue_events(self, events): + async def _enqueue_events(self, events): """Fetches events from the database using the _event_fetch_list. This allows batch and bulk fetching of events - it allows us to fetch events without having to create a new transaction for each request for events. @@ -696,7 +690,7 @@ def _enqueue_events(self, events): events (Iterable[str]): events to be fetched. Returns: - Deferred[Dict[str, Dict]]: map from event id to row data from the database. + Dict[str, Dict]: map from event id to row data from the database. May contain events that weren't requested. """ @@ -719,7 +713,7 @@ def _enqueue_events(self, events): logger.debug("Loading %d events: %s", len(events), events) with PreserveLoggingContext(): - row_map = yield events_d + row_map = await events_d logger.debug("Loaded %d events (%d rows)", len(events), len(row_map)) return row_map @@ -878,12 +872,11 @@ def _maybe_redact_event_row(self, original_ev, redactions, event_map): # no valid redaction found for this event return None - @defer.inlineCallbacks - def have_events_in_timeline(self, event_ids): + async def have_events_in_timeline(self, event_ids): """Given a list of event ids, check if we have already processed and stored them as non outliers. """ - rows = yield self.db_pool.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="events", retcols=("event_id",), column="event_id", @@ -894,15 +887,14 @@ def have_events_in_timeline(self, event_ids): return {r["event_id"] for r in rows} - @defer.inlineCallbacks - def have_seen_events(self, event_ids): + async def have_seen_events(self, event_ids): """Given a list of event ids, check if we have already processed them. Args: event_ids (iterable[str]): Returns: - Deferred[set[str]]: The events we have already seen. + set[str]: The events we have already seen. """ results = set() @@ -918,7 +910,7 @@ def have_seen_events_txn(txn, chunk): # break the input up into chunks of 100 input_iterator = iter(event_ids) for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []): - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "have_seen_events", have_seen_events_txn, chunk ) return results @@ -978,8 +970,7 @@ def get_current_state_event_counts(self, room_id): room_id, ) - @defer.inlineCallbacks - def get_room_complexity(self, room_id): + async def get_room_complexity(self, room_id): """ Get a rough approximation of the complexity of the room. This is used by remote servers to decide whether they wish to join the room or not. @@ -990,9 +981,9 @@ def get_room_complexity(self, room_id): room_id (str) Returns: - Deferred[dict[str:int]] of complexity version to complexity. + dict[str:int] of complexity version to complexity. """ - state_events = yield self.get_current_state_event_counts(room_id) + state_events = await self.get_current_state_event_counts(room_id) # Call this one "v1", so we can introduce new ones as we want to develop # it. @@ -1320,9 +1311,9 @@ async def is_event_after(self, event_id1, event_id2): to_2, so_2 = await self.get_event_ordering(event_id2) return (to_1, so_1) > (to_2, so_2) - @cachedInlineCallbacks(max_entries=5000) - def get_event_ordering(self, event_id): - res = yield self.db_pool.simple_select_one( + @cached(max_entries=5000) + async def get_event_ordering(self, event_id): + res = await self.db_pool.simple_select_one( table="events", retcols=["topological_ordering", "stream_ordering"], keyvalues={"event_id": event_id}, diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index aaf225894e23..dc8116db989f 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -361,11 +361,9 @@ def get_rooms_that_changed(self, room_ids, from_key): if self._events_stream_cache.has_entity_changed(room_id, from_key) } - @defer.inlineCallbacks - def get_room_events_stream_for_room( + async def get_room_events_stream_for_room( self, room_id, from_key, to_key, limit=0, order="DESC" ): - """Get new room events in stream ordering since `from_key`. Args: @@ -380,7 +378,7 @@ def get_room_events_stream_for_room( oldest `limit` events. Returns: - Deferred[tuple[list[FrozenEvent], str]]: Returns the list of + tuple[list[FrozenEvent], str]: Returns the list of events (in ascending order) and the token from the start of the chunk of events returned. """ @@ -390,9 +388,7 @@ def get_room_events_stream_for_room( from_id = RoomStreamToken.parse_stream_token(from_key).stream to_id = RoomStreamToken.parse_stream_token(to_key).stream - has_changed = yield self._events_stream_cache.has_entity_changed( - room_id, from_id - ) + has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id) if not has_changed: return [], from_key @@ -410,9 +406,9 @@ def f(txn): rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] return rows - rows = yield self.db_pool.runInteraction("get_room_events_stream_for_room", f) + rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f) - ret = yield self.get_events_as_list( + ret = await self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) @@ -430,8 +426,7 @@ def f(txn): return ret, key - @defer.inlineCallbacks - def get_membership_changes_for_user(self, user_id, from_key, to_key): + async def get_membership_changes_for_user(self, user_id, from_key, to_key): from_id = RoomStreamToken.parse_stream_token(from_key).stream to_id = RoomStreamToken.parse_stream_token(to_key).stream @@ -460,9 +455,9 @@ def f(txn): return rows - rows = yield self.db_pool.runInteraction("get_membership_changes_for_user", f) + rows = await self.db_pool.runInteraction("get_membership_changes_for_user", f) - ret = yield self.get_events_as_list( + ret = await self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) @@ -470,8 +465,7 @@ def f(txn): return ret - @defer.inlineCallbacks - def get_recent_events_for_room(self, room_id, limit, end_token): + async def get_recent_events_for_room(self, room_id, limit, end_token): """Get the most recent events in the room in topological ordering. Args: @@ -486,11 +480,11 @@ def get_recent_events_for_room(self, room_id, limit, end_token): The events returned are in ascending order. """ - rows, token = yield self.get_recent_event_ids_for_room( + rows, token = await self.get_recent_event_ids_for_room( room_id, limit, end_token ) - events = yield self.get_events_as_list( + events = await self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) @@ -656,8 +650,7 @@ def _set_before_and_after(events, rows, topo_order=True): internal.after = str(RoomStreamToken(topo, stream)) internal.order = (int(topo) if topo else 0, int(stream)) - @defer.inlineCallbacks - def get_events_around( + async def get_events_around( self, room_id, event_id, before_limit, after_limit, event_filter=None ): """Retrieve events and pagination tokens around a given event in a @@ -674,7 +667,7 @@ def get_events_around( dict """ - results = yield self.db_pool.runInteraction( + results = await self.db_pool.runInteraction( "get_events_around", self._get_events_around_txn, room_id, @@ -684,11 +677,11 @@ def get_events_around( event_filter, ) - events_before = yield self.get_events_as_list( + events_before = await self.get_events_as_list( list(results["before"]["event_ids"]), get_prev_content=True ) - events_after = yield self.get_events_as_list( + events_after = await self.get_events_as_list( list(results["after"]["event_ids"]), get_prev_content=True ) @@ -758,8 +751,7 @@ def _get_events_around_txn( "after": {"event_ids": events_after, "token": end_token}, } - @defer.inlineCallbacks - def get_all_new_events_stream(self, from_id, current_id, limit): + async def get_all_new_events_stream(self, from_id, current_id, limit): """Get all new events Returns all events with from_id < stream_ordering <= current_id. @@ -770,7 +762,7 @@ def get_all_new_events_stream(self, from_id, current_id, limit): limit (int): the maximum number of events to return Returns: - Deferred[Tuple[int, list[FrozenEvent]]]: A tuple of (next_id, events), where + Tuple[int, list[FrozenEvent]]: A tuple of (next_id, events), where `next_id` is the next value to pass as `from_id` (it will either be the stream_ordering of the last returned event, or, if fewer than `limit` events were found, `current_id`. @@ -795,11 +787,11 @@ def get_all_new_events_stream_txn(txn): return upper_bound, [row[1] for row in rows] - upper_bound, event_ids = yield self.db_pool.runInteraction( + upper_bound, event_ids = await self.db_pool.runInteraction( "get_all_new_events_stream", get_all_new_events_stream_txn ) - events = yield self.get_events_as_list(event_ids) + events = await self.get_events_as_list(event_ids) return upper_bound, events @@ -1008,8 +1000,7 @@ def _paginate_room_events_txn( return rows, str(next_token) - @defer.inlineCallbacks - def paginate_room_events( + async def paginate_room_events( self, room_id, from_key, to_key=None, direction="b", limit=-1, event_filter=None ): """Returns list of events before or after a given token. @@ -1036,7 +1027,7 @@ def paginate_room_events( if to_key: to_key = RoomStreamToken.parse(to_key) - rows, token = yield self.db_pool.runInteraction( + rows, token = await self.db_pool.runInteraction( "paginate_room_events", self._paginate_room_events_txn, room_id, @@ -1047,7 +1038,7 @@ def paginate_room_events( event_filter, ) - events = yield self.get_events_as_list( + events = await self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 2858d1355829..23db821fb7a0 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -104,7 +104,7 @@ def test_maybe_send_server_notice_to_user_remove_blocked_notice(self): type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) self._rlsn._store.get_events = Mock( - return_value=defer.succeed({"123": mock_event}) + return_value=make_awaitable({"123": mock_event}) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Would be better to check the content, but once == remove blocking event @@ -122,7 +122,7 @@ def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self): type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) self._rlsn._store.get_events = Mock( - return_value=defer.succeed({"123": mock_event}) + return_value=make_awaitable({"123": mock_event}) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -217,7 +217,7 @@ def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked(self): type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) self._rlsn._store.get_events = Mock( - return_value=defer.succeed({"123": mock_event}) + return_value=make_awaitable({"123": mock_event}) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 98b74890d5bc..1c82fcfb4d34 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -31,6 +31,7 @@ ) from tests import unittest +from tests.test_utils import make_awaitable from tests.utils import setup_test_homeserver @@ -349,7 +350,7 @@ def test_get_oldest_unsent_txn(self): other_events = [Mock(event_id="e5"), Mock(event_id="e6")] # we aren't testing store._base stuff here, so mock this out - self.store.get_events_as_list = Mock(return_value=defer.succeed(events)) + self.store.get_events_as_list = Mock(return_value=make_awaitable(events)) yield self._insert_txn(self.as_list[1]["id"], 9, other_events) yield self._insert_txn(service.id, 10, events) diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index a6012c973d51..918387733b20 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -15,6 +15,7 @@ from twisted.internet import defer +from synapse.api.errors import NotFoundError from synapse.rest.client.v1 import room from tests.unittest import HomeserverTestCase @@ -46,30 +47,19 @@ def test_purge(self): storage = self.hs.get_storage() # Get the topological token - event = store.get_topological_token_for_event(last["event_id"]) - self.pump() - event = self.successResultOf(event) - - # Purge everything before this topological token - purge = defer.ensureDeferred( - storage.purge_events.purge_history(self.room_id, event, True) + event = self.get_success( + store.get_topological_token_for_event(last["event_id"]) ) - self.pump() - self.assertEqual(self.successResultOf(purge), None) - # Try and get the events - get_first = store.get_event(first["event_id"]) - get_second = store.get_event(second["event_id"]) - get_third = store.get_event(third["event_id"]) - get_last = store.get_event(last["event_id"]) - self.pump() + # Purge everything before this topological token + self.get_success(storage.purge_events.purge_history(self.room_id, event, True)) # 1-3 should fail and last will succeed, meaning that 1-3 are deleted # and last is not. - self.failureResultOf(get_first) - self.failureResultOf(get_second) - self.failureResultOf(get_third) - self.successResultOf(get_last) + self.get_failure(store.get_event(first["event_id"]), NotFoundError) + self.get_failure(store.get_event(second["event_id"]), NotFoundError) + self.get_failure(store.get_event(third["event_id"]), NotFoundError) + self.get_success(store.get_event(last["event_id"])) def test_purge_wont_delete_extrems(self): """ @@ -84,9 +74,9 @@ def test_purge_wont_delete_extrems(self): storage = self.hs.get_datastore() # Set the topological token higher than it should be - event = storage.get_topological_token_for_event(last["event_id"]) - self.pump() - event = self.successResultOf(event) + event = self.get_success( + storage.get_topological_token_for_event(last["event_id"]) + ) event = "t{}-{}".format( *list(map(lambda x: x + 1, map(int, event[1:].split("-")))) ) @@ -98,14 +88,7 @@ def test_purge_wont_delete_extrems(self): self.assertIn("greater than forward", f.value.args[0]) # Try and get the events - get_first = storage.get_event(first["event_id"]) - get_second = storage.get_event(second["event_id"]) - get_third = storage.get_event(third["event_id"]) - get_last = storage.get_event(last["event_id"]) - self.pump() - - # Nothing is deleted. - self.successResultOf(get_first) - self.successResultOf(get_second) - self.successResultOf(get_third) - self.successResultOf(get_last) + self.get_success(storage.get_event(first["event_id"])) + self.get_success(storage.get_event(second["event_id"])) + self.get_success(storage.get_event(third["event_id"])) + self.get_success(storage.get_event(last["event_id"])) From 382e1c34a5acc58bc9969ac41ed5c39ec21064e1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 12 Aug 2020 14:07:44 -0400 Subject: [PATCH 02/12] Fix typing information. --- synapse/event_auth.py | 2 +- synapse/handlers/federation.py | 18 +++++------ synapse/handlers/message.py | 20 ++++++------ synapse/handlers/room_member.py | 32 +++++++++++++------ synapse/replication/tcp/client.py | 3 +- .../storage/databases/main/events_worker.py | 6 ++-- 6 files changed, 47 insertions(+), 34 deletions(-) diff --git a/synapse/event_auth.py b/synapse/event_auth.py index c0981eee6243..8c907ad5969a 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -47,7 +47,7 @@ def check( Args: room_version_obj: the version of the room event: the event being checked. - auth_events (dict: event-key -> event): the existing room state. + auth_events: the existing room state. Raises: AuthError if the checks fail diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 593932adb788..63089fb9f1e6 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1260,7 +1260,7 @@ async def send_invite(self, target_host, event): return pdu async def on_event_auth(self, event_id: str) -> List[EventBase]: - event = await self.store.get_event(event_id) + event = await self.store.get_event(event_id) # type: EventBase # type: ignore auth = await self.store.get_auth_chain( list(event.auth_event_ids()), include_given=True ) @@ -1778,8 +1778,8 @@ async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase """ event = await self.store.get_event( - event_id, allow_none=False, check_room_id=room_id - ) + event_id, check_room_id=room_id + ) # type: EventBase # type: ignore state_groups = await self.state_store.get_state_groups(room_id, [event_id]) @@ -1806,8 +1806,8 @@ async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]: """Returns the state at the event. i.e. not including said event. """ event = await self.store.get_event( - event_id, allow_none=False, check_room_id=room_id - ) + event_id, check_room_id=room_id + ) # type: EventBase # type: ignore state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id]) @@ -2155,9 +2155,9 @@ async def _check_for_soft_fail( auth_types = auth_types_for_event(event) current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types] - current_auth_events = await self.store.get_events(current_state_ids) + auth_events_map = await self.store.get_events(current_state_ids) current_auth_events = { - (e.type, e.state_key): e for e in current_auth_events.values() + (e.type, e.state_key): e for e in auth_events_map.values() } try: @@ -2174,8 +2174,8 @@ async def on_query_auth( raise AuthError(403, "Host not in room.") event = await self.store.get_event( - event_id, allow_none=False, check_room_id=room_id - ) + event_id, check_room_id=room_id + ) # type: EventBase # type: ignore # Just go through and process each event in `remote_auth_chain`. We # don't want to fall into the trap of `missing` being wrong. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 2643438e8490..68781bd8de93 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from canonicaljson import encode_canonical_json, json @@ -644,7 +644,7 @@ async def send_nonmember_event( event: EventBase, context: EventContext, ratelimit: bool = True, - ) -> int: + ) -> Union[int, EventBase]: """ Persists and notifies local clients and federation of an event. @@ -682,7 +682,7 @@ async def send_nonmember_event( async def deduplicate_state_event( self, event: EventBase, context: EventContext - ) -> None: + ) -> Optional[EventBase]: """ Checks whether event is in the latest resolved state in context. @@ -692,17 +692,17 @@ async def deduplicate_state_event( prev_state_ids = await context.get_prev_state_ids() prev_event_id = prev_state_ids.get((event.type, event.state_key)) if not prev_event_id: - return + return None prev_event = await self.store.get_event(prev_event_id, allow_none=True) if not prev_event: - return + return None if prev_event and event.user_id == prev_event.user_id: prev_content = encode_canonical_json(prev_event.content) next_content = encode_canonical_json(event.content) if prev_content == next_content: return prev_event - return + return None async def create_and_send_nonmember_event( self, @@ -710,7 +710,7 @@ async def create_and_send_nonmember_event( event_dict: dict, ratelimit: bool = True, txn_id: Optional[str] = None, - ) -> Tuple[EventBase, int]: + ) -> Tuple[EventBase, Union[int, EventBase]]: """ Creates an event, then sends it. @@ -957,7 +957,7 @@ async def persist_and_notify_client_event( allow_none=True, ) - is_admin_redaction = ( + is_admin_redaction = bool( original_event and event.sender != original_event.sender ) @@ -1077,8 +1077,8 @@ def is_inviter_member_event(e): auth_events_ids = self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) - auth_events = await self.store.get_events(auth_events_ids) - auth_events = {(e.type, e.state_key): e for e in auth_events.values()} + auth_events_map = await self.store.get_events(auth_events_ids) + auth_events = {(e.type, e.state_key): e for e in auth_events_map.values()} room_version = await self.store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 31705cdbdb7d..0020345da6d9 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -224,13 +224,17 @@ async def _local_membership_update( # info. newly_joined = True if prev_member_event_id: - prev_member_event = await self.store.get_event(prev_member_event_id) + prev_member_event = await self.store.get_event( + prev_member_event_id + ) # type: EventBase # type: ignore newly_joined = prev_member_event.membership != Membership.JOIN if newly_joined: await self._user_joined_room(target, room_id) elif event.membership == Membership.LEAVE: if prev_member_event_id: - prev_member_event = await self.store.get_event(prev_member_event_id) + prev_member_event = await self.store.get_event( + prev_member_event_id + ) # type: EventBase # type: ignore if prev_member_event.membership == Membership.JOIN: await self._user_left_room(target, room_id) @@ -694,13 +698,17 @@ async def send_membership_event( # info. newly_joined = True if prev_member_event_id: - prev_member_event = await self.store.get_event(prev_member_event_id) + prev_member_event = await self.store.get_event( + prev_member_event_id + ) # type: EventBase # type: ignore newly_joined = prev_member_event.membership != Membership.JOIN if newly_joined: await self._user_joined_room(target_user, room_id) elif event.membership == Membership.LEAVE: if prev_member_event_id: - prev_member_event = await self.store.get_event(prev_member_event_id) + prev_member_event = await self.store.get_event( + prev_member_event_id + ) # type: EventBase # type: ignore if prev_member_event.membership == Membership.JOIN: await self._user_left_room(target_user, room_id) @@ -714,9 +722,11 @@ async def _can_guest_join( if not guest_access_id: return False - guest_access = await self.store.get_event(guest_access_id) + guest_access = await self.store.get_event( + guest_access_id + ) # type: EventBase # type: ignore - return ( + return bool( guest_access and guest_access.content and "guest_access" in guest_access.content @@ -772,7 +782,7 @@ async def do_3pid_invite( requester: Requester, txn_id: Optional[str], id_access_token: Optional[str] = None, - ) -> int: + ) -> Union[int, EventBase]: if self.config.block_non_admin_invites: is_requester_admin = await self.auth.is_server_admin(requester.user) if not is_requester_admin: @@ -806,7 +816,7 @@ async def do_3pid_invite( if invitee: _, stream_id = await self.update_membership( requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id - ) + ) # type: Tuple[Any, Union[int, EventBase]] # type: ignore else: stream_id = await self._make_and_store_3pid_invite( requester, @@ -831,7 +841,7 @@ async def _make_and_store_3pid_invite( user: UserID, txn_id: Optional[str], id_access_token: Optional[str] = None, - ) -> int: + ) -> Union[int, EventBase]: room_state = await self.state_handler.get_current_state(room_id) inviter_display_name = "" @@ -1066,7 +1076,9 @@ async def remote_reject_invite( Implements RoomMemberHandler.remote_reject_invite """ - invite_event = await self.store.get_event(invite_event_id) + invite_event = await self.store.get_event( + invite_event_id + ) # type: EventBase # type: ignore room_id = invite_event.room_id target_user = invite_event.state_key diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index fcf8ebf1e74f..7380015809e4 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -22,6 +22,7 @@ from twisted.internet.protocol import ReconnectingClientFactory from synapse.api.constants import EventTypes +from synapse.events import EventBase from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.streams import TypingStream @@ -145,7 +146,7 @@ async def on_rdata( event = await self.store.get_event( row.data.event_id, allow_rejected=True - ) + ) # type: EventBase # type: ignore if event.rejected_reason: continue diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index a962fa09bad3..8e3485c33635 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -19,7 +19,7 @@ import logging import threading from collections import namedtuple -from typing import Dict, List, Optional, Tuple +from typing import Dict, Iterable, List, Optional, Tuple from constantly import NamedConstant, Names @@ -231,7 +231,7 @@ async def get_event( async def get_events( self, - event_ids: List[str], + event_ids: Iterable[str], redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, get_prev_content: bool = False, allow_rejected: bool = False, @@ -267,7 +267,7 @@ async def get_events( async def get_events_as_list( self, - event_ids: List[str], + event_ids: Iterable[str], redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, get_prev_content: bool = False, allow_rejected: bool = False, From 831c4929ebf702766cac7da7986a857bec7920cb Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 12 Aug 2020 14:14:52 -0400 Subject: [PATCH 03/12] lint --- synapse/handlers/room_member.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 0020345da6d9..589e3a8b070e 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -16,7 +16,7 @@ import abc import logging from http import HTTPStatus -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union from unpaddedbase64 import encode_base64 From 7b72cdf4701a1a53d7548c0a9946903e0dae1409 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Aug 2020 10:54:46 -0400 Subject: [PATCH 04/12] Return the previous stream token if a non-member event is a duplicate. --- changelog.d/8093.bugfix | 1 + synapse/handlers/message.py | 25 +++++++++++++++---------- 2 files changed, 16 insertions(+), 10 deletions(-) create mode 100644 changelog.d/8093.bugfix diff --git a/changelog.d/8093.bugfix b/changelog.d/8093.bugfix new file mode 100644 index 000000000000..80045dde1af1 --- /dev/null +++ b/changelog.d/8093.bugfix @@ -0,0 +1 @@ +Return the previous stream token if a non-member event is a duplicate. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 48b0fc7279be..f242d3c6acb7 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -667,14 +667,14 @@ async def send_nonmember_event( assert self.hs.is_mine(user), "User must be our own: %s" % (user,) if event.is_state(): - prev_state = await self.deduplicate_state_event(event, context) - if prev_state is not None: + prev_event = await self.deduplicate_state_event(event, context) + if prev_event is not None: logger.info( "Not bothering to persist state event %s duplicated by %s", event.event_id, - prev_state.event_id, + prev_event.event_id, ) - return prev_state + return await self.store.get_stream_token_for_event(prev_event.event_id) return await self.handle_new_client_event( requester=requester, event=event, context=context, ratelimit=ratelimit @@ -682,27 +682,32 @@ async def send_nonmember_event( async def deduplicate_state_event( self, event: EventBase, context: EventContext - ) -> None: + ) -> Optional[EventBase]: """ Checks whether event is in the latest resolved state in context. - If so, returns the version of the event in context. - Otherwise, returns None. + Args: + event: The event to check for duplication. + context: The event context. + + Returns: + The previous verion of the event is returned, if it is found in the + event context. Otherwise, None is returned. """ prev_state_ids = await context.get_prev_state_ids() prev_event_id = prev_state_ids.get((event.type, event.state_key)) if not prev_event_id: - return + return None prev_event = await self.store.get_event(prev_event_id, allow_none=True) if not prev_event: - return + return None if prev_event and event.user_id == prev_event.user_id: prev_content = encode_canonical_json(prev_event.content) next_content = encode_canonical_json(event.content) if prev_content == next_content: return prev_event - return + return None async def create_and_send_nonmember_event( self, From 1f675a991689e9ae2b4a1301297b80470831a992 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Aug 2020 12:17:18 -0400 Subject: [PATCH 05/12] Rollback type changes due to fixes. --- synapse/handlers/message.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 995cd8cef761..1de6d11e15ec 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from canonicaljson import encode_canonical_json, json @@ -644,7 +644,7 @@ async def send_nonmember_event( event: EventBase, context: EventContext, ratelimit: bool = True, - ) -> Union[int, EventBase]: + ) -> int: """ Persists and notifies local clients and federation of an event. @@ -715,7 +715,7 @@ async def create_and_send_nonmember_event( event_dict: dict, ratelimit: bool = True, txn_id: Optional[str] = None, - ) -> Tuple[EventBase, Union[int, EventBase]]: + ) -> Tuple[EventBase, int]: """ Creates an event, then sends it. From 1c6d051b4a11741a78a684e3ea6cada7718b474f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Aug 2020 12:21:04 -0400 Subject: [PATCH 06/12] Fix more typing information. --- synapse/handlers/room.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 442cca28e6b5..0df022f73dbd 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -266,7 +266,9 @@ async def _update_upgraded_room_pls( ) return - old_room_pl_state = await self.store.get_event(old_room_pl_event_id) + old_room_pl_state = await self.store.get_event( + old_room_pl_event_id + ) # type: EventBase # type: ignore # we try to stop regular users from speaking by setting the PL required # to send regular events and invites to 'Moderator' level. That's normally From 4ae631048143f44ab544a800a21ac28d413b652c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 18 Aug 2020 11:48:40 -0400 Subject: [PATCH 07/12] Remove extraneous file. --- changelog.d/8093.bugfix | 1 - 1 file changed, 1 deletion(-) delete mode 100644 changelog.d/8093.bugfix diff --git a/changelog.d/8093.bugfix b/changelog.d/8093.bugfix deleted file mode 100644 index 80045dde1af1..000000000000 --- a/changelog.d/8093.bugfix +++ /dev/null @@ -1 +0,0 @@ -Return the previous stream token if a non-member event is a duplicate. From a2b64d2dcc03f82a9650a353ebea5d192da2d984 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 18 Aug 2020 11:58:59 -0400 Subject: [PATCH 08/12] Fix up some types. --- synapse/handlers/room_member.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 589e3a8b070e..cd53d5c8f099 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -16,7 +16,7 @@ import abc import logging from http import HTTPStatus -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union from unpaddedbase64 import encode_base64 @@ -782,7 +782,7 @@ async def do_3pid_invite( requester: Requester, txn_id: Optional[str], id_access_token: Optional[str] = None, - ) -> Union[int, EventBase]: + ) -> int: if self.config.block_non_admin_invites: is_requester_admin = await self.auth.is_server_admin(requester.user) if not is_requester_admin: @@ -816,7 +816,7 @@ async def do_3pid_invite( if invitee: _, stream_id = await self.update_membership( requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id - ) # type: Tuple[Any, Union[int, EventBase]] # type: ignore + ) else: stream_id = await self._make_and_store_3pid_invite( requester, @@ -841,7 +841,7 @@ async def _make_and_store_3pid_invite( user: UserID, txn_id: Optional[str], id_access_token: Optional[str] = None, - ) -> Union[int, EventBase]: + ) -> int: room_state = await self.state_handler.get_current_state(room_id) inviter_display_name = "" From 2981aa0b3e05aaed36b9e8dbf36935e70319b758 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 18 Aug 2020 12:38:27 -0400 Subject: [PATCH 09/12] Overload type hints for get_event. --- synapse/handlers/federation.py | 14 +++------ synapse/handlers/room.py | 4 +-- synapse/handlers/room_member.py | 24 +++++---------- synapse/replication/tcp/client.py | 3 +- .../storage/databases/main/events_worker.py | 29 ++++++++++++++++++- 5 files changed, 42 insertions(+), 32 deletions(-) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 63089fb9f1e6..5b270228e784 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1260,7 +1260,7 @@ async def send_invite(self, target_host, event): return pdu async def on_event_auth(self, event_id: str) -> List[EventBase]: - event = await self.store.get_event(event_id) # type: EventBase # type: ignore + event = await self.store.get_event(event_id) auth = await self.store.get_auth_chain( list(event.auth_event_ids()), include_given=True ) @@ -1777,9 +1777,7 @@ async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase """Returns the state at the event. i.e. not including said event. """ - event = await self.store.get_event( - event_id, check_room_id=room_id - ) # type: EventBase # type: ignore + event = await self.store.get_event(event_id, check_room_id=room_id) state_groups = await self.state_store.get_state_groups(room_id, [event_id]) @@ -1805,9 +1803,7 @@ async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]: """Returns the state at the event. i.e. not including said event. """ - event = await self.store.get_event( - event_id, check_room_id=room_id - ) # type: EventBase # type: ignore + event = await self.store.get_event(event_id, check_room_id=room_id) state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id]) @@ -2173,9 +2169,7 @@ async def on_query_auth( if not in_room: raise AuthError(403, "Host not in room.") - event = await self.store.get_event( - event_id, check_room_id=room_id - ) # type: EventBase # type: ignore + event = await self.store.get_event(event_id, check_room_id=room_id) # Just go through and process each event in `remote_auth_chain`. We # don't want to fall into the trap of `missing` being wrong. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 0df022f73dbd..442cca28e6b5 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -266,9 +266,7 @@ async def _update_upgraded_room_pls( ) return - old_room_pl_state = await self.store.get_event( - old_room_pl_event_id - ) # type: EventBase # type: ignore + old_room_pl_state = await self.store.get_event(old_room_pl_event_id) # we try to stop regular users from speaking by setting the PL required # to send regular events and invites to 'Moderator' level. That's normally diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index cd53d5c8f099..27f17719ac50 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -225,16 +225,16 @@ async def _local_membership_update( newly_joined = True if prev_member_event_id: prev_member_event = await self.store.get_event( - prev_member_event_id - ) # type: EventBase # type: ignore + prev_member_event_id, allow_none=False + ) newly_joined = prev_member_event.membership != Membership.JOIN if newly_joined: await self._user_joined_room(target, room_id) elif event.membership == Membership.LEAVE: if prev_member_event_id: prev_member_event = await self.store.get_event( - prev_member_event_id - ) # type: EventBase # type: ignore + prev_member_event_id, allow_none=False + ) if prev_member_event.membership == Membership.JOIN: await self._user_left_room(target, room_id) @@ -698,17 +698,13 @@ async def send_membership_event( # info. newly_joined = True if prev_member_event_id: - prev_member_event = await self.store.get_event( - prev_member_event_id - ) # type: EventBase # type: ignore + prev_member_event = await self.store.get_event(prev_member_event_id) newly_joined = prev_member_event.membership != Membership.JOIN if newly_joined: await self._user_joined_room(target_user, room_id) elif event.membership == Membership.LEAVE: if prev_member_event_id: - prev_member_event = await self.store.get_event( - prev_member_event_id - ) # type: EventBase # type: ignore + prev_member_event = await self.store.get_event(prev_member_event_id) if prev_member_event.membership == Membership.JOIN: await self._user_left_room(target_user, room_id) @@ -722,9 +718,7 @@ async def _can_guest_join( if not guest_access_id: return False - guest_access = await self.store.get_event( - guest_access_id - ) # type: EventBase # type: ignore + guest_access = await self.store.get_event(guest_access_id) return bool( guest_access @@ -1076,9 +1070,7 @@ async def remote_reject_invite( Implements RoomMemberHandler.remote_reject_invite """ - invite_event = await self.store.get_event( - invite_event_id - ) # type: EventBase # type: ignore + invite_event = await self.store.get_event(invite_event_id) room_id = invite_event.room_id target_user = invite_event.state_key diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 7380015809e4..fcf8ebf1e74f 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -22,7 +22,6 @@ from twisted.internet.protocol import ReconnectingClientFactory from synapse.api.constants import EventTypes -from synapse.events import EventBase from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.streams import TypingStream @@ -146,7 +145,7 @@ async def on_rdata( event = await self.store.get_event( row.data.event_id, allow_rejected=True - ) # type: EventBase # type: ignore + ) if event.rejected_reason: continue diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 1a7111f5b412..71d1475172df 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -19,9 +19,10 @@ import logging import threading from collections import namedtuple -from typing import Dict, Iterable, List, Optional, Tuple +from typing import Dict, Iterable, List, Optional, Tuple, overload from constantly import NamedConstant, Names +from typing_extensions import Literal from twisted.internet import defer @@ -137,6 +138,32 @@ def get_received_ts(self, event_id): desc="get_received_ts", ) + # Inform mypy that if allow_none is False (the default) then get_event + # always returns an EventBase. + @overload + async def get_event( + self, + event_id: str, + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: Literal[False] = False, + check_room_id: Optional[str] = None, + ) -> EventBase: + ... + + @overload + async def get_event( + self, + event_id: str, + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: Literal[True] = False, + check_room_id: Optional[str] = None, + ) -> Optional[EventBase]: + ... + async def get_event( self, event_id: str, From 78910742a7254f874596b566c4bce542349d1b11 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 18 Aug 2020 12:40:32 -0400 Subject: [PATCH 10/12] Remove change from testing. --- synapse/handlers/room_member.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 27f17719ac50..aa1ccde2112f 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -224,17 +224,13 @@ async def _local_membership_update( # info. newly_joined = True if prev_member_event_id: - prev_member_event = await self.store.get_event( - prev_member_event_id, allow_none=False - ) + prev_member_event = await self.store.get_event(prev_member_event_id) newly_joined = prev_member_event.membership != Membership.JOIN if newly_joined: await self._user_joined_room(target, room_id) elif event.membership == Membership.LEAVE: if prev_member_event_id: - prev_member_event = await self.store.get_event( - prev_member_event_id, allow_none=False - ) + prev_member_event = await self.store.get_event(prev_member_event_id) if prev_member_event.membership == Membership.JOIN: await self._user_left_room(target, room_id) From 15fe45079db999e8ce271d60ac098224e9ae56c5 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 18 Aug 2020 13:03:37 -0400 Subject: [PATCH 11/12] Fix a couple more type hints. --- synapse/storage/databases/main/events_worker.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 71d1475172df..cdf808d86bcc 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -19,7 +19,7 @@ import logging import threading from collections import namedtuple -from typing import Dict, Iterable, List, Optional, Tuple, overload +from typing import Collection, Dict, Iterable, List, Optional, Tuple, overload from constantly import NamedConstant, Names from typing_extensions import Literal @@ -258,11 +258,11 @@ async def get_events( async def get_events_as_list( self, - event_ids: Iterable[str], + event_ids: Collection[str], redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, get_prev_content: bool = False, allow_rejected: bool = False, - ): + ) -> List[EventBase]: """Get events from the database and return in a list in the same order as given by `event_ids` arg. @@ -283,8 +283,8 @@ async def get_events_as_list( omits rejected events from the response. Returns: - list[EventBase]: List of events fetched from the database. The - events are in the same order as `event_ids` arg. + List of events fetched from the database. The events are in the same + order as `event_ids` arg. Note that the returned list may be smaller than the list of event IDs if not all events could be fetched. From 2f9ac48ee83e12c0c1affbb47c45f49004b8f375 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 18 Aug 2020 13:17:07 -0400 Subject: [PATCH 12/12] Import collection from synapse.types. --- synapse/storage/databases/main/events_worker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index cdf808d86bcc..e3a154a52705 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -19,7 +19,7 @@ import logging import threading from collections import namedtuple -from typing import Collection, Dict, Iterable, List, Optional, Tuple, overload +from typing import Dict, Iterable, List, Optional, Tuple, overload from constantly import NamedConstant, Names from typing_extensions import Literal @@ -43,7 +43,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator -from synapse.types import get_domain_from_id +from synapse.types import Collection, get_domain_from_id from synapse.util.caches.descriptors import Cache, cached from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure