Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds WatchManager to handle reactive commands and support for GET.WATCH #924

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions integration_tests/commands/async/getex_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ func TestGetEx(t *testing.T) {
Etime10 := strconv.FormatInt(time.Now().Unix()+10, 10)

testCases := []struct {
name string
commands []string
name string
commands []string
expected []interface{}
assertType []string
delay []time.Duration
Expand Down
10 changes: 5 additions & 5 deletions integration_tests/commands/async/json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -862,8 +862,8 @@ func TestJsonNummultby(t *testing.T) {
invalidArgMessage := "ERR wrong number of arguments for 'json.nummultby' command"

testCases := []struct {
name string
commands []string
name string
commands []string
expected []interface{}
assertType []string
}{
Expand Down Expand Up @@ -1021,9 +1021,9 @@ func TestJSONNumIncrBy(t *testing.T) {
defer conn.Close()
invalidArgMessage := "ERR wrong number of arguments for 'json.numincrby' command"
testCases := []struct {
name string
setupData string
commands []string
name string
setupData string
commands []string
expected []interface{}
assertType []string
cleanUp []string
Expand Down
4 changes: 2 additions & 2 deletions integration_tests/commands/async/set_data_cmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ func TestSetDataCommand(t *testing.T) {
defer conn.Close()

testCases := []struct {
name string
cmd []string
name string
cmd []string
expected []interface{}
assertType []string
delay []time.Duration
Expand Down
30 changes: 15 additions & 15 deletions integration_tests/commands/async/set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ func TestSet(t *testing.T) {

testCases := []TestCase{
{
name: "Set and Get Simple Value",
name: "Set and Get Simple Cmd",
commands: []string{"SET k v", "GET k"},
expected: []interface{}{"OK", "v"},
},
{
name: "Set and Get Integer Value",
name: "Set and Get Integer Cmd",
commands: []string{"SET k 123456789", "GET k"},
expected: []interface{}{"OK", int64(123456789)},
},
Expand Down Expand Up @@ -146,31 +146,31 @@ func TestSetWithExat(t *testing.T) {
func(t *testing.T) {
// deleteTestKeys([]string{"k"}, store)
FireCommand(conn, "DEL k")
assert.Equal(t, "OK", FireCommand(conn, "SET k v EXAT "+Etime), "Value mismatch for cmd SET k v EXAT "+Etime)
assert.Equal(t, "v", FireCommand(conn, "GET k"), "Value mismatch for cmd GET k")
assert.Assert(t, FireCommand(conn, "TTL k").(int64) <= 5, "Value mismatch for cmd TTL k")
assert.Equal(t, "OK", FireCommand(conn, "SET k v EXAT "+Etime), "Cmd mismatch for cmd SET k v EXAT "+Etime)
assert.Equal(t, "v", FireCommand(conn, "GET k"), "Cmd mismatch for cmd GET k")
assert.Assert(t, FireCommand(conn, "TTL k").(int64) <= 5, "Cmd mismatch for cmd TTL k")
time.Sleep(3 * time.Second)
assert.Assert(t, FireCommand(conn, "TTL k").(int64) <= 3, "Value mismatch for cmd TTL k")
assert.Assert(t, FireCommand(conn, "TTL k").(int64) <= 3, "Cmd mismatch for cmd TTL k")
time.Sleep(3 * time.Second)
assert.Equal(t, "(nil)", FireCommand(conn, "GET k"), "Value mismatch for cmd GET k")
assert.Equal(t, int64(-2), FireCommand(conn, "TTL k"), "Value mismatch for cmd TTL k")
assert.Equal(t, "(nil)", FireCommand(conn, "GET k"), "Cmd mismatch for cmd GET k")
assert.Equal(t, int64(-2), FireCommand(conn, "TTL k"), "Cmd mismatch for cmd TTL k")
})

t.Run("SET with invalid EXAT expires key immediately",
func(t *testing.T) {
// deleteTestKeys([]string{"k"}, store)
FireCommand(conn, "DEL k")
assert.Equal(t, "OK", FireCommand(conn, "SET k v EXAT "+BadTime), "Value mismatch for cmd SET k v EXAT "+BadTime)
assert.Equal(t, "(nil)", FireCommand(conn, "GET k"), "Value mismatch for cmd GET k")
assert.Equal(t, int64(-2), FireCommand(conn, "TTL k"), "Value mismatch for cmd TTL k")
assert.Equal(t, "OK", FireCommand(conn, "SET k v EXAT "+BadTime), "Cmd mismatch for cmd SET k v EXAT "+BadTime)
assert.Equal(t, "(nil)", FireCommand(conn, "GET k"), "Cmd mismatch for cmd GET k")
assert.Equal(t, int64(-2), FireCommand(conn, "TTL k"), "Cmd mismatch for cmd TTL k")
})

t.Run("SET with EXAT and PXAT returns syntax error",
func(t *testing.T) {
// deleteTestKeys([]string{"k"}, store)
FireCommand(conn, "DEL k")
assert.Equal(t, "ERR syntax error", FireCommand(conn, "SET k v PXAT "+Etime+" EXAT "+Etime), "Value mismatch for cmd SET k v PXAT "+Etime+" EXAT "+Etime)
assert.Equal(t, "(nil)", FireCommand(conn, "GET k"), "Value mismatch for cmd GET k")
assert.Equal(t, "ERR syntax error", FireCommand(conn, "SET k v PXAT "+Etime+" EXAT "+Etime), "Cmd mismatch for cmd SET k v PXAT "+Etime+" EXAT "+Etime)
assert.Equal(t, "(nil)", FireCommand(conn, "GET k"), "Cmd mismatch for cmd GET k")
})
}

Expand All @@ -187,7 +187,7 @@ func TestWithKeepTTLFlag(t *testing.T) {
for i := 0; i < len(tcase.commands); i++ {
cmd := tcase.commands[i]
out := tcase.expected[i]
assert.Equal(t, out, FireCommand(conn, cmd), "Value mismatch for cmd %s\n.", cmd)
assert.Equal(t, out, FireCommand(conn, cmd), "Cmd mismatch for cmd %s\n.", cmd)
}
}

Expand All @@ -196,5 +196,5 @@ func TestWithKeepTTLFlag(t *testing.T) {
cmd := "GET k"
out := "(nil)"

assert.Equal(t, out, FireCommand(conn, cmd), "Value mismatch for cmd %s\n.", cmd)
assert.Equal(t, out, FireCommand(conn, cmd), "Cmd mismatch for cmd %s\n.", cmd)
}
6 changes: 3 additions & 3 deletions integration_tests/commands/async/touch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ func TestTouch(t *testing.T) {
defer conn.Close()

testCases := []struct {
name string
commands []string
name string
commands []string
expected []interface{}
assertType []string
delay []time.Duration
}{
{
name: "Touch Simple Value",
name: "Touch Simple Cmd",
commands: []string{"SET foo bar", "OBJECT IDLETIME foo", "TOUCH foo", "OBJECT IDLETIME foo"},
expected: []interface{}{"OK", int64(2), int64(1), int64(0)},
assertType: []string{"equal", "assert", "equal", "assert"},
Expand Down
4 changes: 2 additions & 2 deletions integration_tests/commands/http/getex_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ func TestGetEx(t *testing.T) {
Etime10 := strconv.FormatInt(time.Now().Unix()+10, 10)

testCases := []struct {
name string
commands []HTTPCommand
name string
commands []HTTPCommand
expected []interface{}
assertType []string
delay []time.Duration
Expand Down
4 changes: 2 additions & 2 deletions integration_tests/commands/http/set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ func TestSet(t *testing.T) {

testCases := []TestCase{
{
name: "Set and Get Simple Value",
name: "Set and Get Simple Cmd",
commands: []HTTPCommand{
{Command: "SET", Body: map[string]interface{}{"key": "k", "value": "v"}},
{Command: "GET", Body: map[string]interface{}{"key": "k"}},
},
expected: []interface{}{"OK", "v"},
},
{
name: "Set and Get Integer Value",
name: "Set and Get Integer Cmd",
commands: []HTTPCommand{
{Command: "SET", Body: map[string]interface{}{"key": "k", "value": 123456789}},
{Command: "GET", Body: map[string]interface{}{"key": "k"}},
Expand Down
13 changes: 13 additions & 0 deletions internal/clientio/client_identifier.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package clientio

type ClientIdentifier struct {
ClientIdentifierID int
IsHTTPClient bool
}

func NewClientIdentifier(clientIdentifierID int, isHTTPClient bool) ClientIdentifier {
return ClientIdentifier{
ClientIdentifierID: clientIdentifierID,
IsHTTPClient: isHTTPClient,
}
}
20 changes: 20 additions & 0 deletions internal/cmd/cmds.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
package cmd

import (
"fmt"
"github.com/dgryski/go-farm"
"strings"
)

type RedisCmd struct {
RequestID uint32
Cmd string
Expand All @@ -10,3 +16,17 @@ type RedisCmds struct {
Cmds []*RedisCmd
RequestID uint32
}

// GetFingerprint returns a 32-bit fingerprint of the command and its arguments.
func (cmd *RedisCmd) GetFingerprint() uint32 {
return farm.Fingerprint32([]byte(fmt.Sprintf("%s-%s", cmd.Cmd, strings.Join(cmd.Args, " "))))
}

// GetKey Returns the key which the command operates on.
//
// TODO: This is a naive implementation which assumes that the first argument is the key.
// This is not true for all commands, however, for now this is only used by the watch manager,
// which as of now only supports a small subset of commands (all of which fit this implementation).
func (cmd *RedisCmd) GetKey() string {
return cmd.Args[0]
}
6 changes: 6 additions & 0 deletions internal/comm/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ import (
"github.com/dicedb/dice/internal/cmd"
)

type CmdWatchResponse struct {
ClientIdentifierID uint32
Result interface{}
Error error
}

type QwatchResponse struct {
ClientIdentifierID uint32
Result interface{}
Expand Down
6 changes: 3 additions & 3 deletions internal/eval/store_eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ func evalGET(args []string, store *dstore.Store) *EvalResponse {
// Decode and return the value based on its encoding
switch _, oEnc := object.ExtractTypeEncoding(obj); oEnc {
case object.ObjEncodingInt:
// Value is stored as an int64, so use type assertion
// Cmd is stored as an int64, so use type assertion
if val, ok := obj.Value.(int64); ok {
return &EvalResponse{
Result: val,
Expand All @@ -215,7 +215,7 @@ func evalGET(args []string, store *dstore.Store) *EvalResponse {
}

case object.ObjEncodingEmbStr, object.ObjEncodingRaw:
// Value is stored as a string, use type assertion
// Cmd is stored as a string, use type assertion
if val, ok := obj.Value.(string); ok {
return &EvalResponse{
Result: val,
Expand All @@ -228,7 +228,7 @@ func evalGET(args []string, store *dstore.Store) *EvalResponse {
}

case object.ObjEncodingByteArray:
// Value is stored as a bytearray, use type assertion
// Cmd is stored as a bytearray, use type assertion
if val, ok := obj.Value.(*ByteArray); ok {
return &EvalResponse{
Result: string(val.data),
Expand Down
26 changes: 7 additions & 19 deletions internal/querymanager/query_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,6 @@ type (
Query string `json:"query"`
Data []any `json:"data"`
}

ClientIdentifier struct {
ClientIdentifierID int
IsHTTPClient bool
}
)

var (
Expand All @@ -79,13 +74,6 @@ var (
AdhocQueryChan chan AdhocQuery
)

func NewClientIdentifier(clientIdentifierID int, isHTTPClient bool) ClientIdentifier {
return ClientIdentifier{
ClientIdentifierID: clientIdentifierID,
IsHTTPClient: isHTTPClient,
}
}

// NewQueryManager initializes a new Manager.
func NewQueryManager(logger *slog.Logger) *Manager {
QuerySubscriptionChan = make(chan QuerySubscription)
Expand Down Expand Up @@ -130,11 +118,11 @@ func (m *Manager) listenForSubscriptions(ctx context.Context) {
for {
select {
case event := <-QuerySubscriptionChan:
var client ClientIdentifier
var client clientio.ClientIdentifier
if event.QwatchClientChan != nil {
client = NewClientIdentifier(int(event.ClientIdentifierID), true)
client = clientio.NewClientIdentifier(int(event.ClientIdentifierID), true)
} else {
client = NewClientIdentifier(event.ClientFD, false)
client = clientio.NewClientIdentifier(event.ClientFD, false)
}

if event.Subscribe {
Expand Down Expand Up @@ -224,7 +212,7 @@ func (m *Manager) notifyClients(query *sql.DSQLQuery, clients *sync.Map, queryRe

clients.Range(func(clientKey, clientVal interface{}) bool {
// Identify the type of client and respond accordingly
switch clientIdentifier := clientKey.(ClientIdentifier); {
switch clientIdentifier := clientKey.(clientio.ClientIdentifier); {
case clientIdentifier.IsHTTPClient:
qwatchClientResponseChannel := clientVal.(chan comm.QwatchResponse)
qwatchClientResponseChannel <- comm.QwatchResponse{
Expand Down Expand Up @@ -274,7 +262,7 @@ func (m *Manager) sendWithRetry(query *sql.DSQLQuery, clientFD int, data []byte)
slog.Int("client", clientFD),
slog.Any("error", err),
)
m.removeWatcher(query, NewClientIdentifier(clientFD, false), nil)
m.removeWatcher(query, clientio.NewClientIdentifier(clientFD, false), nil)
return
}
}
Expand All @@ -297,7 +285,7 @@ func (m *Manager) serveAdhocQueries(ctx context.Context) {
}

// addWatcher adds a client as a watcher to a query.
func (m *Manager) addWatcher(query *sql.DSQLQuery, clientIdentifier ClientIdentifier,
func (m *Manager) addWatcher(query *sql.DSQLQuery, clientIdentifier clientio.ClientIdentifier,
qwatchClientChan chan comm.QwatchResponse, cacheChan chan *[]struct {
Key string
Value *object.Obj
Expand Down Expand Up @@ -327,7 +315,7 @@ func (m *Manager) addWatcher(query *sql.DSQLQuery, clientIdentifier ClientIdenti
}

// removeWatcher removes a client from the watchlist for a query.
func (m *Manager) removeWatcher(query *sql.DSQLQuery, clientIdentifier ClientIdentifier,
func (m *Manager) removeWatcher(query *sql.DSQLQuery, clientIdentifier clientio.ClientIdentifier,
qwatchClientChan chan comm.QwatchResponse) {
queryString := query.String()
if clients, ok := m.WatchList.Load(queryString); ok {
Expand Down
5 changes: 5 additions & 0 deletions internal/store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ type QueryWatchEvent struct {
Value object.Obj
}

type CmdWatchEvent struct {
Cmd string
AffectedKey string
}

type Store struct {
store *swiss.Map[string, *object.Obj]
expires *swiss.Map[*object.Obj, uint64] // Does not need to be thread-safe as it is only accessed by a single thread.
Expand Down
Loading
Loading