Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Lay some foundation work to allow workers to only subscribe to some kinds of messages, reducing replication traffic. #12672

Merged
merged 12 commits into from
May 19, 2022
1 change: 1 addition & 0 deletions changelog.d/12672.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Lay some foundation work to allow workers to only subscribe to some kinds of messages, reducing replication traffic.
34 changes: 32 additions & 2 deletions synapse/replication/tcp/handler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright 2017 Vector Creations Ltd
# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright 2020, 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -101,6 +101,9 @@ def __init__(self, hs: "HomeServer"):
self._instance_id = hs.get_instance_id()
self._instance_name = hs.get_instance_name()

# Additional Redis channel suffixes to subscribe to.
self._channels_to_subscribe_to: List[str] = []

self._is_presence_writer = (
hs.get_instance_name() in hs.config.worker.writers.presence
)
Expand Down Expand Up @@ -243,6 +246,31 @@ def __init__(self, hs: "HomeServer"):
# If we're NOT using Redis, this must be handled by the master
self._should_insert_client_ips = hs.get_instance_name() == "master"

if self._is_master or self._should_insert_client_ips:
self.subscribe_to_channel("USER_IP")

def subscribe_to_channel(self, channel_name: str) -> None:
"""
Indicates that we wish to subscribe to a Redis channel by name.

(The name will later be prefixed with the server name; i.e. subscribing
to the 'ABC' channel actually subscribes to 'example.com/ABC' Redis-side.)

Raises:
- If replication has already started, then it's too late to subscribe
to new channels.
"""

if self._factory is not None:
# We don't allow subscribing after the fact to avoid the chance
# of missing an important message because we didn't subscribe in time.
raise RuntimeError(
"Cannot subscribe to more channels after replication started."
)

if channel_name not in self._channels_to_subscribe_to:
self._channels_to_subscribe_to.append(channel_name)

def _add_command_to_stream_queue(
self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand]
) -> None:
Expand Down Expand Up @@ -321,7 +349,9 @@ def start_replication(self, hs: "HomeServer") -> None:

# Now create the factory/connection for the subscription stream.
self._factory = RedisDirectTcpReplicationClientFactory(
hs, outbound_redis_connection
hs,
outbound_redis_connection,
channel_names=self._channels_to_subscribe_to,
)
hs.get_reactor().connectTCP(
hs.config.redis.redis_host,
Expand Down
35 changes: 25 additions & 10 deletions synapse/replication/tcp/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import logging
from inspect import isawaitable
from typing import TYPE_CHECKING, Any, Generic, Optional, Type, TypeVar, cast
from typing import TYPE_CHECKING, Any, Generic, List, Optional, Type, TypeVar, cast

import attr
import txredisapi
Expand Down Expand Up @@ -85,14 +85,15 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):

Attributes:
synapse_handler: The command handler to handle incoming commands.
synapse_stream_name: The *redis* stream name to subscribe to and publish
synapse_stream_prefix: The *redis* stream name to subscribe to and publish
from (not anything to do with Synapse replication streams).
synapse_outbound_redis_connection: The connection to redis to use to send
commands.
"""

synapse_handler: "ReplicationCommandHandler"
synapse_stream_name: str
synapse_stream_prefix: str
synapse_channel_names: List[str]
synapse_outbound_redis_connection: txredisapi.ConnectionHandler

def __init__(self, *args: Any, **kwargs: Any):
Expand All @@ -117,8 +118,13 @@ async def _send_subscribe(self) -> None:
# it's important to make sure that we only send the REPLICATE command once we
# have successfully subscribed to the stream - otherwise we might miss the
# POSITION response sent back by the other end.
logger.info("Sending redis SUBSCRIBE for %s", self.synapse_stream_name)
await make_deferred_yieldable(self.subscribe(self.synapse_stream_name))
fully_qualified_stream_names = [
f"{self.synapse_stream_prefix}/{stream_suffix}"
for stream_suffix in self.synapse_channel_names
] + [self.synapse_stream_prefix]
logger.info("Sending redis SUBSCRIBE for %r", fully_qualified_stream_names)
await make_deferred_yieldable(self.subscribe(fully_qualified_stream_names))

logger.info(
"Successfully subscribed to redis stream, sending REPLICATE command"
)
Expand Down Expand Up @@ -217,7 +223,7 @@ async def _async_send_command(self, cmd: Command) -> None:

await make_deferred_yieldable(
self.synapse_outbound_redis_connection.publish(
self.synapse_stream_name, encoded_string
self.synapse_stream_prefix, encoded_string
)
)

Expand Down Expand Up @@ -300,20 +306,27 @@ def format_address(address: IAddress) -> str:

class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
"""This is a reconnecting factory that connects to redis and immediately
subscribes to a stream.
subscribes to some streams.

