diff --git a/changelog.d/11967.feature b/changelog.d/11967.feature new file mode 100644 index 000000000000..d09320a2906c --- /dev/null +++ b/changelog.d/11967.feature @@ -0,0 +1 @@ +Experimental implementation of [MSC3706](https://github.com/matrix-org/matrix-doc/pull/3706): extensions to `/send_join` to support reduced response size. diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index f05a803a7104..09d692d9a1ed 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -61,3 +61,6 @@ def read_config(self, config: JsonDict, **kwargs): self.msc2409_to_device_messages_enabled: bool = experimental.get( "msc2409_to_device_messages_enabled", False ) + + # MSC3706 (server-side support for partial state in /send_join responses) + self.msc3706_enabled: bool = experimental.get("msc3706_enabled", False) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 9c44cccf0339..482bbdd86744 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -20,6 +20,7 @@ Any, Awaitable, Callable, + Collection, Dict, Iterable, List, @@ -64,7 +65,7 @@ ReplicationGetQueryRestServlet, ) from synapse.storage.databases.main.lock import Lock -from synapse.types import JsonDict, get_domain_from_id +from synapse.types import JsonDict, StateMap, get_domain_from_id from synapse.util import json_decoder, unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results from synapse.util.caches.response_cache import ResponseCache @@ -645,19 +646,40 @@ async def on_invite_request( return {"event": ret_pdu.get_pdu_json(time_now)} async def on_send_join_request( - self, origin: str, content: JsonDict, room_id: str + self, + origin: str, + content: JsonDict, + room_id: str, + caller_supports_partial_state: bool = False, ) -> Dict[str, Any]: event, context = await self._on_send_membership_event( origin, content, Membership.JOIN, room_id ) prev_state_ids = await context.get_prev_state_ids() - state_event_ids = prev_state_ids.values() + + state_event_ids: Collection[str] + servers_in_room: Optional[Collection[str]] + if caller_supports_partial_state: + state_event_ids = _get_event_ids_for_partial_state_join( + event, prev_state_ids + ) + servers_in_room = await self.state.get_hosts_in_room_at_events( + room_id, event_ids=event.prev_event_ids() + ) + else: + state_event_ids = prev_state_ids.values() + servers_in_room = None auth_chain_event_ids = await self.store.get_auth_chain_ids( room_id, state_event_ids ) + # if the caller has opted in, we can omit any auth_chain events which are + # already in state_event_ids + if caller_supports_partial_state: + auth_chain_event_ids.difference_update(state_event_ids) + auth_chain_events = await self.store.get_events_as_list(auth_chain_event_ids) state_events = await self.store.get_events_as_list(state_event_ids) @@ -671,7 +693,12 @@ async def on_send_join_request( "event": event_json, "state": [p.get_pdu_json(time_now) for p in state_events], "auth_chain": [p.get_pdu_json(time_now) for p in auth_chain_events], + "org.matrix.msc3706.partial_state": caller_supports_partial_state, } + + if servers_in_room is not None: + resp["org.matrix.msc3706.servers_in_room"] = list(servers_in_room) + return resp async def on_make_leave_request( @@ -1347,3 +1374,39 @@ async def on_query(self, query_type: str, args: dict) -> JsonDict: # error. logger.warning("No handler registered for query type %s", query_type) raise NotFoundError("No handler for Query type '%s'" % (query_type,)) + + +def _get_event_ids_for_partial_state_join( + join_event: EventBase, + prev_state_ids: StateMap[str], +) -> Collection[str]: + """Calculate state to be retuned in a partial_state send_join + + Args: + join_event: the join event being send_joined + prev_state_ids: the event ids of the state before the join + + Returns: + the event ids to be returned + """ + + # return all non-member events + state_event_ids = { + event_id + for (event_type, state_key), event_id in prev_state_ids.items() + if event_type != EventTypes.Member + } + + # we also need the current state of the current user (it's going to + # be an auth event for the new join, so we may as well return it) + current_membership_event_id = prev_state_ids.get( + (EventTypes.Member, join_event.state_key) + ) + if current_membership_event_id is not None: + state_event_ids.add(current_membership_event_id) + + # TODO: return a few more members: + # - those with invites + # - those that are kicked? / banned + + return state_event_ids diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index d86dfede4e89..310733097da5 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -412,6 +412,16 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet): PREFIX = FEDERATION_V2_PREFIX + def __init__( + self, + hs: "HomeServer", + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self._msc3706_enabled = hs.config.experimental.msc3706_enabled + async def on_PUT( self, origin: str, @@ -422,7 +432,15 @@ async def on_PUT( ) -> Tuple[int, JsonDict]: # TODO(paul): assert that event_id parsed from path actually # match those given in content - result = await self.handler.on_send_join_request(origin, content, room_id) + + partial_state = False + if self._msc3706_enabled: + partial_state = parse_boolean_from_args( + query, "org.matrix.msc3706.partial_state", default=False + ) + result = await self.handler.on_send_join_request( + origin, content, room_id, caller_supports_partial_state=partial_state + ) return 200, result diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index f766ee3f893e..d084919ef7d7 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -248,6 +248,57 @@ def test_send_join(self): ) self.assertEqual(r[("m.room.member", joining_user)].membership, "join") + @override_config({"experimental_features": {"msc3706_enabled": True}}) + def test_send_join_partial_state(self): + """When MSC3706 support is enabled, /send_join should return partial state""" + joining_user = "@misspiggy:" + self.OTHER_SERVER_NAME + join_result = self._make_join(joining_user) + + join_event_dict = join_result["event"] + add_hashes_and_signatures( + KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION], + join_event_dict, + signature_name=self.OTHER_SERVER_NAME, + signing_key=self.OTHER_SERVER_SIGNATURE_KEY, + ) + channel = self.make_signed_federation_request( + "PUT", + f"/_matrix/federation/v2/send_join/{self._room_id}/x?org.matrix.msc3706.partial_state=true", + content=join_event_dict, + ) + self.assertEquals(channel.code, 200, channel.json_body) + + # expect a reduced room state + returned_state = [ + (ev["type"], ev["state_key"]) for ev in channel.json_body["state"] + ] + self.assertCountEqual( + returned_state, + [ + ("m.room.create", ""), + ("m.room.power_levels", ""), + ("m.room.join_rules", ""), + ("m.room.history_visibility", ""), + ], + ) + + # the auth chain should not include anything already in "state" + returned_auth_chain_events = [ + (ev["type"], ev["state_key"]) for ev in channel.json_body["auth_chain"] + ] + self.assertCountEqual( + returned_auth_chain_events, + [ + ("m.room.member", "@kermit:test"), + ], + ) + + # the room should show that the new user is a member + r = self.get_success( + self.hs.get_state_handler().get_current_state(self._room_id) + ) + self.assertEqual(r[("m.room.member", joining_user)].membership, "join") + def _create_acl_event(content): return make_event_from_dict(