Skip to content

feat: client-side caching #2542

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 263 additions & 0 deletions cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
package redis

import (
"context"
"fmt"
"net"
"strconv"
"sync/atomic"
"time"

"github.com/redis/go-redis/v9/internal"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/internal/proto"
)

type Cache interface{}

type cache struct {
client *Client

// cluster? sentinel?
conn *Conn
prefix []string

closed int32 // atomic
}

func newCache() Cache {
// ?
return &cache{}
}

// ------------------------------------------------------------------------------------------

// extension method

func (c *Conn) readReply(ctx context.Context, timeout time.Duration, fn func(rd *proto.Reader) error) error {
return c.withConn(ctx, func(ctx context.Context, conn *pool.Conn) error {
return conn.WithReader(ctx, timeout, fn)
})
}

// ------------------------------------------------------------------------------------------

// Client-side caching command.

type trackingArgs struct {
redirect int
prefixes []string
broadcast bool
optIn bool
optOut bool
noLoop bool
}

func (c *cache) clientTracking(ctx context.Context, t *trackingArgs) *StringCmd {
args := make([]any, 0, 7+len(t.prefixes))
args = append(args, "CLIENT", "TRACKING", "ON")
if t.redirect > 0 {
args = append(args, "REDIRECT", t.redirect)
}
if len(t.prefixes) > 0 {
for _, prefix := range t.prefixes {
args = append(args, "PREFIX", prefix)
}
}
if t.optIn {
args = append(args, "OPTIN")
}
if t.optOut {
args = append(args, "OPTOUT")
}
if t.noLoop {
args = append(args, "NOLOOP")
}
cmd := NewStringCmd(ctx, args...)
_ = c.conn.Process(ctx, cmd)
return cmd
}

func (c *cache) trackingClose(ctx context.Context) error {
return c.conn.Process(ctx, NewStringCmd(ctx, "CLIENT", "TRACKING", "OFF"))
}

func (c *cache) cachingYes(ctx context.Context) error {
return c.conn.Process(ctx, NewStringCmd(ctx, "CLIENT", "CACHING", "YES"))
}

func (c *cache) cachingNo(ctx context.Context) error {
return c.conn.Process(ctx, NewStringCmd(ctx, "CLIENT", "CACHING", "NO"))
}

// ------------------------------------------------------------------------------------

// readInvalidate To read the expired message push from redis-server,
// we only read for invalidate messages, and consider any other data that is read as an error.
func (c *cache) readInvalidate(rd *proto.Reader) ([]string, error) {
line, err := rd.ReadLine()
if err != nil {
return nil, err
}

if line[0] != proto.RespPush {
return nil, fmt.Errorf("invalid data-%s", string(line))
}

n, err := strconv.Atoi(string(line[1:]))
if err != nil {
return nil, err
}
if n != 2 {
return nil, fmt.Errorf("got %d elements in the map, wanted %d", n, 2)
}

// read `invalidate`
s, err := rd.ReadString()
if err != nil {
return nil, err
}
if s != "invalidate" {
return nil, fmt.Errorf("not a client-side caching push message, data-%s", s)
}

n, err = rd.ReadArrayLen()
if err != nil {
return nil, err
}

keys := make([]string, 0, n)
for i := 0; i < n; i++ {
key, err := rd.ReadString()
if err != nil {
return nil, err
}
keys = append(keys, key)
}

return keys, nil
}

// ------------------------------------- Broadcasting -------------------------------------

func (c *cache) listen(timeout time.Duration) {
ctx := context.Background()
defer func() {
if err := recover(); err != nil {
internal.Logger.Printf(ctx, "redis cache: panic - %v", err)
}
}()

if timeout == 0 {
timeout = 30 * time.Second
}
internal.Logger.Printf(ctx, "redis cache: listen working, read timeout-%d second", int(timeout/time.Second))

// state, 0-normal, 1-need init track
const (
normal = 0
bad = 1
)
var state = normal
for {
if atomic.LoadInt32(&c.closed) == 1 {
_ = c.conn.Close()
internal.Logger.Printf(ctx, "redis cache: close, quit listen")
return
}

if state == bad {
internal.Logger.Printf(ctx, "redis cache: state bad")
if err := c.initTrack(ctx); err != nil {
internal.Logger.Printf(ctx, "redis cache: listen init track error-%s", err.Error())
time.Sleep(1 * time.Second)
continue
}
}

if err := c.conn.Ping(ctx).Err(); err != nil {
internal.Logger.Printf(ctx, "redis cache: listen ping error-%s", err.Error())
state = bad
continue
}
state = normal

var keys []string
err := c.conn.withConn(ctx, func(ctx context.Context, conn *pool.Conn) error {
return conn.WithReader(ctx, timeout, func(rd *proto.Reader) (err error) {
keys, err = c.readInvalidate(rd)

if err == nil {
return nil
}

// The timeout error is considered normal, and it is triggered when we fail
// to receive a notification. We handle it as nil.
// We cannot return the timeout error, as go-redis would consider it a network
// problem and close the network connection.
if isNetTimeout(err) {
err = nil
return err
}

// We only listen for redis-push notifications, so under normal circumstances,
// we should not receive any redis-error notifications.
// If we do, we need to handle them as errors; otherwise,
// go-redis may consider redis errors as normal occurrences.
if isRedisError(err) {
err = fmt.Errorf("redis cache: unexpected response redis-error-msg-%s", err.Error())
}

return err
})
})

// under normal circumstances, we should not receive any errors, including redis errors.
if err != nil {
state = bad

internal.Logger.Printf(ctx, "redis cache: read push data error-%s", err.Error())
continue
}

// it's possible that we may not receive any notifications for keys.
if len(keys) > 0 {
// handle keys
}
}
}

func (c *cache) initTrack(ctx context.Context) error {
internal.Logger.Printf(ctx, "redis cache: init track")
if c.conn != nil {
_ = c.conn.Close()
}
c.conn = c.client.Conn()

args := make([]any, 0, 3+2*len(c.prefix)+1)
args = append(args, "CLIENT", "TRACKING", "ON")
for _, prefix := range c.prefix {
args = append(args, "PREFIX", prefix)
}
args = append(args, "BCAST")
cmd := NewStringCmd(ctx, args...)

if err := c.conn.Process(ctx, cmd); err != nil {
_ = c.conn.Close()
return err
}

return nil
}

// isNetTimeout check err == net timeout
func isNetTimeout(err error) bool {
if err == nil {
return false
}
netErr, ok := err.(net.Error)
if !ok {
return false
}
return netErr.Timeout()
}