Args:
hs
outbound_redis_connection: A connection to redis that will be used to
send outbound commands (this is separate to the redis connection
used to subscribe).
channel_names: A list of channel names to append to the base channel name
to additionally subscribe to.
e.g. if ['ABC', 'DEF'] is specified then we'll listen to:
example.com; example.com/ABC; and example.com/DEF.
"""

maxDelay = 5
protocol = RedisSubscriber

def __init__(
self, hs: "HomeServer", outbound_redis_connection: txredisapi.ConnectionHandler
self,
hs: "HomeServer",
outbound_redis_connection: txredisapi.ConnectionHandler,
channel_names: List[str],
):

super().__init__(
Expand All @@ -326,7 +339,8 @@ def __init__(
)

self.synapse_handler = hs.get_replication_command_handler()
self.synapse_stream_name = hs.hostname
self.synapse_stream_prefix = hs.hostname
self.synapse_channel_names = channel_names

self.synapse_outbound_redis_connection = outbound_redis_connection

Expand All @@ -340,7 +354,8 @@ def buildProtocol(self, addr: IAddress) -> RedisSubscriber:
# protocol.
p.synapse_handler = self.synapse_handler
p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection
p.synapse_stream_name = self.synapse_stream_name
p.synapse_stream_prefix = self.synapse_stream_prefix
p.synapse_channel_names = self.synapse_channel_names

return p

Expand Down
54 changes: 42 additions & 12 deletions tests/replication/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Dict, List, Optional, Tuple
from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Tuple

from twisted.internet.address import IPv4Address
from twisted.internet.protocol import Protocol
Expand All @@ -32,6 +33,7 @@

from tests import unittest
from tests.server import FakeTransport
from tests.utils import USE_POSTGRES_FOR_TESTS

try:
import hiredis
Expand Down Expand Up @@ -475,22 +477,25 @@ class FakeRedisPubSubServer:
"""A fake Redis server for pub/sub."""

def __init__(self):
self._subscribers = set()
self._subscribers_by_channel: Dict[
bytes, Set["FakeRedisPubSubProtocol"]
] = defaultdict(set)

def add_subscriber(self, conn):
def add_subscriber(self, conn, channel: bytes):
"""A connection has called SUBSCRIBE"""
self._subscribers.add(conn)
self._subscribers_by_channel[channel].add(conn)

def remove_subscriber(self, conn):
"""A connection has called UNSUBSCRIBE"""
self._subscribers.discard(conn)
"""A connection has lost connection"""
for subscribers in self._subscribers_by_channel.values():
subscribers.discard(conn)

def publish(self, conn, channel, msg) -> int:
def publish(self, conn, channel: bytes, msg) -> int:
"""A connection want to publish a message to subscribers."""
for sub in self._subscribers:
for sub in self._subscribers_by_channel[channel]:
sub.send(["message", channel, msg])

return len(self._subscribers)
return len(self._subscribers_by_channel)

def buildProtocol(self, addr):
return FakeRedisPubSubProtocol(self)
Expand Down Expand Up @@ -531,9 +536,10 @@ def handle_command(self, command, *args):
num_subscribers = self._server.publish(self, channel, message)
self.send(num_subscribers)
elif command == b"SUBSCRIBE":
(channel,) = args
self._server.add_subscriber(self)
self.send(["subscribe", channel, 1])
for idx, channel in enumerate(args):
num_channels = idx + 1
self._server.add_subscriber(self, channel)
self.send(["subscribe", channel, num_channels])

# Since we use SET/GET to cache things we can safely no-op them.
elif command == b"SET":
Expand Down Expand Up @@ -576,3 +582,27 @@ def encode(self, obj):

def connectionLost(self, reason):
self._server.remove_subscriber(self)


class RedisMultiWorkerStreamTestCase(BaseMultiWorkerStreamTestCase):
"""
A test case that enables Redis, providing a fake Redis server.
"""

if not hiredis:
skip = "Requires hiredis"

if not USE_POSTGRES_FOR_TESTS:
# Redis replication only takes place on Postgres
skip = "Requires Postgres"

def default_config(self) -> Dict[str, Any]:
"""
Overrides the default config to enable Redis.
Even if the test only uses make_worker_hs, the main process needs Redis
enabled otherwise it won't create a Fake Redis server to listen on the
Redis port and accept fake TCP connections.
"""
base = super().default_config()
base["redis"] = {"enabled": True}
return base
73 changes: 73 additions & 0 deletions tests/replication/tcp/test_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from tests.replication._base import RedisMultiWorkerStreamTestCase


class ChannelsTestCase(RedisMultiWorkerStreamTestCase):
def test_subscribed_to_enough_redis_channels(self) -> None:
# The default main process is subscribed to the USER_IP channel.
self.assertCountEqual(
self.hs.get_replication_command_handler()._channels_to_subscribe_to,
["USER_IP"],
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test that checks we do in fact connect to these channels? Presumably we can look at what's in FakeRedisPubSubServer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point.

I've added a new base test case class that enables Redis (because even if you're not using it with the main process for anything, you must enable it in order for the Fake Redis server to listen on a fake TCP port) and skips Postgres (since it turns out nothing subscribes to Redis when Postgres is not in use...).


def test_background_worker_subscribed_to_user_ip(self) -> None:
# The default main process is subscribed to the USER_IP channel.
worker1 = self.make_worker_hs(
"synapse.app.generic_worker",
extra_config={
"worker_name": "worker1",
"run_background_tasks_on": "worker1",
"redis": {"enabled": True},
},
)
self.assertIn(
"USER_IP",
worker1.get_replication_command_handler()._channels_to_subscribe_to,
)

# Advance so the Redis subscription gets processed
self.pump(0.1)

# The counts are 2 because both the main process and the worker are subscribed.
self.assertEqual(len(self._redis_server._subscribers_by_channel[b"test"]), 2)
self.assertEqual(
len(self._redis_server._subscribers_by_channel[b"test/USER_IP"]), 2
)

def test_non_background_worker_not_subscribed_to_user_ip(self) -> None:
# The default main process is subscribed to the USER_IP channel.
worker2 = self.make_worker_hs(
"synapse.app.generic_worker",
extra_config={
"worker_name": "worker2",
"run_background_tasks_on": "worker1",
"redis": {"enabled": True},
},
)
self.assertNotIn(
"USER_IP",
worker2.get_replication_command_handler()._channels_to_subscribe_to,
)

# Advance so the Redis subscription gets processed
self.pump(0.1)

# The count is 2 because both the main process and the worker are subscribed.
self.assertEqual(len(self._redis_server._subscribers_by_channel[b"test"]), 2)
# For USER_IP, the count is 1 because only the main process is subscribed.
self.assertEqual(
len(self._redis_server._subscribers_by_channel[b"test/USER_IP"]), 1
)