diff --git a/helper.go b/helper.go index cb59f7ba..1ef39882 100644 --- a/helper.go +++ b/helper.go @@ -9,6 +9,10 @@ import ( intl "github.com/redis/rueidis/internal/cmds" ) +func slot(key string) uint16 { + return intl.Slot(key) +} + // MGetCache is a helper that consults the client-side caches with multiple keys by grouping keys within the same slot into multiple GETs func MGetCache(client Client, ctx context.Context, ttl time.Duration, keys []string) (ret map[string]RedisMessage, err error) { if len(keys) == 0 { @@ -50,12 +54,7 @@ func MGet(client Client, ctx context.Context, keys []string) (ret map[string]Red return clientMGet(client, ctx, client.B().Mget().Key(keys...).Build(), keys) } - cmds := mgetcmdsp.Get(len(keys), len(keys)) - defer mgetcmdsp.Put(cmds) - for i := range cmds.s { - cmds.s[i] = client.B().Get().Key(keys[i]).Build() - } - return doMultiGet(client, ctx, cmds.s, keys) + return clusterMGet(client, ctx, keys) } // MSet is a helper that consults the redis directly with multiple keys by grouping keys within the same slot into MSETs or multiple SETs @@ -139,12 +138,7 @@ func JsonMGet(client Client, ctx context.Context, keys []string, path string) (r return clientMGet(client, ctx, client.B().JsonMget().Key(keys...).Path(path).Build(), keys) } - cmds := mgetcmdsp.Get(len(keys), len(keys)) - defer mgetcmdsp.Put(cmds) - for i := range cmds.s { - cmds.s[i] = client.B().JsonGet().Key(keys[i]).Path(path).Build() - } - return doMultiGet(client, ctx, cmds.s, keys) + return clusterJsonMGet(client, ctx, keys, path) } // JsonMSet is a helper that consults redis directly with multiple keys by grouping keys within the same slot into JSON.MSETs or multiple JSON.SETs @@ -277,6 +271,75 @@ func arrayToKV(m map[string]RedisMessage, arr []RedisMessage, keys []string) map return m } +func clusterMGet(client Client, ctx context.Context, keys []string) (ret map[string]RedisMessage, err error) { + ret = make(map[string]RedisMessage, len(keys)) + slots := make(map[uint16][]int, len(keys)/2) + for i, key := range keys { + s := slot(key) + slots[s] = append(slots[s], i) + } + cmds := mgetcmdsp.Get(0, len(slots)) + defer mgetcmdsp.Put(cmds) + groups := make([][]string, 0, len(slots)) + for _, group := range slots { + gkeys := make([]string, 0, len(group)) + for _, i := range group { + gkeys = append(gkeys, keys[i]) + } + cmds.s = append(cmds.s, client.B().Mget().Key(gkeys...).Build().Pin()) + groups = append(groups, gkeys) + } + resps := client.DoMulti(ctx, cmds.s...) + defer resultsp.Put(&redisresults{s: resps}) + for i, resp := range resps { + arr, err := resp.ToArray() + if err != nil { + return nil, err + } + ret = arrayToKV(ret, arr, groups[i]) + } + for i := range cmds.s { + intl.PutCompletedForce(cmds.s[i]) + } + return ret, nil +} + +func clusterJsonMGet(client Client, ctx context.Context, keys []string, path string) (ret map[string]RedisMessage, err error) { + ret = make(map[string]RedisMessage, len(keys)) + slots := make(map[uint16][]int, len(keys)/2) + for i, key := range keys { + s := slot(key) + slots[s] = append(slots[s], i) + } + if len(slots) == 0 { + return ret, nil + } + cmds := mgetcmdsp.Get(0, len(slots)) + defer mgetcmdsp.Put(cmds) + groups := make([][]string, 0, len(slots)) + for _, group := range slots { + gkeys := make([]string, 0, len(group)) + for _, i := range group { + gkeys = append(gkeys, keys[i]) + } + cmds.s = append(cmds.s, client.B().JsonMget().Key(gkeys...).Path(path).Build().Pin()) + groups = append(groups, gkeys) + } + resps := client.DoMulti(ctx, cmds.s...) + defer resultsp.Put(&redisresults{s: resps}) + for i, resp := range resps { + arr, err := resp.ToArray() + if err != nil { + return nil, err + } + ret = arrayToKV(ret, arr, groups[i]) + } + for i := range cmds.s { + intl.PutCompletedForce(cmds.s[i]) + } + return ret, nil +} + // ErrMSetNXNotSet is used in the MSetNX helper when the underlying MSETNX response is 0. // Ref: https://redis.io/commands/msetnx/ var ErrMSetNXNotSet = errors.New("MSETNX: no key was set") diff --git a/helper_test.go b/helper_test.go index f23bdb06..e3c6945c 100644 --- a/helper_test.go +++ b/helper_test.go @@ -179,18 +179,22 @@ func TestMGetCache(t *testing.T) { t.Fatalf("unexpected err %v", err) } t.Run("Delegate DisabledCache DoCache", func(t *testing.T) { - keys := make([]string, 100) - for i := range keys { - keys[i] = strconv.Itoa(i) - } + keys := []string{"{slot1}a", "{slot1}b", "{slot2}a", "{slot2}b"} m.DoMultiFn = func(cmd ...Completed) *redisresults { result := make([]RedisResult, len(cmd)) - for i, key := range keys { - if !reflect.DeepEqual(cmd[i].Commands(), []string{"GET", key}) { - t.Fatalf("unexpected command %v", cmd) + for i, c := range cmd { + // Each command should be MGET with keys from the same slot + commands := c.Commands() + if commands[0] != "MGET" { + t.Fatalf("expected MGET command, got %v", commands) return nil } - result[i] = newResult(strmsg('+', key), nil) + // Build response array with values matching the keys + values := make([]RedisMessage, len(commands)-1) + for j := 1; j < len(commands); j++ { + values[j-1] = strmsg('+', commands[j]) + } + result[i] = newResult(slicemsg('*', values), nil) } return &redisresults{s: result} } @@ -200,7 +204,7 @@ func TestMGetCache(t *testing.T) { } for _, key := range keys { if vKey, ok := v[key]; !ok || vKey.string() != key { - t.Fatalf("unexpected response %v", v) + t.Fatalf("unexpected response for key %s: %v", key, v) } } }) @@ -358,18 +362,22 @@ func TestMGet(t *testing.T) { t.Fatalf("unexpected err %v", err) } t.Run("Delegate Do", func(t *testing.T) { - keys := make([]string, 100) - for i := range keys { - keys[i] = strconv.Itoa(i) - } + keys := []string{"{slot1}a", "{slot1}b", "{slot2}a", "{slot2}b"} m.DoMultiFn = func(cmd ...Completed) *redisresults { result := make([]RedisResult, len(cmd)) - for i, key := range keys { - if !reflect.DeepEqual(cmd[i].Commands(), []string{"GET", key}) { - t.Fatalf("unexpected command %v", cmd) + for i, c := range cmd { + // Each command should be MGET with keys from the same slot + commands := c.Commands() + if commands[0] != "MGET" { + t.Fatalf("expected MGET command, got %v", commands) return nil } - result[i] = newResult(strmsg('+', key), nil) + // Build response array with values matching the keys + values := make([]RedisMessage, len(commands)-1) + for j := 1; j < len(commands); j++ { + values[j-1] = strmsg('+', commands[j]) + } + result[i] = newResult(slicemsg('*', values), nil) } return &redisresults{s: result} } @@ -379,7 +387,7 @@ func TestMGet(t *testing.T) { } for _, key := range keys { if vKey, ok := v[key]; !ok || vKey.string() != key { - t.Fatalf("unexpected response %v", v) + t.Fatalf("unexpected response for key %s: %v", key, v) } } }) @@ -1162,18 +1170,26 @@ func TestJsonMGet(t *testing.T) { t.Fatalf("unexpected err %v", err) } t.Run("Delegate Do", func(t *testing.T) { - keys := make([]string, 100) - for i := range keys { - keys[i] = strconv.Itoa(i) - } + keys := []string{"{slot1}a", "{slot1}b", "{slot2}a", "{slot2}b"} m.DoMultiFn = func(cmd ...Completed) *redisresults { result := make([]RedisResult, len(cmd)) - for i, key := range keys { - if !reflect.DeepEqual(cmd[i].Commands(), []string{"JSON.GET", key, "$"}) { - t.Fatalf("unexpected command %v", cmd) + for i, c := range cmd { + // Each command should be JSON.MGET with keys from the same slot and path at the end + commands := c.Commands() + if commands[0] != "JSON.MGET" { + t.Fatalf("expected JSON.MGET command, got %v", commands) return nil } - result[i] = newResult(strmsg('+', key), nil) + if commands[len(commands)-1] != "$" { + t.Fatalf("expected $ as last parameter, got %v", commands) + return nil + } + // Build response array with values matching the keys (exclude the path) + values := make([]RedisMessage, len(commands)-2) + for j := 1; j < len(commands)-1; j++ { + values[j-1] = strmsg('+', commands[j]) + } + result[i] = newResult(slicemsg('*', values), nil) } return &redisresults{s: result} } @@ -1183,7 +1199,7 @@ func TestJsonMGet(t *testing.T) { } for _, key := range keys { if vKey, ok := v[key]; !ok || vKey.string() != key { - t.Fatalf("unexpected response %v", v) + t.Fatalf("unexpected response for key %s: %v", key, v) } } })