Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions pkg/agent/loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -514,9 +514,9 @@ func (al *AgentLoop) runLLMIteration(
var response *providers.LLMResponse
var err error

callLLM := func() (*providers.LLMResponse, error) {
callLLMOnce := func(callCtx context.Context) (*providers.LLMResponse, error) {
if len(agent.Candidates) > 1 && al.fallback != nil {
fbResult, fbErr := al.fallback.Execute(ctx, agent.Candidates,
fbResult, fbErr := al.fallback.Execute(callCtx, agent.Candidates,
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
return agent.Provider.Chat(ctx, messages, providerToolDefs, model, map[string]any{
"max_tokens": agent.MaxTokens,
Expand All @@ -534,16 +534,34 @@ func (al *AgentLoop) runLLMIteration(
}
return fbResult.Response, nil
}
return agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, map[string]any{
return agent.Provider.Chat(callCtx, messages, providerToolDefs, agent.Model, map[string]any{
"max_tokens": agent.MaxTokens,
"temperature": agent.Temperature,
})
}

retryCfg := utils.RetryConfig{
Timeouts: []time.Duration{45 * time.Second, 90 * time.Second, 120 * time.Second},
Backoffs: []time.Duration{2 * time.Second, 5 * time.Second},
Notify: func(attempt, total int, decision utils.RetryDecision) {
if opts.Channel == "" || opts.ChatID == "" || constants.IsInternalChannel(opts.Channel) {
return
}
notice := utils.FormatLLMRetryNotice(attempt, total, decision)
al.bus.PublishOutbound(bus.OutboundMessage{
Channel: opts.Channel,
ChatID: opts.ChatID,
Content: notice,
})
},
}

// Retry loop for context/token errors
maxRetries := 2
for retry := 0; retry <= maxRetries; retry++ {
response, err = callLLM()
response, err = utils.DoWithRetry(ctx, retryCfg, func(attemptCtx context.Context) (*providers.LLMResponse, error) {
return callLLMOnce(attemptCtx)
})
if err == nil {
break
}
Expand Down
29 changes: 27 additions & 2 deletions pkg/tools/toolloop.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"context"
"encoding/json"
"fmt"
"time"

"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
Expand Down Expand Up @@ -42,6 +43,9 @@ func RunToolLoop(
iteration := 0
var finalContent string

perAttemptTimeouts := []time.Duration{45 * time.Second, 90 * time.Second, 120 * time.Second}
backoffs := []time.Duration{2 * time.Second, 5 * time.Second}

for iteration < config.MaxIterations {
iteration++

Expand All @@ -62,8 +66,17 @@ func RunToolLoop(
if llmOpts == nil {
llmOpts = map[string]any{}
}
// 3. Call LLM
response, err := config.Provider.Chat(ctx, messages, providerToolDefs, config.Model, llmOpts)
// 3. Call LLM (with bounded retries on timeouts and server errors)
retryCfg := utils.RetryConfig{
Timeouts: perAttemptTimeouts,
Backoffs: backoffs,
Notify: func(attempt, total int, decision utils.RetryDecision) {
sendRetryNotice(ctx, config.Tools, channel, chatID, attempt, total, decision)
},
}
response, err := utils.DoWithRetry(ctx, retryCfg, func(attemptCtx context.Context) (*providers.LLMResponse, error) {
return config.Provider.Chat(attemptCtx, messages, providerToolDefs, config.Model, llmOpts)
})
if err != nil {
logger.ErrorCF("toolloop", "LLM call failed",
map[string]any{
Expand Down Expand Up @@ -160,3 +173,15 @@ func RunToolLoop(
Iterations: iteration,
}, nil
}

func sendRetryNotice(ctx context.Context, tools *ToolRegistry, channel, chatID string, attempt, total int, decision utils.RetryDecision) {
if tools == nil || channel == "" || chatID == "" {
return
}

notice := utils.FormatLLMRetryNotice(attempt, total, decision)
args := map[string]any{
"content": notice,
}
tools.ExecuteWithContext(ctx, "message", args, channel, chatID, nil)
}
141 changes: 141 additions & 0 deletions pkg/utils/llm_retry.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package utils

import (
"context"
"errors"
"fmt"
"strconv"
"strings"
"time"
)

type RetryReason string

const (
RetryReasonTimeout RetryReason = "timeout"
RetryReasonServerError RetryReason = "server_error"
)

type RetryDecision struct {
Retryable bool
Status int
Reason RetryReason
}

func IsRetryableError(err error) RetryDecision {
if err == nil {
return RetryDecision{}
}

if errors.Is(err, context.DeadlineExceeded) {
return RetryDecision{Retryable: true, Reason: RetryReasonTimeout}
}

msg := err.Error()
if strings.Contains(msg, "context deadline exceeded") || strings.Contains(msg, "Client.Timeout") {
return RetryDecision{Retryable: true, Reason: RetryReasonTimeout}
}

if s, ok := ParseHTTPStatusFromError(msg); ok {
if s >= 500 && s <= 599 {
return RetryDecision{Retryable: true, Status: s, Reason: RetryReasonServerError}
}
return RetryDecision{Retryable: false, Status: s}
}

return RetryDecision{}
}

func ParseHTTPStatusFromError(msg string) (int, bool) {
idx := strings.Index(msg, "Status:")
if idx < 0 {
return 0, false
}

s := strings.TrimSpace(msg[idx+len("Status:"):])
end := 0
for end < len(s) {
c := s[end]
if c < '0' || c > '9' {
break
}
end++
}
if end == 0 {
return 0, false
}

code, err := strconv.Atoi(s[:end])
if err != nil {
return 0, false
}
return code, true
}

type RetryNotifyFunc func(attempt, total int, decision RetryDecision)

type RetryConfig struct {
Timeouts []time.Duration
Backoffs []time.Duration
Notify RetryNotifyFunc
}

func DoWithRetry[T any](
ctx context.Context,
retry RetryConfig,
fn func(context.Context) (T, error),
) (T, error) {
var zero T
if len(retry.Timeouts) == 0 {
return fn(ctx)
}

var lastErr error
for attempt := 1; attempt <= len(retry.Timeouts); attempt++ {
attemptCtx, cancel := context.WithTimeout(ctx, retry.Timeouts[attempt-1])
val, err := fn(attemptCtx)
cancel()

if err == nil {
return val, nil
}

lastErr = err
if attempt == len(retry.Timeouts) {
break
}

decision := IsRetryableError(err)
if !decision.Retryable {
break
}

if retry.Notify != nil {
retry.Notify(attempt, len(retry.Timeouts), decision)
}

if attempt-1 < len(retry.Backoffs) {
select {
case <-ctx.Done():
return zero, ctx.Err()
case <-time.After(retry.Backoffs[attempt-1]):
}
}
}

return zero, lastErr
}

func FormatLLMRetryNotice(attempt, total int, decision RetryDecision) string {
switch decision.Reason {
case RetryReasonTimeout:
return fmt.Sprintf("LLM timed out, retrying (attempt %d/%d)", attempt+1, total)
case RetryReasonServerError:
if decision.Status > 0 {
return fmt.Sprintf("LLM server error (%d), retrying (attempt %d/%d)", decision.Status, attempt+1, total)
}
return fmt.Sprintf("LLM server error, retrying (attempt %d/%d)", attempt+1, total)
default:
return fmt.Sprintf("LLM call failed, retrying (attempt %d/%d)", attempt+1, total)
}
}
145 changes: 145 additions & 0 deletions pkg/utils/llm_retry_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
package utils

import (
"context"
"errors"
"testing"
"time"
)

type stubValueRunner struct {
errors []error
vals []string
calls int
}

func (s *stubValueRunner) Run(ctx context.Context) (string, error) {
s.calls++
idx := s.calls - 1
if idx < len(s.errors) && s.errors[idx] != nil {
return "", s.errors[idx]
}
if idx < len(s.vals) {
return s.vals[idx], nil
}
return "", errors.New("no value")
}

func TestLLMRetry_IsRetryableError(t *testing.T) {
cases := []struct {
name string
err error
want RetryDecision
}{
{
name: "deadline exceeded",
err: context.DeadlineExceeded,
want: RetryDecision{Retryable: true, Reason: RetryReasonTimeout},
},
{
name: "client timeout string",
err: errors.New("failed to read response: context deadline exceeded (Client.Timeout)"),
want: RetryDecision{Retryable: true, Reason: RetryReasonTimeout},
},
{
name: "server 502",
err: errors.New("API request failed:\n Status: 502\n Body: bad"),
want: RetryDecision{Retryable: true, Status: 502, Reason: RetryReasonServerError},
},
{
name: "client 400",
err: errors.New("API request failed:\n Status: 400\n Body: bad"),
want: RetryDecision{Retryable: false, Status: 400},
},
{
name: "other error",
err: errors.New("something else"),
want: RetryDecision{Retryable: false},
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got := IsRetryableError(tc.err)
if got.Retryable != tc.want.Retryable || got.Status != tc.want.Status || got.Reason != tc.want.Reason {
t.Fatalf("IsRetryableError(%v) = %+v, want %+v", tc.err, got, tc.want)
}
})
}
}

func TestLLMRetry_DoWithRetry_TimeoutThenSuccess(t *testing.T) {
runner := &stubValueRunner{
errors: []error{context.DeadlineExceeded, nil},
vals: []string{"", "ok"},
}

notices := 0
retryCfg := RetryConfig{
Timeouts: []time.Duration{5 * time.Millisecond, 5 * time.Millisecond},
Backoffs: []time.Duration{},
Notify: func(attempt, total int, decision RetryDecision) {
notices++
if decision.Reason != RetryReasonTimeout {
t.Fatalf("expected timeout reason, got %v", decision.Reason)
}
},
}

val, err := DoWithRetry(context.Background(), retryCfg, runner.Run)
if err != nil {
t.Fatalf("DoWithRetry error: %v", err)
}
if val != "ok" {
t.Fatalf("val = %q, want ok", val)
}
if runner.calls != 2 {
t.Fatalf("runner.calls = %d, want 2", runner.calls)
}
if notices != 1 {
t.Fatalf("notices = %d, want 1", notices)
}
}

func TestLLMRetry_DoWithRetry_ServerErrorThenSuccess(t *testing.T) {
runner := &stubValueRunner{
errors: []error{errors.New("API request failed:\n Status: 502\n Body: bad"), nil},
vals: []string{"", "ok"},
}

retryCfg := RetryConfig{
Timeouts: []time.Duration{5 * time.Millisecond, 5 * time.Millisecond},
Backoffs: []time.Duration{},
}

val, err := DoWithRetry(context.Background(), retryCfg, runner.Run)
if err != nil {
t.Fatalf("DoWithRetry error: %v", err)
}
if val != "ok" {
t.Fatalf("val = %q, want ok", val)
}
if runner.calls != 2 {
t.Fatalf("runner.calls = %d, want 2", runner.calls)
}
}

func TestLLMRetry_DoWithRetry_NoRetryOnClientError(t *testing.T) {
runner := &stubValueRunner{
errors: []error{errors.New("API request failed:\n Status: 400\n Body: bad")},
vals: []string{""},
}

retryCfg := RetryConfig{
Timeouts: []time.Duration{5 * time.Millisecond, 5 * time.Millisecond},
Backoffs: []time.Duration{},
}

_, err := DoWithRetry(context.Background(), retryCfg, runner.Run)
if err == nil {
t.Fatalf("expected error, got nil")
}
if runner.calls != 1 {
t.Fatalf("runner.calls = %d, want 1", runner.calls)
}
}