Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Support pagination tokens from sync/messages in the relations API #11952

Merged
merged 9 commits into from
Feb 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11952.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a long-standing bug where pagination tokens from `/sync` and `/messages` could not be provided to the `/relations` API.
57 changes: 39 additions & 18 deletions synapse/rest/client/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,45 @@
PaginationChunk,
RelationPaginationToken,
)
from synapse.types import JsonDict
from synapse.types import JsonDict, RoomStreamToken, StreamToken

if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.storage.databases.main import DataStore

logger = logging.getLogger(__name__)


async def _parse_token(
store: "DataStore", token: Optional[str]
) -> Optional[StreamToken]:
"""
For backwards compatibility support RelationPaginationToken, but new pagination
tokens are generated as full StreamTokens, to be compatible with /sync and /messages.
"""
if not token:
return None
# Luckily the format for StreamToken and RelationPaginationToken differ enough
# that they can easily be separated. An "_" appears in the serialization of
# RoomStreamToken (as part of StreamToken), but RelationPaginationToken uses
# "-" only for separators.
if "_" in token:
return await StreamToken.from_string(store, token)
else:
relation_token = RelationPaginationToken.from_string(token)
return StreamToken(
room_key=RoomStreamToken(relation_token.topological, relation_token.stream),
presence_key=0,
typing_key=0,
receipt_key=0,
account_data_key=0,
push_rules_key=0,
to_device_key=0,
device_list_key=0,
groups_key=0,
)


class RelationPaginationServlet(RestServlet):
"""API to paginate relations on an event by topological ordering, optionally
filtered by relation type and event type.
Expand Down Expand Up @@ -88,13 +119,8 @@ async def on_GET(
pagination_chunk = PaginationChunk(chunk=[])
else:
# Return the relations
from_token = None
if from_token_str:
from_token = RelationPaginationToken.from_string(from_token_str)

to_token = None
if to_token_str:
to_token = RelationPaginationToken.from_string(to_token_str)
from_token = await _parse_token(self.store, from_token_str)
to_token = await _parse_token(self.store, to_token_str)

pagination_chunk = await self.store.get_relations_for_event(
event_id=parent_id,
Expand Down Expand Up @@ -125,7 +151,7 @@ async def on_GET(
events, now, bundle_aggregations=aggregations
)

return_value = pagination_chunk.to_dict()
return_value = await pagination_chunk.to_dict(self.store)
return_value["chunk"] = serialized_events
return_value["original_event"] = original_event

Expand Down Expand Up @@ -216,7 +242,7 @@ async def on_GET(
to_token=to_token,
)

return 200, pagination_chunk.to_dict()
return 200, await pagination_chunk.to_dict(self.store)


class RelationAggregationGroupPaginationServlet(RestServlet):
Expand Down Expand Up @@ -287,13 +313,8 @@ async def on_GET(
from_token_str = parse_string(request, "from")
to_token_str = parse_string(request, "to")

from_token = None
if from_token_str:
from_token = RelationPaginationToken.from_string(from_token_str)

to_token = None
if to_token_str:
to_token = RelationPaginationToken.from_string(to_token_str)
from_token = await _parse_token(self.store, from_token_str)
to_token = await _parse_token(self.store, to_token_str)

result = await self.store.get_relations_for_event(
event_id=parent_id,
Expand All @@ -313,7 +334,7 @@ async def on_GET(
now = self.clock.time_msec()
serialized_events = self._event_serializer.serialize_events(events, now)

return_value = result.to_dict()
return_value = await result.to_dict(self.store)
return_value["chunk"] = serialized_events

return 200, return_value
Expand Down
46 changes: 31 additions & 15 deletions synapse/storage/databases/main/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,13 @@
)
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.engines import PostgresEngine
from synapse.storage.relations import (
AggregationPaginationToken,
PaginationChunk,
RelationPaginationToken,
)
from synapse.types import JsonDict
from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
from synapse.types import JsonDict, RoomStreamToken, StreamToken
from synapse.util.caches.descriptors import cached, cachedList

if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.storage.databases.main import DataStore

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -98,8 +95,8 @@ async def get_relations_for_event(
aggregation_key: Optional[str] = None,
limit: int = 5,
direction: str = "b",
from_token: Optional[RelationPaginationToken] = None,
to_token: Optional[RelationPaginationToken] = None,
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
) -> PaginationChunk:
"""Get a list of relations for an event, ordered by topological ordering.

