Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add unstable /keys/claim endpoint which always returns fallback keys #15462

Merged
merged 6 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/15462.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Update support for [MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983) to allow always returning fallback-keys in a `/keys/claim` request.
6 changes: 4 additions & 2 deletions synapse/federation/federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,15 +1005,17 @@ async def on_query_user_devices(

@trace
async def on_claim_client_keys(
self, origin: str, content: JsonDict
self, origin: str, content: JsonDict, always_include_fallback_keys: bool
) -> Dict[str, Any]:
query = []
for user_id, device_keys in content.get("one_time_keys", {}).items():
for device_id, algorithm in device_keys.items():
query.append((user_id, device_id, algorithm))

log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
results = await self._e2e_keys_handler.claim_local_one_time_keys(query)
results = await self._e2e_keys_handler.claim_local_one_time_keys(
query, always_include_fallback_keys=always_include_fallback_keys
)

json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for result in results:
Expand Down
6 changes: 6 additions & 0 deletions synapse/federation/transport/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from synapse.federation.transport.server.federation import (
FEDERATION_SERVLET_CLASSES,
FederationAccountStatusServlet,
FederationUnstableClientKeysClaimServlet,
)
from synapse.http.server import HttpServer, JsonResource
from synapse.http.servlet import (
Expand Down Expand Up @@ -298,6 +299,11 @@ def register_servlets(
and not hs.config.experimental.msc3720_enabled
):
continue
if (
servletclass == FederationUnstableClientKeysClaimServlet
and not hs.config.experimental.msc3983_appservice_otk_claims
):
continue

servletclass(
hs=hs,
Expand Down
23 changes: 22 additions & 1 deletion synapse/federation/transport/server/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,28 @@ class FederationClientKeysClaimServlet(BaseFederationServerServlet):
async def on_POST(
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
) -> Tuple[int, JsonDict]:
response = await self.handler.on_claim_client_keys(origin, content)
response = await self.handler.on_claim_client_keys(
origin, content, always_include_fallback_keys=False
)
return 200, response


class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet):
"""
Identical to the stable endpoint (FederationClientKeysClaimServlet) except it
always includes fallback keys in the response.
"""

PREFIX = FEDERATION_UNSTABLE_PREFIX
PATH = "/user/keys/claim"
CATEGORY = "Federation requests"

async def on_POST(
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
) -> Tuple[int, JsonDict]:
response = await self.handler.on_claim_client_keys(
origin, content, always_include_fallback_keys=True
)
return 200, response


Expand Down
13 changes: 5 additions & 8 deletions synapse/handlers/appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,9 +842,7 @@ async def _check_user_exists(self, user_id: str) -> bool:

async def claim_e2e_one_time_keys(
self, query: Iterable[Tuple[str, str, str]]
) -> Tuple[
Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]], List[Tuple[str, str, str]]
]:
) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
"""Claim one time keys from application services.

Users which are exclusively owned by an application service are sent a
Expand All @@ -856,7 +854,7 @@ async def claim_e2e_one_time_keys(

Returns:
A tuple of:
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
A map of user ID -> a map device ID -> a map of key ID -> JSON.

A copy of the input which has not been fulfilled (either because
they are not appservice users or the appservice does not support
Expand Down Expand Up @@ -897,12 +895,11 @@ async def claim_e2e_one_time_keys(
)

# Patch together the results -- they are all independent (since they
# require exclusive control over the users). They get returned as a list
# and the caller combines them.
claimed_keys: List[Dict[str, Dict[str, Dict[str, JsonDict]]]] = []
# require exclusive control over the users, which is the outermost key).
claimed_keys: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for success, result in results:
if success:
claimed_keys.append(result[0])
claimed_keys.update(result[0])
missing.extend(result[1])

return claimed_keys, missing
Expand Down
70 changes: 63 additions & 7 deletions synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,9 @@ async def on_federation_query_client_keys(
return ret

async def claim_local_one_time_keys(
self, local_query: List[Tuple[str, str, str]]
self,
local_query: List[Tuple[str, str, str]],
always_include_fallback_keys: bool,
) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]:
"""Claim one time keys for local users.

