Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion cmd/collect_website.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ package cmd
import (
"fmt"
"os"
"time"

"github.com/Snider/Borg/pkg/circuitbreaker"
"github.com/schollz/progressbar/v3"
"golang.org/x/exp/slog"
"github.com/Snider/Borg/pkg/compress"
"github.com/Snider/Borg/pkg/tim"
"github.com/Snider/Borg/pkg/trix"
Expand Down Expand Up @@ -38,6 +41,11 @@ func NewCollectWebsiteCmd() *cobra.Command {
format, _ := cmd.Flags().GetString("format")
compression, _ := cmd.Flags().GetString("compression")
password, _ := cmd.Flags().GetString("password")
noCircuitBreaker, _ := cmd.Flags().GetBool("no-circuit-breaker")
circuitFailures, _ := cmd.Flags().GetInt("circuit-failures")
circuitCooldown, _ := cmd.Flags().GetDuration("circuit-cooldown")
circuitSuccessThreshold, _ := cmd.Flags().GetInt("circuit-success-threshold")
circuitHalfOpenRequests, _ := cmd.Flags().GetInt("circuit-half-open-requests")

if format != "datanode" && format != "tim" && format != "trix" {
return fmt.Errorf("invalid format: %s (must be 'datanode', 'tim', or 'trix')", format)
Expand All @@ -51,7 +59,25 @@ func NewCollectWebsiteCmd() *cobra.Command {
bar = ui.NewProgressBar(-1, "Crawling website")
}

dn, err := website.DownloadAndPackageWebsite(websiteURL, depth, bar)
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: slog.LevelInfo,
}))

opts := website.DownloadOptions{
URL: websiteURL,
MaxDepth: depth,
ProgressBar: bar,
EnableCircuitBreaker: !noCircuitBreaker,
CBSettings: circuitbreaker.Settings{
FailureThreshold: circuitFailures,
SuccessThreshold: circuitSuccessThreshold,
Cooldown: circuitCooldown,
HalfOpenRequests: circuitHalfOpenRequests,
Logger: logger,
},
}

dn, err := website.DownloadAndPackageWebsite(opts)
if err != nil {
return fmt.Errorf("error downloading and packaging website: %w", err)
}
Expand Down Expand Up @@ -104,5 +130,10 @@ func NewCollectWebsiteCmd() *cobra.Command {
collectWebsiteCmd.PersistentFlags().String("format", "datanode", "Output format (datanode, tim, or trix)")
collectWebsiteCmd.PersistentFlags().String("compression", "none", "Compression format (none, gz, or xz)")
collectWebsiteCmd.PersistentFlags().String("password", "", "Password for encryption")
collectWebsiteCmd.Flags().Bool("no-circuit-breaker", false, "Disable the circuit breaker")
collectWebsiteCmd.Flags().Int("circuit-failures", 5, "Number of failures to trip the circuit breaker")
collectWebsiteCmd.Flags().Duration("circuit-cooldown", 30*time.Second, "Cooldown time for the circuit breaker")
collectWebsiteCmd.Flags().Int("circuit-success-threshold", 2, "Number of successes to close the circuit breaker")
collectWebsiteCmd.Flags().Int("circuit-half-open-requests", 1, "Number of test requests in half-open state")
return collectWebsiteCmd
}
5 changes: 2 additions & 3 deletions cmd/collect_website_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@ import (

"github.com/Snider/Borg/pkg/datanode"
"github.com/Snider/Borg/pkg/website"
"github.com/schollz/progressbar/v3"
)

