Skip to content

Commit

Permalink
Use a txn for accumulator.Accumulate and make one in Storage.Accumulate
Browse files Browse the repository at this point in the history
  • Loading branch information
kegsay committed Jun 8, 2023
1 parent 94da856 commit c2f4b53
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 117 deletions.
215 changes: 105 additions & 110 deletions state/accumulator.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"database/sql"
"encoding/json"
"fmt"

"github.com/getsentry/sentry-go"

"github.com/jmoiron/sqlx"
Expand Down Expand Up @@ -291,131 +292,125 @@ func (a *Accumulator) Initialise(roomID string, state []json.RawMessage) (Initia
// to exist in the database, and the sync stream is already linearised for us.
// - Else it creates a new room state snapshot if the timeline contains state events (as this now represents the current state)
// - It adds entries to the membership log for membership events.
func (a *Accumulator) Accumulate(roomID string, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) {
if len(timeline) == 0 {
return 0, nil, nil
}
err = sqlutil.WithTransaction(a.db, func(txn *sqlx.Tx) error {
// Insert the events. Check for duplicates which can happen in the real world when joining
// Matrix HQ on Synapse.
dedupedEvents := make([]Event, 0, len(timeline))
seenEvents := make(map[string]struct{})
for i := range timeline {
e := Event{
JSON: timeline[i],
RoomID: roomID,
}
if err := e.ensureFieldsSetOnEvent(); err != nil {
return fmt.Errorf("event malformed: %s", err)
}
if _, ok := seenEvents[e.ID]; ok {
logger.Warn().Str("event_id", e.ID).Str("room_id", roomID).Msg(
"Accumulator.Accumulate: seen the same event ID twice, ignoring",
)
continue
}
if i == 0 && prevBatch != "" {
// tag the first timeline event with the prev batch token
e.PrevBatch = sql.NullString{
String: prevBatch,
Valid: true,
}
}
dedupedEvents = append(dedupedEvents, e)
seenEvents[e.ID] = struct{}{}
func (a *Accumulator) Accumulate(txn *sqlx.Tx, roomID string, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) {
// Insert the events. Check for duplicates which can happen in the real world when joining
// Matrix HQ on Synapse.
dedupedEvents := make([]Event, 0, len(timeline))
seenEvents := make(map[string]struct{})
for i := range timeline {
e := Event{
JSON: timeline[i],
RoomID: roomID,
}
eventIDToNID, err := a.eventsTable.Insert(txn, dedupedEvents, false)
if err != nil {
return err
if err := e.ensureFieldsSetOnEvent(); err != nil {
return 0, nil, fmt.Errorf("event malformed: %s", err)
}
if len(eventIDToNID) == 0 {
// nothing to do, we already know about these events
return nil
if _, ok := seenEvents[e.ID]; ok {
logger.Warn().Str("event_id", e.ID).Str("room_id", roomID).Msg(
"Accumulator.Accumulate: seen the same event ID twice, ignoring",
)
continue
}
numNew = len(eventIDToNID)

var latestNID int64
newEvents := make([]Event, 0, len(eventIDToNID))
for _, ev := range dedupedEvents {
nid, ok := eventIDToNID[ev.ID]
if ok {
ev.NID = int64(nid)
if gjson.GetBytes(ev.JSON, "state_key").Exists() {
// XXX: reusing this to mean "it's a state event" as well as "it's part of the state v2 response"
// its important that we don't insert 'ev' at this point as this should be False in the DB.
ev.IsState = true
}
// assign the highest nid value to the latest nid.
// we'll return this to the caller so they can stay in-sync
if ev.NID > latestNID {
latestNID = ev.NID
}
newEvents = append(newEvents, ev)
timelineNIDs = append(timelineNIDs, ev.NID)
if i == 0 && prevBatch != "" {
// tag the first timeline event with the prev batch token
e.PrevBatch = sql.NullString{
String: prevBatch,
Valid: true,
}
}
dedupedEvents = append(dedupedEvents, e)
seenEvents[e.ID] = struct{}{}
}
eventIDToNID, err := a.eventsTable.Insert(txn, dedupedEvents, false)
if err != nil {
return 0, nil, err
}
if len(eventIDToNID) == 0 {
// nothing to do, we already know about these events
return 0, nil, nil
}
numNew = len(eventIDToNID)

// Given a timeline of [E1, E2, S3, E4, S5, S6, E7] (E=message event, S=state event)
// And a prior state snapshot of SNAP0 then the BEFORE snapshot IDs are grouped as:
// E1,E2,S3 => SNAP0
// E4, S5 => (SNAP0 + S3)
// S6 => (SNAP0 + S3 + S5)
// E7 => (SNAP0 + S3 + S5 + S6)
// We can track this by loading the current snapshot ID (after snapshot) then rolling forward
// the timeline until we hit a state event, at which point we make a new snapshot but critically
// do NOT assign the new state event in the snapshot so as to represent the state before the event.
snapID, err := a.roomsTable.CurrentAfterSnapshotID(txn, roomID)
if err != nil {
return err
var latestNID int64
newEvents := make([]Event, 0, len(eventIDToNID))
for _, ev := range dedupedEvents {
nid, ok := eventIDToNID[ev.ID]
if ok {
ev.NID = int64(nid)
if gjson.GetBytes(ev.JSON, "state_key").Exists() {
// XXX: reusing this to mean "it's a state event" as well as "it's part of the state v2 response"
// its important that we don't insert 'ev' at this point as this should be False in the DB.
ev.IsState = true
}
// assign the highest nid value to the latest nid.
// we'll return this to the caller so they can stay in-sync
if ev.NID > latestNID {
latestNID = ev.NID
}
newEvents = append(newEvents, ev)
timelineNIDs = append(timelineNIDs, ev.NID)
}
for _, ev := range newEvents {
var replacesNID int64
// the snapshot ID we assign to this event is unaffected by whether /this/ event is state or not,
// as this is the before snapshot ID.
beforeSnapID := snapID
}

if ev.IsState {
// make a new snapshot and update the snapshot ID
var oldStripped StrippedEvents
if snapID != 0 {
oldStripped, err = a.strippedEventsForSnapshot(txn, snapID)
if err != nil {
return fmt.Errorf("failed to load stripped state events for snapshot %d: %s", snapID, err)
}
}
newStripped, replacedNID, err := a.calculateNewSnapshot(oldStripped, ev)
// Given a timeline of [E1, E2, S3, E4, S5, S6, E7] (E=message event, S=state event)
// And a prior state snapshot of SNAP0 then the BEFORE snapshot IDs are grouped as:
// E1,E2,S3 => SNAP0
// E4, S5 => (SNAP0 + S3)
// S6 => (SNAP0 + S3 + S5)
// E7 => (SNAP0 + S3 + S5 + S6)
// We can track this by loading the current snapshot ID (after snapshot) then rolling forward
// the timeline until we hit a state event, at which point we make a new snapshot but critically
// do NOT assign the new state event in the snapshot so as to represent the state before the event.
snapID, err := a.roomsTable.CurrentAfterSnapshotID(txn, roomID)
if err != nil {
return 0, nil, err
}
for _, ev := range newEvents {
var replacesNID int64
// the snapshot ID we assign to this event is unaffected by whether /this/ event is state or not,
// as this is the before snapshot ID.
beforeSnapID := snapID

if ev.IsState {
// make a new snapshot and update the snapshot ID
var oldStripped StrippedEvents
if snapID != 0 {
oldStripped, err = a.strippedEventsForSnapshot(txn, snapID)
if err != nil {
return fmt.Errorf("failed to calculateNewSnapshot: %s", err)
}
replacesNID = replacedNID
memNIDs, otherNIDs := newStripped.NIDs()
newSnapshot := &SnapshotRow{
RoomID: roomID,
MembershipEvents: memNIDs,
OtherEvents: otherNIDs,
return 0, nil, fmt.Errorf("failed to load stripped state events for snapshot %d: %s", snapID, err)
}
if err = a.snapshotTable.Insert(txn, newSnapshot); err != nil {
return fmt.Errorf("failed to insert new snapshot: %w", err)
}
snapID = newSnapshot.SnapshotID
}
if err := a.eventsTable.UpdateBeforeSnapshotID(txn, ev.NID, beforeSnapID, replacesNID); err != nil {
return err
newStripped, replacedNID, err := a.calculateNewSnapshot(oldStripped, ev)
if err != nil {
return 0, nil, fmt.Errorf("failed to calculateNewSnapshot: %s", err)
}
replacesNID = replacedNID
memNIDs, otherNIDs := newStripped.NIDs()
newSnapshot := &SnapshotRow{
RoomID: roomID,
MembershipEvents: memNIDs,
OtherEvents: otherNIDs,
}
if err = a.snapshotTable.Insert(txn, newSnapshot); err != nil {
return 0, nil, fmt.Errorf("failed to insert new snapshot: %w", err)
}
snapID = newSnapshot.SnapshotID
}

if err = a.spacesTable.HandleSpaceUpdates(txn, newEvents); err != nil {
return fmt.Errorf("HandleSpaceUpdates: %s", err)
if err := a.eventsTable.UpdateBeforeSnapshotID(txn, ev.NID, beforeSnapID, replacesNID); err != nil {
return 0, nil, err
}
}

// the last fetched snapshot ID is the current one
info := a.roomInfoDelta(roomID, newEvents)
if err = a.roomsTable.Upsert(txn, info, snapID, latestNID); err != nil {
return fmt.Errorf("failed to UpdateCurrentSnapshotID to %d: %w", snapID, err)
}
return nil
})
return numNew, timelineNIDs, err
if err = a.spacesTable.HandleSpaceUpdates(txn, newEvents); err != nil {
return 0, nil, fmt.Errorf("HandleSpaceUpdates: %s", err)
}

// the last fetched snapshot ID is the current one
info := a.roomInfoDelta(roomID, newEvents)
if err = a.roomsTable.Upsert(txn, info, snapID, latestNID); err != nil {
return 0, nil, fmt.Errorf("failed to UpdateCurrentSnapshotID to %d: %w", snapID, err)
}
return numNew, timelineNIDs, nil
}

// Delta returns a list of events of at most `limit` for the room not including `lastEventNID`.
Expand Down
42 changes: 36 additions & 6 deletions state/accumulator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"sort"
"testing"

"github.com/jmoiron/sqlx"
"github.com/matrix-org/sliding-sync/sqlutil"
"github.com/matrix-org/sliding-sync/sync2"
"github.com/matrix-org/sliding-sync/testutils"
"github.com/tidwall/gjson"
Expand Down Expand Up @@ -115,7 +117,11 @@ func TestAccumulatorAccumulate(t *testing.T) {
}
var numNew int
var latestNIDs []int64
if numNew, latestNIDs, err = accumulator.Accumulate(roomID, "", newEvents); err != nil {
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
numNew, latestNIDs, err = accumulator.Accumulate(txn, roomID, "", newEvents)
return err
})
if err != nil {
t.Fatalf("failed to Accumulate: %s", err)
}
if numNew != len(newEvents) {
Expand Down Expand Up @@ -185,7 +191,11 @@ func TestAccumulatorAccumulate(t *testing.T) {
}

// subsequent calls do nothing and are not an error
if _, _, err = accumulator.Accumulate(roomID, "", newEvents); err != nil {
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
_, _, err = accumulator.Accumulate(txn, roomID, "", newEvents)
return err
})
if err != nil {
t.Fatalf("failed to Accumulate: %s", err)
}
}
Expand All @@ -207,7 +217,11 @@ func TestAccumulatorDelta(t *testing.T) {
[]byte(`{"event_id":"aH", "type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"public"}}`),
[]byte(`{"event_id":"aI", "type":"m.room.history_visibility", "state_key":"", "content":{"visibility":"public"}}`),
}
if _, _, err = accumulator.Accumulate(roomID, "", roomEvents); err != nil {
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
_, _, err = accumulator.Accumulate(txn, roomID, "", roomEvents)
return err
})
if err != nil {
t.Fatalf("failed to Accumulate: %s", err)
}

Expand Down Expand Up @@ -266,7 +280,11 @@ func TestAccumulatorMembershipLogs(t *testing.T) {
// @me leaves the room
[]byte(`{"event_id":"` + roomEventIDs[7] + `", "type":"m.room.member", "state_key":"@me:localhost","unsigned":{"prev_content":{"membership":"join", "displayname":"Me"}}, "content":{"membership":"leave"}}`),
}
if _, _, err = accumulator.Accumulate(roomID, "", roomEvents); err != nil {
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
_, _, err = accumulator.Accumulate(txn, roomID, "", roomEvents)
return err
})
if err != nil {
t.Fatalf("failed to Accumulate: %s", err)
}
txn, err := accumulator.db.Beginx()
Expand Down Expand Up @@ -389,7 +407,10 @@ func TestAccumulatorDupeEvents(t *testing.T) {
t.Fatalf("failed to Initialise accumulator: %s", err)
}

_, _, err = accumulator.Accumulate(roomID, "", joinRoom.Timeline.Events)
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
_, _, err = accumulator.Accumulate(txn, roomID, "", joinRoom.Timeline.Events)
return err
})
if err != nil {
t.Fatalf("failed to Accumulate: %s", err)
}
Expand Down Expand Up @@ -434,7 +455,10 @@ func TestAccumulatorMisorderedGraceful(t *testing.T) {
}

// Accumulate events D, A, B(msg).
_, _, err = accumulator.Accumulate(roomID, "", []json.RawMessage{eventD, eventA, eventBMsg})
err = sqlutil.WithTransaction(accumulator.db, func(txn *sqlx.Tx) error {
_, _, err = accumulator.Accumulate(txn, roomID, "", []json.RawMessage{eventD, eventA, eventBMsg})
return err
})
if err != nil {
t.Fatalf("failed to Accumulate: %s", err)
}
Expand Down Expand Up @@ -630,6 +654,12 @@ func TestCalculateNewSnapshotDupe(t *testing.T) {
}
}

// Test that you can accumulate the same room with the same partial sequence of timeline events and
// state is updated correctly. This relies on postgres blocking subsequent transactions sensibly.
func TestAccumulatorConcurrency(t *testing.T) {

}

func currentSnapshotNIDs(t *testing.T, snapshotTable *SnapshotTable, roomID string) []int64 {
txn := snapshotTable.db.MustBeginTx(context.Background(), nil)
defer txn.Commit()
Expand Down
9 changes: 8 additions & 1 deletion state/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,14 @@ func (s *Storage) currentNotMembershipStateEventsInAllRooms(txn *sqlx.Tx, eventT
}

func (s *Storage) Accumulate(roomID, prevBatch string, timeline []json.RawMessage) (numNew int, timelineNIDs []int64, err error) {
return s.accumulator.Accumulate(roomID, prevBatch, timeline)
if len(timeline) == 0 {
return 0, nil, nil
}
err = sqlutil.WithTransaction(s.accumulator.db, func(txn *sqlx.Tx) error {
numNew, timelineNIDs, err = s.accumulator.Accumulate(txn, roomID, prevBatch, timeline)
return err
})
return
}

func (s *Storage) Initialise(roomID string, state []json.RawMessage) (InitialiseResult, error) {
Expand Down

0 comments on commit c2f4b53

Please sign in to comment.