From ede07cac66c9b145b7c178d59bac050488d8da34 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Fri, 18 Nov 2022 19:44:45 -0600 Subject: [PATCH 1/9] Add more tracing to filter_events_for_client --- synapse/storage/databases/state/store.py | 8 ++++++++ synapse/visibility.py | 21 +++++++++++++++++---- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index f8cfcaca83e1..0145fee4960c 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -18,6 +18,7 @@ import attr from synapse.api.constants import EventTypes +from synapse.logging.tracing import SynapseTags, set_tag, tag_args, trace from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -158,6 +159,8 @@ def _get_state_group_delta_txn(txn: LoggingTransaction) -> _GetStateGroupDelta: ) @cancellable + @trace + @tag_args async def _get_state_groups_from_groups( self, groups: List[int], state_filter: StateFilter ) -> Dict[int, StateMap[str]]: @@ -171,6 +174,11 @@ async def _get_state_groups_from_groups( Returns: Dict of state group to state map. """ + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "groups.length", + str(len(groups)), + ) + results: Dict[int, StateMap[str]] = {} chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] diff --git a/synapse/visibility.py b/synapse/visibility.py index 40a9c5b53f83..c92eade62b3b 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -23,7 +23,13 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.events.utils import prune_event -from synapse.logging.opentracing import trace +from synapse.logging.opentracing import ( + start_active_span, + SynapseTags, + set_tag, + tag_args, + trace, +) from synapse.storage.controllers import StorageControllers from synapse.storage.databases.main import DataStore from synapse.storage.state import StateFilter @@ -53,6 +59,7 @@ @trace +@tag_args async def filter_events_for_client( storage: StorageControllers, user_id: str, @@ -82,6 +89,11 @@ async def filter_events_for_client( Returns: The filtered events. """ + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "events.length", + str(len(events)), + ) + # Filter out events that have been soft failed so that we don't relay them # to clients. events_before_filtering = events @@ -130,9 +142,10 @@ def allowed(event: EventBase) -> Optional[EventBase]: sender_erased=erased_senders.get(event.sender, False), ) - # Check each event: gives an iterable of None or (a potentially modified) - # EventBase. - filtered_events = map(allowed, events) + with start_active_span("filtering events against allowed function"): + # Check each event: gives an iterable of None or (a potentially modified) + # EventBase. + filtered_events = map(allowed, events) # Turn it into a list and remove None entries before returning. return [ev for ev in filtered_events if ev] From 8340906573fe9f1cae00994cb1ca82936b0bab0e Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Fri, 18 Nov 2022 20:57:15 -0600 Subject: [PATCH 2/9] Add suspicion --- synapse/storage/databases/state/bg_updates.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index a7fcc564a992..224bebf3bfd4 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -101,9 +101,11 @@ def _get_state_groups_from_groups_txn( where_clause = " AND (%s)" % (where_clause,) if isinstance(self.database_engine, PostgresEngine): + # Suspicion start # Temporarily disable sequential scans in this transaction. This is # a temporary hack until we can add the right indices in txn.execute("SET LOCAL enable_seqscan=off") + # Suspicion end # The below query walks the state_group tree so that the "state" # table includes all state_groups in the tree. It then joins From 3ee285f08f870d80851edb06e8751a1650753250 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Mon, 21 Nov 2022 18:13:00 -0600 Subject: [PATCH 3/9] Slight changes --- synapse/storage/databases/state/store.py | 2 ++ synapse/storage/state.py | 6 +++--- synapse/visibility.py | 9 +++++---- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 0145fee4960c..ead23ac9fb06 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -246,6 +246,8 @@ def _get_state_for_group_using_cache( return state_filter.filter_state(state_dict_ids), not missing_types @cancellable + @trace + @tag_args async def _get_state_for_groups( self, groups: Iterable[int], state_filter: Optional[StateFilter] = None ) -> Dict[int, MutableStateMap[str]]: diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 0004d955b434..f5761d318563 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -100,16 +100,16 @@ def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter": The new state filter. """ type_dict: Dict[str, Optional[Set[str]]] = {} - for typ, s in types: + for typ, state_key in types: if typ in type_dict: if type_dict[typ] is None: continue - if s is None: + if state_key is None: type_dict[typ] = None continue - type_dict.setdefault(typ, set()).add(s) # type: ignore + type_dict.setdefault(typ, set()).add(state_key) # type: ignore return StateFilter( types=frozendict( diff --git a/synapse/visibility.py b/synapse/visibility.py index c92eade62b3b..ee4afd4607b6 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -106,12 +106,13 @@ async def filter_events_for_client( [event.event_id for event in events], ) - types = (_HISTORY_VIS_KEY, (EventTypes.Member, user_id)) - - # we exclude outliers at this point, and then handle them separately later + # Grab the history visibility and membership for each of the events. That's all we + # need to know in order to filter them. + filter_types = (_HISTORY_VIS_KEY, (EventTypes.Member, user_id)) event_id_to_state = await storage.state.get_state_for_events( + # we exclude outliers at this point, and then handle them separately later frozenset(e.event_id for e in events if not e.internal_metadata.outlier), - state_filter=StateFilter.from_types(types), + state_filter=StateFilter.from_types(filter_types), ) # Get the users who are ignored by the requesting user. From 2e86455f0cfa24fb4a46e9f26052f80de4b2bbfc Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Mon, 21 Nov 2022 20:57:06 -0600 Subject: [PATCH 4/9] Add alternative lookup --- changelog.d/14494.misc | 1 + synapse/storage/controllers/state.py | 51 +++++- synapse/storage/databases/state/bg_updates.py | 2 - synapse/storage/databases/state/store.py | 154 +++++++++++++++++- synapse/visibility.py | 18 +- 5 files changed, 217 insertions(+), 9 deletions(-) create mode 100644 changelog.d/14494.misc diff --git a/changelog.d/14494.misc b/changelog.d/14494.misc new file mode 100644 index 000000000000..fbf48bf70f19 --- /dev/null +++ b/changelog.d/14494.misc @@ -0,0 +1 @@ +Speed-up `/messages` with `filter_events_for_client` optimizations. diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 2b31ce54bb75..9285b64ed122 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -29,7 +29,7 @@ from synapse.api.constants import EventTypes from synapse.events import EventBase -from synapse.logging.opentracing import tag_args, trace +from synapse.logging.opentracing import SynapseTags, set_tag, tag_args, trace from synapse.storage.roommember import ProfileInfo from synapse.storage.state import StateFilter from synapse.storage.util.partial_state_events_tracker import ( @@ -182,6 +182,53 @@ def _get_state_groups_from_groups( return self.stores.state._get_state_groups_from_groups(groups, state_filter) + @trace + @tag_args + async def _get_state_for_client_filtering_for_events( + self, event_ids: Collection[str], user_id_viewing_events: str + ) -> Dict[str, StateMap[EventBase]]: + """TODO""" + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "event_ids.length", + str(len(event_ids)), + ) + + # Since we're making decisions based on the state, we need to wait. + await_full_state = True + + event_to_groups = await self.get_state_group_for_events( + event_ids, await_full_state=await_full_state + ) + + groups = set(event_to_groups.values()) + logger.info( + "_get_state_for_client_filtering_for_events: groups=%s", + groups, + ) + group_to_state = await self.stores.state._get_state_for_client_filtering( + groups, user_id_viewing_events + ) + logger.info( + "_get_state_for_client_filtering_for_events: group_to_state=%s", + group_to_state, + ) + + state_event_map = await self.stores.main.get_events( + [ev_id for sd in group_to_state.values() for ev_id in sd.values()], + get_prev_content=False, + ) + + event_to_state = { + event_id: { + k: state_event_map[v] + for k, v in group_to_state[group].items() + if v in state_event_map + } + for event_id, group in event_to_groups.items() + } + + return {event: event_to_state[event] for event in event_ids} + @trace async def get_state_for_events( self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None @@ -209,9 +256,11 @@ async def get_state_for_events( ) groups = set(event_to_groups.values()) + logger.info("get_state_for_events: groups=%s", groups) group_to_state = await self.stores.state._get_state_for_groups( groups, state_filter or StateFilter.all() ) + logger.info("get_state_for_events: group_to_state=%s", group_to_state) state_event_map = await self.stores.main.get_events( [ev_id for sd in group_to_state.values() for ev_id in sd.values()], diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index 224bebf3bfd4..a7fcc564a992 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -101,11 +101,9 @@ def _get_state_groups_from_groups_txn( where_clause = " AND (%s)" % (where_clause,) if isinstance(self.database_engine, PostgresEngine): - # Suspicion start # Temporarily disable sequential scans in this transaction. This is # a temporary hack until we can add the right indices in txn.execute("SET LOCAL enable_seqscan=off") - # Suspicion end # The below query walks the state_group tree so that the "state" # table includes all state_groups in the tree. It then joins diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index ead23ac9fb06..0bd4bad57b86 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -13,12 +13,22 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + Iterable, + List, + Mapping, + Optional, + Set, + Tuple, +) import attr from synapse.api.constants import EventTypes -from synapse.logging.tracing import SynapseTags, set_tag, tag_args, trace +from synapse.logging.opentracing import SynapseTags, set_tag, tag_args, trace from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -30,9 +40,11 @@ from synapse.storage.types import Cursor from synapse.storage.util.sequence import build_sequence_generator from synapse.types import MutableStateMap, StateKey, StateMap +from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.cancellation import cancellable +from synapse.util.iterutils import batch_iter if TYPE_CHECKING: from synapse.server import HomeServer @@ -107,6 +119,9 @@ def __init__( "*stateGroupMembersCache*", 500000, ) + # TODO: Remove cache invalidation + self._state_group_cache.invalidate_all() + self._state_group_members_cache.invalidate_all() def get_max_state_group_txn(txn: Cursor) -> int: txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups") @@ -245,6 +260,140 @@ def _get_state_for_group_using_cache( return state_filter.filter_state(state_dict_ids), not missing_types + async def _get_state_groups_from_cache( + self, state_groups: Iterable[int], state_filter: StateFilter + ) -> Tuple[Dict[int, MutableStateMap[str]], Set[int]]: + """TODO + + Returns: + A map from each state_group to the complete/incomplete state map (filled in by cached + values) and the set of incomplete groups + """ + member_filter, non_member_filter = state_filter.get_member_split() + + # Now we look them up in the member and non-member caches + ( + non_member_state, + incomplete_groups_nm, + ) = self._get_state_for_groups_using_cache( + state_groups, self._state_group_cache, state_filter=non_member_filter + ) + + (member_state, incomplete_groups_m) = self._get_state_for_groups_using_cache( + state_groups, self._state_group_members_cache, state_filter=member_filter + ) + + state = dict(non_member_state) + for state_group in state_groups: + state[state_group].update(member_state[state_group]) + + # We may have only got one of the events for the group + incomplete_groups = incomplete_groups_m | incomplete_groups_nm + + return (state, incomplete_groups) + + @cancellable + @trace + @tag_args + async def _get_state_for_client_filtering( + self, groups: Iterable[int], user_id_viewing_events: str + ) -> Dict[int, StateMap[str]]: + """ + TODO + """ + + def _get_state_for_client_filtering_txn( + txn: LoggingTransaction, groups: Iterable[int] + ) -> Mapping[int, StateMap[str]]: + sql = """ + WITH RECURSIVE sgs(state_group) AS ( + VALUES(?::bigint) + UNION ALL + SELECT prev_state_group FROM state_group_edges e, sgs s + WHERE s.state_group = e.state_group + ) + SELECT + type, state_key, event_id + FROM state_groups_state + WHERE + state_group IN ( + SELECT state_group FROM sgs + ) + AND (type = ? AND state_key = ?) + ORDER BY + type, + state_key, + -- Use the lastest state in the chain (highest numbered state_group in the chain) + state_group DESC + LIMIT 1 + """ + + results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups} + for group in groups: + row_info_list: List[Tuple] = [] + txn.execute(sql, (group, EventTypes.RoomHistoryVisibility, "")) + history_vis_info = txn.fetchone() + if history_vis_info is not None: + row_info_list.append(history_vis_info) + + txn.execute(sql, (group, EventTypes.Member, user_id_viewing_events)) + membership_info = txn.fetchone() + if membership_info is not None: + row_info_list.append(membership_info) + + for row in row_info_list: + typ, state_key, event_id = row + key = (intern_string(typ), intern_string(state_key)) + results[group][key] = event_id + + # The results should be considered immutable because we are using + # `intern_string` (TODO: Should we? copied from _get_state_groups_from_groups_txn). + return results + + # Craft a StateFilter to use with the cache + state_filter_for_cache_lookup = StateFilter.from_types( + ( + (EventTypes.RoomHistoryVisibility, ""), + (EventTypes.Member, user_id_viewing_events), + ) + ) + ( + results_from_cache, + incomplete_groups, + ) = await self._get_state_groups_from_cache( + groups, state_filter_for_cache_lookup + ) + + cache_sequence_nm = self._state_group_cache.sequence + cache_sequence_m = self._state_group_members_cache.sequence + + results: Dict[int, StateMap[str]] = results_from_cache + for batch in batch_iter(incomplete_groups, 100): + group_to_state_mapping = await self.db_pool.runInteraction( + "_get_state_for_client_filtering_txn", + _get_state_for_client_filtering_txn, + batch, + ) + logger.info("group_to_state_mapping=%s", group_to_state_mapping) + + # Now lets update the caches + # Help the cache hit ratio by expanding the filter a bit + state_filter_for_cache_insertion = ( + state_filter_for_cache_lookup.return_expanded() + ) + group_to_state_dict: Dict[int, StateMap[str]] = {} + group_to_state_dict.update(group_to_state_mapping) + self._insert_into_cache( + group_to_state_dict, + state_filter_for_cache_insertion, + cache_seq_num_members=cache_sequence_m, + cache_seq_num_non_members=cache_sequence_nm, + ) + + results.update(group_to_state_mapping) + + return results + @cancellable @trace @tag_args @@ -264,6 +413,7 @@ async def _get_state_for_groups( """ state_filter = state_filter or StateFilter.all() + # TODO: Replace with _get_state_groups_from_cache member_filter, non_member_filter = state_filter.get_member_split() # Now we look them up in the member and non-member caches diff --git a/synapse/visibility.py b/synapse/visibility.py index ee4afd4607b6..37a1b1541b1c 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -24,9 +24,9 @@ from synapse.events.snapshot import EventContext from synapse.events.utils import prune_event from synapse.logging.opentracing import ( - start_active_span, SynapseTags, set_tag, + start_active_span, tag_args, trace, ) @@ -108,11 +108,21 @@ async def filter_events_for_client( # Grab the history visibility and membership for each of the events. That's all we # need to know in order to filter them. - filter_types = (_HISTORY_VIS_KEY, (EventTypes.Member, user_id)) - event_id_to_state = await storage.state.get_state_for_events( + event_id_to_state = await storage.state._get_state_for_client_filtering_for_events( # we exclude outliers at this point, and then handle them separately later + event_ids=frozenset( + e.event_id for e in events if not e.internal_metadata.outlier + ), + user_id_viewing_events=user_id, + ) + + # TODO: Remove comparison + logger.info("----------------------------------------------------") + logger.info("----------------------------------------------------") + types = (_HISTORY_VIS_KEY, (EventTypes.Member, user_id)) + event_id_to_state_orig = await storage.state.get_state_for_events( frozenset(e.event_id for e in events if not e.internal_metadata.outlier), - state_filter=StateFilter.from_types(filter_types), + state_filter=StateFilter.from_types(types), ) # Get the users who are ignored by the requesting user. From 92a1aaf80c8b135ab514464a2604209c29b64f0e Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Mon, 21 Nov 2022 20:58:58 -0600 Subject: [PATCH 5/9] Clean up logging for fair comparison --- synapse/storage/controllers/state.py | 15 +++++---------- synapse/storage/databases/state/store.py | 1 - synapse/visibility.py | 4 ++-- 3 files changed, 7 insertions(+), 13 deletions(-) diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 9285b64ed122..7a0d149635ba 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -201,17 +201,13 @@ async def _get_state_for_client_filtering_for_events( ) groups = set(event_to_groups.values()) - logger.info( - "_get_state_for_client_filtering_for_events: groups=%s", - groups, - ) group_to_state = await self.stores.state._get_state_for_client_filtering( groups, user_id_viewing_events ) - logger.info( - "_get_state_for_client_filtering_for_events: group_to_state=%s", - group_to_state, - ) + # logger.info( + # "_get_state_for_client_filtering_for_events: group_to_state=%s", + # group_to_state, + # ) state_event_map = await self.stores.main.get_events( [ev_id for sd in group_to_state.values() for ev_id in sd.values()], @@ -256,11 +252,10 @@ async def get_state_for_events( ) groups = set(event_to_groups.values()) - logger.info("get_state_for_events: groups=%s", groups) group_to_state = await self.stores.state._get_state_for_groups( groups, state_filter or StateFilter.all() ) - logger.info("get_state_for_events: group_to_state=%s", group_to_state) + # logger.info("get_state_for_events: group_to_state=%s", group_to_state) state_event_map = await self.stores.main.get_events( [ev_id for sd in group_to_state.values() for ev_id in sd.values()], diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 0bd4bad57b86..ad639cadc362 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -374,7 +374,6 @@ def _get_state_for_client_filtering_txn( _get_state_for_client_filtering_txn, batch, ) - logger.info("group_to_state_mapping=%s", group_to_state_mapping) # Now lets update the caches # Help the cache hit ratio by expanding the filter a bit diff --git a/synapse/visibility.py b/synapse/visibility.py index 37a1b1541b1c..3a8966d37a98 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -117,8 +117,8 @@ async def filter_events_for_client( ) # TODO: Remove comparison - logger.info("----------------------------------------------------") - logger.info("----------------------------------------------------") + # logger.info("----------------------------------------------------") + # logger.info("----------------------------------------------------") types = (_HISTORY_VIS_KEY, (EventTypes.Member, user_id)) event_id_to_state_orig = await storage.state.get_state_for_events( frozenset(e.event_id for e in events if not e.internal_metadata.outlier), From 0459a9c42fb4d2195147451f7b3130c55d6b5dcd Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Mon, 21 Nov 2022 21:28:12 -0600 Subject: [PATCH 6/9] Compare old and new --- synapse/storage/databases/state/store.py | 3 --- synapse/visibility.py | 18 ++++++++++++++---- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index ad639cadc362..8209e131acda 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -119,9 +119,6 @@ def __init__( "*stateGroupMembersCache*", 500000, ) - # TODO: Remove cache invalidation - self._state_group_cache.invalidate_all() - self._state_group_members_cache.invalidate_all() def get_max_state_group_txn(txn: Cursor) -> int: txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups") diff --git a/synapse/visibility.py b/synapse/visibility.py index 3a8966d37a98..c63da1dca806 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -106,22 +106,32 @@ async def filter_events_for_client( [event.event_id for event in events], ) + non_outlier_event_ids = event_ids = frozenset( + e.event_id for e in events if not e.internal_metadata.outlier + ) + + # TODO: Remove: We do this just to remove await_full_state from the comparison + await storage.state.get_state_group_for_events( + non_outlier_event_ids, await_full_state=True + ) + # Grab the history visibility and membership for each of the events. That's all we # need to know in order to filter them. event_id_to_state = await storage.state._get_state_for_client_filtering_for_events( # we exclude outliers at this point, and then handle them separately later - event_ids=frozenset( - e.event_id for e in events if not e.internal_metadata.outlier - ), + event_ids=non_outlier_event_ids, user_id_viewing_events=user_id, ) # TODO: Remove comparison + # TODO: Remove cache invalidation + storage.state.stores.state._state_group_cache.invalidate_all() + storage.state.stores.state._state_group_members_cache.invalidate_all() # logger.info("----------------------------------------------------") # logger.info("----------------------------------------------------") types = (_HISTORY_VIS_KEY, (EventTypes.Member, user_id)) event_id_to_state_orig = await storage.state.get_state_for_events( - frozenset(e.event_id for e in events if not e.internal_metadata.outlier), + non_outlier_event_ids, state_filter=StateFilter.from_types(types), ) From 2939eadd003e62af2b17e903fc7fd5ae9f41b967 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Mon, 21 Nov 2022 23:22:03 -0600 Subject: [PATCH 7/9] Fix lints --- synapse/storage/controllers/state.py | 20 +++++--- synapse/storage/databases/state/store.py | 65 +++++++++++------------- synapse/visibility.py | 21 +------- 3 files changed, 46 insertions(+), 60 deletions(-) diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 7a0d149635ba..baf89c0fc0d1 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -184,10 +184,21 @@ def _get_state_groups_from_groups( @trace @tag_args - async def _get_state_for_client_filtering_for_events( + async def _get_state_for_events_when_filtering_for_client( self, event_ids: Collection[str], user_id_viewing_events: str ) -> Dict[str, StateMap[EventBase]]: - """TODO""" + """Get the state at each event that is necessary to filter + them before being displayed to clients from the perspective of the + `user_id_viewing_events`. Will fetch `m.room.history_visibility` and + `m.room.member` event of `user_id_viewing_events`. + + Args: + event_ids: List of event ID's that will be displayed to the client + user_id_viewing_events: User ID that will be viewing these events + + Returns: + Dict of event_id to state map. + """ set_tag( SynapseTags.FUNC_ARG_PREFIX + "event_ids.length", str(len(event_ids)), @@ -204,10 +215,6 @@ async def _get_state_for_client_filtering_for_events( group_to_state = await self.stores.state._get_state_for_client_filtering( groups, user_id_viewing_events ) - # logger.info( - # "_get_state_for_client_filtering_for_events: group_to_state=%s", - # group_to_state, - # ) state_event_map = await self.stores.main.get_events( [ev_id for sd in group_to_state.values() for ev_id in sd.values()], @@ -255,7 +262,6 @@ async def get_state_for_events( group_to_state = await self.stores.state._get_state_for_groups( groups, state_filter or StateFilter.all() ) - # logger.info("get_state_for_events: group_to_state=%s", group_to_state) state_event_map = await self.stores.main.get_events( [ev_id for sd in group_to_state.values() for ev_id in sd.values()], diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 8209e131acda..0b71f1e5d716 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -260,11 +260,17 @@ def _get_state_for_group_using_cache( async def _get_state_groups_from_cache( self, state_groups: Iterable[int], state_filter: StateFilter ) -> Tuple[Dict[int, MutableStateMap[str]], Set[int]]: - """TODO + """Given a `state_filter`, pull out the relevant cached state groups that match + the filter. + + Args: + state_groups: List of state group ID's to fetch from the cache + state_filter: The relevant StateFilter to pull against Returns: - A map from each state_group to the complete/incomplete state map (filled in by cached - values) and the set of incomplete groups + A map from each state_group ID to the complete/incomplete state map (filled + in by cached values) and the set of incomplete state_groups that still need + to be filled in. """ member_filter, non_member_filter = state_filter.get_member_split() @@ -284,7 +290,8 @@ async def _get_state_groups_from_cache( for state_group in state_groups: state[state_group].update(member_state[state_group]) - # We may have only got one of the events for the group + # We may have only got one or none of the events for the group so mark those as + # incomplete that need fetching from the database. incomplete_groups = incomplete_groups_m | incomplete_groups_nm return (state, incomplete_groups) @@ -293,15 +300,24 @@ async def _get_state_groups_from_cache( @trace @tag_args async def _get_state_for_client_filtering( - self, groups: Iterable[int], user_id_viewing_events: str - ) -> Dict[int, StateMap[str]]: - """ - TODO + self, state_group_ids: Iterable[int], user_id_viewing_events: str + ) -> Dict[int, MutableStateMap[str]]: + """Get a state map for each state group ID provided that is necessary to filter + the corresponding events before being displayed to clients from the perspective + of the `user_id_viewing_events`. + + Args: + state_group_ids: The state groups to fetch + user_id_viewing_events: User ID that will be viewing the events that correspond + to the state groups + + Returns: + Dict of state_group ID to state map. """ def _get_state_for_client_filtering_txn( txn: LoggingTransaction, groups: Iterable[int] - ) -> Mapping[int, StateMap[str]]: + ) -> Mapping[int, MutableStateMap[str]]: sql = """ WITH RECURSIVE sgs(state_group) AS ( VALUES(?::bigint) @@ -343,8 +359,6 @@ def _get_state_for_client_filtering_txn( key = (intern_string(typ), intern_string(state_key)) results[group][key] = event_id - # The results should be considered immutable because we are using - # `intern_string` (TODO: Should we? copied from _get_state_groups_from_groups_txn). return results # Craft a StateFilter to use with the cache @@ -358,13 +372,13 @@ def _get_state_for_client_filtering_txn( results_from_cache, incomplete_groups, ) = await self._get_state_groups_from_cache( - groups, state_filter_for_cache_lookup + state_group_ids, state_filter_for_cache_lookup ) cache_sequence_nm = self._state_group_cache.sequence cache_sequence_m = self._state_group_members_cache.sequence - results: Dict[int, StateMap[str]] = results_from_cache + results = results_from_cache for batch in batch_iter(incomplete_groups, 100): group_to_state_mapping = await self.db_pool.runInteraction( "_get_state_for_client_filtering_txn", @@ -408,30 +422,13 @@ async def _get_state_for_groups( Dict of state group to state map. """ state_filter = state_filter or StateFilter.all() - - # TODO: Replace with _get_state_groups_from_cache - member_filter, non_member_filter = state_filter.get_member_split() - - # Now we look them up in the member and non-member caches ( - non_member_state, - incomplete_groups_nm, - ) = self._get_state_for_groups_using_cache( - groups, self._state_group_cache, state_filter=non_member_filter - ) - - (member_state, incomplete_groups_m,) = self._get_state_for_groups_using_cache( - groups, self._state_group_members_cache, state_filter=member_filter - ) - - state = dict(non_member_state) - for group in groups: - state[group].update(member_state[group]) + results_from_cache, + incomplete_groups, + ) = await self._get_state_groups_from_cache(groups, state_filter) # Now fetch any missing groups from the database - - incomplete_groups = incomplete_groups_m | incomplete_groups_nm - + state = results_from_cache if not incomplete_groups: return state diff --git a/synapse/visibility.py b/synapse/visibility.py index c63da1dca806..f19793faf9f3 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -106,35 +106,18 @@ async def filter_events_for_client( [event.event_id for event in events], ) - non_outlier_event_ids = event_ids = frozenset( + non_outlier_event_ids = frozenset( e.event_id for e in events if not e.internal_metadata.outlier ) - # TODO: Remove: We do this just to remove await_full_state from the comparison - await storage.state.get_state_group_for_events( - non_outlier_event_ids, await_full_state=True - ) - # Grab the history visibility and membership for each of the events. That's all we # need to know in order to filter them. - event_id_to_state = await storage.state._get_state_for_client_filtering_for_events( + event_id_to_state = await storage.state._get_state_for_events_when_filtering_for_client( # we exclude outliers at this point, and then handle them separately later event_ids=non_outlier_event_ids, user_id_viewing_events=user_id, ) - # TODO: Remove comparison - # TODO: Remove cache invalidation - storage.state.stores.state._state_group_cache.invalidate_all() - storage.state.stores.state._state_group_members_cache.invalidate_all() - # logger.info("----------------------------------------------------") - # logger.info("----------------------------------------------------") - types = (_HISTORY_VIS_KEY, (EventTypes.Member, user_id)) - event_id_to_state_orig = await storage.state.get_state_for_events( - non_outlier_event_ids, - state_filter=StateFilter.from_types(types), - ) - # Get the users who are ignored by the requesting user. ignore_list = await storage.main.ignored_users(user_id) From 65a5d8ffdb6e889e88a2177dc28fdfb50c2ed26b Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Tue, 22 Nov 2022 00:28:31 -0600 Subject: [PATCH 8/9] SQLite compatible cast --- synapse/storage/databases/state/store.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 0b71f1e5d716..d09419b218b3 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -318,9 +318,10 @@ async def _get_state_for_client_filtering( def _get_state_for_client_filtering_txn( txn: LoggingTransaction, groups: Iterable[int] ) -> Mapping[int, MutableStateMap[str]]: + sql = """ WITH RECURSIVE sgs(state_group) AS ( - VALUES(?::bigint) + VALUES(CAST(? AS bigint)) UNION ALL SELECT prev_state_group FROM state_group_edges e, sgs s WHERE s.state_group = e.state_group From d4b647b4bf766bab3dfef580ac7425a3bfd886c7 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Tue, 22 Nov 2022 00:41:13 -0600 Subject: [PATCH 9/9] No need to order by stuff that isn't different See https://github.com/matrix-org/synapse/pull/14494#discussion_r1028925226 --- synapse/storage/databases/state/store.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index d09419b218b3..356f79090ccb 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -335,8 +335,6 @@ def _get_state_for_client_filtering_txn( ) AND (type = ? AND state_key = ?) ORDER BY - type, - state_key, -- Use the lastest state in the chain (highest numbered state_group in the chain) state_group DESC LIMIT 1