diff --git a/app/main_test.go b/app/main_test.go index edc573b..fea22b3 100644 --- a/app/main_test.go +++ b/app/main_test.go @@ -716,6 +716,23 @@ func TestAllowRedisTokenBucketReject(t *testing.T) { <-done } +func TestAllowRedisTokenBucketError(t *testing.T) { + rl := NewRateLimiter(1, time.Second, "token_bucket") + t.Cleanup(rl.Stop) + srv, cli := net.Pipe() + done := make(chan struct{}) + go func() { + defer func() { srv.Close(); close(done) }() + br := bufio.NewReader(srv) + parseRedisCommand(t, br) + srv.Write([]byte("-ERR nope\r\n")) + }() + if _, err := rl.allowRedisTokenBucket(cli, "k"); err == nil { + t.Fatal("expected redis error") + } + <-done +} + func TestAllowRedisLeakyBucketPoolFullClosesConnection(t *testing.T) { old := *redisAddr *redisAddr = "dummy" @@ -815,6 +832,23 @@ func TestAllowRedisLeakyBucketReject(t *testing.T) { <-done } +func TestAllowRedisLeakyBucketError(t *testing.T) { + rl := NewRateLimiter(1, time.Second, "leaky_bucket") + t.Cleanup(rl.Stop) + srv, cli := net.Pipe() + done := make(chan struct{}) + go func() { + defer func() { srv.Close(); close(done) }() + br := bufio.NewReader(srv) + parseRedisCommand(t, br) + srv.Write([]byte("-ERR boom\r\n")) + }() + if _, err := rl.allowRedisLeakyBucket(cli, "k"); err == nil { + t.Fatal("expected redis GET error") + } + <-done +} + func TestAllowRedisLeakyBucketAllow(t *testing.T) { rl := NewRateLimiter(1, time.Hour, "leaky_bucket") t.Cleanup(rl.Stop) diff --git a/app/redis_test.go b/app/redis_test.go index bb8ea50..2a653ad 100644 --- a/app/redis_test.go +++ b/app/redis_test.go @@ -5,10 +5,41 @@ import ( "fmt" "io" "net" + "strconv" + "strings" "testing" "time" ) +func readRESPCommand(t *testing.T, br *bufio.Reader) []string { + t.Helper() + + line, err := br.ReadString('\n') + if err != nil { + t.Fatalf("read command length: %v", err) + } + if !strings.HasPrefix(line, "*") { + t.Fatalf("unexpected command prefix %q", line) + } + n, err := strconv.Atoi(strings.TrimSpace(line[1:])) + if err != nil { + t.Fatalf("parse command length: %v", err) + } + + args := make([]string, n) + for i := 0; i < n; i++ { + if _, err := br.ReadString('\n'); err != nil { + t.Fatalf("read bulk len: %v", err) + } + arg, err := br.ReadString('\n') + if err != nil { + t.Fatalf("read bulk data: %v", err) + } + args[i] = strings.TrimSpace(arg) + } + return args +} + func TestRedisCmdInt(t *testing.T) { srv, cli := net.Pipe() defer srv.Close() @@ -268,6 +299,23 @@ func TestRedisCmdStringReadError(t *testing.T) { } } +func TestRedisCmdStringBulkReadError(t *testing.T) { + srv, cli := net.Pipe() + defer cli.Close() + + go func() { + br := bufio.NewReader(srv) + br.ReadBytes('\n') + br.ReadBytes('\n') + srv.Write([]byte("$4\r\n")) + srv.Close() + }() + + if _, err := redisCmdString(cli, "GET", "key"); err == nil { + t.Fatal("expected bulk read error") + } +} + func TestRedisCmdStringSimple(t *testing.T) { srv, cli := net.Pipe() defer srv.Close() @@ -309,3 +357,94 @@ func TestRedisCmdStringInteger(t *testing.T) { t.Fatalf("expected 5, got %q", val) } } + +func TestAllowRedisTokenBucketEmpty(t *testing.T) { + oldAddr := *redisAddr + *redisAddr = "redis://example:6379" + t.Cleanup(func() { *redisAddr = oldAddr }) + + rl := NewRateLimiter(1, time.Second, "token_bucket") + srv, cli := net.Pipe() + defer srv.Close() + defer cli.Close() + + now := time.Now().UnixNano() + + go func() { + br := bufio.NewReader(srv) + + // GET k + readRESPCommand(t, br) + payload := fmt.Sprintf("0 %d", now) + srv.Write([]byte(fmt.Sprintf("$%d\r\n%s\r\n", len(payload), payload))) + + // SET k + readRESPCommand(t, br) + srv.Write([]byte("+OK\r\n")) + + // PEXPIRE k + readRESPCommand(t, br) + srv.Write([]byte(":1\r\n")) + }() + + allowed, err := rl.allowRedisTokenBucket(cli, "k") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if allowed { + t.Fatal("expected request to be rate limited when bucket is empty") + } +} + +func TestAllowRedisLeakyBucketOverLimit(t *testing.T) { + oldAddr := *redisAddr + *redisAddr = "redis://example:6379" + t.Cleanup(func() { *redisAddr = oldAddr }) + + rl := NewRateLimiter(1, time.Second, "leaky_bucket") + now := time.Now().UnixNano() + srv, cli := net.Pipe() + defer srv.Close() + defer cli.Close() + + go func() { + br := bufio.NewReader(srv) + + // GET k + readRESPCommand(t, br) + payload := fmt.Sprintf("2 %d", now) + srv.Write([]byte(fmt.Sprintf("$%d\r\n%s\r\n", len(payload), payload))) + + // SET k + readRESPCommand(t, br) + srv.Write([]byte("+OK\r\n")) + + // PEXPIRE k + readRESPCommand(t, br) + srv.Write([]byte(":1\r\n")) + }() + + allowed, err := rl.allowRedisLeakyBucket(cli, "k") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if allowed { + t.Fatal("expected request to be rate limited when bucket is over limit") + } +} + +func TestRetryAfterRedisTLSMissingCA(t *testing.T) { + oldAddr := *redisAddr + oldCA := *redisCA + *redisAddr = "rediss://example.com:6379" + *redisCA = "does-not-exist" + t.Cleanup(func() { + *redisAddr = oldAddr + *redisCA = oldCA + }) + + rl := NewRateLimiter(1, time.Second, "fixed_window") + if _, err := rl.retryAfterRedis("key"); err == nil { + t.Fatal("expected error when CA file cannot be read") + } +}