Expand Down Expand Up @@ -138,8 +135,10 @@ async def get_relations_for_event(
pagination_clause = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type]
to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type]
from_token=from_token.room_key.as_historical_tuple()
if from_token
else None,
to_token=to_token.room_key.as_historical_tuple() if to_token else None,
engine=self.database_engine,
)

Expand Down Expand Up @@ -177,12 +176,27 @@ def _get_recent_references_for_event_txn(
last_topo_id = row[1]
last_stream_id = row[2]

next_batch = None
# If there are more events, generate the next pagination key.
next_token = None
if len(events) > limit and last_topo_id and last_stream_id:
next_batch = RelationPaginationToken(last_topo_id, last_stream_id)
next_key = RoomStreamToken(last_topo_id, last_stream_id)
if from_token:
next_token = from_token.copy_and_replace("room_key", next_key)
else:
next_token = StreamToken(
room_key=next_key,
presence_key=0,
typing_key=0,
receipt_key=0,
account_data_key=0,
push_rules_key=0,
to_device_key=0,
device_list_key=0,
groups_key=0,
)

return PaginationChunk(
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
chunk=list(events[:limit]), next_batch=next_token, prev_batch=from_token
)

return await self.db_pool.runInteraction(
Expand Down Expand Up @@ -676,13 +690,15 @@ async def _get_bundled_aggregation_for_event(

annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
if annotations.chunk:
aggregations.annotations = annotations.to_dict()
aggregations.annotations = await annotations.to_dict(
cast("DataStore", self)
)

references = await self.get_relations_for_event(
event_id, room_id, RelationTypes.REFERENCE, direction="f"
)
if references.chunk:
aggregations.references = references.to_dict()
aggregations.references = await references.to_dict(cast("DataStore", self))

# If this event is the start of a thread, include a summary of the replies.
if self._msc3440_enabled:
Expand Down
15 changes: 9 additions & 6 deletions synapse/storage/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
# limitations under the License.

import logging
from typing import Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

import attr

from synapse.api.errors import SynapseError
from synapse.types import JsonDict

if TYPE_CHECKING:
from synapse.storage.databases.main import DataStore

logger = logging.getLogger(__name__)


Expand All @@ -39,14 +42,14 @@ class PaginationChunk:
next_batch: Optional[Any] = None
prev_batch: Optional[Any] = None

def to_dict(self) -> Dict[str, Any]:
async def to_dict(self, store: "DataStore") -> Dict[str, Any]:
d = {"chunk": self.chunk}

if self.next_batch:
d["next_batch"] = self.next_batch.to_string()
d["next_batch"] = await self.next_batch.to_string(store)

if self.prev_batch:
d["prev_batch"] = self.prev_batch.to_string()
d["prev_batch"] = await self.prev_batch.to_string(store)

return d

Expand Down Expand Up @@ -75,7 +78,7 @@ def from_string(string: str) -> "RelationPaginationToken":
except ValueError:
raise SynapseError(400, "Invalid relation pagination token")

def to_string(self) -> str:
async def to_string(self, store: "DataStore") -> str:
reivilibre marked this conversation as resolved.
Show resolved Hide resolved
return "%d-%d" % (self.topological, self.stream)

def as_tuple(self) -> Tuple[Any, ...]:
Expand Down Expand Up @@ -105,7 +108,7 @@ def from_string(string: str) -> "AggregationPaginationToken":
except ValueError:
raise SynapseError(400, "Invalid aggregation pagination token")

def to_string(self) -> str:
async def to_string(self, store: "DataStore") -> str:
return "%d-%d" % (self.count, self.stream)

def as_tuple(self) -> Tuple[Any, ...]:
Expand Down
Loading