From fc024aeb5b41849805ccc2f4a27a40be171b244e Mon Sep 17 00:00:00 2001 From: Rueian Date: Sun, 28 Jul 2024 00:03:04 +0800 Subject: [PATCH] feat: add PubSub to rueidiscompat (#592) --- rueidiscompat/adapter.go | 22 ++ rueidiscompat/pubsub.go | 476 +++++++++++++++++++++++++++++++++++ rueidiscompat/pubsub_test.go | 474 ++++++++++++++++++++++++++++++++++ 3 files changed, 972 insertions(+) create mode 100644 rueidiscompat/pubsub.go create mode 100644 rueidiscompat/pubsub_test.go diff --git a/rueidiscompat/adapter.go b/rueidiscompat/adapter.go index 5a8df142..3a2ad120 100644 --- a/rueidiscompat/adapter.go +++ b/rueidiscompat/adapter.go @@ -469,6 +469,10 @@ type ProbabilisticCmdable interface { TDigestReset(ctx context.Context, key string) *StatusCmd TDigestRevRank(ctx context.Context, key string, values ...float64) *IntSliceCmd TDigestTrimmedMean(ctx context.Context, key string, lowCutQuantile, highCutQuantile float64) *FloatCmd + + Subscribe(ctx context.Context, channels ...string) PubSub + PSubscribe(ctx context.Context, patterns ...string) PubSub + SSubscribe(ctx context.Context, channels ...string) PubSub } // Align with go-redis @@ -4596,6 +4600,24 @@ func (c *Compat) JSONType(ctx context.Context, key, path string) *JSONSliceCmd { return newJSONSliceCmd(c.client.Do(ctx, cmd)) } +func (c *Compat) Subscribe(ctx context.Context, channels ...string) PubSub { + p := newPubSub(c.client) + _ = p.Subscribe(ctx, channels...) + return p +} + +func (c *Compat) SSubscribe(ctx context.Context, channels ...string) PubSub { + p := newPubSub(c.client) + _ = p.SSubscribe(ctx, channels...) + return p +} + +func (c *Compat) PSubscribe(ctx context.Context, patterns ...string) PubSub { + p := newPubSub(c.client) + _ = p.PSubscribe(ctx, patterns...) + return p +} + func (c CacheCompat) BitCount(ctx context.Context, key string, bitCount *BitCount) *IntCmd { var resp rueidis.RedisResult if bitCount == nil { diff --git a/rueidiscompat/pubsub.go b/rueidiscompat/pubsub.go new file mode 100644 index 00000000..e89a98d8 --- /dev/null +++ b/rueidiscompat/pubsub.go @@ -0,0 +1,476 @@ +// Copyright (c) 2013 The github.com/go-redis/redis Authors. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package rueidiscompat + +import ( + "context" + "errors" + "fmt" + "strings" + "sync" + "time" + + "github.com/redis/rueidis" +) + +type PubSub interface { + Close() error + Subscribe(ctx context.Context, channels ...string) error + PSubscribe(ctx context.Context, patterns ...string) error + SSubscribe(ctx context.Context, channels ...string) error + Unsubscribe(ctx context.Context, channels ...string) error + PUnsubscribe(ctx context.Context, patterns ...string) error + SUnsubscribe(ctx context.Context, channels ...string) error + Ping(ctx context.Context, payload ...string) error + ReceiveTimeout(ctx context.Context, timeout time.Duration) (any, error) + Receive(ctx context.Context) (any, error) + ReceiveMessage(ctx context.Context) (*Message, error) + Channel(opts ...ChannelOption) <-chan *Message + ChannelWithSubscriptions(opts ...ChannelOption) <-chan any + String() string +} + +type ChannelOption func(c *chopt) + +type chopt struct { + chanSize int +} + +// WithChannelSize specifies the Go chan size that is used to buffer incoming messages. +// The default is 1000 messages. +func WithChannelSize(size int) ChannelOption { + return func(c *chopt) { + c.chanSize = size + } +} + +// WithChannelHealthCheckInterval is an empty ChannelOption to keep compatibility +func WithChannelHealthCheckInterval(_ time.Duration) ChannelOption { + return func(c *chopt) {} +} + +// WithChannelSendTimeout is an empty ChannelOption to keep compatibility +func WithChannelSendTimeout(_ time.Duration) ChannelOption { + return func(c *chopt) {} +} + +// Subscription received after a successful subscription to channel. +type Subscription struct { + // Can be "subscribe", "unsubscribe", "psubscribe" or "punsubscribe". + Kind string + // Channel name we have subscribed to. + Channel string + // Number of channels we are currently subscribed to. + Count int +} + +func (m *Subscription) String() string { + return fmt.Sprintf("%s: %s", m.Kind, m.Channel) +} + +// Message received as result of a PUBLISH command issued by another client. +type Message struct { + Channel string + Pattern string + Payload string + PayloadSlice []string +} + +func (m *Message) String() string { + return fmt.Sprintf("Message<%s: %s>", m.Channel, m.Payload) +} + +func newPubSub(client rueidis.Client) *pubsub { + return &pubsub{ + rc: client, + channels: make(map[string]bool), + patterns: make(map[string]bool), + schannels: make(map[string]bool), + } +} + +type pubsub struct { + mu sync.Mutex + + rc rueidis.Client + mc rueidis.DedicatedClient + mcancel func() + + channels map[string]bool + patterns map[string]bool + schannels map[string]bool + + allCh chan any + msgCh chan *Message +} + +func (p *pubsub) mconn() rueidis.DedicatedClient { + if p.mc == nil { + p.mc, p.mcancel = p.rc.Dedicate() + } + return p.mc +} + +func (p *pubsub) Close() error { + p.mu.Lock() + defer p.mu.Unlock() + if p.mcancel != nil { + p.mc.SetPubSubHooks(rueidis.PubSubHooks{}) + p.mcancel() + p.mc = nil + p.mcancel = nil + } + p.channels = make(map[string]bool) + p.patterns = make(map[string]bool) + p.schannels = make(map[string]bool) + return nil +} + +func (p *pubsub) Subscribe(ctx context.Context, channels ...string) error { + if len(channels) == 0 { + return nil + } + + p.ChannelWithSubscriptions() + + p.mu.Lock() + defer p.mu.Unlock() + + for _, channel := range channels { + p.channels[channel] = true + } + + c := p.mconn() + return c.Do(ctx, c.B().Subscribe().Channel(channels...).Build()).Error() +} + +func (p *pubsub) PSubscribe(ctx context.Context, patterns ...string) error { + if len(patterns) == 0 { + return nil + } + + p.ChannelWithSubscriptions() + + p.mu.Lock() + defer p.mu.Unlock() + + for _, pattern := range patterns { + p.patterns[pattern] = true + } + + c := p.mconn() + return c.Do(ctx, c.B().Psubscribe().Pattern(patterns...).Build()).Error() +} + +func (p *pubsub) SSubscribe(ctx context.Context, channels ...string) error { + if len(channels) == 0 { + return nil + } + + p.ChannelWithSubscriptions() + + p.mu.Lock() + defer p.mu.Unlock() + + if len(p.channels) != 0 || len(p.patterns) != 0 { + return errors.New("pubsub: cannot use SSubscribe after using Subscribe or PSubscribe") + } + + for _, channel := range channels { + p.schannels[channel] = true + } + + c := p.mconn() + return c.Do(ctx, c.B().Ssubscribe().Channel(channels...).Build()).Error() +} + +func (p *pubsub) Unsubscribe(ctx context.Context, channels ...string) error { + if len(channels) == 0 { + return nil + } + + p.mu.Lock() + defer p.mu.Unlock() + + for _, channel := range channels { + delete(p.channels, channel) + } + + c := p.mconn() + return c.Do(ctx, c.B().Unsubscribe().Channel(channels...).Build()).Error() +} + +func (p *pubsub) PUnsubscribe(ctx context.Context, patterns ...string) error { + if len(patterns) == 0 { + return nil + } + + p.mu.Lock() + defer p.mu.Unlock() + + for _, pattern := range patterns { + delete(p.patterns, pattern) + } + + c := p.mconn() + return c.Do(ctx, c.B().Punsubscribe().Pattern(patterns...).Build()).Error() +} + +func (p *pubsub) SUnsubscribe(ctx context.Context, channels ...string) error { + if len(channels) == 0 { + return nil + } + + p.mu.Lock() + defer p.mu.Unlock() + + for _, channel := range channels { + delete(p.schannels, channel) + } + + c := p.mconn() + return c.Do(ctx, c.B().Sunsubscribe().Channel(channels...).Build()).Error() +} + +func (p *pubsub) Ping(_ context.Context, _ ...string) error { + return nil // we already ping the connection periodically by default +} + +func (p *pubsub) reset() { + if p.mcancel != nil { + p.mcancel() + p.mc = nil + p.mcancel = nil + } + for channel := range p.channels { + p.channels[channel] = false + } + for pattern := range p.patterns { + p.patterns[pattern] = false + } + for schannel := range p.schannels { + p.schannels[schannel] = false + } +} + +func (p *pubsub) resubscribe(ctx context.Context) rueidis.DedicatedClient { + p.mu.Lock() + defer p.mu.Unlock() +retry: + c := p.mconn() + ok := false + if len(p.schannels) != 0 { + builder := c.B().Ssubscribe().Channel() + for channel, ok := range p.schannels { + if !ok { + builder = builder.Channel(channel) + p.schannels[channel] = true + } + } + if cmd := builder.Build(); len(cmd.Commands()) > 1 { + if err := c.Do(ctx, cmd).NonRedisError(); err != nil { + p.reset() + goto retry + } + ok = true + } + } + if len(p.channels) != 0 { + builder := c.B().Subscribe().Channel() + for channel, ok := range p.channels { + if !ok { + builder = builder.Channel(channel) + p.channels[channel] = true + } + } + if cmd := builder.Build(); len(cmd.Commands()) > 1 { + if err := c.Do(ctx, cmd).NonRedisError(); err != nil { + p.reset() + goto retry + } + ok = true + } + } + if len(p.patterns) != 0 { + builder := c.B().Psubscribe().Pattern() + for pattern, ok := range p.patterns { + if !ok { + builder = builder.Pattern(pattern) + p.patterns[pattern] = true + } + } + if cmd := builder.Build(); len(cmd.Commands()) > 1 { + if err := c.Do(ctx, cmd).NonRedisError(); err != nil { + p.reset() + goto retry + } + ok = true + } + } + if !ok { + if err := c.Do(ctx, c.B().Ping().Build()).NonRedisError(); err != nil { + p.reset() + goto retry + } + } + return c +} + +func (p *pubsub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (any, error) { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + select { + case m := <-p.ChannelWithSubscriptions(): + if m == nil { + return nil, errors.New("redis: client is closed") + } + return m, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (p *pubsub) Receive(_ context.Context) (any, error) { + m := <-p.ChannelWithSubscriptions() + if m == nil { + return nil, errors.New("redis: client is closed") + } + return m, nil +} + +func (p *pubsub) ReceiveMessage(_ context.Context) (*Message, error) { + m := <-p.Channel() + if m == nil { + return nil, errors.New("redis: client is closed") + } + return m, nil +} + +func (p *pubsub) Channel(opts ...ChannelOption) <-chan *Message { + ch := p.ChannelWithSubscriptions(opts...) + p.mu.Lock() + defer p.mu.Unlock() + if p.msgCh != nil { + return p.msgCh + } + msgCh := make(chan *Message) + p.msgCh = msgCh + go func() { + for m := range ch { + if msg, ok := m.(*Message); ok { + msgCh <- msg + } + } + p.mu.Lock() + if p.msgCh == msgCh { + p.msgCh = nil + } + p.mu.Unlock() + close(msgCh) + }() + return msgCh +} + +func (p *pubsub) ChannelWithSubscriptions(opts ...ChannelOption) <-chan any { + p.mu.Lock() + if p.allCh != nil { + p.mu.Unlock() + return p.allCh + } + opt := &chopt{chanSize: 1000} + for _, fn := range opts { + fn(opt) + } + allCh := make(chan any, opt.chanSize) + p.allCh = allCh + p.mu.Unlock() + + resubscribe := func() <-chan error { + c := p.resubscribe(context.Background()) + return c.SetPubSubHooks(rueidis.PubSubHooks{ + OnMessage: func(m rueidis.PubSubMessage) { + msg := &Message{ + Channel: m.Channel, + Pattern: m.Pattern, + Payload: m.Message, + } + select { + case allCh <- msg: + default: + } + }, + OnSubscription: func(s rueidis.PubSubSubscription) { + sub := &Subscription{ + Kind: s.Kind, + Channel: s.Channel, + Count: int(s.Count), + } + select { + case allCh <- sub: + default: + } + }, + }) + } + go func(wait <-chan error) { + for { + if err := <-wait; err == nil { + p.mu.Lock() + if p.allCh == allCh { + p.allCh = nil + } + p.mu.Unlock() + close(allCh) + return + } + p.mu.Lock() + p.reset() + p.mu.Unlock() + wait = resubscribe() + } + }(resubscribe()) + + return allCh +} + +func (p *pubsub) String() string { + channels := mapKeys(p.channels) + channels = append(channels, mapKeys(p.patterns)...) + channels = append(channels, mapKeys(p.schannels)...) + return fmt.Sprintf("PubSub(%s)", strings.Join(channels, ", ")) +} + +func mapKeys(m map[string]bool) []string { + s := make([]string, len(m)) + i := 0 + for k := range m { + s[i] = k + i++ + } + return s +} diff --git a/rueidiscompat/pubsub_test.go b/rueidiscompat/pubsub_test.go new file mode 100644 index 00000000..c0e40608 --- /dev/null +++ b/rueidiscompat/pubsub_test.go @@ -0,0 +1,474 @@ +// Copyright (c) 2013 The github.com/go-redis/redis Authors. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package rueidiscompat + +import ( + "bytes" + "context" + "sync" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("PubSub", func() { + var client Cmdable + + BeforeEach(func() { + client = adapterresp3 + }) + + It("implements Stringer", func() { + pubsub := client.PSubscribe(ctx, "mychannel*") + defer pubsub.Close() + + Expect(pubsub.String()).To(Equal("PubSub(mychannel*)")) + }) + + It("should support pattern matching", func() { + pubsub := client.PSubscribe(ctx, "mychannel*") + defer pubsub.Close() + + { + msgi, err := pubsub.ReceiveTimeout(ctx, time.Second) + Expect(err).NotTo(HaveOccurred()) + subscr := msgi.(*Subscription) + Expect(subscr.Kind).To(Equal("psubscribe")) + Expect(subscr.Channel).To(Equal("mychannel*")) + Expect(subscr.Count).To(Equal(1)) + } + + { + msgi, err := pubsub.ReceiveTimeout(ctx, time.Second) + Expect(err).To(MatchError(context.DeadlineExceeded)) + Expect(msgi).To(BeNil()) + } + + n, err := client.Publish(ctx, "mychannel1", "hello").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(int64(1))) + + Expect(pubsub.PUnsubscribe(ctx, "mychannel*")).NotTo(HaveOccurred()) + + { + msgi, err := pubsub.ReceiveTimeout(ctx, time.Second) + Expect(err).NotTo(HaveOccurred()) + subscr := msgi.(*Message) + Expect(subscr.Channel).To(Equal("mychannel1")) + Expect(subscr.Pattern).To(Equal("mychannel*")) + Expect(subscr.Payload).To(Equal("hello")) + } + + { + msgi, err := pubsub.ReceiveTimeout(ctx, time.Second) + Expect(err).NotTo(HaveOccurred()) + subscr := msgi.(*Subscription) + Expect(subscr.Kind).To(Equal("punsubscribe")) + Expect(subscr.Channel).To(Equal("mychannel*")) + Expect(subscr.Count).To(Equal(0)) + } + }) + + It("should pub/sub channels", func() { + channels, err := client.PubSubChannels(ctx, "mychannel*").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(channels).To(BeEmpty()) + + pubsub := client.Subscribe(ctx, "mychannel", "mychannel2") + defer pubsub.Close() + + channels, err = client.PubSubChannels(ctx, "mychannel*").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(channels).To(ConsistOf([]string{"mychannel", "mychannel2"})) + + channels, err = client.PubSubChannels(ctx, "").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(channels).To(BeEmpty()) + + channels, err = client.PubSubChannels(ctx, "*").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(len(channels)).To(BeNumerically(">=", 2)) + }) + + It("should sharded pub/sub channels", func() { + channels, err := client.PubSubShardChannels(ctx, "mychannel*").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(channels).To(BeEmpty()) + + pubsub := client.SSubscribe(ctx, "mychannel", "mychannel2") + defer pubsub.Close() + + channels, err = client.PubSubShardChannels(ctx, "mychannel*").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(channels).To(ConsistOf([]string{"mychannel", "mychannel2"})) + + channels, err = client.PubSubShardChannels(ctx, "").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(channels).To(BeEmpty()) + + channels, err = client.PubSubShardChannels(ctx, "*").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(len(channels)).To(BeNumerically(">=", 2)) + + nums, err := client.PubSubShardNumSub(ctx, "mychannel", "mychannel2", "mychannel3").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(nums).To(Equal(map[string]int64{ + "mychannel": 1, + "mychannel2": 1, + "mychannel3": 0, + })) + }) + + It("should return the numbers of subscribers", func() { + pubsub := client.Subscribe(ctx, "mychannel", "mychannel2") + defer pubsub.Close() + + channels, err := client.PubSubNumSub(ctx, "mychannel", "mychannel2", "mychannel3").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(channels).To(Equal(map[string]int64{ + "mychannel": 1, + "mychannel2": 1, + "mychannel3": 0, + })) + }) + + It("should return the numbers of subscribers by pattern", func() { + num, err := client.PubSubNumPat(ctx).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(num).To(Equal(int64(0))) + + pubsub := client.PSubscribe(ctx, "*") + defer pubsub.Close() + + num, err = client.PubSubNumPat(ctx).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(num).To(Equal(int64(1))) + }) + + It("should pub/sub", func() { + pubsub := client.Subscribe(ctx, "mychannel", "mychannel2") + defer pubsub.Close() + + { + msgi, err := pubsub.ReceiveTimeout(ctx, time.Second) + Expect(err).NotTo(HaveOccurred()) + subscr := msgi.(*Subscription) + Expect(subscr.Kind).To(Equal("subscribe")) + Expect(subscr.Channel).To(Equal("mychannel")) + Expect(subscr.Count).To(Equal(1)) + } + + { + msgi, err := pubsub.ReceiveTimeout(ctx, time.Second) + Expect(err).NotTo(HaveOccurred()) + subscr := msgi.(*Subscription) + Expect(subscr.Kind).To(Equal("subscribe")) + Expect(subscr.Channel).To(Equal("mychannel2")) + Expect(subscr.Count).To(Equal(2)) + } + + { + msgi, err := pubsub.ReceiveTimeout(ctx, time.Second) + Expect(err).To(MatchError(context.DeadlineExceeded)) + Expect(msgi).NotTo(HaveOccurred()) + } + + n, err := client.Publish(ctx, "mychannel", "hello").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(int64(1))) + + n, err = client.Publish(ctx, "mychannel2", "hello2").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(int64(1))) + + Expect(pubsub.Unsubscribe(ctx, "mychannel", "mychannel2")).NotTo(HaveOccurred()) + + { + msgi, err := pubsub.ReceiveTimeout(ctx, time.Second) + Expect(err).NotTo(HaveOccurred()) + msg := msgi.(*Message) + Expect(msg.Channel).To(Equal("mychannel")) + Expect(msg.Payload).To(Equal("hello")) + } + + { + msgi, err := pubsub.ReceiveTimeout(ctx, time.Second) + Expect(err).NotTo(HaveOccurred()) + msg := msgi.(*Message) + Expect(msg.Channel).To(Equal("mychannel2")) + Expect(msg.Payload).To(Equal("hello2")) + } + + { + msgi, err := pubsub.ReceiveTimeout(ctx, time.Second) + Expect(err).NotTo(HaveOccurred()) + subscr := msgi.(*Subscription) + Expect(subscr.Kind).To(Equal("unsubscribe")) + Expect(subscr.Channel).To(Equal("mychannel")) + Expect(subscr.Count).To(Equal(1)) + } + + { + msgi, err := pubsub.ReceiveTimeout(ctx, time.Second) + Expect(err).NotTo(HaveOccurred()) + subscr := msgi.(*Subscription) + Expect(subscr.Kind).To(Equal("unsubscribe")) + Expect(subscr.Channel).To(Equal("mychannel2")) + Expect(subscr.Count).To(Equal(0)) + } + }) + + It("should sharded pub/sub", func() { + pubsub := client.SSubscribe(ctx, "mychannel", "mychannel2") + defer pubsub.Close() + + { + msgi, err := pubsub.ReceiveTimeout(ctx, time.Second) + Expect(err).NotTo(HaveOccurred()) + subscr := msgi.(*Subscription) + Expect(subscr.Kind).To(Equal("ssubscribe")) + Expect(subscr.Channel).To(Equal("mychannel")) + Expect(subscr.Count).To(Equal(1)) + } + + { + msgi, err := pubsub.ReceiveTimeout(ctx, time.Second) + Expect(err).NotTo(HaveOccurred()) + subscr := msgi.(*Subscription) + Expect(subscr.Kind).To(Equal("ssubscribe")) + Expect(subscr.Channel).To(Equal("mychannel2")) + Expect(subscr.Count).To(Equal(2)) + } + + { + msgi, err := pubsub.ReceiveTimeout(ctx, time.Second) + Expect(err).To(MatchError(context.DeadlineExceeded)) + Expect(msgi).NotTo(HaveOccurred()) + } + + n, err := client.SPublish(ctx, "mychannel", "hello").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(int64(1))) + + n, err = client.SPublish(ctx, "mychannel2", "hello2").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(int64(1))) + + Expect(pubsub.SUnsubscribe(ctx, "mychannel", "mychannel2")).NotTo(HaveOccurred()) + + { + msgi, err := pubsub.ReceiveTimeout(ctx, time.Second) + Expect(err).NotTo(HaveOccurred()) + msg := msgi.(*Message) + Expect(msg.Channel).To(Equal("mychannel")) + Expect(msg.Payload).To(Equal("hello")) + } + + { + msgi, err := pubsub.ReceiveTimeout(ctx, time.Second) + Expect(err).NotTo(HaveOccurred()) + msg := msgi.(*Message) + Expect(msg.Channel).To(Equal("mychannel2")) + Expect(msg.Payload).To(Equal("hello2")) + } + + { + msgi, err := pubsub.ReceiveTimeout(ctx, time.Second) + Expect(err).NotTo(HaveOccurred()) + subscr := msgi.(*Subscription) + Expect(subscr.Kind).To(Equal("sunsubscribe")) + Expect(subscr.Channel).To(Equal("mychannel")) + Expect(subscr.Count).To(Equal(1)) + } + + { + msgi, err := pubsub.ReceiveTimeout(ctx, time.Second) + Expect(err).NotTo(HaveOccurred()) + subscr := msgi.(*Subscription) + Expect(subscr.Kind).To(Equal("sunsubscribe")) + Expect(subscr.Channel).To(Equal("mychannel2")) + Expect(subscr.Count).To(Equal(0)) + } + }) + + It("should multi-ReceiveMessage", func() { + pubsub := client.Subscribe(ctx, "mychannel") + defer pubsub.Close() + + subscr, err := pubsub.ReceiveTimeout(ctx, time.Second) + Expect(err).NotTo(HaveOccurred()) + Expect(subscr).To(Equal(&Subscription{ + Kind: "subscribe", + Channel: "mychannel", + Count: 1, + })) + + err = client.Publish(ctx, "mychannel", "hello").Err() + Expect(err).NotTo(HaveOccurred()) + + err = client.Publish(ctx, "mychannel", "world").Err() + Expect(err).NotTo(HaveOccurred()) + + msg, err := pubsub.ReceiveMessage(ctx) + Expect(err).NotTo(HaveOccurred()) + Expect(msg.Channel).To(Equal("mychannel")) + Expect(msg.Payload).To(Equal("hello")) + + msg, err = pubsub.ReceiveMessage(ctx) + Expect(err).NotTo(HaveOccurred()) + Expect(msg.Channel).To(Equal("mychannel")) + Expect(msg.Payload).To(Equal("world")) + }) + + It("should return on Close", func() { + pubsub := client.Subscribe(ctx, "mychannel") + defer pubsub.Close() + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer GinkgoRecover() + + wg.Done() + defer wg.Done() + + _, err := pubsub.ReceiveMessage(ctx) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(SatisfyAny( + Equal("redis: client is closed"), + ContainSubstring("use of closed network connection"), + )) + }() + + wg.Wait() + wg.Add(1) + + Expect(pubsub.Close()).NotTo(HaveOccurred()) + + wg.Wait() + }) + + It("should ReceiveMessage without a subscription", func() { + timeout := 100 * time.Millisecond + + pubsub := client.Subscribe(ctx) + defer pubsub.Close() + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer GinkgoRecover() + defer wg.Done() + + time.Sleep(timeout) + + err := pubsub.Subscribe(ctx, "mychannel") + Expect(err).NotTo(HaveOccurred()) + + time.Sleep(timeout) + + err = client.Publish(ctx, "mychannel", "hello").Err() + Expect(err).NotTo(HaveOccurred()) + }() + + msg, err := pubsub.ReceiveMessage(ctx) + Expect(err).NotTo(HaveOccurred()) + Expect(msg.Channel).To(Equal("mychannel")) + Expect(msg.Payload).To(Equal("hello")) + + wg.Wait() + }) + + It("handles big message payload", func() { + pubsub := client.Subscribe(ctx, "mychannel") + defer pubsub.Close() + + ch := pubsub.Channel() + + bigVal := bigVal() + err := client.Publish(ctx, "mychannel", bigVal).Err() + Expect(err).NotTo(HaveOccurred()) + + var msg *Message + Eventually(ch).WithTimeout(5 * time.Second).Should(Receive(&msg)) + Expect(msg.Channel).To(Equal("mychannel")) + Expect(msg.Payload).To(Equal(string(bigVal))) + }) + + It("supports concurrent Publish and Receive", func() { + const N = 100 + + pubsub := client.Subscribe(ctx, "mychannel") + defer pubsub.Close() + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + + for i := 0; i < N; i++ { + _, err := pubsub.ReceiveTimeout(ctx, 5*time.Second) + Expect(err).NotTo(HaveOccurred()) + } + close(done) + }() + + for i := 0; i < N; i++ { + err := client.Publish(ctx, "mychannel", "hello").Err() + Expect(err).NotTo(HaveOccurred()) + } + + select { + case <-done: + case <-time.After(30 * time.Second): + Fail("timeout") + } + }) + + It("should ChannelMessage", func() { + pubsub := client.Subscribe(ctx, "mychannel") + defer pubsub.Close() + + ch := pubsub.Channel( + WithChannelSize(10), + WithChannelHealthCheckInterval(time.Second), + ) + + text := "test channel message" + err := client.Publish(ctx, "mychannel", text).Err() + Expect(err).NotTo(HaveOccurred()) + + var msg *Message + Eventually(ch).Should(Receive(&msg)) + Expect(msg.Channel).To(Equal("mychannel")) + Expect(msg.Payload).To(Equal(text)) + }) +}) + +func bigVal() []byte { + return bytes.Repeat([]byte{'*'}, 1<<17) // 128kb +}