Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add an API for listing threads in a room. #13394

Merged
merged 18 commits into from
Oct 13, 2022
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/13394.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Experimental support for [MSC3856](https://github.com/matrix-org/matrix-spec-proposals/pull/3856): threads list API.
2 changes: 2 additions & 0 deletions synapse/_scripts/synapse_port_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
RegistrationBackgroundUpdateStore,
find_max_generated_user_id_localpart,
)
from synapse.storage.databases.main.relations import RelationsWorkerStore
from synapse.storage.databases.main.room import RoomBackgroundUpdateStore
from synapse.storage.databases.main.roommember import RoomMemberBackgroundUpdateStore
from synapse.storage.databases.main.search import SearchBackgroundUpdateStore
Expand Down Expand Up @@ -206,6 +207,7 @@ class Store(
PusherWorkerStore,
PresenceBackgroundUpdateStore,
ReceiptsBackgroundUpdateStore,
RelationsWorkerStore,
):
def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]:
return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)
Expand Down
3 changes: 3 additions & 0 deletions synapse/config/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
# MSC3848: Introduce errcodes for specific event sending failures
self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False)

# MSC3856: Threads list API
self.msc3856_enabled: bool = experimental.get("msc3856_enabled", False)

# MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices.
self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False)

Expand Down
86 changes: 85 additions & 1 deletion synapse/handlers/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import logging
from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple

Expand All @@ -20,7 +21,7 @@
from synapse.api.errors import SynapseError
from synapse.events import EventBase, relation_from_event
from synapse.logging.opentracing import trace
from synapse.storage.databases.main.relations import _RelatedEvent
from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, Requester, StreamToken, UserID
from synapse.visibility import filter_events_for_client
Expand All @@ -32,6 +33,13 @@
logger = logging.getLogger(__name__)


class ThreadsListInclude(str, enum.Enum):
"""Valid values for the 'include' flag of /threads."""

all = "all"
participated = "participated"


