diff --git a/state/accumulator.go b/state/accumulator.go index f2230438..75ecfc2c 100644 --- a/state/accumulator.go +++ b/state/accumulator.go @@ -139,69 +139,62 @@ type InitialiseResult struct { // AddedEvents is true iff this call to Initialise added new state events to the DB. AddedEvents bool // SnapshotID is the ID of the snapshot which incorporates all added events. - // It has no meaning if AddedEvents is False. + // It has no meaning if AddedEvents is false. SnapshotID int64 - // PrependTimelineEvents is empty if the room was not initialised prior to this call. - // Otherwise, it is an order-preserving subset of the `state` argument to Initialise - // containing all events that were not persisted prior to the Initialise call. These - // should be prepended to the room timeline by the caller. - PrependTimelineEvents []json.RawMessage + // ReplacedExistingSnapshot is true when we created a new snapshot for the room and + // there a pre-existing room snapshot. It has no meaning if AddedEvents is false. + ReplacedExistingSnapshot bool } -// Initialise starts a new sync accumulator for the given room using the given state as a baseline. +// Initialise processes the state block of a V2 sync response for a particular room. If +// the state of the room has changed, we persist any new state events and create a new +// "snapshot" of its entire state. // -// This will only take effect if this is the first time the v3 server has seen this room, and it wasn't -// possible to get all events up to the create event (e.g Matrix HQ). -// This function: -// - Stores these events -// - Sets up the current snapshot based on the state list given. +// Summary of the logic: // -// If the v3 server has seen this room before, this function -// - queries the DB to determine which state events are known to th server, -// - returns (via InitialiseResult.PrependTimelineEvents) a slice of unknown state events, +// 0. Ensure the state block is not empty. // -// and otherwise does nothing. +// 1. Capture the current snapshot ID, possibly zero. If it is zero, ensure that the +// state block contains a `create event`. +// +// 2. Insert the events. If there are no newly inserted events, bail. If there are new +// events, then the state block has definitely changed. Note: we ignore cases where +// the state has only changed to a known subset of state events (i.e in the case of +// state resets, slow pollers) as it is impossible to then reconcile that state with +// any new events, as any "catchup" state will be ignored due to the events already +// existing. +// +// 3. Fetch the current state of the room, as a map from (type, state_key) to event. +// If there is no existing state snapshot, this map is the empty map. +// If the state hasn't altered, bail. +// +// 4. Create new snapshot. Update the map from (3) with the events in `state`. +// (There is similar logic for this in Accumulate.) +// Store the snapshot. Mark the room's current state as being this snapshot. +// +// 5. Any other processing of the new state events. +// +// 6. Return an "AddedEvents" bool (if true, emit an Initialise payload) and a +// "ReplacedSnapshot" bool (if true, emit a cache invalidation payload). + func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (InitialiseResult, error) { var res InitialiseResult + var startingSnapshotID int64 + + // 0. Ensure the state block is not empty. if len(state) == 0 { return res, nil } - err := sqlutil.WithTransaction(a.db, func(txn *sqlx.Tx) error { + err := sqlutil.WithTransaction(a.db, func(txn *sqlx.Tx) (err error) { + // 1. Capture the current snapshot ID, checking for a create event if this is our first snapshot. + // Attempt to short-circuit. This has to be done inside a transaction to make sure // we don't race with multiple calls to Initialise with the same room ID. - snapshotID, err := a.roomsTable.CurrentAfterSnapshotID(txn, roomID) + startingSnapshotID, err = a.roomsTable.CurrentAfterSnapshotID(txn, roomID) if err != nil { - return fmt.Errorf("error fetching snapshot id for room %s: %s", roomID, err) - } - if snapshotID > 0 { - // Poller A has received a gappy sync v2 response with a state block, and - // we have seen this room before. If we knew for certain that there is some - // other active poller B in this room then we could safely skip this logic. - - // Log at debug for now. If we find an unknown event, we'll return it so - // that the poller can log a warning. - logger.Debug().Str("room_id", roomID).Int64("snapshot_id", snapshotID).Msg("Accumulator.Initialise called with incremental state but current snapshot already exists.") - eventIDs := make([]string, len(state)) - eventIDToRawEvent := make(map[string]json.RawMessage, len(state)) - for i := range state { - eventID := gjson.ParseBytes(state[i]).Get("event_id") - if !eventID.Exists() || eventID.Type != gjson.String { - return fmt.Errorf("Event %d lacks an event ID", i) - } - eventIDToRawEvent[eventID.Str] = state[i] - eventIDs[i] = eventID.Str - } - unknownEventIDs, err := a.eventsTable.SelectUnknownEventIDs(txn, eventIDs) - if err != nil { - return fmt.Errorf("error determing which event IDs are unknown: %s", err) - } - for unknownEventID := range unknownEventIDs { - res.PrependTimelineEvents = append(res.PrependTimelineEvents, eventIDToRawEvent[unknownEventID]) - } - return nil + return fmt.Errorf("error fetching snapshot id for room %s: %w", roomID, err) } - - // We don't have a snapshot for this room. Parse the events first. + // Start by parsing the events in the state block. events := make([]Event, len(state)) for i := range events { events[i] = Event{ @@ -212,71 +205,68 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (Initia } events = filterAndEnsureFieldsSet(events) if len(events) == 0 { - return fmt.Errorf("failed to insert events, all events were filtered out: %w", err) + return fmt.Errorf("failed to parse state block, all events were filtered out: %w", err) } - // Before proceeding further, ensure that we have "proper" state and not just a - // single stray event by looking for the create event. - hasCreate := false - for _, e := range events { - if e.Type == "m.room.create" && e.StateKey == "" { - hasCreate = true - break + if startingSnapshotID == 0 { + // Ensure that we have "proper" state and not "stray" events from Synapse. + if err = ensureStateHasCreateEvent(events); err != nil { + return err } } - if !hasCreate { - const errMsg = "cannot create first snapshot without a create event" - sentry.WithScope(func(scope *sentry.Scope) { - scope.SetContext(internal.SentryCtxKey, map[string]interface{}{ - "room_id": roomID, - "len_state": len(events), - }) - sentry.CaptureMessage(errMsg) - }) - logger.Warn(). - Str("room_id", roomID). - Int("len_state", len(events)). - Msg(errMsg) - // the HS gave us bad data so there's no point retrying => return DataError - return internal.NewDataError(errMsg) - } - // Insert the events. - eventIDToNID, err := a.eventsTable.Insert(txn, events, false) + // 2. Insert the events and determine which ones are new. + newEventIDToNID, err := a.eventsTable.Insert(txn, events, false) if err != nil { return fmt.Errorf("failed to insert events: %w", err) } - if len(eventIDToNID) == 0 { - // we don't have a current snapshot for this room but yet no events are new, - // no idea how this should be handled. - const errMsg = "Accumulator.Initialise: room has no current snapshot but also no new inserted events, doing nothing. This is probably a bug." - logger.Error().Str("room_id", roomID).Msg(errMsg) - sentry.CaptureException(fmt.Errorf(errMsg)) + if len(newEventIDToNID) == 0 { + if startingSnapshotID == 0 { + // we don't have a current snapshot for this room but yet no events are new, + // no idea how this should be handled. + const errMsg = "Accumulator.Initialise: room has no current snapshot but also no new inserted events, doing nothing. This is probably a bug." + logger.Error().Str("room_id", roomID).Msg(errMsg) + sentry.CaptureException(fmt.Errorf(errMsg)) + } + // Note: we otherwise ignore cases where the state has only changed to a + // known subset of state events (i.e in the case of state resets, slow + // pollers) as it is impossible to then reconcile that state with + // any new events, as any "catchup" state will be ignored due to the events + // already existing. return nil } - - // pull out the event NIDs we just inserted - membershipEventIDs := make(map[string]struct{}, len(events)) + newEvents := make([]Event, 0, len(newEventIDToNID)) for _, event := range events { - if event.Type == "m.room.member" { - membershipEventIDs[event.ID] = struct{}{} + newNid, isNew := newEventIDToNID[event.ID] + if isNew { + event.NID = newNid + newEvents = append(newEvents, event) } } - memberNIDs := make([]int64, 0, len(eventIDToNID)) - otherNIDs := make([]int64, 0, len(eventIDToNID)) - for evID, nid := range eventIDToNID { - if _, exists := membershipEventIDs[evID]; exists { - memberNIDs = append(memberNIDs, int64(nid)) - } else { - otherNIDs = append(otherNIDs, int64(nid)) + + // 3. Fetch the current state of the room. + var currentState stateMap + if startingSnapshotID > 0 { + currentState, err = a.stateMapAtSnapshot(txn, startingSnapshotID) + if err != nil { + return fmt.Errorf("failed to load state map: %w", err) + } + } else { + currentState = stateMap{ + Memberships: make(map[string]int64, len(events)), + Other: make(map[[2]string]int64, len(events)), } } - // Make a current snapshot + // 4. Update the map from (3) with the new events to create a new snapshot. + for _, ev := range newEvents { + currentState.Ingest(ev) + } + memberNIDs, otherNIDs := currentState.NIDs() snapshot := &SnapshotRow{ RoomID: roomID, - MembershipEvents: pq.Int64Array(memberNIDs), - OtherEvents: pq.Int64Array(otherNIDs), + MembershipEvents: memberNIDs, + OtherEvents: otherNIDs, } err = a.snapshotTable.Insert(txn, snapshot) if err != nil { @@ -307,8 +297,16 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (Initia // will have an associated state snapshot ID on the event. // Set the snapshot ID as the current state + err = a.roomsTable.Upsert(txn, info, snapshot.SnapshotID, latestNID) + if err != nil { + return err + } + + // 6. Tell the caller what happened, so they know what payloads to emit. res.SnapshotID = snapshot.SnapshotID - return a.roomsTable.Upsert(txn, info, snapshot.SnapshotID, latestNID) + res.AddedEvents = true + res.ReplacedExistingSnapshot = startingSnapshotID > 0 + return nil }) return res, err } @@ -641,3 +639,82 @@ func (a *Accumulator) filterToNewTimelineEvents(txn *sqlx.Tx, dedupedEvents []Ev // A is seen event s[A,B,C] => s[0+1:] => [B,C] return dedupedEvents[seenIndex+1:], nil } + +func ensureStateHasCreateEvent(events []Event) error { + hasCreate := false + for _, e := range events { + if e.Type == "m.room.create" && e.StateKey == "" { + hasCreate = true + break + } + } + if !hasCreate { + const errMsg = "cannot create first snapshot without a create event" + sentry.WithScope(func(scope *sentry.Scope) { + scope.SetContext(internal.SentryCtxKey, map[string]interface{}{ + "room_id": events[0].RoomID, + "len_state": len(events), + }) + sentry.CaptureMessage(errMsg) + }) + logger.Warn(). + Str("room_id", events[0].RoomID). + Int("len_state", len(events)). + Msg(errMsg) + // the HS gave us bad data so there's no point retrying => return DataError + return internal.NewDataError(errMsg) + } + return nil +} + +type stateMap struct { + // state_key (user id) -> NID + Memberships map[string]int64 + // type, state_key -> NID + Other map[[2]string]int64 +} + +func (s *stateMap) Ingest(e Event) (replacedNID int64) { + if e.Type == "m.room.member" { + replacedNID = s.Memberships[e.StateKey] + s.Memberships[e.StateKey] = e.NID + } else { + key := [2]string{e.Type, e.StateKey} + replacedNID = s.Other[key] + s.Other[key] = e.NID + } + return +} + +func (s *stateMap) NIDs() (membershipNIDs, otherNIDs []int64) { + membershipNIDs = make([]int64, 0, len(s.Memberships)) + otherNIDs = make([]int64, 0, len(s.Other)) + for _, nid := range s.Memberships { + membershipNIDs = append(membershipNIDs, nid) + } + for _, nid := range s.Other { + otherNIDs = append(otherNIDs, nid) + } + return +} + +func (a *Accumulator) stateMapAtSnapshot(txn *sqlx.Tx, snapID int64) (stateMap, error) { + snapshot, err := a.snapshotTable.Select(txn, snapID) + if err != nil { + return stateMap{}, err + } + // pull stripped events as this may be huge (think Matrix HQ) + events, err := a.eventsTable.SelectStrippedEventsByNIDs(txn, true, append(snapshot.MembershipEvents, snapshot.OtherEvents...)) + if err != nil { + return stateMap{}, err + } + + state := stateMap{ + Memberships: make(map[string]int64, len(snapshot.MembershipEvents)), + Other: make(map[[2]string]int64, len(snapshot.OtherEvents)), + } + for _, e := range events { + state.Ingest(e) + } + return state, nil +} diff --git a/state/accumulator_test.go b/state/accumulator_test.go index 0a73febb..e6258da5 100644 --- a/state/accumulator_test.go +++ b/state/accumulator_test.go @@ -35,9 +35,8 @@ func TestAccumulatorInitialise(t *testing.T) { if err != nil { t.Fatalf("falied to Initialise accumulator: %s", err) } - if !res.AddedEvents { - t.Fatalf("didn't add events, wanted it to") - } + assertValue(t, "res.AddedEvents", res.AddedEvents, true) + assertValue(t, "res.ReplacedExistingSnapshot", res.ReplacedExistingSnapshot, false) txn, err := accumulator.db.Beginx() if err != nil { @@ -46,21 +45,21 @@ func TestAccumulatorInitialise(t *testing.T) { defer txn.Rollback() // There should be one snapshot on the current state - snapID, err := accumulator.roomsTable.CurrentAfterSnapshotID(txn, roomID) + snapID1, err := accumulator.roomsTable.CurrentAfterSnapshotID(txn, roomID) if err != nil { t.Fatalf("failed to select current snapshot: %s", err) } - if snapID == 0 { + if snapID1 == 0 { t.Fatalf("Initialise did not store a current snapshot") } - if snapID != res.SnapshotID { - t.Fatalf("Initialise returned wrong snapshot ID, got %v want %v", res.SnapshotID, snapID) + if snapID1 != res.SnapshotID { + t.Fatalf("Initialise returned wrong snapshot ID, got %v want %v", res.SnapshotID, snapID1) } // this snapshot should have 1 member event and 2 other events in it - row, err := accumulator.snapshotTable.Select(txn, snapID) + row, err := accumulator.snapshotTable.Select(txn, snapID1) if err != nil { - t.Fatalf("failed to select snapshot %d: %s", snapID, err) + t.Fatalf("failed to select snapshot %d: %s", snapID1, err) } if len(row.MembershipEvents) != 1 { t.Fatalf("got %d membership events, want %d in current state snapshot", len(row.MembershipEvents), 1) @@ -87,7 +86,7 @@ func TestAccumulatorInitialise(t *testing.T) { } } - // Subsequent calls do nothing and are not an error + // Subsequent calls with the same set of the events do nothing and are not an error. res, err = accumulator.Initialise(roomID, roomEvents) if err != nil { t.Fatalf("falied to Initialise accumulator: %s", err) @@ -95,6 +94,37 @@ func TestAccumulatorInitialise(t *testing.T) { if res.AddedEvents { t.Fatalf("added events when it shouldn't have") } + + // Subsequent calls with a subset of events do nothing and are not an error + res, err = accumulator.Initialise(roomID, roomEvents[:2]) + if err != nil { + t.Fatalf("falied to Initialise accumulator: %s", err) + } + if res.AddedEvents { + t.Fatalf("added events when it shouldn't have") + } + + // Subsequent calls with at least one new event expand or replace existing state. + // C, D, E + roomEvents2 := append(roomEvents[2:3], + []byte(`{"event_id":"D", "type":"m.room.topic", "state_key":"", "content":{"topic":"Dr Rick Dagless MD"}}`), + []byte(`{"event_id":"E", "type":"m.room.member", "state_key":"@me:localhost", "content":{"membership":"join", "displayname": "Garth""}}`), + ) + res, err = accumulator.Initialise(roomID, roomEvents2) + assertNoError(t, err) + assertValue(t, "res.AddedEvents", res.AddedEvents, true) + assertValue(t, "res.ReplacedExistingSnapshot", res.ReplacedExistingSnapshot, true) + + snapID2, err := accumulator.roomsTable.CurrentAfterSnapshotID(txn, roomID) + assertNoError(t, err) + if snapID2 == snapID1 || snapID2 == 0 { + t.Errorf("Expected snapID2 (%d) to be neither snapID1 (%d) nor 0", snapID2, snapID1) + } + + row, err = accumulator.snapshotTable.Select(txn, snapID2) + assertNoError(t, err) + assertValue(t, "len(row.MembershipEvents)", len(row.MembershipEvents), 1) + assertValue(t, "len(row.OtherEvents)", len(row.OtherEvents), 3) } // Test that an unknown room shouldn't initialise if given state without a create event. @@ -115,9 +145,9 @@ func TestAccumulatorInitialiseBadInputs(t *testing.T) { func TestAccumulatorAccumulate(t *testing.T) { roomID := "!TestAccumulatorAccumulate:localhost" roomEvents := []json.RawMessage{ - []byte(`{"event_id":"D", "type":"m.room.create", "state_key":"", "content":{"creator":"@me:localhost"}}`), - []byte(`{"event_id":"E", "type":"m.room.member", "state_key":"@me:localhost", "content":{"membership":"join"}}`), - []byte(`{"event_id":"F", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`), + []byte(`{"event_id":"G", "type":"m.room.create", "state_key":"", "content":{"creator":"@me:localhost"}}`), + []byte(`{"event_id":"H", "type":"m.room.member", "state_key":"@me:localhost", "content":{"membership":"join"}}`), + []byte(`{"event_id":"I", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`), } db, close := connectToDB(t) defer close() @@ -130,11 +160,11 @@ func TestAccumulatorAccumulate(t *testing.T) { // accumulate new state makes a new snapshot and removes the old snapshot newEvents := []json.RawMessage{ // non-state event does nothing - []byte(`{"event_id":"G", "type":"m.room.message","content":{"body":"Hello World","msgtype":"m.text"}}`), + []byte(`{"event_id":"J", "type":"m.room.message","content":{"body":"Hello World","msgtype":"m.text"}}`), // join_rules should clobber the one from initialise - []byte(`{"event_id":"H", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`), + []byte(`{"event_id":"K", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`), // new state event should be added to the snapshot - []byte(`{"event_id":"I", "type":"m.room.history_visibility", "state_key":"", "content":{"visibility":"public"}}`), + []byte(`{"event_id":"L", "type":"m.room.history_visibility", "state_key":"", "content":{"visibility":"public"}}`), } var result AccumulateResult err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error { diff --git a/state/rooms_table.go b/state/rooms_table.go index e87f692e..f926f1a7 100644 --- a/state/rooms_table.go +++ b/state/rooms_table.go @@ -40,6 +40,15 @@ func (t *RoomsTable) SelectRoomInfos(txn *sqlx.Tx) (infos []RoomInfo, err error) return } +func (t *RoomsTable) SelectRoomInfo(txn *sqlx.Tx, roomID string) (info RoomInfo, err error) { + err = txn.Get(&info, ` + SELECT room_id, is_encrypted, upgraded_room_id, predecessor_room_id, type + FROM syncv3_rooms + WHERE room_id = $1 + `, roomID) + return +} + func (t *RoomsTable) Upsert(txn *sqlx.Tx, info RoomInfo, snapshotID, latestNID int64) (err error) { // This is a bit of a wonky query to ensure that you cannot set is_encrypted=false after it has been // set to true. diff --git a/state/storage.go b/state/storage.go index faf83937..f5bb325f 100644 --- a/state/storage.go +++ b/state/storage.go @@ -296,37 +296,54 @@ func (s *Storage) MetadataForAllRooms(txn *sqlx.Tx, tempTableName string, result return nil } -// ResetMetadataState updates the given metadata in-place to reflect the current state -// of the room. This is only safe to call from the subscriber goroutine; it is not safe -// to call from the connection goroutines. -// TODO: could have this create a new RoomMetadata and get the caller to assign it. -func (s *Storage) ResetMetadataState(metadata *internal.RoomMetadata) error { - var events []Event - err := s.DB.Select(&events, ` - WITH snapshot(events, membership_events) AS ( - SELECT events, membership_events - FROM syncv3_snapshots - JOIN syncv3_rooms ON snapshot_id = current_snapshot_id - WHERE syncv3_rooms.room_id = $1 - ) - SELECT event_id, event_type, state_key, event, membership - FROM syncv3_events JOIN snapshot ON ( - event_nid = ANY (ARRAY_CAT(events, membership_events)) - ) - WHERE (event_type IN ('m.room.name', 'm.room.avatar', 'm.room.canonical_alias') AND state_key = '') - OR (event_type = 'm.room.member' AND membership IN ('join', '_join', 'invite', '_invite')) - ORDER BY event_nid ASC - ;`, metadata.RoomID) +// FetchRoomMetadata is Like AllJoinedMembers and MetadataForAllRooms combined, but just +// for a single room. +func (s *Storage) FetchRoomMetadata(roomID string) (*internal.RoomMetadata, error) { + var stateEvents []Event + var latestEvents []Event + var roomInfo RoomInfo + err := sqlutil.WithTransaction(s.DB, func(txn *sqlx.Tx) error { + err := txn.Select(&stateEvents, ` + WITH snapshot(events, membership_events) AS ( + SELECT events, membership_events + FROM syncv3_snapshots + JOIN syncv3_rooms ON snapshot_id = current_snapshot_id + WHERE syncv3_rooms.room_id = $1 + ) + SELECT event_id, event_type, state_key, event, membership + FROM syncv3_events JOIN snapshot ON ( + event_nid = ANY (ARRAY_CAT(events, membership_events)) + ) + WHERE (event_type IN ('m.room.name', 'm.room.avatar', 'm.room.canonical_alias') AND state_key = '') + OR (event_type = 'm.room.member' AND membership IN ('join', '_join', 'invite', '_invite')) + ORDER BY event_nid ASC + ;`, roomID) + if err != nil { + return err + } + + latestEvents, err = s.Accumulator.eventsTable.selectLatestEventByTypeInAllRooms(txn) + if err != nil { + return err + } + + roomInfo, err = s.Accumulator.roomsTable.SelectRoomInfo(txn, roomID) + if err != nil { + return err + } + return nil + }) if err != nil { - return fmt.Errorf("ResetMetadataState[%s]: %w", metadata.RoomID, err) + return nil, fmt.Errorf("ResetMetadataState[%s]: %w", roomID, err) } + metadata := internal.NewRoomMetadata(roomID) + + // First, set fields by sweeping over state events. heroMemberships := circularSlice[*Event]{max: 6} - metadata.JoinCount = 0 - metadata.InviteCount = 0 metadata.ChildSpaceRooms = make(map[string]struct{}) - for i, ev := range events { + for i, ev := range stateEvents { switch ev.Type { case "m.room.name": metadata.NameEvent = gjson.GetBytes(ev.JSON, "content.name").Str @@ -335,7 +352,7 @@ func (s *Storage) ResetMetadataState(metadata *internal.RoomMetadata) error { case "m.room.canonical_alias": metadata.CanonicalAlias = gjson.GetBytes(ev.JSON, "content.alias").Str case "m.room.member": - heroMemberships.append(&events[i]) + heroMemberships.append(&stateEvents[i]) switch ev.Membership { case "join": fallthrough @@ -362,9 +379,64 @@ func (s *Storage) ResetMetadataState(metadata *internal.RoomMetadata) error { metadata.Heroes = append(metadata.Heroes, hero) } - // For now, don't bother reloading Encrypted, PredecessorID and UpgradedRoomID. - // These shouldn't be changing during a room's lifetime in normal operation. - return nil + // Second, set fields based on the latest events query. + for _, ev := range latestEvents { + parsed := gjson.ParseBytes(ev.JSON) + ts := parsed.Get("origin_server_ts").Uint() + if ts > metadata.LastMessageTimestamp { + metadata.LastMessageTimestamp = ts + } + metadata.LatestEventsByType[parsed.Get("type").Str] = internal.EventMetadata{ + NID: ev.NID, + Timestamp: ts, + } + } + + // Lastly, set fields based on the RoomInfo struct/query. + metadata.Encrypted = roomInfo.IsEncrypted + metadata.PredecessorRoomID = roomInfo.PredecessorRoomID + metadata.UpgradedRoomID = roomInfo.UpgradedRoomID + metadata.RoomType = roomInfo.Type + + // Don't care about the TypingEvent field. + return metadata, nil +} + +// TODO: there is a very similar query in FetchRoomMetadata which also selects events +// rows for their memberships. It is a shame to have to do this twice---can we query +// once and pass the data around? +func (s *Storage) FetchJoinedAndInvited(roomID string) (joined, invited []string, err error) { + var memberships []Event + err = s.DB.Select(&memberships, ` + WITH snapshot(membership_nids) AS ( + SELECT membership_events + FROM syncv3_snapshots + JOIN syncv3_rooms ON snapshot_id = current_snapshot_id + WHERE syncv3_rooms.room_id = $1 + ) + SELECT state_key, membership + FROM syncv3_events JOIN snapshot ON ( + event_nid = ANY( membership_nids ) + ) + WHERE membership IN ('join', '_join', 'invite', '_invite') + `, roomID) + if err != nil { + return nil, nil, err + } + + for _, membership := range memberships { + switch membership.Membership { + case "_join": + fallthrough + case "join": + joined = append(joined, membership.StateKey) + case "_invite": + fallthrough + case "invite": + invited = append(invited, membership.StateKey) + } + } + return } // Returns all current NOT MEMBERSHIP state events matching the event types given in all rooms. Returns a map of diff --git a/sync2/handler2/handler.go b/sync2/handler2/handler.go index 3fea03de..1318d3cf 100644 --- a/sync2/handler2/handler.go +++ b/sync2/handler2/handler.go @@ -370,20 +370,24 @@ func (h *Handler) Accumulate(ctx context.Context, userID, deviceID, roomID strin return nil } -func (h *Handler) Initialise(ctx context.Context, roomID string, state []json.RawMessage) ([]json.RawMessage, error) { +func (h *Handler) Initialise(ctx context.Context, roomID string, state []json.RawMessage) error { res, err := h.Store.Initialise(roomID, state) if err != nil { logger.Err(err).Int("state", len(state)).Str("room", roomID).Msg("V2: failed to initialise room") internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err) - return nil, err + return err } - if res.AddedEvents { + if res.ReplacedExistingSnapshot { + h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2InvalidateRoom{ + RoomID: roomID, + }) + } else if res.AddedEvents { h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2Initialise{ RoomID: roomID, SnapshotNID: res.SnapshotID, }) } - return res.PrependTimelineEvents, nil + return nil } func (h *Handler) SetTyping(ctx context.Context, pollerID sync2.PollerID, roomID string, ephEvent json.RawMessage) { diff --git a/sync2/poller.go b/sync2/poller.go index 4902bf69..6a20e659 100644 --- a/sync2/poller.go +++ b/sync2/poller.go @@ -40,7 +40,7 @@ type V2DataReceiver interface { // Initialise the room, if it hasn't been already. This means the state section of the v2 response. // If given a state delta from an incremental sync, returns the slice of all state events unknown to the DB. // Return an error to stop the since token advancing. - Initialise(ctx context.Context, roomID string, state []json.RawMessage) ([]json.RawMessage, error) // snapshot ID? + Initialise(ctx context.Context, roomID string, state []json.RawMessage) error // snapshot ID? // SetTyping indicates which users are typing. SetTyping(ctx context.Context, pollerID PollerID, roomID string, ephEvent json.RawMessage) // Sent when there is a new receipt @@ -326,11 +326,11 @@ func (h *PollerMap) Accumulate(ctx context.Context, userID, deviceID, roomID str wg.Wait() return } -func (h *PollerMap) Initialise(ctx context.Context, roomID string, state []json.RawMessage) (result []json.RawMessage, err error) { +func (h *PollerMap) Initialise(ctx context.Context, roomID string, state []json.RawMessage) (err error) { var wg sync.WaitGroup wg.Add(1) h.executor <- func() { - result, err = h.callbacks.Initialise(ctx, roomID, state) + err = h.callbacks.Initialise(ctx, roomID, state) wg.Done() } wg.Wait() @@ -781,30 +781,11 @@ func (p *poller) parseRoomsResponse(ctx context.Context, res *SyncResponse) erro for roomID, roomData := range res.Rooms.Join { if len(roomData.State.Events) > 0 { stateCalls++ - prependStateEvents, err := p.receiver.Initialise(ctx, roomID, roomData.State.Events) + err := p.receiver.Initialise(ctx, roomID, roomData.State.Events) if err != nil { lastErrs = append(lastErrs, fmt.Errorf("Initialise[%s]: %w", roomID, err)) continue } - if len(prependStateEvents) > 0 { - // The poller has just learned of these state events due to an - // incremental poller sync; we must have missed the opportunity to see - // these down /sync in a timeline. As a workaround, inject these into - // the timeline now so that future events are received under the - // correct room state. - const warnMsg = "parseRoomsResponse: prepending state events to timeline after gappy poll" - logger.Warn().Str("room_id", roomID).Int("prependStateEvents", len(prependStateEvents)).Msg(warnMsg) - hub := internal.GetSentryHubFromContextOrDefault(ctx) - hub.WithScope(func(scope *sentry.Scope) { - scope.SetContext(internal.SentryCtxKey, map[string]interface{}{ - "room_id": roomID, - "num_prepend_state_events": len(prependStateEvents), - }) - hub.CaptureMessage(warnMsg) - }) - p.trackGappyStateSize(len(prependStateEvents)) - roomData.Timeline.Events = append(prependStateEvents, roomData.Timeline.Events...) - } } // process typing/receipts before events so we seed the caches correctly for when we return the room for _, ephEvent := range roomData.Ephemeral.Events { diff --git a/sync2/poller_test.go b/sync2/poller_test.go index d2056f3c..e51771af 100644 --- a/sync2/poller_test.go +++ b/sync2/poller_test.go @@ -770,8 +770,8 @@ func TestPollerResendsOnCallbackError(t *testing.T) { // generate a receiver which errors for the right callback generateReceiver: func() V2DataReceiver { return &overrideDataReceiver{ - initialise: func(ctx context.Context, roomID string, state []json.RawMessage) ([]json.RawMessage, error) { - return nil, fmt.Errorf("initialise error") + initialise: func(ctx context.Context, roomID string, state []json.RawMessage) error { + return fmt.Errorf("initialise error") }, } }, @@ -1213,7 +1213,7 @@ func (a *mockDataReceiver) Accumulate(ctx context.Context, userID, deviceID, roo a.timelines[roomID] = append(a.timelines[roomID], timeline.Events...) return nil } -func (a *mockDataReceiver) Initialise(ctx context.Context, roomID string, state []json.RawMessage) ([]json.RawMessage, error) { +func (a *mockDataReceiver) Initialise(ctx context.Context, roomID string, state []json.RawMessage) error { a.states[roomID] = state if a.incomingProcess != nil { a.incomingProcess <- struct{}{} @@ -1223,7 +1223,7 @@ func (a *mockDataReceiver) Initialise(ctx context.Context, roomID string, state } // The return value is a list of unknown state events to be prepended to the room // timeline. Untested here---return nil for now. - return nil, nil + return nil } func (s *mockDataReceiver) UpdateDeviceSince(ctx context.Context, userID, deviceID, since string) { s.mu.Lock() @@ -1236,7 +1236,7 @@ func (s *mockDataReceiver) UpdateDeviceSince(ctx context.Context, userID, device type overrideDataReceiver struct { accumulate func(ctx context.Context, userID, deviceID, roomID, prevBatch string, timeline []json.RawMessage) error - initialise func(ctx context.Context, roomID string, state []json.RawMessage) ([]json.RawMessage, error) + initialise func(ctx context.Context, roomID string, state []json.RawMessage) error setTyping func(ctx context.Context, pollerID PollerID, roomID string, ephEvent json.RawMessage) updateDeviceSince func(ctx context.Context, userID, deviceID, since string) addToDeviceMessages func(ctx context.Context, userID, deviceID string, msgs []json.RawMessage) error @@ -1256,9 +1256,9 @@ func (s *overrideDataReceiver) Accumulate(ctx context.Context, userID, deviceID, } return s.accumulate(ctx, userID, deviceID, roomID, timeline.PrevBatch, timeline.Events) } -func (s *overrideDataReceiver) Initialise(ctx context.Context, roomID string, state []json.RawMessage) ([]json.RawMessage, error) { +func (s *overrideDataReceiver) Initialise(ctx context.Context, roomID string, state []json.RawMessage) error { if s.initialise == nil { - return nil, nil + return nil } return s.initialise(ctx, roomID, state) } diff --git a/sync3/caches/global.go b/sync3/caches/global.go index d14adedc..49f4d27c 100644 --- a/sync3/caches/global.go +++ b/sync3/caches/global.go @@ -3,6 +3,7 @@ package caches import ( "context" "encoding/json" + "fmt" "os" "sort" "sync" @@ -386,19 +387,14 @@ func (c *GlobalCache) OnNewEvent( c.roomIDToMetadata[ed.RoomID] = metadata } -func (c *GlobalCache) OnInvalidateRoom(ctx context.Context, roomID string) { +func (c *GlobalCache) ReloadRoom(roomID string) (*internal.RoomMetadata, error) { c.roomIDToMetadataMu.Lock() defer c.roomIDToMetadataMu.Unlock() - metadata, ok := c.roomIDToMetadata[roomID] - if !ok { - logger.Warn().Str("room_id", roomID).Msg("OnInvalidateRoom: room not in global cache") - return - } - - err := c.store.ResetMetadataState(metadata) + metadata, err := c.store.FetchRoomMetadata(roomID) if err != nil { - internal.GetSentryHubFromContextOrDefault(ctx).CaptureException(err) - logger.Warn().Err(err).Msg("OnInvalidateRoom: failed to reset metadata") + return nil, fmt.Errorf("Failed to FetchRoomMetadata: %w", err) } + c.roomIDToMetadata[roomID] = metadata + return metadata, nil } diff --git a/sync3/caches/update.go b/sync3/caches/update.go index d6500f24..a5a4c6d9 100644 --- a/sync3/caches/update.go +++ b/sync3/caches/update.go @@ -18,6 +18,14 @@ type RoomUpdate interface { UserRoomMetadata() *UserRoomData } +type RoomCacheInvalidationUpdate struct { + RoomUpdate +} + +func (u *RoomCacheInvalidationUpdate) Type() string { + return fmt.Sprintf("RoomCacheInvalidationUpdate[%s]", u.RoomID()) +} + // RoomEventUpdate corresponds to a single event seen in a joined room's timeline under sync v2. type RoomEventUpdate struct { RoomUpdate diff --git a/sync3/caches/user.go b/sync3/caches/user.go index 80d880e8..a7c3e631 100644 --- a/sync3/caches/user.go +++ b/sync3/caches/user.go @@ -761,9 +761,17 @@ func (u *UserCache) ShouldIgnore(userID string) bool { return ignored } -func (u *UserCache) OnInvalidateRoom(ctx context.Context, roomID string) { +func (u *UserCache) OnInvalidateRoom(ctx context.Context, metadata *internal.RoomMetadata) { + urd := u.LoadRoomData(metadata.RoomID) // Nothing for now. In UserRoomData the fields dependant on room state are // IsDM, IsInvite, HasLeft, Invite, CanonicalisedName, ResolvedAvatarURL, Spaces. // Not clear to me if we need to reload these or if we will inherit any changes from // the global cache. + u.emitOnRoomUpdate(ctx, &RoomCacheInvalidationUpdate{ + RoomUpdate: &roomUpdateCache{ + roomID: metadata.RoomID, + globalRoomData: metadata, + userRoomData: &urd, + }, + }) } diff --git a/sync3/dispatcher.go b/sync3/dispatcher.go index 6c71d7e5..9eb0a2d4 100644 --- a/sync3/dispatcher.go +++ b/sync3/dispatcher.go @@ -24,7 +24,6 @@ type Receiver interface { OnNewEvent(ctx context.Context, event *caches.EventData) OnReceipt(ctx context.Context, receipt internal.Receipt) OnEphemeralEvent(ctx context.Context, roomID string, ephEvent json.RawMessage) - OnInvalidateRoom(ctx context.Context, roomID string) // OnRegistered is called after a successful call to Dispatcher.Register OnRegistered(ctx context.Context) error } @@ -287,22 +286,27 @@ func (d *Dispatcher) notifyListeners(ctx context.Context, ed *caches.EventData, } } -func (d *Dispatcher) OnInvalidateRoom(ctx context.Context, roomID string) { - // First dispatch to the global cache. - receiver, ok := d.userToReceiver[DispatcherAllUsers] - if !ok { - logger.Error().Msgf("No receiver for global cache") - } - receiver.OnInvalidateRoom(ctx, roomID) +func (d *Dispatcher) OnInvalidateRoom(ctx context.Context, metadata *internal.RoomMetadata, joined, invited []string) { + // First update the JoinedRoomsTracker. + left := d.jrt.ReloadMembershipsForRoom(metadata.RoomID, joined, invited) // Then dispatch to any users who are joined to that room. - joinedUsers, _ := d.jrt.JoinedUsersForRoom(roomID, nil) d.userToReceiverMu.RLock() defer d.userToReceiverMu.RUnlock() - for _, userID := range joinedUsers { - receiver = d.userToReceiver[userID] - if receiver != nil { - receiver.OnInvalidateRoom(ctx, roomID) + + pokeUsers := func(users []string) { + for _, userID := range users { + rec := d.userToReceiver[userID] + if rec == nil { + continue + } + uc := rec.(*caches.UserCache) + if uc != nil { + uc.OnInvalidateRoom(ctx, metadata) + } } } + pokeUsers(joined) + pokeUsers(invited) + pokeUsers(left) } diff --git a/sync3/handler/handler.go b/sync3/handler/handler.go index 5e6b80e3..077058c3 100644 --- a/sync3/handler/handler.go +++ b/sync3/handler/handler.go @@ -807,7 +807,31 @@ func (h *SyncLiveHandler) OnInvalidateRoom(p *pubsub.V2InvalidateRoom) { ctx, task := internal.StartTask(context.Background(), "OnInvalidateRoom") defer task.End() - h.Dispatcher.OnInvalidateRoom(ctx, p.RoomID) + hub := internal.GetSentryHubFromContextOrDefault(ctx) + hub.ConfigureScope(func(scope *sentry.Scope) { + scope.SetContext(internal.SentryCtxKey, map[string]any{ + "room_id": p.RoomID, + }) + }) + + // TODO: the only consumer actually wants a set---could make this return a set directly? + joined, invited, err := h.Storage.FetchJoinedAndInvited(p.RoomID) + if err != nil { + hub.CaptureException(err) + logger.Err(err). + Str("room_id", p.RoomID). + Msg("Failed to fetch joined and invited members after cache invalidation") + return + } + + metadata, err := h.GlobalCache.ReloadRoom(p.RoomID) + if err != nil { + hub.CaptureException(err) + logger.Err(err).Str("room_id", p.RoomID).Msg("Failed to fetch metadata after cache invalidation") + return + } + + h.Dispatcher.OnInvalidateRoom(ctx, metadata, joined, invited) } func parseIntFromQuery(u *url.URL, param string) (result int64, err *internal.HandlerError) { diff --git a/sync3/tracker.go b/sync3/tracker.go index 3a1c73cf..6cf5267f 100644 --- a/sync3/tracker.go +++ b/sync3/tracker.go @@ -184,3 +184,62 @@ func (t *JoinedRoomsTracker) NumInvitedUsersForRoom(roomID string) int { defer t.mu.RUnlock() return len(t.roomIDToInvitedUsers[roomID]) } + +// ReloadMembershipsForRoom overwrites the JoinedRoomsTracker state for one room to the +// given list of joined and invited users. It returns the list of users who were joined +// or invited prior to this call, but are no longer joined nor invited. +func (t *JoinedRoomsTracker) ReloadMembershipsForRoom(roomID string, joined, invited []string) (left []string) { + newJoined := make(set, len(joined)) + newInvited := make(set, len(invited)) + for _, member := range joined { + newJoined[member] = struct{}{} + } + for _, member := range invited { + newInvited[member] = struct{}{} + } + + t.mu.Lock() + + // 1. Overwrite the room's memberships with the given arguments. + oldJoined := t.roomIDToJoinedUsers[roomID] + oldInvited := t.roomIDToInvitedUsers[roomID] + t.roomIDToJoinedUsers[roomID] = newJoined + t.roomIDToInvitedUsers[roomID] = newInvited + + // 2. Mark the joined users as being joined to this room. + for userID := range newJoined { + _, userAlreadyTracked := t.userIDToJoinedRooms[userID] + if !userAlreadyTracked { + t.userIDToJoinedRooms[userID] = make(set) + } + t.userIDToJoinedRooms[userID][roomID] = struct{}{} + } + + // 3. Scan the old joined list for users who are no longer joined, and mark them as such. + // Also scan for those who have left (i.e. were joined and have not been reinvited). + for userID := range oldJoined { + _, stillJoined := newJoined[userID] + if !stillJoined { + delete(t.userIDToJoinedRooms[userID], roomID) + _, nowInvited := newInvited[userID] + if !nowInvited { + left = append(left, userID) + } + } + } + + t.mu.Unlock() + + // 4. Scan the old invited list for users who have left. + for userID := range oldInvited { + _, stillInvited := newInvited[userID] + if !stillInvited { + _, nowJoined := newJoined[userID] + if !nowJoined { + left = append(left, userID) + } + } + } + + return +} diff --git a/sync3/tracker_test.go b/sync3/tracker_test.go index 7be2336a..39b12d77 100644 --- a/sync3/tracker_test.go +++ b/sync3/tracker_test.go @@ -82,6 +82,45 @@ func TestTrackerStartup(t *testing.T) { assertInt(t, jrt.NumInvitedUsersForRoom(roomC), 0) } +func TestTrackerReload(t *testing.T) { + roomA := "!a" + roomB := "!b" + roomC := "!c" + alice := "@alice" + bob := "@bob" + chris := "@chris" + jrt := NewJoinedRoomsTracker() + jrt.Startup(map[string][]string{ + roomA: {alice, bob}, + roomB: {bob}, + roomC: {alice}, + }) + + t.Log("Chris joins room C.") + jrt.ReloadMembershipsForRoom(roomC, []string{alice, chris}, nil) + members, _ := jrt.JoinedUsersForRoom(roomC, nil) + assertEqualSlices(t, "roomC joined members", members, []string{alice, chris}) + assertEqualSlices(t, "alice's rooms", jrt.JoinedRoomsForUser(alice), []string{roomA, roomC}) + assertEqualSlices(t, "chris's rooms", jrt.JoinedRoomsForUser(chris), []string{roomC}) + assertInt(t, jrt.NumInvitedUsersForRoom(roomC), 0) + + t.Log("Bob leaves room B.") + jrt.ReloadMembershipsForRoom(roomB, nil, nil) + members, _ = jrt.JoinedUsersForRoom(roomB, nil) + assertEqualSlices(t, "roomB joined members", members, nil) + assertEqualSlices(t, "bob's rooms", jrt.JoinedRoomsForUser(bob), []string{roomA}) + assertInt(t, jrt.NumInvitedUsersForRoom(roomB), 0) + + t.Log("Chris joins room A. Alice and Bob leave it, but Chris reinvites Bob.") + jrt.ReloadMembershipsForRoom(roomA, []string{chris}, []string{bob}) + members, _ = jrt.JoinedUsersForRoom(roomA, nil) + assertEqualSlices(t, "roomA joined members", members, []string{chris}) + assertEqualSlices(t, "alice's rooms", jrt.JoinedRoomsForUser(alice), []string{roomC}) + assertEqualSlices(t, "bob's rooms", jrt.JoinedRoomsForUser(bob), nil) + assertEqualSlices(t, "chris's rooms", jrt.JoinedRoomsForUser(chris), []string{roomA, roomC}) + assertInt(t, jrt.NumInvitedUsersForRoom(roomA), 1) +} + func assertBool(t *testing.T, msg string, got, want bool) { t.Helper() if got != want { diff --git a/tests-e2e/gappy_state_test.go b/tests-e2e/gappy_state_test.go index a8004a12..7a37b73c 100644 --- a/tests-e2e/gappy_state_test.go +++ b/tests-e2e/gappy_state_test.go @@ -1,6 +1,7 @@ package syncv3_test import ( + "encoding/json" "fmt" "github.com/matrix-org/sliding-sync/sync3" "github.com/matrix-org/sliding-sync/testutils/m" @@ -59,10 +60,17 @@ func TestGappyState(t *testing.T) { nameContent := map[string]interface{}{"name": "potato"} alice.SetState(t, roomID, "m.room.name", "", nameContent) - t.Log("Alice sends lots of message events (more than the poller will request in a timeline.") - var latestMessageID string - for i := 0; i < 51; i++ { - latestMessageID = alice.SendEventUnsynced(t, roomID, Event{ + t.Log("Alice sends lots of other state events.") + const numOtherState = 40 + for i := 0; i < numOtherState; i++ { + alice.SetState(t, roomID, "com.example.dummy", fmt.Sprintf("%d", i), map[string]any{}) + } + + t.Log("Alice sends a batch of message events.") + const numMessages = 20 + var lastMsgID string + for i := 0; i < numMessages; i++ { + lastMsgID = alice.SendEventUnsynced(t, roomID, Event{ Type: "m.room.message", Content: map[string]interface{}{ "msgtype": "m.text", @@ -71,28 +79,50 @@ func TestGappyState(t *testing.T) { }) } - t.Log("Alice requests an initial sliding sync on device 2.") + t.Logf("The proxy is now %d events behind the HS, which should trigger a limited sync", 1+numOtherState+numMessages) + + t.Log("Alice requests an initial sliding sync on device 2, with timeline limit big enough to see her first message at the start of the test.") syncResp = alice.SlidingSync(t, sync3.Request{ Lists: map[string]sync3.RequestList{ "a": { Ranges: [][2]int64{{0, 20}}, RoomSubscription: sync3.RoomSubscription{ - TimelineLimit: 10, + TimelineLimit: 100, }, }, }, }, ) - t.Log("She should see her latest message with the room name updated") + // We're testing here that the state events from the gappy poll are NOT injected + // into the timeline. The poll is only going to use timeline limit 1 because it's + // the first poll on a new device. See integration test for a "proper" gappy poll. + t.Log("She should see the updated room name, her most recent message, but NOT the state events in the gap nor messages from before the gap.") m.MatchResponse( t, syncResp, m.MatchRoomSubscription( roomID, m.MatchRoomName("potato"), - MatchRoomTimelineMostRecent(1, []Event{{ID: latestMessageID}}), + MatchRoomTimelineMostRecent(1, []Event{{ID: lastMsgID}}), + func(r sync3.Room) error { + for _, rawEv := range r.Timeline { + var ev Event + err := json.Unmarshal(rawEv, &ev) + if err != nil { + t.Fatal(err) + } + // Shouldn't see the state events, only messages + if ev.Type != "m.room.message" { + return fmt.Errorf("timeline contained event %s of type %s (expected m.room.message)", ev.ID, ev.Type) + } + if ev.ID == firstMessageID { + return fmt.Errorf("timeline contained first message from before the gap") + } + } + return nil + }, ), ) } diff --git a/tests-integration/poller_test.go b/tests-integration/poller_test.go index ed66ed26..9db77e25 100644 --- a/tests-integration/poller_test.go +++ b/tests-integration/poller_test.go @@ -600,3 +600,209 @@ func TestTimelineStopsLoadingWhenMissingPrevious(t *testing.T) { m.MatchRoomPrevBatch("dummyPrevBatch"), )) } + +// The "prepend state events" mechanism added in +// https://github.com/matrix-org/sliding-sync/pull/71 ensured that the proxy +// communicated state events in "gappy syncs" to users. But it did so via Accumulate, +// which made one snapshot for each state event. This was not an accurate model of the +// room's history (the state block comes in no particular order) and had awful +// performance for large gappy states. +// +// We now want to handle these in Initialise, making a single snapshot for the state +// block. This test ensures that is the case. The logic is very similar to the e2e test +// TestGappyState. +func TestGappyStateDoesNotAccumulateTheStateBlock(t *testing.T) { + pqString := testutils.PrepareDBConnectionString() + v2 := runTestV2Server(t) + defer v2.close() + v3 := runTestServer(t, v2, pqString) + defer v3.close() + + v2.addAccount(t, alice, aliceToken) + v2.addAccount(t, bob, bobToken) + + t.Log("Alice creates a room, sets its name and sends a message.") + const roomID = "!unimportant" + name1 := testutils.NewStateEvent(t, "m.room.name", "", alice, map[string]any{ + "name": "wonderland", + }) + msg1 := testutils.NewMessageEvent(t, alice, "0118 999 881 999 119 7253") + + joinTimeline := v2JoinTimeline(roomEvents{ + roomID: roomID, + events: append( + createRoomState(t, alice, time.Now()), + name1, + msg1, + ), + }) + v2.queueResponse(aliceToken, sync2.SyncResponse{ + Rooms: sync2.SyncRoomsResponse{ + Join: joinTimeline, + }, + }) + + t.Log("Alice sliding syncs with a huge timeline limit, subscribing to the room she just created.") + res := v3.mustDoV3Request(t, aliceToken, sync3.Request{ + RoomSubscriptions: map[string]sync3.RoomSubscription{ + roomID: {TimelineLimit: 100}, + }, + }) + + t.Log("Alice sees the room with the expected name, with the name event and message at the end of the timeline.") + m.MatchResponse(t, res, m.MatchRoomSubscription(roomID, + m.MatchRoomName("wonderland"), + m.MatchRoomTimelineMostRecent(2, []json.RawMessage{name1, msg1}), + )) + + t.Log("Alice's poller receives a gappy sync, including a room name change, bob joining, and two messages.") + stateBlock := make([]json.RawMessage, 0) + for i := 0; i < 10; i++ { + statePiece := testutils.NewStateEvent(t, "com.example.custom", fmt.Sprintf("%d", i), alice, map[string]any{}) + stateBlock = append(stateBlock, statePiece) + } + name2 := testutils.NewStateEvent(t, "m.room.name", "", alice, map[string]any{ + "name": "not wonderland", + }) + bobJoin := testutils.NewJoinEvent(t, bob) + stateBlock = append(stateBlock, name2, bobJoin) + + msg2 := testutils.NewMessageEvent(t, alice, "Good morning!") + msg3 := testutils.NewMessageEvent(t, alice, "That's a nice tnetennba.") + v2.queueResponse(aliceToken, sync2.SyncResponse{ + Rooms: sync2.SyncRoomsResponse{ + Join: map[string]sync2.SyncV2JoinResponse{ + roomID: { + State: sync2.EventsResponse{ + Events: stateBlock, + }, + Timeline: sync2.TimelineResponse{ + Events: []json.RawMessage{msg2, msg3}, + Limited: true, + PrevBatch: "dummyPrevBatch", + }, + }, + }, + }, + }) + v2.waitUntilEmpty(t, aliceToken) + + t.Log("Alice should see the two most recent message in the timeline only. The room name should change too.") + res = v3.mustDoV3RequestWithPos(t, aliceToken, res.Pos, sync3.Request{}) + m.MatchResponse(t, res, m.MatchRoomSubscription(roomID, + m.MatchRoomName("not wonderland"), + // In particular, we shouldn't see state here because it's not part of the timeline. + // Nor should we see msg1, as that comes before a gap. + m.MatchRoomTimeline([]json.RawMessage{msg2, msg3}), + )) +} + +func TestJoinedRoomsTrackerUpdatedAfterGappyState(t *testing.T) { + pqString := testutils.PrepareDBConnectionString() + v2 := runTestV2Server(t) + defer v2.close() + v3 := runTestServer(t, v2, pqString) + defer v3.close() + + const roomID = "!unimportant" + + v2.addAccount(t, alice, aliceToken) + v2.addAccount(t, bob, bobToken) + v2.addAccount(t, chris, chrisToken) + + t.Log("Queue up an empty poller response for Chris, so the proxy considers him to be polling.") + v2.queueResponse(chrisToken, sync2.SyncResponse{ + NextBatch: "chris1", + }) + chrisRes := v3.mustDoV3Request(t, chrisToken, sync3.Request{}) + v2.waitUntilEmpty(t, chrisToken) + + initialEvents := append( + createRoomState(t, alice, time.Now()), + testutils.NewStateEvent(t, "m.room.member", bob, alice, map[string]any{"membership": "invite"}), + ) + t.Log("Alice creates a room. Bob is invited, but Chris isn't.") + v2.queueResponse(aliceToken, sync2.SyncResponse{ + Rooms: sync2.SyncRoomsResponse{ + Join: v2JoinTimeline(roomEvents{ + roomID: roomID, + events: initialEvents, + }), + }, + NextBatch: "alice1", + }) + + t.Log("Alice sliding syncs and sees herself joined to the room.") + aliceRes := v3.mustDoV3Request(t, aliceToken, sync3.Request{ + RoomSubscriptions: map[string]sync3.RoomSubscription{ + roomID: {TimelineLimit: 20}, + }, + }) + m.MatchResponse(t, aliceRes, m.MatchRoomSubscription(roomID, + m.MatchJoinCount(1), + m.MatchInviteCount(1)), + ) + + t.Log("Bob's poller sees his invite.") + v2.queueResponse(bobToken, sync2.SyncResponse{ + Rooms: sync2.SyncRoomsResponse{ + Invite: map[string]sync2.SyncV2InviteResponse{ + roomID: { + InviteState: sync2.EventsResponse{ + Events: initialEvents, + }, + }, + }}, + NextBatch: "bob1", + }) + + t.Log("Bob sliding syncs sees himself invited to the room.") + bobRes := v3.mustDoV3Request(t, bobToken, sync3.Request{ + Lists: map[string]sync3.RequestList{ + "a": { + Ranges: sync3.SliceRanges{{0, 10}}, + }, + }, + }) + m.MatchResponse(t, bobRes, m.MatchRoomSubscription(roomID, m.MatchInviteCount(1))) + + t.Log("Alice's poller gets a gappy sync response in which Bob joins, Chris joins, and Alice sends a message.") + aliceMsg := testutils.NewMessageEvent(t, alice, "hellooooooooo") + v2.queueResponse(aliceToken, sync2.SyncResponse{ + NextBatch: "alice2", + Rooms: sync2.SyncRoomsResponse{ + Join: map[string]sync2.SyncV2JoinResponse{ + roomID: { + State: sync2.EventsResponse{ + Events: []json.RawMessage{ + testutils.NewStateEvent(t, "m.room.member", bob, bob, map[string]any{"membership": "join"}), + testutils.NewStateEvent(t, "m.room.member", chris, chris, map[string]any{"membership": "join"}), + }, + }, + Timeline: sync2.TimelineResponse{ + Events: []json.RawMessage{aliceMsg}, + Limited: true, + }, + }, + }, + }, + }) + v2.waitUntilEmpty(t, aliceToken) + + t.Log("Bob syncs. He should see himself as having joined the room, and see Alice's message.") + bobRes = v3.mustDoV3RequestWithPos(t, bobToken, bobRes.Pos, sync3.Request{}) + m.MatchResponse(t, bobRes, m.MatchRoomSubscription(roomID, + m.MatchJoinCount(3), + m.MatchInviteCount(0), + m.MatchRoomTimeline([]json.RawMessage{aliceMsg}), + )) + + t.Log("Ditto for Chris.") + chrisRes = v3.mustDoV3RequestWithPos(t, chrisToken, chrisRes.Pos, sync3.Request{}) + m.MatchResponse(t, chrisRes, m.MatchRoomSubscription(roomID, + m.MatchJoinCount(3), + m.MatchInviteCount(0), + m.MatchRoomTimeline([]json.RawMessage{aliceMsg}), + )) + +} diff --git a/tests-integration/v3_test.go b/tests-integration/v3_test.go index 80ecaf5f..fcccbd56 100644 --- a/tests-integration/v3_test.go +++ b/tests-integration/v3_test.go @@ -42,6 +42,8 @@ const ( aliceToken = "ALICE_BEARER_TOKEN" bob = "@bob:localhost" bobToken = "BOB_BEARER_TOKEN" + chris = "@chris:localhost" + chrisToken = "CHRIS_BEARER_TOKEN" ) var (