Skip to content

Commit

Permalink
Enable caching for section 2, including invalidation inside of Persis…
Browse files Browse the repository at this point in the history
…tEventsStore when creating new event_auth_chain_links
  • Loading branch information
realtyem committed Nov 6, 2023
1 parent 55dbbd3 commit 01d2c66
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 29 deletions.
123 changes: 97 additions & 26 deletions synapse/storage/databases/main/event_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,43 +331,89 @@ def _get_auth_chain_ids_using_cover_index_txn(

section_2_rows = set()

with Measure(
self.hs.get_clock(),
"_get_auth_chain_ids_using_cover_index_txn.section_2_cache_retrieval",
):
# Take a copy of the event_chains dict, as it will be mutated to remove
# entries that don't have to be pulled from the database later.
for chain_id, seq_no in dict(event_chains).items():
# Add the initial set of chains, excluding the sequence corresponding to
# initial event.
max_sequence_result = max(seq_no - 1, chains.get(chain_id, 0))
if max_sequence_result > 0:
chains[chain_id] = max_sequence_result

s2_cache_entry = self._authchain_links_list.get(chain_id)
# the seq_no above references a specific set of chains to start
# processing at. The cache will contain(if an entry is there at all) all
# chains referenced by origin chain_id.
if s2_cache_entry is not None:
for origin_seq_number, target_set_info in s2_cache_entry.items():
# This condition gates that a sequence number GREATER than what
# is needed is not used. Thereby avoiding a future state
# authorizing/denial paradox.
if origin_seq_number <= seq_no:
# chains are only reachable if the origin sequence number of
# the link is less than the max sequence number in the
# origin chain.
for target_chain_id, target_seq_no in target_set_info:
# We use a (0, 0) tuple as a placeholder in the cache
# to represent that this particular target set doesn't
# exist in the database and therefore will never be
# in the cache. Typically, this is an origin event and
# will have nothing prior to it, hence no chain.
if (target_chain_id, target_seq_no) != (0, 0):
# This is slightly more optimized than using max()
target_seq_max_result = chains.get(
target_chain_id, 0
)
if target_seq_no > target_seq_max_result:
chains[target_chain_id] = target_seq_no
else:
logger.debug("JASON: hit (0, 0) warning condition")

del event_chains[chain_id]

with Measure(
self.hs.get_clock(),
"_get_auth_chain_ids_using_cover_index_txn.section_2_database",
):
# Add all linked chains reachable from initial set of chains.
sql = """
SELECT
origin_chain_id, origin_sequence_number,
target_chain_id, target_sequence_number
FROM event_auth_chain_links
WHERE %s
"""
for batch2 in batch_iter(event_chains, 1000):
clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch2
)
txn.execute(sql % (clause,), args)
if event_chains:
# Add all linked chains reachable from initial set of chains.
sql = """
SELECT
origin_chain_id, origin_sequence_number,
target_chain_id, target_sequence_number
FROM event_auth_chain_links
WHERE %s
"""
for batch2 in batch_iter(event_chains, 1000):
clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch2
)
txn.execute(sql % (clause,), args)

for (
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in txn:
section_2_rows.add(
(
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
for (
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in txn:
section_2_rows.add(
(
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
)
)
)

with Measure(
self.hs.get_clock(),
"_get_auth_chain_ids_using_cover_index_txn.section_2_postprocessing",
):
# If there were no database hits, this will be skipped
for (
origin_chain_id,
origin_sequence_number,
Expand All @@ -390,6 +436,31 @@ def _get_auth_chain_ids_using_cover_index_txn(
if max_sequence_result > 0:
chains[chain_id] = max_sequence_result

with Measure(
self.hs.get_clock(),
"_get_auth_chain_ids_using_cover_index_txn.section_2_postprocessing_cache",
):
# For this block, first build the cache entries in an efficient way, then
# set them into the cache itself. Again, if the database wasn't pulled from,
# this will be skipped.
cache_entries: Dict[int, Dict[int, Set[Tuple[int, int]]]] = {}
seen_during_batching = set()
for (
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in section_2_rows:
seen_during_batching.add(origin_chain_id)
cache_entries.setdefault(origin_chain_id, {}).setdefault(
origin_sequence_number, set()
).add((target_chain_id, target_sequence_number))

# By not setting the cache entries into the cache while processing above, we
# avoid multiple cache hits and complicated updating brittleness.
for origin_chain_id, cache_entry in cache_entries.items():
self._authchain_links_list.set(origin_chain_id, cache_entry)

# Now for each chain we figure out the maximum sequence number reachable
# from *any* event ID. Events with a sequence less than that are in the
# auth chain.
Expand Down
18 changes: 15 additions & 3 deletions synapse/storage/databases/main/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def _persist_event_auth_chain_txn(
}
event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}

self._add_chain_cover_index(
returned_chains = self._add_chain_cover_index(
txn,
self.db_pool,
self.store.event_chain_id_gen,
Expand All @@ -547,6 +547,12 @@ def _persist_event_auth_chain_txn(
event_to_auth_chain,
)

if returned_chains is not None:
for origin_chain_id in returned_chains:
txn.call_after(
self.store._authchain_links_list.invalidate, origin_chain_id
)

@classmethod
def _add_chain_cover_index(
cls,
Expand All @@ -556,7 +562,7 @@ def _add_chain_cover_index(
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, StrCollection],
) -> None:
) -> Optional[Set[int]]:
"""Calculate the chain cover index for the given events.
Args:
Expand Down Expand Up @@ -698,7 +704,7 @@ def _add_chain_cover_index(
break

if not events_to_calc_chain_id_for:
return
return None

# Allocate chain ID/sequence numbers to each new event.
new_chain_tuples = cls._allocate_chain_ids(
Expand Down Expand Up @@ -840,6 +846,12 @@ def _add_chain_cover_index(
),
values=list(chain_links.get_additions()),
)
# Return the chain rows built for cache invalidation
rows_to_invalidate = set(chain_links.get_additions())
chains_to_invalidate = set()
for row, _, _, _ in rows_to_invalidate:
chains_to_invalidate.add(row)
return chains_to_invalidate

@staticmethod
def _allocate_chain_ids(
Expand Down

0 comments on commit 01d2c66

Please sign in to comment.