Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 134 additions & 12 deletions helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ import (
intl "github.com/redis/rueidis/internal/cmds"
)

func slot(key string) uint16 {
return intl.Slot(key)
}

// MGetCache is a helper that consults the client-side caches with multiple keys by grouping keys within the same slot into multiple GETs
func MGetCache(client Client, ctx context.Context, ttl time.Duration, keys []string) (ret map[string]RedisMessage, err error) {
if len(keys) == 0 {
Expand Down Expand Up @@ -50,12 +54,7 @@ func MGet(client Client, ctx context.Context, keys []string) (ret map[string]Red
return clientMGet(client, ctx, client.B().Mget().Key(keys...).Build(), keys)
}

cmds := mgetcmdsp.Get(len(keys), len(keys))
defer mgetcmdsp.Put(cmds)
for i := range cmds.s {
cmds.s[i] = client.B().Get().Key(keys[i]).Build()
}
return doMultiGet(client, ctx, cmds.s, keys)
return clusterMGet(client, ctx, keys)
}

// MSet is a helper that consults the redis directly with multiple keys by grouping keys within the same slot into MSETs or multiple SETs
Expand Down Expand Up @@ -139,12 +138,7 @@ func JsonMGet(client Client, ctx context.Context, keys []string, path string) (r
return clientMGet(client, ctx, client.B().JsonMget().Key(keys...).Path(path).Build(), keys)
}

cmds := mgetcmdsp.Get(len(keys), len(keys))
defer mgetcmdsp.Put(cmds)
for i := range cmds.s {
cmds.s[i] = client.B().JsonGet().Key(keys[i]).Path(path).Build()
}
return doMultiGet(client, ctx, cmds.s, keys)
return clusterJsonMGet(client, ctx, keys, path)
}

// JsonMSet is a helper that consults redis directly with multiple keys by grouping keys within the same slot into JSON.MSETs or multiple JSON.SETs
Expand Down Expand Up @@ -277,6 +271,134 @@ func arrayToKV(m map[string]RedisMessage, arr []RedisMessage, keys []string) map
return m
}

func clusterMGet(client Client, ctx context.Context, keys []string) (ret map[string]RedisMessage, err error) {
ret = make(map[string]RedisMessage, len(keys))
if len(keys) == 0 {
return ret, nil
}

// Map slot -> index in cmds.s
slotIdx := make(map[uint16]int, len(keys)/2)
cmds := mgetcmdsp.Get(0, len(keys))
defer mgetcmdsp.Put(cmds)

for _, key := range keys {
s := slot(key)
idx, ok := slotIdx[s]
if !ok {
// first key in this slot: create a new MGET
idx = len(cmds.s)
slotIdx[s] = idx
cmds.s = append(cmds.s, client.B().Mget().Key(key).Build().Pin())
continue
}

c := cmds.s[idx]
args := c.Commands() // ["MGET", k1, k2, ...]
mk := client.B().Mget().Key(args[1]) // first existing key
for i := 2; i < len(args); i++ { // remaining existing keys
mk = mk.Key(args[i])
}
mk = mk.Key(key) // add the new key

intl.PutCompletedForce(c)
cmds.s[idx] = mk.Build().Pin()
}

resps := client.DoMulti(ctx, cmds.s...)
defer resultsp.Put(&redisresults{s: resps})

// Convert each response to an array once
values := make([][]RedisMessage, len(cmds.s))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only using:

  1. slotIdx map[uint16]int for slot → command index, and
  2. the pooled cmds.s slice for the actual Completed commands.

for i, resp := range resps {
values[i], err = resp.ToArray()
if err != nil {
return nil, err
}
}

// Track per-command index as we walk keys in order
pos := make([]int, len(cmds.s))
for _, key := range keys {
idx := slotIdx[slot(key)]
ret[key] = values[idx][pos[idx]]
pos[idx]++
}

for i := range cmds.s {
intl.PutCompletedForce(cmds.s[i])
}
return ret, nil
}

func clusterJsonMGet(client Client, ctx context.Context, keys []string, path string) (ret map[string]RedisMessage, err error) {
ret = make(map[string]RedisMessage, len(keys))
if len(keys) == 0 {
return ret, nil
}

// Map slot -> index in cmds.s
slotIdx := make(map[uint16]int, len(keys)/2)
cmds := mgetcmdsp.Get(0, len(keys))
defer mgetcmdsp.Put(cmds)

for _, key := range keys {
s := slot(key)
idx, ok := slotIdx[s]
if !ok {
// first key in this slot: create a new JSON.MGET
idx = len(cmds.s)
slotIdx[s] = idx
cmds.s = append(cmds.s, client.B().JsonMget().Key(key).Path(path).Build().Pin())
continue
}

// extend existing JSON.MGET for this slot by rebuilding the command with the new key
c := cmds.s[idx]
args := c.Commands() // ["JSON.MGET", k1, k2, ..., path]
if len(args) < 3 {
// Shouldn't happen, but guard anyway
intl.PutCompletedForce(c)
cmds.s[idx] = client.B().JsonMget().Key(key).Path(path).Build().Pin()
continue
}

// existing keys are args[1 : len(args)-1], final arg is path
jm := client.B().JsonMget().Key(args[1])
for i := 2; i < len(args)-1; i++ {
jm = jm.Key(args[i])
}
jm = jm.Key(key)
jp := jm.Path(path)

intl.PutCompletedForce(c)
cmds.s[idx] = jp.Build().Pin()
}

resps := client.DoMulti(ctx, cmds.s...)
defer resultsp.Put(&redisresults{s: resps})

values := make([][]RedisMessage, len(cmds.s))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @rueian
Sorry I've been caught up with other things and missed focused work on this
Do you think this allocation is still alright?
I cannot think of a better way for this

for i, resp := range resps {
values[i], err = resp.ToArray()
if err != nil {
return nil, err
}
}

pos := make([]int, len(cmds.s))
for _, key := range keys {
idx := slotIdx[slot(key)]
ret[key] = values[idx][pos[idx]]
pos[idx]++
}

for i := range cmds.s {
intl.PutCompletedForce(cmds.s[i])
}
return ret, nil
}