@attr.s(slots=True, frozen=True, auto_attribs=True)
class _ThreadAggregation:
# The latest event in the thread.
Expand Down Expand Up @@ -482,3 +490,79 @@ async def get_bundled_aggregations(
results.setdefault(event_id, BundledAggregations()).replace = edit

return results

async def get_threads(
self,
requester: Requester,
room_id: str,
include: ThreadsListInclude,
limit: int = 5,
from_token: Optional[ThreadsNextBatch] = None,
) -> JsonDict:
"""Get related events of a event, ordered by topological ordering.

Args:
requester: The user requesting the relations.
room_id: The room the event belongs to.
include: One of "all" or "participated" to indicate which threads should
be returned.
limit: Only fetch the most recent `limit` events.
from_token: Fetch rows from the given token, or from the start if None.

Returns:
The pagination chunk.
"""

user_id = requester.user.to_string()

# TODO Properly handle a user leaving a room.
(_, member_event_id) = await self._auth.check_user_in_room_or_world_readable(
room_id, requester, allow_departed_users=True
)

# Note that ignored users are not passed into get_relations_for_event
# below. Ignored users are handled in filter_events_for_client (and by
# not passing them in here we should get a better cache hit rate).
thread_roots, next_batch = await self._main_store.get_threads(
room_id=room_id, limit=limit, from_token=from_token
)

events = await self._main_store.get_events_as_list(thread_roots)

if include == ThreadsListInclude.participated:
# Pre-seed thread participation with whether the requester sent the event.
participated = {event.event_id: event.sender == user_id for event in events}
# For events the requester did not send, check the database for whether
# the requester sent a threaded reply.
participated.update(
await self._main_store.get_threads_participated(
[eid for eid, p in participated.items() if not p],
user_id,
)
)

# Limit the returned threads to those the user has participated in.
events = [event for event in events if participated[event.event_id]]

events = await filter_events_for_client(
self._storage_controllers,
user_id,
events,
is_peeking=(member_event_id is None),
)

aggregations = await self.get_bundled_aggregations(
events, requester.user.to_string()
)

now = self._clock.time_msec()
serialized_events = self._event_serializer.serialize_events(
events, now, bundle_aggregations=aggregations
)

return_value: JsonDict = {"chunk": serialized_events}

if next_batch:
return_value["next_batch"] = str(next_batch)

return return_value
50 changes: 49 additions & 1 deletion synapse/rest/client/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
# limitations under the License.

import logging
import re
from typing import TYPE_CHECKING, Optional, Tuple

from synapse.handlers.relations import ThreadsListInclude
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.storage.databases.main.relations import ThreadsNextBatch
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict

Expand Down Expand Up @@ -78,5 +81,50 @@ async def on_GET(
return 200, result


class ThreadsServlet(RestServlet):
PATTERNS = (
re.compile(
"^/_matrix/client/unstable/org.matrix.msc3856/rooms/(?P<room_id>[^/]*)/threads"
),
)

def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
self._relations_handler = hs.get_relations_handler()

async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)

limit = parse_integer(request, "limit", default=5)
from_token_str = parse_string(request, "from")
include = parse_string(
request,
"include",
default=ThreadsListInclude.all.value,
allowed_values=[v.value for v in ThreadsListInclude],
)

# Return the relations
from_token = None
if from_token_str:
from_token = ThreadsNextBatch.from_string(from_token_str)

result = await self._relations_handler.get_threads(
requester=requester,
room_id=room_id,
include=ThreadsListInclude(include),
limit=limit,
from_token=from_token,
)

return 200, result


def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RelationPaginationServlet(hs).register(http_server)
if hs.config.experimental.msc3856_enabled:
ThreadsServlet(hs).register(http_server)
1 change: 1 addition & 0 deletions synapse/storage/databases/main/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def _invalidate_caches_for_event(
self._attempt_to_invalidate_cache(
"get_mutual_event_relations_for_rel_type", (relates_to,)
)
self._attempt_to_invalidate_cache("get_threads", (room_id,))

async def invalidate_cache_and_stream(
self, cache_name: str, keys: Tuple[Any, ...]
Expand Down
36 changes: 33 additions & 3 deletions synapse/storage/databases/main/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from prometheus_client import Counter

import synapse.metrics
from synapse.api.constants import EventContentFields, EventTypes
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase, relation_from_event
Expand Down Expand Up @@ -1616,7 +1616,7 @@ def _update_metadata_tables_txn(
)

# Remove from relations table.
self._handle_redact_relations(txn, event.redacts)
self._handle_redact_relations(txn, event.room_id, event.redacts)

# Update the event_forward_extremities, event_backward_extremities and
# event_edges tables.
Expand Down Expand Up @@ -1866,6 +1866,34 @@ def _handle_event_relations(
},
)

if relation.rel_type == RelationTypes.THREAD:
# Upsert into the threads table, but only overwrite the value if the
# new event is of a later topological order OR if the topological
# ordering is equal, but the stream ordering is later.
sql = """
INSERT INTO threads (room_id, thread_id, latest_event_id, topological_ordering, stream_ordering)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT (room_id, thread_id)
DO UPDATE SET
latest_event_id = excluded.latest_event_id,
topological_ordering = excluded.topological_ordering,
stream_ordering = excluded.stream_ordering
WHERE
threads.topological_ordering <= excluded.topological_ordering AND
threads.stream_ordering < excluded.stream_ordering
"""

txn.execute(
sql,
(
event.room_id,
relation.parent_id,
event.event_id,
event.depth,
event.internal_metadata.stream_ordering,
),
)

def _handle_insertion_event(
self, txn: LoggingTransaction, event: EventBase
) -> None:
Expand Down Expand Up @@ -1989,13 +2017,14 @@ def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase) -> None
txn.execute(sql, (batch_id,))

def _handle_redact_relations(
self, txn: LoggingTransaction, redacted_event_id: str
self, txn: LoggingTransaction, room_id: str, redacted_event_id: str
) -> None:
"""Handles receiving a redaction and checking whether the redacted event
has any relations which must be removed from the database.

Args:
txn
room_id: The room ID of the event that was redacted.
redacted_event_id: The event that was redacted.
"""

Expand Down Expand Up @@ -2024,6 +2053,7 @@ def _handle_redact_relations(
self.store._invalidate_cache_and_stream(
txn, self.store.get_thread_participated, (redacted_relates_to,)
)
txn.call_after(self.store.get_threads.invalidate, (room_id,))
self.store._invalidate_cache_and_stream(
txn,
self.store.get_mutual_event_relations_for_rel_type,
Expand Down
Loading