Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fix have_seen_event cache not being invalidated #13863

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/13863.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix `have_seen_event` cache not being invalidated after we persist an event which causes inefficiency effects like extra `/state` federation calls.
40 changes: 22 additions & 18 deletions synapse/storage/databases/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1474,32 +1474,38 @@ async def have_seen_events(
# the batches as big as possible.

results: Set[str] = set()
for chunk in batch_iter(event_ids, 500):
r = await self._have_seen_events_dict(
[(room_id, event_id) for event_id in chunk]
for event_ids_chunk in batch_iter(event_ids, 500):
events_seen_dict = await self._have_seen_events_dict(
room_id, event_ids_chunk
)
results.update(
eid for (eid, have_event) in events_seen_dict.items() if have_event
)
results.update(eid for ((_rid, eid), have_event) in r.items() if have_event)

return results

@cachedList(cached_method_name="have_seen_event", list_name="keys")
@cachedList(cached_method_name="have_seen_event", list_name="event_ids")
async def _have_seen_events_dict(
self, keys: Collection[Tuple[str, str]]
) -> Dict[Tuple[str, str], bool]:
self,
room_id: str,
event_ids: Collection[str],
) -> Dict[str, bool]:
Comment on lines +1487 to +1492
Copy link
Contributor Author

@MadLittleMods MadLittleMods Sep 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the fix as described by @erikjohnston in #13865 (comment)

The rest of the changes in this file are adapting to this change.

"""Helper for have_seen_events
Returns:
a dict {(room_id, event_id)-> bool}
a dict {event_id -> bool}
"""
# if the event cache contains the event, obviously we've seen it.

cache_results = {
(rid, eid)
for (rid, eid) in keys
if await self._get_event_cache.contains((eid,))
event_id
for event_id in event_ids
if await self._get_event_cache.contains((event_id,))
}
results = dict.fromkeys(cache_results, True)
remaining = [k for k in keys if k not in cache_results]
remaining = [
event_id for event_id in event_ids if event_id not in cache_results
]
if not remaining:
return results

Expand All @@ -1511,23 +1517,21 @@ def have_seen_events_txn(txn: LoggingTransaction) -> None:

sql = "SELECT event_id FROM events AS e WHERE "
clause, args = make_in_list_sql_clause(
txn.database_engine, "e.event_id", [eid for (_rid, eid) in remaining]
txn.database_engine, "e.event_id", remaining
)
txn.execute(sql + clause, args)
found_events = {eid for eid, in txn}

# ... and then we can update the results for each key
results.update(
{(rid, eid): (eid in found_events) for (rid, eid) in remaining}
)
results.update({eid: (eid in found_events) for eid in remaining})

await self.db_pool.runInteraction("have_seen_events", have_seen_events_txn)
return results

@cached(max_entries=100000, tree=True)
async def have_seen_event(self, room_id: str, event_id: str) -> bool:
res = await self._have_seen_events_dict(((room_id, event_id),))
return res[(room_id, event_id)]
res = await self._have_seen_events_dict(room_id, [event_id])
return res[event_id]

def _get_current_state_event_counts_txn(
self, txn: LoggingTransaction, room_id: str
Expand Down
6 changes: 6 additions & 0 deletions synapse/util/caches/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,12 @@ def __get__(
cache: DeferredCache[CacheKey, Any] = cached_method.cache
num_args = cached_method.num_args

if num_args != self.num_args:
raise Exception(
"Number of args (%s) does not match underlying cache_method_name=%s (%s)."
% (self.num_args, self.cached_method_name, num_args)
)
Comment on lines +434 to +438
Copy link
Contributor Author

@MadLittleMods MadLittleMods Sep 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a safety check so others don't run into the same pitfall and we can see this error obviously.

This also has a test to make sure the safety check works (see tests/util/caches/test_descriptors.py)

Example error: builtins.Exception: Number of args (1) does not match underlying cache_method_name=have_seen_event (2).

Full error
Traceback (most recent call last):
  File "/Users/eric/Documents/github/element/synapse/tests/unittest.py", line 539, in get_success
    deferred: Deferred[TV] = ensureDeferred(d)  # type: ignore[arg-type]
  File "/Users/eric/Documents/github/element/synapse/.venv/lib/python3.9/site-packages/twisted/internet/defer.py", line 1162, in ensureDeferred
    return Deferred.fromCoroutine(coro)
  File "/Users/eric/Documents/github/element/synapse/.venv/lib/python3.9/site-packages/twisted/internet/defer.py", line 1137, in fromCoroutine
    return _cancellableInlineCallbacks(coro)
  File "/Users/eric/Documents/github/element/synapse/.venv/lib/python3.9/site-packages/twisted/internet/defer.py", line 1856, in _cancellableInlineCallbacks
    _inlineCallbacks(None, gen, status, _copy_context())
--- <exception caught here> ---
  File "/Users/eric/Documents/github/element/synapse/.venv/lib/python3.9/site-packages/twisted/internet/defer.py", line 1696, in _inlineCallbacks
    result = context.run(gen.send, result)
  File "/Users/eric/Documents/github/element/synapse/synapse/logging/opentracing.py", line 889, in _wrapper
    return await func(*args, **kwargs)  # type: ignore[misc]
  File "/Users/eric/Documents/github/element/synapse/synapse/logging/opentracing.py", line 889, in _wrapper
    return await func(*args, **kwargs)  # type: ignore[misc]
  File "/Users/eric/Documents/github/element/synapse/synapse/storage/databases/main/events_worker.py", line 1478, in have_seen_events
    r = await self._have_seen_events_dict(
  File "/Users/eric/Documents/github/element/synapse/synapse/util/caches/descriptors.py", line 435, in __get__
    raise Exception(
builtins.Exception: Number of args (1) does not match underlying cache_method_name=have_seen_event (2).

There are other ways the args could mismatch like the type but this would have caught the problem encountered here with have_seen_event


@functools.wraps(self.orig)
def wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Dict]":
# If we're passed a cache_context then we'll want to call its
Expand Down
152 changes: 104 additions & 48 deletions tests/storage/databases/main/test_events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,66 +35,45 @@
from synapse.util.async_helpers import yieldable_gather_results

from tests import unittest
from tests.test_utils.event_injection import create_event, inject_event


class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
servlets = [
admin.register_servlets,
room.register_servlets,
login.register_servlets,
]

def prepare(self, reactor, clock, hs):
self.hs = hs
self.store: EventsWorkerStore = hs.get_datastores().main

# insert some test data
for rid in ("room1", "room2"):
self.get_success(
self.store.db_pool.simple_insert(
"rooms",
{"room_id": rid, "room_version": 4},
)
)
self.user = self.register_user("user", "pass")
self.token = self.login(self.user, "pass")
self.room_id = self.helper.create_room_as(self.user, tok=self.token)

self.event_ids: List[str] = []
for idx, rid in enumerate(
(
"room1",
"room1",
"room1",
"room2",
)
):
event_json = {"type": f"test {idx}", "room_id": rid}
event = make_event_from_dict(event_json, room_version=RoomVersions.V4)
event_id = event.event_id

self.get_success(
self.store.db_pool.simple_insert(
"events",
{
"event_id": event_id,
"room_id": rid,
"topological_ordering": idx,
"stream_ordering": idx,
"type": event.type,
"processed": True,
"outlier": False,
},
for i in range(3):
event = self.get_success(
inject_event(
hs,
room_version=RoomVersions.V7.identifier,
room_id=self.room_id,
sender=self.user,
type="test_event_type",
content={"body": f"foobarbaz{i}"},
)
)
self.get_success(
self.store.db_pool.simple_insert(
"event_json",
{
"event_id": event_id,
"room_id": rid,
"json": json.dumps(event_json),
"internal_metadata": "{}",
"format_version": 3,
},
)
)
Comment on lines -66 to -91
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplified this test logic by using the real thing.

And makes it easier to to create another event down the line in the new test with create_event(...) without needing to worry about incrementing the stream_ordering manually.

self.event_ids.append(event_id)

self.event_ids.append(event.event_id)

def test_simple(self):
with LoggingContext(name="test") as ctx:
res = self.get_success(
self.store.have_seen_events("room1", [self.event_ids[0], "event19"])
self.store.have_seen_events(
self.room_id, [self.event_ids[0], "eventdoesnotexist"]
)
)
self.assertEqual(res, {self.event_ids[0]})

Expand All @@ -104,7 +83,9 @@ def test_simple(self):
# a second lookup of the same events should cause no queries
with LoggingContext(name="test") as ctx:
res = self.get_success(
self.store.have_seen_events("room1", [self.event_ids[0], "event19"])
self.store.have_seen_events(
self.room_id, [self.event_ids[0], "eventdoesnotexist"]
)
)
self.assertEqual(res, {self.event_ids[0]})
self.assertEqual(ctx.get_resource_usage().db_txn_count, 0)
Expand All @@ -116,11 +97,86 @@ def test_query_via_event_cache(self):
# looking it up should now cause no db hits
with LoggingContext(name="test") as ctx:
res = self.get_success(
self.store.have_seen_events("room1", [self.event_ids[0]])
self.store.have_seen_events(self.room_id, [self.event_ids[0]])
)
self.assertEqual(res, {self.event_ids[0]})
self.assertEqual(ctx.get_resource_usage().db_txn_count, 0)

def test_persisting_event_invalidates_cache(self):
"""
Test to make sure that the `have_seen_event` cache
is invalidated after we persist an event and returns
the updated value.
"""
event, event_context = self.get_success(
create_event(
self.hs,
room_id=self.room_id,
sender=self.user,
type="test_event_type",
content={"body": "garply"},
)
)

with LoggingContext(name="test") as ctx:
# First, check `have_seen_event` for an event we have not seen yet
# to prime the cache with a `false` value.
res = self.get_success(
self.store.have_seen_events(event.room_id, [event.event_id])
)
self.assertEqual(res, set())

# That should result in a single db query to lookup
self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)

# Persist the event which should invalidate or prefill the
# `have_seen_event` cache so we don't return stale values.
persistence = self.hs.get_storage_controllers().persistence
self.get_success(
persistence.persist_event(
event,
event_context,
)
)

with LoggingContext(name="test") as ctx:
# Check `have_seen_event` again and we should see the updated fact
# that we have now seen the event after persisting it.
res = self.get_success(
self.store.have_seen_events(event.room_id, [event.event_id])
)
self.assertEqual(res, {event.event_id})

# That should result in a single db query to lookup
self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)

def test_invalidate_cache_by_room_id(self):
"""
Test to make sure that all events associated with the given `(room_id,)`
are invalidated in the `have_seen_event` cache.
"""
with LoggingContext(name="test") as ctx:
# Prime the cache with some values
res = self.get_success(
self.store.have_seen_events(self.room_id, self.event_ids)
)
self.assertEqual(res, set(self.event_ids))

# That should result in a single db query to lookup
self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)

# Clear the cache with any events associated with the `room_id`
self.store.have_seen_event.invalidate((self.room_id,))

with LoggingContext(name="test") as ctx:
res = self.get_success(
self.store.have_seen_events(self.room_id, self.event_ids)
)
self.assertEqual(res, set(self.event_ids))

# Since we cleared the cache, it should result in another db query to lookup
self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)


class EventCacheTestCase(unittest.HomeserverTestCase):
"""Test that the various layers of event cache works."""
Expand Down
33 changes: 32 additions & 1 deletion tests/util/caches/test_descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Set
from typing import Iterable, Set, Tuple
from unittest import mock

from twisted.internet import defer, reactor
Expand Down Expand Up @@ -1008,3 +1008,34 @@ async def do_lookup():
obj.inner_context_was_finished, "Tried to restart a finished logcontext"
)
self.assertEqual(current_context(), SENTINEL_CONTEXT)

def test_num_args_mismatch(self):
"""
Make sure someone does not accidentally use @cachedList on a method with
a mismatch in the number args to the underlying single cache method.
"""

class Cls:
@descriptors.cached(tree=True)
def fn(self, room_id, event_id):
pass

# This is wrong ❌. `@cachedList` expects to be given the same number
# of arguments as the underlying cached function, just with one of
# the arguments being an iterable
@descriptors.cachedList(cached_method_name="fn", list_name="keys")
def list_fn(self, keys: Iterable[Tuple[str, str]]):
pass

# Corrected syntax ✅
#
# @cachedList(cached_method_name="fn", list_name="event_ids")
# async def list_fn(
# self, room_id: str, event_ids: Collection[str],
# )

obj = Cls()

# Make sure this raises an error about the arg mismatch
with self.assertRaises(Exception):
obj.list_fn([("foo", "bar")])