Skip to content

Commit 48ff0f0

Browse files
committed
Merge pull request redis#42 from go-redis/fix/rewrite_rate_limiter
Rewrite rate limiter.
2 parents 97695ed + 551257a commit 48ff0f0

File tree

3 files changed

+59
-38
lines changed

3 files changed

+59
-38
lines changed

pool.go

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -115,12 +115,7 @@ func newConnPool(dial func() (*conn, error), opt *options) *connPool {
115115
}
116116

117117
func (p *connPool) new() (*conn, error) {
118-
select {
119-
case _, ok := <-p.rl.C:
120-
if !ok {
121-
return nil, errClosed
122-
}
123-
default:
118+
if !p.rl.Check() {
124119
return nil, errRateLimited
125120
}
126121
return p.dial()
@@ -263,7 +258,7 @@ func (p *connPool) Close() error {
263258
return nil
264259
}
265260
p.closed = true
266-
close(p.rl.C)
261+
p.rl.Close()
267262
var retErr error
268263
for {
269264
e := p.conns.Front()

rate_limit.go

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,52 @@
11
package redis
22

33
import (
4+
"sync/atomic"
45
"time"
56
)
67

78
type rateLimiter struct {
8-
C chan struct{}
9+
v int64
10+
11+
_closed int64
912
}
1013

11-
func newRateLimiter(limit time.Duration, chanSize int) *rateLimiter {
14+
func newRateLimiter(limit time.Duration, bucketSize int) *rateLimiter {
1215
rl := &rateLimiter{
13-
C: make(chan struct{}, chanSize),
14-
}
15-
for i := 0; i < chanSize; i++ {
16-
rl.C <- struct{}{}
16+
v: int64(bucketSize),
1717
}
18-
go rl.loop(limit)
18+
go rl.loop(limit, int64(bucketSize))
1919
return rl
2020
}
2121

22-
func (rl *rateLimiter) loop(limit time.Duration) {
23-
defer func() {
24-
recover()
25-
}()
22+
func (rl *rateLimiter) loop(limit time.Duration, bucketSize int64) {
2623
for {
27-
select {
28-
case rl.C <- struct{}{}:
29-
default:
24+
if rl.closed() {
25+
break
26+
}
27+
if v := atomic.LoadInt64(&rl.v); v < bucketSize {
28+
atomic.AddInt64(&rl.v, 1)
3029
}
3130
time.Sleep(limit)
3231
}
3332
}
33+
34+
func (rl *rateLimiter) Check() bool {
35+
for {
36+
if v := atomic.LoadInt64(&rl.v); v > 0 {
37+
if atomic.CompareAndSwapInt64(&rl.v, v, v-1) {
38+
return true
39+
}
40+
}
41+
return false
42+
}
43+
}
44+
45+
func (rl *rateLimiter) Close() error {
46+
atomic.StoreInt64(&rl._closed, 1)
47+
return nil
48+
}
49+
50+
func (rl *rateLimiter) closed() bool {
51+
return atomic.LoadInt64(&rl._closed) == 1
52+
}

redis_test.go

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2833,14 +2833,16 @@ func (t *RedisTest) transactionalIncr(c *C) ([]redis.Cmder, error) {
28332833
}
28342834

28352835
func (t *RedisTest) TestWatchUnwatch(c *C) {
2836-
const N = 10000
2836+
var n = 10000
2837+
if testing.Short() {
2838+
n = 1000
2839+
}
28372840

28382841
set := t.client.Set("key", "0")
28392842
c.Assert(set.Err(), IsNil)
2840-
c.Assert(set.Val(), Equals, "OK")
28412843

28422844
wg := &sync.WaitGroup{}
2843-
for i := 0; i < N; i++ {
2845+
for i := 0; i < n; i++ {
28442846
wg.Add(1)
28452847
go func() {
28462848
defer wg.Done()
@@ -2858,19 +2860,22 @@ func (t *RedisTest) TestWatchUnwatch(c *C) {
28582860
}
28592861
wg.Wait()
28602862

2861-
get := t.client.Get("key")
2862-
c.Assert(get.Err(), IsNil)
2863-
c.Assert(get.Val(), Equals, strconv.FormatInt(N, 10))
2863+
val, err := t.client.Get("key").Int64()
2864+
c.Assert(err, IsNil)
2865+
c.Assert(val, Equals, int64(n))
28642866
}
28652867

28662868
//------------------------------------------------------------------------------
28672869

28682870
func (t *RedisTest) TestRaceEcho(c *C) {
2869-
const N = 10000
2871+
var n = 10000
2872+
if testing.Short() {
2873+
n = 1000
2874+
}
28702875

28712876
wg := &sync.WaitGroup{}
2872-
wg.Add(N)
2873-
for i := 0; i < N; i++ {
2877+
wg.Add(n)
2878+
for i := 0; i < n; i++ {
28742879
go func(i int) {
28752880
msg := "echo" + strconv.Itoa(i)
28762881
echo := t.client.Echo(msg)
@@ -2883,14 +2888,16 @@ func (t *RedisTest) TestRaceEcho(c *C) {
28832888
}
28842889

28852890
func (t *RedisTest) TestRaceIncr(c *C) {
2886-
const N = 10000
2887-
key := "TestIncrFromGoroutines"
2891+
var n = 10000
2892+
if testing.Short() {
2893+
n = 1000
2894+
}
28882895

28892896
wg := &sync.WaitGroup{}
2890-
wg.Add(N)
2891-
for i := int64(0); i < N; i++ {
2897+
wg.Add(n)
2898+
for i := 0; i < n; i++ {
28922899
go func() {
2893-
incr := t.client.Incr(key)
2900+
incr := t.client.Incr("TestRaceIncr")
28942901
if err := incr.Err(); err != nil {
28952902
panic(err)
28962903
}
@@ -2899,9 +2906,9 @@ func (t *RedisTest) TestRaceIncr(c *C) {
28992906
}
29002907
wg.Wait()
29012908

2902-
get := t.client.Get(key)
2903-
c.Assert(get.Err(), IsNil)
2904-
c.Assert(get.Val(), Equals, strconv.Itoa(N))
2909+
val, err := t.client.Get("TestRaceIncr").Result()
2910+
c.Assert(err, IsNil)
2911+
c.Assert(val, Equals, strconv.Itoa(n))
29052912
}
29062913

29072914
//------------------------------------------------------------------------------

0 commit comments

Comments
 (0)