diff --git a/state/receipt_table.go b/state/receipt_table.go index 336b3e15..ccb48095 100644 --- a/state/receipt_table.go +++ b/state/receipt_table.go @@ -11,11 +11,13 @@ import ( ) type receiptEDU struct { - Type string `json:"type"` - Content map[string]struct { - Read map[string]receiptInfo `json:"m.read,omitempty"` - ReadPrivate map[string]receiptInfo `json:"m.read.private,omitempty"` - } `json:"content"` + Type string `json:"type"` + Content map[string]receiptContent `json:"content"` +} + +type receiptContent struct { + Read map[string]receiptInfo `json:"m.read,omitempty"` + ReadPrivate map[string]receiptInfo `json:"m.read.private,omitempty"` } type receiptInfo struct { @@ -164,29 +166,35 @@ func (t *ReceiptTable) bulkInsert(tableName string, txn *sqlx.Tx, receipts []int // client connections. func PackReceiptsIntoEDU(receipts []internal.Receipt) (json.RawMessage, error) { newReceiptEDU := receiptEDU{ - Type: "m.receipt", - Content: make(map[string]struct { - Read map[string]receiptInfo `json:"m.read,omitempty"` - ReadPrivate map[string]receiptInfo `json:"m.read.private,omitempty"` - }), + Type: "m.receipt", + Content: make(map[string]receiptContent), } for _, r := range receipts { + thisReceiptIsUnthreaded := r.ThreadID == "" receiptsForEvent := newReceiptEDU.Content[r.EventID] if r.IsPrivate { if receiptsForEvent.ReadPrivate == nil { receiptsForEvent.ReadPrivate = make(map[string]receiptInfo) } - receiptsForEvent.ReadPrivate[r.UserID] = receiptInfo{ - TS: r.TS, - ThreadID: r.ThreadID, + // MSC4102: always replace threaded receipts with unthreaded ones if there is a clash + _, receiptAlreadyExists := receiptsForEvent.ReadPrivate[r.UserID] + if !receiptAlreadyExists || (receiptAlreadyExists && thisReceiptIsUnthreaded) { + receiptsForEvent.ReadPrivate[r.UserID] = receiptInfo{ + TS: r.TS, + ThreadID: r.ThreadID, + } } } else { if receiptsForEvent.Read == nil { receiptsForEvent.Read = make(map[string]receiptInfo) } - receiptsForEvent.Read[r.UserID] = receiptInfo{ - TS: r.TS, - ThreadID: r.ThreadID, + // MSC4102: always replace threaded receipts with unthreaded ones if there is a clash + _, receiptAlreadyExists := receiptsForEvent.Read[r.UserID] + if !receiptAlreadyExists || (receiptAlreadyExists && thisReceiptIsUnthreaded) { + receiptsForEvent.Read[r.UserID] = receiptInfo{ + TS: r.TS, + ThreadID: r.ThreadID, + } } } newReceiptEDU.Content[r.EventID] = receiptsForEvent diff --git a/state/receipt_table_test.go b/state/receipt_table_test.go index 63ac3bf9..cb0ea6e8 100644 --- a/state/receipt_table_test.go +++ b/state/receipt_table_test.go @@ -31,6 +31,182 @@ func parsedReceiptsEqual(t *testing.T, got, want []internal.Receipt) { } } +func TestReceiptPacking(t *testing.T) { + testCases := []struct { + receipts []internal.Receipt + wantEDU receiptEDU + name string + }{ + { + name: "single receipt", + receipts: []internal.Receipt{ + { + RoomID: "!foo", + EventID: "$bar", + UserID: "@baz", + TS: 42, + }, + }, + wantEDU: receiptEDU{ + Type: "m.receipt", + Content: map[string]receiptContent{ + "$bar": { + Read: map[string]receiptInfo{ + "@baz": { + TS: 42, + }, + }, + }, + }, + }, + }, + { + name: "two distinct receipt", + receipts: []internal.Receipt{ + { + RoomID: "!foo", + EventID: "$bar", + UserID: "@baz", + TS: 42, + }, + { + RoomID: "!foo2", + EventID: "$bar2", + UserID: "@baz2", + TS: 422, + }, + }, + wantEDU: receiptEDU{ + Type: "m.receipt", + Content: map[string]receiptContent{ + "$bar": { + Read: map[string]receiptInfo{ + "@baz": { + TS: 42, + }, + }, + }, + "$bar2": { + Read: map[string]receiptInfo{ + "@baz2": { + TS: 422, + }, + }, + }, + }, + }, + }, + { + name: "MSC4102: unthreaded wins when threaded first", + receipts: []internal.Receipt{ + { + RoomID: "!foo", + EventID: "$bar", + UserID: "@baz", + TS: 42, + ThreadID: "thread_id", + }, + { + RoomID: "!foo", + EventID: "$bar", + UserID: "@baz", + TS: 420, + }, + }, + wantEDU: receiptEDU{ + Type: "m.receipt", + Content: map[string]receiptContent{ + "$bar": { + Read: map[string]receiptInfo{ + "@baz": { + TS: 420, + }, + }, + }, + }, + }, + }, + { + name: "MSC4102: unthreaded wins when unthreaded first", + receipts: []internal.Receipt{ + { + RoomID: "!foo", + EventID: "$bar", + UserID: "@baz", + TS: 420, + }, + { + RoomID: "!foo", + EventID: "$bar", + UserID: "@baz", + TS: 42, + ThreadID: "thread_id", + }, + }, + wantEDU: receiptEDU{ + Type: "m.receipt", + Content: map[string]receiptContent{ + "$bar": { + Read: map[string]receiptInfo{ + "@baz": { + TS: 420, + }, + }, + }, + }, + }, + }, + { + name: "MSC4102: unthreaded wins in private receipts when unthreaded first", + receipts: []internal.Receipt{ + { + RoomID: "!foo", + EventID: "$bar", + UserID: "@baz", + TS: 420, + IsPrivate: true, + }, + { + RoomID: "!foo", + EventID: "$bar", + UserID: "@baz", + TS: 42, + ThreadID: "thread_id", + IsPrivate: true, + }, + }, + wantEDU: receiptEDU{ + Type: "m.receipt", + Content: map[string]receiptContent{ + "$bar": { + ReadPrivate: map[string]receiptInfo{ + "@baz": { + TS: 420, + }, + }, + }, + }, + }, + }, + } + for _, tc := range testCases { + edu, err := PackReceiptsIntoEDU(tc.receipts) + if err != nil { + t.Fatalf("%s: PackReceiptsIntoEDU: %s", tc.name, err) + } + gotEDU := receiptEDU{ + Type: "m.receipt", + Content: make(map[string]receiptContent), + } + if err := json.Unmarshal(edu, &gotEDU); err != nil { + t.Fatalf("%s: json.Unmarshal: %s", tc.name, err) + } + if !reflect.DeepEqual(gotEDU, tc.wantEDU) { + t.Errorf("%s: EDU mismatch, got %+v\n want %+v", tc.name, gotEDU, tc.wantEDU) + } + } +} + func TestReceiptTable(t *testing.T) { db, close := connectToDB(t) defer close()