Skip to content

Commit ce1a7c4

Browse files
committed
support error wrapping for io and context errors
1 parent 4d0db2b commit ce1a7c4

File tree

3 files changed

+286
-30
lines changed

3 files changed

+286
-30
lines changed

error.go

Lines changed: 64 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -52,27 +52,54 @@ type Error interface {
5252
var _ Error = proto.RedisError("")
5353

5454
func isContextError(err error) bool {
55-
switch err {
56-
case context.Canceled, context.DeadlineExceeded:
57-
return true
58-
default:
59-
return false
55+
// Check for wrapped context errors using errors.Is
56+
return errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)
57+
}
58+
59+
// isTimeoutError checks if an error is a timeout error, even if wrapped.
60+
// Returns (isTimeout, shouldRetryOnTimeout) where:
61+
// - isTimeout: true if the error is any kind of timeout error
62+
// - shouldRetryOnTimeout: true if Timeout() method returns true
63+
func isTimeoutError(err error) (isTimeout bool, hasTimeoutFlag bool) {
64+
// Check for timeoutError interface (works with wrapped errors)
65+
var te timeoutError
66+
if errors.As(err, &te) {
67+
return true, te.Timeout()
6068
}
69+
70+
// Check for net.Error specifically (common case for network timeouts)
71+
var netErr net.Error
72+
if errors.As(err, &netErr) {
73+
return true, netErr.Timeout()
74+
}
75+
76+
return false, false
6177
}
6278

6379
func shouldRetry(err error, retryTimeout bool) bool {
64-
switch err {
65-
case io.EOF, io.ErrUnexpectedEOF:
80+
if err == nil {
81+
return false
82+
}
83+
84+
// Check for EOF errors (works with wrapped errors)
85+
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
6686
return true
67-
case nil, context.Canceled, context.DeadlineExceeded:
87+
}
88+
89+
// Check for context errors (works with wrapped errors)
90+
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
6891
return false
69-
case pool.ErrPoolTimeout:
92+
}
93+
94+
// Check for pool timeout (works with wrapped errors)
95+
if errors.Is(err, pool.ErrPoolTimeout) {
7096
// connection pool timeout, increase retries. #3289
7197
return true
7298
}
7399

74-
if v, ok := err.(timeoutError); ok {
75-
if v.Timeout() {
100+
// Check for timeout errors (works with wrapped errors)
101+
if isTimeout, hasTimeoutFlag := isTimeoutError(err); isTimeout {
102+
if hasTimeoutFlag {
76103
return retryTimeout
77104
}
78105
return true
@@ -115,23 +142,37 @@ func shouldRetry(err error, retryTimeout bool) bool {
115142
if strings.HasPrefix(s, "TRYAGAIN ") {
116143
return true
117144
}
145+
if strings.HasPrefix(s, "MASTERDOWN ") {
146+
return true
147+
}
118148

119149
return false
120150
}
121151

122152
func isRedisError(err error) bool {
123-
_, ok := err.(proto.RedisError)
124-
return ok
153+
// Check if error implements the Error interface (works with wrapped errors)
154+
var redisErr Error
155+
if errors.As(err, &redisErr) {
156+
return true
157+
}
158+
// Also check for proto.RedisError specifically
159+
var protoRedisErr proto.RedisError
160+
return errors.As(err, &protoRedisErr)
125161
}
126162

127163
func isBadConn(err error, allowTimeout bool, addr string) bool {
128-
switch err {
129-
case nil:
130-
return false
131-
case context.Canceled, context.DeadlineExceeded:
132-
return true
133-
case pool.ErrConnUnusableTimeout:
134-
return true
164+
if err == nil {
165+
return false
166+
}
167+
168+
// Check for context errors (works with wrapped errors)
169+
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
170+
return true
171+
}
172+
173+
// Check for pool timeout errors (works with wrapped errors)
174+
if errors.Is(err, pool.ErrConnUnusableTimeout) {
175+
return true
135176
}
136177

137178
if isRedisError(err) {
@@ -151,7 +192,9 @@ func isBadConn(err error, allowTimeout bool, addr string) bool {
151192
}
152193

153194
if allowTimeout {
154-
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
195+
// Check for network timeout errors (works with wrapped errors)
196+
var netErr net.Error
197+
if errors.As(err, &netErr) && netErr.Timeout() {
155198
return false
156199
}
157200
}

error_wrapping_test.go

Lines changed: 192 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import (
44
"context"
55
"errors"
66
"fmt"
7+
"io"
8+
"strings"
79
"testing"
810

911
"github.com/redis/go-redis/v9"
@@ -443,17 +445,198 @@ func TestCustomErrorTypeWrapping(t *testing.T) {
443445
}
444446
}
445447

446-
// Helper function to check if a string contains a substring
447-
func contains(s, substr string) bool {
448-
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && findSubstring(s, substr))
448+
// TestTimeoutErrorWrapping tests that timeout errors work correctly when wrapped
449+
func TestTimeoutErrorWrapping(t *testing.T) {
450+
// Test 1: Wrapped timeoutError interface
451+
t.Run("Wrapped timeoutError with Timeout()=true", func(t *testing.T) {
452+
timeoutErr := &testTimeoutError{timeout: true, msg: "i/o timeout"}
453+
wrappedErr := fmt.Errorf("hook wrapper: %w", timeoutErr)
454+
doubleWrappedErr := fmt.Errorf("another wrapper: %w", wrappedErr)
455+
456+
// Should NOT retry when retryTimeout=false
457+
if redis.ShouldRetry(doubleWrappedErr, false) {
458+
t.Errorf("Should not retry timeout error when retryTimeout=false")
459+
}
460+
461+
// Should retry when retryTimeout=true
462+
if !redis.ShouldRetry(doubleWrappedErr, true) {
463+
t.Errorf("Should retry timeout error when retryTimeout=true")
464+
}
465+
})
466+
467+
// Test 2: Wrapped timeoutError with Timeout()=false
468+
t.Run("Wrapped timeoutError with Timeout()=false", func(t *testing.T) {
469+
timeoutErr := &testTimeoutError{timeout: false, msg: "connection error"}
470+
wrappedErr := fmt.Errorf("hook wrapper: %w", timeoutErr)
471+
472+
// Should always retry when Timeout()=false
473+
if !redis.ShouldRetry(wrappedErr, false) {
474+
t.Errorf("Should retry non-timeout error even when retryTimeout=false")
475+
}
476+
if !redis.ShouldRetry(wrappedErr, true) {
477+
t.Errorf("Should retry non-timeout error when retryTimeout=true")
478+
}
479+
})
480+
481+
// Test 3: Wrapped net.Error with Timeout()=true
482+
t.Run("Wrapped net.Error", func(t *testing.T) {
483+
netErr := &testNetError{timeout: true, temporary: true, msg: "network timeout"}
484+
wrappedErr := fmt.Errorf("hook context: %w", netErr)
485+
486+
// Should respect retryTimeout parameter
487+
if redis.ShouldRetry(wrappedErr, false) {
488+
t.Errorf("Should not retry network timeout when retryTimeout=false")
489+
}
490+
if !redis.ShouldRetry(wrappedErr, true) {
491+
t.Errorf("Should retry network timeout when retryTimeout=true")
492+
}
493+
})
494+
495+
// Test 4: Multiple levels of wrapping
496+
t.Run("Multiple levels of wrapping", func(t *testing.T) {
497+
timeoutErr := &testTimeoutError{timeout: true, msg: "timeout"}
498+
customErr := &AppError{
499+
Code: "TIMEOUT_ERROR",
500+
Message: "Operation timed out",
501+
RequestID: "req-timeout-123",
502+
Err: timeoutErr,
503+
}
504+
wrappedErr := fmt.Errorf("hook wrapper: %w", customErr)
505+
506+
// Should still detect timeout through multiple wrappers
507+
if redis.ShouldRetry(wrappedErr, false) {
508+
t.Errorf("Should not retry timeout through custom error when retryTimeout=false")
509+
}
510+
if !redis.ShouldRetry(wrappedErr, true) {
511+
t.Errorf("Should retry timeout through custom error when retryTimeout=true")
512+
}
513+
514+
// Should be able to extract custom error
515+
var appErr *AppError
516+
if !errors.As(wrappedErr, &appErr) {
517+
t.Errorf("Should be able to extract AppError from wrapped error")
518+
}
519+
})
520+
}
521+
522+
// testTimeoutError implements the timeoutError interface for testing
523+
type testTimeoutError struct {
524+
timeout bool
525+
msg string
526+
}
527+
528+
func (e *testTimeoutError) Error() string {
529+
return e.msg
530+
}
531+
532+
func (e *testTimeoutError) Timeout() bool {
533+
return e.timeout
534+
}
535+
536+
// testNetError implements net.Error for testing
537+
type testNetError struct {
538+
timeout bool
539+
temporary bool
540+
msg string
541+
}
542+
543+
func (e *testNetError) Error() string {
544+
return e.msg
449545
}
450546

451-
func findSubstring(s, substr string) bool {
452-
for i := 0; i <= len(s)-len(substr); i++ {
453-
if s[i:i+len(substr)] == substr {
454-
return true
547+
func (e *testNetError) Timeout() bool {
548+
return e.timeout
549+
}
550+
551+
func (e *testNetError) Temporary() bool {
552+
return e.temporary
553+
}
554+
555+
// TestContextErrorWrapping tests that context errors work correctly when wrapped
556+
func TestContextErrorWrapping(t *testing.T) {
557+
t.Run("Wrapped context.Canceled", func(t *testing.T) {
558+
wrappedErr := fmt.Errorf("operation failed: %w", context.Canceled)
559+
doubleWrappedErr := fmt.Errorf("hook wrapper: %w", wrappedErr)
560+
561+
// Should NOT retry
562+
if redis.ShouldRetry(doubleWrappedErr, false) {
563+
t.Errorf("Should not retry wrapped context.Canceled")
455564
}
456-
}
457-
return false
565+
if redis.ShouldRetry(doubleWrappedErr, true) {
566+
t.Errorf("Should not retry wrapped context.Canceled even with retryTimeout=true")
567+
}
568+
})
569+
570+
t.Run("Wrapped context.DeadlineExceeded", func(t *testing.T) {
571+
wrappedErr := fmt.Errorf("timeout: %w", context.DeadlineExceeded)
572+
doubleWrappedErr := fmt.Errorf("hook wrapper: %w", wrappedErr)
573+
574+
// Should NOT retry
575+
if redis.ShouldRetry(doubleWrappedErr, false) {
576+
t.Errorf("Should not retry wrapped context.DeadlineExceeded")
577+
}
578+
if redis.ShouldRetry(doubleWrappedErr, true) {
579+
t.Errorf("Should not retry wrapped context.DeadlineExceeded even with retryTimeout=true")
580+
}
581+
})
582+
}
583+
584+
// TestIOErrorWrapping tests that io errors work correctly when wrapped
585+
func TestIOErrorWrapping(t *testing.T) {
586+
t.Run("Wrapped io.EOF", func(t *testing.T) {
587+
wrappedErr := fmt.Errorf("read failed: %w", io.EOF)
588+
doubleWrappedErr := fmt.Errorf("hook wrapper: %w", wrappedErr)
589+
590+
// Should retry
591+
if !redis.ShouldRetry(doubleWrappedErr, false) {
592+
t.Errorf("Should retry wrapped io.EOF")
593+
}
594+
})
595+
596+
t.Run("Wrapped io.ErrUnexpectedEOF", func(t *testing.T) {
597+
wrappedErr := fmt.Errorf("read failed: %w", io.ErrUnexpectedEOF)
598+
599+
// Should retry
600+
if !redis.ShouldRetry(wrappedErr, false) {
601+
t.Errorf("Should retry wrapped io.ErrUnexpectedEOF")
602+
}
603+
})
604+
}
605+
606+
// TestPoolErrorWrapping tests that pool errors work correctly when wrapped
607+
func TestPoolErrorWrapping(t *testing.T) {
608+
t.Run("Wrapped pool.ErrPoolTimeout", func(t *testing.T) {
609+
wrappedErr := fmt.Errorf("connection failed: %w", redis.ErrPoolTimeout)
610+
doubleWrappedErr := fmt.Errorf("hook wrapper: %w", wrappedErr)
611+
612+
// Should retry
613+
if !redis.ShouldRetry(doubleWrappedErr, false) {
614+
t.Errorf("Should retry wrapped pool.ErrPoolTimeout")
615+
}
616+
})
617+
}
618+
619+
// TestRedisErrorWrapping tests that RedisError detection works with wrapped errors
620+
func TestRedisErrorWrapping(t *testing.T) {
621+
t.Run("Wrapped proto.RedisError", func(t *testing.T) {
622+
redisErr := proto.RedisError("ERR something went wrong")
623+
wrappedErr := fmt.Errorf("command failed: %w", redisErr)
624+
doubleWrappedErr := fmt.Errorf("hook wrapper: %w", wrappedErr)
625+
626+
// Create a command and set the wrapped error
627+
cmd := redis.NewStatusCmd(context.Background(), "GET", "key")
628+
cmd.SetErr(doubleWrappedErr)
629+
630+
// The error should still be recognized as a Redis error
631+
// This is tested indirectly through the typed error system
632+
if !strings.Contains(cmd.Err().Error(), "ERR something went wrong") {
633+
t.Errorf("Error message not preserved through wrapping")
634+
}
635+
})
636+
}
637+
638+
// Helper function to check if a string contains a substring
639+
func contains(s, substr string) bool {
640+
return strings.Contains(s, substr)
458641
}
459642

0 commit comments

Comments
 (0)