diff --git a/circuit_breaker.go b/circuit_breaker.go index 8b25251b..8592dba7 100644 --- a/circuit_breaker.go +++ b/circuit_breaker.go @@ -8,6 +8,7 @@ package resty import ( "errors" "net/http" + "sync" "sync/atomic" "time" ) @@ -29,15 +30,15 @@ type CircuitBreaker struct { failureThreshold uint32 successThreshold uint32 state atomic.Value // circuitBreakerState - failureCount atomic.Uint32 - successCount atomic.Uint32 - lastFailureAt time.Time + openStartAt atomic.Value // time.Time + sw *tfsw } // NewCircuitBreaker method creates a new [CircuitBreaker] with default settings. // // The default settings are: // - Timeout: 10 seconds +// - SlidingWindowBucketSize: 10 // - FailThreshold: 3 // - SuccessThreshold: 1 // - Policies: CircuitBreaker5xxPolicy @@ -48,6 +49,11 @@ func NewCircuitBreaker() *CircuitBreaker { failureThreshold: 3, successThreshold: 1, } + cb.sw = newSlidingWindow( + func() totalAndFailures { return totalAndFailures{} }, + cb.timeout, + 10, + ) cb.state.Store(circuitBreakerStateClosed) return cb } @@ -75,6 +81,7 @@ func (cb *CircuitBreaker) SetPolicies(policies ...CircuitBreakerPolicy) *Circuit // timeout reaches, a single request is allowed to determine the state. func (cb *CircuitBreaker) SetTimeout(timeout time.Duration) *CircuitBreaker { cb.timeout = timeout + cb.sw.SetInterval(timeout) return cb } @@ -142,30 +149,25 @@ func (cb *CircuitBreaker) applyPolicies(resp *http.Response) { } if failed { - if cb.failureCount.Load() > 0 && time.Since(cb.lastFailureAt) > cb.timeout { - cb.failureCount.Store(0) + cb.sw.Add(totalAndFailures{total: 1, failures: 1}) + } else { + cb.sw.Add(totalAndFailures{total: 1, failures: 0}) + } + switch cb.getState() { + case circuitBreakerStateClosed: + if cb.sw.Get().failures >= int(cb.failureThreshold) { + cb.open() } - - switch cb.getState() { - case circuitBreakerStateClosed: - failCount := cb.failureCount.Add(1) - if failCount >= cb.failureThreshold { - cb.open() - } else { - cb.lastFailureAt = time.Now() - } - case circuitBreakerStateHalfOpen: + case circuitBreakerStateHalfOpen: + totalAndFailure := cb.sw.Get() + if totalAndFailure.total-totalAndFailure.failures >= int(cb.successThreshold) { + cb.changeState(circuitBreakerStateClosed) + } else { cb.open() } - } else { - switch cb.getState() { - case circuitBreakerStateClosed: - return - case circuitBreakerStateHalfOpen: - successCount := cb.successCount.Add(1) - if successCount >= cb.successThreshold { - cb.changeState(circuitBreakerStateClosed) - } + case circuitBreakerStateOpen: + if time.Since(cb.openStartAt.Load().(time.Time)) >= cb.timeout { + cb.changeState(circuitBreakerStateHalfOpen) } } } @@ -179,7 +181,90 @@ func (cb *CircuitBreaker) open() { } func (cb *CircuitBreaker) changeState(state circuitBreakerState) { - cb.failureCount.Store(0) - cb.successCount.Store(0) cb.state.Store(state) + cb.openStartAt.Store(time.Now()) +} + +type tfsw = slidingWindow[totalAndFailures] + +func newSlidingWindow[G group[G]]( + newEmpty func() G, + interval time.Duration, + bucketSize int, +) *slidingWindow[G] { + values := make([]G, 0, bucketSize) + for i := 0; i < bucketSize; i++ { + values = append(values, newEmpty()) + } + return &slidingWindow[G]{ + total: newEmpty(), + values: values, + lastStart: time.Now(), + interval: interval / time.Duration(bucketSize), + } +} + +type slidingWindow[G group[G]] struct { + mutex sync.RWMutex + total G + values []G + + idx int + lastStart time.Time + interval time.Duration +} + +// group is a mathematical concept. The values in the sliding window adhere to group properties. +type group[T any] interface { + op(T) T + empty() T + inverse() T +} + +func (sw *slidingWindow[G]) Add(val G) { + sw.mutex.Lock() + defer sw.mutex.Unlock() + for elapsed := time.Since(sw.lastStart); elapsed > sw.interval; elapsed -= sw.interval { + sw.idx++ + if sw.idx >= len(sw.values) { + sw.idx = 0 + } + sw.lastStart = sw.lastStart.Add(sw.interval) + sw.total = sw.total.op(sw.values[sw.idx].inverse()) + sw.values[sw.idx] = sw.values[sw.idx].empty() + } + sw.total = sw.total.op(val) + sw.values[sw.idx] = sw.values[sw.idx].op(val) +} + +func (sw *slidingWindow[G]) Get() G { + sw.mutex.RLock() + defer sw.mutex.RUnlock() + return sw.total +} +func (sw *slidingWindow[G]) SetInterval(interval time.Duration) { + sw.mutex.Lock() + defer sw.mutex.Unlock() + sw.interval = interval / time.Duration(len(sw.values)) +} + +type totalAndFailures struct { + total int + failures int +} + +func (tf totalAndFailures) op(g totalAndFailures) totalAndFailures { + tf.total += g.total + tf.failures += g.failures + return tf +} + +func (tf totalAndFailures) empty() totalAndFailures { + return totalAndFailures{} +} + +func (tf totalAndFailures) inverse() totalAndFailures { + tf.total = -tf.total + tf.failures = -tf.failures + return tf } diff --git a/client_test.go b/client_test.go index 8a5eef50..6c064ecc 100644 --- a/client_test.go +++ b/client_test.go @@ -1507,11 +1507,11 @@ func TestClientCircuitBreaker(t *testing.T) { _, err = c.R().Get(ts.URL + "/500") assertError(t, err) - assertEqual(t, uint32(1), c.circuitBreaker.failureCount.Load()) + assertEqual(t, 1, c.circuitBreaker.sw.Get().failures) time.Sleep(timeout) _, err = c.R().Get(ts.URL + "/500") assertError(t, err) - assertEqual(t, uint32(1), c.circuitBreaker.failureCount.Load()) + assertEqual(t, 1, c.circuitBreaker.sw.Get().failures) }