From 7233d386902d03bf2785d19324289f535ca7a81d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 27 Feb 2020 13:07:39 +0000 Subject: [PATCH 01/16] Move stream fetch DB queries to worker stores. --- synapse/app/generic_worker.py | 3 + synapse/replication/slave/storage/_base.py | 14 ++- synapse/replication/slave/storage/pushers.py | 3 + synapse/storage/data_stores/main/cache.py | 44 +++---- .../storage/data_stores/main/deviceinbox.py | 88 +++++++------- synapse/storage/data_stores/main/events.py | 114 ------------------ .../storage/data_stores/main/events_worker.py | 114 ++++++++++++++++++ synapse/storage/data_stores/main/room.py | 40 +++--- 8 files changed, 218 insertions(+), 202 deletions(-) diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index cdc078cf1106..5caba3160c3a 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -390,6 +390,9 @@ def process_replication_rows(self, token, rows): self._room_serials[row.room_id] = token self._room_typing[row.room_id] = row.user_ids + def get_current_token(self) -> int: + return self._latest_room_serial + class GenericWorkerSlavedStore( # FIXME(#3714): We need to add UserDirectoryStore as we write directly diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index f45cbd37a0f5..751c799d9432 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -18,8 +18,10 @@ import six -from synapse.storage._base import SQLBaseStore -from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME +from synapse.storage.data_stores.main.cache import ( + CURRENT_STATE_CACHE_NAME, + CacheInvalidationWorkerStore, +) from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine @@ -35,7 +37,7 @@ def __func__(inp): return inp.__func__ -class BaseSlavedStore(SQLBaseStore): +class BaseSlavedStore(CacheInvalidationWorkerStore): def __init__(self, database: Database, db_conn, hs): super(BaseSlavedStore, self).__init__(database, db_conn, hs) if isinstance(self.database_engine, PostgresEngine): @@ -60,6 +62,12 @@ def stream_positions(self) -> Dict[str, int]: pos["caches"] = self._cache_id_gen.get_current_token() return pos + def get_cache_stream_token(self): + if self._cache_id_gen: + return self._cache_id_gen.get_current_token() + else: + return 0 + def process_replication_rows(self, stream_name, token, rows): if stream_name == "caches": if self._cache_id_gen: diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index f22c2d44a327..bce8a3d115ca 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -33,6 +33,9 @@ def stream_positions(self): result["pushers"] = self._pushers_id_gen.get_current_token() return result + def get_pushers_stream_token(self): + return self._pushers_id_gen.get_current_token() + def process_replication_rows(self, stream_name, token, rows): if stream_name == "pushers": self._pushers_id_gen.advance(token) diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py index d4c44dcc7586..4dc5da3fe8b6 100644 --- a/synapse/storage/data_stores/main/cache.py +++ b/synapse/storage/data_stores/main/cache.py @@ -32,7 +32,29 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake" -class CacheInvalidationStore(SQLBaseStore): +class CacheInvalidationWorkerStore(SQLBaseStore): + def get_all_updated_caches(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed([]) + + def get_all_updated_caches_txn(txn): + # We purposefully don't bound by the current token, as we want to + # send across cache invalidations as quickly as possible. Cache + # invalidations are idempotent, so duplicates are fine. + sql = ( + "SELECT stream_id, cache_func, keys, invalidation_ts" + " FROM cache_invalidation_stream" + " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, limit)) + return txn.fetchall() + + return self.db.runInteraction( + "get_all_updated_caches", get_all_updated_caches_txn + ) + + +class CacheInvalidationStore(CacheInvalidationWorkerStore): async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): """Invalidates the cache and adds it to the cache stream so slaves will know to invalidate their caches. @@ -145,26 +167,6 @@ def _send_invalidation_to_replication( }, ) - def get_all_updated_caches(self, last_id, current_id, limit): - if last_id == current_id: - return defer.succeed([]) - - def get_all_updated_caches_txn(txn): - # We purposefully don't bound by the current token, as we want to - # send across cache invalidations as quickly as possible. Cache - # invalidations are idempotent, so duplicates are fine. - sql = ( - "SELECT stream_id, cache_func, keys, invalidation_ts" - " FROM cache_invalidation_stream" - " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" - ) - txn.execute(sql, (last_id, limit)) - return txn.fetchall() - - return self.db.runInteraction( - "get_all_updated_caches", get_all_updated_caches_txn - ) - def get_cache_stream_token(self): if self._cache_id_gen: return self._cache_id_gen.get_current_token() diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py index 0613b49f4a8a..9a1178fb3947 100644 --- a/synapse/storage/data_stores/main/deviceinbox.py +++ b/synapse/storage/data_stores/main/deviceinbox.py @@ -207,6 +207,50 @@ def delete_messages_for_remote_destination_txn(txn): "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn ) + def get_all_new_device_messages(self, last_pos, current_pos, limit): + """ + Args: + last_pos(int): + current_pos(int): + limit(int): + Returns: + A deferred list of rows from the device inbox + """ + if last_pos == current_pos: + return defer.succeed([]) + + def get_all_new_device_messages_txn(txn): + # We limit like this as we might have multiple rows per stream_id, and + # we want to make sure we always get all entries for any stream_id + # we return. + upper_pos = min(current_pos, last_pos + limit) + sql = ( + "SELECT max(stream_id), user_id" + " FROM device_inbox" + " WHERE ? < stream_id AND stream_id <= ?" + " GROUP BY user_id" + ) + txn.execute(sql, (last_pos, upper_pos)) + rows = txn.fetchall() + + sql = ( + "SELECT max(stream_id), destination" + " FROM device_federation_outbox" + " WHERE ? < stream_id AND stream_id <= ?" + " GROUP BY destination" + ) + txn.execute(sql, (last_pos, upper_pos)) + rows.extend(txn) + + # Order by ascending stream ordering + rows.sort() + + return rows + + return self.db.runInteraction( + "get_all_new_device_messages", get_all_new_device_messages_txn + ) + class DeviceInboxBackgroundUpdateStore(SQLBaseStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" @@ -411,47 +455,3 @@ def _add_messages_to_local_device_inbox_txn( rows.append((user_id, device_id, stream_id, message_json)) txn.executemany(sql, rows) - - def get_all_new_device_messages(self, last_pos, current_pos, limit): - """ - Args: - last_pos(int): - current_pos(int): - limit(int): - Returns: - A deferred list of rows from the device inbox - """ - if last_pos == current_pos: - return defer.succeed([]) - - def get_all_new_device_messages_txn(txn): - # We limit like this as we might have multiple rows per stream_id, and - # we want to make sure we always get all entries for any stream_id - # we return. - upper_pos = min(current_pos, last_pos + limit) - sql = ( - "SELECT max(stream_id), user_id" - " FROM device_inbox" - " WHERE ? < stream_id AND stream_id <= ?" - " GROUP BY user_id" - ) - txn.execute(sql, (last_pos, upper_pos)) - rows = txn.fetchall() - - sql = ( - "SELECT max(stream_id), destination" - " FROM device_federation_outbox" - " WHERE ? < stream_id AND stream_id <= ?" - " GROUP BY destination" - ) - txn.execute(sql, (last_pos, upper_pos)) - rows.extend(txn) - - # Order by ascending stream ordering - rows.sort() - - return rows - - return self.db.runInteraction( - "get_all_new_device_messages", get_all_new_device_messages_txn - ) diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index d593ef47b8a5..e71c23541d09 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -1267,104 +1267,6 @@ def _count(txn): ret = yield self.db.runInteraction("count_daily_active_rooms", _count) return ret - def get_current_backfill_token(self): - """The current minimum token that backfilled events have reached""" - return -self._backfill_id_gen.get_current_token() - - def get_current_events_token(self): - """The current maximum token that events have reached""" - return self._stream_id_gen.get_current_token() - - def get_all_new_forward_event_rows(self, last_id, current_id, limit): - if last_id == current_id: - return defer.succeed([]) - - def get_all_new_forward_event_rows(txn): - sql = ( - "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " LEFT JOIN event_relations USING (event_id)" - " WHERE ? < stream_ordering AND stream_ordering <= ?" - " ORDER BY stream_ordering ASC" - " LIMIT ?" - ) - txn.execute(sql, (last_id, current_id, limit)) - new_event_updates = txn.fetchall() - - if len(new_event_updates) == limit: - upper_bound = new_event_updates[-1][0] - else: - upper_bound = current_id - - sql = ( - "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" - " FROM events AS e" - " INNER JOIN ex_outlier_stream USING (event_id)" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " LEFT JOIN event_relations USING (event_id)" - " WHERE ? < event_stream_ordering" - " AND event_stream_ordering <= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (last_id, upper_bound)) - new_event_updates.extend(txn) - - return new_event_updates - - return self.db.runInteraction( - "get_all_new_forward_event_rows", get_all_new_forward_event_rows - ) - - def get_all_new_backfill_event_rows(self, last_id, current_id, limit): - if last_id == current_id: - return defer.succeed([]) - - def get_all_new_backfill_event_rows(txn): - sql = ( - "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " LEFT JOIN event_relations USING (event_id)" - " WHERE ? > stream_ordering AND stream_ordering >= ?" - " ORDER BY stream_ordering ASC" - " LIMIT ?" - ) - txn.execute(sql, (-last_id, -current_id, limit)) - new_event_updates = txn.fetchall() - - if len(new_event_updates) == limit: - upper_bound = new_event_updates[-1][0] - else: - upper_bound = current_id - - sql = ( - "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" - " FROM events AS e" - " INNER JOIN ex_outlier_stream USING (event_id)" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " LEFT JOIN event_relations USING (event_id)" - " WHERE ? > event_stream_ordering" - " AND event_stream_ordering >= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (-last_id, -upper_bound)) - new_event_updates.extend(txn.fetchall()) - - return new_event_updates - - return self.db.runInteraction( - "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows - ) - @cached(num_args=5, max_entries=10) def get_all_new_events( self, @@ -1850,22 +1752,6 @@ def _get_event_ordering(self, event_id): return (int(res["topological_ordering"]), int(res["stream_ordering"])) - def get_all_updated_current_state_deltas(self, from_token, to_token, limit): - def get_all_updated_current_state_deltas_txn(txn): - sql = """ - SELECT stream_id, room_id, type, state_key, event_id - FROM current_state_delta_stream - WHERE ? < stream_id AND stream_id <= ? - ORDER BY stream_id ASC LIMIT ? - """ - txn.execute(sql, (from_token, to_token, limit)) - return txn.fetchall() - - return self.db.runInteraction( - "get_all_updated_current_state_deltas", - get_all_updated_current_state_deltas_txn, - ) - def insert_labels_for_event_txn( self, txn, event_id, labels, room_id, topological_ordering ): diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index ca237c6f129e..1361e4e667fa 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -963,3 +963,117 @@ def get_room_complexity(self, room_id): complexity_v1 = round(state_events / 500, 2) return {"v1": complexity_v1} + + def get_current_backfill_token(self): + """The current minimum token that backfilled events have reached""" + return -self._backfill_id_gen.get_current_token() + + def get_current_events_token(self): + """The current maximum token that events have reached""" + return self._stream_id_gen.get_current_token() + + def get_all_new_forward_event_rows(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed([]) + + def get_all_new_forward_event_rows(txn): + sql = ( + "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? < stream_ordering AND stream_ordering <= ?" + " ORDER BY stream_ordering ASC" + " LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + new_event_updates = txn.fetchall() + + if len(new_event_updates) == limit: + upper_bound = new_event_updates[-1][0] + else: + upper_bound = current_id + + sql = ( + "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " INNER JOIN ex_outlier_stream USING (event_id)" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? < event_stream_ordering" + " AND event_stream_ordering <= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (last_id, upper_bound)) + new_event_updates.extend(txn) + + return new_event_updates + + return self.db.runInteraction( + "get_all_new_forward_event_rows", get_all_new_forward_event_rows + ) + + def get_all_new_backfill_event_rows(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed([]) + + def get_all_new_backfill_event_rows(txn): + sql = ( + "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? > stream_ordering AND stream_ordering >= ?" + " ORDER BY stream_ordering ASC" + " LIMIT ?" + ) + txn.execute(sql, (-last_id, -current_id, limit)) + new_event_updates = txn.fetchall() + + if len(new_event_updates) == limit: + upper_bound = new_event_updates[-1][0] + else: + upper_bound = current_id + + sql = ( + "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " INNER JOIN ex_outlier_stream USING (event_id)" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? > event_stream_ordering" + " AND event_stream_ordering >= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (-last_id, -upper_bound)) + new_event_updates.extend(txn.fetchall()) + + return new_event_updates + + return self.db.runInteraction( + "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows + ) + + def get_all_updated_current_state_deltas(self, from_token, to_token, limit): + def get_all_updated_current_state_deltas_txn(txn): + sql = """ + SELECT stream_id, room_id, type, state_key, event_id + FROM current_state_delta_stream + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC LIMIT ? + """ + txn.execute(sql, (from_token, to_token, limit)) + return txn.fetchall() + + return self.db.runInteraction( + "get_all_updated_current_state_deltas", + get_all_updated_current_state_deltas_txn, + ) diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index e6c10c631676..aaebe427d3ac 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -732,6 +732,26 @@ def _quarantine_media_txn( return total_media_quarantined + def get_all_new_public_rooms(self, prev_id, current_id, limit): + def get_all_new_public_rooms(txn): + sql = """ + SELECT stream_id, room_id, visibility, appservice_id, network_id + FROM public_room_list_stream + WHERE stream_id > ? AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + """ + + txn.execute(sql, (prev_id, current_id, limit)) + return txn.fetchall() + + if prev_id == current_id: + return defer.succeed([]) + + return self.db.runInteraction( + "get_all_new_public_rooms", get_all_new_public_rooms + ) + class RoomBackgroundUpdateStore(SQLBaseStore): REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" @@ -1249,26 +1269,6 @@ def add_event_report( def get_current_public_room_stream_id(self): return self._public_room_id_gen.get_current_token() - def get_all_new_public_rooms(self, prev_id, current_id, limit): - def get_all_new_public_rooms(txn): - sql = """ - SELECT stream_id, room_id, visibility, appservice_id, network_id - FROM public_room_list_stream - WHERE stream_id > ? AND stream_id <= ? - ORDER BY stream_id ASC - LIMIT ? - """ - - txn.execute(sql, (prev_id, current_id, limit)) - return txn.fetchall() - - if prev_id == current_id: - return defer.succeed([]) - - return self.db.runInteraction( - "get_all_new_public_rooms", get_all_new_public_rooms - ) - @defer.inlineCallbacks def block_room(self, room_id, user_id): """Marks the room as blocked. Can be called multiple times. From 811d2ecf2ed50613d2f8a0231c4b9487be2ff925 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 20 Mar 2020 15:11:54 +0000 Subject: [PATCH 02/16] Don't panic if streams get behind. The catchup will in future happen on workers, so master process won't need to protect itself by dropping the connection. --- synapse/replication/tcp/protocol.py | 22 +++++---- synapse/replication/tcp/resource.py | 5 +- synapse/replication/tcp/streams/_base.py | 61 ++++++++++-------------- 3 files changed, 42 insertions(+), 46 deletions(-) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index bc1482a9bbf2..d7ef2398fab7 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -485,15 +485,19 @@ async def subscribe_to_stream(self, stream_name, token): self.connecting_streams.add(stream_name) try: - # Get missing updates - updates, current_token = await self.streamer.get_stream_updates( - stream_name, token - ) - - # Send all the missing updates - for update in updates: - token, row = update[0], update[1] - self.send_command(RdataCommand(stream_name, token, row)) + limited = True + while limited: + # Get missing updates + ( + updates, + current_token, + limited, + ) = await self.streamer.get_stream_updates(stream_name, token) + + # Send all the missing updates + for update in updates: + token, row = update[0], update[1] + self.send_command(RdataCommand(stream_name, token, row)) # We send a POSITION command to ensure that they have an up to # date token (especially useful if we didn't send any updates diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 6e2ebaf614d7..5be31024b70e 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -190,7 +190,8 @@ async def _run_notifier_loop(self): stream.current_token(), ) try: - updates, current_token = await stream.get_updates() + updates, current_token, limited = await stream.get_updates() + self.pending_updates |= limited except Exception: logger.info("Failed to handle stream %s", stream.NAME) raise @@ -235,7 +236,7 @@ async def get_stream_updates(self, stream_name, token): if not stream: raise Exception("unknown stream %s", stream_name) - return await stream.get_updates_since(token) + return await stream.get_updates_since(token, stream.current_token()) @measure_func("repl.federation_ack") def federation_ack(self, token): diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index abf5c6c6a840..99cef975320f 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -14,10 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import itertools import logging from collections import namedtuple -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional, Tuple, Union import attr @@ -153,61 +152,53 @@ def discard_updates_and_advance(self): """ self.last_token = self.current_token() - async def get_updates(self): + async def get_updates(self) -> Tuple[List[Tuple[int, JsonDict]], int, bool]: """Gets all updates since the last time this function was called (or since the stream was constructed if it hadn't been called before). Returns: - Deferred[Tuple[List[Tuple[int, Any]], int]: - Resolves to a pair ``(updates, current_token)``, where ``updates`` is a - list of ``(token, row)`` entries. ``row`` will be json-serialised and - sent over the replication steam. + Resolves to a pair `(updates, new_last_token, limited)`, where + `updates` is a list of `(token, row)` entries, `new_last_token` is + the new position in stream, and `limited` is whether there are + more updates to fetch. """ - updates, current_token = await self.get_updates_since(self.last_token) + current_token = self.current_token() + updates, current_token, limited = await self.get_updates_since( + self.last_token, current_token + ) self.last_token = current_token - return updates, current_token + return updates, current_token, limited async def get_updates_since( - self, from_token: int - ) -> Tuple[List[Tuple[int, JsonDict]], int]: + self, from_token: Union[int, str], upto_token: int, limit: int = 100 + ) -> Tuple[List[Tuple[int, JsonDict]], int, bool]: """Like get_updates except allows specifying from when we should stream updates Returns: - Resolves to a pair `(updates, new_last_token)`, where `updates` is - a list of `(token, row)` entries and `new_last_token` is the new - position in stream. + Resolves to a pair `(updates, new_last_token, limited)`, where + `updates` is a list of `(token, row)` entries, `new_last_token` is + the new position in stream, and `limited` is whether there are + more updates to fetch. """ if from_token in ("NOW", "now"): - return [], self.current_token() - - current_token = self.current_token() + return [], upto_token, False from_token = int(from_token) - if from_token == current_token: - return [], current_token - - rows = await self.update_function( - from_token, current_token, limit=MAX_EVENTS_BEHIND + 1 - ) - - # never turn more than MAX_EVENTS_BEHIND + 1 into updates. - rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1) + if from_token == upto_token: + return [], upto_token, False + limited = False + rows = await self.update_function(from_token, upto_token, limit=limit) updates = [(row[0], row[1:]) for row in rows] + if len(updates) == limit: + upto_token = rows[-1][0] + limited = True - # check we didn't get more rows than the limit. - # doing it like this allows the update_function to be a generator. - if len(updates) >= MAX_EVENTS_BEHIND: - raise Exception("stream %s has fallen behind" % (self.NAME)) - - # The update function didn't hit the limit, so we must have got all - # the updates to `current_token`, and can return that as our new - # stream position. - return updates, current_token + return updates, upto_token, limited def current_token(self): """Gets the current token of the underlying streams. Should be provided From ba90596687986c28503dc77b6079bf45bd7f4eb9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 20 Mar 2020 15:17:01 +0000 Subject: [PATCH 03/16] Add ability to catchup on stream by talking to master. --- synapse/federation/sender/__init__.py | 9 +++ synapse/replication/http/__init__.py | 2 + synapse/replication/http/streams.py | 65 +++++++++++++++++++ synapse/replication/tcp/streams/__init__.py | 4 +- synapse/replication/tcp/streams/_base.py | 45 ++++++++++--- synapse/replication/tcp/streams/federation.py | 19 ++++-- 6 files changed, 128 insertions(+), 16 deletions(-) create mode 100644 synapse/replication/http/streams.py diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 233cb33daf94..a477578e445f 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -499,4 +499,13 @@ def wake_destination(self, destination: str): self._get_per_destination_queue(destination).attempt_new_transaction() def get_current_token(self) -> int: + # Dummy implementation for case where federation sender isn't offloaded + # to a worker. return 0 + + async def get_replication_rows( + self, from_token, to_token, limit, federation_ack=None + ): + # Dummy implementation for case where federation sender isn't offloaded + # to a worker. + return [] diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py index 28dbc6fcbaf1..4613b2538ce8 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py @@ -21,6 +21,7 @@ membership, register, send_event, + streams, ) REPLICATION_PREFIX = "/_synapse/replication" @@ -38,3 +39,4 @@ def register_servlets(self, hs): login.register_servlets(hs, self) register.register_servlets(hs, self) devices.register_servlets(hs, self) + streams.register_servlets(hs, self) diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py new file mode 100644 index 000000000000..3889278b2aa9 --- /dev/null +++ b/synapse/replication/http/streams.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.errors import SynapseError +from synapse.http.servlet import parse_integer +from synapse.replication.http._base import ReplicationEndpoint + +logger = logging.getLogger(__name__) + + +class ReplicationGetStreamUpdates(ReplicationEndpoint): + """Fetches stream updates from a server. Used for streams not persisted to + the database, e.g. typing notifications. + """ + + NAME = "get_repl_stream_updates" + PATH_ARGS = ("stream_name",) + METHOD = "GET" + + def __init__(self, hs): + super(ReplicationGetStreamUpdates, self).__init__(hs) + + from synapse.replication.tcp.streams import STREAMS_MAP + + self.streams = {stream.NAME: stream(hs) for stream in STREAMS_MAP.values()} + + @staticmethod + def _serialize_payload(stream_name, from_token, upto_token, limit): + return {"from_token": from_token, "upto_token": upto_token, "limit": limit} + + async def _handle_request(self, request, stream_name): + stream = self.streams.get(stream_name) + if stream is None: + raise SynapseError(400, "Unknown stream") + + from_token = parse_integer(request, "from_token", required=True) + upto_token = parse_integer(request, "upto_token", required=True) + limit = parse_integer(request, "limit", required=True) + + updates, upto_token, limited = await stream.get_updates_since( + from_token, upto_token, limit + ) + + return ( + 200, + {"updates": updates, "upto_token": upto_token, "limited": limited}, + ) + + +def register_servlets(hs, http_server): + ReplicationGetStreamUpdates(hs).register(http_server) diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index 5f52264e8432..c3b9a90ca57f 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -25,6 +25,8 @@ update_function: The function that returns a list of updates between two tokens """ +from typing import Dict, Type + from . import _base, events, federation STREAMS_MAP = { @@ -47,4 +49,4 @@ _base.GroupServerStream, _base.UserSignatureStream, ) -} +} # type: Dict[str, Type[_base.Stream]] diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 99cef975320f..6dea523f8c3d 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -20,6 +20,7 @@ import attr +from synapse.replication.http.streams import ReplicationGetStreamUpdates from synapse.types import JsonDict logger = logging.getLogger(__name__) @@ -127,6 +128,10 @@ class Stream(object): # The type of the row. Used by the default impl of parse_row. ROW_TYPE = None # type: Any + # Whether the update function is only available on master. If True then + # calls to get updates are proxied to the master via a HTTP call. + _QUERY_MASTER = False + @classmethod def parse_row(cls, row): """Parse a row received over replication @@ -143,6 +148,11 @@ def parse_row(cls, row): return cls.ROW_TYPE(*row) def __init__(self, hs): + self._is_worker = hs.config.worker_app is not None + + if self._QUERY_MASTER and self._is_worker: + self._replication_client = ReplicationGetStreamUpdates.make_client(hs) + # The token from which we last asked for updates self.last_token = self.current_token() @@ -191,14 +201,23 @@ async def get_updates_since( if from_token == upto_token: return [], upto_token, False - limited = False - rows = await self.update_function(from_token, upto_token, limit=limit) - updates = [(row[0], row[1:]) for row in rows] - if len(updates) == limit: - upto_token = rows[-1][0] - limited = True - - return updates, upto_token, limited + if self._is_worker and self._QUERY_MASTER: + result = await self._replication_client( + stream_name=self.NAME, + from_token=from_token, + upto_token=upto_token, + limit=limit, + ) + return result["updates"], result["upto_token"], result["limited"] + else: + limited = False + rows = await self.update_function(from_token, upto_token, limit=limit) + updates = [(row[0], row[1:]) for row in rows] + if len(updates) == limit: + upto_token = rows[-1][0] + limited = True + + return updates, upto_token, limited def current_token(self): """Gets the current token of the underlying streams. Should be provided @@ -239,13 +258,16 @@ def __init__(self, hs): class PresenceStream(Stream): NAME = "presence" ROW_TYPE = PresenceStreamRow + _QUERY_MASTER = True def __init__(self, hs): store = hs.get_datastore() presence_handler = hs.get_presence_handler() self.current_token = store.get_current_presence_token # type: ignore - self.update_function = presence_handler.get_all_presence_updates # type: ignore + + if hs.config.worker_app is None: + self.update_function = presence_handler.get_all_presence_updates # type: ignore super(PresenceStream, self).__init__(hs) @@ -253,12 +275,15 @@ def __init__(self, hs): class TypingStream(Stream): NAME = "typing" ROW_TYPE = TypingStreamRow + _QUERY_MASTER = True def __init__(self, hs): typing_handler = hs.get_typing_handler() self.current_token = typing_handler.get_current_token # type: ignore - self.update_function = typing_handler.get_all_typing_updates # type: ignore + + if hs.config.worker_app is None: + self.update_function = typing_handler.get_all_typing_updates # type: ignore super(TypingStream, self).__init__(hs) diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index 615f3dc9ac9e..5d9e87188b45 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -15,7 +15,9 @@ # limitations under the License. from collections import namedtuple -from ._base import Stream +from twisted.internet import defer + +from synapse.replication.tcp.streams._base import Stream FederationStreamRow = namedtuple( "FederationStreamRow", @@ -33,11 +35,18 @@ class FederationStream(Stream): NAME = "federation" ROW_TYPE = FederationStreamRow + _QUERY_MASTER = True def __init__(self, hs): - federation_sender = hs.get_federation_sender() - - self.current_token = federation_sender.get_current_token # type: ignore - self.update_function = federation_sender.get_replication_rows # type: ignore + # Not all synapse instances will have a federation sender instance, + # whether that's a `FederationSender` or a `FederationRemoteSendQueue`, + # so we stub the stream out when that is the case. + if hs.config.worker_app is None or hs.should_send_federation(): + federation_sender = hs.get_federation_sender() + self.current_token = federation_sender.get_current_token # type: ignore + self.update_function = federation_sender.get_replication_rows # type: ignore + else: + self.current_token = lambda: 0 # type: ignore + self.update_function = lambda *args, **kwargs: defer.succeed([]) # type: ignore super(FederationStream, self).__init__(hs) From 1f83255de17eb2de35fc42b91ebaaaf895771aa6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 3 Mar 2020 14:19:23 +0000 Subject: [PATCH 04/16] Move stream catchup to workers. --- synapse/replication/tcp/client.py | 3 +- synapse/replication/tcp/protocol.py | 105 ++++++++---------- synapse/replication/tcp/resource.py | 5 +- synapse/replication/tcp/streams/__init__.py | 6 +- tests/replication/tcp/streams/_base.py | 51 +++++++-- .../replication/tcp/streams/test_receipts.py | 50 +++++++-- 6 files changed, 135 insertions(+), 85 deletions(-) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 02ab5b66eab7..7e7ad0f7980b 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -55,6 +55,7 @@ def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler): self.client_name = client_name self.handler = handler self.server_name = hs.config.server_name + self.hs = hs self._clock = hs.get_clock() # As self.clock is defined in super class hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying) @@ -65,7 +66,7 @@ def startedConnecting(self, connector): def buildProtocol(self, addr): logger.info("Connected to replication: %r", addr) return ClientReplicationStreamProtocol( - self.client_name, self.server_name, self._clock, self.handler + self.hs, self.client_name, self.server_name, self._clock, self.handler, ) def clientConnectionLost(self, connector, reason): diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index d7ef2398fab7..649312f0223f 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -82,7 +82,8 @@ SyncCommand, UserSyncCommand, ) -from synapse.replication.tcp.streams import STREAMS_MAP +from synapse.replication.tcp.streams import STREAMS_MAP, Stream +from synapse.server import HomeServer from synapse.types import Collection from synapse.util import Clock from synapse.util.stringutils import random_string @@ -414,9 +415,6 @@ def __init__(self, server_name, clock, streamer): # The streams the client has subscribed to and is up to date with self.replication_streams = set() # type: Set[str] - # The streams the client is currently subscribing to. - self.connecting_streams = set() # type: Set[str] - # Map from stream name to list of updates to send once we've finished # subscribing the client to the stream. self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]] @@ -482,67 +480,21 @@ async def subscribe_to_stream(self, stream_name, token): are queued and sent once we've sent down any missed updates. """ self.replication_streams.discard(stream_name) - self.connecting_streams.add(stream_name) try: - limited = True - while limited: - # Get missing updates - ( - updates, - current_token, - limited, - ) = await self.streamer.get_stream_updates(stream_name, token) - - # Send all the missing updates - for update in updates: - token, row = update[0], update[1] - self.send_command(RdataCommand(stream_name, token, row)) + # Get current stream position. + current_token = self.streamer.get_stream_token(stream_name) # We send a POSITION command to ensure that they have an up to # date token (especially useful if we didn't send any updates # above) self.send_command(PositionCommand(stream_name, current_token)) - # Now we can send any updates that came in while we were subscribing - pending_rdata = self.pending_rdata.pop(stream_name, []) - updates = [] - for token, update in pending_rdata: - # If the token is null, it is part of a batch update. Batches - # are multiple updates that share a single token. To denote - # this, the token is set to None for all tokens in the batch - # except for the last. If we find a None token, we keep looking - # through tokens until we find one that is not None and then - # process all previous updates in the batch as if they had the - # final token. - if token is None: - # Store this update as part of a batch - updates.append(update) - continue - - if token <= current_token: - # This update or batch of updates is older than - # current_token, dismiss it - updates = [] - continue - - updates.append(update) - - # Send all updates that are part of this batch with the - # found token - for update in updates: - self.send_command(RdataCommand(stream_name, token, update)) - - # Clear stored updates - updates = [] - # They're now fully subscribed self.replication_streams.add(stream_name) except Exception as e: logger.exception("[%s] Failed to handle REPLICATE command", self.id()) self.send_error("failed to handle replicate: %r", e) - finally: - self.connecting_streams.discard(stream_name) def stream_update(self, stream_name, token, data): """Called when a new update is available to stream to clients. @@ -552,10 +504,6 @@ def stream_update(self, stream_name, token, data): if stream_name in self.replication_streams: # The client is subscribed to the stream self.send_command(RdataCommand(stream_name, token, data)) - elif stream_name in self.connecting_streams: - # The client is being subscribed to the stream - logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token) - self.pending_rdata.setdefault(stream_name, []).append((token, data)) else: # The client isn't subscribed logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token) @@ -642,6 +590,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): def __init__( self, + hs: HomeServer, client_name: str, server_name: str, clock: Clock, @@ -653,6 +602,10 @@ def __init__( self.server_name = server_name self.handler = handler + self.streams = { + stream.NAME: stream(hs) for stream in STREAMS_MAP.values() + } # type: Dict[str, Stream] + # Set of stream names that have been subscribe to, but haven't yet # caught up with. This is used to track when the client has been fully # connected to the remote. @@ -660,7 +613,7 @@ def __init__( # Map of stream to batched updates. See RdataCommand for info on how # batching works. - self.pending_batches = {} # type: Dict[str, Any] + self.pending_batches = {} # type: Dict[str, List[Any]] def connectionMade(self): self.send_command(NameCommand(self.client_name)) @@ -701,7 +654,7 @@ async def on_RDATA(self, cmd): ) raise - if cmd.token is None: + if cmd.token is None or stream_name in self.streams_connecting: # I.e. this is part of a batch of updates for this stream. Batch # until we get an update for the stream with a non None token self.pending_batches.setdefault(stream_name, []).append(row) @@ -711,14 +664,46 @@ async def on_RDATA(self, cmd): rows.append(row) await self.handler.on_rdata(stream_name, cmd.token, rows) - async def on_POSITION(self, cmd): + async def on_POSITION(self, cmd: PositionCommand): + stream = self.streams.get(cmd.stream_name) + if not stream: + logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) + return + + # Find where we previously streamed up to. + current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name) + if current_token is None: + logger.warning( + "Got POSITION for stream we're not subscribed to: %s", cmd.stream_name + ) + return + + # Fetch all updates between then and now. + limited = True + while limited: + updates, current_token, limited = await stream.get_updates_since( + current_token, cmd.token + ) + if updates: + await self.handler.on_rdata( + cmd.stream_name, + current_token, + [stream.parse_row(update[1]) for update in updates], + ) + + # We've now caught up to position sent to us, notify handler. + await self.handler.on_position(cmd.stream_name, cmd.token) + # When we get a `POSITION` command it means we've finished getting # missing updates for the given stream, and are now up to date. self.streams_connecting.discard(cmd.stream_name) if not self.streams_connecting: self.handler.finished_connecting() - await self.handler.on_position(cmd.stream_name, cmd.token) + # Handle any RDATA that came in while we were catching up. + rows = self.pending_batches.pop(cmd.stream_name, []) + if rows: + await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows) async def on_SYNC(self, cmd): self.handler.on_sync(cmd.data) diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 5be31024b70e..757129b6d5ef 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -227,8 +227,7 @@ async def _run_notifier_loop(self): self.pending_updates = False self.is_looping = False - @measure_func("repl.get_stream_updates") - async def get_stream_updates(self, stream_name, token): + def get_stream_token(self, stream_name): """For a given stream get all updates since token. This is called when a client first subscribes to a stream. """ @@ -236,7 +235,7 @@ async def get_stream_updates(self, stream_name, token): if not stream: raise Exception("unknown stream %s", stream_name) - return await stream.get_updates_since(token, stream.current_token()) + return stream.current_token() @measure_func("repl.federation_ack") def federation_ack(self, token): diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index c3b9a90ca57f..6f5da99f8503 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -27,7 +27,8 @@ from typing import Dict, Type -from . import _base, events, federation +from synapse.replication.tcp.streams import _base, events, federation +from synapse.replication.tcp.streams._base import Stream STREAMS_MAP = { stream.NAME: stream @@ -50,3 +51,6 @@ _base.UserSignatureStream, ) } # type: Dict[str, Type[_base.Stream]] + + +__all__ = ["Stream", "STREAMS_MAP"] diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index e96ad4ca4e48..b7a61e22f21c 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from mock import Mock from synapse.replication.tcp.commands import ReplicateCommand @@ -29,19 +30,37 @@ def prepare(self, reactor, clock, hs): # build a replication server server_factory = ReplicationStreamProtocolFactory(self.hs) self.streamer = server_factory.streamer - server = server_factory.buildProtocol(None) + self.server = server_factory.buildProtocol(None) - # build a replication client, with a dummy handler - handler_factory = Mock() - self.test_handler = TestReplicationClientHandler() - self.test_handler.factory = handler_factory + self.test_handler = Mock(wraps=TestReplicationClientHandler()) self.client = ClientReplicationStreamProtocol( - "client", "test", clock, self.test_handler + hs, "client", "test", clock, self.test_handler, ) - # wire them together - self.client.makeConnection(FakeTransport(server, reactor)) - server.makeConnection(FakeTransport(self.client, reactor)) + self._client_transport = None + self._server_transport = None + + def reconnect(self): + if self._client_transport: + self.client.close() + + if self._server_transport: + self.server.close() + + self._client_transport = FakeTransport(self.server, self.reactor) + self.client.makeConnection(self._client_transport) + + self._server_transport = FakeTransport(self.client, self.reactor) + self.server.makeConnection(self._server_transport) + + def disconnect(self): + if self._client_transport: + self._client_transport = None + self.client.close() + + if self._server_transport: + self._server_transport = None + self.server.close() def replicate(self): """Tell the master side of replication that something has happened, and then @@ -59,10 +78,15 @@ class TestReplicationClientHandler(object): """Drop-in for ReplicationClientHandler which just collects RDATA rows""" def __init__(self): - self.received_rdata_rows = [] + self.streams = set() + self._received_rdata_rows = [] def get_streams_to_replicate(self): - return {} + positions = {s: 0 for s in self.streams} + for stream, token, _ in self._received_rdata_rows: + if stream in self.streams: + positions[stream] = max(token, positions.get(stream, 0)) + return positions def get_currently_syncing_users(self): return [] @@ -73,6 +97,9 @@ def update_connection(self, connection): def finished_connecting(self): pass + async def on_position(self, stream_name, token): + """Called when we get new position data.""" + async def on_rdata(self, stream_name, token, rows): for r in rows: - self.received_rdata_rows.append((stream_name, token, r)) + self._received_rdata_rows.append((stream_name, token, r)) diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index d5a99f6caaf3..28862b2fe5ca 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -17,30 +17,64 @@ from tests.replication.tcp.streams._base import BaseStreamTestCase USER_ID = "@feeling:blue" -ROOM_ID = "!room:blue" -EVENT_ID = "$event:blue" class ReceiptsStreamTestCase(BaseStreamTestCase): def test_receipt(self): + self.reconnect() + # make the client subscribe to the receipts stream self.replicate_stream("receipts", "NOW") + self.test_handler.streams.add("receipts") # tell the master to send a new receipt self.get_success( self.hs.get_datastore().insert_receipt( - ROOM_ID, "m.read", USER_ID, [EVENT_ID], {"a": 1} + "!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1} ) ) self.replicate() # there should be one RDATA command - rdata_rows = self.test_handler.received_rdata_rows + self.test_handler.on_rdata.assert_called_once() + stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + self.assertEqual(stream_name, "receipts") self.assertEqual(1, len(rdata_rows)) - self.assertEqual(rdata_rows[0][0], "receipts") - row = rdata_rows[0][2] # type: ReceiptsStreamRow - self.assertEqual(ROOM_ID, row.room_id) + row = rdata_rows[0] # type: ReceiptsStreamRow + self.assertEqual("!room:blue", row.room_id) self.assertEqual("m.read", row.receipt_type) self.assertEqual(USER_ID, row.user_id) - self.assertEqual(EVENT_ID, row.event_id) + self.assertEqual("$event:blue", row.event_id) self.assertEqual({"a": 1}, row.data) + + # Now let's disconnect and insert some data. + self.disconnect() + + self.test_handler.on_rdata.reset_mock() + + self.get_success( + self.hs.get_datastore().insert_receipt( + "!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2} + ) + ) + self.replicate() + + # Nothing should have happened as we are disconnected + self.test_handler.on_rdata.assert_not_called() + + self.reconnect() + self.pump(0.1) + + # We should now have caught up and get the missing data + self.test_handler.on_rdata.assert_called_once() + stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + self.assertEqual(stream_name, "receipts") + self.assertEqual(token, 3) + self.assertEqual(1, len(rdata_rows)) + + row = rdata_rows[0] # type: ReceiptsStreamRow + self.assertEqual("!room2:blue", row.room_id) + self.assertEqual("m.read", row.receipt_type) + self.assertEqual(USER_ID, row.user_id) + self.assertEqual("$event2:foo", row.event_id) + self.assertEqual({"a": 2}, row.data) From 8734b75ca8b4b81f5998f5c2ef57dfa0998c66ac Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 3 Mar 2020 16:51:34 +0000 Subject: [PATCH 05/16] Remove unused token param from REPLICATE cmd --- synapse/replication/tcp/commands.py | 23 +++++------------------ synapse/replication/tcp/protocol.py | 24 ++++++++---------------- tests/replication/tcp/streams/_base.py | 2 +- 3 files changed, 14 insertions(+), 35 deletions(-) diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 451671412d15..e506f529351f 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -183,35 +183,22 @@ class ReplicateCommand(Command): Format:: - REPLICATE + REPLICATE - Where may be either: - * a numeric stream_id to stream updates from - * "NOW" to stream all subsequent updates. - - The can be "ALL" to subscribe to all known streams, in which - case the must be set to "NOW", i.e.:: - - REPLICATE ALL NOW + The can be "ALL" to subscribe to all known streams """ NAME = "REPLICATE" - def __init__(self, stream_name, token): + def __init__(self, stream_name): self.stream_name = stream_name - self.token = token @classmethod def from_line(cls, line): - stream_name, token = line.split(" ", 1) - if token in ("NOW", "now"): - token = "NOW" - else: - token = int(token) - return cls(stream_name, token) + return cls(line) def to_line(self): - return " ".join((self.stream_name, str(self.token))) + return self.stream_name def get_logcontext_id(self): return "REPLICATE-" + self.stream_name diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 649312f0223f..817b84ad7fcc 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -435,12 +435,11 @@ async def on_USER_SYNC(self, cmd): async def on_REPLICATE(self, cmd): stream_name = cmd.stream_name - token = cmd.token if stream_name == "ALL": # Subscribe to all streams we're publishing to. deferreds = [ - run_in_background(self.subscribe_to_stream, stream, token) + run_in_background(self.subscribe_to_stream, stream) for stream in iterkeys(self.streamer.streams_by_name) ] @@ -448,7 +447,7 @@ async def on_REPLICATE(self, cmd): defer.gatherResults(deferreds, consumeErrors=True) ) else: - await self.subscribe_to_stream(stream_name, token) + await self.subscribe_to_stream(stream_name) async def on_FEDERATION_ACK(self, cmd): self.streamer.federation_ack(cmd.token) @@ -472,12 +471,8 @@ async def on_USER_IP(self, cmd): cmd.last_seen, ) - async def subscribe_to_stream(self, stream_name, token): + async def subscribe_to_stream(self, stream_name): """Subscribe the remote to a stream. - - This invloves checking if they've missed anything and sending those - updates down if they have. During that time new updates for the stream - are queued and sent once we've sent down any missed updates. """ self.replication_streams.discard(stream_name) @@ -620,8 +615,8 @@ def connectionMade(self): BaseReplicationStreamProtocol.connectionMade(self) # Once we've connected subscribe to the necessary streams - for stream_name, token in iteritems(self.handler.get_streams_to_replicate()): - self.replicate(stream_name, token) + for stream_name in self.handler.get_streams_to_replicate(): + self.replicate(stream_name) # Tell the server if we have any users currently syncing (should only # happen on synchrotrons) @@ -711,22 +706,19 @@ async def on_SYNC(self, cmd): async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): self.handler.on_remote_server_up(cmd.data) - def replicate(self, stream_name, token): + def replicate(self, stream_name): """Send the subscription request to the server """ if stream_name not in STREAMS_MAP: raise Exception("Invalid stream name %r" % (stream_name,)) logger.info( - "[%s] Subscribing to replication stream: %r from %r", - self.id(), - stream_name, - token, + "[%s] Subscribing to replication stream: %r", self.id(), stream_name, ) self.streams_connecting.add(stream_name) - self.send_command(ReplicateCommand(stream_name, token)) + self.send_command(ReplicateCommand(stream_name)) def on_connection_closed(self): BaseReplicationStreamProtocol.on_connection_closed(self) diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index b7a61e22f21c..f69564cd32d5 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -71,7 +71,7 @@ def replicate(self): def replicate_stream(self, stream, token="NOW"): """Make the client end a REPLICATE command to set up a subscription to a stream""" - self.client.send_command(ReplicateCommand(stream, token)) + self.client.send_command(ReplicateCommand(stream)) class TestReplicationClientHandler(object): From 32c656865a9d1a7bcd998fb872f6d08e18f39be2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 3 Mar 2020 17:05:35 +0000 Subject: [PATCH 06/16] Always subscribe to all streams. This already happens since the worker merge. --- synapse/replication/tcp/protocol.py | 80 +++++------------------------ 1 file changed, 12 insertions(+), 68 deletions(-) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 817b84ad7fcc..b371d66ce7e8 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -53,17 +53,15 @@ import logging import struct from collections import defaultdict -from typing import Any, DefaultDict, Dict, List, Set, Tuple +from typing import Any, DefaultDict, Dict, List, Set -from six import iteritems, iterkeys +from six import iteritems from prometheus_client import Counter -from twisted.internet import defer from twisted.protocols.basic import LineOnlyReceiver from twisted.python.failure import Failure -from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.commands import ( @@ -412,13 +410,6 @@ def __init__(self, server_name, clock, streamer): self.server_name = server_name self.streamer = streamer - # The streams the client has subscribed to and is up to date with - self.replication_streams = set() # type: Set[str] - - # Map from stream name to list of updates to send once we've finished - # subscribing the client to the stream. - self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]] - def connectionMade(self): self.send_command(ServerCommand(self.server_name)) BaseReplicationStreamProtocol.connectionMade(self) @@ -434,20 +425,10 @@ async def on_USER_SYNC(self, cmd): ) async def on_REPLICATE(self, cmd): - stream_name = cmd.stream_name - - if stream_name == "ALL": - # Subscribe to all streams we're publishing to. - deferreds = [ - run_in_background(self.subscribe_to_stream, stream) - for stream in iterkeys(self.streamer.streams_by_name) - ] - - await make_deferred_yieldable( - defer.gatherResults(deferreds, consumeErrors=True) - ) - else: - await self.subscribe_to_stream(stream_name) + # Subscribe to all streams we're publishing to. + for stream_name in self.streamer.streams_by_name: + current_token = self.streamer.get_stream_token(stream_name) + self.send_command(PositionCommand(stream_name, current_token)) async def on_FEDERATION_ACK(self, cmd): self.streamer.federation_ack(cmd.token) @@ -471,37 +452,12 @@ async def on_USER_IP(self, cmd): cmd.last_seen, ) - async def subscribe_to_stream(self, stream_name): - """Subscribe the remote to a stream. - """ - self.replication_streams.discard(stream_name) - - try: - # Get current stream position. - current_token = self.streamer.get_stream_token(stream_name) - - # We send a POSITION command to ensure that they have an up to - # date token (especially useful if we didn't send any updates - # above) - self.send_command(PositionCommand(stream_name, current_token)) - - # They're now fully subscribed - self.replication_streams.add(stream_name) - except Exception as e: - logger.exception("[%s] Failed to handle REPLICATE command", self.id()) - self.send_error("failed to handle replicate: %r", e) - def stream_update(self, stream_name, token, data): """Called when a new update is available to stream to clients. We need to check if the client is interested in the stream or not """ - if stream_name in self.replication_streams: - # The client is subscribed to the stream - self.send_command(RdataCommand(stream_name, token, data)) - else: - # The client isn't subscribed - logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token) + self.send_command(RdataCommand(stream_name, token, data)) def send_sync(self, data): self.send_command(SyncCommand(data)) @@ -604,7 +560,7 @@ def __init__( # Set of stream names that have been subscribe to, but haven't yet # caught up with. This is used to track when the client has been fully # connected to the remote. - self.streams_connecting = set() # type: Set[str] + self.streams_connecting = set(STREAMS_MAP) # type: Set[str] # Map of stream to batched updates. See RdataCommand for info on how # batching works. @@ -615,8 +571,7 @@ def connectionMade(self): BaseReplicationStreamProtocol.connectionMade(self) # Once we've connected subscribe to the necessary streams - for stream_name in self.handler.get_streams_to_replicate(): - self.replicate(stream_name) + self.replicate() # Tell the server if we have any users currently syncing (should only # happen on synchrotrons) @@ -628,10 +583,6 @@ def connectionMade(self): # We've now finished connecting to so inform the client handler self.handler.update_connection(self) - # This will happen if we don't actually subscribe to any streams - if not self.streams_connecting: - self.handler.finished_connecting() - async def on_SERVER(self, cmd): if cmd.data != self.server_name: logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) @@ -706,19 +657,12 @@ async def on_SYNC(self, cmd): async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): self.handler.on_remote_server_up(cmd.data) - def replicate(self, stream_name): + def replicate(self): """Send the subscription request to the server """ - if stream_name not in STREAMS_MAP: - raise Exception("Invalid stream name %r" % (stream_name,)) - - logger.info( - "[%s] Subscribing to replication stream: %r", self.id(), stream_name, - ) - - self.streams_connecting.add(stream_name) + logger.info("[%s] Subscribing to replication streams", self.id()) - self.send_command(ReplicateCommand(stream_name)) + self.send_command(ReplicateCommand("ALL")) def on_connection_closed(self): BaseReplicationStreamProtocol.on_connection_closed(self) From 259cdffa96058d909a491a1cbc992876699e7920 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 3 Mar 2020 14:30:33 +0000 Subject: [PATCH 07/16] Newsfile --- changelog.d/7024.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/7024.misc diff --git a/changelog.d/7024.misc b/changelog.d/7024.misc new file mode 100644 index 000000000000..676f285377f5 --- /dev/null +++ b/changelog.d/7024.misc @@ -0,0 +1 @@ +Move catchup of replication streams logic to worker. From a2070a2c4e008ceff6decce3d569f984d5e0f902 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 23 Mar 2020 14:56:22 +0000 Subject: [PATCH 08/16] Remove unused 'stream' param of REPLICATE and update docs --- docs/tcp_replication.md | 40 +++++-------------- synapse/replication/tcp/commands.py | 17 +++----- synapse/replication/tcp/protocol.py | 6 +-- tests/replication/tcp/streams/_base.py | 4 +- .../replication/tcp/streams/test_receipts.py | 2 +- 5 files changed, 21 insertions(+), 48 deletions(-) diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md index e3a4634b1407..15a61f6fcf8c 100644 --- a/docs/tcp_replication.md +++ b/docs/tcp_replication.md @@ -14,16 +14,16 @@ example flow would be (where '>' indicates master to worker and '<' worker to master flows): > SERVER example.com - < REPLICATE events 53 + < REPLICATE + > POSITION events 53 > RDATA events 54 ["$foo1:bar.com", ...] > RDATA events 55 ["$foo4:bar.com", ...] -The example shows the server accepting a new connection and sending its -identity with the `SERVER` command, followed by the client asking to -subscribe to the `events` stream from the token `53`. The server then -periodically sends `RDATA` commands which have the format -`RDATA `, where the format of `` is -defined by the individual streams. +The example shows the server accepting a new connection and sending its identity +with the `SERVER` command, followed by the client server to respond with the +position of all streams. The server then periodically sends `RDATA` commands +which have the format `RDATA `, where the format of +`` is defined by the individual streams. Error reporting happens by either the client or server sending an ERROR command, and usually the connection will be closed. @@ -32,9 +32,6 @@ Since the protocol is a simple line based, its possible to manually connect to the server using a tool like netcat. A few things should be noted when manually using the protocol: -- When subscribing to a stream using `REPLICATE`, the special token - `NOW` can be used to get all future updates. The special stream name - `ALL` can be used with `NOW` to subscribe to all available streams. - The federation stream is only available if federation sending has been disabled on the main process. - The server will only time connections out that have sent a `PING` @@ -91,9 +88,7 @@ The client: - Sends a `NAME` command, allowing the server to associate a human friendly name with the connection. This is optional. - Sends a `PING` as above -- For each stream the client wishes to subscribe to it sends a - `REPLICATE` with the `stream_name` and token it wants to subscribe - from. +- Sends a `REPLICATE` to get the current position of all streams. - On receipt of a `SERVER` command, checks that the server name matches the expected server name. @@ -140,9 +135,7 @@ the wire: > PING 1490197665618 < NAME synapse.app.appservice < PING 1490197665618 - < REPLICATE events 1 - < REPLICATE backfill 1 - < REPLICATE caches 1 + < REPLICATE > POSITION events 1 > POSITION backfill 1 > POSITION caches 1 @@ -199,20 +192,7 @@ client (C): #### REPLICATE (C) -Asks the server to replicate a given stream. The syntax is: - -``` - REPLICATE -``` - -Where `` may be either: - * a numeric stream_id to stream updates since (exclusive) - * `NOW` to stream all subsequent updates. - -The `` is the name of a replication stream to subscribe -to (see [here](../synapse/replication/tcp/streams/_base.py) for a list -of streams). It can also be `ALL` to subscribe to all known streams, -in which case the `` must be set to `NOW`. +Asks the server for the current position of all streams. #### USER_SYNC (C) diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index e506f529351f..b0f06c6d83d9 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -179,29 +179,24 @@ class NameCommand(Command): class ReplicateCommand(Command): - """Sent by the client to subscribe to the stream. + """Sent by the client to subscribe to streams. Format:: - REPLICATE - - The can be "ALL" to subscribe to all known streams + REPLICATE """ NAME = "REPLICATE" - def __init__(self, stream_name): - self.stream_name = stream_name + def __init__(self): + pass @classmethod def from_line(cls, line): - return cls(line) + return cls() def to_line(self): - return self.stream_name - - def get_logcontext_id(self): - return "REPLICATE-" + self.stream_name + return "" class UserSyncCommand(Command): diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index b371d66ce7e8..13e5fa9b1205 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -35,9 +35,7 @@ > PING 1490197665618 < NAME synapse.app.appservice < PING 1490197665618 - < REPLICATE events 1 - < REPLICATE backfill 1 - < REPLICATE caches 1 + < REPLICATE > POSITION events 1 > POSITION backfill 1 > POSITION caches 1 @@ -662,7 +660,7 @@ def replicate(self): """ logger.info("[%s] Subscribing to replication streams", self.id()) - self.send_command(ReplicateCommand("ALL")) + self.send_command(ReplicateCommand()) def on_connection_closed(self): BaseReplicationStreamProtocol.on_connection_closed(self) diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index f69564cd32d5..a755fe28794f 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -69,9 +69,9 @@ def replicate(self): self.streamer.on_notifier_poke() self.pump(0.1) - def replicate_stream(self, stream, token="NOW"): + def replicate_stream(self): """Make the client end a REPLICATE command to set up a subscription to a stream""" - self.client.send_command(ReplicateCommand(stream)) + self.client.send_command(ReplicateCommand()) class TestReplicationClientHandler(object): diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index 0411809e740c..0ec0825a0e62 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -24,7 +24,7 @@ def test_receipt(self): self.reconnect() # make the client subscribe to the receipts stream - self.replicate_stream("receipts", "NOW") + self.replicate_stream() self.test_handler.streams.add("receipts") # tell the master to send a new receipt From ba1a8be9300595104c580e2c8e652ba2c58afff3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 23 Mar 2020 16:13:12 +0000 Subject: [PATCH 09/16] Review comments --- docs/tcp_replication.md | 5 ++--- synapse/replication/http/streams.py | 15 ++++++++++++++- synapse/replication/tcp/commands.py | 4 ++-- synapse/replication/tcp/protocol.py | 3 +-- synapse/replication/tcp/streams/_base.py | 16 ++++++++-------- 5 files changed, 27 insertions(+), 16 deletions(-) diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md index 15a61f6fcf8c..5b26f70f88c1 100644 --- a/docs/tcp_replication.md +++ b/docs/tcp_replication.md @@ -174,9 +174,8 @@ client (C): #### POSITION (S) - The position of the stream has been updated. Sent to the client - after all missing updates for a stream have been sent to the client - and they're now up to date. + On receipt of a POSITION command clients should check if they have missed any + updates, and if so then fetch them out of band. #### ERROR (S, C) diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py index 3889278b2aa9..141df687870d 100644 --- a/synapse/replication/http/streams.py +++ b/synapse/replication/http/streams.py @@ -25,6 +25,19 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): """Fetches stream updates from a server. Used for streams not persisted to the database, e.g. typing notifications. + + The API looks like: + + GET /_synapse/replication/get_repl_stream_updates/events?from_token=0&to_token=10&limit=100 + + 200 OK + + { + updates: [ ... ], + upto_token: 10, + limited: False, + } + """ NAME = "get_repl_stream_updates" @@ -32,7 +45,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): METHOD = "GET" def __init__(self, hs): - super(ReplicationGetStreamUpdates, self).__init__(hs) + super().__init__(hs) from synapse.replication.tcp.streams import STREAMS_MAP diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index b0f06c6d83d9..5a6b734094c7 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -136,8 +136,8 @@ class PositionCommand(Command): """Sent by the server to tell the client the stream postition without needing to send an RDATA. - Sent to the client after all missing updates for a stream have been sent - to the client and they're now up to date. + On receipt of a POSITION command clients should check if they have missed + any updates, and if so then fetch them out of band. """ NAME = "POSITION" diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 13e5fa9b1205..8aa749265c81 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -638,8 +638,7 @@ async def on_POSITION(self, cmd: PositionCommand): # We've now caught up to position sent to us, notify handler. await self.handler.on_position(cmd.stream_name, cmd.token) - # When we get a `POSITION` command it means we've finished getting - # missing updates for the given stream, and are now up to date. + # We're now up to date wit the stream self.streams_connecting.discard(cmd.stream_name) if not self.streams_connecting: self.handler.finished_connecting() diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 97af6bf9e196..d5b9c2831b3a 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -79,10 +79,10 @@ async def get_updates(self) -> Tuple[List[Tuple[int, JsonDict]], int, bool]: since the stream was constructed if it hadn't been called before). Returns: - Resolves to a pair `(updates, new_last_token, limited)`, where - `updates` is a list of `(token, row)` entries, `new_last_token` is - the new position in stream, and `limited` is whether there are - more updates to fetch. + A triplet `(updates, new_last_token, limited)`, where `updates` is + a list of `(token, row)` entries, `new_last_token` is the new + position in stream, and `limited` is whether there are more updates + to fetch. """ current_token = self.current_token() updates, current_token, limited = await self.get_updates_since( @@ -99,10 +99,10 @@ async def get_updates_since( stream updates Returns: - Resolves to a pair `(updates, new_last_token, limited)`, where - `updates` is a list of `(token, row)` entries, `new_last_token` is - the new position in stream, and `limited` is whether there are - more updates to fetch. + A triplet `(updates, new_last_token, limited)`, where `updates` is + a list of `(token, row)` entries, `new_last_token` is the new + position in stream, and `limited` is whether there are more updates + to fetch. """ if from_token in ("NOW", "now"): From 3204b0e79fc0521281de6e5270375b9855201dfb Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 23 Mar 2020 18:29:21 +0000 Subject: [PATCH 10/16] Handle connection closing under us --- synapse/replication/tcp/protocol.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 8aa749265c81..e266c7241761 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -628,6 +628,12 @@ async def on_POSITION(self, cmd: PositionCommand): updates, current_token, limited = await stream.get_updates_since( current_token, cmd.token ) + + # Check if the connection was closed underneath us, if so we bail + # rather than risk having concurrent catch ups going on. + if self.state == ConnectionStates.CLOSED: + return + if updates: await self.handler.on_rdata( cmd.stream_name, @@ -643,6 +649,11 @@ async def on_POSITION(self, cmd: PositionCommand): if not self.streams_connecting: self.handler.finished_connecting() + # Check if the connection was closed underneath us, if so we bail + # rather than risk having concurrent catch ups going on. + if self.state == ConnectionStates.CLOSED: + return + # Handle any RDATA that came in while we were catching up. rows = self.pending_batches.pop(cmd.stream_name, []) if rows: From 2380e401e43f626c2dc64e5d1ab2297088097746 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 24 Mar 2020 11:47:57 +0000 Subject: [PATCH 11/16] Remove import loop --- synapse/replication/http/streams.py | 6 +++--- synapse/replication/tcp/protocol.py | 8 ++++++-- synapse/replication/tcp/resource.py | 11 ++++++++--- synapse/server.py | 5 +++++ 4 files changed, 22 insertions(+), 8 deletions(-) diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py index 141df687870d..ffd4c6199378 100644 --- a/synapse/replication/http/streams.py +++ b/synapse/replication/http/streams.py @@ -47,9 +47,9 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): def __init__(self, hs): super().__init__(hs) - from synapse.replication.tcp.streams import STREAMS_MAP - - self.streams = {stream.NAME: stream(hs) for stream in STREAMS_MAP.values()} + # We pull the streams from the replication steamer (if we try and make + # them ourselves we end up in an import loop). + self.streams = hs.get_replication_streamer().get_streams() @staticmethod def _serialize_payload(stream_name, from_token, upto_token, limit): diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index e266c7241761..67de5c3e7ede 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -79,11 +79,15 @@ UserSyncCommand, ) from synapse.replication.tcp.streams import STREAMS_MAP, Stream -from synapse.server import HomeServer from synapse.types import Collection from synapse.util import Clock from synapse.util.stringutils import random_string +MYPY = False +if MYPY: + from synapse.server import HomeServer + + connection_close_counter = Counter( "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"] ) @@ -539,7 +543,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): def __init__( self, - hs: HomeServer, + hs: "HomeServer", client_name: str, server_name: str, clock: Clock, diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 757129b6d5ef..4374e99e3253 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -17,7 +17,7 @@ import logging import random -from typing import Any, List +from typing import Any, Dict, List from six import itervalues @@ -30,7 +30,7 @@ from synapse.util.metrics import Measure, measure_func from .protocol import ServerReplicationStreamProtocol -from .streams import STREAMS_MAP +from .streams import STREAMS_MAP, Stream from .streams.federation import FederationStream stream_updates_counter = Counter( @@ -52,7 +52,7 @@ class ReplicationStreamProtocolFactory(Factory): """ def __init__(self, hs): - self.streamer = ReplicationStreamer(hs) + self.streamer = hs.get_replication_streamer() self.clock = hs.get_clock() self.server_name = hs.config.server_name @@ -133,6 +133,11 @@ def on_shutdown(self): for conn in self.connections: conn.send_error("server shutting down") + def get_streams(self) -> Dict[str, Stream]: + """Get a mapp from stream name to stream instance. + """ + return self.streams_by_name + def on_notifier_poke(self): """Checks if there is actually any new data and sends it to the connections if there are. diff --git a/synapse/server.py b/synapse/server.py index 1b980371de31..9426eb167279 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -85,6 +85,7 @@ from synapse.notifier import Notifier from synapse.push.action_generator import ActionGenerator from synapse.push.pusherpool import PusherPool +from synapse.replication.tcp.resource import ReplicationStreamer from synapse.rest.media.v1.media_repository import ( MediaRepository, MediaRepositoryResource, @@ -199,6 +200,7 @@ def build_DEPENDENCY(self) "saml_handler", "event_client_serializer", "storage", + "replication_streamer", ] REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"] @@ -536,6 +538,9 @@ def build_event_client_serializer(self): def build_storage(self) -> Storage: return Storage(self, self.datastores) + def build_replication_streamer(self) -> ReplicationStreamer: + return ReplicationStreamer(self) + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) From e4c5b1d9d6be6a38224d274e8f38099d0ab550ac Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 24 Mar 2020 16:00:54 +0000 Subject: [PATCH 12/16] Review comments --- docs/tcp_replication.md | 3 ++- synapse/replication/tcp/protocol.py | 1 - synapse/replication/tcp/streams/_base.py | 3 --- 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md index 5b26f70f88c1..d4f7d9ec18f0 100644 --- a/docs/tcp_replication.md +++ b/docs/tcp_replication.md @@ -175,7 +175,8 @@ client (C): #### POSITION (S) On receipt of a POSITION command clients should check if they have missed any - updates, and if so then fetch them out of band. + updates, and if so then fetch them out of band. Sent in response to a + REPLICATE command (but can happen at any time). #### ERROR (S, C) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 67de5c3e7ede..f81d2e2442d2 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -648,7 +648,6 @@ async def on_POSITION(self, cmd: PositionCommand): # We've now caught up to position sent to us, notify handler. await self.handler.on_position(cmd.stream_name, cmd.token) - # We're now up to date wit the stream self.streams_connecting.discard(cmd.stream_name) if not self.streams_connecting: self.handler.finished_connecting() diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index d5b9c2831b3a..d7e9371a00a5 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -105,9 +105,6 @@ async def get_updates_since( to fetch. """ - if from_token in ("NOW", "now"): - return [], upto_token, False - from_token = int(from_token) if from_token == upto_token: From 309aee4636217092ace6b657f444ea49891c5bce Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 24 Mar 2020 16:20:05 +0000 Subject: [PATCH 13/16] Move calling http replication out of base stream --- synapse/replication/tcp/streams/_base.py | 96 +++++++++++-------- synapse/replication/tcp/streams/events.py | 5 +- synapse/replication/tcp/streams/federation.py | 6 +- 3 files changed, 60 insertions(+), 47 deletions(-) diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index d7e9371a00a5..d64cbc5cc87b 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -16,7 +16,7 @@ import logging from collections import namedtuple -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Awaitable, Callable, List, Optional, Tuple, Union import attr @@ -40,10 +40,6 @@ class Stream(object): # The type of the row. Used by the default impl of parse_row. ROW_TYPE = None # type: Any - # Whether the update function is only available on master. If True then - # calls to get updates are proxied to the master via a HTTP call. - _QUERY_MASTER = False - @classmethod def parse_row(cls, row): """Parse a row received over replication @@ -60,10 +56,6 @@ def parse_row(cls, row): return cls.ROW_TYPE(*row) def __init__(self, hs): - self._is_worker = hs.config.worker_app is not None - - if self._QUERY_MASTER and self._is_worker: - self._replication_client = ReplicationGetStreamUpdates.make_client(hs) # The token from which we last asked for updates self.last_token = self.current_token() @@ -110,23 +102,10 @@ async def get_updates_since( if from_token == upto_token: return [], upto_token, False - if self._is_worker and self._QUERY_MASTER: - result = await self._replication_client( - stream_name=self.NAME, - from_token=from_token, - upto_token=upto_token, - limit=limit, - ) - return result["updates"], result["upto_token"], result["limited"] - else: - limited = False - rows = await self.update_function(from_token, upto_token, limit=limit) - updates = [(row[0], row[1:]) for row in rows] - if len(updates) == limit: - upto_token = rows[-1][0] - limited = True - - return updates, upto_token, limited + updates, upto_token, limited = await self.update_function( + from_token, upto_token, limit=limit, + ) + return updates, upto_token, limited def current_token(self): """Gets the current token of the underlying streams. Should be provided @@ -148,6 +127,26 @@ def update_function(self, from_token, current_token, limit): raise NotImplementedError() +def db_query_to_update_function( + query_function: Callable[[int, int, int], Awaitable[List[tuple]]] +) -> Callable[[int, int, int], Awaitable[Tuple[List[Tuple[int, tuple]], int, bool]]]: + """Wraps a db query function which returns a list of rows to make it + suitable for use as an `update_function` for the Stream class + """ + + async def update_function(from_token, upto_token, limit): + rows = await query_function(from_token, upto_token, limit) + updates = [(row[0], row[1:]) for row in rows] + limited = False + if len(updates) == limit: + upto_token = rows[-1][0] + limited = True + + return updates, upto_token, limited + + return update_function + + class BackfillStream(Stream): """We fetched some old events and either we had never seen that event before or it went from being an outlier to not. @@ -171,7 +170,7 @@ class BackfillStream(Stream): def __init__(self, hs): store = hs.get_datastore() self.current_token = store.get_current_backfill_token # type: ignore - self.update_function = store.get_all_new_backfill_event_rows # type: ignore + self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows) # type: ignore super(BackfillStream, self).__init__(hs) @@ -192,16 +191,20 @@ class PresenceStream(Stream): NAME = "presence" ROW_TYPE = PresenceStreamRow - _QUERY_MASTER = True def __init__(self, hs): store = hs.get_datastore() presence_handler = hs.get_presence_handler() + self._is_worker = hs.config.worker_app is not None + self.current_token = store.get_current_presence_token # type: ignore if hs.config.worker_app is None: - self.update_function = presence_handler.get_all_presence_updates # type: ignore + self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates) # type: ignore + else: + # Query master process + self.update_function = ReplicationGetStreamUpdates.make_client(hs) # type: ignore super(PresenceStream, self).__init__(hs) @@ -213,7 +216,6 @@ class TypingStream(Stream): NAME = "typing" ROW_TYPE = TypingStreamRow - _QUERY_MASTER = True def __init__(self, hs): typing_handler = hs.get_typing_handler() @@ -221,7 +223,10 @@ def __init__(self, hs): self.current_token = typing_handler.get_current_token # type: ignore if hs.config.worker_app is None: - self.update_function = typing_handler.get_all_typing_updates # type: ignore + self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates) # type: ignore + else: + # Query master process + self.update_function = ReplicationGetStreamUpdates.make_client(hs) # type: ignore super(TypingStream, self).__init__(hs) @@ -245,7 +250,7 @@ def __init__(self, hs): store = hs.get_datastore() self.current_token = store.get_max_receipt_stream_id # type: ignore - self.update_function = store.get_all_updated_receipts # type: ignore + self.update_function = db_query_to_update_function(store.get_all_updated_receipts) # type: ignore super(ReceiptsStream, self).__init__(hs) @@ -269,7 +274,13 @@ def current_token(self): async def update_function(self, from_token, to_token, limit): rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit) - return [(row[0], row[2]) for row in rows] + + limited = False + if len(rows) == limit: + to_token = rows[-1][0] + limited = True + + return [(row[0], row[2]) for row in rows], to_token, limited class PushersStream(Stream): @@ -288,7 +299,7 @@ def __init__(self, hs): store = hs.get_datastore() self.current_token = store.get_pushers_stream_token # type: ignore - self.update_function = store.get_all_updated_pushers_rows # type: ignore + self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows) # type: ignore super(PushersStream, self).__init__(hs) @@ -320,7 +331,7 @@ def __init__(self, hs): store = hs.get_datastore() self.current_token = store.get_cache_stream_token # type: ignore - self.update_function = store.get_all_updated_caches # type: ignore + self.update_function = db_query_to_update_function(store.get_all_updated_caches) # type: ignore super(CachesStream, self).__init__(hs) @@ -346,7 +357,7 @@ def __init__(self, hs): store = hs.get_datastore() self.current_token = store.get_current_public_room_stream_id # type: ignore - self.update_function = store.get_all_new_public_rooms # type: ignore + self.update_function = db_query_to_update_function(store.get_all_new_public_rooms) # type: ignore super(PublicRoomsStream, self).__init__(hs) @@ -367,7 +378,7 @@ def __init__(self, hs): store = hs.get_datastore() self.current_token = store.get_device_stream_token # type: ignore - self.update_function = store.get_all_device_list_changes_for_remotes # type: ignore + self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes) # type: ignore super(DeviceListsStream, self).__init__(hs) @@ -385,7 +396,7 @@ def __init__(self, hs): store = hs.get_datastore() self.current_token = store.get_to_device_stream_token # type: ignore - self.update_function = store.get_all_new_device_messages # type: ignore + self.update_function = db_query_to_update_function(store.get_all_new_device_messages) # type: ignore super(ToDeviceStream, self).__init__(hs) @@ -405,7 +416,7 @@ def __init__(self, hs): store = hs.get_datastore() self.current_token = store.get_max_account_data_stream_id # type: ignore - self.update_function = store.get_all_updated_tags # type: ignore + self.update_function = db_query_to_update_function(store.get_all_updated_tags) # type: ignore super(TagAccountDataStream, self).__init__(hs) @@ -425,10 +436,11 @@ def __init__(self, hs): self.store = hs.get_datastore() self.current_token = self.store.get_max_account_data_stream_id # type: ignore + self.update_function = db_query_to_update_function(self._update_function) # type: ignore super(AccountDataStream, self).__init__(hs) - async def update_function(self, from_token, to_token, limit): + async def _update_function(self, from_token, to_token, limit): global_results, room_results = await self.store.get_all_updated_account_data( from_token, from_token, to_token, limit ) @@ -455,7 +467,7 @@ def __init__(self, hs): store = hs.get_datastore() self.current_token = store.get_group_stream_token # type: ignore - self.update_function = store.get_all_groups_changes # type: ignore + self.update_function = db_query_to_update_function(store.get_all_groups_changes) # type: ignore super(GroupServerStream, self).__init__(hs) @@ -473,6 +485,6 @@ def __init__(self, hs): store = hs.get_datastore() self.current_token = store.get_device_stream_token # type: ignore - self.update_function = store.get_all_user_signature_changes_for_remotes # type: ignore + self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes) # type: ignore super(UserSignatureStream, self).__init__(hs) diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index b3afabb8cde3..c6a595629f7b 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -19,7 +19,7 @@ import attr -from ._base import Stream +from ._base import Stream, db_query_to_update_function """Handling of the 'events' replication stream @@ -117,10 +117,11 @@ class EventsStream(Stream): def __init__(self, hs): self._store = hs.get_datastore() self.current_token = self._store.get_current_events_token # type: ignore + self.update_function = db_query_to_update_function(self._update_function) # type: ignore super(EventsStream, self).__init__(hs) - async def update_function(self, from_token, current_token, limit=None): + async def _update_function(self, from_token, current_token, limit=None): event_rows = await self._store.get_all_new_forward_event_rows( from_token, current_token, limit ) diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index 67e0eaa262d0..48c1d4571824 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -17,7 +17,7 @@ from twisted.internet import defer -from synapse.replication.tcp.streams._base import Stream +from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function class FederationStream(Stream): @@ -44,9 +44,9 @@ def __init__(self, hs): if hs.config.worker_app is None or hs.should_send_federation(): federation_sender = hs.get_federation_sender() self.current_token = federation_sender.get_current_token # type: ignore - self.update_function = federation_sender.get_replication_rows # type: ignore + self.update_function = db_query_to_update_function(federation_sender.get_replication_rows) # type: ignore else: self.current_token = lambda: 0 # type: ignore - self.update_function = lambda *args, **kwargs: defer.succeed([]) # type: ignore + self.update_function = lambda from_token, upto_token, limit: defer.succeed(([], upto_token, bool)) # type: ignore super(FederationStream, self).__init__(hs) From bd64b8fcd5074ece79149805b9ecf9ecbd948566 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 24 Mar 2020 16:52:17 +0000 Subject: [PATCH 14/16] Fixup push rules stream --- synapse/replication/tcp/streams/_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index d64cbc5cc87b..2699e466bcfa 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -280,7 +280,7 @@ async def update_function(self, from_token, to_token, limit): to_token = rows[-1][0] limited = True - return [(row[0], row[2]) for row in rows], to_token, limited + return [(row[0], (row[2],)) for row in rows], to_token, limited class PushersStream(Stream): From f8038f4670cb562fcdb98f57f6bdf734745f1a8f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 24 Mar 2020 17:31:51 +0000 Subject: [PATCH 15/16] Fix HTTP update_function --- synapse/replication/tcp/streams/_base.py | 26 ++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 2699e466bcfa..007b105d4df2 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -147,6 +147,28 @@ async def update_function(from_token, upto_token, limit): return update_function +def make_http_update_function( + hs, stream_name: str +) -> Callable[[int, int, int], Awaitable[Tuple[List[Tuple[int, tuple]], int, bool]]]: + """Makes a suitable function for use as an `update_function` that queries + the master process for updates. + """ + + client = ReplicationGetStreamUpdates.make_client(hs) + + async def update_function( + from_token: int, upto_token: int, limit: int + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + return await client( + stream_name=stream_name, + from_token=from_token, + upto_token=upto_token, + limit=limit, + ) + + return update_function + + class BackfillStream(Stream): """We fetched some old events and either we had never seen that event before or it went from being an outlier to not. @@ -204,7 +226,7 @@ def __init__(self, hs): self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates) # type: ignore else: # Query master process - self.update_function = ReplicationGetStreamUpdates.make_client(hs) # type: ignore + self.update_function = make_http_update_function(hs, self.NAME) # type: ignore super(PresenceStream, self).__init__(hs) @@ -226,7 +248,7 @@ def __init__(self, hs): self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates) # type: ignore else: # Query master process - self.update_function = ReplicationGetStreamUpdates.make_client(hs) # type: ignore + self.update_function = make_http_update_function(hs, self.NAME) # type: ignore super(TypingStream, self).__init__(hs) From 309c7eb1a197b940d11249bca4fd8c19b7e84a07 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 24 Mar 2020 17:43:42 +0000 Subject: [PATCH 16/16] Add some type aliases --- synapse/replication/tcp/streams/_base.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 007b105d4df2..c14dff6c6484 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -16,7 +16,7 @@ import logging from collections import namedtuple -from typing import Any, Awaitable, Callable, List, Optional, Tuple, Union +from typing import Any, Awaitable, Callable, List, Optional, Tuple import attr @@ -29,6 +29,15 @@ MAX_EVENTS_BEHIND = 500000 +# Some type aliases to make things a bit easier. + +# A stream position token +Token = int + +# A pair of position in stream and args used to create an instance of `ROW_TYPE`. +StreamRow = Tuple[Token, tuple] + + class Stream(object): """Base class for the streams. @@ -66,7 +75,7 @@ def discard_updates_and_advance(self): """ self.last_token = self.current_token() - async def get_updates(self) -> Tuple[List[Tuple[int, JsonDict]], int, bool]: + async def get_updates(self) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]: """Gets all updates since the last time this function was called (or since the stream was constructed if it hadn't been called before). @@ -85,8 +94,8 @@ async def get_updates(self) -> Tuple[List[Tuple[int, JsonDict]], int, bool]: return updates, current_token, limited async def get_updates_since( - self, from_token: Union[int, str], upto_token: int, limit: int = 100 - ) -> Tuple[List[Tuple[int, JsonDict]], int, bool]: + self, from_token: Token, upto_token: Token, limit: int = 100 + ) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]: """Like get_updates except allows specifying from when we should stream updates @@ -128,8 +137,8 @@ def update_function(self, from_token, current_token, limit): def db_query_to_update_function( - query_function: Callable[[int, int, int], Awaitable[List[tuple]]] -) -> Callable[[int, int, int], Awaitable[Tuple[List[Tuple[int, tuple]], int, bool]]]: + query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]] +) -> Callable[[Token, Token, int], Awaitable[Tuple[List[StreamRow], Token, bool]]]: """Wraps a db query function which returns a list of rows to make it suitable for use as an `update_function` for the Stream class """ @@ -149,7 +158,7 @@ async def update_function(from_token, upto_token, limit): def make_http_update_function( hs, stream_name: str -) -> Callable[[int, int, int], Awaitable[Tuple[List[Tuple[int, tuple]], int, bool]]]: +) -> Callable[[Token, Token, Token], Awaitable[Tuple[List[StreamRow], Token, bool]]]: """Makes a suitable function for use as an `update_function` that queries the master process for updates. """