diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index 00e1b4f..064e230 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -57,7 +57,7 @@ DSGo modules are higher-level behaviors built on top of the core `Signature` + ` - **Predict**: Basic structured input → output prediction. - **ChainOfThought**: Adds a reasoning step and stores it in `Prediction.Rationale` (plus any explicit `reasoning` output fields). -- **ReAct**: A tool-calling loop agent (think → act → observe), suitable for external integrations. +- **ReAct**: A native tool-calling loop agent with an auto-injected `finish` tool and a post-loop extractor fallback to produce signature-valid outputs. - **Refine**: Iteratively improves a prediction when `inputs["feedback"]` (or a custom refinement field) is provided. - **BestOfN**: Runs a wrapped module `N` times and selects the best result using a scorer; can parallelize and optionally return all completions. - **Program**: Sequential pipeline; merges previous outputs into the next step's inputs. @@ -93,13 +93,18 @@ graph TD ```mermaid flowchart TD - Q[User query] --> T[Thought] - T --> A[Act: tool call] - A --> O[Observation] - O --> T - T --> F[Final answer] + Q[Inputs] --> M[LM step] + M -->|tool call(s)| A[Act: tool execution] + A --> O[Observation(s)] + O --> M + M -->|finish(args) or direct answer| F[Final outputs (validated)] + F -->|if invalid or limits hit| X[Extractor (schema enforced)] ``` +Notes: +- DSGo auto-injects a synthetic `finish` tool that mirrors the output `Signature`. +- If the loop can’t produce a valid final output (or exceeds limits), DSGo runs an extraction step over the full trajectory to salvage a schema-valid answer. + ## Streaming pipeline ```mermaid diff --git a/QUICKSTART.md b/QUICKSTART.md index 74faedd..d3d25bf 100644 --- a/QUICKSTART.md +++ b/QUICKSTART.md @@ -224,33 +224,78 @@ fmt.Println("Answer:", result.GetString("answer")) ### ReAct - Tool-Using Agents -For tasks requiring external tools: +ReAct is DSGo’s native tool-calling agent loop. It’s designed to keep tool trajectories bounded and to reliably produce outputs that match your `Signature`. ```go -// Define a tool -func searchWeb(ctx context.Context, args map[string]interface{}) (string, error) { +// Tool: keep result sizes bounded (cost + context control). +func searchWeb(ctx context.Context, args map[string]any) (any, error) { query := args["query"].(string) - return fmt.Sprintf("Search results for '%s': Wikipedia, Google, etc.", query), nil + + // Size control: default to a small summary. + maxChars := 800 + if v, ok := args["max_chars"].(float64); ok { // JSON numbers decode as float64 + maxChars = int(v) + } + + result := fmt.Sprintf("Search results for %q: Wikipedia, news, etc...", query) + if len(result) > maxChars { + result = result[:maxChars] + "...(truncated)" + } + return result, nil } searchTool := dsgo.NewTool( "search", - "Search the web for information", + "Search the web (return a short summary)", searchWeb, -).AddParameter("query", dsgo.FieldTypeString, "Search query", true) +). + AddParameter("query", "string", "Search query", true). + AddParameter("max_chars", "integer", "Max chars to return (size control)", false) sig := dsgo.NewSignature("Answer questions using tools"). AddInput("question", dsgo.FieldTypeString, "Question to answer"). AddOutput("answer", dsgo.FieldTypeString, "Final answer based on tool results") -agent := dsgo.NewReAct(sig, lm, []dsgo.Tool{*searchTool}) +history := dsgo.NewHistoryWithLimit(20) // trajectory limit for multi-turn chats -result, _ := agent.Forward(ctx, map[string]any{ +agent := dsgo.NewReAct(sig, lm, []dsgo.Tool{*searchTool}). + WithHistory(history). + WithMaxIterations(8) // trajectory limit for a single run + +pred, err := agent.Forward(ctx, map[string]any{ "question": "Who is the current president of France?", }) -fmt.Println(result.GetString("answer")) +if err != nil { + fmt.Println("error:", err) + return +} +fmt.Println(pred.GetString("answer")) +``` + +#### ReAct output guarantees + +- **Native tool calling first**: Tools execute only when the selected model/provider supports native tool calls. For non-tool models, use `Predict`/`ChainOfThought` or switch models. +- **Structured finish (recommended)**: DSGo auto-injects a synthetic `finish` tool (unless you provide one) whose arguments mirror your output fields. The model can end the loop by calling `finish(answer=...)`. +- **Forced final mode**: On the last iteration DSGo requests a final JSON object matching your output signature and disallows further tool calls. +- **Extractor fallback**: If parsing/validation fails or the loop hits limits, DSGo runs a post-loop extraction step that synthesizes a valid final JSON answer from the full tool trajectory. + +Structured output enforcement is enabled by default and controls the extractor’s retry behavior: + +```go +dsgo.Configure( + dsgo.WithStructuredOutputEnabled(true), + dsgo.WithStructuredOutputMaxAttempts(3), + dsgo.WithStructuredOutputTemperature(0.0), +) ``` +#### ReAct behavioral details + +- **Bounded trajectories**: Tool results are truncated to `MaxToolResultBytes` (default 16KB) and encoded as JSON envelopes. The trajectory is rendered within a soft prompt budget (`MaxPromptBytes`, default 256KB), keeping only the newest steps that fit. +- **Context overflow recovery**: If a provider returns a context-length error, ReAct drops the oldest trajectory steps and retries (up to 3 times). +- **Strict tool schemas**: Tool parameter schemas include `additionalProperties: false`, so providers will reject tool calls with extra keys not in your schema. +- **Termination policies**: The loop terminates early on repeated identical tool calls (3 in a row), repeated identical observations (stagnation), or consecutive tool errors (2 in a row). Termination reason is available in `Prediction.Metadata["termination_reason"]`. + ### Refine - Iterative Improvement For improving outputs through iteration: @@ -326,31 +371,31 @@ fmt.Println("Answer:", result.GetString("answer")) ### ReAct - Tool-Using Agents -For tasks requiring external tools: +Same API, but constructed from the `module` package: ```go -// Define a tool -func searchWeb(ctx context.Context, args map[string]interface{}) (string, error) { - query := args["query"].(string) - return fmt.Sprintf("Search results for '%s': Wikipedia, Google, etc.", query), nil -} - searchTool := dsgo.NewTool( "search", - "Search the web for information", + "Search the web (return a short summary)", searchWeb, -).AddParameter("query", dsgo.FieldTypeString, "Search query", true) +). + AddParameter("query", "string", "Search query", true). + AddParameter("max_chars", "integer", "Max chars to return (size control)", false) sig := dsgo.NewSignature("Answer questions using tools"). AddInput("question", dsgo.FieldTypeString, "Question to answer"). AddOutput("answer", dsgo.FieldTypeString, "Final answer based on tool results") -agent := module.NewReAct(sig, lm, []dsgo.Tool{*searchTool}) +agent := module.NewReAct(sig, lm, []dsgo.Tool{*searchTool}).WithMaxIterations(8) -result, _ := agent.Forward(ctx, map[string]any{ +pred, err := agent.Forward(ctx, map[string]any{ "question": "Who is the current president of France?", }) -fmt.Println(result.GetString("answer")) +if err != nil { + fmt.Println("error:", err) + return +} +fmt.Println(pred.GetString("answer")) ``` ### Refine - Iterative Improvement @@ -439,17 +484,33 @@ func calculate(ctx context.Context, args map[string]interface{}) (string, error) } calcTool := dsgo.NewTool("calculate", "Perform mathematical operations", calculate). - AddParameter("operation", dsgo.FieldTypeString, "Operation (add/multiply/divide)", true). - AddParameter("a", dsgo.FieldTypeFloat, "First number", true). - AddParameter("b", dsgo.FieldTypeFloat, "Second number", true). - AddParameter("precision", dsgo.FieldTypeInt, "Decimal places", false) // Optional + AddParameter("operation", "string", "Operation (add/multiply/divide)", true). + AddParameter("a", "number", "First number", true). + AddParameter("b", "number", "Second number", true). + AddParameter("precision", "integer", "Decimal places", false) // Optional ``` +#### Tool schema best practices + +- Keep parameter schemas **small and explicit**: prefer a few required fields over a single “blob” argument. +- Use constrained types: + - `AddEnumParameter(...)` for flags/modes + - `AddArrayParameter(...)` for lists +- Put operational constraints in the schema (so the model can self-regulate): `max_results`, `max_chars`, `page`, `timeout_ms`. + +#### Tool result size controls + +Tool outputs become part of the ReAct trajectory. Treat them like *prompt tokens*: + +- Default to **summaries**, not raw payloads. +- Add and enforce a `max_chars`/`max_bytes` argument and truncate server-side. +- Prefer “preview” tools (e.g. `read_file(path, max_bytes)`) over returning full files. + ### Multi-Tool Agents ```go weatherTool := dsgo.NewTool("get_weather", "Get current weather", getWeatherFunc). - AddParameter("location", dsgo.FieldTypeString, "City name", true) + AddParameter("location", "string", "City name", true) tools := []dsgo.Tool{*calcTool, *weatherTool} diff --git a/REFERENCE.md b/REFERENCE.md index 1e390e1..8a9c87f 100644 --- a/REFERENCE.md +++ b/REFERENCE.md @@ -115,7 +115,7 @@ This allows filtering either: |---|---|---| | Predict | `dsgo.NewPredict(sig, lm)` | Structured prediction | | ChainOfThought | `dsgo.NewChainOfThought(sig, lm)` | Stores reasoning in `Prediction.Rationale` (and/or explicit reasoning fields) | -| ReAct | `dsgo.NewReAct(sig, lm, tools)` | Tool-using agent loop | +| ReAct | `dsgo.NewReAct(sig, lm, tools)` | Tool-using agent loop (native tool calls + auto `finish` tool + extractor fallback) | | Refine | `dsgo.NewRefine(sig, lm)` | Refines only when `inputs["feedback"]` (or custom refinement field) is provided | | BestOfN | `dsgo.NewBestOfN(module, n)` | Requires `WithScorer(...)`; can parallelize and optionally return all completions | | Program | `dsgo.NewProgram(name)` | Sequential composition | @@ -123,6 +123,23 @@ This allows filtering either: | ProgramOfThought | `dsgo.NewProgramOfThought(sig, lm, language)` | Generates code-first solutions; execution is disabled by default | | MultiChainComparison | `dsgo.NewMultiChainComparison(sig, lm, m)` | Synthesizes from `inputs["completions"]` and prepends `rationale` output | +### ReAct: key knobs + +- **Trajectory limit (per run)**: `WithMaxIterations(n)` (default: 10) +- **Trajectory limit (multi-turn)**: `WithHistory(dsgo.NewHistoryWithLimit(n))` to bound remembered messages +- **Prompted finishing**: ReAct guides the model to terminate by calling a synthetic `finish` tool whose arguments match your output `Signature`. +- **Extractor / schema enforcement**: configure globally via: + +```go +dsgo.Configure( + dsgo.WithStructuredOutputEnabled(true), + dsgo.WithStructuredOutputMaxAttempts(3), + dsgo.WithStructuredOutputTemperature(0.0), +) +``` + +- **Tool result size control**: design tool schemas with explicit bounds like `max_chars`, `max_results`, and enforce them in the tool implementation. + ## Adapters | Adapter | Description | diff --git a/examples/react_experiment/main.go b/examples/react_experiment/main.go index d72df19..b2f4a7d 100644 --- a/examples/react_experiment/main.go +++ b/examples/react_experiment/main.go @@ -3,6 +3,9 @@ package main import ( "context" "fmt" + "os" + "sort" + "strconv" "strings" "sync" "time" @@ -12,41 +15,51 @@ import ( "github.com/assagman/dsgo/internal/core" ) -const MaxIterations = 10 +const ( + MaxIterations = 10 + // Limit concurrency to reduce rate limit errors. + DefaultMaxConcurrent = 3 +) func main() { ctx := context.Background() - models := []string{ - "openrouter/google/gemini-2.5-flash-lite-preview-09-2025", - "openrouter/openai/gpt-4o-mini", - "openrouter/openai/gpt-5-mini-2025-08-07", - "openrouter/openai/gpt-5-nano-2025-08-07", - "openrouter/openai/gpt-4.1-2025-04-14", - "openrouter/google/gemini-2.5-flash", - "openrouter/x-ai/grok-code-fast-1", - "openrouter/deepseek/deepseek-v3.2", - "openrouter/qwen/qwen3-next-80b-a3b-instruct", - "openrouter/z-ai/glm-4.6:exacto", - "openrouter/moonshotai/kimi-k2-0905:exacto", - "openrouter/openai/gpt-oss-120b:exacto", - "openrouter/qwen/qwen3-coder:exacto", + // This example is intended for experimentation; allow running models that aren't in the catalog. + dsgo.Configure() + + models := selectModels() + if len(models) == 0 { + fmt.Println("No models selected.\n" + + "Set OPENROUTER_API_KEY and/or OPENAI_API_KEY, or update examples/react_experiment/main.go.") + return + } + + maxConcurrent := DefaultMaxConcurrent + if v := os.Getenv("DSGO_EXPERIMENT_CONCURRENCY"); v != "" { + if n, err := strconv.Atoi(v); err == nil && n > 0 { + maxConcurrent = n + } } - fmt.Println("🧪 ReAct Experiment: Cost Package Analysis") + fmt.Println("🧪 ReAct Experiment: Core Package Analysis") fmt.Println("=" + strings.Repeat("=", 60)) fmt.Printf("Models: %s\n", strings.Join(models, ", ")) fmt.Printf("Max Iterations: %d\n", MaxIterations) + fmt.Printf("Max Concurrent: %d\n", maxConcurrent) fmt.Println() var wg sync.WaitGroup results := make(map[string]*ExperimentResult) var mu sync.Mutex + sem := make(chan struct{}, maxConcurrent) for _, modelName := range models { wg.Add(1) go func(model string) { defer wg.Done() + sem <- struct{}{} + defer func() { <-sem }() + result := runExperiment(ctx, model) mu.Lock() results[model] = result @@ -61,6 +74,158 @@ func main() { displayComparison(results) } +func selectModels() []string { + hasOpenRouter := os.Getenv("OPENROUTER_API_KEY") != "" + hasOpenAI := os.Getenv("OPENAI_API_KEY") != "" + + // Curated selection: + // - max 10 models total + // - 2 models each from: moonshotai, google, openai, z-ai, qwen + // - avoid ":free" variants (often unstable) + maxTotal := 10 + perOrg := 2 + + type candidate struct { + id string + cost float64 // PromptPrice + CompletionPrice (USD / 1M tokens) + newest time.Time // parsed from LastUpdated/ReleaseDate + } + + parseDate := func(s string) time.Time { + s = strings.TrimSpace(s) + if s == "" { + return time.Time{} + } + // Allow YYYY-MM-DD, YYYY-MM, or YYYY. + for _, layout := range []string{"2006-01-02", "2006-01", "2006"} { + if t, err := time.Parse(layout, s); err == nil { + return t + } + } + return time.Time{} + } + + pickForPrefix := func(provider, prefix string, max int) []string { + models := dsgo.ListModelsByProvider(provider) + cands := make([]candidate, 0, len(models)) + for _, m := range models { + if !strings.HasPrefix(m.ID, prefix) { + continue + } + // Avoid free-tier models; they tend to be unstable / rate-limited. + if strings.Contains(m.ID, ":free") { + continue + } + // ReAct relies on tool calling (native when supported). + if !m.Capabilities.ToolCall { + continue + } + // Some catalog entries use 0 output tokens as "unknown"; avoid those. + if m.Limits.OutputTokens <= 0 { + continue + } + + newest := parseDate(m.Metadata.LastUpdated) + if newest.IsZero() { + newest = parseDate(m.Metadata.ReleaseDate) + } + + cands = append(cands, candidate{ + id: m.ID, + cost: m.Pricing.PromptPrice + m.Pricing.CompletionPrice, + newest: newest, + }) + } + + if len(cands) == 0 { + return nil + } + + // Pick 1 cheapest and 1 newest (if distinct). This matches "newest and cheapest". + cheapestIdx := 0 + for i := 1; i < len(cands); i++ { + if cands[i].cost < cands[cheapestIdx].cost || (cands[i].cost == cands[cheapestIdx].cost && cands[i].id < cands[cheapestIdx].id) { + cheapestIdx = i + } + } + + newestIdx := 0 + for i := 1; i < len(cands); i++ { + if cands[i].newest.After(cands[newestIdx].newest) { + newestIdx = i + continue + } + if cands[i].newest.Equal(cands[newestIdx].newest) { + // tie-breaker: cheaper first + if cands[i].cost < cands[newestIdx].cost || (cands[i].cost == cands[newestIdx].cost && cands[i].id < cands[newestIdx].id) { + newestIdx = i + } + } + } + + picked := make([]string, 0, max) + picked = append(picked, cands[cheapestIdx].id) + if len(picked) >= max { + return picked + } + if newestIdx != cheapestIdx { + picked = append(picked, cands[newestIdx].id) + } + + // If we still need more (e.g. cheapest==newest), fill by (newest desc, cost asc). + if len(picked) < max { + sort.Slice(cands, func(i, j int) bool { + if cands[i].newest.Equal(cands[j].newest) { + if cands[i].cost == cands[j].cost { + return cands[i].id < cands[j].id + } + return cands[i].cost < cands[j].cost + } + return cands[i].newest.After(cands[j].newest) + }) + for i := 0; i < len(cands) && len(picked) < max; i++ { + already := false + for _, id := range picked { + if id == cands[i].id { + already = true + break + } + } + if already { + continue + } + picked = append(picked, cands[i].id) + } + } + + return picked + } + + out := make([]string, 0, maxTotal) + + if hasOpenRouter { + orgPrefixes := []string{ + "openrouter/moonshotai/", + "openrouter/google/", + "openrouter/openai/", + "openrouter/z-ai/", + "openrouter/qwen/", + } + for _, prefix := range orgPrefixes { + out = append(out, pickForPrefix("openrouter", prefix, perOrg)...) + } + // Exactly 10 (5 orgs * 2 models). + return out + } + + // Fallback when OpenRouter isn't configured: only OpenAI direct models can be selected. + if hasOpenAI { + return pickForPrefix("openai", "openai/", perOrg) + } + + return nil +} + // SpyLM wraps an LM to capture interactions type SpyLM struct { dsgo.LM @@ -146,6 +311,7 @@ func runExperiment(ctx context.Context, modelName string) *ExperimentResult { react = react.WithOptions(&dsgo.GenerateOptions{ Temperature: 0.3, MaxTokens: 2000, + RetryConfig: &dsgo.RetryConfig{MaxRetries: 2}, }) startTime := time.Now() @@ -154,7 +320,7 @@ func runExperiment(ctx context.Context, modelName string) *ExperimentResult { defer cancel() inputs := map[string]any{ - "task": "Analyze the internal/core package. Read the source files to understand: 1) The overall architecture and design, 2) Key interfaces like LM, Module, and Provider, 3) The pipeline and execution flow, 4) Error handling and logging, 5) Test coverage and examples. Use the filesystem tools to read actual source code.", + "task": "Analyze the internal/core package. First, list internal/core and only read files that exist in that listing (do not guess file names). Focus on: overall architecture, LM + Module interfaces, signatures/validation, adapters/parsing flow, prediction/usage metadata, tools and tool calling surfaces, caching, history, settings/config, and collectors.", } prediction, err := react.Forward(ctx, inputs) @@ -176,6 +342,18 @@ func runExperiment(ctx context.Context, modelName string) *ExperimentResult { result.FinalOutput = analysis } + // Prefer module-reported iteration count (includes termination and extractor info). + if it, ok := prediction.Metadata["react_iterations_used"]; ok { + switch v := it.(type) { + case int: + result.Iterations = v + case int64: + result.Iterations = int(v) + case float64: + result.Iterations = int(v) + } + } + // Process captured interactions to extract thoughts and tool calls processInteractions(spyLM.Interactions, result) } diff --git a/internal/core/prediction.go b/internal/core/prediction.go index de7dc44..ef466fe 100644 --- a/internal/core/prediction.go +++ b/internal/core/prediction.go @@ -23,6 +23,11 @@ type Prediction struct { // Parse diagnostics (for partial outputs and validation tracking) ParseDiagnostics *ValidationDiagnostics // Validation diagnostics for partial outputs + + // Metadata is arbitrary module/provider diagnostics. + // This is intentionally flexible to allow modules to attach iteration counts, + // termination reasons, tool traces, etc. + Metadata map[string]any } // NewPrediction creates a new prediction from outputs @@ -30,6 +35,7 @@ func NewPrediction(outputs map[string]any) *Prediction { return &Prediction{ Outputs: outputs, Completions: []map[string]any{}, + Metadata: map[string]any{}, } } @@ -75,6 +81,26 @@ func (p *Prediction) WithParseDiagnostics(diag *ValidationDiagnostics) *Predicti return p } +// WithMetadata replaces the prediction metadata map. +// Callers should prefer AddMetadata for incremental updates. +func (p *Prediction) WithMetadata(metadata map[string]any) *Prediction { + if metadata == nil { + p.Metadata = map[string]any{} + return p + } + p.Metadata = metadata + return p +} + +// AddMetadata adds a single metadata key/value. +func (p *Prediction) AddMetadata(key string, value any) *Prediction { + if p.Metadata == nil { + p.Metadata = map[string]any{} + } + p.Metadata[key] = value + return p +} + // Get retrieves a value from outputs func (p *Prediction) Get(key string) (any, bool) { val, ok := p.Outputs[key] diff --git a/internal/module/program.go b/internal/module/program.go index 3224228..2f65a67 100644 --- a/internal/module/program.go +++ b/internal/module/program.go @@ -733,6 +733,11 @@ func copyPrediction(p *core.Prediction) *core.Prediction { } // Create new prediction + var diagCopy *core.ValidationDiagnostics + if p.ParseDiagnostics != nil { + diagCopy = p.ParseDiagnostics.Clone() + } + resultCopy := &core.Prediction{ Outputs: core.DeepCopyMap(p.Outputs), Usage: p.Usage, // Usage is a struct, copies by value @@ -745,7 +750,8 @@ func copyPrediction(p *core.Prediction) *core.Prediction { ParseSuccess: p.ParseSuccess, ParseAttempts: p.ParseAttempts, FallbackUsed: p.FallbackUsed, - ParseDiagnostics: p.ParseDiagnostics.Clone(), + ParseDiagnostics: diagCopy, + Metadata: core.DeepCopyMap(p.Metadata), } return resultCopy diff --git a/internal/module/react.go b/internal/module/react.go index 780e0ed..bd57720 100644 --- a/internal/module/react.go +++ b/internal/module/react.go @@ -5,37 +5,64 @@ import ( "encoding/json" "fmt" "maps" + "os" "regexp" "strconv" "strings" "time" "github.com/assagman/dsgo/internal/core" + "github.com/assagman/dsgo/internal/jsonutil" "github.com/assagman/dsgo/internal/logging" ) const ( MaxReActIterations = 10 + + defaultReActMaxToolResultBytes = 16 * 1024 + // A conservative prompt budget used for trajectory rendering. + // ReAct will still detect provider context overflow errors and shrink further. + defaultReActMaxPromptBytes = 256 * 1024 ) -// ReAct implements the Reasoning and Acting pattern +// ReAct implements the Reasoning and Acting pattern. +// +// Key properties: +// - Uses an explicit trajectory object rather than mutating a shared []Message. +// - Encodes tool observations as a bounded JSON envelope. +// - Detects context overflow errors and truncates oldest trajectory steps. +// - Supports a prompted planning mode for LMs that don't support native tool calling. +// - Prefers a signature-valid in-loop final candidate; otherwise falls back to an extractor stage. +// +// ReAct instances are not safe for concurrent use. For parallel execution, +// call Clone() per concurrent worker and configure each clone before use. +// +// NOTE: This module is intentionally breaking-change tolerant, but dsgo.NewReAct +// remains available as the primary constructor. type ReAct struct { - Signature *core.Signature - LM core.LM - Tools []core.Tool - Options *core.GenerateOptions - Adapter core.Adapter - History *core.History // Optional conversation history - Demos []core.Example // Optional few-shot examples + Signature *core.Signature + LM core.LM + Tools []core.Tool + Options *core.GenerateOptions + Adapter core.Adapter + History *core.History // Optional conversation history + Demos []core.Example // Optional few-shot examples + MaxIterations int Verbose bool + + // MaxToolResultBytes controls deterministic truncation of tool results + // before they are added to the prompt. + MaxToolResultBytes int + + // MaxPromptBytes is the soft prompt budget for trajectory rendering. + // If 0, a conservative default is used. + MaxPromptBytes int } -// NewReAct creates a new ReAct module +// NewReAct creates a new ReAct module. // // Panics if signature or lm is nil to fail fast on invalid configuration. -// ReAct instances are not safe for concurrent use. For parallel execution, -// call Clone() per concurrent worker and configure each clone before use. func NewReAct(signature *core.Signature, lm core.LM, tools []core.Tool) *ReAct { if signature == nil { panic("NewReAct: signature cannot be nil") @@ -47,21 +74,30 @@ func NewReAct(signature *core.Signature, lm core.LM, tools []core.Tool) *ReAct { tools = []core.Tool{} } - // Clone tools slice to avoid mutating caller's backing array when appending finish tool clonedTools := make([]core.Tool, len(tools)) copy(clonedTools, tools) r := &ReAct{ - Signature: signature, - LM: lm, - Tools: clonedTools, - Options: core.DefaultGenerateOptions(), - Adapter: core.NewFallbackAdapter(), - MaxIterations: MaxReActIterations, - Verbose: false, + Signature: signature, + LM: lm, + Tools: clonedTools, + Options: core.DefaultGenerateOptions(), + Adapter: core.NewFallbackAdapter(), + MaxIterations: MaxReActIterations, + Verbose: false, + MaxToolResultBytes: defaultReActMaxToolResultBytes, + MaxPromptBytes: defaultReActMaxPromptBytes, + } + + // Allow env override for prompt budget in tests / tuning. + if v := strings.TrimSpace(os.Getenv("DSGO_REACT_MAX_PROMPT_BYTES")); v != "" { + if n, err := strconv.Atoi(v); err == nil && n > 0 { + r.MaxPromptBytes = n + } } - // AUTO-INJECT finish tool if not present + // AUTO-INJECT finish tool if not present. + // finish is a termination signal; the final answer still goes through extraction. if r.findTool("finish") == nil { finishTool := buildFinishTool(signature) r.Tools = append(r.Tools, *finishTool) @@ -70,8 +106,8 @@ func NewReAct(signature *core.Signature, lm core.LM, tools []core.Tool) *ReAct { return r } -// WithOptions sets custom generation options -// If nil is passed, defaults are used +// WithOptions sets custom generation options. +// If nil is passed, defaults are used. func (r *ReAct) WithOptions(options *core.GenerateOptions) *ReAct { if options == nil { r.Options = core.DefaultGenerateOptions() @@ -81,50 +117,48 @@ func (r *ReAct) WithOptions(options *core.GenerateOptions) *ReAct { return r } -// WithAdapter sets a custom adapter +// WithAdapter sets a custom adapter. func (r *ReAct) WithAdapter(adapter core.Adapter) *ReAct { r.Adapter = adapter return r } -// WithHistory sets conversation history for multi-turn interactions +// WithHistory sets conversation history for multi-turn interactions. func (r *ReAct) WithHistory(history *core.History) *ReAct { r.History = history return r } -// WithDemos sets few-shot examples for in-context learning +// WithDemos sets few-shot examples for in-context learning. func (r *ReAct) WithDemos(demos []core.Example) *ReAct { r.Demos = demos return r } -// WithMaxIterations sets the maximum number of ReAct iterations +// WithMaxIterations sets the maximum number of ReAct iterations. func (r *ReAct) WithMaxIterations(max int) *ReAct { r.MaxIterations = max return r } -// WithVerbose enables verbose logging +// WithVerbose enables verbose logging. func (r *ReAct) WithVerbose(verbose bool) *ReAct { r.Verbose = verbose return r } -// GetSignature returns the module's signature +// GetSignature returns the module's signature. func (r *ReAct) GetSignature() *core.Signature { return r.Signature } -// Forward executes the ReAct loop +// Forward executes the ReAct loop. func (r *ReAct) Forward(ctx context.Context, inputs map[string]any) (*core.Prediction, error) { - // Ensure context has IDs ctx = logging.EnsureRequestID(ctx) ctx = logging.EnsureCorrelationID(ctx) startTime := time.Now() logging.LogPredictionStart(ctx, logging.ModuleReAct, r.Signature.Description) - var predErr error defer func() { logging.LogPredictionEnd(ctx, logging.ModuleReAct, time.Since(startTime), predErr) @@ -135,460 +169,348 @@ func (r *ReAct) Forward(ctx context.Context, inputs map[string]any) (*core.Predi return nil, predErr } - // Use adapter to format messages with demos newMessages, err := r.Adapter.Format(r.Signature, inputs, r.Demos) if err != nil { predErr = fmt.Errorf("failed to format messages: %w", err) return nil, predErr } - // Build initial message list - var messages []core.Message - - // Add system prompt for ReAct pattern - systemPrompt := r.buildSystemPrompt() - if systemPrompt != "" { - messages = append(messages, core.Message{Role: "system", Content: systemPrompt}) + // Build base trajectory messages. + var base []core.Message + if systemPrompt := r.buildSystemPrompt(); systemPrompt != "" { + base = append(base, core.Message{Role: "system", Content: systemPrompt}) } - - // Prepend history if available if r.History != nil && !r.History.IsEmpty() { - historyMessages := r.Adapter.FormatHistory(r.History) - messages = append(messages, historyMessages...) + base = append(base, r.Adapter.FormatHistory(r.History)...) } + base = append(base, newMessages...) - // Add new messages from adapter - messages = append(messages, newMessages...) - - // Track observations for stagnation detection - var lastObservation string - var finalMode bool + traj := newReActTrajectory(base) + term := newReActTermination() - // Track total usage across all iterations for accurate billing/monitoring + // Aggregate usage across all LM calls (loop + extraction). totalUsage := core.Usage{} - // Track whether history has been updated to avoid duplicate entries - historyUpdated := false - - // ReAct loop: Thought -> Action -> Observation + iterationsUsed := 0 + extractionUsed := false for i := 0; i < r.MaxIterations; i++ { - // Check for context cancellation before each iteration + iterationsUsed = i + 1 if err := ctx.Err(); err != nil { predErr = fmt.Errorf("context canceled before iteration %d: %w", i+1, err) return nil, predErr } - if r.Verbose { - fmt.Printf("\n=== ReAct Iteration %d ===\n", i+1) - } - logging.GetLogger().Debug(ctx, "ReAct iteration started", map[string]any{ - "iteration": i + 1, - "max_iterations": r.MaxIterations, - }) - - // Activate final mode on last iteration - if i == r.MaxIterations-1 { - finalMode = true - if r.Verbose { - fmt.Println("⚠️ Final iteration - forcing final answer mode") + if r.LM.SupportsTools() { + usage, done, stepErr := r.iterationNativeTools(ctx, traj, term, i) + totalUsage = addUsage(totalUsage, usage) + if stepErr != nil { + predErr = stepErr + return nil, predErr } - } - - // Copy options to avoid mutation - options := r.Options.Copy() - - // In final mode, signal we don't want tool calls but keep tools available - // for providers that require tool definitions when conversation has tool history - // (e.g., Amazon Bedrock requires toolConfig when toolUse/toolResult blocks exist) - if finalMode { - // Only pass tools if LM supports them to avoid confusing non-tool LMs - if r.LM.SupportsTools() && r.hasRealTools() { - options.Tools = r.Tools - options.ToolChoice = "none" - } else { - options.Tools = nil - options.ToolChoice = "" + if done { + break } + continue + } - // Inject user message to prompt for final answer - finalPrompt := r.buildFinalAnswerPrompt() - messages = append(messages, core.Message{ - Role: "user", - Content: finalPrompt, - }) - - if r.LM.SupportsJSON() { - options.ResponseFormat = "json" - // Auto-generate JSON schema from signature for structured outputs - if options.ResponseSchema == nil { - // Use OpenAI-compliant schema for OpenAI providers to avoid strict mode errors - if r.LM.IsOpenAI() { - options.ResponseSchema = r.Signature.SignatureToOpenAIJSONSchema() - } else { - options.ResponseSchema = r.Signature.SignatureToJSONSchema() - } - } + if r.hasRealTools() { + // Prompted planning mode for non-tool LMs. + usage, done, stepErr := r.iterationPrompted(ctx, traj, term, i) + totalUsage = addUsage(totalUsage, usage) + if stepErr != nil { + predErr = stepErr + return nil, predErr } - } else { - // Normal mode: enable tools if available - if r.LM.SupportsTools() && len(r.Tools) > 0 { - options.Tools = r.Tools - options.ToolChoice = "auto" + if done { + break } + continue } - // Enable JSON mode when tools are not used (for final answer) - if r.LM.SupportsJSON() && len(options.Tools) == 0 { - if _, isJSON := r.Adapter.(*core.JSONAdapter); isJSON { - options.ResponseFormat = "json" - // Auto-generate JSON schema from signature for structured outputs - if options.ResponseSchema == nil { - // Use OpenAI-compliant schema for OpenAI providers to avoid strict mode errors - if r.LM.IsOpenAI() { - options.ResponseSchema = r.Signature.SignatureToOpenAIJSONSchema() - } else { - options.ResponseSchema = r.Signature.SignatureToJSONSchema() - } - } - } - } + // No usable tools and no native tool support: skip loop and rely on extractor. + break + } - result, err := r.LM.Generate(ctx, messages, options) - if err != nil { - predErr = fmt.Errorf("LM generation failed at iteration %d: %w", i+1, err) - return nil, predErr - } + if err := ctx.Err(); err != nil { + predErr = fmt.Errorf("context canceled before extraction: %w", err) + return nil, predErr + } - // Accumulate usage across iterations for accurate total tracking - totalUsage.PromptTokens += result.Usage.PromptTokens - totalUsage.CompletionTokens += result.Usage.CompletionTokens - totalUsage.TotalTokens += result.Usage.TotalTokens - totalUsage.Cost += result.Usage.Cost - // For sequential iterations, sum the latencies - totalUsage.Latency += result.Usage.Latency - - // Implicit Finish: model chose direct answer over tools. - // When using native tool calling APIs, no tool calls = intentional direct answer. - // This is architecturally correct for LangChain/LlamaIndex-style native tool calling. - if len(result.ToolCalls) == 0 { - if r.Verbose { - fmt.Printf("Thought: %s\n", core.StripMarkers(result.Content)) - fmt.Println("Action: None (Final Answer)") - } + // Prefer returning a signature-valid candidate captured in-loop. + if pred, ok := r.tryFinalizeFromCandidate(inputs, newMessages, totalUsage, term); ok { + pred.AddMetadata("react_iterations_used", iterationsUsed) + pred.AddMetadata("react_max_iterations", r.MaxIterations) + pred.AddMetadata("react_termination_reason", string(term.Reason())) + pred.AddMetadata("react_extraction_used", false) + return pred, nil + } - // Apply hardened parsing (P2) - cleanedContent := stripToJSON(result.Content) - - // Use adapter to parse output - outputs, err := r.Adapter.Parse(r.Signature, cleanedContent) - if err != nil { - // Safeguard: guide model to use tools if parse fails early. - // In early iterations, nudge the model to use tools instead of accepting malformed output. - // Only do this if the LM actually supports tools and we have real tools available. - hasRealTools := r.hasRealTools() - if !finalMode && i < r.MaxIterations-2 && hasRealTools && r.LM.SupportsTools() { - if r.Verbose { - fmt.Println("⚠️ Parsing failed and tools available - requesting tool use") - } - messages = append(messages, core.Message{ - Role: "assistant", - Content: result.Content, - }) - messages = append(messages, core.Message{ - Role: "user", - Content: "Please use the available tools to gather the information needed, then provide a complete answer in the requested format. Do not include any meta-commentary or explanations - just the answer.", - }) - continue - } + // Fall back to extractor stage over the rendered trajectory. + extractionUsed = true + pred, err := r.runExtractWithContextRetry(ctx, traj, inputs, newMessages, totalUsage) + if err != nil { + predErr = err + return nil, predErr + } - // If in final mode and parsing fails, run extraction (P1) - if finalMode { - if r.Verbose { - fmt.Println("⚠️ Final answer parsing failed - running extraction") - } - var res *core.Prediction - res, predErr = r.runExtract(ctx, messages, inputs, newMessages, totalUsage, historyUpdated) - return res, predErr - } + pred.AddMetadata("react_iterations_used", iterationsUsed) + pred.AddMetadata("react_max_iterations", r.MaxIterations) + pred.AddMetadata("react_termination_reason", string(term.Reason())) + pred.AddMetadata("react_extraction_used", extractionUsed) + return pred, nil +} - // FALLBACK: If structured parsing fails, attempt text extraction for string fields - // This makes ReAct resilient to less capable models that don't follow structured formats - extractedOutputs := r.extractTextOutputs(cleanedContent, messages) - if len(extractedOutputs) > 0 { - if r.Verbose { - fmt.Println("⚠️ Structured parsing failed - falling back to raw text extraction") - } - outputs = extractedOutputs - } else { - // Last resort: run extraction - if r.Verbose { - fmt.Println("⚠️ All parsing failed - running extraction") - } - var res *core.Prediction - res, predErr = r.runExtract(ctx, messages, inputs, newMessages, totalUsage, historyUpdated) - return res, predErr - } - } +func (r *ReAct) iterationNativeTools(ctx context.Context, traj *reactTrajectory, term *reactTermination, iteration int) (core.Usage, bool, error) { + options := r.Options.Copy() + options.ResponseFormat = "" + options.ResponseSchema = nil + options.Tools = r.Tools + options.ToolChoice = "auto" - // Apply type coercion (P2) - outputs = coerceBasicTypes(r.Signature, outputs) + result, err := r.generateWithContextRetry(ctx, traj, options, nil) + if err != nil { + return core.Usage{}, false, fmt.Errorf("LM generation failed at iteration %d: %w", iteration+1, err) + } - // Apply output normalization - outputs = core.NormalizeOutputKeys(r.Signature, outputs) + step := traj.AddStep(result.Content, result.ToolCalls) - // Validate outputs; if validation fails, fall back to extraction - if err := r.Signature.ValidateOutputs(outputs); err != nil { - // Validation failed - try extraction as fallback - if r.Verbose { - fmt.Printf("⚠️ Output validation failed: %v - running extraction\n", err) - } - var res *core.Prediction - res, predErr = r.runExtract(ctx, messages, inputs, newMessages, totalUsage, historyUpdated) - return res, predErr - } + if r.Verbose { + fmt.Printf("\n=== ReAct Iteration %d (native tools) ===\n", iteration+1) + fmt.Printf("Thought: %s\n", core.StripMarkers(result.Content)) + } - // Extract adapter metadata - adapterUsed, parseAttempts, fallbackUsed := core.ExtractAdapterMetadata(outputs) + if len(result.ToolCalls) == 0 { + // Implicit finish: model chose to respond directly. + term.SetFinalContent(result.Content) + term.MarkDone(terminationNoToolCalls) + return result.Usage, true, nil + } - // Extract rationale if present - rationale := "" - if reasoning, exists := outputs["reasoning"]; exists { - rationale = fmt.Sprintf("%v", reasoning) - // Remove reasoning from outputs if not part of signature - if r.Signature.GetOutputField("reasoning") == nil { - delete(outputs, "reasoning") - } - } + // Track finish but still execute ALL tool calls. + // Provider APIs (OpenAI/OpenRouter) require a tool message for every tool_call_id. + hasFinish := false + for _, tc := range result.ToolCalls { + if strings.EqualFold(tc.Name, "finish") { + term.SetFinalToolArgs(tc.Arguments) + hasFinish = true + } + } - // Update history if present (only here since we return immediately after) - if r.History != nil { - // Add only the new user message(s) (not from history) - for _, msg := range newMessages { - if msg.Role == "user" { - r.History.Add(msg) - } - } + // Execute every tool call and append a tool message for each. + // Even if termination conditions are triggered mid-iteration, we must respond to all tool calls + // from this assistant message. + for _, tc := range result.ToolCalls { + term.ObserveToolCall(tc) + + tool := r.findTool(tc.Name) + var toolOut any + var toolErr error + if tool == nil { + toolErr = fmt.Errorf("tool '%s' not found", tc.Name) + } else { + toolOut, toolErr = tool.Execute(ctx, tc.Arguments) + } - // Add assistant response - r.History.Add(core.Message{ - Role: "assistant", - Content: result.Content, - }) - } + env, truncated, obsHash := encodeToolResult(tc.Name, tc.ID, toolOut, toolErr, r.MaxToolResultBytes) + step.AddToolResult(reactToolResult{ToolCallID: tc.ID, ToolName: tc.Name, Content: env, Truncated: truncated, Err: toolErr}) + term.ObserveToolResult(tc, obsHash, toolErr) - // Build Prediction object with accumulated usage from all iterations - prediction := core.NewPrediction(outputs). - WithRationale(rationale). - WithUsage(totalUsage). - WithModuleName(logging.ModuleReAct). - WithInputs(inputs) + if r.Verbose { + fmt.Printf("Action: %s(%v)\n", tc.Name, tc.Arguments) + fmt.Printf("Observation: %s\n", env) + } + } - // Add adapter metrics if available - if adapterUsed != "" { - prediction.WithAdapterMetrics(adapterUsed, parseAttempts, fallbackUsed) - } + if hasFinish { + term.MarkDone(terminationFinishTool) + return result.Usage, true, nil + } - return prediction, nil - } + return result.Usage, term.ShouldStop(), nil +} - // Add assistant's response with tool calls - messages = append(messages, core.Message{ - Role: "assistant", - Content: result.Content, - ToolCalls: result.ToolCalls, - }) +func (r *ReAct) iterationPrompted(ctx context.Context, traj *reactTrajectory, term *reactTermination, iteration int) (core.Usage, bool, error) { + planningPrompt := r.buildPlanningPrompt() + options := r.Options.Copy() + options.Tools = nil + options.ToolChoice = "" + options.ResponseSchema = nil + if r.LM.SupportsJSON() { + options.ResponseFormat = "json" + } else { + options.ResponseFormat = "" + } - if r.Verbose { - fmt.Printf("Thought: %s\n", core.StripMarkers(result.Content)) - } + result, err := r.generateWithContextRetry(ctx, traj, options, []core.Message{{Role: "user", Content: planningPrompt}}) + if err != nil { + return core.Usage{}, false, fmt.Errorf("planning generation failed at iteration %d: %w", iteration+1, err) + } - // Execute tool calls and add observations - var currentObservation string - for _, toolCall := range result.ToolCalls { - if r.Verbose { - fmt.Printf("Action: %s(%v)\n", toolCall.Name, toolCall.Arguments) - } + plan, parseErr := parsePlanningResult(result.Content) + if parseErr != nil { + // Record the assistant content as a step for debugging, then terminate into extraction. + traj.AddStep(result.Content, nil) + term.ObserveError(parseErr) + term.MarkDone(terminationPlanningParseError) + return result.Usage, true, nil + } - // Check if this is a "finish" tool call - treat as final answer - if strings.ToLower(toolCall.Name) == "finish" { - if r.Verbose { - fmt.Println("Finish tool called - extracting final answer") - } + if r.Verbose { + fmt.Printf("\n=== ReAct Iteration %d (prompted) ===\n", iteration+1) + fmt.Printf("Plan: %s\n", strings.TrimSpace(result.Content)) + } - // Extract outputs from finish tool arguments - outputs := make(map[string]any) - maps.Copy(outputs, toolCall.Arguments) - - // Apply type coercion and normalization for consistency with direct answer path - outputs = coerceBasicTypes(r.Signature, outputs) - outputs = core.NormalizeOutputKeys(r.Signature, outputs) - - // Validate outputs match signature - if err := r.Signature.ValidateOutputs(outputs); err != nil { - // If finish tool args don't match signature, continue and let model try again - observation := fmt.Sprintf("Error: finish tool arguments don't match required outputs: %v", err) - messages = append(messages, core.Message{ - Role: "tool", - Content: observation, - ToolID: toolCall.ID, - }) - if r.Verbose { - fmt.Printf("Observation: %s\n", observation) - } - currentObservation = observation - continue - } + if plan.Done || plan.NextToolName == "" || strings.EqualFold(plan.NextToolName, "finish") { + traj.AddStep(result.Content, nil) + term.MarkDone(terminationPlanningDone) + return result.Usage, true, nil + } - // Update history with final answer for multi-turn consistency (only here since we return immediately after) - if r.History != nil { - for _, msg := range newMessages { - if msg.Role == "user" { - r.History.Add(msg) - } - } - contentBytes, _ := json.Marshal(outputs) - r.History.Add(core.Message{ - Role: "assistant", - Content: string(contentBytes), - }) - } + // Synthesize a deterministic tool call ID for prompted mode. + toolCallID := fmt.Sprintf("prompted_%d", iteration+1) + toolCall := core.ToolCall{ID: toolCallID, Name: plan.NextToolName, Arguments: plan.NextToolArgs} + step := traj.AddStep(result.Content, []core.ToolCall{toolCall}) - // Build prediction and return with accumulated usage - prediction := core.NewPrediction(outputs). - WithUsage(totalUsage). - WithModuleName(logging.ModuleReAct). - WithInputs(inputs) + term.ObserveToolCall(toolCall) + if term.ShouldStop() { + term.MarkDone(terminationStagnation) + return result.Usage, true, nil + } - return prediction, nil - } + tool := r.findTool(plan.NextToolName) + var toolOut any + var toolErr error + if tool == nil { + toolErr = fmt.Errorf("tool '%s' not found", plan.NextToolName) + } else { + toolOut, toolErr = tool.Execute(ctx, plan.NextToolArgs) + } - tool := r.findTool(toolCall.Name) - if tool == nil { - observation := fmt.Sprintf("Error: Tool '%s' not found", toolCall.Name) - messages = append(messages, core.Message{ - Role: "tool", - Content: observation, - ToolID: toolCall.ID, - }) - if r.Verbose { - fmt.Printf("Observation: %s\n", observation) - } - currentObservation = observation - continue - } + env, truncated, obsHash := encodeToolResult(plan.NextToolName, toolCallID, toolOut, toolErr, r.MaxToolResultBytes) + step.AddToolResult(reactToolResult{ToolCallID: toolCallID, ToolName: plan.NextToolName, Content: env, Truncated: truncated, Err: toolErr}) + term.ObserveToolResult(toolCall, obsHash, toolErr) - toolResult, err := tool.Execute(ctx, toolCall.Arguments) - if err != nil { - observation := fmt.Sprintf("Error executing tool '%s': %v", toolCall.Name, err) - logging.GetLogger().Warn(ctx, "Tool execution failed", map[string]any{ - "tool": toolCall.Name, - "tool_id": toolCall.ID, - "error": err.Error(), - "iteration": i + 1, - }) - messages = append(messages, core.Message{ - Role: "tool", - Content: observation, - ToolID: toolCall.ID, - }) - if r.Verbose { - fmt.Printf("Observation: %s\n", observation) - } - currentObservation = observation - continue - } + return result.Usage, term.ShouldStop(), nil +} - observation := fmt.Sprintf("%v", toolResult) - messages = append(messages, core.Message{ - Role: "tool", - Content: observation, - ToolID: toolCall.ID, - }) - if r.Verbose { - fmt.Printf("Observation: %s\n", observation) - } - currentObservation = observation +type planningResult struct { + NextToolName string + NextToolArgs map[string]any + Done bool +} + +func parsePlanningResult(content string) (planningResult, error) { + cleaned := stripToJSON(content) + repaired := jsonutil.RepairJSON(cleaned) + + var raw map[string]any + if err := json.Unmarshal([]byte(repaired), &raw); err != nil { + return planningResult{}, fmt.Errorf("parse planning json: %w", err) + } + + // Be tolerant to naming variations. + name := "" + if v, ok := raw["next_tool_name"]; ok { + name, _ = v.(string) + } + if name == "" { + if v, ok := raw["tool_name"]; ok { + name, _ = v.(string) } + } + if name == "" { + if v, ok := raw["tool"]; ok { + name, _ = v.(string) + } + } - // Detect stagnation: if same observation appears twice in a row, force final answer - if currentObservation != "" && currentObservation == lastObservation { - if r.Verbose { - fmt.Println("\n⚠️ Stagnation detected - activating final mode") + args := map[string]any{} + if v, ok := raw["next_tool_args"]; ok { + if m, ok := v.(map[string]any); ok { + args = m + } + } + if len(args) == 0 { + if v, ok := raw["tool_args"]; ok { + if m, ok := v.(map[string]any); ok { + args = m } - finalMode = true - messages = append(messages, core.Message{ - Role: "user", - Content: "You've received the same observation twice. Please provide your final answer now as a JSON object with all required fields. Do not call any more tools.", - }) } - lastObservation = currentObservation } - - // Check context before extraction - if err := ctx.Err(); err != nil { - predErr = fmt.Errorf("context canceled before extraction: %w", err) - return nil, predErr + if len(args) == 0 { + // Some models return args as a JSON string. + if v, ok := raw["args"]; ok { + switch vv := v.(type) { + case map[string]any: + args = vv + case string: + _ = json.Unmarshal([]byte(jsonutil.RepairJSON(vv)), &args) + } + } } - // Max iterations exceeded - run extraction to salvage an answer (P1) - if r.Verbose { - fmt.Printf("\n⚠️ Exceeded maximum iterations (%d) - running extraction\n", r.MaxIterations) + done := false + if v, ok := raw["done"]; ok { + if b, ok := v.(bool); ok { + done = b + } + } + if v, ok := raw["final"]; ok { + if b, ok := v.(bool); ok { + done = done || b + } } - var res *core.Prediction - res, predErr = r.runExtract(ctx, messages, inputs, newMessages, totalUsage, historyUpdated) - return res, predErr + + return planningResult{NextToolName: strings.TrimSpace(name), NextToolArgs: args, Done: done}, nil } func (r *ReAct) buildSystemPrompt() string { - // Don't build system prompt if only the finish tool exists (no real tools) - if len(r.Tools) == 0 || (len(r.Tools) == 1 && r.Tools[0].Name == "finish") { + if !r.hasRealTools() { return "" } - var prompt strings.Builder - prompt.WriteString("You are a helpful AI assistant that uses tools to answer questions.\n\n") - prompt.WriteString("Follow these steps:\n") - prompt.WriteString("1. Use the available tools to gather the information you need\n") - prompt.WriteString("2. Once you have enough information, call the 'finish' tool with your complete answer\n") - prompt.WriteString("3. If you already have the answer, call 'finish' immediately\n\n") - - prompt.WriteString("IMPORTANT:\n") - prompt.WriteString("- Use the native tool calling mechanism\n") - prompt.WriteString("- Do NOT write textual representations like 'Action: search(...)' or 'Thought:'\n") - prompt.WriteString("- When calling 'finish', provide ALL required fields in the tool arguments\n") - prompt.WriteString("- Do not include explanations or meta-commentary\n") - - return prompt.String() + var b strings.Builder + b.WriteString("You are a helpful AI assistant. Use tools when needed.\n\n") + if r.LM.SupportsTools() { + b.WriteString("When you need external information, call tools using the native tool calling mechanism.\n") + b.WriteString("When you are done, you may either respond directly or call the 'finish' tool.\n") + b.WriteString("Do not write textual tool call syntax; use the tool calling API.\n") + } else { + b.WriteString("This model does not support native tool calling.\n") + b.WriteString("You will be asked to output a JSON plan indicating the next tool to run.\n") + b.WriteString("When you are done, set next_tool_name to \"finish\" and done=true.\n") + } + return b.String() } -func (r *ReAct) buildFinalAnswerPrompt() string { - var prompt strings.Builder - prompt.WriteString("Based on all the information gathered above, please provide your final answer now.\n\n") - - // Add output format specification (P3: clearer JSON instructions) - prompt.WriteString("Respond with a valid JSON object containing these fields:\n") - for _, field := range r.Signature.OutputFields { - optional := "" - if field.Optional { - optional = " (optional)" - } - classInfo := "" - if field.Type == core.FieldTypeClass && len(field.Classes) > 0 { - classInfo = fmt.Sprintf(" [one of: %s]", strings.Join(field.Classes, ", ")) +func (r *ReAct) buildPlanningPrompt() string { + var b strings.Builder + b.WriteString("Decide the next action.\n\n") + b.WriteString("Return ONLY a JSON object with: \n") + b.WriteString("- next_tool_name: string (tool name, or \"finish\" if done)\n") + b.WriteString("- next_tool_args: object (arguments for that tool)\n") + b.WriteString("- done: boolean\n\n") + b.WriteString("Available tools:\n") + for _, t := range r.Tools { + if strings.EqualFold(t.Name, "finish") { + continue } - if field.Description != "" { - prompt.WriteString(fmt.Sprintf("- %s (%s)%s%s: %s\n", field.Name, field.Type, optional, classInfo, field.Description)) - } else { - prompt.WriteString(fmt.Sprintf("- %s (%s)%s%s\n", field.Name, field.Type, optional, classInfo)) + b.WriteString("- ") + b.WriteString(t.Name) + if t.Description != "" { + b.WriteString(": ") + b.WriteString(t.Description) } + b.WriteString("\n") } + b.WriteString("\nIf no tool is needed, set next_tool_name=\"finish\" and done=true.\n") + return b.String() +} - prompt.WriteString("\nCRITICAL REQUIREMENTS:\n") - prompt.WriteString("- Return ONLY a valid JSON object (no code fences, no explanations)\n") - prompt.WriteString("- Include all required fields with appropriate values\n") - prompt.WriteString("- Use the exact field names specified above\n") - prompt.WriteString("- Provide a complete answer based on all observations you've gathered\n") - +func (r *ReAct) buildExtractionPrompt() string { + var prompt strings.Builder + prompt.WriteString("Based on the conversation above (including tool observations), synthesize the final answer.\n") + prompt.WriteString("Return ONLY a JSON object with the required fields.\n\n") return prompt.String() } @@ -601,136 +523,33 @@ func (r *ReAct) findTool(name string) *core.Tool { return nil } -// hasRealTools returns true if there are tools beyond the auto-injected "finish" tool +// hasRealTools returns true if there are tools beyond the auto-injected "finish" tool. func (r *ReAct) hasRealTools() bool { for _, t := range r.Tools { - if strings.ToLower(t.Name) != "finish" { + if !strings.EqualFold(t.Name, "finish") { return true } } return false } -// extractTextOutputs attempts to extract output fields from raw text when structured parsing fails -// This is a last-resort fallback for less capable models that don't follow JSON/Chat formats -func (r *ReAct) extractTextOutputs(content string, messages []core.Message) map[string]any { - outputs := make(map[string]any) - content = strings.TrimSpace(content) - - // If content is empty or very short, try to synthesize from message history - if len(content) < 10 { - if r.Verbose { - fmt.Println("⚠️ Content too short, synthesizing from observations") - } - content = r.synthesizeAnswerFromHistory(messages) - } - - // Only attempt extraction for string fields - var stringFields []core.Field - for _, field := range r.Signature.OutputFields { - if field.Type == core.FieldTypeString { - stringFields = append(stringFields, field) - } - } - - if len(stringFields) == 0 { - return nil - } - - // Strategy 1: If only one string field and it's "answer", use entire content - if len(stringFields) == 1 && stringFields[0].Name == "answer" { - outputs["answer"] = content - return outputs - } - - // Strategy 2: For multiple fields, try simple heuristics - // - If all required fields are strings, split content or use entire content for primary field - var primaryFieldName string - if answerField := r.Signature.GetOutputField("answer"); answerField != nil { - primaryFieldName = "answer" - } else if len(stringFields) > 0 { - primaryFieldName = stringFields[0].Name - } - - if primaryFieldName != "" { - // Use entire content for primary field (answer) - outputs[primaryFieldName] = content - - // For other string fields, try to provide reasonable defaults - for _, field := range stringFields { - if field.Name != primaryFieldName { - if field.Name == "sources" { - // Extract mentions that look like sources - outputs["sources"] = "Based on search results and tool observations" - } else if !field.Optional { - // Provide placeholder for required fields - outputs[field.Name] = content - } - } - } - } - - return outputs -} - -// synthesizeAnswerFromHistory extracts and summarizes observations from the message history -// Used as a fallback when the model produces empty content in final mode -func (r *ReAct) synthesizeAnswerFromHistory(messages []core.Message) string { - var observations []string - - // Extract tool observations from message history - for _, msg := range messages { - if msg.Role == "tool" && msg.Content != "" { - // Skip error messages - if !strings.HasPrefix(msg.Content, "Error:") { - observations = append(observations, msg.Content) - } - } - } - - if len(observations) == 0 { - return "No information available from tools" - } - - // Use the most recent relevant observation - // Take the last non-duplicate observation - seen := make(map[string]bool) - var uniqueObs []string - for i := len(observations) - 1; i >= 0 && len(uniqueObs) < 3; i-- { - obs := observations[i] - if !seen[obs] && len(obs) > 20 { - uniqueObs = append([]string{obs}, uniqueObs...) - seen[obs] = true - } - } - - if len(uniqueObs) > 0 { - return strings.Join(uniqueObs, " ") - } - - return observations[len(observations)-1] -} - // buildFinishTool creates a synthetic "finish" tool that allows models to explicitly -// conclude the ReAct loop by providing final outputs matching the signature +// conclude the ReAct loop by providing final outputs matching the signature. func buildFinishTool(signature *core.Signature) *core.Tool { tool := core.NewTool( "finish", - "Call this tool when you have gathered enough information and are ready to provide the final answer. Use the tool arguments to provide your complete answer.", + "Call this tool when you are ready to conclude the reasoning/tool loop.", func(ctx context.Context, args map[string]any) (any, error) { - // This tool is intercepted in Forward() before execution - return "Final answer provided", nil + return "finish", nil }, ) - // Add parameters matching the output signature for _, field := range signature.OutputFields { description := field.Description if description == "" { description = fmt.Sprintf("The %s field of the final answer", field.Name) } - // Determine parameter type paramType := "string" switch field.Type { case core.FieldTypeInt: @@ -738,166 +557,217 @@ func buildFinishTool(signature *core.Signature) *core.Tool { case core.FieldTypeBool: paramType = "boolean" } - - // Add class information to description if field.Type == core.FieldTypeClass && len(field.Classes) > 0 { description = fmt.Sprintf("%s (one of: %s)", description, strings.Join(field.Classes, ", ")) } - tool.AddParameter(field.Name, paramType, description, !field.Optional) } return tool } -// stripToJSON removes common LLM artifacts from JSON output -// Handles: code fences, trailing commentary, leading/trailing text -func stripToJSON(content string) string { - content = strings.TrimSpace(content) - - // Remove markdown code fences - re := regexp.MustCompile("(?s)```(?:json)?\n?(.*?)\n?```") - if matches := re.FindStringSubmatch(content); len(matches) > 1 { - content = strings.TrimSpace(matches[1]) +// Clone creates an independent copy of ReAct module. +func (r *ReAct) Clone() core.Module { + cloned := &ReAct{ + Signature: r.Signature, + LM: r.LM, + Tools: make([]core.Tool, len(r.Tools)), + Options: r.Options, + Adapter: r.Adapter, + History: nil, + Demos: make([]core.Example, len(r.Demos)), + MaxIterations: r.MaxIterations, + Verbose: r.Verbose, + MaxToolResultBytes: r.MaxToolResultBytes, + MaxPromptBytes: r.MaxPromptBytes, } - // Find JSON object boundaries - start := strings.Index(content, "{") - end := strings.LastIndex(content, "}") + copy(cloned.Demos, r.Demos) + copy(cloned.Tools, r.Tools) - if start != -1 && end != -1 && end > start { - content = content[start : end+1] + if r.History != nil { + cloned.History = r.History.Clone() } - return strings.TrimSpace(content) + return cloned } -// coerceBasicTypes handles type mismatches in parsed outputs -// Converts: string numbers to ints, string bools to bools, etc. -func coerceBasicTypes(signature *core.Signature, outputs map[string]any) map[string]any { - coerced := make(map[string]any) - - for key, value := range outputs { - field := signature.GetOutputField(key) - if field == nil { - coerced[key] = value - continue +func (r *ReAct) tryFinalizeFromCandidate(inputs map[string]any, newMessages []core.Message, usage core.Usage, term *reactTermination) (*core.Prediction, bool) { + // 1) Finish tool args (already structured). + if args := term.FinalToolArgs(); args != nil { + outputs := cloneMap(args) + outputs = coerceBasicTypes(r.Signature, outputs) + outputs = core.NormalizeOutputKeys(r.Signature, outputs) + + // Extract rationale/reasoning if present and not in signature. + rationale := "" + if val, ok := outputs["rationale"]; ok { + rationale = fmt.Sprintf("%v", val) + if r.Signature.GetOutputField("rationale") == nil { + delete(outputs, "rationale") + } + } + if rationale == "" { + if val, ok := outputs["reasoning"]; ok { + rationale = fmt.Sprintf("%v", val) + if r.Signature.GetOutputField("reasoning") == nil { + delete(outputs, "reasoning") + } + } } - switch field.Type { - case core.FieldTypeInt: - // Try to convert string to int - if strVal, ok := value.(string); ok { - // Extract first number from string (e.g., "5 years" -> 5) - re := regexp.MustCompile(`-?\d+`) - if match := re.FindString(strVal); match != "" { - if intVal, err := strconv.Atoi(match); err == nil { - coerced[key] = intVal - continue + if err := r.Signature.ValidateOutputs(outputs); err == nil { + if r.History != nil { + for _, msg := range newMessages { + if msg.Role == "user" { + r.History.Add(msg) } } + contentBytes, _ := json.Marshal(outputs) + r.History.Add(core.Message{Role: "assistant", Content: string(contentBytes)}) } - // Try float64 (from JSON unmarshaling) to int - if floatVal, ok := value.(float64); ok { - coerced[key] = int(floatVal) - continue + + pred := core.NewPrediction(outputs). + WithUsage(usage). + WithModuleName(logging.ModuleReAct). + WithInputs(inputs) + if rationale != "" { + pred = pred.WithRationale(rationale) } - coerced[key] = value + return pred, true + } + } - case core.FieldTypeBool: - // Try to convert string to bool - if strVal, ok := value.(string); ok { - strVal = strings.ToLower(strings.TrimSpace(strVal)) - if strVal == "true" || strVal == "yes" || strVal == "1" { - coerced[key] = true - continue + // 2) Direct answer content (parse + validate). + content := strings.TrimSpace(term.FinalContent()) + if content != "" { + parsed, err := r.Adapter.Parse(r.Signature, content) + if err != nil { + cleaned := stripToJSON(content) + if cleaned != content { + parsed, err = r.Adapter.Parse(r.Signature, cleaned) + } + } + if err == nil { + parsed = coerceBasicTypes(r.Signature, parsed) + parsed = core.NormalizeOutputKeys(r.Signature, parsed) + rationale := "" + if val, ok := parsed["rationale"]; ok { + rationale = fmt.Sprintf("%v", val) + if r.Signature.GetOutputField("rationale") == nil { + delete(parsed, "rationale") } - if strVal == "false" || strVal == "no" || strVal == "0" { - coerced[key] = false - continue + } + if rationale == "" { + if val, ok := parsed["reasoning"]; ok { + rationale = fmt.Sprintf("%v", val) + if r.Signature.GetOutputField("reasoning") == nil { + delete(parsed, "reasoning") + } } } - coerced[key] = value - case core.FieldTypeString: - // Convert any type to string if needed - if value != nil { - coerced[key] = fmt.Sprintf("%v", value) - } else { - coerced[key] = value + if err := r.Signature.ValidateOutputs(parsed); err == nil { + adapterUsed, parseAttempts, fallbackUsed := core.ExtractAdapterMetadata(parsed) + + if r.History != nil { + for _, msg := range newMessages { + if msg.Role == "user" { + r.History.Add(msg) + } + } + r.History.Add(core.Message{Role: "assistant", Content: content}) + } + + pred := core.NewPrediction(parsed). + WithUsage(usage). + WithModuleName(logging.ModuleReAct). + WithInputs(inputs) + if rationale != "" { + pred = pred.WithRationale(rationale) + } + if adapterUsed != "" { + pred.WithAdapterMetrics(adapterUsed, parseAttempts, fallbackUsed) + } + return pred, true } + } + } - default: - coerced[key] = value + return nil, false +} + +// --- Extraction helpers (always uses extractor) --- + +func (r *ReAct) runExtractWithContextRetry( + ctx context.Context, + traj *reactTrajectory, + inputs map[string]any, + newMessages []core.Message, + priorUsage core.Usage, +) (*core.Prediction, error) { + // Retry extraction a few times on context overflow, truncating oldest steps. + for attempt := 0; attempt < 3; attempt++ { + pred, err := r.runExtract(ctx, traj.Render(r.maxPromptBytes()), inputs, newMessages, priorUsage, false) + if err == nil { + return pred, nil + } + if !isContextOverflowError(err) { + return nil, err + } + if traj.DropOldestSteps(1) == 0 { + return nil, err } } + // Final attempt without further truncation. + return r.runExtract(ctx, traj.Render(r.maxPromptBytes()), inputs, newMessages, priorUsage, false) +} - return coerced +func (r *ReAct) maxPromptBytes() int { + if r.MaxPromptBytes > 0 { + return r.MaxPromptBytes + } + return defaultReActMaxPromptBytes } // runExtract performs post-loop extraction to synthesize a final answer -// from the accumulated message history (trajectory). This is the critical -// fallback that ensures ReAct always returns something, even if the main -// loop fails or produces unparseable output. -// -// This phase uses structured output enforcement (when enabled) to ensure -// the extraction converges to valid outputs, with bounded retries. -// -// Parameters: -// - newMessages: the original user messages (for history update) -// - priorUsage: accumulated usage from prior iterations (to be merged with extraction usage) -// - historyUpdated: indicates if history has already been updated (to avoid duplicates) +// from the accumulated message history (trajectory). func (r *ReAct) runExtract(ctx context.Context, messages []core.Message, inputs map[string]any, newMessages []core.Message, priorUsage core.Usage, historyUpdated bool) (*core.Prediction, error) { - if r.Verbose { - fmt.Println("\n=== Running Post-Loop Extraction (with reasoning) ===") - } - - // Check if structured outputs are enabled settings := core.GetSettings() useStructuredMode := settings.StructuredOutput.Enabled - if useStructuredMode { return r.runExtractStructured(ctx, messages, inputs, newMessages, priorUsage, historyUpdated) } - return r.runExtractLegacy(ctx, messages, inputs, newMessages, priorUsage, historyUpdated) } -// runExtractStructured performs extraction with structured output enforcement func (r *ReAct) runExtractStructured(ctx context.Context, messages []core.Message, inputs map[string]any, newMessages []core.Message, priorUsage core.Usage, historyUpdated bool) (*core.Prediction, error) { settings := core.GetSettings() - // Build extraction prompt - extractPrompt := r.buildExtractionPrompt() - - // Append extraction request to message history extractMessages := make([]core.Message, len(messages)) copy(extractMessages, messages) - extractMessages = append(extractMessages, core.Message{ - Role: "user", - Content: extractPrompt, - }) + extractMessages = append(extractMessages, core.Message{Role: "user", Content: r.buildExtractionPrompt()}) - // Create a custom adapter wrapper that includes extraction messages wrappedAdapter := &reactExtractAdapter{ base: core.NewSchemaFirstAdapter(r.LM.SupportsJSON()).WithReasoning(true), messages: extractMessages, } - // Copy options and set Tools/ToolChoice for Bedrock compatibility: - // when conversation history contains tool calls, some providers require - // toolConfig to be present even when not requesting tool use. extractOptions := r.Options.Copy() - extractOptions.Tools = r.Tools - extractOptions.ToolChoice = "none" + if r.LM.SupportsTools() { + extractOptions.Tools = r.Tools + extractOptions.ToolChoice = "none" + } else { + extractOptions.Tools = nil + extractOptions.ToolChoice = "" + } - // Call structured output enforcement loop result, err := core.GenerateStructured( ctx, r.LM, r.Signature, inputs, - []core.Example{}, // No demos for extraction + []core.Example{}, core.GenerateStructuredOptions{ Adapter: wrappedAdapter, BaseOptions: extractOptions, @@ -907,43 +777,31 @@ func (r *ReAct) runExtractStructured(ctx context.Context, messages []core.Messag StreamCallback: r.Options.StreamCallback, }, ) - if err != nil { return nil, fmt.Errorf("extraction generation failed: %w", err) } - // Extract and remove rationale/reasoning field from outputs - var rationale string outputs := result.Outputs + + // Extract rationale/reasoning when not part of signature. + rationale := "" if val, ok := outputs["rationale"]; ok { - if str, ok := val.(string); ok { - rationale = str + rationale, _ = val.(string) + if r.Signature.GetOutputField("rationale") == nil { delete(outputs, "rationale") } } if rationale == "" { if val, ok := outputs["reasoning"]; ok { - if str, ok := val.(string); ok { - rationale = str + rationale, _ = val.(string) + if r.Signature.GetOutputField("reasoning") == nil { delete(outputs, "reasoning") } } } - // Merge extraction usage with prior accumulated usage - totalUsage := core.Usage{ - PromptTokens: priorUsage.PromptTokens + result.Usage.PromptTokens, - CompletionTokens: priorUsage.CompletionTokens + result.Usage.CompletionTokens, - TotalTokens: priorUsage.TotalTokens + result.Usage.TotalTokens, - Cost: priorUsage.Cost + result.Usage.Cost, - Latency: priorUsage.Latency, - } - if result.Usage.Latency > totalUsage.Latency { - totalUsage.Latency = result.Usage.Latency - } + totalUsage := addUsage(priorUsage, result.Usage) - // Update history with final answer for multi-turn consistency - // Only update if not already done in Forward (to avoid duplicates) if r.History != nil && !historyUpdated { for _, msg := range newMessages { if msg.Role == "user" { @@ -951,13 +809,9 @@ func (r *ReAct) runExtractStructured(ctx context.Context, messages []core.Messag } } contentBytes, _ := json.Marshal(outputs) - r.History.Add(core.Message{ - Role: "assistant", - Content: string(contentBytes), - }) + r.History.Add(core.Message{Role: "assistant", Content: string(contentBytes)}) } - // Build prediction with diagnostics and rationale pred := core.NewPrediction(outputs). WithRationale(rationale). WithUsage(totalUsage). @@ -967,25 +821,19 @@ func (r *ReAct) runExtractStructured(ctx context.Context, messages []core.Messag if result.Diagnostics != nil { pred.WithParseDiagnostics(result.Diagnostics) } - - if r.Verbose { - fmt.Printf("Extracted outputs: %+v\n", outputs) - if result.Diagnostics != nil && result.Diagnostics.HasErrors() { - fmt.Printf("⚠️ Extraction diagnostics: %v\n", result.Diagnostics) - } + if result.AdapterUsed != "" { + pred.WithAdapterMetrics(result.AdapterUsed, result.ParseAttempts, result.FallbackUsed) } return pred, nil } -// reactExtractAdapter wraps adapter to inject extraction messages type reactExtractAdapter struct { base core.Adapter messages []core.Message } func (rea *reactExtractAdapter) Format(sig *core.Signature, inputs map[string]any, demos []core.Example) ([]core.Message, error) { - // Return pre-built extraction messages instead of formatting return rea.messages, nil } @@ -997,25 +845,13 @@ func (rea *reactExtractAdapter) FormatHistory(history *core.History) []core.Mess return rea.base.FormatHistory(history) } -// runExtractLegacy performs extraction using the legacy path (without structured output enforcement) func (r *ReAct) runExtractLegacy(ctx context.Context, messages []core.Message, inputs map[string]any, newMessages []core.Message, priorUsage core.Usage, historyUpdated bool) (*core.Prediction, error) { - // Build extraction prompt - extractPrompt := r.buildExtractionPrompt() - - // Append extraction request to message history extractMessages := make([]core.Message, len(messages)) copy(extractMessages, messages) - extractMessages = append(extractMessages, core.Message{ - Role: "user", - Content: extractPrompt, - }) - - // Copy options and force JSON mode - // Keep tools available for providers that require tool definitions when - // conversation has tool history (e.g., Amazon Bedrock) + extractMessages = append(extractMessages, core.Message{Role: "user", Content: r.buildExtractionPrompt()}) + options := r.Options.Copy() - // Only pass tools if LM supports them to avoid confusing non-tool LMs - if r.LM.SupportsTools() && r.hasRealTools() { + if r.LM.SupportsTools() { options.Tools = r.Tools options.ToolChoice = "none" } else { @@ -1026,7 +862,6 @@ func (r *ReAct) runExtractLegacy(ctx context.Context, messages []core.Message, i if r.LM.SupportsJSON() { options.ResponseFormat = "json" if options.ResponseSchema == nil { - // Use OpenAI-compliant schema for OpenAI providers to avoid strict mode errors if r.LM.IsOpenAI() { options.ResponseSchema = r.Signature.SignatureToOpenAIJSONSchema() } else { @@ -1035,29 +870,18 @@ func (r *ReAct) runExtractLegacy(ctx context.Context, messages []core.Message, i } } - // Generate extraction result, err := r.LM.Generate(ctx, extractMessages, options) if err != nil { return nil, fmt.Errorf("extraction generation failed: %w", err) } - if r.Verbose { - fmt.Printf("Extraction response: %s\n", result.Content) - } - - // Apply hardened parsing cleanedContent := stripToJSON(result.Content) - // Create temporary adapter WITH reasoning for extraction phase extractAdapter := core.NewFallbackAdapter().WithReasoning(true) - - // Try adapter parsing first (with reasoning) outputs, err := extractAdapter.Parse(r.Signature, cleanedContent) if err != nil { - // Fallback: try direct JSON parsing outputs = make(map[string]any) if jsonErr := json.Unmarshal([]byte(cleanedContent), &outputs); jsonErr != nil { - // Last resort: extract text outputs outputs = r.extractTextOutputs(cleanedContent, extractMessages) if len(outputs) == 0 { return nil, fmt.Errorf("extraction failed to parse output: %w (JSON error: %v)", err, jsonErr) @@ -1065,49 +889,30 @@ func (r *ReAct) runExtractLegacy(ctx context.Context, messages []core.Message, i } } - // Extract and remove rationale/reasoning field from outputs var rationale string if val, ok := outputs["rationale"]; ok { - if str, ok := val.(string); ok { - rationale = str + rationale, _ = val.(string) + if r.Signature.GetOutputField("rationale") == nil { delete(outputs, "rationale") } } if rationale == "" { if val, ok := outputs["reasoning"]; ok { - if str, ok := val.(string); ok { - rationale = str + rationale, _ = val.(string) + if r.Signature.GetOutputField("reasoning") == nil { delete(outputs, "reasoning") } } } - // Apply type coercion outputs = coerceBasicTypes(r.Signature, outputs) - - // Apply output normalization outputs = core.NormalizeOutputKeys(r.Signature, outputs) - - // Use partial validation (allow missing optional fields) diagnostics := r.Signature.ValidateOutputsPartial(outputs) - // Extract adapter metadata adapterUsed, parseAttempts, fallbackUsed := core.ExtractAdapterMetadata(outputs) - // Merge extraction usage with prior accumulated usage - totalUsage := core.Usage{ - PromptTokens: priorUsage.PromptTokens + result.Usage.PromptTokens, - CompletionTokens: priorUsage.CompletionTokens + result.Usage.CompletionTokens, - TotalTokens: priorUsage.TotalTokens + result.Usage.TotalTokens, - Cost: priorUsage.Cost + result.Usage.Cost, - Latency: priorUsage.Latency, - } - if result.Usage.Latency > totalUsage.Latency { - totalUsage.Latency = result.Usage.Latency - } + totalUsage := addUsage(priorUsage, result.Usage) - // Update history with final answer for multi-turn consistency - // Only update if not already done in Forward (to avoid duplicates) if r.History != nil && !historyUpdated { for _, msg := range newMessages { if msg.Role == "user" { @@ -1115,13 +920,9 @@ func (r *ReAct) runExtractLegacy(ctx context.Context, messages []core.Message, i } } contentBytes, _ := json.Marshal(outputs) - r.History.Add(core.Message{ - Role: "assistant", - Content: string(contentBytes), - }) + r.History.Add(core.Message{Role: "assistant", Content: string(contentBytes)}) } - // Build prediction with diagnostics, rationale, and merged usage pred := core.NewPrediction(outputs). WithUsage(totalUsage). WithModuleName(logging.ModuleReAct). @@ -1129,76 +930,176 @@ func (r *ReAct) runExtractLegacy(ctx context.Context, messages []core.Message, i WithAdapterMetrics(adapterUsed, parseAttempts, fallbackUsed). WithParseDiagnostics(diagnostics) - // Attach rationale if found if rationale != "" { pred = pred.WithRationale(rationale) - if r.Verbose { - fmt.Printf("Extracted rationale: %s\n", rationale) + } + + return pred, nil +} + +// extractTextOutputs attempts to extract output fields from raw text when structured parsing fails. +func (r *ReAct) extractTextOutputs(content string, messages []core.Message) map[string]any { + outputs := make(map[string]any) + content = strings.TrimSpace(content) + + // If content is empty/very short, synthesize from tool observations. + if len(content) < 10 { + content = r.synthesizeAnswerFromHistory(messages) + } + + var stringFields []core.Field + for _, field := range r.Signature.OutputFields { + if field.Type == core.FieldTypeString { + stringFields = append(stringFields, field) } } + if len(stringFields) == 0 { + return nil + } - if r.Verbose { - fmt.Printf("Extracted outputs: %+v\n", outputs) - if diagnostics != nil && diagnostics.HasErrors() { - fmt.Printf("⚠️ Extraction diagnostics: %v\n", diagnostics) + if len(stringFields) == 1 && stringFields[0].Name == "answer" { + outputs["answer"] = content + return outputs + } + + primaryField := "" + if r.Signature.GetOutputField("answer") != nil { + primaryField = "answer" + } else { + primaryField = stringFields[0].Name + } + outputs[primaryField] = content + + for _, field := range stringFields { + if field.Name != primaryField && !field.Optional { + outputs[field.Name] = content } } - return pred, nil + return outputs } -// buildExtractionPrompt creates a prompt for post-loop extraction -func (r *ReAct) buildExtractionPrompt() string { - var prompt strings.Builder - prompt.WriteString("Based on the conversation above, including all tool observations and reasoning, ") - prompt.WriteString("please synthesize a final answer now.\n\n") - - prompt.WriteString("Respond with a JSON object containing:\n") - for _, field := range r.Signature.OutputFields { - optional := "" - if field.Optional { - optional = " (optional)" +// synthesizeAnswerFromHistory extracts recent tool observations from history. +// Used as a fallback when the model produces empty content in extraction. +func (r *ReAct) synthesizeAnswerFromHistory(messages []core.Message) string { + var observations []string + for _, msg := range messages { + if msg.Role == "tool" && strings.TrimSpace(msg.Content) != "" { + if strings.HasPrefix(strings.TrimSpace(msg.Content), "Error:") { + continue + } + observations = append(observations, strings.TrimSpace(msg.Content)) } - classInfo := "" - if field.Type == core.FieldTypeClass && len(field.Classes) > 0 { - classInfo = fmt.Sprintf(" [one of: %s]", strings.Join(field.Classes, ", ")) + } + if len(observations) == 0 { + return "No information available from tools" + } + + seen := make(map[string]bool) + unique := make([]string, 0, 3) + for i := len(observations) - 1; i >= 0 && len(unique) < 3; i-- { + obs := observations[i] + if seen[obs] { + continue } - if field.Description != "" { - prompt.WriteString(fmt.Sprintf("- %s (%s)%s%s: %s\n", field.Name, field.Type, optional, classInfo, field.Description)) - } else { - prompt.WriteString(fmt.Sprintf("- %s (%s)%s%s\n", field.Name, field.Type, optional, classInfo)) + if len(obs) <= 20 { + continue } + seen[obs] = true + unique = append([]string{obs}, unique...) + } + if len(unique) > 0 { + return strings.Join(unique, " ") } + return observations[len(observations)-1] +} - prompt.WriteString("\nIMPORTANT:\n") - prompt.WriteString("- Use all information from the tool observations above\n") - prompt.WriteString("- Provide your best answer even if some information is missing\n") - prompt.WriteString("- Return ONLY valid JSON with the required fields\n") - prompt.WriteString("- Do not include any explanations or commentary\n") +// stripToJSON removes common LLM artifacts from JSON output. +func stripToJSON(content string) string { + content = strings.TrimSpace(content) + re := regexp.MustCompile("(?s)```(?:json)?\\n?(.*?)\\n?```") + if matches := re.FindStringSubmatch(content); len(matches) > 1 { + content = strings.TrimSpace(matches[1]) + } - return prompt.String() + start := strings.Index(content, "{") + end := strings.LastIndex(content, "}") + if start != -1 && end != -1 && end > start { + content = content[start : end+1] + } + return strings.TrimSpace(content) } -// Clone creates an independent copy of ReAct module -func (r *ReAct) Clone() core.Module { - cloned := &ReAct{ - Signature: r.Signature, - LM: r.LM, - Tools: make([]core.Tool, len(r.Tools)), - Options: r.Options, - Adapter: r.Adapter, - History: nil, - Demos: make([]core.Example, len(r.Demos)), - MaxIterations: r.MaxIterations, - Verbose: r.Verbose, - } +// coerceBasicTypes handles basic type mismatches in parsed outputs. +func coerceBasicTypes(signature *core.Signature, outputs map[string]any) map[string]any { + coerced := make(map[string]any) + for key, value := range outputs { + field := signature.GetOutputField(key) + if field == nil { + coerced[key] = value + continue + } - copy(cloned.Demos, r.Demos) - copy(cloned.Tools, r.Tools) + switch field.Type { + case core.FieldTypeInt: + if strVal, ok := value.(string); ok { + re := regexp.MustCompile(`-?\d+`) + if match := re.FindString(strVal); match != "" { + if intVal, err := strconv.Atoi(match); err == nil { + coerced[key] = intVal + continue + } + } + } + if floatVal, ok := value.(float64); ok { + coerced[key] = int(floatVal) + continue + } + coerced[key] = value - if r.History != nil { - cloned.History = r.History.Clone() + case core.FieldTypeBool: + if strVal, ok := value.(string); ok { + strVal = strings.ToLower(strings.TrimSpace(strVal)) + if strVal == "true" || strVal == "yes" || strVal == "1" { + coerced[key] = true + continue + } + if strVal == "false" || strVal == "no" || strVal == "0" { + coerced[key] = false + continue + } + } + coerced[key] = value + + case core.FieldTypeString: + if value != nil { + coerced[key] = fmt.Sprintf("%v", value) + } else { + coerced[key] = value + } + + default: + coerced[key] = value + } } + return coerced +} - return cloned +func addUsage(a, b core.Usage) core.Usage { + a.PromptTokens += b.PromptTokens + a.CompletionTokens += b.CompletionTokens + a.TotalTokens += b.TotalTokens + a.Cost += b.Cost + a.Latency += b.Latency + return a +} + +// For safety, keep output map independent when using tool args. +func cloneMap(m map[string]any) map[string]any { + if m == nil { + return nil + } + out := make(map[string]any, len(m)) + maps.Copy(out, m) + return out } diff --git a/internal/module/react_context.go b/internal/module/react_context.go new file mode 100644 index 0000000..a2b347a --- /dev/null +++ b/internal/module/react_context.go @@ -0,0 +1,78 @@ +package module + +import ( + "context" + "errors" + "strings" + + "github.com/assagman/dsgo/internal/core" +) + +const reactContextOverflowMaxRetries = 3 + +// contextLengthSentinel is an internal error marker. +// It is not used by providers directly but allows tests to simulate overflow. +type contextLengthSentinel struct{} + +func (contextLengthSentinel) Error() string { return "context length exceeded" } + +// generateWithContextRetry calls the LM and retries on context overflow errors by +// truncating the oldest trajectory steps (DSPy-style) and retrying. +func (r *ReAct) generateWithContextRetry(ctx context.Context, traj *reactTrajectory, options *core.GenerateOptions, extra []core.Message) (*core.GenerateResult, error) { + var lastErr error + for attempt := 0; attempt < reactContextOverflowMaxRetries; attempt++ { + messages := traj.Render(r.maxPromptBytes()) + if len(extra) > 0 { + messages = append(messages, extra...) + } + + result, err := r.LM.Generate(ctx, messages, options) + if err == nil { + return result, nil + } + lastErr = err + if !isContextOverflowError(err) { + return nil, err + } + if traj.DropOldestSteps(1) == 0 { + return nil, err + } + } + + if lastErr == nil { + lastErr = errors.New("context overflow retry exhausted") + } + return nil, lastErr +} + +// isContextOverflowError attempts to detect provider context window exceeded errors. +// This is heuristic by design because different providers surface different error types. +func isContextOverflowError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, contextLengthSentinel{}) { + return true + } + + msg := strings.ToLower(err.Error()) + patterns := []string{ + "context_length_exceeded", + "maximum context length", + "max context length", + "maximum context window", + "context window", + "too many tokens", + "exceeded the context", + "please reduce the length of the messages", + "prompt is too long", + "tokens exceeded", + "input is too long", + } + for _, p := range patterns { + if strings.Contains(msg, p) { + return true + } + } + return false +} diff --git a/internal/module/react_termination.go b/internal/module/react_termination.go new file mode 100644 index 0000000..78f042f --- /dev/null +++ b/internal/module/react_termination.go @@ -0,0 +1,146 @@ +package module + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + + "github.com/assagman/dsgo/internal/core" +) + +type terminationReason string + +const ( + terminationNone terminationReason = "" + terminationNoToolCalls terminationReason = "no_tool_calls" + terminationFinishTool terminationReason = "finish_tool" + terminationPlanningDone terminationReason = "planning_done" + terminationPlanningParseError terminationReason = "planning_parse_error" + terminationRepeatedToolCall terminationReason = "repeated_tool_call" + terminationRepeatedErrors terminationReason = "repeated_errors" + terminationStagnation terminationReason = "stagnation" +) + +// reactTermination implements loop termination policies: +// - repeated tool+args fingerprint +// - repeated tool errors +// - repeated identical observations (stagnation) +// +// It does not inject prompts; it only signals that the loop should stop. +type reactTermination struct { + done bool + reason terminationReason + + // Optional final candidates captured from the loop. + finalToolArgs map[string]any + finalContent string + + lastToolFingerprint string + repeatToolCalls int + + lastObservationHash string + repeatObservations int + + consecutiveErrors int +} + +func newReActTermination() *reactTermination { + return &reactTermination{} +} + +func (t *reactTermination) MarkDone(reason terminationReason) { + if t.done { + return + } + t.done = true + t.reason = reason +} + +func (t *reactTermination) ShouldStop() bool { + return t.done +} + +func (t *reactTermination) Reason() terminationReason { + return t.reason +} + +func (t *reactTermination) SetFinalToolArgs(args map[string]interface{}) { + if args == nil { + return + } + copied := make(map[string]any, len(args)) + for k, v := range args { + copied[k] = v + } + t.finalToolArgs = copied +} + +func (t *reactTermination) SetFinalContent(content string) { + t.finalContent = content +} + +func (t *reactTermination) FinalToolArgs() map[string]any { + return t.finalToolArgs +} + +func (t *reactTermination) FinalContent() string { + return t.finalContent +} + +func (t *reactTermination) ObserveToolCall(tc core.ToolCall) { + fp := toolFingerprint(tc) + if fp == "" { + return + } + if fp == t.lastToolFingerprint { + t.repeatToolCalls++ + if t.repeatToolCalls >= 2 { + t.MarkDone(terminationRepeatedToolCall) + } + return + } + t.lastToolFingerprint = fp + t.repeatToolCalls = 0 +} + +func (t *reactTermination) ObserveToolResult(tc core.ToolCall, observationHash string, err error) { + if err != nil { + t.consecutiveErrors++ + if t.consecutiveErrors >= 2 { + t.MarkDone(terminationRepeatedErrors) + } + } else { + t.consecutiveErrors = 0 + } + + hash := observationHash + if hash == "" { + h := sha256.Sum256([]byte(tc.Name)) + hash = hex.EncodeToString(h[:]) + } + if hash == t.lastObservationHash { + t.repeatObservations++ + if t.repeatObservations >= 2 { + t.MarkDone(terminationStagnation) + } + return + } + t.lastObservationHash = hash + t.repeatObservations = 0 +} + +func (t *reactTermination) ObserveError(err error) { + if err == nil { + return + } + t.consecutiveErrors++ + if t.consecutiveErrors >= 2 { + t.MarkDone(terminationRepeatedErrors) + } +} + +func toolFingerprint(tc core.ToolCall) string { + argsJSON, _ := json.Marshal(tc.Arguments) + h := sha256.Sum256(argsJSON) + return tc.Name + ":" + hex.EncodeToString(h[:]) +} diff --git a/internal/module/react_test.go b/internal/module/react_test.go index da0974e..9731a9b 100644 --- a/internal/module/react_test.go +++ b/internal/module/react_test.go @@ -2,6 +2,7 @@ package module import ( "context" + "encoding/json" "errors" "fmt" "strings" @@ -50,25 +51,38 @@ func TestReAct_Forward_WithToolCalls(t *testing.T) { SupportsToolsVal: true, GenerateFunc: func(ctx context.Context, messages []core.Message, options *core.GenerateOptions) (*core.GenerateResult, error) { callCount++ - if callCount == 1 { + + switch callCount { + case 1: + // First call: make a tool call return &core.GenerateResult{ Content: "Let me search", ToolCalls: []core.ToolCall{ {ID: "1", Name: "search", Arguments: map[string]interface{}{"query": "test"}}, }, }, nil + case 2: + // Second call: make same tool call (stagnation) + return &core.GenerateResult{ + Content: "Let me search again", + ToolCalls: []core.ToolCall{ + {ID: "2", Name: "search", Arguments: map[string]interface{}{"query": "test"}}, + }, + }, nil + default: + // After stagnation message: provide final answer + return &core.GenerateResult{ + Content: `{"answer": "forced final answer"}`, + }, nil } - return &core.GenerateResult{ - Content: `{"answer": "final answer"}`, - }, nil }, } searchTool := core.NewTool("search", "Search for info", func(ctx context.Context, args map[string]any) (any, error) { - return "search result", nil + return "same result", nil }) - react := NewReAct(sig, lm, []core.Tool{*searchTool}) + react := NewReAct(sig, lm, []core.Tool{*searchTool}).WithMaxIterations(10) outputs, err := react.Forward(context.Background(), map[string]interface{}{ "question": "test", }) @@ -77,140 +91,21 @@ func TestReAct_Forward_WithToolCalls(t *testing.T) { t.Fatalf("Forward() error = %v", err) } - if outputs.Outputs["answer"] != "final answer" { - t.Errorf("Expected final answer, got %v", outputs.Outputs["answer"]) + // Verify that the final answer was forced after stagnation + if outputs.Outputs["answer"] != "forced final answer" { + t.Errorf("Expected forced final answer after stagnation, got %v", outputs.Outputs["answer"]) } -} -func TestCoerceBasicTypes(t *testing.T) { - t.Parallel() - tests := []struct { - name string - signature *core.Signature - inputs map[string]any - expected map[string]any - }{ - { - name: "int from string with number", - signature: core.NewSignature("test"). - AddOutput("age", core.FieldTypeInt, "Age"), - inputs: map[string]any{"age": "5 years"}, - expected: map[string]any{"age": 5}, - }, - { - name: "int from string negative", - signature: core.NewSignature("test"). - AddOutput("temp", core.FieldTypeInt, "Temperature"), - inputs: map[string]any{"temp": "-10 degrees"}, - expected: map[string]any{"temp": -10}, - }, - { - name: "int from float64", - signature: core.NewSignature("test"). - AddOutput("count", core.FieldTypeInt, "Count"), - inputs: map[string]any{"count": float64(42.7)}, - expected: map[string]any{"count": 42}, - }, - { - name: "int from unparseable string", - signature: core.NewSignature("test"). - AddOutput("val", core.FieldTypeInt, "Value"), - inputs: map[string]any{"val": "no number here"}, - expected: map[string]any{"val": "no number here"}, - }, - { - name: "bool from string true variants", - signature: core.NewSignature("test"). - AddOutput("flag1", core.FieldTypeBool, "Flag1"). - AddOutput("flag2", core.FieldTypeBool, "Flag2"). - AddOutput("flag3", core.FieldTypeBool, "Flag3"), - inputs: map[string]any{ - "flag1": "true", - "flag2": "YES", - "flag3": " 1 ", - }, - expected: map[string]any{ - "flag1": true, - "flag2": true, - "flag3": true, - }, - }, - { - name: "bool from string false variants", - signature: core.NewSignature("test"). - AddOutput("flag1", core.FieldTypeBool, "Flag1"). - AddOutput("flag2", core.FieldTypeBool, "Flag2"). - AddOutput("flag3", core.FieldTypeBool, "Flag3"), - inputs: map[string]any{ - "flag1": "false", - "flag2": "NO", - "flag3": " 0 ", - }, - expected: map[string]any{ - "flag1": false, - "flag2": false, - "flag3": false, - }, - }, - { - name: "bool from unparseable string", - signature: core.NewSignature("test"). - AddOutput("flag", core.FieldTypeBool, "Flag"), - inputs: map[string]any{"flag": "maybe"}, - expected: map[string]any{"flag": "maybe"}, - }, - { - name: "string from non-nil value", - signature: core.NewSignature("test"). - AddOutput("text", core.FieldTypeString, "Text"), - inputs: map[string]any{"text": 123}, - expected: map[string]any{"text": "123"}, - }, - { - name: "string from nil value", - signature: core.NewSignature("test"). - AddOutput("text", core.FieldTypeString, "Text"), - inputs: map[string]any{"text": nil}, - expected: map[string]any{"text": nil}, - }, - { - name: "field not in signature", - signature: core.NewSignature("test"). - AddOutput("known", core.FieldTypeString, "Known"), - inputs: map[string]any{"unknown": "value", "known": "test"}, - expected: map[string]any{"unknown": "value", "known": "test"}, - }, - { - name: "default type passthrough", - signature: core.NewSignature("test"). - AddOutput("data", core.FieldTypeJSON, "Data"), - inputs: map[string]any{"data": map[string]any{"key": "value"}}, - expected: map[string]any{"data": map[string]any{"key": "value"}}, - }, - } + // Stagnation triggers early termination into extraction; no extra prompt injection. - for _, tt := range tests { - tt := tt // Capture range variable - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - result := coerceBasicTypes(tt.signature, tt.inputs) - - for key, expectedVal := range tt.expected { - actualVal, ok := result[key] - if !ok { - t.Errorf("Expected key %q not found in result", key) - continue - } - if fmt.Sprintf("%v", actualVal) != fmt.Sprintf("%v", expectedVal) { - t.Errorf("For key %q: expected %v (%T), got %v (%T)", - key, expectedVal, expectedVal, actualVal, actualVal) - } - } - }) + // Verify the model was called at least 3 times (2 tool calls + 1 final answer after stagnation) + if callCount < 3 { + t.Errorf("Expected at least 3 LM calls (stagnation + recovery), got %d", callCount) } } -func TestReAct_RunExtract_Success(t *testing.T) { +// TestReAct_Forward_WithHistory tests history management +func TestReAct_Forward_WithHistory(t *testing.T) { t.Parallel() sig := core.NewSignature("Answer question"). AddInput("question", core.FieldTypeString, "Question"). @@ -220,124 +115,144 @@ func TestReAct_RunExtract_Success(t *testing.T) { SupportsJSONVal: true, GenerateFunc: func(ctx context.Context, messages []core.Message, options *core.GenerateOptions) (*core.GenerateResult, error) { return &core.GenerateResult{ - Content: `{"rationale": "my reasoning", "answer": "extracted answer"}`, - Usage: core.Usage{PromptTokens: 10, CompletionTokens: 20, TotalTokens: 30}, + Content: `{"answer": "final answer with history"}`, }, nil }, } - react := NewReAct(sig, lm, []core.Tool{}) - messages := []core.Message{ - {Role: "user", Content: "What is the answer?"}, - } - inputs := map[string]any{"question": "test"} + history := core.NewHistory() + history.Add(core.Message{Role: "user", Content: "previous question"}) + history.Add(core.Message{Role: "assistant", Content: "previous answer"}) + + react := NewReAct(sig, lm, []core.Tool{}).WithHistory(history) + outputs, err := react.Forward(context.Background(), map[string]interface{}{ + "question": "current question", + }) - pred, err := react.runExtract(context.Background(), messages, inputs, []core.Message{}, core.Usage{}, false) if err != nil { - t.Fatalf("runExtract() error = %v", err) + t.Fatalf("Forward() error = %v", err) } - if pred.Outputs["answer"] != "extracted answer" { - t.Errorf("Expected answer='extracted answer', got %v", pred.Outputs["answer"]) + if outputs.Outputs["answer"] != "final answer with history" { + t.Errorf("Expected answer with history, got %v", outputs.Outputs["answer"]) } - if pred.Rationale != "my reasoning" { - t.Errorf("Expected rationale='my reasoning', got %v", pred.Rationale) + // Verify history was updated + if history.Len() != 4 { // 2 previous + 1 user + 1 assistant + t.Errorf("Expected 4 messages in history, got %d", history.Len()) } } -func TestReAct_RunExtract_FallbackToDirectJSON(t *testing.T) { +// TestReAct_Forward_WithFinishTool tests the "finish" tool detection +func TestReAct_Forward_WithFinishTool(t *testing.T) { t.Parallel() sig := core.NewSignature("Answer question"). AddInput("question", core.FieldTypeString, "Question"). - AddOutput("answer", core.FieldTypeString, "Answer") + AddOutput("answer", core.FieldTypeString, "Answer"). + AddOutput("confidence", core.FieldTypeFloat, "Confidence") + callCount := 0 lm := &MockLM{ - SupportsJSONVal: true, + SupportsToolsVal: true, + SupportsJSONVal: true, GenerateFunc: func(ctx context.Context, messages []core.Message, options *core.GenerateOptions) (*core.GenerateResult, error) { - // Return JSON that might fail adapter parsing but is valid JSON - return &core.GenerateResult{ - Content: `{"answer": "direct json answer"}`, - Usage: core.Usage{PromptTokens: 10, CompletionTokens: 20, TotalTokens: 30}, - }, nil + callCount++ + if callCount == 1 { + // Loop call: finish is a termination signal. + return &core.GenerateResult{ + Content: "I have the answer", + ToolCalls: []core.ToolCall{ + { + ID: "finish-1", + Name: "finish", + Arguments: map[string]interface{}{ + "answer": "The answer is 42", + "confidence": 0.95, + }, + }, + }, + }, nil + } + // Extraction call: produce signature-valid JSON. + return &core.GenerateResult{Content: `{"answer":"The answer is 42","confidence":0.95}`}, nil }, } - react := NewReAct(sig, lm, []core.Tool{}) - messages := []core.Message{{Role: "user", Content: "test"}} - inputs := map[string]any{"question": "test"} + dummyTool := core.NewTool("dummy", "unused", func(ctx context.Context, args map[string]any) (any, error) { + return "unused", nil + }) - pred, err := react.runExtract(context.Background(), messages, inputs, []core.Message{}, core.Usage{}, false) + react := NewReAct(sig, lm, []core.Tool{*dummyTool}) + outputs, err := react.Forward(context.Background(), map[string]interface{}{ + "question": "What is the answer?", + }) if err != nil { - t.Fatalf("runExtract() error = %v", err) - } - - if pred.Outputs["answer"] != "direct json answer" { - t.Errorf("Expected answer from direct JSON, got %v", pred.Outputs["answer"]) + t.Fatalf("Forward() error = %v", err) } -} - -func TestReAct_RunExtract_FallbackToTextExtraction(t *testing.T) { - t.Parallel() - sig := core.NewSignature("Answer question"). - AddInput("question", core.FieldTypeString, "Question"). - AddOutput("answer", core.FieldTypeString, "Answer") - lm := &MockLM{ - SupportsJSONVal: true, - GenerateFunc: func(ctx context.Context, messages []core.Message, options *core.GenerateOptions) (*core.GenerateResult, error) { - // Return non-JSON text that needs text extraction - return &core.GenerateResult{ - Content: `The answer is: fallback text answer`, - Usage: core.Usage{PromptTokens: 10, CompletionTokens: 20, TotalTokens: 30}, - }, nil - }, + if outputs.Outputs["answer"] != "The answer is 42" { + t.Errorf("expected answer, got %v", outputs.Outputs["answer"]) } - - react := NewReAct(sig, lm, []core.Tool{}) - messages := []core.Message{{Role: "user", Content: "test"}} - inputs := map[string]any{"question": "test"} - - pred, err := react.runExtract(context.Background(), messages, inputs, []core.Message{}, core.Usage{}, false) - // Should succeed using extractTextOutputs as last resort - if err != nil { - t.Fatalf("runExtract() should succeed with text extraction, got error: %v", err) + if outputs.Outputs["confidence"] != 0.95 { + t.Errorf("expected confidence 0.95, got %v", outputs.Outputs["confidence"]) } - - // Check that some output was extracted - if len(pred.Outputs) == 0 { - t.Errorf("Expected text extraction to produce outputs") + if callCount != 1 { + t.Errorf("expected 1 LM call (finish args validated), got %d", callCount) } } -func TestReAct_RunExtract_GenerationError(t *testing.T) { +// TestReAct_Forward_WithFinishTool_InvalidOutputs tests finish tool with validation errors +func TestReAct_Forward_WithFinishTool_InvalidOutputs(t *testing.T) { t.Parallel() sig := core.NewSignature("Answer question"). AddInput("question", core.FieldTypeString, "Question"). - AddOutput("answer", core.FieldTypeString, "Answer") + AddOutput("answer", core.FieldTypeString, "Answer"). + AddOutput("score", core.FieldTypeInt, "Score") + callCount := 0 lm := &MockLM{ - SupportsJSONVal: true, + SupportsToolsVal: true, + SupportsJSONVal: true, GenerateFunc: func(ctx context.Context, messages []core.Message, options *core.GenerateOptions) (*core.GenerateResult, error) { - return nil, errors.New("generation failed") + callCount++ + if callCount == 1 { + // First call: finish tool (possibly invalid args). The loop terminates and extraction produces outputs. + return &core.GenerateResult{ + Content: "Trying to finish", + ToolCalls: []core.ToolCall{ + { + ID: "finish-1", + Name: "finish", + Arguments: map[string]interface{}{"answer": "incomplete"}, + }, + }, + }, nil + } + // Extraction call: proper final answer. + return &core.GenerateResult{Content: `{"answer":"complete answer","score":85}`}, nil }, } - react := NewReAct(sig, lm, []core.Tool{}) - messages := []core.Message{{Role: "user", Content: "test"}} - inputs := map[string]any{"question": "test"} + dummyTool := core.NewTool("dummy", "unused", func(ctx context.Context, args map[string]any) (any, error) { + return "unused", nil + }) - _, err := react.runExtract(context.Background(), messages, inputs, []core.Message{}, core.Usage{}, false) - if err == nil { - t.Fatal("runExtract() should fail when generation fails") + react := NewReAct(sig, lm, []core.Tool{*dummyTool}) + outputs, err := react.Forward(context.Background(), map[string]interface{}{"question": "test"}) + if err != nil { + t.Fatalf("Forward() error = %v", err) } - if !strings.Contains(err.Error(), "extraction generation failed") { - t.Errorf("Expected error about generation failure, got: %v", err) + if outputs.Outputs["answer"] != "complete answer" { + t.Errorf("expected recovered answer, got %v", outputs.Outputs["answer"]) + } + if callCount != 2 { + t.Errorf("expected 2 calls (loop + extraction), got %d", callCount) } } -func TestReAct_RunExtract_CompleteFailure(t *testing.T) { +// TestReAct_Forward_WithReasoning tests reasoning field extraction and cleanup +func TestReAct_Forward_WithReasoning(t *testing.T) { t.Parallel() sig := core.NewSignature("Answer question"). AddInput("question", core.FieldTypeString, "Question"). @@ -346,583 +261,799 @@ func TestReAct_RunExtract_CompleteFailure(t *testing.T) { lm := &MockLM{ SupportsJSONVal: true, GenerateFunc: func(ctx context.Context, messages []core.Message, options *core.GenerateOptions) (*core.GenerateResult, error) { - // Return unparseable JSON-like content return &core.GenerateResult{ - Content: `{invalid json`, - Usage: core.Usage{PromptTokens: 10, CompletionTokens: 5, TotalTokens: 15}, + Content: `{"reasoning": "Let me think about this...", "answer": "The answer"}`, }, nil }, } react := NewReAct(sig, lm, []core.Tool{}) - messages := []core.Message{{Role: "user", Content: "test"}} - inputs := map[string]any{"question": "test"} + outputs, err := react.Forward(context.Background(), map[string]interface{}{ + "question": "test", + }) - // Even with invalid JSON, extractTextOutputs will extract something - pred, err := react.runExtract(context.Background(), messages, inputs, []core.Message{}, core.Usage{}, false) if err != nil { - t.Fatalf("runExtract() should succeed with text extraction fallback, got error: %v", err) + t.Fatalf("Forward() error = %v", err) + } + + // Reasoning should be extracted to rationale + if outputs.Rationale != "Let me think about this..." { + t.Errorf("Expected rationale to be set, got %q", outputs.Rationale) } - // Should have extracted something via text extraction - if len(pred.Outputs) == 0 { - t.Error("Expected text extraction to produce outputs") + // Reasoning should be removed from outputs if not in signature + if _, exists := outputs.Outputs["reasoning"]; exists { + t.Error("Reasoning should be removed from outputs when not in signature") } } -func TestReAct_RunExtract_WithReasoningField(t *testing.T) { +// TestReAct_Forward_WithReasoningInSignature tests when reasoning is part of the signature +func TestReAct_Forward_WithReasoningInSignature(t *testing.T) { t.Parallel() sig := core.NewSignature("Answer question"). AddInput("question", core.FieldTypeString, "Question"). + AddOutput("reasoning", core.FieldTypeString, "Reasoning"). AddOutput("answer", core.FieldTypeString, "Answer") lm := &MockLM{ SupportsJSONVal: true, GenerateFunc: func(ctx context.Context, messages []core.Message, options *core.GenerateOptions) (*core.GenerateResult, error) { return &core.GenerateResult{ - Content: `{"reasoning": "alternative reasoning field", "answer": "test answer"}`, - Usage: core.Usage{PromptTokens: 10, CompletionTokens: 20, TotalTokens: 30}, + Content: `{"reasoning": "Thinking step by step", "answer": "42"}`, }, nil }, } react := NewReAct(sig, lm, []core.Tool{}) - messages := []core.Message{{Role: "user", Content: "test"}} - inputs := map[string]any{"question": "test"} + outputs, err := react.Forward(context.Background(), map[string]interface{}{ + "question": "test", + }) - pred, err := react.runExtract(context.Background(), messages, inputs, []core.Message{}, core.Usage{}, false) if err != nil { - t.Fatalf("runExtract() error = %v", err) - } - - if pred.Rationale != "alternative reasoning field" { - t.Errorf("Expected rationale from 'reasoning' field, got %v", pred.Rationale) - } - - // Reasoning field should be removed from outputs - if _, exists := pred.Outputs["reasoning"]; exists { - t.Errorf("reasoning field should be removed from outputs") - } -} - -func TestReAct_Forward_InvalidInput(t *testing.T) { - t.Parallel() - sig := core.NewSignature("Test"). - AddInput("required", core.FieldTypeString, "Required") - - lm := &MockLM{} - react := NewReAct(sig, lm, []core.Tool{}) - - _, err := react.Forward(context.Background(), map[string]interface{}{}) - if err == nil { - t.Error("Forward() should error on invalid input") + t.Fatalf("Forward() error = %v", err) } -} - -func TestReAct_Forward_LMError(t *testing.T) { - t.Parallel() - sig := core.NewSignature("Test"). - AddInput("question", core.FieldTypeString, "Question") - lm := &MockLM{ - GenerateFunc: func(ctx context.Context, messages []core.Message, options *core.GenerateOptions) (*core.GenerateResult, error) { - return nil, errors.New("LM error") - }, + // Reasoning should be in rationale + if outputs.Rationale != "Thinking step by step" { + t.Errorf("Expected rationale to be set, got %q", outputs.Rationale) } - react := NewReAct(sig, lm, []core.Tool{}) - _, err := react.Forward(context.Background(), map[string]interface{}{ - "question": "test", - }) - - if err == nil { - t.Error("Forward() should propagate LM error") + // Reasoning should remain in outputs when it's in the signature + if _, exists := outputs.Outputs["reasoning"]; !exists { + t.Error("Reasoning should remain in outputs when it's part of the signature") } } -func TestReAct_Forward_MaxIterations(t *testing.T) { +// TestReAct_Forward_JSONModeWithJSONAdapter tests JSON mode enablement with JSONAdapter +func TestReAct_Forward_JSONModeWithJSONAdapter(t *testing.T) { t.Parallel() - sig := core.NewSignature("Test"). + sig := core.NewSignature("Answer question"). AddInput("question", core.FieldTypeString, "Question"). AddOutput("answer", core.FieldTypeString, "Answer") + optionsCaptured := false lm := &MockLM{ + SupportsJSONVal: true, GenerateFunc: func(ctx context.Context, messages []core.Message, options *core.GenerateOptions) (*core.GenerateResult, error) { + if options.ResponseFormat == "json" { + optionsCaptured = true + } return &core.GenerateResult{ - Content: "thinking", - ToolCalls: []core.ToolCall{ - {ID: "1", Name: "search", Arguments: map[string]interface{}{}}, - }, + Content: `{"answer": "json mode answer"}`, }, nil }, } - react := NewReAct(sig, lm, []core.Tool{}).WithMaxIterations(2) - result, err := react.Forward(context.Background(), map[string]interface{}{ + react := NewReAct(sig, lm, []core.Tool{}).WithAdapter(core.NewJSONAdapter()) + outputs, err := react.Forward(context.Background(), map[string]interface{}{ "question": "test", }) - // With the extraction phase, ReAct should now return a result instead of erroring if err != nil { - t.Errorf("Forward() should not error when max iterations exceeded, got: %v", err) + t.Fatalf("Forward() error = %v", err) } - if result == nil { - t.Error("Forward() should return a result via extraction") + + if !optionsCaptured { + t.Error("JSON mode should be enabled when using JSONAdapter and LM supports JSON") } - // Verify that extraction was called (should have made additional LM call) - if result != nil { - answer, _ := result.GetString("answer") - if answer == "" { - t.Error("Extraction should have produced an answer") - } + + if outputs.Outputs["answer"] != "json mode answer" { + t.Errorf("Expected answer, got %v", outputs.Outputs["answer"]) } } -func TestReAct_Forward_ToolNotFound(t *testing.T) { +// TestReAct_Forward_MultipleToolCalls tests multiple tool calls in one iteration +func TestReAct_Forward_MultipleToolCalls(t *testing.T) { t.Parallel() - sig := core.NewSignature("Test"). + sig := core.NewSignature("Answer question"). AddInput("question", core.FieldTypeString, "Question"). AddOutput("answer", core.FieldTypeString, "Answer") callCount := 0 lm := &MockLM{ + SupportsToolsVal: true, GenerateFunc: func(ctx context.Context, messages []core.Message, options *core.GenerateOptions) (*core.GenerateResult, error) { callCount++ if callCount == 1 { return &core.GenerateResult{ - Content: "Using tool", + Content: "Using multiple tools", ToolCalls: []core.ToolCall{ - {ID: "1", Name: "nonexistent", Arguments: map[string]interface{}{}}, + {ID: "1", Name: "search", Arguments: map[string]interface{}{"query": "test1"}}, + {ID: "2", Name: "calculate", Arguments: map[string]interface{}{"expr": "2+2"}}, }, }, nil } return &core.GenerateResult{ - Content: `{"answer": "recovered"}`, + Content: `{"answer": "combined result"}`, }, nil }, } - react := NewReAct(sig, lm, []core.Tool{}) - outputs, err := react.Forward(context.Background(), map[string]interface{}{ - "question": "test", + searchTool := core.NewTool("search", "Search", func(ctx context.Context, args map[string]any) (any, error) { + return "search result", nil }) - + calcTool := core.NewTool("calculate", "Calculate", func(ctx context.Context, args map[string]any) (any, error) { + return "4", nil + }) + + react := NewReAct(sig, lm, []core.Tool{*searchTool, *calcTool}) + outputs, err := react.Forward(context.Background(), map[string]interface{}{ + "question": "test", + }) + if err != nil { - t.Fatalf("Forward() should handle missing tool gracefully, got error: %v", err) + t.Fatalf("Forward() error = %v", err) } - if outputs.Outputs["answer"] != "recovered" { - t.Error("Should recover from tool not found error") + if outputs.Outputs["answer"] != "combined result" { + t.Error("Should handle multiple tool calls in one iteration") } } -func TestReAct_Forward_ToolError(t *testing.T) { +// TestReAct_Forward_WithDemos tests few-shot examples +func TestReAct_Forward_WithDemos(t *testing.T) { t.Parallel() - sig := core.NewSignature("Test"). + sig := core.NewSignature("Answer question"). AddInput("question", core.FieldTypeString, "Question"). AddOutput("answer", core.FieldTypeString, "Answer") - callCount := 0 lm := &MockLM{ + SupportsJSONVal: true, GenerateFunc: func(ctx context.Context, messages []core.Message, options *core.GenerateOptions) (*core.GenerateResult, error) { - callCount++ - if callCount == 1 { - return &core.GenerateResult{ - Content: "Using tool", - ToolCalls: []core.ToolCall{ - {ID: "1", Name: "failing_tool", Arguments: map[string]interface{}{}}, - }, - }, nil - } return &core.GenerateResult{ - Content: `{"answer": "recovered from error"}`, + Content: `{"answer": "demo-informed answer"}`, }, nil }, } - failingTool := core.NewTool("failing_tool", "Fails", func(ctx context.Context, args map[string]any) (any, error) { - return nil, errors.New("tool failed") - }) + demos := []core.Example{ + { + Inputs: map[string]any{"question": "What is 2+2?"}, + Outputs: map[string]any{"answer": "4"}, + }, + } - react := NewReAct(sig, lm, []core.Tool{*failingTool}) + react := NewReAct(sig, lm, []core.Tool{}).WithDemos(demos) outputs, err := react.Forward(context.Background(), map[string]interface{}{ - "question": "test", + "question": "What is 3+3?", }) if err != nil { - t.Fatalf("Forward() should handle tool errors, got: %v", err) - } - - if outputs.Outputs["answer"] != "recovered from error" { - t.Error("Should recover from tool execution error") - } -} - -func TestReAct_WithOptions(t *testing.T) { - t.Parallel() - sig := core.NewSignature("Test") - lm := &MockLM{} - react := NewReAct(sig, lm, []core.Tool{}) - - customOpts := &core.GenerateOptions{Temperature: 0.9} - react.WithOptions(customOpts) - - if react.Options.Temperature != 0.9 { - t.Error("WithOptions should set custom options") - } -} - -func TestReAct_WithMaxIterations(t *testing.T) { - t.Parallel() - react := NewReAct(core.NewSignature("Test"), &MockLM{}, []core.Tool{}) - react.WithMaxIterations(5) - - if react.MaxIterations != 5 { - t.Error("WithMaxIterations should set max iterations") - } -} - -func TestReAct_WithVerbose(t *testing.T) { - t.Parallel() - react := NewReAct(core.NewSignature("Test"), &MockLM{}, []core.Tool{}) - react.WithVerbose(true) - - if !react.Verbose { - t.Error("WithVerbose should enable verbose mode") + t.Fatalf("Forward() error = %v", err) } -} -func TestReAct_GetSignature(t *testing.T) { - t.Parallel() - sig := core.NewSignature("Test") - react := NewReAct(sig, &MockLM{}, []core.Tool{}) - - if react.GetSignature() != sig { - t.Error("GetSignature should return the signature") + if outputs.Outputs["answer"] != "demo-informed answer" { + t.Errorf("Expected demo-informed answer, got %v", outputs.Outputs["answer"]) } } -// TestReAct_FixJSONNewlines removed - functionality moved to internal/jsonutil package -// See internal/jsonutil/extract_test.go for comprehensive JSON extraction and newline fixing tests - -func TestReAct_BuildSystemPrompt_NoTools(t *testing.T) { +// TestReAct_Forward_AdapterMetrics tests adapter metadata extraction +func TestReAct_Forward_AdapterMetrics(t *testing.T) { t.Parallel() - sig := core.NewSignature("Test") - react := NewReAct(sig, &MockLM{}, []core.Tool{}) + sig := core.NewSignature("Answer question"). + AddInput("question", core.FieldTypeString, "Question"). + AddOutput("answer", core.FieldTypeString, "Answer") - prompt := react.buildSystemPrompt() - if prompt != "" { - t.Error("System prompt should be empty when no tools") + lm := &MockLM{ + SupportsJSONVal: true, + GenerateFunc: func(ctx context.Context, messages []core.Message, options *core.GenerateOptions) (*core.GenerateResult, error) { + // JSON format succeeds on first try (JSONAdapter is now first in fallback chain) + return &core.GenerateResult{ + Content: `{"answer": "test"}`, + }, nil + }, } -} -func TestReAct_BuildSystemPrompt_WithTools(t *testing.T) { - t.Parallel() - sig := core.NewSignature("Test") - tool := core.NewTool("test", "Test tool", nil) - react := NewReAct(sig, &MockLM{}, []core.Tool{*tool}) + react := NewReAct(sig, lm, []core.Tool{}) + outputs, err := react.Forward(context.Background(), map[string]interface{}{ + "question": "test", + }) - prompt := react.buildSystemPrompt() - if prompt == "" { - t.Error("System prompt should not be empty with tools") + if err != nil { + t.Fatalf("Forward() error = %v", err) } - if !contains(prompt, "tools") { - t.Error("System prompt should mention tools") - } - if !contains(prompt, "finish") { - t.Error("System prompt should mention finish tool") + if outputs.AdapterUsed == "" { + t.Error("Expected adapter metadata to be extracted") } -} -func TestReAct_FindTool(t *testing.T) { - t.Parallel() - tool1 := core.NewTool("search", "Search", nil) - tool2 := core.NewTool("calculate", "Calculate", nil) - - sig := core.NewSignature("Test") - react := NewReAct(sig, &MockLM{}, []core.Tool{*tool1, *tool2}) - - found := react.findTool("search") - if found == nil || found.Name != "search" { - t.Error("Should find existing tool") + // JSON format should succeed on first attempt (no fallback needed) + if outputs.ParseAttempts != 1 { + t.Errorf("Expected 1 parse attempt, got %d", outputs.ParseAttempts) } - notFound := react.findTool("nonexistent") - if notFound != nil { - t.Error("Should return nil for missing tool") + if outputs.FallbackUsed { + t.Error("Expected fallback_used to be false for JSON format") } } -func TestReAct_StagnationDetection(t *testing.T) { +// TestReAct_Forward_OutputValidationError tests validation errors after parsing +func TestReAct_Forward_OutputValidationError(t *testing.T) { t.Parallel() sig := core.NewSignature("Answer question"). AddInput("question", core.FieldTypeString, "Question"). - AddOutput("answer", core.FieldTypeString, "Answer") + AddOutput("answer", core.FieldTypeString, "Answer"). + AddOutput("score", core.FieldTypeInt, "Required score") callCount := 0 - var capturedMessages []core.Message lm := &MockLM{ SupportsToolsVal: true, + SupportsJSONVal: true, GenerateFunc: func(ctx context.Context, messages []core.Message, options *core.GenerateOptions) (*core.GenerateResult, error) { callCount++ - capturedMessages = messages - - switch callCount { - case 1: - // First call: make a tool call - return &core.GenerateResult{ - Content: "Let me search", - ToolCalls: []core.ToolCall{ - {ID: "1", Name: "search", Arguments: map[string]interface{}{"query": "test"}}, - }, - }, nil - case 2: - // Second call: make same tool call (stagnation) - return &core.GenerateResult{ - Content: "Let me search again", - ToolCalls: []core.ToolCall{ - {ID: "2", Name: "search", Arguments: map[string]interface{}{"query": "test"}}, - }, - }, nil - default: - // After stagnation message: provide final answer - return &core.GenerateResult{ - Content: `{"answer": "forced final answer"}`, - }, nil + if callCount == 1 { + // Loop call (implicit finish): missing required field. + return &core.GenerateResult{Content: `{"answer":"incomplete"}`}, nil } + // Extraction call - provide complete answer. + return &core.GenerateResult{Content: `{"answer":"extracted answer","score":42}`}, nil }, } - searchTool := core.NewTool("search", "Search for info", func(ctx context.Context, args map[string]any) (any, error) { - return "same result", nil - }) + dummyTool := core.NewTool("dummy", "unused", func(ctx context.Context, args map[string]any) (any, error) { return "unused", nil }) - react := NewReAct(sig, lm, []core.Tool{*searchTool}).WithMaxIterations(10) - outputs, err := react.Forward(context.Background(), map[string]interface{}{ + react := NewReAct(sig, lm, []core.Tool{*dummyTool}) + result, err := react.Forward(context.Background(), map[string]interface{}{ "question": "test", }) + // With extraction, validation failures should be handled gracefully if err != nil { - t.Fatalf("Forward() error = %v", err) + t.Errorf("Forward() should not error with extraction fallback, got: %v", err) } - // Verify that the final answer was forced after stagnation - if outputs.Outputs["answer"] != "forced final answer" { - t.Errorf("Expected forced final answer after stagnation, got %v", outputs.Outputs["answer"]) + if result == nil { + t.Error("Forward() should return a result via extraction") } - // Verify that a stagnation prevention message was injected - stagnationMessageFound := false - for _, msg := range capturedMessages { - if msg.Role == "user" && contains(msg.Content, "same observation twice") { - stagnationMessageFound = true - break - } + if callCount != 2 { + t.Errorf("Expected 2 LM calls (initial + extraction), got %d", callCount) } +} - if !stagnationMessageFound { - t.Error("Expected stagnation prevention message to be injected") - } +func TestReAct_ExtractTextOutputs_ShortContent(t *testing.T) { + t.Parallel() + sig := core.NewSignature("Test"). + AddOutput("answer", core.FieldTypeString, "Answer") - // Verify the model was called at least 3 times (2 tool calls + 1 final answer after stagnation) - if callCount < 3 { - t.Errorf("Expected at least 3 LM calls (stagnation + recovery), got %d", callCount) + react := NewReAct(sig, &MockLM{}, []core.Tool{}) + + // Test with short content (< 10 chars) + messages := []core.Message{} + outputs := react.extractTextOutputs("short", messages) + + // Should synthesize from history even though there's no history + if outputs == nil { + t.Error("extractTextOutputs should return outputs for short content") } } -// TestReAct_Forward_WithHistory tests history management -func TestReAct_Forward_WithHistory(t *testing.T) { +func TestReAct_ExtractTextOutputs_NoStringFields(t *testing.T) { t.Parallel() - sig := core.NewSignature("Answer question"). - AddInput("question", core.FieldTypeString, "Question"). - AddOutput("answer", core.FieldTypeString, "Answer") + sig := core.NewSignature("Test"). + AddOutput("count", core.FieldTypeInt, "Count") - lm := &MockLM{ - SupportsJSONVal: true, - GenerateFunc: func(ctx context.Context, messages []core.Message, options *core.GenerateOptions) (*core.GenerateResult, error) { - return &core.GenerateResult{ - Content: `{"answer": "final answer with history"}`, - }, nil - }, + react := NewReAct(sig, &MockLM{}, []core.Tool{}) + + messages := []core.Message{} + outputs := react.extractTextOutputs("long enough content here", messages) + + if outputs != nil { + t.Error("extractTextOutputs should return nil when no string output fields") } +} - history := core.NewHistory() - history.Add(core.Message{Role: "user", Content: "previous question"}) - history.Add(core.Message{Role: "assistant", Content: "previous answer"}) +func TestReAct_ExtractTextOutputs_SingleField(t *testing.T) { + t.Parallel() + sig := core.NewSignature("Test"). + AddOutput("answer", core.FieldTypeString, "Answer") - react := NewReAct(sig, lm, []core.Tool{}).WithHistory(history) - outputs, err := react.Forward(context.Background(), map[string]interface{}{ - "question": "current question", - }) + react := NewReAct(sig, &MockLM{}, []core.Tool{}) - if err != nil { - t.Fatalf("Forward() error = %v", err) - } + content := "This is the final answer to the question" + messages := []core.Message{} + outputs := react.extractTextOutputs(content, messages) - if outputs.Outputs["answer"] != "final answer with history" { - t.Errorf("Expected answer with history, got %v", outputs.Outputs["answer"]) + if outputs == nil { + t.Fatal("extractTextOutputs should extract single field") } - // Verify history was updated - if history.Len() != 4 { // 2 previous + 1 user + 1 assistant - t.Errorf("Expected 4 messages in history, got %d", history.Len()) + if answer, ok := outputs["answer"].(string); !ok || answer != content { + t.Errorf("Expected answer='%s', got %v", content, outputs["answer"]) } } -// TestReAct_Forward_WithFinishTool tests the "finish" tool detection -func TestReAct_Forward_WithFinishTool(t *testing.T) { +func TestReAct_ExtractTextOutputs_MultipleFields(t *testing.T) { t.Parallel() - sig := core.NewSignature("Answer question"). - AddInput("question", core.FieldTypeString, "Question"). + sig := core.NewSignature("Test"). AddOutput("answer", core.FieldTypeString, "Answer"). - AddOutput("confidence", core.FieldTypeFloat, "Confidence") + AddOutput("reasoning", core.FieldTypeString, "Reasoning") - lm := &MockLM{ - SupportsToolsVal: true, - GenerateFunc: func(ctx context.Context, messages []core.Message, options *core.GenerateOptions) (*core.GenerateResult, error) { - return &core.GenerateResult{ - Content: "I have the answer", - ToolCalls: []core.ToolCall{ - { - ID: "finish-1", - Name: "finish", - Arguments: map[string]interface{}{ - "answer": "The answer is 42", - "confidence": 0.95, - }, - }, - }, - }, nil - }, - } + react := NewReAct(sig, &MockLM{}, []core.Tool{}) - react := NewReAct(sig, lm, []core.Tool{}) - outputs, err := react.Forward(context.Background(), map[string]interface{}{ - "question": "What is the answer?", - }) + content := "Based on my analysis, the final answer is 42" + messages := []core.Message{} + outputs := react.extractTextOutputs(content, messages) - if err != nil { - t.Fatalf("Forward() error = %v", err) + if outputs == nil { + t.Fatal("extractTextOutputs should extract multiple fields") } - if outputs.Outputs["answer"] != "The answer is 42" { - t.Errorf("Expected finish tool answer, got %v", outputs.Outputs["answer"]) + // First field should get the content + if answer, ok := outputs["answer"].(string); !ok || answer != content { + t.Errorf("Expected answer to be content, got %v", outputs["answer"]) } - if outputs.Outputs["confidence"] != 0.95 { - t.Errorf("Expected confidence 0.95, got %v", outputs.Outputs["confidence"]) + // Second required field should get a placeholder + if reasoning, ok := outputs["reasoning"].(string); !ok || reasoning == "" { + t.Errorf("Expected reasoning placeholder, got %v", outputs["reasoning"]) } } -// TestReAct_Forward_WithFinishTool_InvalidOutputs tests finish tool with validation errors -func TestReAct_Forward_WithFinishTool_InvalidOutputs(t *testing.T) { +func TestReAct_SynthesizeAnswerFromHistory_NoObservations(t *testing.T) { + t.Parallel() + react := NewReAct(core.NewSignature("Test"), &MockLM{}, []core.Tool{}) + + messages := []core.Message{ + {Role: "user", Content: "test question"}, + {Role: "assistant", Content: "thinking"}, + } + + result := react.synthesizeAnswerFromHistory(messages) + if result != "No information available from tools" { + t.Errorf("Expected 'No information available' message, got '%s'", result) + } +} + +func TestReAct_SynthesizeAnswerFromHistory_WithObservations(t *testing.T) { + t.Parallel() + react := NewReAct(core.NewSignature("Test"), &MockLM{}, []core.Tool{}) + + messages := []core.Message{ + {Role: "user", Content: "test question"}, + {Role: "tool", Content: "The weather is sunny"}, + {Role: "assistant", Content: "thinking"}, + {Role: "tool", Content: "Temperature is 25 degrees"}, + } + + result := react.synthesizeAnswerFromHistory(messages) + + // Should use recent observations + if result == "No information available from tools" { + t.Error("Should synthesize from tool observations") + } + + // Should contain one of the tool observations + if !contains(result, "sunny") && !contains(result, "25 degrees") { + t.Errorf("Result should contain tool observations, got '%s'", result) + } +} + +func TestReAct_SynthesizeAnswerFromHistory_SkipsErrors(t *testing.T) { + t.Parallel() + react := NewReAct(core.NewSignature("Test"), &MockLM{}, []core.Tool{}) + + messages := []core.Message{ + {Role: "tool", Content: "Error: tool failed"}, + {Role: "tool", Content: "Valid observation here and it is definitely longer than 20 characters"}, + } + + result := react.synthesizeAnswerFromHistory(messages) + + // Should not include error messages + if contains(result, "Error:") { + t.Error("Should skip error messages in synthesis") + } + + if !contains(result, "Valid observation") { + t.Errorf("Should include valid observation, got '%s'", result) + } +} + +func TestReAct_SynthesizeAnswerFromHistory_DeduplicatesObservations(t *testing.T) { + t.Parallel() + react := NewReAct(core.NewSignature("Test"), &MockLM{}, []core.Tool{}) + + duplicateObs := "This is a long observation that will be duplicated to test deduplication" + messages := []core.Message{ + {Role: "tool", Content: duplicateObs}, + {Role: "tool", Content: duplicateObs}, // Duplicate + {Role: "tool", Content: "Different observation that is also long enough to be considered"}, + } + + result := react.synthesizeAnswerFromHistory(messages) + + // Should only have unique observations (up to 3) + // Count occurrences of duplicate string + count := 0 + content := result + for i := 0; i < len(content); { + idx := strings.Index(content[i:], "duplicated") + if idx == -1 { + break + } + count++ + i += idx + 1 + } + + if count > 1 { + t.Errorf("Should deduplicate observations, found %d occurrences", count) + } +} + +func TestReAct_SynthesizeAnswerFromHistory_LimitsToThreeObservations(t *testing.T) { + t.Parallel() + react := NewReAct(core.NewSignature("Test"), &MockLM{}, []core.Tool{}) + + messages := []core.Message{ + {Role: "tool", Content: "First observation is definitely longer than twenty characters"}, + {Role: "tool", Content: "Second observation is definitely longer than twenty characters"}, + {Role: "tool", Content: "Third observation is definitely longer than twenty characters"}, + {Role: "tool", Content: "Fourth observation is definitely longer than twenty characters"}, + {Role: "tool", Content: "Fifth observation is definitely longer than twenty characters"}, + } + + result := react.synthesizeAnswerFromHistory(messages) + + // Should use most recent 3 unique observations + if contains(result, "First") && contains(result, "Second") { + t.Error("Should limit to 3 most recent observations") + } +} + +func TestReAct_SynthesizeAnswerFromHistory_SkipsShortObservations(t *testing.T) { + t.Parallel() + react := NewReAct(core.NewSignature("Test"), &MockLM{}, []core.Tool{}) + + messages := []core.Message{ + {Role: "tool", Content: "short"}, + {Role: "tool", Content: "This is a longer observation that should be included"}, + } + + result := react.synthesizeAnswerFromHistory(messages) + + if contains(result, "short") && !contains(result, "longer observation") { + t.Errorf("Should skip observations <= 20 chars, got '%s'", result) + } +} + +// TestReAct_ExtractionWithReasoning verifies that runExtract uses reasoning adapter +// and attaches rationale to the prediction when hitting MaxIterations +func TestReAct_ExtractionWithReasoning(t *testing.T) { t.Parallel() sig := core.NewSignature("Answer question"). AddInput("question", core.FieldTypeString, "Question"). AddOutput("answer", core.FieldTypeString, "Answer"). - AddOutput("score", core.FieldTypeInt, "Score") + AddOutput("confidence", core.FieldTypeInt, "Confidence score") - callCount := 0 + iterationCount := 0 lm := &MockLM{ SupportsToolsVal: true, + SupportsJSONVal: true, GenerateFunc: func(ctx context.Context, messages []core.Message, options *core.GenerateOptions) (*core.GenerateResult, error) { - callCount++ - if callCount == 1 { - // First call: finish tool with invalid outputs (missing score) + iterationCount++ + + // Check ToolChoice to determine mode (tools are now always present for provider compatibility) + // ToolChoice == "auto" means tool-using mode, ToolChoice == "none" means final/extraction mode + toolsEnabled := options.ToolChoice != "none" && len(options.Tools) > 0 + + // Tool-using mode: return tool calls to force hitting MaxIterations + // Use different queries to avoid stagnation detection + if toolsEnabled { + query := fmt.Sprintf("test query %d", iterationCount) return &core.GenerateResult{ - Content: "Trying to finish", + Content: "Using search tool", ToolCalls: []core.ToolCall{ { - ID: "finish-1", - Name: "finish", - Arguments: map[string]interface{}{ - "answer": "incomplete", + ID: fmt.Sprintf("call_%d", iterationCount), + Name: "search", + Arguments: map[string]any{ + "query": query, }, }, }, }, nil } - // Second call: proper final answer + + // No tools mode (final mode or extraction) + // During final mode (iteration 2): return malformed JSON to force extraction + // During extraction (iteration 3): return proper JSON with reasoning + if iterationCount == 2 { + // Return malformed JSON that will fail parsing and trigger extraction + return &core.GenerateResult{ + Content: "I'm thinking about it but not formatting correctly", + }, nil + } + + // Extraction phase (iteration 3): return proper answer with reasoning return &core.GenerateResult{ - Content: `{"answer": "complete answer", "score": 85}`, + Content: `{ + "rationale": "Based on all the tool observations, I can now provide the final answer.", + "answer": "The answer based on search results", + "confidence": 95 + }`, }, nil }, } - react := NewReAct(sig, lm, []core.Tool{}) - outputs, err := react.Forward(context.Background(), map[string]interface{}{ - "question": "test", + callNumber := 0 + searchTool := core.NewTool( + "search", + "Search for information", + func(ctx context.Context, args map[string]any) (any, error) { + callNumber++ + return fmt.Sprintf("Search results %d: relevant information", callNumber), nil + }, + ).AddParameter("query", "string", "Search query", true) + + react := NewReAct(sig, lm, []core.Tool{*searchTool}). + WithMaxIterations(2). + WithVerbose(false) + + result, err := react.Forward(context.Background(), map[string]any{ + "question": "What is the answer?", }) if err != nil { t.Fatalf("Forward() error = %v", err) } - if outputs.Outputs["answer"] != "complete answer" { - t.Error("Should recover from invalid finish tool and provide proper answer") + // Should have hit MaxIterations and triggered extraction + // 2 tool-using iterations + 1 extraction call = 3 total + if iterationCount < 3 { + t.Errorf("Expected at least 3 LM calls (2 iterations + extraction), got %d", iterationCount) } - if callCount != 2 { - t.Errorf("Expected 2 calls (invalid finish + recovery), got %d", callCount) + // Check that answer was extracted + answer, ok := result.GetString("answer") + if !ok { + t.Error("Expected answer field in result") + } + if !contains(answer, "answer based on search") { + t.Errorf("Expected answer to contain extracted text, got: %s", answer) + } + + // CRITICAL: Check that rationale was attached to prediction + if result.Rationale == "" { + t.Error("Expected non-empty rationale from extraction phase with reasoning adapter") + } + if !contains(result.Rationale, "tool observations") { + t.Errorf("Expected rationale to contain reasoning, got: %s", result.Rationale) + } + + // Verify rationale was removed from outputs (not part of signature) + if _, exists := result.Outputs["rationale"]; exists { + t.Error("Rationale should be removed from outputs map") + } + if _, exists := result.Outputs["reasoning"]; exists { + t.Error("Reasoning should be removed from outputs map") } } -// TestReAct_Forward_WithReasoning tests reasoning field extraction and cleanup -func TestReAct_Forward_WithReasoning(t *testing.T) { +// TestReAct_ImplicitFinish tests that ReAct accepts direct answers without tool calls. +// This validates the "Implicit Finish" pattern where the model provides a valid answer +// directly instead of using tools, which is correct behavior for native tool calling APIs. +func TestReAct_ImplicitFinish(t *testing.T) { t.Parallel() sig := core.NewSignature("Answer question"). AddInput("question", core.FieldTypeString, "Question"). AddOutput("answer", core.FieldTypeString, "Answer") + callCount := 0 lm := &MockLM{ - SupportsJSONVal: true, + SupportsToolsVal: true, GenerateFunc: func(ctx context.Context, messages []core.Message, options *core.GenerateOptions) (*core.GenerateResult, error) { + callCount++ + // Model returns valid JSON without making any tool calls (implicit finish) return &core.GenerateResult{ - Content: `{"reasoning": "Let me think about this...", "answer": "The answer"}`, + Content: `{"answer": "42"}`, + ToolCalls: []core.ToolCall{}, // Empty - no tool calls }, nil }, } - react := NewReAct(sig, lm, []core.Tool{}) - outputs, err := react.Forward(context.Background(), map[string]interface{}{ - "question": "test", + searchTool := core.NewTool("search", "Search for info", func(ctx context.Context, args map[string]any) (any, error) { + t.Error("Tool should not be executed in implicit finish scenario") + return "search result", nil + }) + + react := NewReAct(sig, lm, []core.Tool{*searchTool}) + result, err := react.Forward(context.Background(), map[string]interface{}{ + "question": "What is the answer to life?", }) + // Verify: err == nil (success) if err != nil { - t.Fatalf("Forward() error = %v", err) + t.Fatalf("Forward() error = %v, want nil", err) } - // Reasoning should be extracted to rationale - if outputs.Rationale != "Let me think about this..." { - t.Errorf("Expected rationale to be set, got %q", outputs.Rationale) + // Direct answer was signature-valid; no extraction call needed. + if callCount != 1 { + t.Errorf("Expected 1 LM call, got %d", callCount) } - // Reasoning should be removed from outputs if not in signature - if _, exists := outputs.Outputs["reasoning"]; exists { - t.Error("Reasoning should be removed from outputs when not in signature") + // Verify: result.Outputs["answer"] == "42" + if result.Outputs["answer"] != "42" { + t.Errorf("Expected answer='42', got %v", result.Outputs["answer"]) } } -// TestReAct_Forward_WithReasoningInSignature tests when reasoning is part of the signature -func TestReAct_Forward_WithReasoningInSignature(t *testing.T) { +// TestReAct_ImplicitFinish_MalformedRetry tests the retry mechanism when implicit finish +// fails validation in early iterations. The model should be guided to use tools. +// Note: This test uses int fields to ensure malformed content fails validation, +// triggering the retry mechanism. String-only signatures would use text extraction +// as a fallback and accept malformed content. +func TestReAct_ImplicitFinish_MalformedRetry(t *testing.T) { t.Parallel() - sig := core.NewSignature("Answer question"). + // Use an int output field so malformed text fails validation + sig := core.NewSignature("Calculate something"). AddInput("question", core.FieldTypeString, "Question"). - AddOutput("reasoning", core.FieldTypeString, "Reasoning"). - AddOutput("answer", core.FieldTypeString, "Answer") + AddOutput("count", core.FieldTypeInt, "Count result") + callCount := 0 lm := &MockLM{ - SupportsJSONVal: true, + SupportsToolsVal: true, + SupportsJSONVal: true, GenerateFunc: func(ctx context.Context, messages []core.Message, options *core.GenerateOptions) (*core.GenerateResult, error) { + callCount++ + + if callCount == 1 { + // First call: return malformed text without tool calls + // This fails validation because "count" expects int, gets no valid int + return &core.GenerateResult{ + Content: "thinking about this problem without any numbers", + ToolCalls: []core.ToolCall{}, + }, nil + } + // Second call: return valid JSON with int (recovery) return &core.GenerateResult{ - Content: `{"reasoning": "Thinking step by step", "answer": "42"}`, + Content: `{"count": 42}`, + ToolCalls: []core.ToolCall{}, }, nil }, } - react := NewReAct(sig, lm, []core.Tool{}) - outputs, err := react.Forward(context.Background(), map[string]interface{}{ + searchTool := core.NewTool("search", "Search for info", func(ctx context.Context, args map[string]any) (any, error) { + return "search result", nil + }) + + react := NewReAct(sig, lm, []core.Tool{*searchTool}).WithMaxIterations(5) + result, err := react.Forward(context.Background(), map[string]interface{}{ + "question": "What is the count?", + }) + + // Verify: err == nil (success after retry) + if err != nil { + t.Fatalf("Forward() error = %v, want nil", err) + } + + // Loop terminates and extractor produces the final structured output. + if callCount != 2 { + t.Errorf("Expected 2 LM calls (loop + extraction), got %d", callCount) + } + + // Verify: result.Outputs["count"] == 42 + count, ok := result.GetInt("count") + if !ok || count != 42 { + t.Errorf("Expected count=42, got %v", result.Outputs["count"]) + } +} + +// TestReAct_WithMethods tests all ReAct configuration methods +func TestReAct_WithMethods(t *testing.T) { + t.Parallel() + sig := core.NewSignature("test"). + AddInput("question", core.FieldTypeString, ""). + AddOutput("answer", core.FieldTypeString, "") + + lm := &MockLM{} + tools := []core.Tool{} + history := core.NewHistory() + demos := []core.Example{ + *core.NewExample( + map[string]any{"question": "test"}, + map[string]any{"answer": "test"}, + ), + } + adapter := core.NewJSONAdapter() + + react := NewReAct(sig, lm, tools). + WithAdapter(adapter). + WithHistory(history). + WithDemos(demos) + + if react.Adapter != adapter { + t.Error("WithAdapter should set adapter") + } + if react.History != history { + t.Error("WithHistory should set history") + } + if len(react.Demos) != 1 { + t.Error("WithDemos should set demos") + } +} + +// TestReAct_UsageAccumulation tests that usage (tokens, cost, latency) accumulates correctly across multiple iterations +func TestReAct_UsageAccumulation(t *testing.T) { + t.Parallel() + sig := core.NewSignature("Answer question"). + AddInput("question", core.FieldTypeString, "Question"). + AddOutput("answer", core.FieldTypeString, "Answer") + + callCount := 0 + lm := &MockLM{ + SupportsToolsVal: true, + SupportsJSONVal: true, + GenerateFunc: func(ctx context.Context, messages []core.Message, options *core.GenerateOptions) (*core.GenerateResult, error) { + callCount++ + switch callCount { + case 1: + // Loop: tool call + return &core.GenerateResult{ + Content: "Let me search", + ToolCalls: []core.ToolCall{{ID: "1", Name: "search", Arguments: map[string]interface{}{"query": "test"}}}, + Usage: core.Usage{ + PromptTokens: 100, + CompletionTokens: 50, + TotalTokens: 150, + Cost: 0.001, + Latency: 500 * 1_000_000, + }, + }, nil + case 2: + // Loop: invalid direct answer (forces extraction). + return &core.GenerateResult{ + Content: `{"wrong":"field"}`, + Usage: core.Usage{ + PromptTokens: 200, + CompletionTokens: 100, + TotalTokens: 300, + Cost: 0.002, + Latency: 600 * 1_000_000, + }, + }, nil + default: + // Extraction: final structured answer. + return &core.GenerateResult{ + Content: `{"answer":"final answer"}`, + Usage: core.Usage{ + PromptTokens: 50, + CompletionTokens: 25, + TotalTokens: 75, + Cost: 0.0005, + Latency: 250 * 1_000_000, + }, + }, nil + } + }, + } + + searchTool := core.NewTool("search", "Search for info", func(ctx context.Context, args map[string]any) (any, error) { + return "search result", nil + }) + + react := NewReAct(sig, lm, []core.Tool{*searchTool}) + pred, err := react.Forward(context.Background(), map[string]interface{}{ "question": "test", }) @@ -930,867 +1061,853 @@ func TestReAct_Forward_WithReasoningInSignature(t *testing.T) { t.Fatalf("Forward() error = %v", err) } - // Reasoning should be in rationale - if outputs.Rationale != "Thinking step by step" { - t.Errorf("Expected rationale to be set, got %q", outputs.Rationale) + // Verify usage accumulation across loop + extraction (3 LM calls) + expectedPromptTokens := 100 + 200 + 50 + expectedCompletionTokens := 50 + 100 + 25 + expectedTotalTokens := 150 + 300 + 75 + expectedCost := 0.001 + 0.002 + 0.0005 + expectedLatency := 500 + 600 + 250 + + if pred.Usage.PromptTokens != expectedPromptTokens { + t.Errorf("PromptTokens: expected %d, got %d", expectedPromptTokens, pred.Usage.PromptTokens) + } + if pred.Usage.CompletionTokens != expectedCompletionTokens { + t.Errorf("CompletionTokens: expected %d, got %d", expectedCompletionTokens, pred.Usage.CompletionTokens) + } + if pred.Usage.TotalTokens != expectedTotalTokens { + t.Errorf("TotalTokens: expected %d, got %d", expectedTotalTokens, pred.Usage.TotalTokens) + } + if pred.Usage.Cost != expectedCost { + t.Errorf("Cost: expected %.6f, got %.6f", expectedCost, pred.Usage.Cost) } - // Reasoning should remain in outputs when it's in the signature - if _, exists := outputs.Outputs["reasoning"]; !exists { - t.Error("Reasoning should remain in outputs when it's part of the signature") + // Ensure extraction answer wins. + if pred.Outputs["answer"] != "final answer" { + t.Errorf("expected final answer, got %v", pred.Outputs["answer"]) + } + + expectedLatencyNs := int64(expectedLatency) * 1_000_000 + if pred.Usage.Latency != expectedLatencyNs { + t.Errorf("Latency: expected %d ns (%.2fms), got %d ns (%.2fms)", + expectedLatencyNs, float64(expectedLatencyNs)/1_000_000, + pred.Usage.Latency, float64(pred.Usage.Latency)/1_000_000) } } -// TestReAct_Forward_JSONModeWithJSONAdapter tests JSON mode enablement with JSONAdapter -func TestReAct_Forward_JSONModeWithJSONAdapter(t *testing.T) { +// TestReAct_ToolsSliceNotMutated tests that the caller's tools slice is not mutated by NewReAct +func TestReAct_ToolsSliceNotMutated(t *testing.T) { t.Parallel() sig := core.NewSignature("Answer question"). AddInput("question", core.FieldTypeString, "Question"). AddOutput("answer", core.FieldTypeString, "Answer") - optionsCaptured := false - lm := &MockLM{ - SupportsJSONVal: true, - GenerateFunc: func(ctx context.Context, messages []core.Message, options *core.GenerateOptions) (*core.GenerateResult, error) { - if options.ResponseFormat == "json" { - optionsCaptured = true - } - return &core.GenerateResult{ - Content: `{"answer": "json mode answer"}`, - }, nil - }, - } + lm := &MockLM{} - react := NewReAct(sig, lm, []core.Tool{}).WithAdapter(core.NewJSONAdapter()) - outputs, err := react.Forward(context.Background(), map[string]interface{}{ - "question": "test", + // Create tools slice with specific capacity to detect append mutation + originalTools := make([]core.Tool, 1, 10) // capacity > len to allow in-place append + searchTool := core.NewTool("search", "Search for info", func(ctx context.Context, args map[string]any) (any, error) { + return "result", nil }) + originalTools[0] = *searchTool - if err != nil { - t.Fatalf("Forward() error = %v", err) - } + // Capture original length + originalLen := len(originalTools) - if !optionsCaptured { - t.Error("JSON mode should be enabled when using JSONAdapter and LM supports JSON") + // Create ReAct which auto-injects finish tool + _ = NewReAct(sig, lm, originalTools) + + // Verify caller's slice was NOT modified + if len(originalTools) != originalLen { + t.Errorf("Caller's tools slice was mutated: expected len %d, got %d", originalLen, len(originalTools)) } - if outputs.Outputs["answer"] != "json mode answer" { - t.Errorf("Expected answer, got %v", outputs.Outputs["answer"]) + // Verify finish tool was NOT appended to caller's slice + for _, tool := range originalTools { + if strings.ToLower(tool.Name) == "finish" { + t.Error("Finish tool should not appear in caller's original tools slice") + } } } -// TestReAct_Forward_MultipleToolCalls tests multiple tool calls in one iteration -func TestReAct_Forward_MultipleToolCalls(t *testing.T) { +// TestReAct_NonToolLM_NoToolsPassedInFinalMode tests that tools are not passed to LMs that don't support them in final mode +func TestReAct_NonToolLM_NoToolsPassedInFinalMode(t *testing.T) { t.Parallel() sig := core.NewSignature("Answer question"). AddInput("question", core.FieldTypeString, "Question"). AddOutput("answer", core.FieldTypeString, "Answer") + var finalModeOptions *core.GenerateOptions callCount := 0 lm := &MockLM{ - SupportsToolsVal: true, + SupportsToolsVal: false, // LM does NOT support tools GenerateFunc: func(ctx context.Context, messages []core.Message, options *core.GenerateOptions) (*core.GenerateResult, error) { callCount++ if callCount == 1 { + // First call: return something that doesn't parse well to trigger iteration return &core.GenerateResult{ - Content: "Using multiple tools", - ToolCalls: []core.ToolCall{ - {ID: "1", Name: "search", Arguments: map[string]interface{}{"query": "test1"}}, - {ID: "2", Name: "calculate", Arguments: map[string]interface{}{"expr": "2+2"}}, - }, + Content: "I need to think about this...", + }, nil + } + if callCount == 2 { + // Second call: still no good answer, will trigger final mode + return &core.GenerateResult{ + Content: "Still thinking...", }, nil } + // Third call onward: final mode - capture options here + finalModeOptions = options return &core.GenerateResult{ - Content: `{"answer": "combined result"}`, + Content: `{"answer": "final answer"}`, }, nil }, } - searchTool := core.NewTool("search", "Search", func(ctx context.Context, args map[string]any) (any, error) { - return "search result", nil - }) - calcTool := core.NewTool("calculate", "Calculate", func(ctx context.Context, args map[string]any) (any, error) { - return "4", nil - }) - - react := NewReAct(sig, lm, []core.Tool{*searchTool, *calcTool}) - outputs, err := react.Forward(context.Background(), map[string]interface{}{ - "question": "test", + searchTool := core.NewTool("search", "Search for info", func(ctx context.Context, args map[string]any) (any, error) { + return "result", nil }) + react := NewReAct(sig, lm, []core.Tool{*searchTool}).WithMaxIterations(4) + _, err := react.Forward(context.Background(), map[string]any{"question": "test"}) if err != nil { t.Fatalf("Forward() error = %v", err) } - if outputs.Outputs["answer"] != "combined result" { - t.Error("Should handle multiple tool calls in one iteration") + // Verify final mode options: tools should NOT be passed since LM doesn't support them + if finalModeOptions != nil { + if len(finalModeOptions.Tools) > 0 { + t.Errorf("Tools should not be passed to non-tool LM in final mode, got %d tools", len(finalModeOptions.Tools)) + } + if finalModeOptions.ToolChoice == "none" { + t.Errorf("ToolChoice should not be 'none' for non-tool LM, got %q", finalModeOptions.ToolChoice) + } } } -// TestReAct_Forward_WithDemos tests few-shot examples -func TestReAct_Forward_WithDemos(t *testing.T) { +// ============================================================================ +// reactTrajectory tests +// ============================================================================ + +func TestReActTrajectory_Render_EmptySteps(t *testing.T) { t.Parallel() - sig := core.NewSignature("Answer question"). - AddInput("question", core.FieldTypeString, "Question"). - AddOutput("answer", core.FieldTypeString, "Answer") + base := []core.Message{{Role: "system", Content: "You are helpful."}} + traj := newReActTrajectory(base) - lm := &MockLM{ - SupportsJSONVal: true, - GenerateFunc: func(ctx context.Context, messages []core.Message, options *core.GenerateOptions) (*core.GenerateResult, error) { - return &core.GenerateResult{ - Content: `{"answer": "demo-informed answer"}`, - }, nil - }, + msgs := traj.Render(1000) + if len(msgs) != 1 { + t.Errorf("expected 1 message, got %d", len(msgs)) } - - demos := []core.Example{ - { - Inputs: map[string]any{"question": "What is 2+2?"}, - Outputs: map[string]any{"answer": "4"}, - }, + if msgs[0].Role != "system" { + t.Errorf("expected system message, got %s", msgs[0].Role) } +} - react := NewReAct(sig, lm, []core.Tool{}).WithDemos(demos) - outputs, err := react.Forward(context.Background(), map[string]interface{}{ - "question": "What is 3+3?", - }) +func TestReActTrajectory_Render_AlwaysIncludesNewestStep(t *testing.T) { + t.Parallel() + base := []core.Message{{Role: "system", Content: "base"}} + traj := newReActTrajectory(base) - if err != nil { - t.Fatalf("Forward() error = %v", err) - } + // Add a step that exceeds the budget by itself. + largeThought := string(make([]byte, 500)) + traj.AddStep(largeThought, nil) - if outputs.Outputs["answer"] != "demo-informed answer" { - t.Errorf("Expected demo-informed answer, got %v", outputs.Outputs["answer"]) + // Budget is 100 bytes, base is ~4 bytes, but newest step should still be included. + msgs := traj.Render(100) + if len(msgs) < 2 { + t.Errorf("expected at least 2 messages (base + newest step), got %d", len(msgs)) } } -// TestReAct_Forward_AdapterMetrics tests adapter metadata extraction -func TestReAct_Forward_AdapterMetrics(t *testing.T) { +func TestReActTrajectory_Render_SelectsSuffix(t *testing.T) { t.Parallel() - sig := core.NewSignature("Answer question"). - AddInput("question", core.FieldTypeString, "Question"). - AddOutput("answer", core.FieldTypeString, "Answer") + base := []core.Message{{Role: "system", Content: "x"}} + traj := newReActTrajectory(base) - lm := &MockLM{ - SupportsJSONVal: true, - GenerateFunc: func(ctx context.Context, messages []core.Message, options *core.GenerateOptions) (*core.GenerateResult, error) { - // JSON format succeeds on first try (JSONAdapter is now first in fallback chain) - return &core.GenerateResult{ - Content: `{"answer": "test"}`, - }, nil - }, + // Add 5 steps, each ~10 bytes. + for i := 0; i < 5; i++ { + traj.AddStep("thought123", nil) } - react := NewReAct(sig, lm, []core.Tool{}) - outputs, err := react.Forward(context.Background(), map[string]interface{}{ - "question": "test", - }) - - if err != nil { - t.Fatalf("Forward() error = %v", err) - } + // Budget allows base (~1 byte) + ~30 bytes = ~3 steps. + msgs := traj.Render(35) - if outputs.AdapterUsed == "" { - t.Error("Expected adapter metadata to be extracted") + // Should get base + some suffix of steps. + // We expect at least 2 messages (base + at least 1 step). + if len(msgs) < 2 { + t.Errorf("expected at least 2 messages, got %d", len(msgs)) } - // JSON format should succeed on first attempt (no fallback needed) - if outputs.ParseAttempts != 1 { - t.Errorf("Expected 1 parse attempt, got %d", outputs.ParseAttempts) + // The last message should be from the newest step (step 5). + found := false + for _, m := range msgs { + if m.Content == "thought123" && m.Role == "assistant" { + found = true + } } - - if outputs.FallbackUsed { - t.Error("Expected fallback_used to be false for JSON format") + if !found { + t.Error("expected newest step content to be present") } } -// TestReAct_Forward_OutputValidationError tests validation errors after parsing -func TestReAct_Forward_OutputValidationError(t *testing.T) { +func TestReActTrajectory_Render_BaseExceedsBudget(t *testing.T) { t.Parallel() - sig := core.NewSignature("Answer question"). - AddInput("question", core.FieldTypeString, "Question"). - AddOutput("answer", core.FieldTypeString, "Answer"). - AddOutput("score", core.FieldTypeInt, "Required score") + // Base itself exceeds budget. + largeBase := []core.Message{{Role: "system", Content: string(make([]byte, 1000))}} + traj := newReActTrajectory(largeBase) + traj.AddStep("small", nil) - callCount := 0 - lm := &MockLM{ - SupportsJSONVal: true, - GenerateFunc: func(ctx context.Context, messages []core.Message, options *core.GenerateOptions) (*core.GenerateResult, error) { - callCount++ - if callCount == 1 { - // Missing required "score" field - will trigger extraction - return &core.GenerateResult{ - Content: `{"answer": "incomplete"}`, - }, nil - } - // Extraction call - provide complete answer - return &core.GenerateResult{ - Content: `{"answer": "extracted answer", "score": 42}`, - }, nil - }, + // Budget is only 100 bytes, base is 1000 bytes. + msgs := traj.Render(100) + + // Should only return base, no steps (remaining budget is <= 0). + if len(msgs) != 1 { + t.Errorf("expected only base message, got %d messages", len(msgs)) } +} - react := NewReAct(sig, lm, []core.Tool{}) - result, err := react.Forward(context.Background(), map[string]interface{}{ - "question": "test", - }) +func TestReActTrajectory_DropOldestSteps(t *testing.T) { + t.Parallel() + base := []core.Message{{Role: "system", Content: "x"}} + traj := newReActTrajectory(base) - // With extraction, validation failures should be handled gracefully - if err != nil { - t.Errorf("Forward() should not error with extraction fallback, got: %v", err) + for i := 0; i < 5; i++ { + traj.AddStep("step", nil) } - if result == nil { - t.Error("Forward() should return a result via extraction") + // Drop 2 oldest. + dropped := traj.DropOldestSteps(2) + if dropped != 2 { + t.Errorf("expected 2 dropped, got %d", dropped) } - if callCount != 2 { - t.Errorf("Expected 2 LM calls (initial + extraction), got %d", callCount) + // Should have 3 steps remaining. + msgs := traj.Render(100000) + // Base (1) + 3 steps (3 assistant messages) = 4. + if len(msgs) != 4 { + t.Errorf("expected 4 messages, got %d", len(msgs)) } } -func TestReAct_ExtractTextOutputs_ShortContent(t *testing.T) { +func TestReActTrajectory_DropOldestSteps_MoreThanAvailable(t *testing.T) { t.Parallel() - sig := core.NewSignature("Test"). - AddOutput("answer", core.FieldTypeString, "Answer") + traj := newReActTrajectory(nil) + traj.AddStep("s1", nil) + traj.AddStep("s2", nil) - react := NewReAct(sig, &MockLM{}, []core.Tool{}) + dropped := traj.DropOldestSteps(10) + if dropped != 2 { + t.Errorf("expected 2 dropped (all available), got %d", dropped) + } - // Test with short content (< 10 chars) - messages := []core.Message{} - outputs := react.extractTextOutputs("short", messages) + msgs := traj.Render(100000) + if len(msgs) != 0 { + t.Errorf("expected 0 messages after dropping all, got %d", len(msgs)) + } +} - // Should synthesize from history even though there's no history - if outputs == nil { - t.Error("extractTextOutputs should return outputs for short content") +func TestReActTrajectory_DropOldestSteps_Zero(t *testing.T) { + t.Parallel() + traj := newReActTrajectory(nil) + traj.AddStep("s1", nil) + + dropped := traj.DropOldestSteps(0) + if dropped != 0 { + t.Errorf("expected 0 dropped, got %d", dropped) } } -func TestReAct_ExtractTextOutputs_NoStringFields(t *testing.T) { +func TestReActTrajectory_HasToolContent(t *testing.T) { t.Parallel() - sig := core.NewSignature("Test"). - AddOutput("count", core.FieldTypeInt, "Count") - react := NewReAct(sig, &MockLM{}, []core.Tool{}) + t.Run("no tool content", func(t *testing.T) { + traj := newReActTrajectory(nil) + traj.AddStep("thought only", nil) + if traj.HasToolContent() { + t.Error("expected no tool content") + } + }) - messages := []core.Message{} - outputs := react.extractTextOutputs("long enough content here", messages) + t.Run("with tool calls", func(t *testing.T) { + traj := newReActTrajectory(nil) + traj.AddStep("thought", []core.ToolCall{{ID: "1", Name: "search"}}) + if !traj.HasToolContent() { + t.Error("expected tool content from tool calls") + } + }) + + t.Run("with tool results", func(t *testing.T) { + traj := newReActTrajectory(nil) + step := traj.AddStep("thought", nil) + step.AddToolResult(reactToolResult{ToolCallID: "1", Content: "result"}) + if !traj.HasToolContent() { + t.Error("expected tool content from tool results") + } + }) +} + +func TestReActStep_ToMessages(t *testing.T) { + t.Parallel() + step := &reactStep{ + Thought: "I'm thinking", + ToolCalls: []core.ToolCall{{ID: "1", Name: "search"}}, + ToolResults: []reactToolResult{ + {ToolCallID: "1", Content: `{"tool":"search","ok":true}`}, + }, + Errors: []string{"error message"}, + } - if outputs != nil { - t.Error("extractTextOutputs should return nil when no string output fields") + msgs := step.toMessages() + if len(msgs) != 3 { + t.Fatalf("expected 3 messages, got %d", len(msgs)) + } + + if msgs[0].Role != "assistant" { + t.Errorf("expected assistant role, got %s", msgs[0].Role) + } + if msgs[1].Role != "tool" { + t.Errorf("expected tool role, got %s", msgs[1].Role) + } + if msgs[2].Role != "system" { + t.Errorf("expected system role for error, got %s", msgs[2].Role) } } -func TestReAct_ExtractTextOutputs_SingleField(t *testing.T) { - t.Parallel() - sig := core.NewSignature("Test"). - AddOutput("answer", core.FieldTypeString, "Answer") +// ============================================================================ +// encodeToolResult tests +// ============================================================================ - react := NewReAct(sig, &MockLM{}, []core.Tool{}) +func TestEncodeToolResult_SmallResult(t *testing.T) { + t.Parallel() + content, truncated, hash := encodeToolResult("search", "call-1", "hello", nil, 1000) - content := "This is the final answer to the question" - messages := []core.Message{} - outputs := react.extractTextOutputs(content, messages) + if truncated { + t.Error("expected not truncated") + } - if outputs == nil { - t.Fatal("extractTextOutputs should extract single field") + var env toolResultEnvelope + if err := json.Unmarshal([]byte(content), &env); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + if env.Tool != "search" { + t.Errorf("expected tool=search, got %s", env.Tool) + } + if !env.OK { + t.Error("expected OK=true") + } + if env.Result != "hello" { + t.Errorf("expected result=hello, got %v", env.Result) } - if answer, ok := outputs["answer"].(string); !ok || answer != content { - t.Errorf("Expected answer='%s', got %v", content, outputs["answer"]) + if hash == "" { + t.Error("expected non-empty hash") } } -func TestReAct_ExtractTextOutputs_MultipleFields(t *testing.T) { +func TestEncodeToolResult_WithError(t *testing.T) { t.Parallel() - sig := core.NewSignature("Test"). - AddOutput("answer", core.FieldTypeString, "Answer"). - AddOutput("reasoning", core.FieldTypeString, "Reasoning") - - react := NewReAct(sig, &MockLM{}, []core.Tool{}) - - content := "Based on my analysis, the final answer is 42" - messages := []core.Message{} - outputs := react.extractTextOutputs(content, messages) + content, _, _ := encodeToolResult("search", "call-1", nil, errors.New("tool failed"), 1000) - if outputs == nil { - t.Fatal("extractTextOutputs should extract multiple fields") + var env toolResultEnvelope + if err := json.Unmarshal([]byte(content), &env); err != nil { + t.Fatalf("invalid JSON: %v", err) } - - // First field should get the content - if answer, ok := outputs["answer"].(string); !ok || answer != content { - t.Errorf("Expected answer to be content, got %v", outputs["answer"]) + if env.OK { + t.Error("expected OK=false for error") } - - // Second required field should get a placeholder - if reasoning, ok := outputs["reasoning"].(string); !ok || reasoning == "" { - t.Errorf("Expected reasoning placeholder, got %v", outputs["reasoning"]) + if env.Error != "tool failed" { + t.Errorf("expected error message, got %s", env.Error) } } -func TestReAct_SynthesizeAnswerFromHistory_NoObservations(t *testing.T) { +func TestEncodeToolResult_Truncation(t *testing.T) { t.Parallel() - react := NewReAct(core.NewSignature("Test"), &MockLM{}, []core.Tool{}) + largeResult := string(make([]byte, 10000)) + content, truncated, _ := encodeToolResult("search", "call-1", largeResult, nil, 200) - messages := []core.Message{ - {Role: "user", Content: "test question"}, - {Role: "assistant", Content: "thinking"}, + if !truncated { + t.Error("expected truncated=true") + } + if len(content) > 200 { + t.Errorf("expected content <= 200 bytes, got %d", len(content)) } - result := react.synthesizeAnswerFromHistory(messages) - if result != "No information available from tools" { - t.Errorf("Expected 'No information available' message, got '%s'", result) + var env toolResultEnvelope + if err := json.Unmarshal([]byte(content), &env); err != nil { + t.Fatalf("truncated result should still be valid JSON: %v", err) + } + if !env.Truncated { + t.Error("expected Truncated=true in envelope") + } + if env.OriginalSize == 0 { + t.Error("expected OriginalSize to be set") } } -func TestReAct_SynthesizeAnswerFromHistory_WithObservations(t *testing.T) { +func TestEncodeToolResult_StableHashIgnoresToolCallID(t *testing.T) { t.Parallel() - react := NewReAct(core.NewSignature("Test"), &MockLM{}, []core.Tool{}) + // Same tool, same result, different tool_call_id should produce same hash. + _, _, hash1 := encodeToolResult("search", "call-1", "result", nil, 1000) + _, _, hash2 := encodeToolResult("search", "call-2", "result", nil, 1000) - messages := []core.Message{ - {Role: "user", Content: "test question"}, - {Role: "tool", Content: "The weather is sunny"}, - {Role: "assistant", Content: "thinking"}, - {Role: "tool", Content: "Temperature is 25 degrees"}, + if hash1 != hash2 { + t.Errorf("expected same hash for same content, different call IDs: %s vs %s", hash1, hash2) } +} - result := react.synthesizeAnswerFromHistory(messages) - - // Should use recent observations - if result == "No information available from tools" { - t.Error("Should synthesize from tool observations") - } +func TestEncodeToolResult_DifferentResultsDifferentHashes(t *testing.T) { + t.Parallel() + _, _, hash1 := encodeToolResult("search", "call-1", "result1", nil, 1000) + _, _, hash2 := encodeToolResult("search", "call-1", "result2", nil, 1000) - // Should contain one of the tool observations - if !contains(result, "sunny") && !contains(result, "25 degrees") { - t.Errorf("Result should contain tool observations, got '%s'", result) + if hash1 == hash2 { + t.Error("expected different hashes for different results") } } -func TestReAct_SynthesizeAnswerFromHistory_SkipsErrors(t *testing.T) { +func TestEncodeToolResult_UTF8SafeTruncation(t *testing.T) { t.Parallel() - react := NewReAct(core.NewSignature("Test"), &MockLM{}, []core.Tool{}) + // Use a multi-byte UTF-8 character. + result := "Hello 世界 🌍" + content, _, _ := encodeToolResult("search", "call-1", result, nil, 50) - messages := []core.Message{ - {Role: "tool", Content: "Error: tool failed"}, - {Role: "tool", Content: "Valid observation here and it is definitely longer than 20 characters"}, + var env toolResultEnvelope + if err := json.Unmarshal([]byte(content), &env); err != nil { + t.Fatalf("should produce valid JSON: %v", err) } +} - result := react.synthesizeAnswerFromHistory(messages) +func TestEncodeToolResult_MinimalEnvelope(t *testing.T) { + t.Parallel() + // Extremely small budget. + content, truncated, _ := encodeToolResult("x", "1", "huge data", nil, 50) - // Should not include error messages - if contains(result, "Error:") { - t.Error("Should skip error messages in synthesis") + if !truncated { + t.Error("expected truncated") } - if !contains(result, "Valid observation") { - t.Errorf("Should include valid observation, got '%s'", result) + var env toolResultEnvelope + if err := json.Unmarshal([]byte(content), &env); err != nil { + t.Fatalf("should produce valid JSON even with tiny budget: %v", err) } } -func TestReAct_SynthesizeAnswerFromHistory_DeduplicatesObservations(t *testing.T) { +// ============================================================================ +// reactTermination tests +// ============================================================================ + +func TestReActTermination_MarkDone_Idempotent(t *testing.T) { t.Parallel() - react := NewReAct(core.NewSignature("Test"), &MockLM{}, []core.Tool{}) + term := newReActTermination() - duplicateObs := "This is a long observation that will be duplicated to test deduplication" - messages := []core.Message{ - {Role: "tool", Content: duplicateObs}, - {Role: "tool", Content: duplicateObs}, // Duplicate - {Role: "tool", Content: "Different observation that is also long enough to be considered"}, + term.MarkDone(terminationRepeatedToolCall) + term.MarkDone(terminationStagnation) // Should be ignored. + + if term.Reason() != terminationRepeatedToolCall { + t.Errorf("expected reason %s, got %s", terminationRepeatedToolCall, term.Reason()) } +} - result := react.synthesizeAnswerFromHistory(messages) +func TestReActTermination_ObserveToolCall_RepeatedTriggers(t *testing.T) { + t.Parallel() + term := newReActTermination() - // Should only have unique observations (up to 3) - // Count occurrences of duplicate string - count := 0 - content := result - for i := 0; i < len(content); { - idx := strings.Index(content[i:], "duplicated") - if idx == -1 { - break - } - count++ - i += idx + 1 + tc := core.ToolCall{Name: "search", Arguments: map[string]any{"q": "test"}} + + // First call: sets fingerprint. + term.ObserveToolCall(tc) + if term.ShouldStop() { + t.Error("should not stop after first call") } - if count > 1 { - t.Errorf("Should deduplicate observations, found %d occurrences", count) + // Second call: increments count to 1. + term.ObserveToolCall(tc) + if term.ShouldStop() { + t.Error("should not stop after second call") + } + + // Third call: count becomes 2, triggers termination. + term.ObserveToolCall(tc) + if !term.ShouldStop() { + t.Error("should stop after third repeated call") + } + if term.Reason() != terminationRepeatedToolCall { + t.Errorf("expected reason %s, got %s", terminationRepeatedToolCall, term.Reason()) } } -func TestReAct_SynthesizeAnswerFromHistory_LimitsToThreeObservations(t *testing.T) { +func TestReActTermination_ObserveToolCall_DifferentCallsReset(t *testing.T) { t.Parallel() - react := NewReAct(core.NewSignature("Test"), &MockLM{}, []core.Tool{}) + term := newReActTermination() - messages := []core.Message{ - {Role: "tool", Content: "First observation is definitely longer than twenty characters"}, - {Role: "tool", Content: "Second observation is definitely longer than twenty characters"}, - {Role: "tool", Content: "Third observation is definitely longer than twenty characters"}, - {Role: "tool", Content: "Fourth observation is definitely longer than twenty characters"}, - {Role: "tool", Content: "Fifth observation is definitely longer than twenty characters"}, - } + tc1 := core.ToolCall{Name: "search", Arguments: map[string]any{"q": "test1"}} + tc2 := core.ToolCall{Name: "search", Arguments: map[string]any{"q": "test2"}} - result := react.synthesizeAnswerFromHistory(messages) + term.ObserveToolCall(tc1) + term.ObserveToolCall(tc1) // repeat count = 1 + term.ObserveToolCall(tc2) // Different call, resets. - // Should use most recent 3 unique observations - if contains(result, "First") && contains(result, "Second") { - t.Error("Should limit to 3 most recent observations") + if term.ShouldStop() { + t.Error("should not stop after changing tool call args") } } -func TestReAct_SynthesizeAnswerFromHistory_SkipsShortObservations(t *testing.T) { +func TestReActTermination_ObserveToolResult_Stagnation(t *testing.T) { t.Parallel() - react := NewReAct(core.NewSignature("Test"), &MockLM{}, []core.Tool{}) + term := newReActTermination() - messages := []core.Message{ - {Role: "tool", Content: "short"}, - {Role: "tool", Content: "This is a longer observation that should be included"}, + tc := core.ToolCall{Name: "search"} + hash := "same-hash-123" + + // First result: sets hash. + term.ObserveToolResult(tc, hash, nil) + if term.ShouldStop() { + t.Error("should not stop after first observation") } - result := react.synthesizeAnswerFromHistory(messages) + // Second result: increments. + term.ObserveToolResult(tc, hash, nil) + if term.ShouldStop() { + t.Error("should not stop after second observation") + } - if contains(result, "short") && !contains(result, "longer observation") { - t.Errorf("Should skip observations <= 20 chars, got '%s'", result) + // Third result: triggers stagnation. + term.ObserveToolResult(tc, hash, nil) + if !term.ShouldStop() { + t.Error("should stop after third repeated observation") + } + if term.Reason() != terminationStagnation { + t.Errorf("expected reason %s, got %s", terminationStagnation, term.Reason()) } } -// TestReAct_ExtractionWithReasoning verifies that runExtract uses reasoning adapter -// and attaches rationale to the prediction when hitting MaxIterations -func TestReAct_ExtractionWithReasoning(t *testing.T) { +func TestReActTermination_ObserveError_RepeatedErrors(t *testing.T) { t.Parallel() - sig := core.NewSignature("Answer question"). - AddInput("question", core.FieldTypeString, "Question"). - AddOutput("answer", core.FieldTypeString, "Answer"). - AddOutput("confidence", core.FieldTypeInt, "Confidence score") + term := newReActTermination() - iterationCount := 0 - lm := &MockLM{ - SupportsToolsVal: true, - SupportsJSONVal: true, - GenerateFunc: func(ctx context.Context, messages []core.Message, options *core.GenerateOptions) (*core.GenerateResult, error) { - iterationCount++ + term.ObserveError(errors.New("err1")) + if term.ShouldStop() { + t.Error("should not stop after first error") + } - // Check ToolChoice to determine mode (tools are now always present for provider compatibility) - // ToolChoice == "auto" means tool-using mode, ToolChoice == "none" means final/extraction mode - toolsEnabled := options.ToolChoice != "none" && len(options.Tools) > 0 + term.ObserveError(errors.New("err2")) + if !term.ShouldStop() { + t.Error("should stop after second consecutive error") + } + if term.Reason() != terminationRepeatedErrors { + t.Errorf("expected reason %s, got %s", terminationRepeatedErrors, term.Reason()) + } +} - // Tool-using mode: return tool calls to force hitting MaxIterations - // Use different queries to avoid stagnation detection - if toolsEnabled { - query := fmt.Sprintf("test query %d", iterationCount) - return &core.GenerateResult{ - Content: "Using search tool", - ToolCalls: []core.ToolCall{ - { - ID: fmt.Sprintf("call_%d", iterationCount), - Name: "search", - Arguments: map[string]any{ - "query": query, - }, - }, - }, - }, nil - } +func TestReActTermination_ObserveToolResult_ErrorResetsOnSuccess(t *testing.T) { + t.Parallel() + term := newReActTermination() - // No tools mode (final mode or extraction) - // During final mode (iteration 2): return malformed JSON to force extraction - // During extraction (iteration 3): return proper JSON with reasoning - if iterationCount == 2 { - // Return malformed JSON that will fail parsing and trigger extraction - return &core.GenerateResult{ - Content: "I'm thinking about it but not formatting correctly", - }, nil - } + tc := core.ToolCall{Name: "search"} - // Extraction phase (iteration 3): return proper answer with reasoning - return &core.GenerateResult{ - Content: `{ - "rationale": "Based on all the tool observations, I can now provide the final answer.", - "answer": "The answer based on search results", - "confidence": 95 - }`, - }, nil - }, + term.ObserveToolResult(tc, "h1", errors.New("err")) + term.ObserveToolResult(tc, "h2", nil) // Success resets error count. + term.ObserveToolResult(tc, "h3", errors.New("err")) + + if term.ShouldStop() { + t.Error("should not stop; error count was reset by success") } +} - callNumber := 0 - searchTool := core.NewTool( - "search", - "Search for information", - func(ctx context.Context, args map[string]any) (any, error) { - callNumber++ - return fmt.Sprintf("Search results %d: relevant information", callNumber), nil - }, - ).AddParameter("query", "string", "Search query", true) +func TestReActTermination_SetFinalToolArgs(t *testing.T) { + t.Parallel() + term := newReActTermination() - react := NewReAct(sig, lm, []core.Tool{*searchTool}). - WithMaxIterations(2). - WithVerbose(false) + args := map[string]any{"answer": "42"} + term.SetFinalToolArgs(args) - result, err := react.Forward(context.Background(), map[string]any{ - "question": "What is the answer?", - }) + // Modify original to ensure copy was made. + args["answer"] = "modified" - if err != nil { - t.Fatalf("Forward() error = %v", err) + if term.FinalToolArgs()["answer"] != "42" { + t.Error("expected args to be copied, not referenced") } +} - // Should have hit MaxIterations and triggered extraction - // 2 tool-using iterations + 1 extraction call = 3 total - if iterationCount < 3 { - t.Errorf("Expected at least 3 LM calls (2 iterations + extraction), got %d", iterationCount) - } +func TestReActTermination_SetFinalContent(t *testing.T) { + t.Parallel() + term := newReActTermination() + term.SetFinalContent(`{"answer":"42"}`) - // Check that answer was extracted - answer, ok := result.GetString("answer") - if !ok { - t.Error("Expected answer field in result") - } - if !contains(answer, "answer based on search") { - t.Errorf("Expected answer to contain extracted text, got: %s", answer) + if term.FinalContent() != `{"answer":"42"}` { + t.Errorf("unexpected final content: %s", term.FinalContent()) } +} - // CRITICAL: Check that rationale was attached to prediction - if result.Rationale == "" { - t.Error("Expected non-empty rationale from extraction phase with reasoning adapter") - } - if !contains(result.Rationale, "tool observations") { - t.Errorf("Expected rationale to contain reasoning, got: %s", result.Rationale) - } +// ============================================================================ +// generateWithContextRetry / isContextOverflowError tests +// ============================================================================ - // Verify rationale was removed from outputs (not part of signature) - if _, exists := result.Outputs["rationale"]; exists { - t.Error("Rationale should be removed from outputs map") - } - if _, exists := result.Outputs["reasoning"]; exists { - t.Error("Reasoning should be removed from outputs map") +func TestIsContextOverflowError_Sentinel(t *testing.T) { + t.Parallel() + err := contextLengthSentinel{} + if !isContextOverflowError(err) { + t.Error("expected sentinel to be detected as overflow") } } -// TestReAct_ImplicitFinish tests that ReAct accepts direct answers without tool calls. -// This validates the "Implicit Finish" pattern where the model provides a valid answer -// directly instead of using tools, which is correct behavior for native tool calling APIs. -func TestReAct_ImplicitFinish(t *testing.T) { +func TestIsContextOverflowError_Patterns(t *testing.T) { t.Parallel() - sig := core.NewSignature("Answer question"). - AddInput("question", core.FieldTypeString, "Question"). - AddOutput("answer", core.FieldTypeString, "Answer") + tests := []struct { + msg string + want bool + }{ + {"context_length_exceeded", true}, + {"maximum context length", true}, + {"max context length exceeded", true}, + {"maximum context window exceeded", true}, + {"context window exceeded", true}, + {"too many tokens", true}, + {"exceeded the context", true}, + {"please reduce the length of the messages", true}, + {"prompt is too long", true}, + {"tokens exceeded", true}, + {"input is too long", true}, + {"some other error", false}, + {"network timeout", false}, + } + + for _, tc := range tests { + t.Run(tc.msg, func(t *testing.T) { + err := errors.New(tc.msg) + if got := isContextOverflowError(err); got != tc.want { + t.Errorf("isContextOverflowError(%q) = %v, want %v", tc.msg, got, tc.want) + } + }) + } +} +func TestIsContextOverflowError_Nil(t *testing.T) { + t.Parallel() + if isContextOverflowError(nil) { + t.Error("expected nil to not be overflow error") + } +} + +func TestGenerateWithContextRetry_NoOverflow(t *testing.T) { + t.Parallel() callCount := 0 lm := &MockLM{ - SupportsToolsVal: true, GenerateFunc: func(ctx context.Context, messages []core.Message, options *core.GenerateOptions) (*core.GenerateResult, error) { callCount++ - // Model returns valid JSON without making any tool calls (implicit finish) - return &core.GenerateResult{ - Content: `{"answer": "42"}`, - ToolCalls: []core.ToolCall{}, // Empty - no tool calls - }, nil + return &core.GenerateResult{Content: "success"}, nil }, } - searchTool := core.NewTool("search", "Search for info", func(ctx context.Context, args map[string]any) (any, error) { - t.Error("Tool should not be executed in implicit finish scenario") - return "search result", nil - }) - - react := NewReAct(sig, lm, []core.Tool{*searchTool}) - result, err := react.Forward(context.Background(), map[string]interface{}{ - "question": "What is the answer to life?", - }) + r := &ReAct{LM: lm, MaxPromptBytes: 10000} + traj := newReActTrajectory([]core.Message{{Role: "user", Content: "hi"}}) - // Verify: err == nil (success) + result, err := r.generateWithContextRetry(context.Background(), traj, &core.GenerateOptions{}, nil) if err != nil { - t.Fatalf("Forward() error = %v, want nil", err) + t.Fatalf("unexpected error: %v", err) } - - // Verify: callCount == 1 (single LM call, no retry) - if callCount != 1 { - t.Errorf("Expected 1 LM call for implicit finish, got %d", callCount) + if result.Content != "success" { + t.Errorf("unexpected content: %s", result.Content) } - - // Verify: result.Outputs["answer"] == "42" - if result.Outputs["answer"] != "42" { - t.Errorf("Expected answer='42', got %v", result.Outputs["answer"]) + if callCount != 1 { + t.Errorf("expected 1 call, got %d", callCount) } } -// TestReAct_ImplicitFinish_MalformedRetry tests the retry mechanism when implicit finish -// fails validation in early iterations. The model should be guided to use tools. -// Note: This test uses int fields to ensure malformed content fails validation, -// triggering the retry mechanism. String-only signatures would use text extraction -// as a fallback and accept malformed content. -func TestReAct_ImplicitFinish_MalformedRetry(t *testing.T) { +func TestGenerateWithContextRetry_OverflowRetry(t *testing.T) { t.Parallel() - // Use an int output field so malformed text fails validation - sig := core.NewSignature("Calculate something"). - AddInput("question", core.FieldTypeString, "Question"). - AddOutput("count", core.FieldTypeInt, "Count result") - callCount := 0 - var capturedMessages []core.Message lm := &MockLM{ - SupportsToolsVal: true, GenerateFunc: func(ctx context.Context, messages []core.Message, options *core.GenerateOptions) (*core.GenerateResult, error) { callCount++ - capturedMessages = messages - - if callCount == 1 { - // First call: return malformed text without tool calls - // This fails validation because "count" expects int, gets no valid int - return &core.GenerateResult{ - Content: "thinking about this problem without any numbers", - ToolCalls: []core.ToolCall{}, - }, nil + if callCount < 3 { + return nil, contextLengthSentinel{} } - // Second call: return valid JSON with int (recovery) - return &core.GenerateResult{ - Content: `{"count": 42}`, - ToolCalls: []core.ToolCall{}, - }, nil + return &core.GenerateResult{Content: "success after retry"}, nil }, } - searchTool := core.NewTool("search", "Search for info", func(ctx context.Context, args map[string]any) (any, error) { - return "search result", nil - }) - - react := NewReAct(sig, lm, []core.Tool{*searchTool}).WithMaxIterations(5) - result, err := react.Forward(context.Background(), map[string]interface{}{ - "question": "What is the count?", - }) + r := &ReAct{LM: lm, MaxPromptBytes: 10000} + traj := newReActTrajectory([]core.Message{{Role: "user", Content: "hi"}}) + traj.AddStep("step1", nil) + traj.AddStep("step2", nil) + traj.AddStep("step3", nil) - // Verify: err == nil (success after retry) + result, err := r.generateWithContextRetry(context.Background(), traj, &core.GenerateOptions{}, nil) if err != nil { - t.Fatalf("Forward() error = %v, want nil", err) + t.Fatalf("unexpected error: %v", err) } - - // Verify: callCount == 2 (retry occurred) - if callCount != 2 { - t.Errorf("Expected 2 LM calls (malformed + retry), got %d", callCount) + if result.Content != "success after retry" { + t.Errorf("unexpected content: %s", result.Content) } - - // Verify: Messages contain "Please use the available tools" - foundToolGuidance := false - for _, msg := range capturedMessages { - if msg.Role == "user" && contains(msg.Content, "Please use the available tools") { - foundToolGuidance = true - break - } + if callCount != 3 { + t.Errorf("expected 3 calls (2 failures + 1 success), got %d", callCount) } - if !foundToolGuidance { - t.Error("Expected retry message containing 'Please use the available tools'") +} + +func TestGenerateWithContextRetry_ExhaustedRetries(t *testing.T) { + t.Parallel() + callCount := 0 + lm := &MockLM{ + GenerateFunc: func(ctx context.Context, messages []core.Message, options *core.GenerateOptions) (*core.GenerateResult, error) { + callCount++ + return nil, contextLengthSentinel{} + }, } - // Verify: result.Outputs["count"] == 42 - count, ok := result.GetInt("count") - if !ok || count != 42 { - t.Errorf("Expected count=42, got %v", result.Outputs["count"]) + r := &ReAct{LM: lm, MaxPromptBytes: 10000} + traj := newReActTrajectory([]core.Message{{Role: "user", Content: "hi"}}) + traj.AddStep("step1", nil) + traj.AddStep("step2", nil) + traj.AddStep("step3", nil) + + _, err := r.generateWithContextRetry(context.Background(), traj, &core.GenerateOptions{}, nil) + if err == nil { + t.Fatal("expected error after exhausting retries") + } + // Should have tried reactContextOverflowMaxRetries (3) times. + if callCount != 3 { + t.Errorf("expected 3 calls, got %d", callCount) } } -// TestReAct_WithMethods tests all ReAct configuration methods -func TestReAct_WithMethods(t *testing.T) { +func TestGenerateWithContextRetry_NoStepsToDropReturnsImmediately(t *testing.T) { t.Parallel() - sig := core.NewSignature("test"). - AddInput("question", core.FieldTypeString, ""). - AddOutput("answer", core.FieldTypeString, "") - - lm := &MockLM{} - tools := []core.Tool{} - history := core.NewHistory() - demos := []core.Example{ - *core.NewExample( - map[string]any{"question": "test"}, - map[string]any{"answer": "test"}, - ), + callCount := 0 + lm := &MockLM{ + GenerateFunc: func(ctx context.Context, messages []core.Message, options *core.GenerateOptions) (*core.GenerateResult, error) { + callCount++ + return nil, contextLengthSentinel{} + }, } - adapter := core.NewJSONAdapter() - react := NewReAct(sig, lm, tools). - WithAdapter(adapter). - WithHistory(history). - WithDemos(demos) + r := &ReAct{LM: lm, MaxPromptBytes: 10000} + traj := newReActTrajectory([]core.Message{{Role: "user", Content: "hi"}}) + // No steps added - nothing to drop. - if react.Adapter != adapter { - t.Error("WithAdapter should set adapter") - } - if react.History != history { - t.Error("WithHistory should set history") + _, err := r.generateWithContextRetry(context.Background(), traj, &core.GenerateOptions{}, nil) + if err == nil { + t.Fatal("expected error") } - if len(react.Demos) != 1 { - t.Error("WithDemos should set demos") + // Should return immediately after first failure since no steps to drop. + if callCount != 1 { + t.Errorf("expected 1 call (immediate return), got %d", callCount) } } -// TestReAct_UsageAccumulation tests that usage (tokens, cost, latency) accumulates correctly across multiple iterations -func TestReAct_UsageAccumulation(t *testing.T) { +func TestGenerateWithContextRetry_NonOverflowErrorReturnsImmediately(t *testing.T) { t.Parallel() - sig := core.NewSignature("Answer question"). - AddInput("question", core.FieldTypeString, "Question"). - AddOutput("answer", core.FieldTypeString, "Answer") - callCount := 0 lm := &MockLM{ - SupportsToolsVal: true, GenerateFunc: func(ctx context.Context, messages []core.Message, options *core.GenerateOptions) (*core.GenerateResult, error) { callCount++ - if callCount == 1 { - // First iteration with tool call - return &core.GenerateResult{ - Content: "Let me search", - ToolCalls: []core.ToolCall{ - {ID: "1", Name: "search", Arguments: map[string]interface{}{"query": "test"}}, - }, - Usage: core.Usage{ - PromptTokens: 100, - CompletionTokens: 50, - TotalTokens: 150, - Cost: 0.001, - Latency: 500 * 1_000_000, // 500ms in nanoseconds - }, - }, nil - } - // Second iteration with final answer - return &core.GenerateResult{ - Content: `{"answer": "final answer"}`, - Usage: core.Usage{ - PromptTokens: 200, - CompletionTokens: 100, - TotalTokens: 300, - Cost: 0.002, - Latency: 600 * 1_000_000, // 600ms in nanoseconds - }, - }, nil + return nil, errors.New("network error") }, } - searchTool := core.NewTool("search", "Search for info", func(ctx context.Context, args map[string]any) (any, error) { - return "search result", nil - }) - - react := NewReAct(sig, lm, []core.Tool{*searchTool}) - pred, err := react.Forward(context.Background(), map[string]interface{}{ - "question": "test", - }) + r := &ReAct{LM: lm, MaxPromptBytes: 10000} + traj := newReActTrajectory([]core.Message{{Role: "user", Content: "hi"}}) + traj.AddStep("step1", nil) - if err != nil { - t.Fatalf("Forward() error = %v", err) + _, err := r.generateWithContextRetry(context.Background(), traj, &core.GenerateOptions{}, nil) + if err == nil || err.Error() != "network error" { + t.Fatalf("expected network error, got: %v", err) + } + if callCount != 1 { + t.Errorf("expected 1 call (immediate return for non-overflow error), got %d", callCount) } +} - // Verify usage accumulation across 2 iterations - expectedPromptTokens := 100 + 200 // First + second iteration - expectedCompletionTokens := 50 + 100 // First + second iteration - expectedTotalTokens := 150 + 300 // First + second iteration - expectedCost := 0.001 + 0.002 // First + second iteration - expectedLatency := 500 + 600 // Sum of latencies (in milliseconds) +// ============================================================================ +// Helper: approxMessagesBytes / approxBytes tests +// ============================================================================ - if pred.Usage.PromptTokens != expectedPromptTokens { - t.Errorf("PromptTokens: expected %d, got %d", expectedPromptTokens, pred.Usage.PromptTokens) - } - if pred.Usage.CompletionTokens != expectedCompletionTokens { - t.Errorf("CompletionTokens: expected %d, got %d", expectedCompletionTokens, pred.Usage.CompletionTokens) +func TestApproxMessagesBytes(t *testing.T) { + t.Parallel() + msgs := []core.Message{ + {Role: "user", Content: "hello"}, + {Role: "assistant", Content: "world"}, } - if pred.Usage.TotalTokens != expectedTotalTokens { - t.Errorf("TotalTokens: expected %d, got %d", expectedTotalTokens, pred.Usage.TotalTokens) + + bytes := approxMessagesBytes(msgs) + if bytes != 10 { // "hello" (5) + "world" (5) + t.Errorf("expected 10 bytes, got %d", bytes) } - if pred.Usage.Cost != expectedCost { - t.Errorf("Cost: expected %.6f, got %.6f", expectedCost, pred.Usage.Cost) +} + +func TestApproxMessagesBytes_WithToolCalls(t *testing.T) { + t.Parallel() + msgs := []core.Message{ + { + Role: "assistant", + Content: "x", + ToolCalls: []core.ToolCall{ + {ID: "1", Name: "search", Arguments: map[string]any{"q": "test"}}, + }, + }, } - expectedLatencyNs := int64(expectedLatency) * 1_000_000 - if pred.Usage.Latency != expectedLatencyNs { - t.Errorf("Latency: expected %d ns (%.2fms), got %d ns (%.2fms)", - expectedLatencyNs, float64(expectedLatencyNs)/1_000_000, - pred.Usage.Latency, float64(pred.Usage.Latency)/1_000_000) + + bytes := approxMessagesBytes(msgs) + // Content "x" (1 byte) + JSON of tool calls. + if bytes <= 1 { + t.Errorf("expected > 1 byte due to tool calls, got %d", bytes) } } -// TestReAct_ToolsSliceNotMutated tests that the caller's tools slice is not mutated by NewReAct -func TestReAct_ToolsSliceNotMutated(t *testing.T) { +func TestReActStep_ApproxBytes(t *testing.T) { t.Parallel() - sig := core.NewSignature("Answer question"). - AddInput("question", core.FieldTypeString, "Question"). - AddOutput("answer", core.FieldTypeString, "Answer") - - lm := &MockLM{} + step := &reactStep{ + Thought: "thinking", + ToolResults: []reactToolResult{ + {Content: "result"}, + }, + Errors: []string{"error"}, + } - // Create tools slice with specific capacity to detect append mutation - originalTools := make([]core.Tool, 1, 10) // capacity > len to allow in-place append - searchTool := core.NewTool("search", "Search for info", func(ctx context.Context, args map[string]any) (any, error) { - return "result", nil - }) - originalTools[0] = *searchTool + bytes := step.approxBytes() + // "thinking" (8) + "result" (6) + "error" (5) = 19 + if bytes != 19 { + t.Errorf("expected 19 bytes, got %d", bytes) + } +} - // Capture original length - originalLen := len(originalTools) +// ============================================================================ +// toolFingerprint tests +// ============================================================================ - // Create ReAct which auto-injects finish tool - _ = NewReAct(sig, lm, originalTools) +func TestToolFingerprint_SameCallsSameFingerprint(t *testing.T) { + t.Parallel() + tc1 := core.ToolCall{Name: "search", Arguments: map[string]any{"q": "test"}} + tc2 := core.ToolCall{Name: "search", Arguments: map[string]any{"q": "test"}} - // Verify caller's slice was NOT modified - if len(originalTools) != originalLen { - t.Errorf("Caller's tools slice was mutated: expected len %d, got %d", originalLen, len(originalTools)) - } + fp1 := toolFingerprint(tc1) + fp2 := toolFingerprint(tc2) - // Verify finish tool was NOT appended to caller's slice - for _, tool := range originalTools { - if strings.ToLower(tool.Name) == "finish" { - t.Error("Finish tool should not appear in caller's original tools slice") - } + if fp1 != fp2 { + t.Errorf("expected same fingerprint for identical calls: %s vs %s", fp1, fp2) } } -// TestReAct_NonToolLM_NoToolsPassedInFinalMode tests that tools are not passed to LMs that don't support them in final mode -func TestReAct_NonToolLM_NoToolsPassedInFinalMode(t *testing.T) { +func TestToolFingerprint_DifferentArgsDifferentFingerprint(t *testing.T) { t.Parallel() - sig := core.NewSignature("Answer question"). - AddInput("question", core.FieldTypeString, "Question"). - AddOutput("answer", core.FieldTypeString, "Answer") + tc1 := core.ToolCall{Name: "search", Arguments: map[string]any{"q": "test1"}} + tc2 := core.ToolCall{Name: "search", Arguments: map[string]any{"q": "test2"}} - var finalModeOptions *core.GenerateOptions - callCount := 0 - lm := &MockLM{ - SupportsToolsVal: false, // LM does NOT support tools - GenerateFunc: func(ctx context.Context, messages []core.Message, options *core.GenerateOptions) (*core.GenerateResult, error) { - callCount++ - if callCount == 1 { - // First call: return something that doesn't parse well to trigger iteration - return &core.GenerateResult{ - Content: "I need to think about this...", - }, nil - } - if callCount == 2 { - // Second call: still no good answer, will trigger final mode - return &core.GenerateResult{ - Content: "Still thinking...", - }, nil - } - // Third call onward: final mode - capture options here - finalModeOptions = options - return &core.GenerateResult{ - Content: `{"answer": "final answer"}`, - }, nil - }, + fp1 := toolFingerprint(tc1) + fp2 := toolFingerprint(tc2) + + if fp1 == fp2 { + t.Error("expected different fingerprints for different args") } +} - searchTool := core.NewTool("search", "Search for info", func(ctx context.Context, args map[string]any) (any, error) { - return "result", nil - }) +func TestToolFingerprint_IncludesToolName(t *testing.T) { + t.Parallel() + tc1 := core.ToolCall{Name: "search", Arguments: map[string]any{"q": "test"}} + tc2 := core.ToolCall{Name: "calculate", Arguments: map[string]any{"q": "test"}} - react := NewReAct(sig, lm, []core.Tool{*searchTool}).WithMaxIterations(4) - _, err := react.Forward(context.Background(), map[string]any{"question": "test"}) - if err != nil { - t.Fatalf("Forward() error = %v", err) - } + fp1 := toolFingerprint(tc1) + fp2 := toolFingerprint(tc2) - // Verify final mode options: tools should NOT be passed since LM doesn't support them - if finalModeOptions != nil { - if len(finalModeOptions.Tools) > 0 { - t.Errorf("Tools should not be passed to non-tool LM in final mode, got %d tools", len(finalModeOptions.Tools)) - } - if finalModeOptions.ToolChoice == "none" { - t.Errorf("ToolChoice should not be 'none' for non-tool LM, got %q", finalModeOptions.ToolChoice) - } + if fp1 == fp2 { + t.Error("expected different fingerprints for different tool names") } } diff --git a/internal/module/react_tool_result.go b/internal/module/react_tool_result.go new file mode 100644 index 0000000..35ded67 --- /dev/null +++ b/internal/module/react_tool_result.go @@ -0,0 +1,168 @@ +package module + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "unicode/utf8" +) + +type reactToolResult struct { + ToolCallID string + ToolName string + Content string + Truncated bool + Err error +} + +type toolResultEnvelope struct { + Tool string `json:"tool"` + ToolCallID string `json:"tool_call_id,omitempty"` + OK bool `json:"ok"` + Result any `json:"result,omitempty"` + Error string `json:"error,omitempty"` + Truncated bool `json:"truncated,omitempty"` + OriginalSize int `json:"original_bytes,omitempty"` +} + +// encodeToolResult encodes tool outputs as a JSON envelope and truncates deterministically +// to maxBytes, ensuring the resulting string is valid JSON. +// +// It also returns a stable hash of the observation content that excludes per-call identifiers +// (like tool_call_id), so termination policies can detect repeated observations. +func encodeToolResult(toolName, toolCallID string, value any, toolErr error, maxBytes int) (content string, truncated bool, stableHash string) { + if maxBytes <= 0 { + maxBytes = defaultReActMaxToolResultBytes + } + + env := toolResultEnvelope{ + Tool: toolName, + ToolCallID: toolCallID, + OK: toolErr == nil, + } + if toolErr != nil { + env.Error = toolErr.Error() + } else { + env.Result = value + } + + data, err := json.Marshal(env) + if err != nil { + fallback := toolResultEnvelope{Tool: toolName, ToolCallID: toolCallID, OK: false, Error: fmt.Sprintf("marshal error: %v", err)} + fallbackBytes, _ := json.Marshal(fallback) + stable := stableToolEnvelope(toolName, toolErr == nil, nil, fallback.Error, false) + return string(fallbackBytes), false, stableHashJSON(stable) + } + if len(data) <= maxBytes { + stable := stableToolEnvelope(toolName, toolErr == nil, value, env.Error, false) + return string(data), false, stableHashJSON(stable) + } + + origSize := len(data) + + // Truncation strategy: represent the tool result as a JSON string excerpt. + var excerptBytes []byte + if toolErr != nil { + excerptBytes = []byte(toolErr.Error()) + } else { + if b, err := json.Marshal(value); err == nil { + excerptBytes = b + } else { + excerptBytes = []byte(fmt.Sprintf("%v", value)) + } + } + + trunc := toolResultEnvelope{ + Tool: toolName, + ToolCallID: toolCallID, + OK: toolErr == nil, + Truncated: true, + OriginalSize: origSize, + } + if toolErr != nil { + trunc.Error = toolErr.Error() + } + + // Binary search the maximum excerpt prefix that still fits. + low, high := 0, len(excerptBytes) + best := 0 + for low <= high { + mid := (low + high) / 2 + trunc.Result = truncateUTF8Bytes(excerptBytes, mid) + b, _ := json.Marshal(trunc) + if len(b) <= maxBytes { + best = mid + low = mid + 1 + } else { + high = mid - 1 + } + } + + trunc.Result = truncateUTF8Bytes(excerptBytes, best) + b, _ := json.Marshal(trunc) + if len(b) <= maxBytes { + stable := stableToolEnvelope(toolName, toolErr == nil, trunc.Result, trunc.Error, true) + return string(b), true, stableHashJSON(stable) + } + + // If still too large (extreme maxBytes), fall back to a minimal envelope. + minimal := toolResultEnvelope{ + Tool: toolName, + ToolCallID: toolCallID, + OK: toolErr == nil, + Truncated: true, + } + minimalBytes, _ := json.Marshal(minimal) + if len(minimalBytes) <= maxBytes { + stable := stableToolEnvelope(toolName, toolErr == nil, nil, "", true) + return string(minimalBytes), true, stableHashJSON(stable) + } + + // Last resort: produce valid JSON even if it exceeds budget. + stable := stableToolEnvelope(toolName, toolErr == nil, nil, "", true) + return string(minimalBytes), true, stableHashJSON(stable) +} + +type stableToolResultEnvelope struct { + Tool string `json:"tool"` + OK bool `json:"ok"` + Result any `json:"result,omitempty"` + Error string `json:"error,omitempty"` + Truncated bool `json:"truncated,omitempty"` +} + +func stableToolEnvelope(tool string, ok bool, result any, errMsg string, truncated bool) stableToolResultEnvelope { + env := stableToolResultEnvelope{Tool: tool, OK: ok, Truncated: truncated} + if !ok { + env.Error = errMsg + } else { + env.Result = result + } + return env +} + +func stableHashJSON(v any) string { + b, err := json.Marshal(v) + if err != nil { + // Fall back to hashing the error string to remain stable. + h := sha256.Sum256([]byte(fmt.Sprintf("marshal error: %v", err))) + return hex.EncodeToString(h[:]) + } + h := sha256.Sum256(b) + return hex.EncodeToString(h[:]) +} + +func truncateUTF8Bytes(b []byte, n int) string { + if n <= 0 { + return "" + } + if n >= len(b) { + return string(b) + } + cut := b[:n] + for len(cut) > 0 && !utf8.Valid(cut) { + cut = cut[:len(cut)-1] + } + return string(cut) +} diff --git a/internal/module/react_trajectory.go b/internal/module/react_trajectory.go new file mode 100644 index 0000000..07d450e --- /dev/null +++ b/internal/module/react_trajectory.go @@ -0,0 +1,146 @@ +package module + +import ( + "encoding/json" + + "github.com/assagman/dsgo/internal/core" +) + +// reactTrajectory stores ReAct loop state as structured steps and can render them +// into a provider-friendly []core.Message under a soft budget. +// +// Budgeting is approximate (byte-based) and intended for deterministic truncation +// and overflow recovery. +type reactTrajectory struct { + base []core.Message + steps []*reactStep +} + +type reactStep struct { + Thought string + ToolCalls []core.ToolCall + ToolResults []reactToolResult + Errors []string +} + +func newReActTrajectory(base []core.Message) *reactTrajectory { + cloned := make([]core.Message, len(base)) + copy(cloned, base) + return &reactTrajectory{base: cloned, steps: []*reactStep{}} +} + +func (t *reactTrajectory) AddStep(thought string, toolCalls []core.ToolCall) *reactStep { + step := &reactStep{Thought: thought} + if toolCalls != nil { + step.ToolCalls = make([]core.ToolCall, len(toolCalls)) + copy(step.ToolCalls, toolCalls) + } + t.steps = append(t.steps, step) + return step +} + +func (s *reactStep) AddToolResult(res reactToolResult) { + s.ToolResults = append(s.ToolResults, res) +} + +func (t *reactTrajectory) DropOldestSteps(n int) int { + if n <= 0 || len(t.steps) == 0 { + return 0 + } + if n >= len(t.steps) { + dropped := len(t.steps) + t.steps = nil + return dropped + } + t.steps = t.steps[n:] + return n +} + +func (t *reactTrajectory) HasToolContent() bool { + for _, s := range t.steps { + if len(s.ToolCalls) > 0 || len(s.ToolResults) > 0 { + return true + } + } + return false +} + +func (t *reactTrajectory) Render(budgetBytes int) []core.Message { + if budgetBytes <= 0 { + budgetBytes = defaultReActMaxPromptBytes + } + + base := make([]core.Message, len(t.base)) + copy(base, t.base) + if len(t.steps) == 0 { + return base + } + + baseBytes := approxMessagesBytes(base) + remaining := budgetBytes - baseBytes + if remaining <= 0 { + return base + } + + // Select suffix of steps that fits within remaining. + includeFrom := len(t.steps) - 1 + used := 0 + for i := len(t.steps) - 1; i >= 0; i-- { + stepBytes := t.steps[i].approxBytes() + // Always include the newest step even if it exceeds the budget. + if used+stepBytes > remaining && i != len(t.steps)-1 { + break + } + includeFrom = i + used += stepBytes + } + + msgs := append([]core.Message{}, base...) + for _, s := range t.steps[includeFrom:] { + msgs = append(msgs, s.toMessages()...) + } + return msgs +} + +func (s *reactStep) toMessages() []core.Message { + msgs := []core.Message{} + if s.Thought != "" || len(s.ToolCalls) > 0 { + msgs = append(msgs, core.Message{Role: "assistant", Content: s.Thought, ToolCalls: s.ToolCalls}) + } + for _, tr := range s.ToolResults { + msgs = append(msgs, core.Message{Role: "tool", Content: tr.Content, ToolID: tr.ToolCallID}) + } + for _, e := range s.Errors { + msgs = append(msgs, core.Message{Role: "system", Content: e}) + } + return msgs +} + +func (s *reactStep) approxBytes() int { + b := len(s.Thought) + if len(s.ToolCalls) > 0 { + if data, err := json.Marshal(s.ToolCalls); err == nil { + b += len(data) + } + } + for _, tr := range s.ToolResults { + b += len(tr.Content) + } + for _, e := range s.Errors { + b += len(e) + } + return b +} + +func approxMessagesBytes(msgs []core.Message) int { + total := 0 + for _, m := range msgs { + total += len(m.Content) + if len(m.ToolCalls) > 0 { + if data, err := json.Marshal(m.ToolCalls); err == nil { + total += len(data) + } + } + } + return total +} diff --git a/internal/providers/mock/lm.go b/internal/providers/mock/lm.go index 498de40..c537f5c 100644 --- a/internal/providers/mock/lm.go +++ b/internal/providers/mock/lm.go @@ -339,9 +339,10 @@ func (m *mockHTTP) buildRequest(messages []core.Message, options *core.GenerateO req["tools"] = tools if options.ToolChoice != "" && options.ToolChoice != "auto" { - if options.ToolChoice == "none" { - req["tool_choice"] = "none" - } else { + switch options.ToolChoice { + case "none", "required": + req["tool_choice"] = options.ToolChoice + default: req["tool_choice"] = map[string]any{ "type": "function", "function": map[string]string{ @@ -459,9 +460,10 @@ func convertTool(tool *core.Tool) map[string]any { "name": tool.Name, "description": tool.Description, "parameters": map[string]any{ - "type": "object", - "properties": properties, - "required": required, + "type": "object", + "properties": properties, + "required": required, + "additionalProperties": false, }, }, } diff --git a/internal/providers/mock/lm_test.go b/internal/providers/mock/lm_test.go index 1734b2c..a631d5b 100644 --- a/internal/providers/mock/lm_test.go +++ b/internal/providers/mock/lm_test.go @@ -124,3 +124,31 @@ func TestMockHTTP_Stream(t *testing.T) { t.Fatalf("usage.total=%d, want 3", last.Usage.TotalTokens) } } + +func TestMockHTTP_buildRequest_ToolChoiceRequired(t *testing.T) { + m := &mockHTTP{model: "gpt-4o"} + + tool := core.NewTool("test_tool", "A test tool", nil) + tool.AddParameter("q", "string", "Query", true) + + req := m.buildRequest([]core.Message{{Role: "user", Content: "hi"}}, &core.GenerateOptions{ + Tools: []core.Tool{*tool}, + ToolChoice: "required", + }) + + if req["tool_choice"] != "required" { + t.Fatalf("tool_choice=%v, want required", req["tool_choice"]) + } +} + +func TestMockHTTP_convertTool_AdditionalPropertiesFalse(t *testing.T) { + tool := core.NewTool("test_tool", "A test tool", nil) + tool.AddParameter("q", "string", "Query", true) + + converted := convertTool(tool) + fn, _ := converted["function"].(map[string]any) + params, _ := fn["parameters"].(map[string]any) + if params["additionalProperties"] != false { + t.Fatalf("additionalProperties=%v, want false", params["additionalProperties"]) + } +} diff --git a/internal/providers/openai/lm.go b/internal/providers/openai/lm.go index 31e13b3..ab791f7 100644 --- a/internal/providers/openai/lm.go +++ b/internal/providers/openai/lm.go @@ -17,6 +17,7 @@ import ( "github.com/assagman/dsgo/internal/logging" "github.com/assagman/dsgo/internal/modelcatalog" "github.com/assagman/dsgo/internal/providers/util" + "github.com/assagman/dsgo/internal/retry" "github.com/openai/openai-go/v3" "github.com/openai/openai-go/v3/option" "github.com/openai/openai-go/v3/shared" @@ -131,20 +132,59 @@ func (o *openAI) Generate(ctx context.Context, messages []core.Message, options params := o.buildParams(messages, options) - var rawResp *http.Response - reqOpts := []option.RequestOption{option.WithResponseInto(&rawResp)} - if o.BaseURL != "" { - reqOpts = append(reqOpts, option.WithBaseURL(o.BaseURL)) + // Apply exponential backoff on 429/5xx to survive bursty concurrency. + retryOpts := retry.DefaultOptions() + if options != nil && options.RetryConfig != nil { + retryOpts.MergeFrom( + options.RetryConfig.MaxRetries, + options.RetryConfig.InitialBackoff, + options.RetryConfig.MaxBackoff, + options.RetryConfig.JitterFactor, + ) } - if options.RetryConfig != nil { - reqOpts = append(reqOpts, option.WithMaxRetries(options.RetryConfig.MaxRetries)) + + var rawResp *http.Response + var chatCompletion *openai.ChatCompletion + var lastErr error + + httpFn := func() (*http.Response, error) { + chatCompletion = nil + lastErr = nil + + var attemptResp *http.Response + reqOpts := []option.RequestOption{option.WithResponseInto(&attemptResp)} + if o.BaseURL != "" { + reqOpts = append(reqOpts, option.WithBaseURL(o.BaseURL)) + } + // Disable SDK retries: we do our own (with Retry-After support). + reqOpts = append(reqOpts, option.WithMaxRetries(0)) + + cc, err := o.Client.Chat.Completions.New(ctx, params, reqOpts...) + if err != nil { + lastErr = err + if attemptResp == nil { + return nil, err + } + return attemptResp, nil + } + + chatCompletion = cc + rawResp = attemptResp + return attemptResp, nil } - chatCompletion, err := o.Client.Chat.Completions.New(ctx, params, reqOpts...) + _, err := retry.WithExponentialBackoffOpts(ctx, httpFn, retryOpts) if err != nil { logging.LogAPIError(ctx, "provider.OpenAI", o.Model, err) return nil, fmt.Errorf("request failed: %w", err) } + if lastErr != nil { + logging.LogAPIError(ctx, "provider.OpenAI", o.Model, lastErr) + return nil, fmt.Errorf("request failed: %w", lastErr) + } + if chatCompletion == nil { + return nil, fmt.Errorf("request failed: empty response") + } result, err := o.parseResponse(chatCompletion) if err != nil { @@ -241,11 +281,17 @@ func (o *openAI) buildParams(messages []core.Message, options *core.GenerateOpti params.Tools = tools if options.ToolChoice != "" && options.ToolChoice != "auto" { - if options.ToolChoice == "none" { + switch options.ToolChoice { + case "none": params.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{ OfAuto: openai.String("none"), } - } else { + case "required": + // Require the model to call a tool. + params.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{ + OfAuto: openai.String("required"), + } + default: params.ToolChoice = openai.ToolChoiceOptionFunctionToolChoice( openai.ChatCompletionNamedToolChoiceFunctionParam{ Name: options.ToolChoice, @@ -376,9 +422,10 @@ func (o *openAI) convertTool(tool *core.Tool) openai.ChatCompletionToolUnionPara Name: tool.Name, Description: openai.String(tool.Description), Parameters: shared.FunctionParameters{ - "type": "object", - "properties": properties, - "required": required, + "type": "object", + "properties": properties, + "required": required, + "additionalProperties": false, }, }) } diff --git a/internal/providers/openai/lm_test.go b/internal/providers/openai/lm_test.go index 2f9d138..0055f69 100644 --- a/internal/providers/openai/lm_test.go +++ b/internal/providers/openai/lm_test.go @@ -162,6 +162,21 @@ func TestOpenAI_BuildParams_ToolChoice(t *testing.T) { } }) + t.Run("required", func(t *testing.T) { + lm := &openAI{Model: "gpt-4o"} + tool := core.NewTool("test_tool", "A test tool", nil) + + opts := &core.GenerateOptions{Tools: []core.Tool{*tool}, ToolChoice: "required"} + params := lm.buildParams([]core.Message{{Role: "user", Content: "hi"}}, opts) + + if !params.ToolChoice.OfAuto.Valid() { + t.Fatal("expected ToolChoice.OfAuto to be set") + } + if params.ToolChoice.OfAuto.Value != "required" { + t.Errorf("expected tool choice required, got %q", params.ToolChoice.OfAuto.Value) + } + }) + t.Run("with provider params", func(t *testing.T) { lm := &openAI{Model: "gpt-4o"} messages := []core.Message{{Role: "user", Content: "test"}} @@ -317,6 +332,10 @@ func TestOpenAI_ConvertTool(t *testing.T) { if len(required) != 1 || required[0] != "param1" { t.Errorf("expected required to be [param1], got %v", required) } + + if fn.Parameters["additionalProperties"] != false { + t.Errorf("expected additionalProperties=false, got %v", fn.Parameters["additionalProperties"]) + } } func TestOpenAI_ParseResponse_InvalidToolArgs(t *testing.T) { diff --git a/internal/providers/openrouter/lm.go b/internal/providers/openrouter/lm.go index 087a00a..5230a41 100644 --- a/internal/providers/openrouter/lm.go +++ b/internal/providers/openrouter/lm.go @@ -273,13 +273,18 @@ func (o *openRouter) buildParams(messages []core.Message, options *core.Generate params.Tools = tools if options.ToolChoice != "" && options.ToolChoice != "auto" { - if options.ToolChoice == "none" { + switch options.ToolChoice { + case "none": if !isZAIModel { params.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{ OfAuto: openai.String("none"), } } - } else { + case "required": + params.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{ + OfAuto: openai.String("required"), + } + default: params.ToolChoice = openai.ToolChoiceOptionFunctionToolChoice( openai.ChatCompletionNamedToolChoiceFunctionParam{ Name: options.ToolChoice, @@ -301,8 +306,12 @@ func (o *openRouter) buildParams(messages []core.Message, options *core.Generate } if isZAIModel && (len(options.Tools) > 0 || hasToolContent) { - params.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{ - OfAuto: openai.String("auto"), + // Z.AI models have strict tool_choice handling and may reject "none". + // We only override when the caller didn't explicitly ask for a non-auto policy. + if options.ToolChoice == "" || options.ToolChoice == "auto" || options.ToolChoice == "none" { + params.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{ + OfAuto: openai.String("auto"), + } } } @@ -448,9 +457,10 @@ func (o *openRouter) convertTool(tool *core.Tool) openai.ChatCompletionToolUnion Name: tool.Name, Description: openai.String(tool.Description), Parameters: shared.FunctionParameters{ - "type": "object", - "properties": properties, - "required": required, + "type": "object", + "properties": properties, + "required": required, + "additionalProperties": false, }, }) } diff --git a/internal/providers/openrouter/lm_test.go b/internal/providers/openrouter/lm_test.go index e829554..4a925bc 100644 --- a/internal/providers/openrouter/lm_test.go +++ b/internal/providers/openrouter/lm_test.go @@ -420,6 +420,24 @@ func TestOpenRouter_BuildParams_ToolChoice(t *testing.T) { t.Errorf("expected tool_choice 'none' for non-ZAI model, got %v", m["tool_choice"]) } }) + + t.Run("non-zai model sends tool_choice required", func(t *testing.T) { + lm := &openRouter{Model: "openai/gpt-4"} + messages := []core.Message{{Role: "user", Content: "test"}} + options := &core.GenerateOptions{ + Tools: []core.Tool{*core.NewTool("test_tool", "A test tool", nil)}, + ToolChoice: "required", + } + params := lm.buildParams(messages, options) + + data, _ := json.Marshal(params) + var m map[string]any + _ = json.Unmarshal(data, &m) + + if m["tool_choice"] != "required" { + t.Errorf("expected tool_choice 'required' for non-ZAI model, got %v", m["tool_choice"]) + } + }) } func TestOpenRouter_BuildParams_ZAIToolChoice(t *testing.T) { @@ -507,8 +525,23 @@ func TestOpenRouter_ConvertTool(t *testing.T) { converted := lm.convertTool(tool) - if converted.GetFunction() == nil { - t.Error("expected function to be set") + fn := converted.GetFunction() + if fn == nil { + t.Fatal("expected function to be set") + } + + data, err := json.Marshal(converted) + if err != nil { + t.Fatalf("marshal tool: %v", err) + } + var m map[string]any + if err := json.Unmarshal(data, &m); err != nil { + t.Fatalf("unmarshal tool: %v", err) + } + function, _ := m["function"].(map[string]any) + params, _ := function["parameters"].(map[string]any) + if params["additionalProperties"] != false { + t.Errorf("expected additionalProperties=false, got %v", params["additionalProperties"]) } }