Expand All @@ -573,6 +575,7 @@ async def claim_local_one_time_keys(

Args:
local_query: An iterable of tuples of (user ID, device ID, algorithm).
always_include_fallback_keys: True to always include fallback keys.

Returns:
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
Expand All @@ -583,24 +586,73 @@ async def claim_local_one_time_keys(
# If the application services have not provided any keys via the C-S
# API, query it directly for one-time keys.
if self._query_appservices_for_otks:
# TODO Should this query for fallback keys of uploaded OTKs if
# always_include_fallback_keys is True? The MSC is ambiguous.
(
appservice_results,
not_found,
) = await self._appservice_handler.claim_e2e_one_time_keys(not_found)
else:
appservice_results = []
appservice_results = {}

# Calculate which user ID / device ID / algorithm tuples to get fallback
# keys for. This can be either only missing results *or* all results
# (which don't already have a fallback key).
if always_include_fallback_keys:
# Build the fallback query as any part of the original query where
# the appservice didn't respond with a fallback key.
fallback_query = []

# Iterate each item in the original query and search the results
# from the appservice for that user ID / device ID. If it is found,
# check if any of the keys match the requested algorithm & are a
# fallback key.
for user_id, device_id, algorithm in local_query:
# Check if the appservice responded for this query.
as_result = appservice_results.get(user_id, {}).get(device_id, {})
found_otk = False
for key_id, key_json in as_result.items():
if key_id.startswith(f"{algorithm}:"):
# A OTK or fallback key was found for this query.
found_otk = True
# A fallback key was found for this query, no need to
# query further.
if key_json.get("fallback", False):
break

else:
# No fallback key was found from appservices, query for it.
# Only mark the fallback key as used if no OTK was found
# (from either the database or appservices).
mark_as_used = not found_otk and not any(
key_id.startswith(f"{algorithm}:")
for key_id in otk_results.get(user_id, {})
.get(device_id, {})
.keys()
)
fallback_query.append((user_id, device_id, algorithm, mark_as_used))

else:
# All fallback keys get marked as used.
fallback_query = [
(user_id, device_id, algorithm, True)
for user_id, device_id, algorithm in not_found
]

# For each user that does not have a one-time keys available, see if
# there is a fallback key.
fallback_results = await self.store.claim_e2e_fallback_keys(not_found)
fallback_results = await self.store.claim_e2e_fallback_keys(fallback_query)

# Return the results in order, each item from the input query should
# only appear once in the combined list.
return (otk_results, *appservice_results, fallback_results)
return (otk_results, appservice_results, fallback_results)

@trace
async def claim_one_time_keys(
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int]
self,
query: Dict[str, Dict[str, Dict[str, str]]],
timeout: Optional[int],
always_include_fallback_keys: bool,
) -> JsonDict:
local_query: List[Tuple[str, str, str]] = []
remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
Expand All @@ -617,15 +669,19 @@ async def claim_one_time_keys(
set_tag("local_key_query", str(local_query))
set_tag("remote_key_query", str(remote_queries))

results = await self.claim_local_one_time_keys(local_query)
results = await self.claim_local_one_time_keys(
local_query, always_include_fallback_keys
)

# A map of user ID -> device ID -> key ID -> key.
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for result in results:
for user_id, device_keys in result.items():
for device_id, keys in device_keys.items():
for key_id, key in keys.items():
json_result.setdefault(user_id, {})[device_id] = {key_id: key}
json_result.setdefault(user_id, {}).setdefault(
device_id, {}
).update({key_id: key})

# Remote failures.
failures: Dict[str, JsonDict] = {}
Expand Down
31 changes: 30 additions & 1 deletion synapse/rest/client/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.

import logging
import re
from typing import TYPE_CHECKING, Any, Optional, Tuple

from synapse.api.errors import InvalidAPICallError, SynapseError
Expand Down Expand Up @@ -288,7 +289,33 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
result = await self.e2e_keys_handler.claim_one_time_keys(body, timeout)
result = await self.e2e_keys_handler.claim_one_time_keys(
body, timeout, always_include_fallback_keys=False
)
return 200, result


class UnstableOneTimeKeyServlet(RestServlet):
"""
Identical to the stable endpoint (OneTimeKeyServlet) except it always includes
fallback keys in the response.
"""

PATTERNS = [re.compile(r"^/_matrix/client/unstable/org.matrix.msc3983/keys/claim$")]
CATEGORY = "Encryption requests"

def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()

async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
result = await self.e2e_keys_handler.claim_one_time_keys(
body, timeout, always_include_fallback_keys=True
)
return 200, result


Expand Down Expand Up @@ -394,6 +421,8 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
KeyQueryServlet(hs).register(http_server)
KeyChangesServlet(hs).register(http_server)
OneTimeKeyServlet(hs).register(http_server)
if hs.config.experimental.msc3983_appservice_otk_claims:
UnstableOneTimeKeyServlet(hs).register(http_server)
if hs.config.worker.worker_app is None:
SigningKeyUploadServlet(hs).register(http_server)
SignaturesUploadServlet(hs).register(http_server)
9 changes: 5 additions & 4 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,18 +1149,19 @@ def _claim_e2e_one_time_key_returning(
return results, missing

async def claim_e2e_fallback_keys(
self, query_list: Iterable[Tuple[str, str, str]]
self, query_list: Iterable[Tuple[str, str, str, bool]]
) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
"""Take a list of fallback keys out of the database.

Args:
query_list: An iterable of tuples of (user ID, device ID, algorithm).
query_list: An iterable of tuples of
(user ID, device ID, algorithm, whether the key should be marked as used).

Returns:
A map of user ID -> a map device ID -> a map of key ID -> JSON.
"""
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for user_id, device_id, algorithm in query_list:
for user_id, device_id, algorithm, mark_as_used in query_list:
row = await self.db_pool.simple_select_one(
table="e2e_fallback_keys_json",
keyvalues={
Expand All @@ -1180,7 +1181,7 @@ async def claim_e2e_fallback_keys(
used = row["used"]

# Mark fallback key as used if not already.
if not used:
if not used and mark_as_used:
await self.db_pool.simple_update_one(
table="e2e_fallback_keys_json",
keyvalues={
Expand Down
Loading