diff --git a/app/main_test.go b/app/main_test.go index aedb4e1..edc573b 100644 --- a/app/main_test.go +++ b/app/main_test.go @@ -474,6 +474,19 @@ func TestAllowRedisUnsupportedScheme(t *testing.T) { } } +func TestAllowRedisParseError(t *testing.T) { + old := *redisAddr + *redisAddr = "://bad" + rl := NewRateLimiter(1, time.Second, "") + t.Cleanup(func() { + *redisAddr = old + rl.Stop() + }) + if _, err := rl.allowRedis("k"); err == nil { + t.Fatal("expected parse error") + } +} + func TestAllowRedisTLSCAError(t *testing.T) { oldAddr, oldCA := *redisAddr, *redisCA *redisAddr = "rediss://localhost:1" @@ -631,6 +644,42 @@ func TestAllowRedisTokenBucket(t *testing.T) { <-done } +func TestAllowRedisTokenBucketRefillClamp(t *testing.T) { + rl := NewRateLimiter(2, time.Second, "token_bucket") + t.Cleanup(rl.Stop) + srv, cli := net.Pipe() + done := make(chan struct{}) + go func() { + defer func() { srv.Close(); close(done) }() + br := bufio.NewReader(srv) + if cmd, args := parseRedisCommand(t, br); cmd != "GET" || args[0] != "k" { + t.Errorf("unexpected command %s %v", cmd, args) + return + } + val := fmt.Sprintf("%f %d", float64(rl.limit)+5, time.Now().Add(-time.Minute).UnixNano()) + fmt.Fprintf(srv, "$%d\r\n%s\r\n", len(val), val) + if cmd, args := parseRedisCommand(t, br); cmd != "SET" || args[0] != "k" { + t.Errorf("unexpected command %s %v", cmd, args) + return + } + srv.Write([]byte("+OK\r\n")) + if cmd, _ := parseRedisCommand(t, br); cmd != "EXPIRE" && cmd != "PEXPIRE" { + t.Errorf("unexpected command %s", cmd) + return + } + srv.Write([]byte(":1\r\n")) + }() + + ok, err := rl.allowRedisTokenBucket(cli, "k") + if err != nil { + t.Fatalf("allowRedisTokenBucket error: %v", err) + } + if !ok { + t.Fatal("expected token bucket allow") + } + <-done +} + func TestAllowRedisTokenBucketReject(t *testing.T) { rl := NewRateLimiter(1, time.Hour, "token_bucket") t.Cleanup(rl.Stop) @@ -1060,6 +1109,45 @@ func TestAllowRedisDialError(t *testing.T) { } } +func TestAllowRedisAuthError(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + done := make(chan struct{}) + go func() { + defer close(done) + c, err := ln.Accept() + if err != nil { + return + } + defer c.Close() + br := bufio.NewReader(c) + if cmd, args := parseRedisCommand(t, br); cmd != "AUTH" || len(args) != 2 || args[0] != "user" || args[1] != "pw" { + t.Errorf("unexpected auth command %s %v", cmd, args) + return + } + c.Write([]byte("-ERR nope\r\n")) + }() + + oldAddr, oldTimeout := *redisAddr, *redisTimeout + *redisAddr = "redis://user:pw@" + ln.Addr().String() + *redisTimeout = time.Second + rl := NewRateLimiter(1, time.Second, "") + t.Cleanup(func() { + *redisAddr = oldAddr + *redisTimeout = oldTimeout + rl.Stop() + }) + + if _, err := rl.allowRedis("k"); err == nil { + t.Fatal("expected auth error") + } + <-done +} + // Helper process used for testing main(). func TestMainHelper(t *testing.T) { if os.Getenv("GO_WANT_MAIN_HELPER") != "1" {