func TestCollectWebsiteCmd_Good(t *testing.T) {
// Mock the website downloader
oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
website.DownloadAndPackageWebsite = func(opts website.DownloadOptions) (*datanode.DataNode, error) {
return datanode.New(), nil
}
defer func() {
Expand All @@ -35,7 +34,7 @@ func TestCollectWebsiteCmd_Good(t *testing.T) {
func TestCollectWebsiteCmd_Bad(t *testing.T) {
// Mock the website downloader to return an error
oldDownloadAndPackageWebsite := website.DownloadAndPackageWebsite
website.DownloadAndPackageWebsite = func(startURL string, maxDepth int, bar *progressbar.ProgressBar) (*datanode.DataNode, error) {
website.DownloadAndPackageWebsite = func(opts website.DownloadOptions) (*datanode.DataNode, error) {
return nil, fmt.Errorf("website error")
}
defer func() {
Expand Down
7 changes: 6 additions & 1 deletion examples/collect_website/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@ func main() {
log.Println("Collecting website...")

// Download and package the website.
dn, err := website.DownloadAndPackageWebsite("https://example.com", 2, nil)
opts := website.DownloadOptions{
URL: "https://example.com",
MaxDepth: 2,
EnableCircuitBreaker: true,
}
dn, err := website.DownloadAndPackageWebsite(opts)
if err != nil {
log.Fatalf("Failed to collect website: %v", err)
}
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ require (
github.com/wailsapp/mimetype v1.4.1 // indirect
github.com/xanzy/ssh-agent v0.3.3 // indirect
golang.org/x/crypto v0.44.0 // indirect
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/term v0.37.0 // indirect
golang.org/x/text v0.31.0 // indirect
Expand Down
160 changes: 160 additions & 0 deletions pkg/circuitbreaker/circuitbreaker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@

package circuitbreaker

import (
"fmt"
"io"
"sync"
"time"

"golang.org/x/exp/slog"
)

// State represents the state of the circuit breaker.
type State int

const (
// ClosedState is the initial state of the circuit breaker.
ClosedState State = iota
// OpenState is the state when the circuit breaker is tripped.
OpenState
// HalfOpenState is the state when the circuit breaker is testing for recovery.
HalfOpenState
)

// String returns the string representation of the state.
func (s State) String() string {
switch s {
case ClosedState:
return "CLOSED"
case OpenState:
return "OPEN"
case HalfOpenState:
return "HALF-OPEN"
default:
return "UNKNOWN"
}
}

// Settings configures the circuit breaker.
type Settings struct {
// FailureThreshold is the number of consecutive failures before opening the circuit.
FailureThreshold int
// SuccessThreshold is the number of consecutive successes to close the circuit.
SuccessThreshold int
// Cooldown is the time to wait in the open state before transitioning to half-open.
Cooldown time.Duration
// HalfOpenRequests is the number of test requests to allow in the half-open state.
HalfOpenRequests int
// Logger is the logger to use for state changes.
Logger *slog.Logger
}

// CircuitBreaker is a state machine that prevents repeated calls to a failing service.
type CircuitBreaker struct {
settings Settings
domain string
mu sync.Mutex
state State
failures int
successes int
halfOpenRequests int
lastError error
expiry time.Time
}

// New creates a new CircuitBreaker.
func New(domain string, settings Settings) *CircuitBreaker {
if settings.Logger == nil {
settings.Logger = slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{
Level: slog.LevelError,
}))
}
return &CircuitBreaker{
settings: settings,
domain: domain,
state: ClosedState,
}
}

// Execute runs the given function, protected by the circuit breaker.
func (cb *CircuitBreaker) Execute(fn func() (interface{}, error)) (interface{}, error) {
cb.mu.Lock()
defer cb.mu.Unlock()

switch cb.state {
case OpenState:
if time.Now().After(cb.expiry) {
cb.setState(HalfOpenState)
return cb.executeHalfOpen(fn)
}
return nil, fmt.Errorf("circuit is open for: %w", cb.lastError)
case HalfOpenState:
return cb.executeHalfOpen(fn)
default: // ClosedState
return cb.executeClosed(fn)
}
}

func (cb *CircuitBreaker) executeClosed(fn func() (interface{}, error)) (interface{}, error) {
res, err := fn()
if err != nil {
cb.failures++
if cb.failures >= cb.settings.FailureThreshold {
cb.lastError = err
cb.setState(OpenState)
}
return nil, err
}

cb.failures = 0
return res, nil
}

func (cb *CircuitBreaker) executeHalfOpen(fn func() (interface{}, error)) (interface{}, error) {
if cb.halfOpenRequests >= cb.settings.HalfOpenRequests {
return nil, fmt.Errorf("circuit is half-open, test requests exhausted: %w", cb.lastError)
}
cb.halfOpenRequests++

res, err := fn()
if err != nil {
cb.lastError = err
cb.setState(OpenState)
return nil, err
}

cb.successes++
// If enough test requests succeed, close the circuit.
if cb.successes >= cb.settings.SuccessThreshold {
cb.setState(ClosedState)
}

return res, nil
}

func (cb *CircuitBreaker) setState(state State) {
if cb.state == state {
return
}

cb.state = state
logMessage := fmt.Sprintf("Circuit %s for %s", state, cb.domain)
if state == OpenState {
cb.settings.Logger.Warn(logMessage)
} else {
cb.settings.Logger.Info(logMessage)
}

switch state {
case OpenState:
cb.expiry = time.Now().Add(cb.settings.Cooldown)
cb.successes = 0
case ClosedState:
cb.failures = 0
cb.successes = 0
case HalfOpenState:
cb.failures = 0
cb.halfOpenRequests = 0
}
}
101 changes: 101 additions & 0 deletions pkg/circuitbreaker/circuitbreaker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@

package circuitbreaker

import (
"errors"
"testing"
"time"
)

func TestCircuitBreaker(t *testing.T) {
settings := Settings{
FailureThreshold: 2,
SuccessThreshold: 2,
Cooldown: 100 * time.Millisecond,
HalfOpenRequests: 2,
}
cb := New("test.com", settings)

// Initially closed
_, err := cb.Execute(func() (interface{}, error) {
return "success", nil
})
if err != nil {
t.Errorf("Expected success, got %v", err)
}

// Trip the breaker
_, err = cb.Execute(func() (interface{}, error) {
return nil, errors.New("failure 1")
})
if err == nil {
t.Error("Expected failure, got nil")
}
_, err = cb.Execute(func() (interface{}, error) {
return nil, errors.New("failure 2")
})
if err == nil {
t.Error("Expected failure, got nil")
}

// Now open
_, err = cb.Execute(func() (interface{}, error) {
return "should not be called", nil
})
if err == nil || err.Error() != "circuit is open for: failure 2" {
t.Errorf("Expected open circuit error, got %v", err)
}

// Wait for cooldown
time.Sleep(150 * time.Millisecond)

// Half-open, should succeed
_, err = cb.Execute(func() (interface{}, error) {
return "success", nil
})
if err != nil {
t.Errorf("Expected success in half-open, got %v", err)
}

// Still half-open, need another success
_, err = cb.Execute(func() (interface{}, error) {
return "success", nil
})
if err != nil {
t.Errorf("Expected success in half-open, got %v", err)
}

// Now closed again
_, err = cb.Execute(func() (interface{}, error) {
return "success", nil
})
if err != nil {
t.Errorf("Expected success in closed state, got %v", err)
}

// Trip again to test half-open failure
cb.Execute(func() (interface{}, error) {
return nil, errors.New("failure 1")
})
cb.Execute(func() (interface{}, error) {
return nil, errors.New("failure 2")
})

time.Sleep(150 * time.Millisecond)

// Half-open, but fail
_, err = cb.Execute(func() (interface{}, error) {
return nil, errors.New("half-open failure")
})
if err == nil {
t.Error("Expected failure in half-open, got nil")
}

// Should be open again
_, err = cb.Execute(func() (interface{}, error) {
return "should not be called", nil
})
if err == nil || err.Error() != "circuit is open for: half-open failure" {
t.Errorf("Expected open circuit error, got %v", err)
}
}
Loading
Loading