diff --git a/integration_tests/commands/async/getex_test.go b/integration_tests/commands/async/getex_test.go index a02597dad..8266cc671 100644 --- a/integration_tests/commands/async/getex_test.go +++ b/integration_tests/commands/async/getex_test.go @@ -15,8 +15,8 @@ func TestGetEx(t *testing.T) { Etime10 := strconv.FormatInt(time.Now().Unix()+10, 10) testCases := []struct { - name string - commands []string + name string + commands []string expected []interface{} assertType []string delay []time.Duration diff --git a/integration_tests/commands/async/json_test.go b/integration_tests/commands/async/json_test.go index a92e60b65..59f5cbc85 100644 --- a/integration_tests/commands/async/json_test.go +++ b/integration_tests/commands/async/json_test.go @@ -862,8 +862,8 @@ func TestJsonNummultby(t *testing.T) { invalidArgMessage := "ERR wrong number of arguments for 'json.nummultby' command" testCases := []struct { - name string - commands []string + name string + commands []string expected []interface{} assertType []string }{ @@ -1021,9 +1021,9 @@ func TestJSONNumIncrBy(t *testing.T) { defer conn.Close() invalidArgMessage := "ERR wrong number of arguments for 'json.numincrby' command" testCases := []struct { - name string - setupData string - commands []string + name string + setupData string + commands []string expected []interface{} assertType []string cleanUp []string diff --git a/integration_tests/commands/async/set_data_cmd_test.go b/integration_tests/commands/async/set_data_cmd_test.go index d6ffaa4f6..cb03863d8 100644 --- a/integration_tests/commands/async/set_data_cmd_test.go +++ b/integration_tests/commands/async/set_data_cmd_test.go @@ -30,8 +30,8 @@ func TestSetDataCommand(t *testing.T) { defer conn.Close() testCases := []struct { - name string - cmd []string + name string + cmd []string expected []interface{} assertType []string delay []time.Duration diff --git a/integration_tests/commands/async/set_test.go b/integration_tests/commands/async/set_test.go index 611b937c1..98f1b5082 100644 --- a/integration_tests/commands/async/set_test.go +++ b/integration_tests/commands/async/set_test.go @@ -20,12 +20,12 @@ func TestSet(t *testing.T) { testCases := []TestCase{ { - name: "Set and Get Simple Value", + name: "Set and Get Simple Cmd", commands: []string{"SET k v", "GET k"}, expected: []interface{}{"OK", "v"}, }, { - name: "Set and Get Integer Value", + name: "Set and Get Integer Cmd", commands: []string{"SET k 123456789", "GET k"}, expected: []interface{}{"OK", int64(123456789)}, }, @@ -146,31 +146,31 @@ func TestSetWithExat(t *testing.T) { func(t *testing.T) { // deleteTestKeys([]string{"k"}, store) FireCommand(conn, "DEL k") - assert.Equal(t, "OK", FireCommand(conn, "SET k v EXAT "+Etime), "Value mismatch for cmd SET k v EXAT "+Etime) - assert.Equal(t, "v", FireCommand(conn, "GET k"), "Value mismatch for cmd GET k") - assert.Assert(t, FireCommand(conn, "TTL k").(int64) <= 5, "Value mismatch for cmd TTL k") + assert.Equal(t, "OK", FireCommand(conn, "SET k v EXAT "+Etime), "Cmd mismatch for cmd SET k v EXAT "+Etime) + assert.Equal(t, "v", FireCommand(conn, "GET k"), "Cmd mismatch for cmd GET k") + assert.Assert(t, FireCommand(conn, "TTL k").(int64) <= 5, "Cmd mismatch for cmd TTL k") time.Sleep(3 * time.Second) - assert.Assert(t, FireCommand(conn, "TTL k").(int64) <= 3, "Value mismatch for cmd TTL k") + assert.Assert(t, FireCommand(conn, "TTL k").(int64) <= 3, "Cmd mismatch for cmd TTL k") time.Sleep(3 * time.Second) - assert.Equal(t, "(nil)", FireCommand(conn, "GET k"), "Value mismatch for cmd GET k") - assert.Equal(t, int64(-2), FireCommand(conn, "TTL k"), "Value mismatch for cmd TTL k") + assert.Equal(t, "(nil)", FireCommand(conn, "GET k"), "Cmd mismatch for cmd GET k") + assert.Equal(t, int64(-2), FireCommand(conn, "TTL k"), "Cmd mismatch for cmd TTL k") }) t.Run("SET with invalid EXAT expires key immediately", func(t *testing.T) { // deleteTestKeys([]string{"k"}, store) FireCommand(conn, "DEL k") - assert.Equal(t, "OK", FireCommand(conn, "SET k v EXAT "+BadTime), "Value mismatch for cmd SET k v EXAT "+BadTime) - assert.Equal(t, "(nil)", FireCommand(conn, "GET k"), "Value mismatch for cmd GET k") - assert.Equal(t, int64(-2), FireCommand(conn, "TTL k"), "Value mismatch for cmd TTL k") + assert.Equal(t, "OK", FireCommand(conn, "SET k v EXAT "+BadTime), "Cmd mismatch for cmd SET k v EXAT "+BadTime) + assert.Equal(t, "(nil)", FireCommand(conn, "GET k"), "Cmd mismatch for cmd GET k") + assert.Equal(t, int64(-2), FireCommand(conn, "TTL k"), "Cmd mismatch for cmd TTL k") }) t.Run("SET with EXAT and PXAT returns syntax error", func(t *testing.T) { // deleteTestKeys([]string{"k"}, store) FireCommand(conn, "DEL k") - assert.Equal(t, "ERR syntax error", FireCommand(conn, "SET k v PXAT "+Etime+" EXAT "+Etime), "Value mismatch for cmd SET k v PXAT "+Etime+" EXAT "+Etime) - assert.Equal(t, "(nil)", FireCommand(conn, "GET k"), "Value mismatch for cmd GET k") + assert.Equal(t, "ERR syntax error", FireCommand(conn, "SET k v PXAT "+Etime+" EXAT "+Etime), "Cmd mismatch for cmd SET k v PXAT "+Etime+" EXAT "+Etime) + assert.Equal(t, "(nil)", FireCommand(conn, "GET k"), "Cmd mismatch for cmd GET k") }) } @@ -187,7 +187,7 @@ func TestWithKeepTTLFlag(t *testing.T) { for i := 0; i < len(tcase.commands); i++ { cmd := tcase.commands[i] out := tcase.expected[i] - assert.Equal(t, out, FireCommand(conn, cmd), "Value mismatch for cmd %s\n.", cmd) + assert.Equal(t, out, FireCommand(conn, cmd), "Cmd mismatch for cmd %s\n.", cmd) } } @@ -196,5 +196,5 @@ func TestWithKeepTTLFlag(t *testing.T) { cmd := "GET k" out := "(nil)" - assert.Equal(t, out, FireCommand(conn, cmd), "Value mismatch for cmd %s\n.", cmd) + assert.Equal(t, out, FireCommand(conn, cmd), "Cmd mismatch for cmd %s\n.", cmd) } diff --git a/integration_tests/commands/async/touch_test.go b/integration_tests/commands/async/touch_test.go index b1c7c98d9..963cb20cd 100644 --- a/integration_tests/commands/async/touch_test.go +++ b/integration_tests/commands/async/touch_test.go @@ -12,14 +12,14 @@ func TestTouch(t *testing.T) { defer conn.Close() testCases := []struct { - name string - commands []string + name string + commands []string expected []interface{} assertType []string delay []time.Duration }{ { - name: "Touch Simple Value", + name: "Touch Simple Cmd", commands: []string{"SET foo bar", "OBJECT IDLETIME foo", "TOUCH foo", "OBJECT IDLETIME foo"}, expected: []interface{}{"OK", int64(2), int64(1), int64(0)}, assertType: []string{"equal", "assert", "equal", "assert"}, diff --git a/integration_tests/commands/http/getex_test.go b/integration_tests/commands/http/getex_test.go index 9bbd42fee..e93552755 100644 --- a/integration_tests/commands/http/getex_test.go +++ b/integration_tests/commands/http/getex_test.go @@ -14,8 +14,8 @@ func TestGetEx(t *testing.T) { Etime10 := strconv.FormatInt(time.Now().Unix()+10, 10) testCases := []struct { - name string - commands []HTTPCommand + name string + commands []HTTPCommand expected []interface{} assertType []string delay []time.Duration diff --git a/integration_tests/commands/http/set_test.go b/integration_tests/commands/http/set_test.go index bf5d8d6d0..ed367a1fd 100644 --- a/integration_tests/commands/http/set_test.go +++ b/integration_tests/commands/http/set_test.go @@ -20,7 +20,7 @@ func TestSet(t *testing.T) { testCases := []TestCase{ { - name: "Set and Get Simple Value", + name: "Set and Get Simple Cmd", commands: []HTTPCommand{ {Command: "SET", Body: map[string]interface{}{"key": "k", "value": "v"}}, {Command: "GET", Body: map[string]interface{}{"key": "k"}}, @@ -28,7 +28,7 @@ func TestSet(t *testing.T) { expected: []interface{}{"OK", "v"}, }, { - name: "Set and Get Integer Value", + name: "Set and Get Integer Cmd", commands: []HTTPCommand{ {Command: "SET", Body: map[string]interface{}{"key": "k", "value": 123456789}}, {Command: "GET", Body: map[string]interface{}{"key": "k"}}, diff --git a/internal/clientio/client_identifier.go b/internal/clientio/client_identifier.go new file mode 100644 index 000000000..91ea3bce0 --- /dev/null +++ b/internal/clientio/client_identifier.go @@ -0,0 +1,13 @@ +package clientio + +type ClientIdentifier struct { + ClientIdentifierID int + IsHTTPClient bool +} + +func NewClientIdentifier(clientIdentifierID int, isHTTPClient bool) ClientIdentifier { + return ClientIdentifier{ + ClientIdentifierID: clientIdentifierID, + IsHTTPClient: isHTTPClient, + } +} diff --git a/internal/cmd/cmds.go b/internal/cmd/cmds.go index d0dfae6fa..d677e685f 100644 --- a/internal/cmd/cmds.go +++ b/internal/cmd/cmds.go @@ -1,5 +1,11 @@ package cmd +import ( + "fmt" + "github.com/dgryski/go-farm" + "strings" +) + type RedisCmd struct { RequestID uint32 Cmd string @@ -10,3 +16,17 @@ type RedisCmds struct { Cmds []*RedisCmd RequestID uint32 } + +// GetFingerprint returns a 32-bit fingerprint of the command and its arguments. +func (cmd *RedisCmd) GetFingerprint() uint32 { + return farm.Fingerprint32([]byte(fmt.Sprintf("%s-%s", cmd.Cmd, strings.Join(cmd.Args, " ")))) +} + +// GetKey Returns the key which the command operates on. +// +// TODO: This is a naive implementation which assumes that the first argument is the key. +// This is not true for all commands, however, for now this is only used by the watch manager, +// which as of now only supports a small subset of commands (all of which fit this implementation). +func (cmd *RedisCmd) GetKey() string { + return cmd.Args[0] +} diff --git a/internal/comm/client.go b/internal/comm/client.go index 33601ca67..1cadc533e 100644 --- a/internal/comm/client.go +++ b/internal/comm/client.go @@ -8,6 +8,12 @@ import ( "github.com/dicedb/dice/internal/cmd" ) +type CmdWatchResponse struct { + ClientIdentifierID uint32 + Result interface{} + Error error +} + type QwatchResponse struct { ClientIdentifierID uint32 Result interface{} diff --git a/internal/eval/store_eval.go b/internal/eval/store_eval.go index d55659b1e..07a858f44 100644 --- a/internal/eval/store_eval.go +++ b/internal/eval/store_eval.go @@ -201,7 +201,7 @@ func evalGET(args []string, store *dstore.Store) *EvalResponse { // Decode and return the value based on its encoding switch _, oEnc := object.ExtractTypeEncoding(obj); oEnc { case object.ObjEncodingInt: - // Value is stored as an int64, so use type assertion + // Cmd is stored as an int64, so use type assertion if val, ok := obj.Value.(int64); ok { return &EvalResponse{ Result: val, @@ -215,7 +215,7 @@ func evalGET(args []string, store *dstore.Store) *EvalResponse { } case object.ObjEncodingEmbStr, object.ObjEncodingRaw: - // Value is stored as a string, use type assertion + // Cmd is stored as a string, use type assertion if val, ok := obj.Value.(string); ok { return &EvalResponse{ Result: val, @@ -228,7 +228,7 @@ func evalGET(args []string, store *dstore.Store) *EvalResponse { } case object.ObjEncodingByteArray: - // Value is stored as a bytearray, use type assertion + // Cmd is stored as a bytearray, use type assertion if val, ok := obj.Value.(*ByteArray); ok { return &EvalResponse{ Result: string(val.data), diff --git a/internal/querymanager/query_manager.go b/internal/querymanager/query_manager.go index a1c61b246..f91fa3f81 100644 --- a/internal/querymanager/query_manager.go +++ b/internal/querymanager/query_manager.go @@ -64,11 +64,6 @@ type ( Query string `json:"query"` Data []any `json:"data"` } - - ClientIdentifier struct { - ClientIdentifierID int - IsHTTPClient bool - } ) var ( @@ -79,13 +74,6 @@ var ( AdhocQueryChan chan AdhocQuery ) -func NewClientIdentifier(clientIdentifierID int, isHTTPClient bool) ClientIdentifier { - return ClientIdentifier{ - ClientIdentifierID: clientIdentifierID, - IsHTTPClient: isHTTPClient, - } -} - // NewQueryManager initializes a new Manager. func NewQueryManager(logger *slog.Logger) *Manager { QuerySubscriptionChan = make(chan QuerySubscription) @@ -130,11 +118,11 @@ func (m *Manager) listenForSubscriptions(ctx context.Context) { for { select { case event := <-QuerySubscriptionChan: - var client ClientIdentifier + var client clientio.ClientIdentifier if event.QwatchClientChan != nil { - client = NewClientIdentifier(int(event.ClientIdentifierID), true) + client = clientio.NewClientIdentifier(int(event.ClientIdentifierID), true) } else { - client = NewClientIdentifier(event.ClientFD, false) + client = clientio.NewClientIdentifier(event.ClientFD, false) } if event.Subscribe { @@ -224,7 +212,7 @@ func (m *Manager) notifyClients(query *sql.DSQLQuery, clients *sync.Map, queryRe clients.Range(func(clientKey, clientVal interface{}) bool { // Identify the type of client and respond accordingly - switch clientIdentifier := clientKey.(ClientIdentifier); { + switch clientIdentifier := clientKey.(clientio.ClientIdentifier); { case clientIdentifier.IsHTTPClient: qwatchClientResponseChannel := clientVal.(chan comm.QwatchResponse) qwatchClientResponseChannel <- comm.QwatchResponse{ @@ -274,7 +262,7 @@ func (m *Manager) sendWithRetry(query *sql.DSQLQuery, clientFD int, data []byte) slog.Int("client", clientFD), slog.Any("error", err), ) - m.removeWatcher(query, NewClientIdentifier(clientFD, false), nil) + m.removeWatcher(query, clientio.NewClientIdentifier(clientFD, false), nil) return } } @@ -297,7 +285,7 @@ func (m *Manager) serveAdhocQueries(ctx context.Context) { } // addWatcher adds a client as a watcher to a query. -func (m *Manager) addWatcher(query *sql.DSQLQuery, clientIdentifier ClientIdentifier, +func (m *Manager) addWatcher(query *sql.DSQLQuery, clientIdentifier clientio.ClientIdentifier, qwatchClientChan chan comm.QwatchResponse, cacheChan chan *[]struct { Key string Value *object.Obj @@ -327,7 +315,7 @@ func (m *Manager) addWatcher(query *sql.DSQLQuery, clientIdentifier ClientIdenti } // removeWatcher removes a client from the watchlist for a query. -func (m *Manager) removeWatcher(query *sql.DSQLQuery, clientIdentifier ClientIdentifier, +func (m *Manager) removeWatcher(query *sql.DSQLQuery, clientIdentifier clientio.ClientIdentifier, qwatchClientChan chan comm.QwatchResponse) { queryString := query.String() if clients, ok := m.WatchList.Load(queryString); ok { diff --git a/internal/store/store.go b/internal/store/store.go index 46fbe2bbb..bb0119dd3 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -22,6 +22,11 @@ type QueryWatchEvent struct { Value object.Obj } +type CmdWatchEvent struct { + Cmd string + AffectedKey string +} + type Store struct { store *swiss.Map[string, *object.Obj] expires *swiss.Map[*object.Obj, uint64] // Does not need to be thread-safe as it is only accessed by a single thread. diff --git a/internal/watchmanager/watch_manager.go b/internal/watchmanager/watch_manager.go new file mode 100644 index 000000000..322cbf6d8 --- /dev/null +++ b/internal/watchmanager/watch_manager.go @@ -0,0 +1,182 @@ +package watchmanager + +import ( + "context" + "github.com/dicedb/dice/internal/clientio" + "github.com/dicedb/dice/internal/cmd" + "github.com/dicedb/dice/internal/comm" + dstore "github.com/dicedb/dice/internal/store" + "log/slog" + "sync" +) + +type ( + WatchSubscription struct { + Subscribe bool // Subscribe is true for subscribe, false for unsubscribe + WatchCmd cmd.RedisCmd // WatchCmd Represents a unique key for each watch artifact, only populated for subscriptions. + Fingerprint uint32 // Fingerprint is a unique identifier for each watch artifact, only populated for unsubscriptions. + ClientFD int // ClientFD is the file descriptor of the client connection + CmdWatchClientChan chan comm.CmdWatchResponse // CmdWatchClientChan is the generic channel for HTTP/Websockets etc. + ClientIdentifierID uint32 // ClientIdentifierID Helps identify CmdWatch client on httpserver side + } + + Manager struct { + querySubscriptionMap map[string]map[uint32]bool // querySubscriptionMap is a map of Key -> [fingerprint1, fingerprint2, ...] + tcpSubscriptionMap map[uint32]map[clientio.ClientIdentifier]bool // tcpSubscriptionMap is a map of fingerprint -> [client1, client2, ...] + fingerprintCmdMap map[uint32]cmd.RedisCmd // fingerprintCmdMap is a map of fingerprint -> RedisCmd + mu sync.RWMutex + logger *slog.Logger + } +) + +var ( + CmdWatchSubscriptionChan chan WatchSubscription +) + +func NewManager(logger *slog.Logger) *Manager { + CmdWatchSubscriptionChan = make(chan WatchSubscription) + return &Manager{ + querySubscriptionMap: make(map[string]map[uint32]bool), + tcpSubscriptionMap: make(map[uint32]map[clientio.ClientIdentifier]bool), + fingerprintCmdMap: make(map[uint32]cmd.RedisCmd), + logger: logger, + } +} + +// Run starts the watch manager, listening for subscription requests and events +func (m *Manager) Run(ctx context.Context, eventChan chan dstore.CmdWatchEvent) { + var wg sync.WaitGroup + + wg.Add(2) + go func() { + defer wg.Done() + m.listenForSubscriptions(ctx) + }() + + go func() { + defer wg.Done() + m.listenForEvents(ctx, eventChan) + }() + + wg.Wait() +} + +// listenForSubscriptions handles incoming subscription requests +func (m *Manager) listenForSubscriptions(ctx context.Context) { + for { + select { + case sub := <-CmdWatchSubscriptionChan: + if sub.Subscribe { + m.handleSubscription(sub) + } else { + m.handleUnsubscription(sub) + } + case <-ctx.Done(): + return + } + } +} + +// handleSubscription processes a new subscription request +func (m *Manager) handleSubscription(sub WatchSubscription) { + fingerprint := sub.WatchCmd.GetFingerprint() + key := sub.WatchCmd.GetKey() + + client := clientio.NewClientIdentifier(sub.ClientFD, false) + + m.mu.Lock() + defer m.mu.Unlock() + + // Add fingerprint to querySubscriptionMap + if m.querySubscriptionMap[key] == nil { + m.querySubscriptionMap[key] = make(map[uint32]bool) + } + m.querySubscriptionMap[key][fingerprint] = true + + // Add RedisCmd to fingerprintCmdMap + m.fingerprintCmdMap[fingerprint] = sub.WatchCmd + + // Add clientID to tcpSubscriptionMap + if m.tcpSubscriptionMap[fingerprint] == nil { + m.tcpSubscriptionMap[fingerprint] = make(map[clientio.ClientIdentifier]bool) + } + m.tcpSubscriptionMap[fingerprint][client] = true +} + +// handleUnsubscription processes an unsubscription request +func (m *Manager) handleUnsubscription(sub WatchSubscription) { + fingerprint := sub.Fingerprint + client := clientio.NewClientIdentifier(sub.ClientFD, false) + + m.mu.Lock() + defer m.mu.Unlock() + + // Remove clientID from tcpSubscriptionMap + if clients, ok := m.tcpSubscriptionMap[fingerprint]; ok { + delete(clients, client) + // If there are no more clients listening to this fingerprint, remove it from the map + if len(clients) == 0 { + // Remove the fingerprint from tcpSubscriptionMap + delete(m.tcpSubscriptionMap, fingerprint) + // Also remove the fingerprint from fingerprintCmdMap + delete(m.fingerprintCmdMap, fingerprint) + } else { + // Update the map with the new set of clients + m.tcpSubscriptionMap[fingerprint] = clients + } + } + + // Remove fingerprint from querySubscriptionMap + if redisCmd, ok := m.fingerprintCmdMap[fingerprint]; ok { + key := redisCmd.GetKey() + if fingerprints, ok := m.querySubscriptionMap[key]; ok { + // Remove the fingerprint from the list of fingerprints listening to this key + delete(fingerprints, fingerprint) + // If there are no more fingerprints listening to this key, remove it from the map + if len(fingerprints) == 0 { + delete(m.querySubscriptionMap, key) + } else { + // Update the map with the new set of fingerprints + m.querySubscriptionMap[key] = fingerprints + } + } + } +} + +func (m *Manager) listenForEvents(ctx context.Context, eventChan chan dstore.CmdWatchEvent) { + affectedCmdMap := map[string]map[string]bool{"SET": {"GET": true}} + for { + select { + case <-ctx.Done(): + return + case event := <-eventChan: + m.mu.RLock() + + // Check if any watch commands are listening to updates on this key. + if _, ok := m.querySubscriptionMap[event.AffectedKey]; ok { + // iterate through all command fingerprints that are listening to this key + for fingerprint := range m.querySubscriptionMap[event.AffectedKey] { + // Check if the command associated with this fingerprint actually needs to be executed for this event. + // For instance, if the event is a SET, only execute GET commands need to be executed. This also + // helps us handle cases where a key might get updated by an unrelated command which makes it + // incompatible with the watched command. + if affectedCommands, ok := affectedCmdMap[event.Cmd]; ok { + if _, ok := affectedCommands[m.fingerprintCmdMap[fingerprint].Cmd]; ok { + // TODO: execute the command, store the result, send to clients + if clients, ok := m.tcpSubscriptionMap[fingerprint]; ok { + for client := range clients { + notifyClient(client, result) + } + } + } + } else { + m.logger.Error("Received a watch event for an unknown command type", + slog.String("cmd", event.Cmd)) + } + } + } + + m.mu.RUnlock() + } + } +}