diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 5669d4354984..bba206ee5de8 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -14,7 +14,7 @@ # limitations under the License. import heapq import logging -from collections import defaultdict +from collections import ChainMap, defaultdict from typing import ( TYPE_CHECKING, Any, @@ -92,8 +92,11 @@ def __init__( prev_group: Optional[int] = None, delta_ids: Optional[StateMap[str]] = None, ): - if state is None and state_group is None: - raise Exception("Either state or state group must be not None") + if state is None and state_group is None and prev_group is None: + raise Exception("One of state, state_group or prev_group must be not None") + + if prev_group is not None and delta_ids is None: + raise Exception("If prev_group is set so must delta_ids") # A map from (type, state_key) to event_id. # @@ -120,10 +123,29 @@ async def get_state( if self._state is not None: return self._state - assert self.state_group is not None + if self.state_group is not None: + return await state_storage.get_state_ids_for_group( + self.state_group, state_filter + ) + + assert self.prev_group is not None and self.delta_ids is not None + + prev_state = await state_storage.get_state_ids_for_group( + self.prev_group, state_filter + ) + + return ChainMap(prev_state, self.delta_ids, {}) + + def copy_remove_state(self) -> "_StateCacheEntry": + """Copy the state cache entry, removing the stored state if possible.""" + + include_state = self.state_group is None and self.prev_group is None - return await state_storage.get_state_ids_for_group( - self.state_group, state_filter + return _StateCacheEntry( + state=self._state if include_state else None, + state_group=self.state_group, + prev_group=self.prev_group, + delta_ids=self.delta_ids, ) def __len__(self) -> int: @@ -594,7 +616,7 @@ async def resolve_state_groups( with Measure(self.clock, "state.create_group_ids"): cache = _make_state_cache_entry(new_state, state_groups_ids) - self._state_cache[group_names] = cache + self._state_cache[group_names] = cache.copy_remove_state() return cache