// ErrMSetNXNotSet is used in the MSetNX helper when the underlying MSETNX response is 0.
// Ref: https://redis.io/commands/msetnx/
var ErrMSetNXNotSet = errors.New("MSETNX: no key was set")
Expand Down
70 changes: 43 additions & 27 deletions helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,18 +179,22 @@ func TestMGetCache(t *testing.T) {
t.Fatalf("unexpected err %v", err)
}
t.Run("Delegate DisabledCache DoCache", func(t *testing.T) {
keys := make([]string, 100)
for i := range keys {
keys[i] = strconv.Itoa(i)
}
keys := []string{"{slot1}a", "{slot1}b", "{slot2}a", "{slot2}b"}
m.DoMultiFn = func(cmd ...Completed) *redisresults {
result := make([]RedisResult, len(cmd))
for i, key := range keys {
if !reflect.DeepEqual(cmd[i].Commands(), []string{"GET", key}) {
t.Fatalf("unexpected command %v", cmd)
for i, c := range cmd {
// Each command should be MGET with keys from the same slot
commands := c.Commands()
if commands[0] != "MGET" {
t.Fatalf("expected MGET command, got %v", commands)
return nil
}
result[i] = newResult(strmsg('+', key), nil)
// Build response array with values matching the keys
values := make([]RedisMessage, len(commands)-1)
for j := 1; j < len(commands); j++ {
values[j-1] = strmsg('+', commands[j])
}
result[i] = newResult(slicemsg('*', values), nil)
}
return &redisresults{s: result}
}
Expand All @@ -200,7 +204,7 @@ func TestMGetCache(t *testing.T) {
}
for _, key := range keys {
if vKey, ok := v[key]; !ok || vKey.string() != key {
t.Fatalf("unexpected response %v", v)
t.Fatalf("unexpected response for key %s: %v", key, v)
}
}
})
Expand Down Expand Up @@ -358,18 +362,22 @@ func TestMGet(t *testing.T) {
t.Fatalf("unexpected err %v", err)
}
t.Run("Delegate Do", func(t *testing.T) {
keys := make([]string, 100)
for i := range keys {
keys[i] = strconv.Itoa(i)
}
keys := []string{"{slot1}a", "{slot1}b", "{slot2}a", "{slot2}b"}
m.DoMultiFn = func(cmd ...Completed) *redisresults {
result := make([]RedisResult, len(cmd))
for i, key := range keys {
if !reflect.DeepEqual(cmd[i].Commands(), []string{"GET", key}) {
t.Fatalf("unexpected command %v", cmd)
for i, c := range cmd {
// Each command should be MGET with keys from the same slot
commands := c.Commands()
if commands[0] != "MGET" {
t.Fatalf("expected MGET command, got %v", commands)
return nil
}
result[i] = newResult(strmsg('+', key), nil)
// Build response array with values matching the keys
values := make([]RedisMessage, len(commands)-1)
for j := 1; j < len(commands); j++ {
values[j-1] = strmsg('+', commands[j])
}
result[i] = newResult(slicemsg('*', values), nil)
}
return &redisresults{s: result}
}
Expand All @@ -379,7 +387,7 @@ func TestMGet(t *testing.T) {
}
for _, key := range keys {
if vKey, ok := v[key]; !ok || vKey.string() != key {
t.Fatalf("unexpected response %v", v)
t.Fatalf("unexpected response for key %s: %v", key, v)
}
}
})
Expand Down Expand Up @@ -1162,18 +1170,26 @@ func TestJsonMGet(t *testing.T) {
t.Fatalf("unexpected err %v", err)
}
t.Run("Delegate Do", func(t *testing.T) {
keys := make([]string, 100)
for i := range keys {
keys[i] = strconv.Itoa(i)
}
keys := []string{"{slot1}a", "{slot1}b", "{slot2}a", "{slot2}b"}
m.DoMultiFn = func(cmd ...Completed) *redisresults {
result := make([]RedisResult, len(cmd))
for i, key := range keys {
if !reflect.DeepEqual(cmd[i].Commands(), []string{"JSON.GET", key, "$"}) {
t.Fatalf("unexpected command %v", cmd)
for i, c := range cmd {
// Each command should be JSON.MGET with keys from the same slot and path at the end
commands := c.Commands()
if commands[0] != "JSON.MGET" {
t.Fatalf("expected JSON.MGET command, got %v", commands)
return nil
}
result[i] = newResult(strmsg('+', key), nil)
if commands[len(commands)-1] != "$" {
t.Fatalf("expected $ as last parameter, got %v", commands)
return nil
}
// Build response array with values matching the keys (exclude the path)
values := make([]RedisMessage, len(commands)-2)
for j := 1; j < len(commands)-1; j++ {
values[j-1] = strmsg('+', commands[j])
}
result[i] = newResult(slicemsg('*', values), nil)
}
return &redisresults{s: result}
}
Expand All @@ -1183,7 +1199,7 @@ func TestJsonMGet(t *testing.T) {
}
for _, key := range keys {
if vKey, ok := v[key]; !ok || vKey.string() != key {
t.Fatalf("unexpected response %v", v)
t.Fatalf("unexpected response for key %s: %v", key, v)
}
}
})
Expand Down
Loading