diff --git a/docs/architecture/README.md b/docs/architecture/README.md new file mode 100644 index 000000000..cdb890a7c --- /dev/null +++ b/docs/architecture/README.md @@ -0,0 +1,52 @@ +# PicoClaw Multi-Agent Architecture + +This directory contains C4 model diagrams (rendered with Mermaid) documenting the multi-agent collaboration framework for PicoClaw. + +## Reference Implementation + +**OpenClaw (moltbot)** — the state-of-the-art personal AI gateway whose founder was hired by OpenAI. picoclaw ports and improves upon OpenClaw's validated patterns in a lightweight Go single-binary. + +- Reference code: `/home/leeaandrob/Projects/Personal/llm/auto-agents/moltbot` +- See [PRP](../prp/multi-agent-hardening.md) for detailed implementation plan + +## Documents + +| Document | Scope | Description | +|----------|-------|-------------| +| [C1 - System Context](./c1-system-context.md) | Highest level | PicoClaw in its ecosystem: users, channels, LLM providers | +| [C2 - Container](./c2-container.md) | Runtime containers | Gateway, Agent Loop, Provider Layer, Channels | +| [C3 - Component](./c3-component-multi-agent.md) | Multi-agent internals | Current + planned components across 4 hardening phases | +| [C4 - Code](./c4-code-detail.md) | Key structs/interfaces | Go interfaces, data flow, tool execution | +| [Sequence Diagrams](./sequences.md) | Runtime flows | Handoff, Blackboard sync, Fallback chain | +| [Roadmap](./roadmap.md) | Phased plan | 4-phase hardening based on OpenClaw gap analysis | + +## Related + +- **PRP**: [Multi-Agent Hardening](../prp/multi-agent-hardening.md) — Full implementation plan with acceptance criteria +- **Issue**: [#294 - Base Multi-agent Collaboration Framework](https://github.com/sipeed/picoclaw/issues/294) +- **Issue**: [#283 - Refactor Provider Architecture](https://github.com/sipeed/picoclaw/issues/283) +- **Discussion**: [#122 - Provider Architecture Proposal](https://github.com/sipeed/picoclaw/discussions/122) + +## Status + +| Phase | Status | PR | +|-------|--------|----| +| Provider Protocol Refactor | Merged | [#213](https://github.com/sipeed/picoclaw/pull/213) | +| Model Fallback + Routing | Merged | [#131](https://github.com/sipeed/picoclaw/pull/131) | +| Blackboard + Handoff + Discovery | WIP | [#423](https://github.com/sipeed/picoclaw/pull/423) | +| Phase 1: Foundation Fix | Planned | PR #423 | +| Phase 2: Tool Policy | Planned | PR #423 | +| Phase 3: Resilience | Planned | TBD | +| Phase 4: Async Multi-Agent | Planned | TBD | +| SOUL.md Bootstrap | In Progress (other dev) | TBD | + +## picoclaw Advantages Over OpenClaw + +| Area | picoclaw | OpenClaw | +|------|----------|----------| +| Shared agent state | Blackboard (real-time) | None (announce-only) | +| Runtime | Go single binary | Node.js | +| Memory footprint | ~10x smaller | Node.js overhead | +| Deployment | Copy binary and run | npm install + config | +| Concurrency | goroutines (native) | async/await | +| Type safety | Compile-time | Runtime | diff --git a/docs/architecture/c1-system-context.md b/docs/architecture/c1-system-context.md new file mode 100644 index 000000000..69d92945d --- /dev/null +++ b/docs/architecture/c1-system-context.md @@ -0,0 +1,59 @@ +# C1 - System Context Diagram + +PicoClaw as a multi-agent platform within its ecosystem. + +```mermaid +C4Context + title System Context - PicoClaw Multi-Agent Platform + + Person(user, "User", "Interacts via messaging channels or CLI") + Person(dev, "Developer", "Configures agents, skills, and providers") + + System(picoclaw, "PicoClaw", "Multi-agent AI platform that routes user messages to specialized agents backed by multiple LLM providers") + + System_Ext(discord, "Discord", "Chat platform") + System_Ext(telegram, "Telegram", "Chat platform") + System_Ext(slack, "Slack", "Workspace messaging") + System_Ext(whatsapp, "WhatsApp", "Messaging") + System_Ext(cli, "CLI", "Direct terminal access") + + System_Ext(openai, "OpenAI API", "GPT models, Codex") + System_Ext(anthropic, "Anthropic API", "Claude models") + System_Ext(gemini, "Google Gemini", "Gemini models") + System_Ext(openrouter, "OpenRouter", "Multi-model gateway") + System_Ext(groq, "Groq", "Fast inference") + System_Ext(ollama, "Ollama", "Local LLM") + System_Ext(claude_cli, "Claude Code CLI", "Subprocess provider") + System_Ext(codex_cli, "Codex CLI", "Subprocess provider") + + Rel(user, discord, "Sends messages") + Rel(user, telegram, "Sends messages") + Rel(user, slack, "Sends messages") + Rel(user, whatsapp, "Sends messages") + Rel(user, cli, "Direct input") + Rel(dev, picoclaw, "Configures via config.json") + + Rel(discord, picoclaw, "Webhook/Bot events") + Rel(telegram, picoclaw, "Bot API") + Rel(slack, picoclaw, "Events API") + Rel(whatsapp, picoclaw, "Webhook") + Rel(cli, picoclaw, "Stdin/Stdout") + + Rel(picoclaw, openai, "HTTPS/REST") + Rel(picoclaw, anthropic, "HTTPS/REST") + Rel(picoclaw, gemini, "HTTPS/REST") + Rel(picoclaw, openrouter, "HTTPS/REST") + Rel(picoclaw, groq, "HTTPS/REST") + Rel(picoclaw, ollama, "HTTP localhost") + Rel(picoclaw, claude_cli, "Subprocess stdio") + Rel(picoclaw, codex_cli, "Subprocess stdio") +``` + +## Key interactions + +| Boundary | Protocol | Direction | +|----------|----------|-----------| +| User -> Channels | Platform-native (Discord bot, Telegram bot, etc.) | Inbound | +| Channels -> PicoClaw | Go channel bus (`pkg/bus`) | Internal | +| PicoClaw -> LLM Providers | HTTPS REST / Subprocess stdio | Outbound | +| Developer -> PicoClaw | `~/.picoclaw/config.json` + workspace files | Config | diff --git a/docs/architecture/c2-container.md b/docs/architecture/c2-container.md new file mode 100644 index 000000000..7276b7ad1 --- /dev/null +++ b/docs/architecture/c2-container.md @@ -0,0 +1,58 @@ +# C2 - Container Diagram + +Runtime containers inside PicoClaw. + +```mermaid +C4Container + title Container Diagram - PicoClaw Runtime + + Person(user, "User") + + System_Boundary(picoclaw, "PicoClaw Process") { + Container(gateway, "Gateway", "Go HTTP server", "Exposes health/ready endpoints, manages lifecycle") + Container(channel_mgr, "Channel Manager", "pkg/channels", "Manages Discord, Telegram, Slack, WhatsApp, CLI connections") + Container(msg_bus, "Message Bus", "pkg/bus", "Pub/sub event bus routing messages between channels and agents") + Container(agent_loop, "Agent Loop", "pkg/agent", "Core orchestrator: routes messages to agents, manages tool loops, sessions") + Container(registry, "Agent Registry", "pkg/agent", "Stores AgentInstance configs, resolves agent by ID or route") + Container(router, "Route Resolver", "pkg/routing", "Matches incoming messages to agents based on channel/chat/peer bindings") + Container(multiagent, "Multi-Agent Framework", "pkg/multiagent", "Blackboard shared context, Handoff mechanism, Agent discovery tools") + Container(tools, "Tool Registry", "pkg/tools", "Shell, file, web, session, message, spawn, exec tools") + Container(providers, "Provider Layer", "pkg/providers", "LLM provider abstraction: HTTP, CLI, OAuth, Fallback chain") + Container(session, "Session Store", "pkg/session", "Per-agent session persistence with conversation history") + Container(skills, "Skills Engine", "pkg/skills", "Loads SKILL.md files, provides skill tools to agents") + Container(config, "Config", "pkg/config", "Loads config.json, agent definitions, model_list") + } + + System_Ext(llm, "LLM Providers", "OpenAI, Anthropic, Gemini, Groq, Ollama, Claude CLI, Codex CLI") + System_Ext(channels_ext, "Messaging Platforms", "Discord, Telegram, Slack, WhatsApp") + + Rel(user, channels_ext, "Sends message") + Rel(channels_ext, channel_mgr, "Delivers event") + Rel(channel_mgr, msg_bus, "Publishes message") + Rel(msg_bus, agent_loop, "Delivers to agent") + Rel(agent_loop, router, "Resolves target agent") + Rel(agent_loop, registry, "Gets AgentInstance") + Rel(agent_loop, multiagent, "Blackboard sync, Handoff") + Rel(agent_loop, tools, "Executes tool calls") + Rel(agent_loop, session, "Load/save history") + Rel(agent_loop, skills, "Resolves skill tools") + Rel(agent_loop, providers, "LLM Chat()") + Rel(providers, llm, "API calls") + Rel(gateway, agent_loop, "Lifecycle management") + Rel(config, agent_loop, "Agent definitions") + Rel(config, providers, "Provider config") + Rel(config, registry, "AgentConfig list") +``` + +## Container responsibilities + +| Container | Package | Key types | +|-----------|---------|-----------| +| Agent Loop | `pkg/agent` | `AgentLoop`, `RunToolLoop()` | +| Agent Registry | `pkg/agent` | `AgentRegistry`, `AgentInstance` | +| Route Resolver | `pkg/routing` | `RouteResolver`, `SessionKeyBuilder` | +| Multi-Agent | `pkg/multiagent` | `Blackboard`, `HandoffTool`, `ListAgentsTool` | +| Provider Layer | `pkg/providers` | `LLMProvider`, `FallbackChain`, `HTTPProvider` | +| Tool Registry | `pkg/tools` | `Tool`, `ContextualTool`, `AsyncTool` | +| Session Store | `pkg/session` | `SessionStore`, conversation history | +| Config | `pkg/config` | `Config`, `AgentConfig`, `ModelConfig` | diff --git a/docs/architecture/c3-component-multi-agent.md b/docs/architecture/c3-component-multi-agent.md new file mode 100644 index 000000000..a7d0ba898 --- /dev/null +++ b/docs/architecture/c3-component-multi-agent.md @@ -0,0 +1,346 @@ +# C3 - Component Diagram: Multi-Agent Framework + +Detailed view of the multi-agent collaboration components. +Includes both current (PR #423) and planned (Phases 1-4) components. + +## Core Multi-Agent Components (Current) + +```mermaid +C4Component + title Component Diagram - Multi-Agent Collaboration (pkg/multiagent + pkg/agent) + + Container_Boundary(agent_pkg, "pkg/agent") { + Component(loop, "AgentLoop", "loop.go", "Core orchestrator: tool loop, LLM calls, session management") + Component(registry, "AgentRegistry", "registry.go", "Stores AgentInstance map, resolves by ID, lists all agents") + Component(instance, "AgentInstance", "instance.go", "Per-agent config: ID, Name, Role, SystemPrompt, tools, workspace") + Component(resolver_adapter, "registryResolver", "loop.go", "Adapter: bridges AgentRegistry to multiagent.AgentResolver interface") + } + + Container_Boundary(multiagent_pkg, "pkg/multiagent") { + Component(blackboard, "Blackboard", "blackboard.go", "Thread-safe shared key-value store with author/scope/timestamp metadata") + Component(bb_tool, "BlackboardTool", "blackboard_tool.go", "LLM tool: read/write/list/delete on shared context") + Component(handoff, "ExecuteHandoff", "handoff.go", "Resolves target agent, writes context to blackboard, delegates via RunToolLoop") + Component(handoff_tool, "HandoffTool", "handoff_tool.go", "LLM tool: delegates sub-task to another agent with optional context") + Component(list_tool, "ListAgentsTool", "list_agents_tool.go", "LLM tool: returns all registered agents with ID/Name/Role") + Component(agent_resolver, "AgentResolver", "handoff.go", "Interface: GetAgentInfo(id), ListAgents() - decouples from pkg/agent") + } + + Container_Boundary(routing_pkg, "pkg/routing") { + Component(route_resolver, "RouteResolver", "route.go", "Matches message to agent based on channel/chat/peer bindings") + Component(session_key, "SessionKeyBuilder", "session_key.go", "Builds per-agent session keys from channel+chat+agent") + Component(agent_id, "AgentID", "agent_id.go", "Normalizes agent identifiers") + } + + Container_Boundary(providers_pkg, "pkg/providers") { + Component(fallback, "FallbackChain", "fallback.go", "Tries candidates in order, skips cooled-down, classifies errors") + Component(cooldown, "CooldownTracker", "cooldown.go", "Per-model failure tracking with exponential backoff") + Component(error_cls, "ErrorClassifier", "error_classifier.go", "Maps HTTP errors to FailoverReason: rate_limit, billing, auth, etc.") + Component(factory, "CreateProvider", "factory.go", "Resolves config to provider: HTTP, CLI, OAuth, Fallback") + } + + Rel(loop, registry, "GetInstance(agentID)") + Rel(loop, resolver_adapter, "Creates on init") + Rel(resolver_adapter, registry, "Delegates to") + Rel(resolver_adapter, agent_resolver, "Implements") + + Rel(loop, blackboard, "getOrCreateBlackboard(sessionKey)") + Rel(loop, bb_tool, "Registers when >1 agent") + Rel(loop, handoff_tool, "Registers when >1 agent") + Rel(loop, list_tool, "Registers when >1 agent") + + Rel(handoff_tool, handoff, "Calls ExecuteHandoff()") + Rel(handoff, agent_resolver, "GetAgentInfo(targetID)") + Rel(handoff, blackboard, "Writes handoff context") + Rel(handoff, loop, "Calls RunToolLoop() for target agent") + + Rel(bb_tool, blackboard, "CRUD operations") + Rel(list_tool, agent_resolver, "ListAgents()") + + Rel(loop, route_resolver, "ResolveAgent(msg)") + Rel(loop, session_key, "BuildKey(channel, chat, agent)") + Rel(loop, fallback, "Chat() with fallback") + Rel(fallback, cooldown, "Check/update cooldown") + Rel(fallback, error_cls, "ClassifyError()") +``` + +## Planned Components (Phases 1-4) + +```mermaid +C4Component + title Planned Components - Hardening Phases + + Container_Boundary(tools_pkg, "pkg/tools (Phase 1-3)") { + Component(hooks, "ToolHook", "hooks.go", "BeforeExecute/AfterExecute interface for tool call interception") + Component(groups, "ToolGroups", "groups.go", "Named tool groups: fs, web, exec, sessions, memory") + Component(policy, "PolicyPipeline", "policy.go", "Layered allow/deny: global -> per-agent -> per-depth") + Component(loop_det, "LoopDetector", "loop_detector.go", "Generic repeat + ping-pong detection with configurable thresholds") + } + + Container_Boundary(multiagent_new, "pkg/multiagent (Phase 3-4)") { + Component(cascade, "CascadeStop", "cascade.go", "RunRegistry + recursive context cancellation") + Component(spawn, "AsyncSpawn", "spawn.go", "Non-blocking agent invocation via goroutines") + Component(announce, "AnnounceProtocol", "announce.go", "Result delivery: steer/queue/direct modes") + } + + Container_Boundary(providers_new, "pkg/providers (Phase 3)") { + Component(auth_rot, "AuthRotator", "auth_rotation.go", "Round-robin profiles + 2-track cooldown (transient + billing)") + } + + Container_Boundary(gateway_new, "pkg/gateway (Phase 4)") { + Component(dedup, "DedupCache", "dedup.go", "Idempotency layer with TTL-based deduplication") + } + + Rel(hooks, loop_det, "AfterExecute feeds detection") + Rel(hooks, policy, "BeforeExecute applies policy") + Rel(policy, groups, "Resolves group references") + Rel(cascade, spawn, "Tracks child runs") + Rel(spawn, announce, "Delivers results") + Rel(auth_rot, fallback, "Enhances with profile rotation") +``` + +## Known Issues (Pre-Phase 1) + +```mermaid +graph TD + BUG1[Blackboard Split-Brain]:::critical + BUG2[No Recursion Guard]:::critical + BUG3[Handoff Ignores Allowlist]:::high + BUG4[SubagentsConfig.Model Unused]:::low + + BUG1 --> FIX1[Phase 1a: Unify board per session] + BUG2 --> FIX2[Phase 1b: Depth + cycle detection] + BUG3 --> FIX3[Phase 1c: Check CanSpawnSubagent] + BUG4 --> FIX4[Defer to Phase 4] + + classDef critical fill:#ef4444,color:#fff + classDef high fill:#f59e0b,color:#000 + classDef low fill:#6b7280,color:#fff +``` + +### Blackboard Split-Brain Detail + +```mermaid +sequenceDiagram + participant RS as registerSharedTools + participant BT as BlackboardTool + participant RL as runAgentLoop + participant SP as System Prompt + + Note over RS: At startup + RS->>RS: sharedBoard := NewBlackboard() + RS->>BT: NewBlackboardTool(sharedBoard, agentID) + + Note over RL: At runtime (per message) + RL->>RL: sessionBoard := getOrCreateBlackboard(sessionKey) + RL->>SP: sessionBoard.Snapshot() → inject into system prompt + + Note over BT,SP: BUG: sharedBoard ≠ sessionBoard + BT->>RS: Write to sharedBoard ← WRONG BOARD + SP->>RL: Read from sessionBoard ← DIFFERENT OBJECT + + Note over BT,SP: FIX: SetBoard(sessionBoard) before execution +``` + +## Blackboard Data Model + +```mermaid +classDiagram + class Blackboard { + -entries map[string]BlackboardEntry + -mu sync.RWMutex + +Set(key, value, author, scope) + +Get(key) string + +GetEntry(key) BlackboardEntry + +Delete(key) + +List() []BlackboardEntry + +Snapshot() string + +Size() int + +MarshalJSON() []byte + +UnmarshalJSON([]byte) + } + + class BlackboardEntry { + +Key string + +Value string + +Author string + +Scope string + +Timestamp time.Time + } + + class BlackboardTool { + -board *Blackboard + -agentID string + +Name() string + +Execute(args) string + +SetBoard(board) void + } + + class HandoffRequest { + +FromAgentID string + +ToAgentID string + +Task string + +Context map[string]string + +Depth int + +Visited []string + } + + class HandoffResult { + +AgentID string + +Response string + +Success bool + +Error string + +Iterations int + } + + class AgentResolver { + <> + +GetAgentInfo(agentID) *AgentInfo + +ListAgents() []AgentInfo + } + + class AllowlistChecker { + <> + +CanHandoff(from, to) bool + } + + class AgentInfo { + +ID string + +Name string + +Role string + +SystemPrompt string + } + + Blackboard "1" --> "*" BlackboardEntry : stores + BlackboardTool --> Blackboard : operates on + HandoffRequest ..> AgentResolver : resolved via + HandoffRequest ..> AllowlistChecker : verified by + AgentResolver --> AgentInfo : returns +``` + +## Tool Policy Pipeline (Phase 2) + +```mermaid +graph LR + subgraph "Input" + ALL[All Registered Tools] + end + + subgraph "Pipeline Steps" + S1[Global Allow/Deny] + S2[Per-Agent Allow/Deny] + S3[Per-Depth Deny] + S4[Sandbox Override] + end + + subgraph "Output" + FINAL[Filtered Tools for Agent] + end + + ALL --> S1 --> S2 --> S3 --> S4 --> FINAL + + style S1 fill:#3b82f6,color:#fff + style S2 fill:#f59e0b,color:#000 + style S3 fill:#ef4444,color:#fff + style S4 fill:#8b5cf6,color:#fff +``` + +```mermaid +graph TD + subgraph "Tool Groups" + GFS["group:fs
read_file, write_file, edit_file, append_file, list_dir"] + GWEB["group:web
web_search, web_fetch"] + GEXEC["group:exec
exec"] + GSESS["group:sessions
blackboard, handoff, list_agents, spawn"] + GMEM["group:memory
memory_search, memory_get"] + end + + subgraph "Depth Deny Rules" + D0["Depth 0 (main)
Full access"] + D1["Depth 1+ (subagent)
Deny: gateway"] + DL["Depth = max (leaf)
Deny: spawn, handoff, list_agents"] + end +``` + +## Async Multi-Agent Flow (Phase 4) + +```mermaid +sequenceDiagram + participant U as User + participant MA as Main Agent + participant SP as AsyncSpawn + participant SA1 as Subagent: Researcher + participant SA2 as Subagent: Analyst + participant AN as AnnounceProtocol + participant BB as Blackboard + + U->>MA: "Research and analyze market trends" + MA->>SP: AsyncSpawn(researcher, "find market data") + SP-->>MA: RunID: abc-123 + MA->>SP: AsyncSpawn(analyst, "prepare analysis framework") + SP-->>MA: RunID: def-456 + MA-->>U: "Working on it — spawned 2 agents..." + + par Parallel execution + SA1->>BB: write("market_data", findings) + SA1-->>AN: Complete: "Found 5 key trends" + and + SA2->>BB: write("framework", analysis_template) + SA2-->>AN: Complete: "Framework ready" + end + + AN->>MA: Announce(researcher result) [steer] + AN->>MA: Announce(analyst result) [queue] + + MA->>BB: read("market_data") + MA->>BB: read("framework") + MA->>MA: Synthesize results + MA-->>U: "Here's the market analysis..." +``` + +## Provider Protocol Architecture (PR #213 + #283) + +```mermaid +graph TB + subgraph "Config Layer" + CFG[config.json] + ML[model_list - future] + end + + subgraph "Factory (factory.go)" + RS[resolveProviderSelection] + CP[CreateProvider] + end + + subgraph "Protocol Families" + subgraph "OpenAI-Compatible" + OC[openai_compat/provider.go] + HTTP[HTTPProvider - thin delegate] + end + subgraph "Anthropic" + ANT[anthropic/provider.go] + CP2[ClaudeProvider] + end + subgraph "CLI-Based" + CC[ClaudeCliProvider] + CX[CodexCliProvider] + end + subgraph "Resilience" + FB[FallbackChain] + CD[CooldownTracker] + EC[ErrorClassifier] + AR[AuthRotator - Phase 3] + end + end + + CFG --> RS + ML -.-> RS + RS --> CP + CP --> HTTP + CP --> ANT + CP --> CC + CP --> CX + HTTP --> OC + FB --> CD + FB --> EC + AR -.-> FB +``` diff --git a/docs/architecture/c4-code-detail.md b/docs/architecture/c4-code-detail.md new file mode 100644 index 000000000..89aee8d6f --- /dev/null +++ b/docs/architecture/c4-code-detail.md @@ -0,0 +1,148 @@ +# C4 - Code Detail + +Key interfaces, structs, and data flows at the code level. + +## Core Interfaces + +```mermaid +classDiagram + class LLMProvider { + <> + +Chat(ctx, messages, tools, model, opts) *LLMResponse + +GetDefaultModel() string + } + + class Tool { + <> + +Name() string + +Description() string + +Parameters() map[string]any + +Execute(args map[string]any) (string, error) + } + + class ContextualTool { + <> + +SetContext(ctx ToolContext) + } + + class AsyncTool { + <> + +ExecuteAsync(ctx, args) (string, error) + } + + class AgentResolver { + <> + +GetAgentInfo(agentID string) *AgentInfo + +ListAgents() []AgentInfo + } + + Tool <|-- ContextualTool + Tool <|-- AsyncTool + LLMProvider <|.. HTTPProvider + LLMProvider <|.. ClaudeCliProvider + LLMProvider <|.. CodexCliProvider + LLMProvider <|.. CodexProvider + LLMProvider <|.. ClaudeProvider + LLMProvider <|.. FallbackChain + Tool <|.. BlackboardTool + Tool <|.. HandoffTool + Tool <|.. ListAgentsTool + ContextualTool <|.. HandoffTool + AgentResolver <|.. registryResolver +``` + +## Tool Loop Execution + +```mermaid +flowchart TD + MSG[Incoming Message] --> RR[RouteResolver.ResolveAgent] + RR --> AI[AgentInstance selected] + AI --> SL[Session.Load history] + SL --> SP[Build system prompt] + SP --> BS{Multi-agent?} + + BS -->|Yes| INJ[Inject Blackboard snapshot into system prompt] + BS -->|No| LLM + + INJ --> LLM[provider.Chat - send to LLM] + LLM --> RESP{Response type?} + + RESP -->|Text only| OUT[Return text to channel] + RESP -->|Tool calls| TC[Execute tool calls] + + TC --> WHICH{Which tool?} + + WHICH -->|blackboard| BB[BlackboardTool.Execute] + BB --> WL[Read/Write/List/Delete shared context] + WL --> LLM + + WHICH -->|handoff| HO[HandoffTool.Execute] + HO --> EH[ExecuteHandoff] + EH --> TA[Resolve target agent] + TA --> WC[Write context to blackboard] + WC --> RTL[RunToolLoop for target agent] + RTL --> HR[HandoffResult] + HR --> LLM + + WHICH -->|list_agents| LA[ListAgentsTool.Execute] + LA --> LLM + + WHICH -->|shell, file, web...| OT[Other tools execute] + OT --> LLM + + RESP -->|stop| SS[Session.Save] + SS --> OUT +``` + +## Blackboard Lifecycle + +```mermaid +stateDiagram-v2 + [*] --> Empty: getOrCreateBlackboard(sessionKey) + + Empty --> HasEntries: Agent writes via BlackboardTool + HasEntries --> HasEntries: Agent reads/writes/deletes + HasEntries --> Snapshot: System prompt build + Snapshot --> HasEntries: Snapshot injected, loop continues + + HasEntries --> HandoffContext: Handoff writes "handoff_context_*" + HandoffContext --> TargetReads: Target agent reads context + TargetReads --> HasEntries: Target completes, result written + + HasEntries --> Empty: All entries deleted + Empty --> [*]: Session ends + + note right of Snapshot + Snapshot format: + ## Shared Context + - key1: value1 (by agent-a) + - key2: value2 (by agent-b) + end note +``` + +## Fallback Chain Decision Tree + +```mermaid +flowchart TD + REQ[Chat Request] --> P[Try Primary Model] + P --> PS{Success?} + PS -->|Yes| RST[Reset cooldown] --> RET[Return response] + PS -->|No| CLS[ClassifyError] + + CLS --> RT{Retriable?} + RT -->|No: auth, format| FAIL[Return error immediately] + RT -->|Yes| NXT[Next candidate] + + NXT --> CD{In cooldown?} + CD -->|Yes| SKIP[Skip, try next] + CD -->|No| TRY[Try candidate] + + SKIP --> MORE{More candidates?} + TRY --> TS{Success?} + TS -->|Yes| RST2[Reset cooldown] --> RET + TS -->|No| REC[Record failure, update cooldown] + REC --> MORE + + MORE -->|Yes| NXT + MORE -->|No| EXHAUST[FallbackExhaustedError] +``` diff --git a/docs/architecture/roadmap.md b/docs/architecture/roadmap.md new file mode 100644 index 000000000..921253f06 --- /dev/null +++ b/docs/architecture/roadmap.md @@ -0,0 +1,136 @@ +# Multi-Agent Feature Roadmap + +Phased implementation plan based on issues #294, #283, and discussion #122. +Updated with OpenClaw (moltbot) gap analysis — patterns validated by OpenAI (founder hired). + +## Phase Overview + +```mermaid +gantt + title PicoClaw Multi-Agent Roadmap + dateFormat YYYY-MM-DD + axisFormat %b %d + + section Provider Refactor (#283) + Phase 1: Protocol packages (PR #213) :done, p1, 2026-02-01, 2026-02-18 + Phase 2: model_list + explicit api_type :p2, 2026-03-15, 2026-04-01 + Phase 3: Independent Gemini protocol :p3, after p2, 7d + Phase 4: Local LLM + cleanup :p4, after p3, 5d + + section Multi-Agent (#294) + Fallback chain + routing (PR #131) :done, ma1, 2026-02-01, 2026-02-18 + Blackboard + Handoff + Discovery (PR #423) :done, ma2, 2026-02-18, 2026-03-01 + Phase 1: Foundation Fix + Guardrails :active, h1, 2026-02-19, 2026-02-28 + Phase 2: Tool Policy Pipeline :h2, after h1, 10d + Phase 3: Resilience :h3, after h2, 10d + Phase 4: Async Multi-Agent :h4, after h3, 14d + + section Workspace + SOUL.md Bootstrap (separate PR) :active, soul, 2026-02-19, 2026-03-01 + + section Integration + model_list + multi-agent config merge :int1, after p2, 7d + Community agent marketplace :int2, after h4, 21d +``` + +## Hardening Phases (Based on OpenClaw Gap Analysis) + +### Phase 1: Foundation Fix + Guardrails + +| Task | Description | OpenClaw Reference | +|------|-------------|-------------------| +| Fix blackboard split-brain | Unify static board and session board | N/A (picoclaw-specific bug) | +| Recursion guard | Depth counter + cycle detection in handoff | `subagent-depth.ts` | +| Handoff allowlist | Enforce CanSpawnSubagent in ExecuteHandoff | `subagent-spawn.ts` (allowlist check) | +| Before-tool-call hooks | ToolHook interface for extensibility | `pi-tools.before-tool-call.ts` | + +### Phase 2: Tool Policy Pipeline + +| Task | Description | OpenClaw Reference | +|------|-------------|-------------------| +| Tool groups | Named groups: fs, web, exec, sessions, memory | `tool-policy.ts` (TOOL_GROUPS) | +| Per-agent allow/deny | Config-driven tool filtering | `tool-policy-pipeline.ts` | +| Subagent deny-by-depth | Leaf agents can't spawn/handoff | `pi-tools.policy.ts` | +| Pipeline composition | Layered: global → agent → depth | `tool-policy-pipeline.ts` | + +### Phase 3: Resilience + +| Task | Description | OpenClaw Reference | +|------|-------------|-------------------| +| Loop detection | Generic repeat + ping-pong detectors | `tool-loop-detection.ts` | +| Context overflow recovery | Compaction → truncation → user error | `pi-embedded-runner/run.ts` | +| Auth profile rotation | Round-robin + 2-track cooldown | `auth-profiles/order.ts` + `usage.ts` | +| Cascade stop | Context cancellation propagation | `subagents-tool.ts` | + +### Phase 4: Async Multi-Agent + +| Task | Description | OpenClaw Reference | +|------|-------------|-------------------| +| Async spawn | Non-blocking via goroutines | `subagent-spawn.ts` | +| Announce protocol | Result injection: steer/queue/direct | `subagent-announce.ts` | +| Process isolation | Scope-keyed exec tool | Process supervisor (scope-keyed) | +| Idempotency | Dedup cache for gateway RPC | `server-methods/agent.ts` | + +## Completed + +| Phase | PR | Description | +|-------|----|-------------| +| Provider Protocol Refactor | [#213](https://github.com/sipeed/picoclaw/pull/213) | `protocoltypes/`, `openai_compat/`, `anthropic/` packages | +| Fallback Chain + Routing | [#131](https://github.com/sipeed/picoclaw/pull/131) | `FallbackChain`, `CooldownTracker`, `RouteResolver` | +| Blackboard + Handoff + Discovery | [#423](https://github.com/sipeed/picoclaw/pull/423) | Blackboard, HandoffTool, ListAgentsTool, AgentResolver | +| golangci-lint compliance | [#304](https://github.com/sipeed/picoclaw/pull/304) | 62 lint issues fixed in our code | +| C4 Architecture docs | [#423](https://github.com/sipeed/picoclaw/pull/423) | This directory | + +## Dependency Graph + +```mermaid +graph TD + PR213[PR #213: Protocol Refactor]:::done --> PR131[PR #131: Fallback + Routing]:::done + PR131 --> PR423[PR #423: Blackboard + Handoff]:::done + + PR423 --> H1[Phase 1: Foundation Fix]:::active + H1 --> H2[Phase 2: Tool Policy]:::planned + H2 --> H3[Phase 3: Resilience]:::planned + H3 --> H4[Phase 4: Async Multi-Agent]:::planned + + PR213 --> P2[Provider Phase 2: model_list]:::future + P2 --> P3[Provider Phase 3: Gemini]:::future + P3 --> P4[Provider Phase 4: Local LLM]:::future + + SOUL[SOUL.md Bootstrap]:::active + + H4 --> SWARM[Swarm Mode]:::future + H4 --> DASH[AIEOS Dashboard]:::future + P2 --> MCONFIG[model_list + multi-agent merge]:::future + + classDef done fill:#22c55e,color:#fff + classDef active fill:#eab308,color:#000 + classDef planned fill:#3b82f6,color:#fff + classDef future fill:#6b7280,color:#fff +``` + +## Key Decisions + +| Decision | Choice | Rationale | +|----------|--------|-----------| +| Shared context pattern | Blackboard (key-value) | OpenClaw has no shared state (announce-only). Blackboard is more flexible. | +| Handoff mechanism | Synchronous → async in Phase 4 | Start simple, add async when foundation is solid | +| Tool policy model | Layered pipeline (like OpenClaw) | Composable, debuggable, per-layer narrowing | +| Loop detection | 2 detectors (repeat + ping-pong) | OpenClaw has 4 — start with 2 most impactful | +| Auth rotation | 2-track cooldown (transient + billing) | Directly ported from OpenClaw, proven in production | +| Recursion guard | Depth counter + visited set | Simple, O(1) check, prevents both depth and cycles | +| Reference implementation | OpenClaw (moltbot) | Founder hired by OpenAI — patterns are industry-validated | + +## picoclaw vs. OpenClaw Comparison + +| Feature | OpenClaw (Node.js) | picoclaw (Go) | Advantage | +|---------|-------------------|---------------|-----------| +| Shared agent state | None (announce-only) | Blackboard | picoclaw | +| Performance | Node.js runtime | Single Go binary | picoclaw (10x less memory) | +| Deployment | npm install + node | Copy binary | picoclaw | +| Tool policy | 8-layer pipeline | Planned (Phase 2) | OpenClaw (for now) | +| Loop detection | 4 detectors | Planned (Phase 3) | OpenClaw (for now) | +| Auth rotation | 2-track + file lock | FallbackChain only | OpenClaw (for now) | +| Async spawn | Lane-based + announce | Planned (Phase 4) | OpenClaw (for now) | +| Concurrency model | async/await | goroutines | picoclaw | +| Type safety | TypeScript | Go compiler | picoclaw | diff --git a/docs/architecture/sequences.md b/docs/architecture/sequences.md new file mode 100644 index 000000000..380b9e51c --- /dev/null +++ b/docs/architecture/sequences.md @@ -0,0 +1,389 @@ +# Sequence Diagrams + +Runtime interaction flows for multi-agent collaboration. + +## 1. Agent Handoff Flow (Current) + +A main agent delegates a sub-task to a specialized agent. + +```mermaid +sequenceDiagram + participant U as User (Discord/Telegram) + participant CH as Channel Manager + participant AL as AgentLoop + participant RR as RouteResolver + participant MA as Main Agent + participant LLM as LLM Provider + participant HT as HandoffTool + participant BB as Blackboard + participant AR as AgentResolver + participant SA as Specialized Agent + + U->>CH: "Translate this code to Python" + CH->>AL: Message{channel, chat_id, content} + AL->>RR: ResolveAgent(channel, chat_id) + RR-->>AL: "main" + AL->>MA: Load AgentInstance + session + AL->>BB: getOrCreateBlackboard(sessionKey) + AL->>AL: Inject BB snapshot into system prompt + AL->>LLM: Chat(messages + tools) + LLM-->>AL: ToolCall{name: "handoff", args: {target: "coder", task: "translate to python"}} + + AL->>HT: Execute(args) + HT->>AR: GetAgentInfo("coder") + AR-->>HT: AgentInfo{ID: "coder", Role: "Code Expert"} + HT->>BB: Set("handoff_context_coder", task + context) + HT->>AL: RunToolLoop(coderAgent, task) + + AL->>SA: Load AgentInstance("coder") + AL->>LLM: Chat(coder_system_prompt + task) + LLM-->>AL: "Here's the Python translation..." + AL-->>HT: HandoffResult{Response: "...", Success: true} + + HT-->>AL: Tool result string + AL->>LLM: Chat(messages + tool_result) + LLM-->>AL: "The coder agent translated your code: ..." + AL->>CH: Send response + CH->>U: "The coder agent translated your code: ..." +``` + +## 2. Blackboard Shared Context Flow (Current) + +Multiple agents share data through the blackboard within a session. + +```mermaid +sequenceDiagram + participant A1 as Agent: Researcher + participant BB as Blackboard + participant AL as AgentLoop + participant A2 as Agent: Writer + + Note over A1, A2: Same session, shared blackboard + + A1->>BB: BlackboardTool.write("findings", "3 key points...") + BB-->>A1: OK + + A1->>BB: BlackboardTool.write("sources", "arxiv:2024...") + BB-->>A1: OK + + Note over AL: Handoff from Researcher to Writer + + AL->>BB: Snapshot() + BB-->>AL: "findings: 3 key points... (by researcher)\nsources: arxiv:2024... (by researcher)" + + AL->>A2: System prompt + snapshot + task + + A2->>BB: BlackboardTool.read("findings") + BB-->>A2: "3 key points..." + + A2->>BB: BlackboardTool.read("sources") + BB-->>A2: "arxiv:2024..." + + A2->>BB: BlackboardTool.write("draft", "Article based on findings...") + BB-->>A2: OK + + Note over BB: Blackboard state:
findings (by researcher)
sources (by researcher)
draft (by writer) +``` + +## 3. Model Fallback Chain Flow (Current) + +Provider resilience with automatic failover. + +```mermaid +sequenceDiagram + participant AL as AgentLoop + participant FB as FallbackChain + participant CD as CooldownTracker + participant EC as ErrorClassifier + participant P1 as Primary: GPT-4o + participant P2 as Fallback: Claude-3.5 + participant P3 as Fallback: DeepSeek + + AL->>FB: Chat(messages) + + FB->>CD: IsAvailable("gpt-4o")? + CD-->>FB: Yes + + FB->>P1: Chat(messages) + P1-->>FB: Error 429 (rate limit) + + FB->>EC: ClassifyError(err) + EC-->>FB: FailoverReason: RATE_LIMITED (retriable) + + FB->>CD: RecordFailure("gpt-4o", RATE_LIMITED) + Note over CD: gpt-4o cooldown: 30s + + FB->>CD: IsAvailable("claude-3.5")? + CD-->>FB: Yes + + FB->>P2: Chat(messages) + P2-->>FB: Error 503 (overloaded) + + FB->>EC: ClassifyError(err) + EC-->>FB: FailoverReason: OVERLOADED (retriable) + + FB->>CD: RecordFailure("claude-3.5", OVERLOADED) + + FB->>CD: IsAvailable("deepseek")? + CD-->>FB: Yes + + FB->>P3: Chat(messages) + P3-->>FB: LLMResponse{Content: "..."} + + FB->>CD: RecordSuccess("deepseek") + FB-->>AL: LLMResponse +``` + +## 4. Route Resolution Flow (Current) + +How incoming messages are routed to the correct agent. + +```mermaid +sequenceDiagram + participant MSG as Incoming Message + participant RR as RouteResolver + participant REG as AgentRegistry + participant SK as SessionKeyBuilder + + MSG->>RR: ResolveAgent(channel:"discord", chat:"123", peer_kind:"guild", peer_id:"456") + + RR->>RR: Check bindings + Note over RR: Binding: {channel: "discord", chat: "123"} -> "support-agent" + + alt Match found + RR-->>MSG: "support-agent" + else No match + RR->>REG: GetDefault() + REG-->>RR: "main" + RR-->>MSG: "main" + end + + MSG->>SK: BuildKey(channel, chat, agentID) + SK-->>MSG: "discord:123:support-agent" + Note over MSG: Session key used for:
- Session history
- Blackboard lookup
- State persistence +``` + +## 5. Multi-Agent Configuration Lifecycle (Current) + +From config.json to running agents. + +```mermaid +sequenceDiagram + participant CFG as config.json + participant REG as AgentRegistry + participant INST as AgentInstance + participant LOOP as AgentLoop + participant TOOLS as Tool Registry + + CFG->>REG: Parse agents.list[] + + loop For each AgentConfig + REG->>INST: NewAgentInstance(agentCfg, cfg) + INST->>INST: Set ID, Name, Role, SystemPrompt + INST->>INST: Create per-agent tools (shell, file, exec) + INST-->>REG: Register(instance) + end + + alt No agents.list configured + REG->>INST: Create implicit "main" agent + INST-->>REG: Register as default + end + + REG-->>LOOP: Registry ready + + LOOP->>LOOP: Check registry.ListAgentIDs() + + alt len(agents) > 1 + LOOP->>TOOLS: Register BlackboardTool + LOOP->>TOOLS: Register HandoffTool + LOOP->>TOOLS: Register ListAgentsTool + Note over TOOLS: Multi-agent tools active + else Single agent + Note over TOOLS: No multi-agent tools (zero overhead) + end +``` + +## 6. Handoff with Guardrails (Phase 1 — Planned) + +Handoff with depth limit, cycle detection, and allowlist enforcement. + +```mermaid +sequenceDiagram + participant MA as Main Agent (depth=0) + participant HT as HandoffTool + participant GD as Guards + participant AL as AllowlistChecker + participant SA as Agent B (depth=1) + participant SC as Agent C (depth=2) + + MA->>HT: handoff(target="B", task="research") + HT->>GD: Check depth (0 < maxDepth=3) + GD-->>HT: OK + HT->>GD: Check cycle (visited=["main"]) + GD-->>HT: OK (B not in visited) + HT->>AL: CanHandoff("main", "B") + AL-->>HT: Allowed + + HT->>SA: ExecuteHandoff(depth=1, visited=["main","B"]) + SA->>HT: handoff(target="C", task="analyze") + HT->>GD: Check depth (1 < maxDepth=3) + GD-->>HT: OK + HT->>GD: Check cycle (visited=["main","B"]) + GD-->>HT: OK (C not in visited) + + HT->>SC: ExecuteHandoff(depth=2, visited=["main","B","C"]) + SC-->>SA: Result + SA-->>MA: Combined result + + Note over MA: Cycle detection example + MA->>HT: handoff(target="B", task="...") + SA->>HT: handoff(target="main", task="...") + HT->>GD: Check cycle (visited=["main","B"]) + GD-->>HT: BLOCKED: "main" already in visited + HT-->>SA: Error: handoff cycle detected +``` + +## 7. Tool Policy Pipeline (Phase 2 — Planned) + +How tools are filtered before agent execution. + +```mermaid +sequenceDiagram + participant CFG as Config + participant PP as PolicyPipeline + participant GR as ToolGroups + participant AG as Agent Tools + participant DEPTH as DepthPolicy + + Note over PP: Agent "researcher" at depth=1 + + CFG->>PP: Global deny: ["gateway"] + PP->>GR: Resolve "gateway" → ["gateway"] + PP->>PP: Remove "gateway" from tool set + + CFG->>PP: Agent allow: ["group:web", "group:fs", "blackboard"] + PP->>GR: Resolve groups → [web_search, web_fetch, read_file, ...] + PP->>PP: Keep only allowed tools + + DEPTH->>PP: Depth=1 deny: ["spawn"] + PP->>PP: Remove "spawn" from tool set + + PP->>AG: Final tools: [web_search, web_fetch, read_file, write_file, blackboard, handoff] + Note over AG: Each layer narrows, never widens +``` + +## 8. Loop Detection (Phase 3 — Planned) + +How tool call loops are detected and blocked. + +```mermaid +sequenceDiagram + participant LLM as LLM Provider + participant HK as ToolHook + participant LD as LoopDetector + participant T as Tool + + loop Normal execution (calls 1-9) + LLM->>HK: BeforeExecute("web_search", {query: "same query"}) + HK->>LD: Check(hash("web_search:same query")) + LD-->>HK: OK (repeat count < 10) + HK->>T: Execute + T-->>HK: Result + HK->>LD: Record(hash, outcome) + end + + LLM->>HK: BeforeExecute("web_search", {query: "same query"}) + HK->>LD: Check (repeat count = 10) + LD-->>HK: WARNING: possible loop + Note over HK: Warning injected into tool result + + loop Calls 11-19 (with warning) + LLM->>HK: BeforeExecute("web_search", {query: "same query"}) + HK->>LD: Check + LD-->>HK: WARNING + end + + LLM->>HK: BeforeExecute("web_search", {query: "same query"}) + HK->>LD: Check (repeat count = 20) + LD-->>HK: BLOCKED: tool call repeated too many times + HK-->>LLM: Error: loop detected, try a different approach +``` + +## 9. Async Spawn + Announce (Phase 4 — Planned) + +Parallel agent execution with result delivery. + +```mermaid +sequenceDiagram + participant U as User + participant MA as Main Agent + participant SP as AsyncSpawn + participant RR as RunRegistry + participant SA1 as Researcher + participant SA2 as Analyst + participant AN as AnnounceProtocol + participant BB as Blackboard + + U->>MA: "Research market trends and analyze competitors" + + MA->>SP: AsyncSpawn(researcher, "find market data") + SP->>RR: Register(runID=abc, parent=main) + SP-->>MA: RunID: abc (non-blocking) + + MA->>SP: AsyncSpawn(analyst, "competitor analysis") + SP->>RR: Register(runID=def, parent=main) + SP-->>MA: RunID: def (non-blocking) + + MA-->>U: "Working on it — 2 agents spawned..." + + par Parallel Execution + SA1->>BB: write("market_data", findings) + SA1->>RR: Complete(abc) + SA1->>AN: Announce(to=main, content=findings) + and + SA2->>BB: write("competitors", analysis) + SA2->>RR: Complete(def) + SA2->>AN: Announce(to=main, content=analysis) + end + + AN->>MA: Steer: "Researcher completed: found 5 trends" + AN->>MA: Queue: "Analyst completed: 3 competitors analyzed" + + MA->>BB: read("market_data") + read("competitors") + MA->>MA: Synthesize + MA-->>U: "Here's the complete market analysis..." +``` + +## 10. Cascade Stop (Phase 3 — Planned) + +Stopping a parent agent cascades to all children. + +```mermaid +sequenceDiagram + participant U as User + participant RR as RunRegistry + participant MA as Main Agent + participant SA1 as Subagent 1 + participant SA2 as Subagent 2 + participant SA3 as Sub-subagent (child of SA1) + + Note over MA: Active run tree:
main → SA1 → SA3
main → SA2 + + U->>RR: CascadeStop("main") + + RR->>MA: Cancel context + MA->>MA: Stop processing + + RR->>RR: Find children of "main" + RR->>SA1: Cancel context + SA1->>SA1: Stop processing + + RR->>RR: Find children of SA1 + RR->>SA3: Cancel context + SA3->>SA3: Stop processing + + RR->>SA2: Cancel context + SA2->>SA2: Stop processing + + RR-->>U: Killed 4 runs (main + SA1 + SA2 + SA3) +``` diff --git a/docs/prp/multi-agent-hardening.md b/docs/prp/multi-agent-hardening.md new file mode 100644 index 000000000..8435438e4 --- /dev/null +++ b/docs/prp/multi-agent-hardening.md @@ -0,0 +1,770 @@ +# PRP: Multi-Agent Framework Hardening + +## Goal + +Harden the picoclaw multi-agent collaboration framework (PR #423) to production-grade quality by porting validated patterns from OpenClaw (moltbot) — the state-of-the-art personal AI gateway whose founder was hired by OpenAI. + +**PR**: [#423](https://github.com/sipeed/picoclaw/pull/423) +**Issue**: [#294](https://github.com/sipeed/picoclaw/issues/294) +**Branch**: `feat/multi-agent-framework` +**Reference**: `/home/leeaandrob/Projects/Personal/llm/auto-agents/moltbot` + +--- + +## What + +Transform picoclaw's multi-agent framework from a functional prototype (Blackboard + Handoff + Routing) into a production-ready orchestration system with guardrails, tool policy, resilience, and async capabilities — matching and exceeding OpenClaw's patterns in a lightweight Go single-binary. + +--- + +## Success Criteria + +### Phase 1: Foundation Fix + Guardrails +- [ ] Blackboard split-brain bug is fixed — tools and system prompt use the same board per session +- [ ] Handoff recursion is bounded — max depth enforced, cycle detection prevents A→B→A loops +- [ ] Handoff respects allowlist — same CanSpawnSubagent check as spawn tool +- [ ] Before-tool-call hook infrastructure exists — extensible for loop detection and policy +- [ ] All existing tests pass + new tests for each fix +- [ ] Zero regression in single-agent mode + +### Phase 2: Tool Policy Pipeline +- [ ] Tool groups defined: `group:fs`, `group:web`, `group:exec`, `group:sessions`, `group:memory` +- [ ] Per-agent `tools.allow` / `tools.deny` in config, supports group references +- [ ] Subagent tool restriction by depth (leaf agents can't spawn/handoff) +- [ ] Pipeline composes: global → per-agent → per-depth (each layer narrows, never widens) +- [ ] Config backward-compatible (no tools config = full access, like today) + +### Phase 3: Resilience +- [ ] Loop detection: generic repeat (hash-based) + ping-pong detector +- [ ] Context overflow recovery: auto-compaction + tool result truncation + user error +- [ ] Auth profile rotation: round-robin with 2-track cooldown (transient + billing) +- [ ] Cascade stop: context cancellation propagates through handoff/spawn chains + +### Phase 4: Async Multi-Agent +- [ ] Async spawn: non-blocking agent invocation via goroutines +- [ ] Announce protocol: result injection into parent session (steer/queue/direct) +- [ ] Scope-keyed process isolation: exec tool scoped by session key +- [ ] Idempotency: dedup map for duplicate message prevention + +--- + +## Current State Analysis + +### What exists (PR #423) + +| Component | Package | Status | Quality | +|-----------|---------|--------|---------| +| Blackboard | `pkg/multiagent` | Implemented | Good base, has split-brain bug | +| BlackboardTool | `pkg/multiagent` | Implemented | Works but uses wrong board | +| ExecuteHandoff | `pkg/multiagent` | Implemented | No recursion guard, ignores allowlist | +| HandoffTool | `pkg/multiagent` | Implemented | No self-handoff guard | +| ListAgentsTool | `pkg/multiagent` | Implemented | Exposes too little info | +| AgentResolver | `pkg/multiagent` | Implemented | Clean interface | +| RouteResolver | `pkg/routing` | Implemented | 7-tier cascade, complete | +| SessionKeyBuilder | `pkg/routing` | Implemented | DM scope support | +| AgentRegistry | `pkg/agent` | Implemented | CanSpawnSubagent exists | +| AgentLoop integration | `pkg/agent` | Implemented | Snapshot injection, conditional tools | + +### Critical Bug: Blackboard Split-Brain + +``` +registerSharedTools() → creates static `sharedBoard` per agent +runAgentLoop() → creates per-session board via getOrCreateBlackboard(sessionKey) + +BlackboardTool.Execute() → writes to static sharedBoard ← WRONG +messages[0].Content → reads from session board ← DIFFERENT OBJECT +``` + +**Impact**: Agents think they're sharing context but they're writing to void. Any multi-agent demo fails silently. + +### What's missing vs. OpenClaw + +| Feature | OpenClaw | picoclaw | Gap | +|---------|----------|----------|-----| +| Tool policy | 8-layer pipeline | None | Critical | +| Recursion guard | maxSpawnDepth + maxChildren | None | Critical | +| Loop detection | 4 detectors + circuit breaker | None | High | +| Context overflow | 3-tier cascade recovery | Basic forceCompression | High | +| Auth rotation | Round-robin + 2-track cooldown | FallbackChain only | Medium | +| Cascade stop | cascadeKillChildren + abort | None | Medium | +| Before-tool hook | wrapToolWithBeforeToolCallHook | None | Medium | +| Async spawn | Lane-based + announce | None | Low (Phase 4) | +| Process isolation | Scope-keyed | None | Low (Phase 4) | +| Idempotency | Gateway dedup + announce dedup | None | Low (Phase 4) | + +--- + +## Phase 1: Foundation Fix + Guardrails + +### 1a. Fix Blackboard Split-Brain + +**Problem**: `registerSharedTools` in `loop.go:178` creates one `sharedBoard` per agent at startup. `runAgentLoop` in `loop.go` creates separate per-session blackboards in `AgentLoop.blackboards` sync.Map. Tools write to one, system prompt reads from other. + +**Solution**: Make BlackboardTool and HandoffTool session-aware. + +**Approach A (Recommended)**: ContextualTool pattern +```go +// BlackboardTool already has agentID — add board setter +type BlackboardTool struct { + board *Blackboard + agentID string + mu sync.RWMutex +} + +func (t *BlackboardTool) SetBoard(board *Blackboard) { + t.mu.Lock() + defer t.mu.Unlock() + t.board = board +} +``` + +In `runAgentLoop`, before tool execution: +```go +bb := al.getOrCreateBlackboard(opts.SessionKey) +// Update all session-aware tools with the current session's board +for _, tool := range agent.Tools.List() { + if setter, ok := tool.(BoardSetter); ok { + setter.SetBoard(bb) + } +} +``` + +**Approach B**: Lazy board resolution via callback +```go +type BlackboardTool struct { + resolveBoard func() *Blackboard + agentID string +} +``` + +**Files to modify**: +- `pkg/multiagent/blackboard_tool.go` — Add `SetBoard` or callback +- `pkg/multiagent/handoff_tool.go` — Same pattern +- `pkg/agent/loop.go` — Wire session board to tools before execution + +**Tests**: +- Verify tool writes are visible in system prompt snapshot +- Verify cross-agent writes via handoff are visible +- Verify session isolation (board A != board B) + +### 1b. Recursion Guard + +**Problem**: Agent A can hand off to B, which hands off to A, causing infinite recursion. + +**Solution**: Add depth tracking and cycle detection to ExecuteHandoff. + +```go +// Add to HandoffRequest +type HandoffRequest struct { + FromAgentID string + ToAgentID string + Task string + Context map[string]string + Depth int // NEW: current nesting depth + Visited []string // NEW: agent IDs in the chain +} + +// In ExecuteHandoff +const DefaultMaxHandoffDepth = 3 + +func ExecuteHandoff(ctx context.Context, resolver AgentResolver, board *Blackboard, + req *HandoffRequest, channel, chatID string, maxDepth int) *HandoffResult { + + // Depth check + if req.Depth >= maxDepth { + return &HandoffResult{ + Error: fmt.Sprintf("handoff depth limit reached (%d/%d)", req.Depth, maxDepth), + Success: false, + } + } + + // Cycle detection + for _, visited := range req.Visited { + if visited == req.ToAgentID { + return &HandoffResult{ + Error: fmt.Sprintf("handoff cycle detected: %s already in chain %v", + req.ToAgentID, req.Visited), + Success: false, + } + } + } + + // Propagate depth + visited to nested handoffs + // ... (inject into target agent's tool context) +} +``` + +**Config**: +```json +{ + "agents": { + "defaults": { + "max_handoff_depth": 3 + } + } +} +``` + +**Files to modify**: +- `pkg/multiagent/handoff.go` — Depth + cycle detection +- `pkg/multiagent/handoff_tool.go` — Propagate depth/visited +- `pkg/config/config.go` — Add `MaxHandoffDepth` to `AgentDefaults` + +**Tests**: +- Direct handoff succeeds (depth 0→1) +- Chain handoff A→B→C succeeds (depth 0→1→2) +- Depth limit exceeded returns error +- Cycle A→B→A returns error +- Self-handoff returns error + +### 1c. Handoff Allowlist Enforcement + +**Problem**: `CanSpawnSubagent` only checked by spawn tool, not by handoff. + +**Solution**: Add allowlist check to `ExecuteHandoff`. + +```go +func ExecuteHandoff(...) *HandoffResult { + // ... depth/cycle checks above ... + + // Allowlist check (new) + if checker, ok := resolver.(AllowlistChecker); ok { + if !checker.CanHandoff(req.FromAgentID, req.ToAgentID) { + return &HandoffResult{ + Error: fmt.Sprintf("agent %q is not allowed to handoff to %q", + req.FromAgentID, req.ToAgentID), + Success: false, + } + } + } + + // ... rest of execution +} +``` + +**Files to modify**: +- `pkg/multiagent/handoff.go` — Add `AllowlistChecker` interface + check +- `pkg/agent/loop.go` — Make `registryResolver` implement `AllowlistChecker` + +**Tests**: +- Allowed handoff succeeds +- Disallowed handoff returns error +- Wildcard `"*"` allows all + +### 1d. Before-Tool-Call Hook Infrastructure + +**Problem**: No extensibility point for tool execution. Needed for loop detection (Phase 3) and tool policy (Phase 2). + +**Solution**: ToolHook interface wrapping tool execution. + +```go +// pkg/tools/hooks.go +type ToolHook interface { + BeforeExecute(ctx context.Context, toolName string, args map[string]any) (map[string]any, error) + AfterExecute(ctx context.Context, toolName string, args map[string]any, result *ToolResult, err error) +} + +// pkg/tools/registry.go — Add hook chain +func (r *ToolRegistry) SetHooks(hooks []ToolHook) { ... } + +func (r *ToolRegistry) ExecuteWithHooks(ctx context.Context, toolName string, args map[string]any) (*ToolResult, error) { + // Run BeforeExecute hooks in order + currentArgs := args + for _, hook := range r.hooks { + var err error + currentArgs, err = hook.BeforeExecute(ctx, toolName, currentArgs) + if err != nil { + return &ToolResult{Content: err.Error(), IsError: true}, nil + } + } + + // Execute tool + result := tool.Execute(ctx, currentArgs) + + // Run AfterExecute hooks (fire-and-forget) + for _, hook := range r.hooks { + hook.AfterExecute(ctx, toolName, currentArgs, result, nil) + } + + return result, nil +} +``` + +**Files to create**: +- `pkg/tools/hooks.go` — ToolHook interface + chain execution + +**Files to modify**: +- `pkg/tools/registry.go` — Add hooks field and ExecuteWithHooks +- `pkg/agent/loop.go` — Use ExecuteWithHooks in runLLMIteration + +**Tests**: +- Hook blocks tool → error returned +- Hook modifies args → tool receives modified args +- AfterExecute called with result +- Multiple hooks execute in order + +--- + +## Phase 2: Tool Policy Pipeline + +### 2a. Tool Groups + +```go +// pkg/tools/groups.go +var DefaultToolGroups = map[string][]string{ + "group:fs": {"read_file", "write_file", "edit_file", "append_file", "list_dir"}, + "group:web": {"web_search", "web_fetch"}, + "group:exec": {"exec"}, + "group:sessions": {"blackboard", "handoff", "list_agents", "spawn"}, + "group:memory": {"memory_search", "memory_get"}, + "group:message": {"message"}, + "group:image": {"image_generation"}, +} + +func ResolveToolNames(refs []string, groups map[string][]string) []string { + var resolved []string + for _, ref := range refs { + if tools, ok := groups[ref]; ok { + resolved = append(resolved, tools...) + } else { + resolved = append(resolved, ref) + } + } + return resolved +} +``` + +### 2b. Per-Agent Allow/Deny Config + +```go +// pkg/config/config.go — Add to AgentConfig +type ToolPolicyConfig struct { + Allow []string `json:"allow,omitempty"` // tool names or group refs + Deny []string `json:"deny,omitempty"` // tool names or group refs +} + +type AgentConfig struct { + // ... existing fields ... + Tools *ToolPolicyConfig `json:"tools,omitempty"` // NEW +} +``` + +### 2c. Subagent Deny-by-Depth + +```go +// pkg/tools/policy.go +type DepthPolicy struct { + MaxDepth int +} + +func (p *DepthPolicy) DenyListForDepth(depth int) []string { + if depth == 0 { + return nil // main agent: full access + } + + // All subagents lose dangerous tools + deny := []string{"gateway"} + + if depth >= p.MaxDepth { + // Leaf workers: no spawning + deny = append(deny, "spawn", "handoff", "list_agents") + } + return deny +} +``` + +### 2d. Pipeline Composition + +```go +// pkg/tools/policy.go +type PolicyStep struct { + Allow []string + Deny []string + Label string +} + +func ApplyPolicyPipeline(tools []Tool, steps []PolicyStep) []Tool { + remaining := tools + for _, step := range steps { + if len(step.Allow) > 0 { + allowed := toSet(ResolveToolNames(step.Allow, DefaultToolGroups)) + remaining = filter(remaining, func(t Tool) bool { + return allowed[t.Name()] + }) + } + if len(step.Deny) > 0 { + denied := toSet(ResolveToolNames(step.Deny, DefaultToolGroups)) + remaining = filter(remaining, func(t Tool) bool { + return !denied[t.Name()] + }) + } + } + return remaining +} +``` + +**Files to create**: +- `pkg/tools/groups.go` — Tool group definitions +- `pkg/tools/policy.go` — PolicyStep, ApplyPolicyPipeline, DepthPolicy + +**Files to modify**: +- `pkg/config/config.go` — ToolPolicyConfig in AgentConfig +- `pkg/agent/loop.go` — Apply policy pipeline in registerSharedTools +- `pkg/multiagent/handoff.go` — Apply depth policy for handoff target + +--- + +## Phase 3: Resilience + +### 3a. Loop Detection + +```go +// pkg/tools/loop_detector.go +type LoopDetector struct { + history []toolCallRecord + maxHistory int // default 30 + warnAt int // default 10 + blockAt int // default 20 +} + +type toolCallRecord struct { + Hash string + ToolName string + Timestamp time.Time + Outcome string // hash of result for no-progress detection +} + +func (d *LoopDetector) Check(toolName string, args map[string]any) LoopVerdict { + hash := hashToolCall(toolName, args) + + // Generic repeat detection + repeatCount := d.countRepeats(hash) + if repeatCount >= d.blockAt { + return LoopVerdict{Blocked: true, Reason: "tool call repeated too many times"} + } + if repeatCount >= d.warnAt { + return LoopVerdict{Warning: true, Reason: "possible loop detected"} + } + + // Ping-pong detection (A,B,A,B pattern) + if d.detectPingPong(hash) { + return LoopVerdict{Blocked: true, Reason: "ping-pong loop detected"} + } + + return LoopVerdict{} +} + +func (d *LoopDetector) Record(toolName string, args map[string]any, result string) { + // Append to history ring buffer, trim to maxHistory +} +``` + +**Files to create**: `pkg/tools/loop_detector.go`, `pkg/tools/loop_detector_test.go` + +### 3b. Context Overflow Recovery + +Enhance existing `forceCompression` in `loop.go`: + +```go +func (al *AgentLoop) recoverContextOverflow(ctx context.Context, agent *AgentInstance, + sessionKey string, err error) RecoveryResult { + + // Tier 1: LLM-based compaction (summarize history) + for attempt := 0; attempt < 3; attempt++ { + if al.compactSession(ctx, agent, sessionKey) { + return RecoveryResult{Recovered: true, Method: "compaction"} + } + } + + // Tier 2: Truncate oversized tool results in history + if al.truncateToolResults(ctx, agent, sessionKey) { + return RecoveryResult{Recovered: true, Method: "tool_truncation"} + } + + // Tier 3: Give up with user-facing error + return RecoveryResult{Recovered: false, Method: "none"} +} +``` + +**Files to modify**: `pkg/agent/loop.go` + +### 3c. Auth Profile Rotation + +Enhance existing `FallbackChain`: + +```go +// pkg/providers/auth_rotation.go +type AuthProfile struct { + ID string + Provider string + APIKey string + ErrorCount int + CooldownUntil time.Time + DisabledUntil time.Time // billing track + LastUsed time.Time +} + +type AuthRotator struct { + profiles []AuthProfile + mu sync.RWMutex +} + +func (r *AuthRotator) NextAvailable() *AuthProfile { + // Round-robin: sort by lastUsed (oldest first), skip cooldown +} + +func (r *AuthRotator) MarkFailure(profileID string, reason FailoverReason) { + // Transient: exponential 1min → 5min → 25min → 1hr + // Billing: separate disabledUntil track +} + +func (r *AuthRotator) MarkSuccess(profileID string) { + // Reset errorCount, update lastUsed +} +``` + +**Files to create**: `pkg/providers/auth_rotation.go` + +### 3d. Cascade Stop + +```go +// pkg/multiagent/cascade.go +type RunRegistry struct { + active sync.Map // sessionKey -> *ActiveRun +} + +type ActiveRun struct { + SessionKey string + Cancel context.CancelFunc + Children []string // child session keys +} + +func (r *RunRegistry) CascadeStop(sessionKey string) int { + killed := 0 + if run, ok := r.active.Load(sessionKey); ok { + ar := run.(*ActiveRun) + ar.Cancel() + killed++ + for _, child := range ar.Children { + killed += r.CascadeStop(child) + } + r.active.Delete(sessionKey) + } + return killed +} +``` + +**Files to create**: `pkg/multiagent/cascade.go` + +--- + +## Phase 4: Async Multi-Agent + +### 4a. Async Spawn + +```go +// pkg/multiagent/spawn.go +type SpawnResult struct { + RunID string + SessionKey string +} + +func AsyncSpawn(ctx context.Context, resolver AgentResolver, board *Blackboard, + req *HandoffRequest, channel, chatID string) *SpawnResult { + + runID := uuid.New().String() + childKey := fmt.Sprintf("agent:%s:subagent:%s", req.ToAgentID, runID) + + go func() { + result := ExecuteHandoff(ctx, resolver, board, req, channel, chatID, maxDepth) + // Announce result back to parent + announceResult(req.FromAgentID, result, childKey) + }() + + return &SpawnResult{RunID: runID, SessionKey: childKey} +} +``` + +### 4b. Announce Protocol + +```go +// pkg/multiagent/announce.go +type AnnounceMode string + +const ( + AnnounceSteer AnnounceMode = "steer" // inject into active LLM call + AnnounceQueue AnnounceMode = "queue" // hold until idle + AnnounceDirect AnnounceMode = "direct" // send immediately +) + +type Announcement struct { + FromSessionKey string + ToSessionKey string + Content string + Mode AnnounceMode + RunID string +} +``` + +### 4c. Scope-Keyed Process Isolation + +Scope the exec tool's process visibility by session key: + +```go +// In exec tool construction +type ScopedExecTool struct { + scopeKey string + // processes only visible within this scope +} +``` + +### 4d. Idempotency + +```go +// pkg/gateway/dedup.go +type DedupCache struct { + entries sync.Map // key -> *DedupEntry + ttl time.Duration // default 5min + maxSize int // default 1000 +} +``` + +--- + +## Dependencies + +``` +Phase 1a (blackboard fix) ← no deps, start immediately +Phase 1b (recursion guard) ← no deps +Phase 1c (allowlist) ← 1b (uses same HandoffRequest changes) +Phase 1d (tool hooks) ← no deps + +Phase 2a (groups) ← no deps +Phase 2b (per-agent policy) ← 2a +Phase 2c (depth policy) ← 2b + 1b (needs depth tracking) +Phase 2d (pipeline) ← 2a + 2b + 2c + 1d (hooks infrastructure) + +Phase 3a (loop detection) ← 1d (hooks) +Phase 3b (context recovery) ← no deps +Phase 3c (auth rotation) ← no deps (enhances existing FallbackChain) +Phase 3d (cascade stop) ← 1b (depth tracking) + +Phase 4a (async spawn) ← 1b + 3d (depth + cascade) +Phase 4b (announce) ← 4a +Phase 4c (process isolation) ← no deps +Phase 4d (idempotency) ← no deps +``` + +```mermaid +graph TD + P1A[1a: Fix Blackboard] --> P2D + P1B[1b: Recursion Guard] --> P1C[1c: Allowlist] + P1B --> P2C + P1B --> P3D + P1D[1d: Tool Hooks] --> P2D[2d: Pipeline] + P1D --> P3A + + P2A[2a: Tool Groups] --> P2B[2b: Per-Agent Policy] + P2B --> P2C[2c: Depth Policy] + P2C --> P2D + + P3A[3a: Loop Detection] -.-> P2D + P3B[3b: Context Recovery] + P3C[3c: Auth Rotation] + P3D[3d: Cascade Stop] --> P4A + + P4A[4a: Async Spawn] --> P4B[4b: Announce Protocol] + P4C[4c: Process Isolation] + P4D[4d: Idempotency] + + style P1A fill:#ef4444,color:#fff + style P1B fill:#ef4444,color:#fff + style P1C fill:#ef4444,color:#fff + style P1D fill:#ef4444,color:#fff + style P2A fill:#f59e0b,color:#000 + style P2B fill:#f59e0b,color:#000 + style P2C fill:#f59e0b,color:#000 + style P2D fill:#f59e0b,color:#000 + style P3A fill:#3b82f6,color:#fff + style P3B fill:#3b82f6,color:#fff + style P3C fill:#3b82f6,color:#fff + style P3D fill:#3b82f6,color:#fff + style P4A fill:#8b5cf6,color:#fff + style P4B fill:#8b5cf6,color:#fff + style P4C fill:#8b5cf6,color:#fff + style P4D fill:#8b5cf6,color:#fff +``` + +--- + +## File Inventory + +### New Files + +| File | Phase | Description | +|------|-------|-------------| +| `pkg/tools/hooks.go` | 1d | ToolHook interface, chain execution | +| `pkg/tools/groups.go` | 2a | Tool group definitions | +| `pkg/tools/policy.go` | 2b-d | PolicyStep, pipeline, DepthPolicy | +| `pkg/tools/policy_test.go` | 2 | Policy pipeline tests | +| `pkg/tools/loop_detector.go` | 3a | Loop detection (repeat + ping-pong) | +| `pkg/tools/loop_detector_test.go` | 3a | Loop detection tests | +| `pkg/providers/auth_rotation.go` | 3c | Auth profile rotation + cooldown | +| `pkg/providers/auth_rotation_test.go` | 3c | Auth rotation tests | +| `pkg/multiagent/cascade.go` | 3d | RunRegistry, CascadeStop | +| `pkg/multiagent/cascade_test.go` | 3d | Cascade stop tests | +| `pkg/multiagent/spawn.go` | 4a | AsyncSpawn | +| `pkg/multiagent/announce.go` | 4b | Announce protocol | +| `pkg/gateway/dedup.go` | 4d | Idempotency cache | + +### Modified Files + +| File | Phase | Changes | +|------|-------|---------| +| `pkg/multiagent/blackboard_tool.go` | 1a | Add SetBoard / BoardSetter interface | +| `pkg/multiagent/handoff_tool.go` | 1a, 1b | SetBoard + depth propagation | +| `pkg/multiagent/handoff.go` | 1b, 1c | Depth/cycle guard + allowlist check | +| `pkg/agent/loop.go` | 1a, 1d, 2d | Wire session board, hook chain, policy pipeline | +| `pkg/config/config.go` | 1b, 2b | MaxHandoffDepth, ToolPolicyConfig | +| `pkg/tools/registry.go` | 1d | Add hooks field, ExecuteWithHooks | + +--- + +## OpenClaw Reference Map + +For each phase, the OpenClaw file to study: + +| Phase | picoclaw Target | OpenClaw Reference | +|-------|-----------------|-------------------| +| 1d | `pkg/tools/hooks.go` | `src/agents/pi-tools.before-tool-call.ts` | +| 2a | `pkg/tools/groups.go` | `src/agents/tool-policy.ts` (TOOL_GROUPS) | +| 2b-d | `pkg/tools/policy.go` | `src/agents/tool-policy-pipeline.ts` | +| 2c | depth deny | `src/agents/pi-tools.policy.ts` (resolveSubagentDenyList) | +| 3a | `pkg/tools/loop_detector.go` | `src/agents/tool-loop-detection.ts` | +| 3b | context recovery | `src/agents/pi-embedded-runner/run.ts` (overflow cascade) | +| 3c | `pkg/providers/auth_rotation.go` | `src/agents/auth-profiles/order.ts` + `usage.ts` | +| 3d | `pkg/multiagent/cascade.go` | `src/agents/tools/subagents-tool.ts` (cascadeKillChildren) | +| 4a | `pkg/multiagent/spawn.go` | `src/agents/subagent-spawn.ts` | +| 4b | `pkg/multiagent/announce.go` | `src/agents/subagent-announce.ts` | + +--- + +## Risk Assessment + +| Risk | Impact | Mitigation | +|------|--------|------------| +| Blackboard fix breaks existing tests | Medium | Fix is additive — SetBoard is optional, existing behavior preserved if not called | +| Tool policy breaks single-agent mode | High | Default: no policy config = full access (backward compatible) | +| Loop detection false positives | Medium | Start with high thresholds (warn=10, block=20), tune based on real usage | +| Auth rotation race conditions | High | Use sync.Mutex for profile state, file lock for cross-session (like OpenClaw) | +| Async spawn goroutine leaks | High | Always use context.WithTimeout, track in RunRegistry | + +--- + +## Non-Goals + +- Visual AIEOS dashboard (future) +- Community agent marketplace (future) +- A2A protocol compatibility (future — after Phase 4) +- SOUL.md bootstrap enhancement (separate PR, handled by other developer) +- model_list / provider Phase 2-4 (separate track per issue #283) diff --git a/pkg/agent/instance.go b/pkg/agent/instance.go index 54a5396e7..ed001840f 100644 --- a/pkg/agent/instance.go +++ b/pkg/agent/instance.go @@ -17,6 +17,8 @@ import ( type AgentInstance struct { ID string Name string + Role string + SystemPrompt string Model string Fallbacks []string Workspace string @@ -26,6 +28,7 @@ type AgentInstance struct { Sessions *session.SessionManager ContextBuilder *ContextBuilder Tools *tools.ToolRegistry + Capabilities []string Subagents *config.SubagentsConfig SkillsFilter []string Candidates []providers.FallbackCandidate @@ -61,14 +64,20 @@ func NewAgentInstance( agentID := routing.DefaultAgentID agentName := "" + agentRole := "" + agentSystemPrompt := "" var subagents *config.SubagentsConfig var skillsFilter []string + var capabilities []string if agentCfg != nil { agentID = routing.NormalizeAgentID(agentCfg.ID) agentName = agentCfg.Name + agentRole = agentCfg.Role + agentSystemPrompt = agentCfg.SystemPrompt subagents = agentCfg.Subagents skillsFilter = agentCfg.Skills + capabilities = agentCfg.Capabilities } maxIter := defaults.MaxToolIterations @@ -86,6 +95,8 @@ func NewAgentInstance( return &AgentInstance{ ID: agentID, Name: agentName, + Role: agentRole, + SystemPrompt: agentSystemPrompt, Model: model, Fallbacks: fallbacks, Workspace: workspace, @@ -95,6 +106,7 @@ func NewAgentInstance( Sessions: sessionsManager, ContextBuilder: contextBuilder, Tools: toolsRegistry, + Capabilities: capabilities, Subagents: subagents, SkillsFilter: skillsFilter, Candidates: candidates, diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index ed69712ff..a34aaa856 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -21,6 +21,7 @@ import ( "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/constants" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/multiagent" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/routing" "github.com/sipeed/picoclaw/pkg/state" @@ -35,8 +36,12 @@ type AgentLoop struct { state *state.Manager running atomic.Bool summarizing sync.Map + blackboards sync.Map // sessionKey -> *multiagent.Blackboard fallback *providers.FallbackChain channelManager *channels.Manager + runRegistry *multiagent.RunRegistry // tracks active handoff/spawn runs for cascade stop + announcer *multiagent.Announcer // per-session spawn result delivery (Phase 4b) + spawnManager *multiagent.SpawnManager // async spawn orchestrator (Phase 4a) } // processOptions configures how a message is processed @@ -53,9 +58,29 @@ type processOptions struct { func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop { registry := NewAgentRegistry(cfg, provider) + runRegistry := multiagent.NewRunRegistry() + announcer := multiagent.NewAnnouncer(32) + + // Resolve spawn limits from config (per-agent settings, with defaults). + maxChildren := multiagent.DefaultMaxChildren + spawnTimeout := multiagent.DefaultSpawnTimeout + if len(cfg.Agents.List) > 0 { + for _, ac := range cfg.Agents.List { + if ac.Subagents != nil { + if ac.Subagents.MaxChildrenPerAgent > 0 { + maxChildren = ac.Subagents.MaxChildrenPerAgent + } + if ac.Subagents.SpawnTimeoutSec > 0 { + spawnTimeout = time.Duration(ac.Subagents.SpawnTimeoutSec) * time.Second + } + break // use first configured value as global default + } + } + } + spawnManager := multiagent.NewSpawnManager(runRegistry, announcer, maxChildren, spawnTimeout) // Register shared tools to all agents - registerSharedTools(cfg, msgBus, registry, provider) + registerSharedTools(cfg, msgBus, registry, provider, runRegistry, spawnManager, announcer) // Set up shared fallback chain cooldown := providers.NewCooldownTracker() @@ -69,17 +94,61 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers } return &AgentLoop{ - bus: msgBus, - cfg: cfg, - registry: registry, - state: stateManager, - summarizing: sync.Map{}, - fallback: fallbackChain, + bus: msgBus, + cfg: cfg, + registry: registry, + state: stateManager, + summarizing: sync.Map{}, + fallback: fallbackChain, + runRegistry: runRegistry, + announcer: announcer, + spawnManager: spawnManager, } } +// registryResolver adapts AgentRegistry to multiagent.AgentResolver. +type registryResolver struct { + registry *AgentRegistry +} + +func (r *registryResolver) GetAgentInfo(agentID string) *multiagent.AgentInfo { + agent, ok := r.registry.GetAgent(agentID) + if !ok { + return nil + } + return &multiagent.AgentInfo{ + ID: agent.ID, + Name: agent.Name, + Role: agent.Role, + SystemPrompt: agent.SystemPrompt, + Model: agent.Model, + Provider: agent.Provider, + Tools: agent.Tools, + MaxIter: agent.MaxIterations, + Capabilities: agent.Capabilities, + } +} + +func (r *registryResolver) ListAgents() []multiagent.AgentInfo { + ids := r.registry.ListAgentIDs() + agents := make([]multiagent.AgentInfo, 0, len(ids)) + for _, id := range ids { + agent, ok := r.registry.GetAgent(id) + if !ok { + continue + } + agents = append(agents, multiagent.AgentInfo{ + ID: agent.ID, + Name: agent.Name, + Role: agent.Role, + Capabilities: agent.Capabilities, + }) + } + return agents +} + // registerSharedTools registers tools that are shared across all agents (web, message, spawn). -func registerSharedTools(cfg *config.Config, msgBus *bus.MessageBus, registry *AgentRegistry, provider providers.LLMProvider) { +func registerSharedTools(cfg *config.Config, msgBus *bus.MessageBus, registry *AgentRegistry, provider providers.LLMProvider, runReg *multiagent.RunRegistry, spawnMgr *multiagent.SpawnManager, announcer *multiagent.Announcer) { for _, agentID := range registry.ListAgentIDs() { agent, ok := registry.GetAgent(agentID) if !ok { @@ -126,11 +195,87 @@ func registerSharedTools(cfg *config.Config, msgBus *bus.MessageBus, registry *A }) agent.Tools.Register(spawnTool) - // Update context builder with the complete tools registry + // Multi-agent collaboration tools (blackboard, handoff, discovery) + // Only register when multiple agents are configured. + if len(registry.ListAgentIDs()) > 1 { + resolver := ®istryResolver{registry: registry} + + // Blackboard tool: per-agent instance sharing a placeholder blackboard. + // The actual per-session blackboard is wired via SetBoard in updateToolContexts + // before each message processing cycle (fixing the split-brain bug). + placeholderBoard := multiagent.NewBlackboard() + agent.Tools.Register(multiagent.NewBlackboardTool(placeholderBoard, agentID)) + + // Handoff tool: delegate tasks to other agents + handoffTool := multiagent.NewHandoffTool(resolver, placeholderBoard, agentID) + + // Allowlist checker: default-open when no subagents config, + // enforces allow_agents when configured. + currentAgentIDForHandoff := agentID + handoffTool.SetAllowlistChecker(multiagent.AllowlistCheckerFunc(func(from, to string) bool { + parent, ok := registry.GetAgent(from) + if !ok { + return false + } + // Default open: if no allowlist configured, allow all handoffs + if parent.Subagents == nil || parent.Subagents.AllowAgents == nil { + return true + } + return registry.CanSpawnSubagent(currentAgentIDForHandoff, to) + })) + handoffTool.SetRunRegistry(runReg, "") + agent.Tools.Register(handoffTool) + + // List agents tool: discover available agents + agent.Tools.Register(multiagent.NewListAgentsTool(resolver)) + + // Async spawn tool: fire-and-forget agent invocation (Phase 4a). + // Results are auto-announced back to the parent session via the Announcer. + if spawnMgr != nil { + spawnAgentTool := multiagent.NewSpawnTool(resolver, placeholderBoard, spawnMgr, agentID) + currentAgentIDForSpawn := agentID + spawnAgentTool.SetAllowlistChecker(multiagent.AllowlistCheckerFunc(func(from, to string) bool { + parent, ok := registry.GetAgent(from) + if !ok { + return false + } + if parent.Subagents == nil || parent.Subagents.AllowAgents == nil { + return true + } + return registry.CanSpawnSubagent(currentAgentIDForSpawn, to) + })) + agent.Tools.Register(spawnAgentTool) + } + } + + // Apply per-agent tool policy (static, startup-time filtering). + // This removes denied tools from the registry before the LLM ever sees them. + if agentCfg := findAgentConfig(cfg, agentID); agentCfg != nil && agentCfg.ToolPolicy != nil { + tools.ApplyPolicy(agent.Tools, tools.ToolPolicy{ + Allow: agentCfg.ToolPolicy.Allow, + Deny: agentCfg.ToolPolicy.Deny, + }) + } + + // Register loop detector hook (per-agent, session-isolated via context key). + // Uses production defaults: warn@10 repeats, block@20, circuit breaker@30. + agent.Tools.AddHook(tools.NewLoopDetector(tools.DefaultLoopDetectorConfig())) + + // Update context builder with the (possibly filtered) tools registry agent.ContextBuilder.SetToolsRegistry(agent.Tools) } } +// findAgentConfig returns the AgentConfig for a given agent ID, or nil if not found. +func findAgentConfig(cfg *config.Config, agentID string) *config.AgentConfig { + for i := range cfg.Agents.List { + if routing.NormalizeAgentID(cfg.Agents.List[i].ID) == agentID { + return &cfg.Agents.List[i] + } + } + return nil +} + func (al *AgentLoop) Run(ctx context.Context) error { al.running.Store(true) @@ -377,8 +522,8 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt } } - // 1. Update tool contexts - al.updateToolContexts(agent, opts.Channel, opts.ChatID) + // 1. Update tool contexts (including per-session blackboard wiring) + al.updateToolContexts(agent, opts.Channel, opts.ChatID, opts.SessionKey) // 2. Build messages (skip history for heartbeat) var history []providers.Message @@ -396,10 +541,21 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt opts.ChatID, ) + // 2b. Inject blackboard snapshot into system context if available + if bb := al.getOrCreateBlackboard(opts.SessionKey); bb != nil && bb.Size() > 0 { + snapshot := bb.Snapshot() + if snapshot != "" && len(messages) > 0 && messages[0].Role == "system" { + messages[0].Content += "\n\n" + snapshot + } + } + // 3. Save user message to session agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage) - // 4. Run LLM iteration loop + // 4. Inject session key into context for loop detection + ctx = tools.WithSessionKey(ctx, opts.SessionKey) + + // 5. Run LLM iteration loop finalContent, iteration, err := al.runLLMIteration(ctx, agent, messages, opts) if err != nil { return "", err @@ -452,6 +608,21 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, agent *AgentInstance, for iteration < agent.MaxIterations { iteration++ + // Drain pending spawn announcements and inject as context (Phase 4b). + // Between each LLM iteration, check if any child spawns completed + // and inject their results so the LLM can use them. + if al.announcer != nil { + anns := al.announcer.Drain(opts.SessionKey) + for _, ann := range anns { + annMsg := providers.Message{ + Role: "system", + Content: fmt.Sprintf("[Spawn result from agent %q (run: %s)]\n%s", ann.AgentID, ann.RunID, ann.Content), + } + messages = append(messages, annMsg) + agent.Sessions.AddFullMessage(opts.SessionKey, annMsg) + } + } + logger.DebugCF("agent", "LLM iteration", map[string]interface{}{ "agent_id": agent.ID, @@ -513,8 +684,8 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, agent *AgentInstance, }) } - // Retry loop for context/token errors - maxRetries := 2 + // Retry loop for recoverable errors (context window + rate limits). + maxRetries := 3 for retry := 0; retry <= maxRetries; retry++ { response, err = callLLM() if err == nil { @@ -522,16 +693,46 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, agent *AgentInstance, } errMsg := strings.ToLower(err.Error()) + + // Rate-limit / transient errors: wait with exponential backoff. + isRateLimitError := strings.Contains(errMsg, "429") || + strings.Contains(errMsg, "rate limit") || + strings.Contains(errMsg, "rate_limit") || + strings.Contains(errMsg, "resource_exhausted") || + strings.Contains(errMsg, "resource exhausted") || + strings.Contains(errMsg, "too many requests") || + strings.Contains(errMsg, "overloaded") || + strings.Contains(errMsg, "quota") + + if isRateLimitError && retry < maxRetries { + backoff := time.Duration(1< cutPoint-200 && j > 0; j-- { + if runes[j] == '\n' { + cutPoint = j + break + } + } + history[i].Content = string(runes[:cutPoint]) + + "\n\n[... content truncated — original too large for context window]" + truncated++ + } + + if truncated > 0 { + agent.Sessions.SetHistory(sessionKey, history) + logger.WarnCF("agent", "Tool results truncated", map[string]interface{}{ + "session_key": sessionKey, + "truncated": truncated, + "max_chars": maxToolResultChars, + }) + } +} + // forceCompression aggressively reduces context when the limit is hit. // It drops the oldest 50% of messages (keeping system prompt and last user message). func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) { diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index f2257973c..ee5f2ddda 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -628,3 +628,311 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) { t.Errorf("Expected history to be compressed (len < 8), got %d", len(finalHistory)) } } + +// TestTruncateToolResults_OversizedResultTruncated verifies that tool results +// exceeding maxToolResultChars are truncated with a warning footer. +func TestTruncateToolResults_OversizedResultTruncated(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &mockProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + agent := al.registry.GetDefaultAgent() + if agent == nil { + t.Fatal("No default agent found") + } + + // Create a large tool result (> maxToolResultChars) + bigContent := "" + for i := 0; i < maxToolResultChars+1000; i++ { + bigContent += "x" + } + + sessionKey := "test-truncate" + agent.Sessions.GetOrCreate(sessionKey) // ensure session exists + history := []providers.Message{ + {Role: "system", Content: "You are a helper."}, + {Role: "user", Content: "Do something"}, + {Role: "assistant", Content: "Calling tool..."}, + {Role: "tool", Content: bigContent, ToolCallID: "call-1"}, + {Role: "assistant", Content: "Here is the result."}, + } + agent.Sessions.SetHistory(sessionKey, history) + + al.truncateToolResults(agent, sessionKey) + + updated := agent.Sessions.GetHistory(sessionKey) + toolMsg := updated[3] + if len(toolMsg.Content) >= len(bigContent) { + t.Errorf("Expected tool result to be truncated, got length %d", len(toolMsg.Content)) + } + if toolMsg.ToolCallID != "call-1" { + t.Errorf("ToolCallID should be preserved, got %q", toolMsg.ToolCallID) + } + if !containsString(toolMsg.Content, "[... content truncated") { + t.Error("Expected truncation footer in tool result") + } +} + +// TestTruncateToolResults_SmallResultUntouched verifies that tool results +// under the threshold are not modified. +func TestTruncateToolResults_SmallResultUntouched(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &mockProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + agent := al.registry.GetDefaultAgent() + if agent == nil { + t.Fatal("No default agent found") + } + + originalContent := "Short tool result" + sessionKey := "test-no-truncate" + agent.Sessions.GetOrCreate(sessionKey) + history := []providers.Message{ + {Role: "system", Content: "You are a helper."}, + {Role: "user", Content: "Do something"}, + {Role: "tool", Content: originalContent, ToolCallID: "call-1"}, + {Role: "assistant", Content: "Done."}, + } + agent.Sessions.SetHistory(sessionKey, history) + + al.truncateToolResults(agent, sessionKey) + + updated := agent.Sessions.GetHistory(sessionKey) + if updated[2].Content != originalContent { + t.Errorf("Small tool result should not be modified, got %q", updated[2].Content) + } +} + +// TestTruncateToolResults_NonToolMessagesUntouched verifies that user and +// assistant messages are never truncated, even if large. +func TestTruncateToolResults_NonToolMessagesUntouched(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &mockProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + agent := al.registry.GetDefaultAgent() + if agent == nil { + t.Fatal("No default agent found") + } + + // Create large user & assistant messages + bigContent := "" + for i := 0; i < maxToolResultChars+500; i++ { + bigContent += "y" + } + + sessionKey := "test-non-tool" + agent.Sessions.GetOrCreate(sessionKey) + history := []providers.Message{ + {Role: "system", Content: "Prompt"}, + {Role: "user", Content: bigContent}, + {Role: "assistant", Content: bigContent}, + } + agent.Sessions.SetHistory(sessionKey, history) + + al.truncateToolResults(agent, sessionKey) + + updated := agent.Sessions.GetHistory(sessionKey) + if updated[1].Content != bigContent { + t.Error("User message should not be truncated") + } + if updated[2].Content != bigContent { + t.Error("Assistant message should not be truncated") + } +} + +// TestTruncateToolResults_NewlineBoundary verifies truncation prefers +// newline boundaries for cleaner output. +func TestTruncateToolResults_NewlineBoundary(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &mockProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + agent := al.registry.GetDefaultAgent() + if agent == nil { + t.Fatal("No default agent found") + } + + // Build content with a newline near the cut point. + // Place a newline 50 chars before maxToolResultChars. + content := "" + for i := 0; i < maxToolResultChars-50; i++ { + content += "a" + } + content += "\n" // newline at maxToolResultChars - 49 + for i := 0; i < 1500; i++ { + content += "b" + } + + sessionKey := "test-newline" + agent.Sessions.GetOrCreate(sessionKey) + history := []providers.Message{ + {Role: "system", Content: "Prompt"}, + {Role: "tool", Content: content, ToolCallID: "call-1"}, + } + agent.Sessions.SetHistory(sessionKey, history) + + al.truncateToolResults(agent, sessionKey) + + updated := agent.Sessions.GetHistory(sessionKey) + toolContent := updated[1].Content + + // The truncation should have cut at the newline, not at maxToolResultChars + // The content before the footer should end at the newline boundary + if !containsString(toolContent, "[... content truncated") { + t.Error("Expected truncation footer") + } + // The truncated content should be shorter than maxToolResultChars + footer + if len([]rune(toolContent)) > maxToolResultChars+100 { + t.Error("Content should be truncated near the newline boundary") + } +} + +// TestTruncateToolResults_MultipleToolResults verifies that all oversized tool +// results in a session are truncated. +func TestTruncateToolResults_MultipleToolResults(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &mockProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + agent := al.registry.GetDefaultAgent() + if agent == nil { + t.Fatal("No default agent found") + } + + bigContent := "" + for i := 0; i < maxToolResultChars+2000; i++ { + bigContent += "z" + } + + sessionKey := "test-multi-tool" + agent.Sessions.GetOrCreate(sessionKey) + history := []providers.Message{ + {Role: "system", Content: "Prompt"}, + {Role: "user", Content: "Step 1"}, + {Role: "tool", Content: bigContent, ToolCallID: "call-1"}, + {Role: "assistant", Content: "Got it"}, + {Role: "user", Content: "Step 2"}, + {Role: "tool", Content: bigContent, ToolCallID: "call-2"}, + {Role: "tool", Content: "small result", ToolCallID: "call-3"}, + {Role: "assistant", Content: "Done"}, + } + agent.Sessions.SetHistory(sessionKey, history) + + al.truncateToolResults(agent, sessionKey) + + updated := agent.Sessions.GetHistory(sessionKey) + + // call-1 should be truncated + if !containsString(updated[2].Content, "[... content truncated") { + t.Error("First oversized tool result should be truncated") + } + // call-2 should be truncated + if !containsString(updated[5].Content, "[... content truncated") { + t.Error("Second oversized tool result should be truncated") + } + // call-3 should be untouched + if updated[6].Content != "small result" { + t.Error("Small tool result should not be modified") + } +} + +// containsString is a simple helper to check substring presence. +func containsString(s, substr string) bool { + return len(s) >= len(substr) && searchSubstring(s, substr) +} + +func searchSubstring(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 682996bd6..8dd948ab4 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -100,19 +100,33 @@ func (m AgentModelConfig) MarshalJSON() ([]byte, error) { return json.Marshal(raw{Primary: m.Primary, Fallbacks: m.Fallbacks}) } +// ToolPolicyConfig defines allow/deny lists for per-agent tool filtering. +// Tool names can be individual (e.g. "exec") or group refs (e.g. "group:web"). +// nil = full access (no filtering applied). +type ToolPolicyConfig struct { + Allow []string `json:"allow,omitempty"` // tool names or group refs + Deny []string `json:"deny,omitempty"` // tool names or group refs +} + type AgentConfig struct { - ID string `json:"id"` - Default bool `json:"default,omitempty"` - Name string `json:"name,omitempty"` - Workspace string `json:"workspace,omitempty"` - Model *AgentModelConfig `json:"model,omitempty"` - Skills []string `json:"skills,omitempty"` - Subagents *SubagentsConfig `json:"subagents,omitempty"` + ID string `json:"id"` + Default bool `json:"default,omitempty"` + Name string `json:"name,omitempty"` + Role string `json:"role,omitempty"` + SystemPrompt string `json:"system_prompt,omitempty"` + Workspace string `json:"workspace,omitempty"` + Model *AgentModelConfig `json:"model,omitempty"` + Skills []string `json:"skills,omitempty"` + Capabilities []string `json:"capabilities,omitempty"` + Subagents *SubagentsConfig `json:"subagents,omitempty"` + ToolPolicy *ToolPolicyConfig `json:"tool_policy,omitempty"` } type SubagentsConfig struct { - AllowAgents []string `json:"allow_agents,omitempty"` - Model *AgentModelConfig `json:"model,omitempty"` + AllowAgents []string `json:"allow_agents,omitempty"` + Model *AgentModelConfig `json:"model,omitempty"` + MaxChildrenPerAgent int `json:"max_children_per_agent,omitempty"` // max concurrent async spawns per parent (default 5) + SpawnTimeoutSec int `json:"spawn_timeout_sec,omitempty"` // per-spawn timeout in seconds (default 300) } type PeerMatch struct { @@ -266,11 +280,25 @@ type ProvidersConfig struct { } type ProviderConfig struct { - APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"` - APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"` - Proxy string `json:"proxy,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_PROXY"` - AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"` - ConnectMode string `json:"connect_mode,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_CONNECT_MODE"` //only for Github Copilot, `stdio` or `grpc` + APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"` + APIKeys []string `json:"api_keys,omitempty"` // multiple keys for auth rotation (takes precedence over api_key) + APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"` + Proxy string `json:"proxy,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_PROXY"` + AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"` + ConnectMode string `json:"connect_mode,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_CONNECT_MODE"` //only for Github Copilot, `stdio` or `grpc` +} + +// ResolveAPIKeys returns the effective list of API keys for this provider. +// If APIKeys is set, returns it. Otherwise wraps APIKey as a single-element slice. +// Returns nil if no keys are configured. +func (pc *ProviderConfig) ResolveAPIKeys() []string { + if len(pc.APIKeys) > 0 { + return pc.APIKeys + } + if pc.APIKey != "" { + return []string{pc.APIKey} + } + return nil } type OpenAIProviderConfig struct { diff --git a/pkg/multiagent/announce.go b/pkg/multiagent/announce.go new file mode 100644 index 000000000..81e024178 --- /dev/null +++ b/pkg/multiagent/announce.go @@ -0,0 +1,139 @@ +package multiagent + +import ( + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/logger" +) + +// AnnounceMode determines how a spawn result is delivered to the parent session. +// Inspired by Google Cloud Pub/Sub delivery modes and Microsoft Azure Service Bus. +type AnnounceMode string + +const ( + // AnnounceQueue buffers the result until the parent requests it (default). + // Like Apple's GCD serial queue — ordered, non-blocking. + AnnounceQueue AnnounceMode = "queue" + + // AnnounceDirect sends the result immediately to the parent's chat channel. + // Like Google Pub/Sub push subscription — immediate delivery. + AnnounceDirect AnnounceMode = "direct" +) + +// Announcement is a completion notice from a child spawn to its parent. +type Announcement struct { + FromSessionKey string + ToSessionKey string + RunID string + AgentID string + Content string + Outcome *SpawnOutcome + Mode AnnounceMode + CreatedAt time.Time +} + +// Announcer manages per-session announcement delivery using Go channels +// (inspired by Apple's Grand Central Dispatch work queues). +// Thread-safe for concurrent producers (multiple child spawns completing +// simultaneously) and a single consumer (parent agent). +type Announcer struct { + // Per-session buffered channels (Google Pub/Sub topic model). + channels sync.Map // sessionKey -> chan *Announcement + bufSize int +} + +// NewAnnouncer creates an announcer with the given per-session buffer size. +// Buffer size follows NVIDIA's double-buffering pattern: enough to absorb +// burst completions without blocking producers. +func NewAnnouncer(bufSize int) *Announcer { + if bufSize <= 0 { + bufSize = 32 // default: buffer up to 32 pending announcements + } + return &Announcer{bufSize: bufSize} +} + +// Deliver sends an announcement to the target session's channel. +// Non-blocking: if the channel is full, the oldest announcement is dropped +// (Meta's back-pressure pattern for high-throughput systems). +func (a *Announcer) Deliver(targetSessionKey string, ann *Announcement) { + ann.CreatedAt = time.Now() + if ann.Mode == "" { + ann.Mode = AnnounceQueue + } + + ch := a.getOrCreateChan(targetSessionKey) + + select { + case ch <- ann: + logger.DebugCF("announce", "Announcement delivered", map[string]interface{}{ + "from": ann.FromSessionKey, + "to": targetSessionKey, + "run_id": ann.RunID, + "agent": ann.AgentID, + "mode": string(ann.Mode), + }) + default: + // Channel full — drop oldest to make room (back-pressure). + select { + case <-ch: + logger.WarnCF("announce", "Dropped oldest announcement (buffer full)", map[string]interface{}{ + "session": targetSessionKey, + }) + default: + } + // Retry delivery. + select { + case ch <- ann: + default: + logger.WarnCF("announce", "Failed to deliver announcement", map[string]interface{}{ + "session": targetSessionKey, + "run_id": ann.RunID, + }) + } + } +} + +// Drain returns all pending announcements for a session, clearing the buffer. +// The parent agent calls this between LLM iterations to collect spawn results. +// Follows Google's batch-pull pattern from Cloud Pub/Sub. +func (a *Announcer) Drain(sessionKey string) []*Announcement { + v, ok := a.channels.Load(sessionKey) + if !ok { + return nil + } + ch := v.(chan *Announcement) + + var results []*Announcement + for { + select { + case ann := <-ch: + results = append(results, ann) + default: + return results + } + } +} + +// Pending returns the number of pending announcements for a session. +func (a *Announcer) Pending(sessionKey string) int { + v, ok := a.channels.Load(sessionKey) + if !ok { + return 0 + } + return len(v.(chan *Announcement)) +} + +// Cleanup removes the channel for a session (called on session end). +func (a *Announcer) Cleanup(sessionKey string) { + a.channels.Delete(sessionKey) +} + +func (a *Announcer) getOrCreateChan(sessionKey string) chan *Announcement { + if v, ok := a.channels.Load(sessionKey); ok { + return v.(chan *Announcement) + } + ch := make(chan *Announcement, a.bufSize) + actual, _ := a.channels.LoadOrStore(sessionKey, ch) + return actual.(chan *Announcement) +} diff --git a/pkg/multiagent/blackboard.go b/pkg/multiagent/blackboard.go new file mode 100644 index 000000000..35300933b --- /dev/null +++ b/pkg/multiagent/blackboard.go @@ -0,0 +1,154 @@ +package multiagent + +import ( + "encoding/json" + "sort" + "sync" + "time" +) + +// BlackboardEntry represents a single entry in the shared context pool. +type BlackboardEntry struct { + Key string `json:"key"` + Value string `json:"value"` + Author string `json:"author"` + Scope string `json:"scope"` + Timestamp time.Time `json:"timestamp"` +} + +// BoardAware is implemented by tools that need the session blackboard injected +// before each execution. This fixes the split-brain bug where tools were bound +// to a static board at registration time instead of the per-session board. +type BoardAware interface { + SetBoard(board *Blackboard) +} + +// Blackboard is a thread-safe shared context pool for multi-agent collaboration. +// Agents read and write string key-value entries, each tagged with authorship +// and scope metadata. +type Blackboard struct { + entries map[string]*BlackboardEntry + mu sync.RWMutex +} + +// NewBlackboard creates an empty Blackboard. +func NewBlackboard() *Blackboard { + return &Blackboard{ + entries: make(map[string]*BlackboardEntry), + } +} + +// Set writes or overwrites an entry on the blackboard. +func (b *Blackboard) Set(key, value, author string) { + b.mu.Lock() + defer b.mu.Unlock() + b.entries[key] = &BlackboardEntry{ + Key: key, + Value: value, + Author: author, + Scope: "shared", + Timestamp: time.Now(), + } +} + +// Get returns the value for a key, or empty string if not found. +func (b *Blackboard) Get(key string) string { + b.mu.RLock() + defer b.mu.RUnlock() + if e, ok := b.entries[key]; ok { + return e.Value + } + return "" +} + +// GetEntry returns the full entry for a key, or nil if not found. +func (b *Blackboard) GetEntry(key string) *BlackboardEntry { + b.mu.RLock() + defer b.mu.RUnlock() + if e, ok := b.entries[key]; ok { + cp := *e + return &cp + } + return nil +} + +// Delete removes an entry by key. Returns true if it existed. +func (b *Blackboard) Delete(key string) bool { + b.mu.Lock() + defer b.mu.Unlock() + _, ok := b.entries[key] + if ok { + delete(b.entries, key) + } + return ok +} + +// List returns all keys sorted alphabetically. +func (b *Blackboard) List() []string { + b.mu.RLock() + defer b.mu.RUnlock() + keys := make([]string, 0, len(b.entries)) + for k := range b.entries { + keys = append(keys, k) + } + sort.Strings(keys) + return keys +} + +// Snapshot returns a string summary of all entries suitable for injection +// into an LLM system prompt. +func (b *Blackboard) Snapshot() string { + b.mu.RLock() + defer b.mu.RUnlock() + if len(b.entries) == 0 { + return "" + } + + keys := make([]string, 0, len(b.entries)) + for k := range b.entries { + keys = append(keys, k) + } + sort.Strings(keys) + + result := "## Shared Context (Blackboard)\n\n" + for _, k := range keys { + e := b.entries[k] + result += "- **" + k + "** (by " + e.Author + "): " + e.Value + "\n" + } + return result +} + +// Size returns the number of entries. +func (b *Blackboard) Size() int { + b.mu.RLock() + defer b.mu.RUnlock() + return len(b.entries) +} + +// MarshalJSON serializes the blackboard entries to JSON. +func (b *Blackboard) MarshalJSON() ([]byte, error) { + b.mu.RLock() + defer b.mu.RUnlock() + + entries := make([]*BlackboardEntry, 0, len(b.entries)) + for _, e := range b.entries { + entries = append(entries, e) + } + return json.Marshal(entries) +} + +// UnmarshalJSON deserializes blackboard entries from JSON. +func (b *Blackboard) UnmarshalJSON(data []byte) error { + var entries []*BlackboardEntry + if err := json.Unmarshal(data, &entries); err != nil { + return err + } + + b.mu.Lock() + defer b.mu.Unlock() + b.entries = make(map[string]*BlackboardEntry, len(entries)) + for _, e := range entries { + b.entries[e.Key] = e + } + return nil +} diff --git a/pkg/multiagent/blackboard_test.go b/pkg/multiagent/blackboard_test.go new file mode 100644 index 000000000..89ff0aefb --- /dev/null +++ b/pkg/multiagent/blackboard_test.go @@ -0,0 +1,400 @@ +package multiagent + +import ( + "context" + "encoding/json" + "sync" + "testing" +) + +func TestBlackboard_SetGet(t *testing.T) { + bb := NewBlackboard() + bb.Set("goal", "build feature X", "main") + + if got := bb.Get("goal"); got != "build feature X" { + t.Errorf("Get(goal) = %q, want %q", got, "build feature X") + } +} + +func TestBlackboard_GetMissing(t *testing.T) { + bb := NewBlackboard() + if got := bb.Get("missing"); got != "" { + t.Errorf("Get(missing) = %q, want empty", got) + } +} + +func TestBlackboard_GetEntry(t *testing.T) { + bb := NewBlackboard() + bb.Set("status", "in-progress", "coder") + + entry := bb.GetEntry("status") + if entry == nil { + t.Fatal("expected non-nil entry") + } + if entry.Author != "coder" { + t.Errorf("Author = %q, want %q", entry.Author, "coder") + } + if entry.Scope != "shared" { + t.Errorf("Scope = %q, want %q", entry.Scope, "shared") + } +} + +func TestBlackboard_GetEntryMissing(t *testing.T) { + bb := NewBlackboard() + if entry := bb.GetEntry("nope"); entry != nil { + t.Error("expected nil entry for missing key") + } +} + +func TestBlackboard_Overwrite(t *testing.T) { + bb := NewBlackboard() + bb.Set("counter", "1", "a") + bb.Set("counter", "2", "b") + + entry := bb.GetEntry("counter") + if entry.Value != "2" { + t.Errorf("Value = %q after overwrite, want %q", entry.Value, "2") + } + if entry.Author != "b" { + t.Errorf("Author = %q after overwrite, want %q", entry.Author, "b") + } +} + +func TestBlackboard_Delete(t *testing.T) { + bb := NewBlackboard() + bb.Set("tmp", "value", "main") + + if !bb.Delete("tmp") { + t.Error("Delete(tmp) returned false, expected true") + } + if bb.Delete("tmp") { + t.Error("Delete(tmp) second call returned true, expected false") + } + if bb.Get("tmp") != "" { + t.Error("Get(tmp) after delete should return empty") + } +} + +func TestBlackboard_List(t *testing.T) { + bb := NewBlackboard() + bb.Set("b", "2", "a") + bb.Set("a", "1", "a") + bb.Set("c", "3", "a") + + keys := bb.List() + if len(keys) != 3 { + t.Fatalf("List() returned %d keys, want 3", len(keys)) + } + if keys[0] != "a" || keys[1] != "b" || keys[2] != "c" { + t.Errorf("List() = %v, want [a b c]", keys) + } +} + +func TestBlackboard_Snapshot(t *testing.T) { + bb := NewBlackboard() + if s := bb.Snapshot(); s != "" { + t.Errorf("empty blackboard Snapshot() = %q, want empty", s) + } + + bb.Set("goal", "test", "main") + s := bb.Snapshot() + if s == "" { + t.Error("Snapshot() returned empty for non-empty blackboard") + } + if !contains(s, "goal") || !contains(s, "main") || !contains(s, "test") { + t.Errorf("Snapshot() = %q, expected to contain key/author/value", s) + } +} + +func TestBlackboard_Size(t *testing.T) { + bb := NewBlackboard() + if bb.Size() != 0 { + t.Errorf("Size() = %d, want 0", bb.Size()) + } + bb.Set("a", "1", "x") + bb.Set("b", "2", "x") + if bb.Size() != 2 { + t.Errorf("Size() = %d, want 2", bb.Size()) + } +} + +func TestBlackboard_ConcurrentAccess(_ *testing.T) { + bb := NewBlackboard() + var wg sync.WaitGroup + + for range 100 { + wg.Go(func() { + key := "key" + bb.Set(key, "val", "agent") + bb.Get(key) + bb.List() + bb.Snapshot() + }) + } + wg.Wait() +} + +func TestBlackboard_JSON(t *testing.T) { + bb := NewBlackboard() + bb.Set("x", "1", "a") + bb.Set("y", "2", "b") + + data, err := json.Marshal(bb) + if err != nil { + t.Fatalf("MarshalJSON failed: %v", err) + } + + bb2 := NewBlackboard() + if err := json.Unmarshal(data, bb2); err != nil { + t.Fatalf("UnmarshalJSON failed: %v", err) + } + + if bb2.Get("x") != "1" || bb2.Get("y") != "2" { + t.Error("roundtrip lost data") + } +} + +func TestBlackboardTool_Write(t *testing.T) { + bb := NewBlackboard() + tool := NewBlackboardTool(bb, "test-agent") + + result := tool.Execute(context.Background(), map[string]any{ + "action": "write", + "key": "task", + "value": "implement feature", + }) + if result.IsError { + t.Fatalf("write failed: %s", result.ForLLM) + } + + if bb.Get("task") != "implement feature" { + t.Error("write did not persist") + } + entry := bb.GetEntry("task") + if entry.Author != "test-agent" { + t.Errorf("Author = %q, want %q", entry.Author, "test-agent") + } +} + +func TestBlackboardTool_Read(t *testing.T) { + bb := NewBlackboard() + bb.Set("info", "hello", "other") + tool := NewBlackboardTool(bb, "reader") + + result := tool.Execute(context.Background(), map[string]any{ + "action": "read", + "key": "info", + }) + if result.IsError { + t.Fatalf("read failed: %s", result.ForLLM) + } + if !contains(result.ForLLM, "hello") { + t.Errorf("read result = %q, expected to contain 'hello'", result.ForLLM) + } +} + +func TestBlackboardTool_ReadMissing(t *testing.T) { + bb := NewBlackboard() + tool := NewBlackboardTool(bb, "reader") + + result := tool.Execute(context.Background(), map[string]any{ + "action": "read", + "key": "nope", + }) + if result.IsError { + t.Fatalf("read missing should not be error: %s", result.ForLLM) + } + if !contains(result.ForLLM, "No entry") { + t.Errorf("expected 'No entry' message, got %q", result.ForLLM) + } +} + +func TestBlackboardTool_List(t *testing.T) { + bb := NewBlackboard() + bb.Set("a", "1", "x") + bb.Set("b", "2", "y") + tool := NewBlackboardTool(bb, "lister") + + result := tool.Execute(context.Background(), map[string]any{ + "action": "list", + }) + if result.IsError { + t.Fatalf("list failed: %s", result.ForLLM) + } + if !contains(result.ForLLM, "a") || !contains(result.ForLLM, "b") { + t.Errorf("list result = %q, expected keys", result.ForLLM) + } +} + +func TestBlackboardTool_Delete(t *testing.T) { + bb := NewBlackboard() + bb.Set("tmp", "val", "x") + tool := NewBlackboardTool(bb, "deleter") + + result := tool.Execute(context.Background(), map[string]any{ + "action": "delete", + "key": "tmp", + }) + if result.IsError { + t.Fatalf("delete failed: %s", result.ForLLM) + } + if bb.Size() != 0 { + t.Error("delete did not remove entry") + } +} + +func TestBlackboardTool_InvalidAction(t *testing.T) { + bb := NewBlackboard() + tool := NewBlackboardTool(bb, "test") + + result := tool.Execute(context.Background(), map[string]any{ + "action": "invalid", + }) + if !result.IsError { + t.Error("expected error for invalid action") + } +} + +func TestBlackboardTool_MissingKey(t *testing.T) { + bb := NewBlackboard() + tool := NewBlackboardTool(bb, "test") + + // read without key + result := tool.Execute(context.Background(), map[string]any{ + "action": "read", + }) + if !result.IsError { + t.Error("expected error for read without key") + } + + // write without key + result = tool.Execute(context.Background(), map[string]any{ + "action": "write", + "value": "test", + }) + if !result.IsError { + t.Error("expected error for write without key") + } + + // write without value + result = tool.Execute(context.Background(), map[string]any{ + "action": "write", + "key": "k", + }) + if !result.IsError { + t.Error("expected error for write without value") + } +} + +func TestBlackboardTool_SetBoard(t *testing.T) { + bb1 := NewBlackboard() + bb2 := NewBlackboard() + bb2.Set("from_session", "session_data", "system") + + tool := NewBlackboardTool(bb1, "agent1") + + // Initially reads from bb1 (empty) + result := tool.Execute(context.Background(), map[string]any{ + "action": "read", + "key": "from_session", + }) + if !contains(result.ForLLM, "No entry") { + t.Errorf("expected 'No entry' before SetBoard, got %q", result.ForLLM) + } + + // Switch to session board + tool.SetBoard(bb2) + + // Now reads from bb2 + result = tool.Execute(context.Background(), map[string]any{ + "action": "read", + "key": "from_session", + }) + if !contains(result.ForLLM, "session_data") { + t.Errorf("expected 'session_data' after SetBoard, got %q", result.ForLLM) + } + + // Writes go to bb2, not bb1 + tool.Execute(context.Background(), map[string]any{ + "action": "write", + "key": "new_key", + "value": "new_val", + }) + if bb1.Get("new_key") != "" { + t.Error("write went to old board after SetBoard") + } + if bb2.Get("new_key") != "new_val" { + t.Error("write didn't go to new board after SetBoard") + } +} + +// TestBlackboard_UnmarshalJSON_InvalidData verifies that UnmarshalJSON returns an error +// for malformed input instead of silently producing a broken blackboard. +func TestBlackboard_UnmarshalJSON_InvalidData(t *testing.T) { + bb := NewBlackboard() + err := bb.UnmarshalJSON([]byte("not valid json")) + if err == nil { + t.Error("expected error for invalid JSON input") + } + // Board should remain empty after a failed unmarshal + if bb.Size() != 0 { + t.Errorf("Size() = %d after failed unmarshal, want 0", bb.Size()) + } +} + +// TestBlackboardTool_ListEmpty verifies the "Blackboard is empty" message path when +// the board has no entries. +func TestBlackboardTool_ListEmpty(t *testing.T) { + bb := NewBlackboard() + tool := NewBlackboardTool(bb, "lister") + + result := tool.Execute(context.Background(), map[string]any{ + "action": "list", + }) + if result.IsError { + t.Fatalf("list on empty board should not error: %s", result.ForLLM) + } + if !contains(result.ForLLM, "empty") { + t.Errorf("expected 'empty' in result for empty board, got %q", result.ForLLM) + } +} + +// TestBlackboardTool_DeleteMissing verifies the "not found" path for delete on a +// key that does not exist. +func TestBlackboardTool_DeleteMissing(t *testing.T) { + bb := NewBlackboard() + tool := NewBlackboardTool(bb, "deleter") + + result := tool.Execute(context.Background(), map[string]any{ + "action": "delete", + "key": "nonexistent_key", + }) + if result.IsError { + t.Fatalf("delete of missing key should not be an error: %s", result.ForLLM) + } + if !contains(result.ForLLM, "not found") { + t.Errorf("expected 'not found' in result, got %q", result.ForLLM) + } +} + +func TestBoardAware_Interface(t *testing.T) { + // Verify both tools implement BoardAware + bb := NewBlackboard() + var _ BoardAware = NewBlackboardTool(bb, "test") + + resolver := newMockResolver() + var _ BoardAware = NewHandoffTool(resolver, bb, "test") +} + +func contains(s, sub string) bool { + return len(s) >= len(sub) && (s == sub || len(s) > 0 && containsStr(s, sub)) +} + +func containsStr(s, sub string) bool { + for i := 0; i+len(sub) <= len(s); i++ { + if s[i:i+len(sub)] == sub { + return true + } + } + return false +} diff --git a/pkg/multiagent/blackboard_tool.go b/pkg/multiagent/blackboard_tool.go new file mode 100644 index 000000000..6c57c2764 --- /dev/null +++ b/pkg/multiagent/blackboard_tool.go @@ -0,0 +1,128 @@ +package multiagent + +import ( + "context" + "fmt" + "strings" + + "github.com/sipeed/picoclaw/pkg/tools" +) + +// BlackboardTool exposes the Blackboard to an LLM agent via the tool interface. +// Each instance is bound to a specific agent ID for authorship tracking. +type BlackboardTool struct { + board *Blackboard + agentID string +} + +// NewBlackboardTool creates a blackboard tool bound to a specific agent. +func NewBlackboardTool(board *Blackboard, agentID string) *BlackboardTool { + return &BlackboardTool{ + board: board, + agentID: agentID, + } +} + +// SetBoard replaces the blackboard reference, allowing the tool to be wired +// to the correct per-session board before each execution. +func (t *BlackboardTool) SetBoard(board *Blackboard) { + t.board = board +} + +// Name returns the tool name. +func (t *BlackboardTool) Name() string { return "blackboard" } + +// Description returns a human-readable description of the tool. +func (t *BlackboardTool) Description() string { + return "Read, write, list, or delete entries in the shared context blackboard. " + + "Use this to share information between agents in a multi-agent session." +} + +// Parameters returns the JSON Schema for the tool's input. +func (t *BlackboardTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "action": map[string]any{ + "type": "string", + "enum": []string{"read", "write", "list", "delete"}, + "description": "The action to perform on the blackboard", + }, + "key": map[string]any{ + "type": "string", + "description": "The key to read, write, or delete (not required for list)", + }, + "value": map[string]any{ + "type": "string", + "description": "The value to write (only required for write action)", + }, + }, + "required": []string{"action"}, + } +} + +// Execute runs the blackboard action specified in args. +func (t *BlackboardTool) Execute(_ context.Context, args map[string]any) *tools.ToolResult { + action, ok := args["action"].(string) + if !ok { + action = "" + } + key, ok := args["key"].(string) + if !ok { + key = "" + } + value, ok := args["value"].(string) + if !ok { + value = "" + } + + switch strings.ToLower(action) { + case "read": + if key == "" { + return tools.ErrorResult("key is required for read action") + } + entry := t.board.GetEntry(key) + if entry == nil { + return tools.NewToolResult(fmt.Sprintf("No entry found for key %q", key)) + } + return tools.NewToolResult(fmt.Sprintf("Key: %s\nValue: %s\nAuthor: %s\nScope: %s", + entry.Key, entry.Value, entry.Author, entry.Scope)) + + case "write": + if key == "" { + return tools.ErrorResult("key is required for write action") + } + if value == "" { + return tools.ErrorResult("value is required for write action") + } + t.board.Set(key, value, t.agentID) + return tools.NewToolResult(fmt.Sprintf("Written key %q to blackboard", key)) + + case "list": + keys := t.board.List() + if len(keys) == 0 { + return tools.NewToolResult("Blackboard is empty") + } + var sb strings.Builder + fmt.Fprintf(&sb, "Blackboard entries (%d):\n", len(keys)) + for _, k := range keys { + entry := t.board.GetEntry(k) + if entry != nil { + fmt.Fprintf(&sb, "- %s (by %s): %s\n", k, entry.Author, entry.Value) + } + } + return tools.NewToolResult(sb.String()) + + case "delete": + if key == "" { + return tools.ErrorResult("key is required for delete action") + } + if t.board.Delete(key) { + return tools.NewToolResult(fmt.Sprintf("Deleted key %q from blackboard", key)) + } + return tools.NewToolResult(fmt.Sprintf("Key %q not found on blackboard", key)) + + default: + return tools.ErrorResult(fmt.Sprintf("unknown action %q; use read, write, list, or delete", action)) + } +} diff --git a/pkg/multiagent/cascade.go b/pkg/multiagent/cascade.go new file mode 100644 index 000000000..338b0c8fb --- /dev/null +++ b/pkg/multiagent/cascade.go @@ -0,0 +1,135 @@ +package multiagent + +import ( + "context" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/logger" +) + +// ActiveRun represents a running handoff or spawn that can be cancelled. +type ActiveRun struct { + SessionKey string + AgentID string + ParentKey string // parent session key ("" for top-level) + Cancel context.CancelFunc // cancels this run's context + StartedAt time.Time +} + +// RunRegistry tracks active agent runs for cascade cancellation. +// Thread-safe via sync.Map. +type RunRegistry struct { + runs sync.Map // sessionKey -> *ActiveRun +} + +// NewRunRegistry creates an empty run registry. +func NewRunRegistry() *RunRegistry { + return &RunRegistry{} +} + +// Register adds an active run to the registry. +func (r *RunRegistry) Register(run *ActiveRun) { + r.runs.Store(run.SessionKey, run) + logger.DebugCF("cascade", "Run registered", + map[string]interface{}{ + "session_key": run.SessionKey, + "agent_id": run.AgentID, + "parent_key": run.ParentKey, + }) +} + +// Deregister removes a run from the registry (normal completion). +func (r *RunRegistry) Deregister(sessionKey string) { + r.runs.Delete(sessionKey) + logger.DebugCF("cascade", "Run deregistered", + map[string]interface{}{ + "session_key": sessionKey, + }) +} + +// CascadeStop cancels a run and all its descendants. +// Returns the number of runs cancelled. Uses a seen-set to prevent infinite loops. +func (r *RunRegistry) CascadeStop(sessionKey string) int { + seen := make(map[string]bool) + killed := r.cascadeStop(sessionKey, seen) + if killed > 0 { + logger.InfoCF("cascade", "Cascade stop completed", + map[string]interface{}{ + "root_key": sessionKey, + "killed": killed, + }) + } + return killed +} + +func (r *RunRegistry) cascadeStop(sessionKey string, seen map[string]bool) int { + if seen[sessionKey] { + return 0 + } + seen[sessionKey] = true + killed := 0 + + // Cancel and remove this run + if v, ok := r.runs.LoadAndDelete(sessionKey); ok { + run := v.(*ActiveRun) + run.Cancel() + killed++ + logger.DebugCF("cascade", "Run cancelled", + map[string]interface{}{ + "session_key": sessionKey, + "agent_id": run.AgentID, + }) + } + + // Find and cascade-stop all children (runs whose ParentKey == sessionKey) + r.runs.Range(func(key, value interface{}) bool { + childRun := value.(*ActiveRun) + if childRun.ParentKey == sessionKey { + killed += r.cascadeStop(key.(string), seen) + } + return true + }) + + return killed +} + +// StopAll cancels every active run. Returns the number cancelled. +func (r *RunRegistry) StopAll() int { + killed := 0 + r.runs.Range(func(key, value interface{}) bool { + run := value.(*ActiveRun) + run.Cancel() + r.runs.Delete(key) + killed++ + return true + }) + if killed > 0 { + logger.InfoCF("cascade", "Stop all completed", + map[string]interface{}{"killed": killed}) + } + return killed +} + +// ActiveCount returns the number of currently active runs. +func (r *RunRegistry) ActiveCount() int { + count := 0 + r.runs.Range(func(_, _ interface{}) bool { + count++ + return true + }) + return count +} + +// GetChildren returns session keys of all direct children of the given parent. +func (r *RunRegistry) GetChildren(parentKey string) []string { + var children []string + r.runs.Range(func(key, value interface{}) bool { + run := value.(*ActiveRun) + if run.ParentKey == parentKey { + children = append(children, key.(string)) + } + return true + }) + return children +} diff --git a/pkg/multiagent/cascade_test.go b/pkg/multiagent/cascade_test.go new file mode 100644 index 000000000..926333198 --- /dev/null +++ b/pkg/multiagent/cascade_test.go @@ -0,0 +1,327 @@ +package multiagent + +import ( + "context" + "fmt" + "sync/atomic" + "testing" + "time" +) + +func TestRunRegistry_RegisterAndDeregister(t *testing.T) { + reg := NewRunRegistry() + _, cancel := context.WithCancel(context.Background()) + defer cancel() + + reg.Register(&ActiveRun{ + SessionKey: "run-1", + AgentID: "coder", + Cancel: cancel, + StartedAt: time.Now(), + }) + + if reg.ActiveCount() != 1 { + t.Fatalf("ActiveCount = %d, want 1", reg.ActiveCount()) + } + + reg.Deregister("run-1") + + if reg.ActiveCount() != 0 { + t.Fatalf("ActiveCount = %d after deregister, want 0", reg.ActiveCount()) + } +} + +func TestRunRegistry_CascadeStop_SingleRun(t *testing.T) { + reg := NewRunRegistry() + ctx, cancel := context.WithCancel(context.Background()) + + reg.Register(&ActiveRun{ + SessionKey: "run-1", + AgentID: "coder", + Cancel: cancel, + StartedAt: time.Now(), + }) + + killed := reg.CascadeStop("run-1") + if killed != 1 { + t.Fatalf("killed = %d, want 1", killed) + } + + // Context should be cancelled + select { + case <-ctx.Done(): + // expected + default: + t.Fatal("context not cancelled after CascadeStop") + } + + if reg.ActiveCount() != 0 { + t.Fatalf("ActiveCount = %d after CascadeStop, want 0", reg.ActiveCount()) + } +} + +func TestRunRegistry_CascadeStop_ParentChildChain(t *testing.T) { + reg := NewRunRegistry() + + // Build: parent → child → grandchild + _, cancelParent := context.WithCancel(context.Background()) + ctxChild, cancelChild := context.WithCancel(context.Background()) + ctxGrandchild, cancelGrandchild := context.WithCancel(context.Background()) + + reg.Register(&ActiveRun{ + SessionKey: "parent", + AgentID: "main", + ParentKey: "", + Cancel: cancelParent, + StartedAt: time.Now(), + }) + reg.Register(&ActiveRun{ + SessionKey: "child", + AgentID: "coder", + ParentKey: "parent", + Cancel: cancelChild, + StartedAt: time.Now(), + }) + reg.Register(&ActiveRun{ + SessionKey: "grandchild", + AgentID: "researcher", + ParentKey: "child", + Cancel: cancelGrandchild, + StartedAt: time.Now(), + }) + + if reg.ActiveCount() != 3 { + t.Fatalf("ActiveCount = %d, want 3", reg.ActiveCount()) + } + + // Cascade stop from parent should kill all 3 + killed := reg.CascadeStop("parent") + if killed != 3 { + t.Fatalf("killed = %d, want 3", killed) + } + + // All contexts should be cancelled + select { + case <-ctxChild.Done(): + default: + t.Fatal("child context not cancelled") + } + select { + case <-ctxGrandchild.Done(): + default: + t.Fatal("grandchild context not cancelled") + } + + if reg.ActiveCount() != 0 { + t.Fatalf("ActiveCount = %d after cascade, want 0", reg.ActiveCount()) + } +} + +func TestRunRegistry_CascadeStop_MidChain(t *testing.T) { + reg := NewRunRegistry() + + // Build: parent → child → grandchild + _, cancelParent := context.WithCancel(context.Background()) + _, cancelChild := context.WithCancel(context.Background()) + ctxGrandchild, cancelGrandchild := context.WithCancel(context.Background()) + + reg.Register(&ActiveRun{ + SessionKey: "parent", + AgentID: "main", + ParentKey: "", + Cancel: cancelParent, + StartedAt: time.Now(), + }) + reg.Register(&ActiveRun{ + SessionKey: "child", + AgentID: "coder", + ParentKey: "parent", + Cancel: cancelChild, + StartedAt: time.Now(), + }) + reg.Register(&ActiveRun{ + SessionKey: "grandchild", + AgentID: "researcher", + ParentKey: "child", + Cancel: cancelGrandchild, + StartedAt: time.Now(), + }) + + // Cascade stop from child should kill child + grandchild, but NOT parent + killed := reg.CascadeStop("child") + if killed != 2 { + t.Fatalf("killed = %d, want 2", killed) + } + + // Grandchild context cancelled + select { + case <-ctxGrandchild.Done(): + default: + t.Fatal("grandchild context not cancelled") + } + + // Parent still active + if reg.ActiveCount() != 1 { + t.Fatalf("ActiveCount = %d, want 1 (parent survives)", reg.ActiveCount()) + } +} + +func TestRunRegistry_CascadeStop_MultipleSiblings(t *testing.T) { + reg := NewRunRegistry() + + _, cancelParent := context.WithCancel(context.Background()) + _, cancelSibling1 := context.WithCancel(context.Background()) + _, cancelSibling2 := context.WithCancel(context.Background()) + + reg.Register(&ActiveRun{ + SessionKey: "parent", + AgentID: "main", + Cancel: cancelParent, + StartedAt: time.Now(), + }) + reg.Register(&ActiveRun{ + SessionKey: "sibling-1", + AgentID: "coder", + ParentKey: "parent", + Cancel: cancelSibling1, + StartedAt: time.Now(), + }) + reg.Register(&ActiveRun{ + SessionKey: "sibling-2", + AgentID: "researcher", + ParentKey: "parent", + Cancel: cancelSibling2, + StartedAt: time.Now(), + }) + + killed := reg.CascadeStop("parent") + if killed != 3 { + t.Fatalf("killed = %d, want 3 (parent + 2 siblings)", killed) + } + if reg.ActiveCount() != 0 { + t.Fatalf("ActiveCount = %d, want 0", reg.ActiveCount()) + } +} + +func TestRunRegistry_CascadeStop_CycleProtection(t *testing.T) { + reg := NewRunRegistry() + + // Manually create a cycle: A→B→A (shouldn't happen in practice, but guard against it) + _, cancelA := context.WithCancel(context.Background()) + _, cancelB := context.WithCancel(context.Background()) + + reg.Register(&ActiveRun{ + SessionKey: "A", + AgentID: "agent-a", + ParentKey: "B", + Cancel: cancelA, + StartedAt: time.Now(), + }) + reg.Register(&ActiveRun{ + SessionKey: "B", + AgentID: "agent-b", + ParentKey: "A", + Cancel: cancelB, + StartedAt: time.Now(), + }) + + // Should not infinite-loop, should kill both + killed := reg.CascadeStop("A") + if killed != 2 { + t.Fatalf("killed = %d, want 2", killed) + } +} + +func TestRunRegistry_CascadeStop_NonExistent(t *testing.T) { + reg := NewRunRegistry() + + killed := reg.CascadeStop("ghost") + if killed != 0 { + t.Fatalf("killed = %d, want 0 for non-existent key", killed) + } +} + +func TestRunRegistry_StopAll(t *testing.T) { + reg := NewRunRegistry() + + var cancelled atomic.Int32 + for i := 0; i < 5; i++ { + _, cancel := context.WithCancel(context.Background()) + idx := i + reg.Register(&ActiveRun{ + SessionKey: fmt.Sprintf("run-%d", idx), + AgentID: "agent", + Cancel: func() { + cancelled.Add(1) + cancel() + }, + StartedAt: time.Now(), + }) + } + + killed := reg.StopAll() + if killed != 5 { + t.Fatalf("StopAll killed = %d, want 5", killed) + } + if cancelled.Load() != 5 { + t.Fatalf("cancel called %d times, want 5", cancelled.Load()) + } + if reg.ActiveCount() != 0 { + t.Fatalf("ActiveCount = %d after StopAll, want 0", reg.ActiveCount()) + } +} + +func TestRunRegistry_GetChildren(t *testing.T) { + reg := NewRunRegistry() + _, cancel := context.WithCancel(context.Background()) + defer cancel() + + reg.Register(&ActiveRun{SessionKey: "parent", Cancel: cancel, StartedAt: time.Now()}) + reg.Register(&ActiveRun{SessionKey: "child-1", ParentKey: "parent", Cancel: cancel, StartedAt: time.Now()}) + reg.Register(&ActiveRun{SessionKey: "child-2", ParentKey: "parent", Cancel: cancel, StartedAt: time.Now()}) + reg.Register(&ActiveRun{SessionKey: "other", ParentKey: "other-parent", Cancel: cancel, StartedAt: time.Now()}) + + children := reg.GetChildren("parent") + if len(children) != 2 { + t.Fatalf("GetChildren = %d, want 2", len(children)) + } +} + +func TestRunRegistry_ContextCancellationPropagates(t *testing.T) { + reg := NewRunRegistry() + + // Simulate Go context tree: parent ctx → child ctx + parentCtx, parentCancel := context.WithCancel(context.Background()) + childCtx, childCancel := context.WithCancel(parentCtx) + + reg.Register(&ActiveRun{ + SessionKey: "parent", + AgentID: "main", + Cancel: parentCancel, + StartedAt: time.Now(), + }) + reg.Register(&ActiveRun{ + SessionKey: "child", + AgentID: "coder", + ParentKey: "parent", + Cancel: childCancel, + StartedAt: time.Now(), + }) + + // Cascade stop on parent should cancel parent context + reg.CascadeStop("parent") + + // Parent context cancelled + select { + case <-parentCtx.Done(): + default: + t.Fatal("parent context not cancelled") + } + + // Child context should ALSO be cancelled (Go context tree propagation) + select { + case <-childCtx.Done(): + default: + t.Fatal("child context not cancelled via Go context tree") + } +} diff --git a/pkg/multiagent/dedup.go b/pkg/multiagent/dedup.go new file mode 100644 index 000000000..b4f5b93a3 --- /dev/null +++ b/pkg/multiagent/dedup.go @@ -0,0 +1,170 @@ +package multiagent + +import ( + "crypto/sha256" + "fmt" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/logger" +) + +// Idempotency cache defaults. +// Follows Stripe's idempotency key pattern: deterministic keys with TTL-based expiry. +const ( + DefaultDedupTTL = 5 * time.Minute + DefaultDedupSweepInterval = 60 * time.Second +) + +// DedupEntry tracks a single idempotent operation. +type DedupEntry struct { + Key string + CreatedAt time.Time + ExpiresAt time.Time + Result string // cached result for idempotent replay +} + +// DedupCache provides idempotent execution guarantees for spawn and announce +// operations. Uses deterministic keys (like Stripe) with TTL-based expiry +// and periodic sweep (like Google Cloud Tasks dedup). +type DedupCache struct { + mu sync.RWMutex + entries map[string]*DedupEntry + ttl time.Duration + stop chan struct{} +} + +// NewDedupCache creates a dedup cache with the given TTL and starts +// the background sweep goroutine. +func NewDedupCache(ttl time.Duration) *DedupCache { + if ttl <= 0 { + ttl = DefaultDedupTTL + } + dc := &DedupCache{ + entries: make(map[string]*DedupEntry), + ttl: ttl, + stop: make(chan struct{}), + } + go dc.sweepLoop() + return dc +} + +// Check returns true if the key has already been processed (duplicate). +// If not a duplicate, registers the key and returns false. +// Thread-safe via mutex (simpler than CAS for this throughput level). +func (dc *DedupCache) Check(key string) bool { + dc.mu.Lock() + defer dc.mu.Unlock() + + now := time.Now() + + // Check if key exists and hasn't expired. + if entry, ok := dc.entries[key]; ok { + if now.Before(entry.ExpiresAt) { + return true // duplicate + } + // Expired — remove and treat as new. + delete(dc.entries, key) + } + + // Register new key. + dc.entries[key] = &DedupEntry{ + Key: key, + CreatedAt: now, + ExpiresAt: now.Add(dc.ttl), + } + return false // not a duplicate +} + +// CheckWithResult returns the cached result if the key is a duplicate. +// If not a duplicate, registers the key and returns ("", false). +func (dc *DedupCache) CheckWithResult(key string) (string, bool) { + dc.mu.Lock() + defer dc.mu.Unlock() + + now := time.Now() + if entry, ok := dc.entries[key]; ok { + if now.Before(entry.ExpiresAt) { + return entry.Result, true + } + delete(dc.entries, key) + } + + dc.entries[key] = &DedupEntry{ + Key: key, + CreatedAt: now, + ExpiresAt: now.Add(dc.ttl), + } + return "", false +} + +// SetResult stores the result for an already-registered key. +func (dc *DedupCache) SetResult(key, result string) { + dc.mu.Lock() + defer dc.mu.Unlock() + if entry, ok := dc.entries[key]; ok { + entry.Result = result + } +} + +// Size returns the current number of entries. +func (dc *DedupCache) Size() int { + dc.mu.RLock() + defer dc.mu.RUnlock() + return len(dc.entries) +} + +// Stop stops the background sweep goroutine. +func (dc *DedupCache) Stop() { + close(dc.stop) +} + +// sweepLoop periodically removes expired entries (Google Cloud Tasks pattern). +func (dc *DedupCache) sweepLoop() { + ticker := time.NewTicker(DefaultDedupSweepInterval) + defer ticker.Stop() + + for { + select { + case <-dc.stop: + return + case <-ticker.C: + dc.sweep() + } + } +} + +func (dc *DedupCache) sweep() { + dc.mu.Lock() + defer dc.mu.Unlock() + + now := time.Now() + expired := 0 + for key, entry := range dc.entries { + if now.After(entry.ExpiresAt) { + delete(dc.entries, key) + expired++ + } + } + if expired > 0 { + logger.DebugCF("dedup", "Sweep completed", map[string]interface{}{ + "expired": expired, + "remaining": len(dc.entries), + }) + } +} + +// BuildSpawnKey creates a deterministic dedup key for a spawn request. +// Format: "spawn:v1:{from}:{to}:{task_hash}" +// Same task from the same agent to the same target within TTL = idempotent. +func BuildSpawnKey(fromAgentID, toAgentID, task string) string { + h := sha256.Sum256([]byte(task)) + return fmt.Sprintf("spawn:v1:%s:%s:%x", fromAgentID, toAgentID, h[:8]) +} + +// BuildAnnounceKey creates a deterministic dedup key for an announcement. +// Format: "announce:v1:{childSessionKey}:{runID}" +// Prevents duplicate announcements for the same spawn completion. +func BuildAnnounceKey(childSessionKey, runID string) string { + return fmt.Sprintf("announce:v1:%s:%s", childSessionKey, runID) +} diff --git a/pkg/multiagent/dedup_test.go b/pkg/multiagent/dedup_test.go new file mode 100644 index 000000000..6969dd992 --- /dev/null +++ b/pkg/multiagent/dedup_test.go @@ -0,0 +1,160 @@ +package multiagent + +import ( + "sync" + "testing" + "time" +) + +func TestDedupCache_FirstCallNotDuplicate(t *testing.T) { + dc := NewDedupCache(5 * time.Minute) + defer dc.Stop() + + if dc.Check("key-1") { + t.Error("first call should not be a duplicate") + } +} + +func TestDedupCache_SecondCallIsDuplicate(t *testing.T) { + dc := NewDedupCache(5 * time.Minute) + defer dc.Stop() + + dc.Check("key-1") + if !dc.Check("key-1") { + t.Error("second call with same key should be a duplicate") + } +} + +func TestDedupCache_DifferentKeysNotDuplicate(t *testing.T) { + dc := NewDedupCache(5 * time.Minute) + defer dc.Stop() + + dc.Check("key-1") + if dc.Check("key-2") { + t.Error("different key should not be a duplicate") + } +} + +func TestDedupCache_ExpiredEntryNotDuplicate(t *testing.T) { + dc := NewDedupCache(50 * time.Millisecond) // very short TTL + defer dc.Stop() + + dc.Check("key-1") + time.Sleep(100 * time.Millisecond) // wait for expiry + + if dc.Check("key-1") { + t.Error("expired entry should not be treated as duplicate") + } +} + +func TestDedupCache_CheckWithResult(t *testing.T) { + dc := NewDedupCache(5 * time.Minute) + defer dc.Stop() + + // First call: not a duplicate + result, isDup := dc.CheckWithResult("key-1") + if isDup || result != "" { + t.Error("first call should not be a duplicate") + } + + // Set result + dc.SetResult("key-1", "cached-result") + + // Second call: duplicate with cached result + result, isDup = dc.CheckWithResult("key-1") + if !isDup { + t.Error("second call should be a duplicate") + } + if result != "cached-result" { + t.Errorf("expected cached-result, got %q", result) + } +} + +func TestDedupCache_Size(t *testing.T) { + dc := NewDedupCache(5 * time.Minute) + defer dc.Stop() + + if dc.Size() != 0 { + t.Error("expected size 0") + } + + dc.Check("key-1") + dc.Check("key-2") + dc.Check("key-3") + + if dc.Size() != 3 { + t.Errorf("expected size 3, got %d", dc.Size()) + } +} + +func TestDedupCache_ConcurrentAccess(t *testing.T) { + dc := NewDedupCache(5 * time.Minute) + defer dc.Stop() + + var wg sync.WaitGroup + duplicates := 0 + var mu sync.Mutex + + // 100 goroutines all trying the same key + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if dc.Check("same-key") { + mu.Lock() + duplicates++ + mu.Unlock() + } + }() + } + wg.Wait() + + // Exactly 99 should be duplicates (first one registers) + if duplicates != 99 { + t.Errorf("expected 99 duplicates, got %d", duplicates) + } +} + +func TestDedupCache_Sweep(t *testing.T) { + dc := NewDedupCache(50 * time.Millisecond) + defer dc.Stop() + + dc.Check("key-1") + dc.Check("key-2") + dc.Check("key-3") + + if dc.Size() != 3 { + t.Fatalf("expected 3, got %d", dc.Size()) + } + + // Wait for entries to expire + time.Sleep(100 * time.Millisecond) + + // Manually trigger sweep + dc.sweep() + + if dc.Size() != 0 { + t.Errorf("expected 0 after sweep, got %d", dc.Size()) + } +} + +func TestBuildSpawnKey_Deterministic(t *testing.T) { + k1 := BuildSpawnKey("main", "worker", "do X") + k2 := BuildSpawnKey("main", "worker", "do X") + k3 := BuildSpawnKey("main", "worker", "do Y") + + if k1 != k2 { + t.Error("same inputs should produce same key") + } + if k1 == k3 { + t.Error("different task should produce different key") + } +} + +func TestBuildAnnounceKey_Format(t *testing.T) { + key := BuildAnnounceKey("child-session", "run-123") + expected := "announce:v1:child-session:run-123" + if key != expected { + t.Errorf("expected %q, got %q", expected, key) + } +} diff --git a/pkg/multiagent/handoff.go b/pkg/multiagent/handoff.go new file mode 100644 index 000000000..f90794543 --- /dev/null +++ b/pkg/multiagent/handoff.go @@ -0,0 +1,218 @@ +package multiagent + +import ( + "context" + "fmt" + "slices" + + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/tools" +) + +// AgentResolver looks up an agent by ID. +// Typically backed by agent.AgentRegistry.GetAgent. +type AgentResolver interface { + GetAgentInfo(agentID string) *AgentInfo + ListAgents() []AgentInfo +} + +// AllowlistChecker determines whether a handoff from one agent to another is allowed. +type AllowlistChecker interface { + CanHandoff(fromAgentID, toAgentID string) bool +} + +// AllowlistCheckerFunc adapts a function to the AllowlistChecker interface. +type AllowlistCheckerFunc func(fromAgentID, toAgentID string) bool + +func (f AllowlistCheckerFunc) CanHandoff(fromAgentID, toAgentID string) bool { + return f(fromAgentID, toAgentID) +} + +// AgentInfo is a minimal view of an agent for handoff purposes, +// decoupled from the full AgentInstance to avoid circular imports. +type AgentInfo struct { + ID string + Name string + Role string + SystemPrompt string + Model string + Provider providers.LLMProvider + Tools *tools.ToolRegistry + MaxIter int + Capabilities []string // optional tags for capability-based routing (e.g. "coding", "research") +} + +// FindAgentsByCapability returns agents that advertise the given capability. +func FindAgentsByCapability(resolver AgentResolver, capability string) []AgentInfo { + var matches []AgentInfo + for _, a := range resolver.ListAgents() { + if slices.Contains(a.Capabilities, capability) { + matches = append(matches, a) + } + } + return matches +} + +// DefaultMaxHandoffDepth is the maximum handoff chain depth when not configured. +const DefaultMaxHandoffDepth = 3 + +// HandoffRequest describes a delegation from one agent to another. +type HandoffRequest struct { + FromAgentID string + ToAgentID string + Task string + Context map[string]string // k-v to write to blackboard before handoff + Depth int // current depth level (0 = top-level) + Visited []string // agent IDs already in the call chain + MaxDepth int // max allowed depth (0 = use DefaultMaxHandoffDepth) + ParentRunKey string // parent run session key for cascade stop tracking +} + +// HandoffResult contains the outcome of a handoff execution. +type HandoffResult struct { + AgentID string + Content string + Iterations int + Success bool + Error string +} + +// ExecuteHandoff delegates a task to a target agent, injecting blackboard context. +// It enforces recursion guards: depth limit and cycle detection. +func ExecuteHandoff(ctx context.Context, resolver AgentResolver, board *Blackboard, req HandoffRequest, channel, chatID string) *HandoffResult { + // Recursion guard: depth limit + maxDepth := req.MaxDepth + if maxDepth == 0 { + maxDepth = DefaultMaxHandoffDepth + } + if req.Depth >= maxDepth { + return &HandoffResult{ + AgentID: req.ToAgentID, + Success: false, + Error: fmt.Sprintf("handoff depth limit reached (%d/%d): %v -> %s", req.Depth, maxDepth, req.Visited, req.ToAgentID), + } + } + + // Recursion guard: cycle detection + for _, v := range req.Visited { + if v == req.ToAgentID { + return &HandoffResult{ + AgentID: req.ToAgentID, + Success: false, + Error: fmt.Sprintf("handoff cycle detected: %q already in chain %v", req.ToAgentID, req.Visited), + } + } + } + + target := resolver.GetAgentInfo(req.ToAgentID) + if target == nil { + return &HandoffResult{ + AgentID: req.ToAgentID, + Success: false, + Error: fmt.Sprintf("agent %q not found", req.ToAgentID), + } + } + + // Write request context to blackboard + if board != nil && req.Context != nil { + for k, v := range req.Context { + board.Set(k, v, req.FromAgentID) + } + } + + // Propagate depth and visited to target agent's handoff tool + newVisited := make([]string, len(req.Visited)+1) + copy(newVisited, req.Visited) + newVisited[len(req.Visited)] = req.ToAgentID + + if target.Tools != nil { + // Wire session blackboard to target's tools + if tool, ok := target.Tools.Get("blackboard"); ok { + if ba, ok := tool.(BoardAware); ok { + ba.SetBoard(board) + } + } + if tool, ok := target.Tools.Get("handoff"); ok { + if ba, ok := tool.(BoardAware); ok { + ba.SetBoard(board) + } + if ht, ok := tool.(*HandoffTool); ok { + ht.depth = req.Depth + 1 + ht.visited = newVisited + ht.maxDepth = maxDepth + ht.parentSessionKey = req.ParentRunKey + } + } + } + + // Build system prompt incorporating agent role, system prompt, and blackboard + systemPrompt := buildHandoffSystemPrompt(target, board) + + messages := []providers.Message{ + {Role: "system", Content: systemPrompt}, + {Role: "user", Content: req.Task}, + } + + maxIter := target.MaxIter + if maxIter == 0 { + maxIter = 10 + } + + // Apply depth-based tool policy: clone target tools and remove depth-denied tools. + // At max depth, leaf agents lose spawn/handoff/list_agents to prevent further chaining. + targetTools := target.Tools + denyList := tools.DepthDenyList(req.Depth+1, maxDepth) + if len(denyList) > 0 && targetTools != nil { + targetTools = target.Tools.Clone() + tools.ApplyPolicy(targetTools, tools.ToolPolicy{Deny: denyList}) + } + + loopResult, err := tools.RunToolLoop(ctx, tools.ToolLoopConfig{ + Provider: target.Provider, + Model: target.Model, + Tools: targetTools, + MaxIterations: maxIter, + LLMOptions: map[string]any{ + "max_tokens": 4096, + "temperature": 0.7, + }, + }, messages, channel, chatID) + + if err != nil { + return &HandoffResult{ + AgentID: req.ToAgentID, + Success: false, + Error: err.Error(), + } + } + + return &HandoffResult{ + AgentID: req.ToAgentID, + Content: loopResult.Content, + Iterations: loopResult.Iterations, + Success: true, + } +} + +func buildHandoffSystemPrompt(agent *AgentInfo, board *Blackboard) string { + prompt := "You are " + agent.Name + if agent.Role != "" { + prompt += ", " + agent.Role + } + prompt += ".\n" + + if agent.SystemPrompt != "" { + prompt += "\n" + agent.SystemPrompt + "\n" + } + + prompt += "\nComplete the delegated task and provide a clear result." + + if board != nil { + snapshot := board.Snapshot() + if snapshot != "" { + prompt += "\n\n" + snapshot + } + } + + return prompt +} diff --git a/pkg/multiagent/handoff_test.go b/pkg/multiagent/handoff_test.go new file mode 100644 index 000000000..91e8c5e72 --- /dev/null +++ b/pkg/multiagent/handoff_test.go @@ -0,0 +1,847 @@ +package multiagent + +import ( + "context" + "fmt" + "strings" + "testing" + + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/tools" +) + +// mockProvider is a minimal LLM provider for testing. +type mockProvider struct { + response string + err error +} + +func (m *mockProvider) Chat(_ context.Context, _ []providers.Message, _ []providers.ToolDefinition, _ string, _ map[string]any) (*providers.LLMResponse, error) { + if m.err != nil { + return nil, m.err + } + return &providers.LLMResponse{ + Content: m.response, + FinishReason: "stop", + }, nil +} + +func (m *mockProvider) GetDefaultModel() string { return "mock-model" } + +// mockResolver implements AgentResolver for testing. +type mockResolver struct { + agents map[string]*AgentInfo +} + +func newMockResolver(agents ...*AgentInfo) *mockResolver { + m := &mockResolver{agents: make(map[string]*AgentInfo)} + for _, a := range agents { + m.agents[a.ID] = a + } + return m +} + +func (m *mockResolver) GetAgentInfo(agentID string) *AgentInfo { + return m.agents[agentID] +} + +func (m *mockResolver) ListAgents() []AgentInfo { + result := make([]AgentInfo, 0, len(m.agents)) + for _, a := range m.agents { + result = append(result, *a) + } + return result +} + +func TestExecuteHandoff_Success(t *testing.T) { + provider := &mockProvider{response: "task completed successfully"} + resolver := newMockResolver(&AgentInfo{ + ID: "coder", + Name: "Code Agent", + Role: "coding specialist", + Model: "test-model", + Provider: provider, + Tools: tools.NewToolRegistry(), + MaxIter: 5, + }) + + bb := NewBlackboard() + result := ExecuteHandoff(context.Background(), resolver, bb, HandoffRequest{ + FromAgentID: "main", + ToAgentID: "coder", + Task: "write a function", + Context: map[string]string{"language": "Go"}, + }, "cli", "direct") + + if !result.Success { + t.Fatalf("expected success, got error: %s", result.Error) + } + if result.Content != "task completed successfully" { + t.Errorf("Content = %q, want %q", result.Content, "task completed successfully") + } + if result.AgentID != "coder" { + t.Errorf("AgentID = %q, want %q", result.AgentID, "coder") + } + + // Verify context was written to blackboard + if bb.Get("language") != "Go" { + t.Errorf("blackboard 'language' = %q, want %q", bb.Get("language"), "Go") + } +} + +func TestExecuteHandoff_UnknownAgent(t *testing.T) { + resolver := newMockResolver() + bb := NewBlackboard() + + result := ExecuteHandoff(context.Background(), resolver, bb, HandoffRequest{ + FromAgentID: "main", + ToAgentID: "nonexistent", + Task: "do something", + }, "cli", "direct") + + if result.Success { + t.Error("expected failure for unknown agent") + } + if !strings.Contains(result.Error, "not found") { + t.Errorf("Error = %q, expected 'not found'", result.Error) + } +} + +func TestExecuteHandoff_NilBlackboard(t *testing.T) { + provider := &mockProvider{response: "done"} + resolver := newMockResolver(&AgentInfo{ + ID: "helper", + Name: "Helper", + Model: "test", + Provider: provider, + Tools: tools.NewToolRegistry(), + MaxIter: 5, + }) + + // Should not panic with nil blackboard + result := ExecuteHandoff(context.Background(), resolver, nil, HandoffRequest{ + FromAgentID: "main", + ToAgentID: "helper", + Task: "help me", + Context: map[string]string{"key": "value"}, + }, "cli", "direct") + + if !result.Success { + t.Fatalf("expected success, got error: %s", result.Error) + } +} + +func TestHandoffTool_Execute(t *testing.T) { + provider := &mockProvider{response: "handoff result"} + resolver := newMockResolver( + &AgentInfo{ID: "main", Name: "Main", Role: "orchestrator", Provider: provider, Tools: tools.NewToolRegistry(), MaxIter: 5}, + &AgentInfo{ID: "coder", Name: "Coder", Role: "coding", Model: "test", Provider: provider, Tools: tools.NewToolRegistry(), MaxIter: 5}, + ) + + bb := NewBlackboard() + tool := NewHandoffTool(resolver, bb, "main") + + result := tool.Execute(context.Background(), map[string]any{ + "agent_id": "coder", + "task": "write code", + }) + + if result.IsError { + t.Fatalf("handoff tool failed: %s", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "handoff result") { + t.Errorf("ForLLM = %q, expected to contain 'handoff result'", result.ForLLM) + } +} + +func TestHandoffTool_MissingArgs(t *testing.T) { + resolver := newMockResolver() + bb := NewBlackboard() + tool := NewHandoffTool(resolver, bb, "main") + + // Missing agent_id + result := tool.Execute(context.Background(), map[string]any{ + "task": "do something", + }) + if !result.IsError { + t.Error("expected error for missing agent_id") + } + + // Missing task + result = tool.Execute(context.Background(), map[string]any{ + "agent_id": "coder", + }) + if !result.IsError { + t.Error("expected error for missing task") + } +} + +func TestHandoffTool_Description(t *testing.T) { + resolver := newMockResolver( + &AgentInfo{ID: "main", Name: "Main"}, + &AgentInfo{ID: "coder", Name: "Coder", Role: "coding specialist"}, + ) + bb := NewBlackboard() + tool := NewHandoffTool(resolver, bb, "main") + + desc := tool.Description() + if !strings.Contains(desc, "coder") { + t.Errorf("Description = %q, expected to contain 'coder'", desc) + } + if !strings.Contains(desc, "coding specialist") { + t.Errorf("Description = %q, expected to contain role", desc) + } +} + +func TestHandoffTool_Description_WithCapabilities(t *testing.T) { + resolver := newMockResolver( + &AgentInfo{ID: "main", Name: "Main"}, + &AgentInfo{ID: "coder", Name: "Coder", Role: "coding", Capabilities: []string{"coding", "review"}}, + ) + bb := NewBlackboard() + tool := NewHandoffTool(resolver, bb, "main") + + desc := tool.Description() + if !strings.Contains(desc, "coding, review") { + t.Errorf("Description = %q, expected capabilities", desc) + } +} + +func TestHandoffTool_ExecuteByCapability(t *testing.T) { + provider := &mockProvider{response: "capability result"} + resolver := newMockResolver( + &AgentInfo{ID: "main", Name: "Main", Provider: provider, Tools: tools.NewToolRegistry(), MaxIter: 5}, + &AgentInfo{ID: "coder", Name: "Coder", Capabilities: []string{"coding"}, Provider: provider, Tools: tools.NewToolRegistry(), MaxIter: 5}, + ) + bb := NewBlackboard() + tool := NewHandoffTool(resolver, bb, "main") + + result := tool.Execute(context.Background(), map[string]any{ + "capability": "coding", + "task": "write a function", + }) + + if result.IsError { + t.Fatalf("handoff by capability failed: %s", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "capability result") { + t.Errorf("ForLLM = %q, expected 'capability result'", result.ForLLM) + } +} + +func TestHandoffTool_ExecuteByCapability_NotFound(t *testing.T) { + resolver := newMockResolver( + &AgentInfo{ID: "main", Name: "Main"}, + ) + bb := NewBlackboard() + tool := NewHandoffTool(resolver, bb, "main") + + result := tool.Execute(context.Background(), map[string]any{ + "capability": "nonexistent", + "task": "do something", + }) + + if !result.IsError { + t.Error("expected error for unknown capability") + } +} + +func TestHandoffTool_ExecuteNoAgentNoCapability(t *testing.T) { + resolver := newMockResolver() + bb := NewBlackboard() + tool := NewHandoffTool(resolver, bb, "main") + + result := tool.Execute(context.Background(), map[string]any{ + "task": "do something", + }) + + if !result.IsError { + t.Error("expected error when neither agent_id nor capability provided") + } +} + +func TestListAgentsTool_Execute(t *testing.T) { + resolver := newMockResolver( + &AgentInfo{ID: "main", Name: "Main Agent", Role: "general"}, + &AgentInfo{ID: "coder", Name: "Code Agent", Role: "coding"}, + ) + tool := NewListAgentsTool(resolver) + + result := tool.Execute(context.Background(), nil) + if result.IsError { + t.Fatalf("list_agents failed: %s", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "main") || !strings.Contains(result.ForLLM, "coder") { + t.Errorf("ForLLM = %q, expected agent IDs", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "2") { + t.Errorf("ForLLM = %q, expected count", result.ForLLM) + } +} + +func TestListAgentsTool_Empty(t *testing.T) { + resolver := newMockResolver() + tool := NewListAgentsTool(resolver) + + result := tool.Execute(context.Background(), nil) + if result.IsError { + t.Fatalf("unexpected error: %s", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "No agents") { + t.Errorf("ForLLM = %q, expected 'No agents' message", result.ForLLM) + } +} + +func TestFindAgentsByCapability(t *testing.T) { + resolver := newMockResolver( + &AgentInfo{ID: "coder", Name: "Coder", Capabilities: []string{"coding", "review"}}, + &AgentInfo{ID: "researcher", Name: "Researcher", Capabilities: []string{"research", "web_search"}}, + &AgentInfo{ID: "generalist", Name: "Generalist"}, + ) + + // Find coding agents + matches := FindAgentsByCapability(resolver, "coding") + if len(matches) != 1 || matches[0].ID != "coder" { + t.Errorf("FindAgentsByCapability(coding) = %v, want [coder]", matches) + } + + // Find research agents + matches = FindAgentsByCapability(resolver, "research") + if len(matches) != 1 || matches[0].ID != "researcher" { + t.Errorf("FindAgentsByCapability(research) = %v, want [researcher]", matches) + } + + // No match + matches = FindAgentsByCapability(resolver, "design") + if len(matches) != 0 { + t.Errorf("FindAgentsByCapability(design) = %v, want empty", matches) + } +} + +func TestFindAgentsByCapability_Multiple(t *testing.T) { + resolver := newMockResolver( + &AgentInfo{ID: "a", Capabilities: []string{"coding"}}, + &AgentInfo{ID: "b", Capabilities: []string{"coding", "review"}}, + &AgentInfo{ID: "c", Capabilities: []string{"research"}}, + ) + + matches := FindAgentsByCapability(resolver, "coding") + if len(matches) != 2 { + t.Errorf("expected 2 matches, got %d", len(matches)) + } +} + +func TestFindAgentsByCapability_Empty(t *testing.T) { + resolver := newMockResolver() + matches := FindAgentsByCapability(resolver, "anything") + if len(matches) != 0 { + t.Errorf("expected empty, got %v", matches) + } +} + +func TestAgentInfo_Capabilities(t *testing.T) { + agent := &AgentInfo{ + ID: "coder", + Name: "Code Agent", + Capabilities: []string{"coding", "review", "testing"}, + } + if len(agent.Capabilities) != 3 { + t.Errorf("Capabilities len = %d, want 3", len(agent.Capabilities)) + } + + // Nil capabilities should not panic + agent2 := &AgentInfo{ID: "basic"} + if agent2.Capabilities != nil { + t.Error("expected nil Capabilities for unset agent") + } +} + +func TestExecuteHandoff_DepthLimit(t *testing.T) { + provider := &mockProvider{response: "done"} + resolver := newMockResolver(&AgentInfo{ + ID: "target", + Name: "Target", + Model: "test", + Provider: provider, + Tools: tools.NewToolRegistry(), + MaxIter: 5, + }) + + bb := NewBlackboard() + result := ExecuteHandoff(context.Background(), resolver, bb, HandoffRequest{ + FromAgentID: "main", + ToAgentID: "target", + Task: "do something", + Depth: 3, // at max depth + MaxDepth: 3, + Visited: []string{"main", "agent-a", "agent-b"}, + }, "cli", "direct") + + if result.Success { + t.Error("expected failure at max depth") + } + if !strings.Contains(result.Error, "depth limit") { + t.Errorf("Error = %q, expected 'depth limit'", result.Error) + } +} + +func TestExecuteHandoff_CycleDetection(t *testing.T) { + provider := &mockProvider{response: "done"} + resolver := newMockResolver( + &AgentInfo{ID: "main", Name: "Main", Model: "test", Provider: provider, Tools: tools.NewToolRegistry(), MaxIter: 5}, + &AgentInfo{ID: "coder", Name: "Coder", Model: "test", Provider: provider, Tools: tools.NewToolRegistry(), MaxIter: 5}, + ) + + bb := NewBlackboard() + + // Try to hand off to "main" which is already in the visited chain + result := ExecuteHandoff(context.Background(), resolver, bb, HandoffRequest{ + FromAgentID: "coder", + ToAgentID: "main", + Task: "some task", + Depth: 1, + Visited: []string{"main", "coder"}, + }, "cli", "direct") + + if result.Success { + t.Error("expected failure due to cycle detection") + } + if !strings.Contains(result.Error, "cycle detected") { + t.Errorf("Error = %q, expected 'cycle detected'", result.Error) + } +} + +func TestExecuteHandoff_DefaultMaxDepth(t *testing.T) { + provider := &mockProvider{response: "done"} + resolver := newMockResolver(&AgentInfo{ + ID: "target", Name: "Target", Model: "test", + Provider: provider, Tools: tools.NewToolRegistry(), MaxIter: 5, + }) + + bb := NewBlackboard() + + // Depth 2 with default max (3) should succeed + result := ExecuteHandoff(context.Background(), resolver, bb, HandoffRequest{ + FromAgentID: "main", + ToAgentID: "target", + Task: "do something", + Depth: 2, + Visited: []string{"main", "middle"}, + }, "cli", "direct") + if !result.Success { + t.Fatalf("expected success at depth 2 (max 3), got error: %s", result.Error) + } + + // Depth 3 with default max should fail + result = ExecuteHandoff(context.Background(), resolver, bb, HandoffRequest{ + FromAgentID: "main", + ToAgentID: "target", + Task: "do something", + Depth: 3, + Visited: []string{"main", "a", "b"}, + }, "cli", "direct") + if result.Success { + t.Error("expected failure at depth 3 with default max 3") + } +} + +func TestExecuteHandoff_PropagatesDepthToTarget(t *testing.T) { + provider := &mockProvider{response: "done"} + targetRegistry := tools.NewToolRegistry() + innerResolver := newMockResolver() + targetHandoff := NewHandoffTool(innerResolver, NewBlackboard(), "target") + targetRegistry.Register(targetHandoff) + + resolver := newMockResolver(&AgentInfo{ + ID: "target", Name: "Target", Model: "test", + Provider: provider, Tools: targetRegistry, MaxIter: 5, + }) + + bb := NewBlackboard() + result := ExecuteHandoff(context.Background(), resolver, bb, HandoffRequest{ + FromAgentID: "main", + ToAgentID: "target", + Task: "do something", + Depth: 1, + Visited: []string{"main"}, + MaxDepth: 5, + }, "cli", "direct") + + if !result.Success { + t.Fatalf("expected success, got error: %s", result.Error) + } + + // Verify the target's handoff tool got the propagated depth + if targetHandoff.depth != 2 { + t.Errorf("target handoff depth = %d, want 2", targetHandoff.depth) + } + if len(targetHandoff.visited) != 2 || targetHandoff.visited[0] != "main" || targetHandoff.visited[1] != "target" { + t.Errorf("target handoff visited = %v, want [main target]", targetHandoff.visited) + } + if targetHandoff.maxDepth != 5 { + t.Errorf("target handoff maxDepth = %d, want 5", targetHandoff.maxDepth) + } +} + +func TestHandoffTool_AllowlistBlocks(t *testing.T) { + provider := &mockProvider{response: "done"} + resolver := newMockResolver( + &AgentInfo{ID: "main", Name: "Main", Provider: provider, Tools: tools.NewToolRegistry(), MaxIter: 5}, + &AgentInfo{ID: "restricted", Name: "Restricted", Model: "test", Provider: provider, Tools: tools.NewToolRegistry(), MaxIter: 5}, + ) + + bb := NewBlackboard() + tool := NewHandoffTool(resolver, bb, "main") + tool.SetAllowlistChecker(AllowlistCheckerFunc(func(from, to string) bool { + return to == "allowed-agent" // only allow "allowed-agent" + })) + + result := tool.Execute(context.Background(), map[string]any{ + "agent_id": "restricted", + "task": "do something", + }) + if !result.IsError { + t.Error("expected error for blocked handoff") + } + if !strings.Contains(result.ForLLM, "not allowed") { + t.Errorf("ForLLM = %q, expected 'not allowed'", result.ForLLM) + } +} + +func TestHandoffTool_AllowlistPermits(t *testing.T) { + provider := &mockProvider{response: "allowed result"} + resolver := newMockResolver( + &AgentInfo{ID: "main", Name: "Main", Provider: provider, Tools: tools.NewToolRegistry(), MaxIter: 5}, + &AgentInfo{ID: "coder", Name: "Coder", Model: "test", Provider: provider, Tools: tools.NewToolRegistry(), MaxIter: 5}, + ) + + bb := NewBlackboard() + tool := NewHandoffTool(resolver, bb, "main") + tool.SetAllowlistChecker(AllowlistCheckerFunc(func(from, to string) bool { + return to == "coder" // allow coder + })) + + result := tool.Execute(context.Background(), map[string]any{ + "agent_id": "coder", + "task": "write code", + }) + if result.IsError { + t.Fatalf("expected success, got error: %s", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "allowed result") { + t.Errorf("ForLLM = %q, expected 'allowed result'", result.ForLLM) + } +} + +func TestHandoffTool_NoAllowlistAllowsAll(t *testing.T) { + provider := &mockProvider{response: "ok"} + resolver := newMockResolver( + &AgentInfo{ID: "main", Name: "Main", Provider: provider, Tools: tools.NewToolRegistry(), MaxIter: 5}, + &AgentInfo{ID: "any", Name: "Any", Model: "test", Provider: provider, Tools: tools.NewToolRegistry(), MaxIter: 5}, + ) + + bb := NewBlackboard() + tool := NewHandoffTool(resolver, bb, "main") + // No allowlist checker set + + result := tool.Execute(context.Background(), map[string]any{ + "agent_id": "any", + "task": "anything", + }) + if result.IsError { + t.Fatalf("expected success with no allowlist, got: %s", result.ForLLM) + } +} + +func TestHandoffTool_SetBoard(t *testing.T) { + provider := &mockProvider{response: "done"} + resolver := newMockResolver( + &AgentInfo{ID: "main", Name: "Main", Provider: provider, Tools: tools.NewToolRegistry(), MaxIter: 5}, + &AgentInfo{ID: "coder", Name: "Coder", Model: "test", Provider: provider, Tools: tools.NewToolRegistry(), MaxIter: 5}, + ) + + bb1 := NewBlackboard() + bb2 := NewBlackboard() + bb2.Set("session_data", "hello", "system") + + tool := NewHandoffTool(resolver, bb1, "main") + + // Switch to session board + tool.SetBoard(bb2) + + // Execute with context that writes to blackboard + tool.Execute(context.Background(), map[string]any{ + "agent_id": "coder", + "task": "write code", + "context": map[string]any{"language": "Go"}, + }) + + // Context should have been written to bb2 (session board), not bb1 + if bb1.Get("language") != "" { + t.Error("context was written to old board") + } + if bb2.Get("language") != "Go" { + t.Errorf("context not written to session board: %q", bb2.Get("language")) + } +} + +func TestAllowlistCheckerFunc(t *testing.T) { + checker := AllowlistCheckerFunc(func(from, to string) bool { + return from == "main" && to == "coder" + }) + + if !checker.CanHandoff("main", "coder") { + t.Error("expected main->coder to be allowed") + } + if checker.CanHandoff("main", "other") { + t.Error("expected main->other to be blocked") + } + if checker.CanHandoff("other", "coder") { + t.Error("expected other->coder to be blocked") + } +} + +// TestExecuteHandoff_DepthBoundary verifies that depth == maxDepth - 1 (one below limit) succeeds, +// while depth == maxDepth fails. This is the exact boundary behaviour of the recursion guard. +func TestExecuteHandoff_DepthBoundary(t *testing.T) { + provider := &mockProvider{response: "done"} + resolver := newMockResolver(&AgentInfo{ + ID: "target", Name: "Target", Model: "test", + Provider: provider, Tools: tools.NewToolRegistry(), MaxIter: 5, + }) + bb := NewBlackboard() + + // depth == maxDepth - 1 (2 < 3): must succeed + result := ExecuteHandoff(context.Background(), resolver, bb, HandoffRequest{ + FromAgentID: "main", + ToAgentID: "target", + Task: "do something", + Depth: 2, + MaxDepth: 3, + Visited: []string{"main", "middle"}, + }, "cli", "direct") + if !result.Success { + t.Errorf("depth == maxDepth-1 should succeed, got error: %s", result.Error) + } + + // depth == maxDepth (3 >= 3): must fail + result = ExecuteHandoff(context.Background(), resolver, bb, HandoffRequest{ + FromAgentID: "main", + ToAgentID: "target", + Task: "do something", + Depth: 3, + MaxDepth: 3, + Visited: []string{"main", "a", "b"}, + }, "cli", "direct") + if result.Success { + t.Error("depth == maxDepth should fail") + } + if !strings.Contains(result.Error, "depth limit") { + t.Errorf("Error = %q, expected 'depth limit'", result.Error) + } +} + +// TestExecuteHandoff_ProviderError verifies that a provider error during RunToolLoop +// is surfaced as a failed HandoffResult with an error message. +func TestExecuteHandoff_ProviderError(t *testing.T) { + provider := &mockProvider{err: fmt.Errorf("LLM provider unavailable")} + resolver := newMockResolver(&AgentInfo{ + ID: "target", Name: "Target", Model: "test", + Provider: provider, Tools: tools.NewToolRegistry(), MaxIter: 5, + }) + + bb := NewBlackboard() + result := ExecuteHandoff(context.Background(), resolver, bb, HandoffRequest{ + FromAgentID: "main", + ToAgentID: "target", + Task: "failing task", + }, "cli", "direct") + + if result.Success { + t.Error("expected failure when provider returns error") + } + if !strings.Contains(result.Error, "provider unavailable") { + t.Errorf("Error = %q, expected provider error message", result.Error) + } + if result.AgentID != "target" { + t.Errorf("AgentID = %q, want 'target'", result.AgentID) + } +} + +// TestExecuteHandoff_MaxIterDefault verifies that MaxIter == 0 on the target agent +// is defaulted to 10 inside ExecuteHandoff (not left as 0 which would mean no iterations). +func TestExecuteHandoff_MaxIterDefault(t *testing.T) { + provider := &mockProvider{response: "ran with default iter"} + resolver := newMockResolver(&AgentInfo{ + ID: "target", + Name: "Target", + Model: "test", + Provider: provider, + Tools: tools.NewToolRegistry(), + MaxIter: 0, // explicitly zero, should default to 10 + }) + + bb := NewBlackboard() + result := ExecuteHandoff(context.Background(), resolver, bb, HandoffRequest{ + FromAgentID: "main", + ToAgentID: "target", + Task: "task with default iter", + }, "cli", "direct") + + if !result.Success { + t.Errorf("expected success with default MaxIter, got: %s", result.Error) + } +} + +// TestExecuteHandoff_CycleDetectionSingleHop verifies A->A (self-handoff) is caught. +func TestExecuteHandoff_CycleDetectionSingleHop(t *testing.T) { + provider := &mockProvider{response: "done"} + resolver := newMockResolver(&AgentInfo{ + ID: "main", Name: "Main", Model: "test", + Provider: provider, Tools: tools.NewToolRegistry(), MaxIter: 5, + }) + + bb := NewBlackboard() + // "main" handing off to itself, already in visited + result := ExecuteHandoff(context.Background(), resolver, bb, HandoffRequest{ + FromAgentID: "main", + ToAgentID: "main", + Task: "self task", + Depth: 0, + Visited: []string{"main"}, + }, "cli", "direct") + + if result.Success { + t.Error("expected failure for self-handoff cycle") + } + if !strings.Contains(result.Error, "cycle detected") { + t.Errorf("Error = %q, expected 'cycle detected'", result.Error) + } +} + +// TestHandoffTool_SetContext verifies SetContext updates origin channel and chatID. +func TestHandoffTool_SetContext(t *testing.T) { + resolver := newMockResolver() + bb := NewBlackboard() + tool := NewHandoffTool(resolver, bb, "main") + + tool.SetContext("telegram", "chat-123") + + // Verify fields are updated (access via the exported setter, values verified by ensuring + // no panic and the defaults were overwritten — integration confirmed via Execute routing). + if tool.originChannel != "telegram" { + t.Errorf("originChannel = %q, want %q", tool.originChannel, "telegram") + } + if tool.originChatID != "chat-123" { + t.Errorf("originChatID = %q, want %q", tool.originChatID, "chat-123") + } +} + +// TestHandoff_DepthPolicy_LeafNoSpawn verifies that at max depth, the target agent's +// tool registry clone has spawn/handoff/list_agents removed. +func TestHandoff_DepthPolicy_LeafNoSpawn(t *testing.T) { + provider := &mockProvider{response: "leaf result"} + targetRegistry := tools.NewToolRegistry() + targetRegistry.Register(&simpleTool{name: "read_file"}) + targetRegistry.Register(&simpleTool{name: "spawn"}) + targetRegistry.Register(&simpleTool{name: "handoff"}) + targetRegistry.Register(&simpleTool{name: "list_agents"}) + + resolver := newMockResolver(&AgentInfo{ + ID: "leaf", Name: "Leaf Agent", Model: "test", + Provider: provider, Tools: targetRegistry, MaxIter: 5, + }) + + bb := NewBlackboard() + // Depth 2, maxDepth 3: the target will run at depth 3 (req.Depth+1), + // which equals maxDepth, triggering depth deny. + result := ExecuteHandoff(context.Background(), resolver, bb, HandoffRequest{ + FromAgentID: "main", + ToAgentID: "leaf", + Task: "do something as a leaf", + Depth: 2, + MaxDepth: 3, + Visited: []string{"main", "middle"}, + }, "cli", "direct") + + if !result.Success { + t.Fatalf("expected success, got error: %s", result.Error) + } + + // Original registry should still have all 4 tools (clone was modified, not original) + if targetRegistry.Count() != 4 { + t.Errorf("original registry count = %d, want 4 (unmodified)", targetRegistry.Count()) + } +} + +// TestHandoff_DepthPolicy_MidChain verifies that mid-chain agents retain all tools. +func TestHandoff_DepthPolicy_MidChain(t *testing.T) { + provider := &mockProvider{response: "mid-chain result"} + targetRegistry := tools.NewToolRegistry() + targetRegistry.Register(&simpleTool{name: "read_file"}) + targetRegistry.Register(&simpleTool{name: "spawn"}) + targetRegistry.Register(&simpleTool{name: "handoff"}) + targetRegistry.Register(&simpleTool{name: "list_agents"}) + + resolver := newMockResolver(&AgentInfo{ + ID: "mid", Name: "Mid Agent", Model: "test", + Provider: provider, Tools: targetRegistry, MaxIter: 5, + }) + + bb := NewBlackboard() + // Depth 0, maxDepth 3: target runs at depth 1, well below max. + result := ExecuteHandoff(context.Background(), resolver, bb, HandoffRequest{ + FromAgentID: "main", + ToAgentID: "mid", + Task: "mid-chain task", + Depth: 0, + MaxDepth: 3, + Visited: []string{"main"}, + }, "cli", "direct") + + if !result.Success { + t.Fatalf("expected success, got error: %s", result.Error) + } + + // Original registry should still have all 4 tools + if targetRegistry.Count() != 4 { + t.Errorf("original registry count = %d, want 4", targetRegistry.Count()) + } +} + +// simpleTool is a minimal tool for depth policy tests. +type simpleTool struct { + name string +} + +func (s *simpleTool) Name() string { return s.name } +func (s *simpleTool) Description() string { return "test tool" } +func (s *simpleTool) Parameters() map[string]interface{} { return nil } +func (s *simpleTool) Execute(_ context.Context, _ map[string]interface{}) *tools.ToolResult { + return tools.NewToolResult("ok") +} + +func TestBuildHandoffSystemPrompt(t *testing.T) { + agent := &AgentInfo{ + Name: "Code Agent", + Role: "coding specialist", + SystemPrompt: "Focus on Go code quality.", + } + + bb := NewBlackboard() + bb.Set("language", "Go", "main") + + prompt := buildHandoffSystemPrompt(agent, bb) + if !strings.Contains(prompt, "Code Agent") { + t.Errorf("prompt missing agent name: %s", prompt) + } + if !strings.Contains(prompt, "coding specialist") { + t.Errorf("prompt missing role: %s", prompt) + } + if !strings.Contains(prompt, "Focus on Go code quality") { + t.Errorf("prompt missing system prompt: %s", prompt) + } + if !strings.Contains(prompt, "language") { + t.Errorf("prompt missing blackboard context: %s", prompt) + } +} diff --git a/pkg/multiagent/handoff_tool.go b/pkg/multiagent/handoff_tool.go new file mode 100644 index 000000000..1f4ed23e0 --- /dev/null +++ b/pkg/multiagent/handoff_tool.go @@ -0,0 +1,192 @@ +package multiagent + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/sipeed/picoclaw/pkg/tools" +) + +// HandoffTool allows an LLM agent to delegate a task to another agent. +type HandoffTool struct { + resolver AgentResolver + board *Blackboard + fromAgentID string + originChannel string + originChatID string + depth int // current handoff depth (0 = top-level) + visited []string // agent IDs already in the call chain + maxDepth int // max allowed depth (0 = use DefaultMaxHandoffDepth) + allowlistChecker AllowlistChecker // optional; nil = allow all + registry *RunRegistry // optional; nil = no run tracking + parentSessionKey string // session key of the parent run +} + +// NewHandoffTool creates a handoff tool bound to a specific source agent. +func NewHandoffTool(resolver AgentResolver, board *Blackboard, fromAgentID string) *HandoffTool { + return &HandoffTool{ + resolver: resolver, + board: board, + fromAgentID: fromAgentID, + originChannel: "cli", + originChatID: "direct", + } +} + +// Name returns the tool name. +func (t *HandoffTool) Name() string { return "handoff" } + +// Description returns a dynamic description listing available target agents. +func (t *HandoffTool) Description() string { + agents := t.resolver.ListAgents() + if len(agents) <= 1 { + return "Delegate a task to another agent. No other agents are currently available." + } + + var sb strings.Builder + sb.WriteString("Delegate a task to another agent. Available agents:\n") + for _, a := range agents { + if a.ID == t.fromAgentID { + continue + } + fmt.Fprintf(&sb, "- %s", a.ID) + if a.Name != "" { + fmt.Fprintf(&sb, " (%s)", a.Name) + } + if a.Role != "" { + fmt.Fprintf(&sb, ": %s", a.Role) + } + if len(a.Capabilities) > 0 { + fmt.Fprintf(&sb, " [%s]", strings.Join(a.Capabilities, ", ")) + } + sb.WriteString("\n") + } + return sb.String() +} + +// Parameters returns the JSON Schema for the tool's input. +func (t *HandoffTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "agent_id": map[string]any{ + "type": "string", + "description": "The ID of the target agent to hand off to (required if capability is not set)", + }, + "capability": map[string]any{ + "type": "string", + "description": "Route to an agent with this capability instead of by ID (e.g. \"coding\", \"research\")", + }, + "task": map[string]any{ + "type": "string", + "description": "The task description for the target agent", + }, + "context": map[string]any{ + "type": "object", + "description": "Optional key-value context to share via blackboard before handoff", + }, + }, + "required": []string{"task"}, + } +} + +// SetBoard replaces the blackboard reference, allowing the tool to be wired +// to the correct per-session board before each execution. +func (t *HandoffTool) SetBoard(board *Blackboard) { + t.board = board +} + +// SetAllowlistChecker sets an optional checker that controls which agents +// can be handed off to. If nil, all handoffs are allowed. +func (t *HandoffTool) SetAllowlistChecker(checker AllowlistChecker) { + t.allowlistChecker = checker +} + +// SetRunRegistry sets the registry for tracking active runs (cascade cancellation). +func (t *HandoffTool) SetRunRegistry(registry *RunRegistry, parentSessionKey string) { + t.registry = registry + t.parentSessionKey = parentSessionKey +} + +// SetContext updates the origin channel and chat ID for handoff routing. +func (t *HandoffTool) SetContext(channel, chatID string) { + t.originChannel = channel + t.originChatID = chatID +} + +// Execute delegates a task to the specified target agent. +func (t *HandoffTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult { + agentID, _ := args["agent_id"].(string) + capability, _ := args["capability"].(string) + task, _ := args["task"].(string) + + if task == "" { + return tools.ErrorResult("task is required") + } + + // Resolve agent: by ID or by capability + if agentID == "" && capability != "" { + matches := FindAgentsByCapability(t.resolver, capability) + if len(matches) == 0 { + return tools.ErrorResult(fmt.Sprintf("no agent found with capability %q", capability)) + } + agentID = matches[0].ID + } + if agentID == "" { + return tools.ErrorResult("agent_id or capability is required") + } + + // Allowlist check: if a checker is set and it denies the handoff, block it. + if t.allowlistChecker != nil && !t.allowlistChecker.CanHandoff(t.fromAgentID, agentID) { + return tools.ErrorResult(fmt.Sprintf("handoff from %q to %q not allowed by policy", t.fromAgentID, agentID)) + } + + // Parse optional context map + var contextMap map[string]string + if ctxRaw, ok := args["context"].(map[string]any); ok { + contextMap = make(map[string]string, len(ctxRaw)) + for k, v := range ctxRaw { + contextMap[k] = fmt.Sprintf("%v", v) + } + } + + // Create cancellable context for cascade stop support. + // If the parent context is cancelled, this handoff is also cancelled. + childCtx, cancel := context.WithCancel(ctx) + defer cancel() + + // Register this run in the registry for cascade cancellation. + childSessionKey := fmt.Sprintf("handoff:%s:%s:%d:%d", t.fromAgentID, agentID, t.depth, time.Now().UnixNano()) + if t.registry != nil { + t.registry.Register(&ActiveRun{ + SessionKey: childSessionKey, + AgentID: agentID, + ParentKey: t.parentSessionKey, + Cancel: cancel, + StartedAt: time.Now(), + }) + defer t.registry.Deregister(childSessionKey) + } + + result := ExecuteHandoff(childCtx, t.resolver, t.board, HandoffRequest{ + FromAgentID: t.fromAgentID, + ToAgentID: agentID, + Task: task, + Context: contextMap, + Depth: t.depth, + Visited: t.visited, + MaxDepth: t.maxDepth, + ParentRunKey: childSessionKey, + }, t.originChannel, t.originChatID) + + if !result.Success { + return tools.ErrorResult(fmt.Sprintf("Handoff to %q failed: %s", agentID, result.Error)) + } + + return &tools.ToolResult{ + ForLLM: fmt.Sprintf("Agent %q completed task (iterations: %d):\n%s", agentID, result.Iterations, result.Content), + ForUser: result.Content, + } +} diff --git a/pkg/multiagent/list_agents_tool.go b/pkg/multiagent/list_agents_tool.go new file mode 100644 index 000000000..f46838848 --- /dev/null +++ b/pkg/multiagent/list_agents_tool.go @@ -0,0 +1,57 @@ +package multiagent + +import ( + "context" + "fmt" + "strings" + + "github.com/sipeed/picoclaw/pkg/tools" +) + +// ListAgentsTool allows the LLM to discover all available agents. +type ListAgentsTool struct { + resolver AgentResolver +} + +// NewListAgentsTool creates a discovery tool backed by an AgentResolver. +func NewListAgentsTool(resolver AgentResolver) *ListAgentsTool { + return &ListAgentsTool{resolver: resolver} +} + +// Name returns the tool name. +func (t *ListAgentsTool) Name() string { return "list_agents" } + +// Description returns a human-readable description of the tool. +func (t *ListAgentsTool) Description() string { + return "List all available agents with their IDs, names, and roles." +} + +// Parameters returns the JSON Schema for the tool's input. +func (t *ListAgentsTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{}, + } +} + +// Execute lists all registered agents with their metadata. +func (t *ListAgentsTool) Execute(_ context.Context, _ map[string]any) *tools.ToolResult { + agents := t.resolver.ListAgents() + if len(agents) == 0 { + return tools.NewToolResult("No agents registered.") + } + + var sb strings.Builder + fmt.Fprintf(&sb, "Available agents (%d):\n", len(agents)) + for _, a := range agents { + fmt.Fprintf(&sb, "- ID: %s", a.ID) + if a.Name != "" { + fmt.Fprintf(&sb, ", Name: %s", a.Name) + } + if a.Role != "" { + fmt.Fprintf(&sb, ", Role: %s", a.Role) + } + sb.WriteString("\n") + } + return tools.NewToolResult(sb.String()) +} diff --git a/pkg/multiagent/spawn.go b/pkg/multiagent/spawn.go new file mode 100644 index 000000000..3096ece95 --- /dev/null +++ b/pkg/multiagent/spawn.go @@ -0,0 +1,235 @@ +package multiagent + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/logger" +) + +// Spawn concurrency defaults. +// MaxChildrenPerAgent follows NVIDIA's stream scheduling pattern: +// limit parallel work to prevent resource exhaustion while maximizing throughput. +const ( + DefaultMaxChildren = 5 + DefaultSpawnTimeout = 5 * time.Minute +) + +// SpawnRequest describes an async agent invocation. +type SpawnRequest struct { + FromAgentID string + ToAgentID string + Task string + Context map[string]string // k-v to write to blackboard + Depth int + Visited []string + MaxDepth int + ParentRunKey string +} + +// SpawnResult is returned immediately to the caller (fire-and-forget). +type SpawnResult struct { + RunID string // unique identifier for this spawn + SessionKey string // child session key for tracking + Status string // "accepted" or "rejected" + Error string // rejection reason if status != "accepted" +} + +// SpawnOutcome is the final result written to the announcer when the spawn completes. +type SpawnOutcome struct { + RunID string + SessionKey string + AgentID string + Content string + Iterations int + Success bool + Error string + Duration time.Duration +} + +// SpawnManager orchestrates async agent spawns with concurrency limiting +// (semaphore pattern, inspired by NVIDIA CUDA stream scheduling and +// Apple GCD quality-of-service queues). +type SpawnManager struct { + registry *RunRegistry + announcer *Announcer + maxChildren int + timeout time.Duration + + // Per-parent semaphore: limits concurrent children per session. + // Google's MapReduce uses similar fan-out caps per mapper. + semaphores sync.Map // parentSessionKey -> *semaphore +} + +type semaphore struct { + ch chan struct{} +} + +func newSemaphore(max int) *semaphore { + return &semaphore{ch: make(chan struct{}, max)} +} + +func (s *semaphore) acquire() bool { + select { + case s.ch <- struct{}{}: + return true + default: + return false + } +} + +func (s *semaphore) release() { + <-s.ch +} + +func (s *semaphore) count() int { + return len(s.ch) +} + +// NewSpawnManager creates a spawn manager with the given limits. +func NewSpawnManager(registry *RunRegistry, announcer *Announcer, maxChildren int, timeout time.Duration) *SpawnManager { + if maxChildren <= 0 { + maxChildren = DefaultMaxChildren + } + if timeout <= 0 { + timeout = DefaultSpawnTimeout + } + return &SpawnManager{ + registry: registry, + announcer: announcer, + maxChildren: maxChildren, + timeout: timeout, + } +} + +// AsyncSpawn launches an agent in a background goroutine and returns immediately. +// Inspired by Google's fan-out pattern and Anthropic's parallel tool execution. +// The result is delivered via the Announcer when the spawn completes. +func (sm *SpawnManager) AsyncSpawn( + ctx context.Context, + resolver AgentResolver, + board *Blackboard, + req SpawnRequest, + channel, chatID string, +) *SpawnResult { + // Generate unique run ID and session key. + runID := fmt.Sprintf("spawn:%s:%s:%d", req.FromAgentID, req.ToAgentID, time.Now().UnixNano()) + childSessionKey := fmt.Sprintf("spawn:%s:%s:%d:%d", req.FromAgentID, req.ToAgentID, req.Depth, time.Now().UnixNano()) + + // Acquire per-parent semaphore (NVIDIA stream scheduling pattern). + sem := sm.getOrCreateSemaphore(req.ParentRunKey) + if !sem.acquire() { + return &SpawnResult{ + RunID: runID, + SessionKey: childSessionKey, + Status: "rejected", + Error: fmt.Sprintf("max concurrent children reached (%d/%d) for parent session", sem.count(), sm.maxChildren), + } + } + + // Create cancellable context with timeout to prevent goroutine leaks. + // Microsoft Azure Functions uses similar timeout patterns for durable functions. + spawnCtx, cancel := context.WithTimeout(ctx, sm.timeout) + + // Register in RunRegistry for cascade cancellation (built in Phase 3d). + sm.registry.Register(&ActiveRun{ + SessionKey: childSessionKey, + AgentID: req.ToAgentID, + ParentKey: req.ParentRunKey, + Cancel: cancel, + StartedAt: time.Now(), + }) + + logger.InfoCF("spawn", "Async spawn started", map[string]interface{}{ + "run_id": runID, + "from": req.FromAgentID, + "to": req.ToAgentID, + "depth": req.Depth, + "parent": req.ParentRunKey, + "timeout": sm.timeout.String(), + "active": sem.count(), + "max": sm.maxChildren, + }) + + // Fire-and-forget goroutine (Google MapReduce worker pattern). + go func() { + defer cancel() + defer sem.release() + defer sm.registry.Deregister(childSessionKey) + + start := time.Now() + + // Execute the handoff synchronously inside the goroutine. + result := ExecuteHandoff(spawnCtx, resolver, board, HandoffRequest{ + FromAgentID: req.FromAgentID, + ToAgentID: req.ToAgentID, + Task: req.Task, + Context: req.Context, + Depth: req.Depth, + Visited: req.Visited, + MaxDepth: req.MaxDepth, + ParentRunKey: childSessionKey, + }, channel, chatID) + + outcome := &SpawnOutcome{ + RunID: runID, + SessionKey: childSessionKey, + AgentID: req.ToAgentID, + Content: result.Content, + Iterations: result.Iterations, + Success: result.Success, + Error: result.Error, + Duration: time.Since(start), + } + + // Push result to parent via Announcer (Anthropic's auto-announce pattern). + if sm.announcer != nil { + sm.announcer.Deliver(req.ParentRunKey, &Announcement{ + FromSessionKey: childSessionKey, + ToSessionKey: req.ParentRunKey, + RunID: runID, + AgentID: req.ToAgentID, + Content: formatOutcomeMessage(outcome), + Outcome: outcome, + }) + } + + logger.InfoCF("spawn", "Async spawn completed", map[string]interface{}{ + "run_id": runID, + "agent_id": req.ToAgentID, + "success": result.Success, + "iterations": result.Iterations, + "duration": outcome.Duration.Round(time.Millisecond).String(), + }) + }() + + return &SpawnResult{ + RunID: runID, + SessionKey: childSessionKey, + Status: "accepted", + } +} + +// ActiveChildCount returns the number of active children for a parent session. +func (sm *SpawnManager) ActiveChildCount(parentSessionKey string) int { + return len(sm.registry.GetChildren(parentSessionKey)) +} + +func (sm *SpawnManager) getOrCreateSemaphore(parentKey string) *semaphore { + if v, ok := sm.semaphores.Load(parentKey); ok { + return v.(*semaphore) + } + sem := newSemaphore(sm.maxChildren) + actual, _ := sm.semaphores.LoadOrStore(parentKey, sem) + return actual.(*semaphore) +} + +func formatOutcomeMessage(o *SpawnOutcome) string { + if !o.Success { + return fmt.Sprintf("[Subagent %q failed after %s: %s]", o.AgentID, o.Duration.Round(time.Millisecond), o.Error) + } + return fmt.Sprintf("[Subagent %q completed in %s (%d iterations)]:\n%s", + o.AgentID, o.Duration.Round(time.Millisecond), o.Iterations, o.Content) +} diff --git a/pkg/multiagent/spawn_test.go b/pkg/multiagent/spawn_test.go new file mode 100644 index 000000000..7c4007e23 --- /dev/null +++ b/pkg/multiagent/spawn_test.go @@ -0,0 +1,378 @@ +package multiagent + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/tools" +) + +// mockAgentResolver provides test agents. +type mockAgentResolver struct { + agents map[string]*AgentInfo +} + +func (r *mockAgentResolver) GetAgentInfo(id string) *AgentInfo { + return r.agents[id] +} + +func (r *mockAgentResolver) ListAgents() []AgentInfo { + var list []AgentInfo + for _, a := range r.agents { + list = append(list, *a) + } + return list +} + +// mockLLMProvider returns a fixed response after a configurable delay. +type mockLLMProvider struct { + response string + delay time.Duration +} + +func (m *mockLLMProvider) Chat(ctx context.Context, messages []providers.Message, t []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) { + if m.delay > 0 { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(m.delay): + } + } + return &providers.LLMResponse{Content: m.response}, nil +} + +func (m *mockLLMProvider) GetDefaultModel() string { return "mock" } + +func newTestResolver() *mockAgentResolver { + toolReg := tools.NewToolRegistry() + return &mockAgentResolver{ + agents: map[string]*AgentInfo{ + "worker-a": { + ID: "worker-a", + Name: "Worker A", + Role: "test worker", + Provider: &mockLLMProvider{response: "result from A", delay: 50 * time.Millisecond}, + Tools: toolReg, + MaxIter: 3, + }, + "worker-b": { + ID: "worker-b", + Name: "Worker B", + Role: "test worker", + Provider: &mockLLMProvider{response: "result from B", delay: 50 * time.Millisecond}, + Tools: toolReg, + MaxIter: 3, + }, + }, + } +} + +func TestAsyncSpawn_Accepted(t *testing.T) { + registry := NewRunRegistry() + announcer := NewAnnouncer(10) + sm := NewSpawnManager(registry, announcer, 5, 10*time.Second) + resolver := newTestResolver() + board := NewBlackboard() + + result := sm.AsyncSpawn(context.Background(), resolver, board, SpawnRequest{ + FromAgentID: "main", + ToAgentID: "worker-a", + Task: "do something", + ParentRunKey: "parent-session", + }, "test", "chat1") + + if result.Status != "accepted" { + t.Fatalf("expected accepted, got %s: %s", result.Status, result.Error) + } + if result.RunID == "" { + t.Error("expected non-empty RunID") + } + + // Wait for completion + time.Sleep(200 * time.Millisecond) + + // Check announcement was delivered + anns := announcer.Drain("parent-session") + if len(anns) == 0 { + t.Fatal("expected at least 1 announcement") + } + if anns[0].AgentID != "worker-a" { + t.Errorf("expected agent worker-a, got %s", anns[0].AgentID) + } + if anns[0].Outcome == nil || !anns[0].Outcome.Success { + t.Error("expected successful outcome") + } +} + +func TestAsyncSpawn_ConcurrencyLimit(t *testing.T) { + registry := NewRunRegistry() + announcer := NewAnnouncer(10) + sm := NewSpawnManager(registry, announcer, 2, 10*time.Second) // max 2 concurrent + + slowProvider := &mockLLMProvider{response: "slow result", delay: 500 * time.Millisecond} + toolReg := tools.NewToolRegistry() + resolver := &mockAgentResolver{ + agents: map[string]*AgentInfo{ + "worker": { + ID: "worker", Name: "Worker", Provider: slowProvider, + Tools: toolReg, MaxIter: 3, + }, + }, + } + board := NewBlackboard() + + // Spawn 2 (should succeed — at limit) + r1 := sm.AsyncSpawn(context.Background(), resolver, board, SpawnRequest{ + FromAgentID: "main", ToAgentID: "worker", Task: "task 1", ParentRunKey: "parent", + }, "test", "chat1") + r2 := sm.AsyncSpawn(context.Background(), resolver, board, SpawnRequest{ + FromAgentID: "main", ToAgentID: "worker", Task: "task 2", ParentRunKey: "parent", + }, "test", "chat1") + + if r1.Status != "accepted" || r2.Status != "accepted" { + t.Fatalf("first 2 should be accepted, got r1=%s r2=%s", r1.Status, r2.Status) + } + + // Spawn 3rd (should be rejected — over limit) + r3 := sm.AsyncSpawn(context.Background(), resolver, board, SpawnRequest{ + FromAgentID: "main", ToAgentID: "worker", Task: "task 3", ParentRunKey: "parent", + }, "test", "chat1") + + if r3.Status != "rejected" { + t.Errorf("3rd spawn should be rejected, got %s", r3.Status) + } + + // Wait for first 2 to complete + time.Sleep(700 * time.Millisecond) + + // Now should be able to spawn again + r4 := sm.AsyncSpawn(context.Background(), resolver, board, SpawnRequest{ + FromAgentID: "main", ToAgentID: "worker", Task: "task 4", ParentRunKey: "parent", + }, "test", "chat1") + + if r4.Status != "accepted" { + t.Errorf("4th spawn should be accepted after slots freed, got %s: %s", r4.Status, r4.Error) + } +} + +func TestAsyncSpawn_CascadeStop(t *testing.T) { + registry := NewRunRegistry() + announcer := NewAnnouncer(10) + sm := NewSpawnManager(registry, announcer, 5, 10*time.Second) + + slowProvider := &mockLLMProvider{response: "should not complete", delay: 2 * time.Second} + toolReg := tools.NewToolRegistry() + resolver := &mockAgentResolver{ + agents: map[string]*AgentInfo{ + "worker": { + ID: "worker", Name: "Worker", Provider: slowProvider, + Tools: toolReg, MaxIter: 3, + }, + }, + } + board := NewBlackboard() + + r := sm.AsyncSpawn(context.Background(), resolver, board, SpawnRequest{ + FromAgentID: "main", ToAgentID: "worker", Task: "long task", ParentRunKey: "parent", + }, "test", "chat1") + + if r.Status != "accepted" { + t.Fatalf("expected accepted, got %s", r.Status) + } + + // Give goroutine time to start + time.Sleep(50 * time.Millisecond) + + // Verify it's registered + if registry.ActiveCount() == 0 { + t.Fatal("expected at least 1 active run") + } + + // Cascade stop should cancel it + killed := registry.CascadeStop(r.SessionKey) + if killed == 0 { + t.Error("expected cascade stop to kill at least 1 run") + } + + // Wait for goroutine to clean up + time.Sleep(200 * time.Millisecond) +} + +func TestAsyncSpawn_ContextTimeout(t *testing.T) { + registry := NewRunRegistry() + announcer := NewAnnouncer(10) + sm := NewSpawnManager(registry, announcer, 5, 200*time.Millisecond) // very short timeout + + slowProvider := &mockLLMProvider{response: "too slow", delay: 5 * time.Second} + toolReg := tools.NewToolRegistry() + resolver := &mockAgentResolver{ + agents: map[string]*AgentInfo{ + "worker": { + ID: "worker", Name: "Worker", Provider: slowProvider, + Tools: toolReg, MaxIter: 3, + }, + }, + } + board := NewBlackboard() + + sm.AsyncSpawn(context.Background(), resolver, board, SpawnRequest{ + FromAgentID: "main", ToAgentID: "worker", Task: "slow task", ParentRunKey: "parent", + }, "test", "chat1") + + // Wait for timeout + cleanup + time.Sleep(500 * time.Millisecond) + + // Should have an announcement with failure + anns := announcer.Drain("parent") + if len(anns) == 0 { + t.Fatal("expected announcement after timeout") + } + if anns[0].Outcome.Success { + t.Error("expected failure after timeout") + } +} + +func TestAsyncSpawn_ParallelFanOut(t *testing.T) { + registry := NewRunRegistry() + announcer := NewAnnouncer(20) + sm := NewSpawnManager(registry, announcer, 10, 10*time.Second) + resolver := newTestResolver() + board := NewBlackboard() + + // Fan-out: spawn multiple agents in parallel (Google MapReduce pattern) + var results []*SpawnResult + for i := 0; i < 5; i++ { + target := "worker-a" + if i%2 == 1 { + target = "worker-b" + } + r := sm.AsyncSpawn(context.Background(), resolver, board, SpawnRequest{ + FromAgentID: "main", + ToAgentID: target, + Task: fmt.Sprintf("parallel task %d", i), + ParentRunKey: "parent", + }, "test", "chat1") + results = append(results, r) + } + + // All should be accepted + for i, r := range results { + if r.Status != "accepted" { + t.Errorf("spawn %d should be accepted, got %s", i, r.Status) + } + } + + // Fan-in: wait for all to complete and collect results + time.Sleep(500 * time.Millisecond) + + anns := announcer.Drain("parent") + if len(anns) != 5 { + t.Errorf("expected 5 announcements (fan-in), got %d", len(anns)) + } +} + +// TestAnnouncer tests + +func TestAnnouncer_DeliverAndDrain(t *testing.T) { + a := NewAnnouncer(10) + + a.Deliver("session-1", &Announcement{ + RunID: "run-1", + AgentID: "worker-a", + Content: "result 1", + }) + a.Deliver("session-1", &Announcement{ + RunID: "run-2", + AgentID: "worker-b", + Content: "result 2", + }) + + results := a.Drain("session-1") + if len(results) != 2 { + t.Fatalf("expected 2 announcements, got %d", len(results)) + } + + // Drain again should return empty + results2 := a.Drain("session-1") + if len(results2) != 0 { + t.Errorf("expected 0 after drain, got %d", len(results2)) + } +} + +func TestAnnouncer_Pending(t *testing.T) { + a := NewAnnouncer(10) + + if a.Pending("session-1") != 0 { + t.Error("expected 0 pending for new session") + } + + a.Deliver("session-1", &Announcement{RunID: "r1"}) + a.Deliver("session-1", &Announcement{RunID: "r2"}) + + if a.Pending("session-1") != 2 { + t.Errorf("expected 2 pending, got %d", a.Pending("session-1")) + } +} + +func TestAnnouncer_BackPressure(t *testing.T) { + a := NewAnnouncer(2) // tiny buffer + + // Fill buffer + a.Deliver("session-1", &Announcement{RunID: "r1", Content: "first"}) + a.Deliver("session-1", &Announcement{RunID: "r2", Content: "second"}) + + // Overflow — should drop oldest + a.Deliver("session-1", &Announcement{RunID: "r3", Content: "third"}) + + results := a.Drain("session-1") + if len(results) != 2 { + t.Fatalf("expected 2 after back-pressure, got %d", len(results)) + } + // Most recent should be present + hasThird := false + for _, r := range results { + if r.RunID == "r3" { + hasThird = true + } + } + if !hasThird { + t.Error("expected the newest announcement (r3) to be present after back-pressure") + } +} + +func TestAnnouncer_ConcurrentDelivery(t *testing.T) { + a := NewAnnouncer(100) + + var wg sync.WaitGroup + for i := 0; i < 50; i++ { + wg.Add(1) + go func(n int) { + defer wg.Done() + a.Deliver("session-1", &Announcement{ + RunID: fmt.Sprintf("r-%d", n), + Content: fmt.Sprintf("result %d", n), + }) + }(i) + } + wg.Wait() + + results := a.Drain("session-1") + if len(results) != 50 { + t.Errorf("expected 50 concurrent deliveries, got %d", len(results)) + } +} + +func TestAnnouncer_Cleanup(t *testing.T) { + a := NewAnnouncer(10) + a.Deliver("session-1", &Announcement{RunID: "r1"}) + a.Cleanup("session-1") + + // After cleanup, pending should be 0 (new channel) + if a.Pending("session-1") != 0 { + t.Error("expected 0 pending after cleanup") + } +} diff --git a/pkg/multiagent/spawn_tool.go b/pkg/multiagent/spawn_tool.go new file mode 100644 index 000000000..2785abb1b --- /dev/null +++ b/pkg/multiagent/spawn_tool.go @@ -0,0 +1,169 @@ +package multiagent + +import ( + "context" + "fmt" + "strings" + + "github.com/sipeed/picoclaw/pkg/tools" +) + +// SpawnTool allows an LLM agent to asynchronously spawn a child agent. +// Unlike HandoffTool (synchronous, blocking), SpawnTool returns immediately +// with a run ID. Results are auto-announced back to the parent session. +// +// Pattern: Anthropic's orchestrator-workers + OpenAI Swarm's lightweight handoffs. +type SpawnTool struct { + resolver AgentResolver + board *Blackboard + spawnManager *SpawnManager + fromAgentID string + originChannel string + originChatID string + depth int + visited []string + maxDepth int + parentSessionKey string + allowlistChecker AllowlistChecker +} + +// NewSpawnTool creates a spawn tool bound to a source agent. +func NewSpawnTool(resolver AgentResolver, board *Blackboard, spawnManager *SpawnManager, fromAgentID string) *SpawnTool { + return &SpawnTool{ + resolver: resolver, + board: board, + spawnManager: spawnManager, + fromAgentID: fromAgentID, + originChannel: "cli", + originChatID: "direct", + } +} + +func (t *SpawnTool) Name() string { return "spawn_agent" } + +func (t *SpawnTool) Description() string { + agents := t.resolver.ListAgents() + if len(agents) <= 1 { + return "Spawn a child agent asynchronously. Returns immediately — result auto-announces back. No other agents currently available." + } + + var sb strings.Builder + sb.WriteString("Spawn a child agent asynchronously. Returns immediately with a run ID — the result will auto-announce back when complete. Available agents:\n") + for _, a := range agents { + if a.ID == t.fromAgentID { + continue + } + fmt.Fprintf(&sb, "- %s", a.ID) + if a.Name != "" { + fmt.Fprintf(&sb, " (%s)", a.Name) + } + if a.Role != "" { + fmt.Fprintf(&sb, ": %s", a.Role) + } + sb.WriteString("\n") + } + sb.WriteString("\nUse 'list_spawns' to check status. Results are auto-delivered — no need to poll.") + return sb.String() +} + +func (t *SpawnTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "agent_id": map[string]any{ + "type": "string", + "description": "The ID of the agent to spawn", + }, + "capability": map[string]any{ + "type": "string", + "description": "Route to an agent with this capability instead of by ID", + }, + "task": map[string]any{ + "type": "string", + "description": "The task for the spawned agent", + }, + "context": map[string]any{ + "type": "object", + "description": "Optional key-value context to share via blackboard", + }, + }, + "required": []string{"task"}, + } +} + +// SetBoard implements BoardAware. +func (t *SpawnTool) SetBoard(board *Blackboard) { + t.board = board +} + +// SetContext implements ContextualTool. +func (t *SpawnTool) SetContext(channel, chatID string) { + t.originChannel = channel + t.originChatID = chatID +} + +// SetAllowlistChecker sets the allowlist checker for spawn permissions. +func (t *SpawnTool) SetAllowlistChecker(checker AllowlistChecker) { + t.allowlistChecker = checker +} + +// SetRunRegistry sets registry and parent key for cascade tracking. +func (t *SpawnTool) SetRunRegistry(registry *RunRegistry, parentSessionKey string) { + t.parentSessionKey = parentSessionKey +} + +func (t *SpawnTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult { + agentID, _ := args["agent_id"].(string) + capability, _ := args["capability"].(string) + task, _ := args["task"].(string) + + if task == "" { + return tools.ErrorResult("task is required") + } + + // Resolve agent + if agentID == "" && capability != "" { + matches := FindAgentsByCapability(t.resolver, capability) + if len(matches) == 0 { + return tools.ErrorResult(fmt.Sprintf("no agent found with capability %q", capability)) + } + agentID = matches[0].ID + } + if agentID == "" { + return tools.ErrorResult("agent_id or capability is required") + } + + // Allowlist check + if t.allowlistChecker != nil && !t.allowlistChecker.CanHandoff(t.fromAgentID, agentID) { + return tools.ErrorResult(fmt.Sprintf("spawn from %q to %q not allowed by policy", t.fromAgentID, agentID)) + } + + // Parse context + var contextMap map[string]string + if ctxRaw, ok := args["context"].(map[string]any); ok { + contextMap = make(map[string]string, len(ctxRaw)) + for k, v := range ctxRaw { + contextMap[k] = fmt.Sprintf("%v", v) + } + } + + result := t.spawnManager.AsyncSpawn(ctx, t.resolver, t.board, SpawnRequest{ + FromAgentID: t.fromAgentID, + ToAgentID: agentID, + Task: task, + Context: contextMap, + Depth: t.depth, + Visited: t.visited, + MaxDepth: t.maxDepth, + ParentRunKey: t.parentSessionKey, + }, t.originChannel, t.originChatID) + + if result.Status != "accepted" { + return tools.ErrorResult(fmt.Sprintf("Spawn rejected: %s", result.Error)) + } + + return &tools.ToolResult{ + ForLLM: fmt.Sprintf("Agent %q spawned (run_id: %s). It runs asynchronously — the result will auto-announce back when complete. Continue with other work.", agentID, result.RunID), + ForUser: fmt.Sprintf("Spawned agent %q (run: %s)", agentID, result.RunID), + } +} diff --git a/pkg/providers/auth_rotation.go b/pkg/providers/auth_rotation.go new file mode 100644 index 000000000..eaef1d631 --- /dev/null +++ b/pkg/providers/auth_rotation.go @@ -0,0 +1,185 @@ +package providers + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/logger" +) + +// AuthProfile represents a single API key with rotation metadata. +type AuthProfile struct { + ID string // unique identifier (e.g. "openrouter:0") + APIKey string +} + +// AuthRotator manages round-robin selection across multiple API keys, +// with per-key cooldown tracking via CooldownTracker. +type AuthRotator struct { + profiles []AuthProfile + cooldown *CooldownTracker + mu sync.Mutex + lastUsed map[string]time.Time +} + +// NewAuthRotator creates a rotator for the given profiles. +// Uses the provided CooldownTracker for per-key cooldown state. +func NewAuthRotator(profiles []AuthProfile, cooldown *CooldownTracker) *AuthRotator { + lastUsed := make(map[string]time.Time, len(profiles)) + for _, p := range profiles { + lastUsed[p.ID] = time.Time{} // never used + } + return &AuthRotator{ + profiles: profiles, + cooldown: cooldown, + lastUsed: lastUsed, + } +} + +// NextAvailable returns the best available profile using round-robin +// (oldest lastUsed first), skipping profiles in cooldown. +// Returns nil if all profiles are in cooldown. +func (r *AuthRotator) NextAvailable() *AuthProfile { + r.mu.Lock() + defer r.mu.Unlock() + + var best *AuthProfile + var bestTime time.Time + first := true + + for i := range r.profiles { + p := &r.profiles[i] + if !r.cooldown.IsAvailable(p.ID) { + continue + } + lu := r.lastUsed[p.ID] + if first || lu.Before(bestTime) { + best = p + bestTime = lu + first = false + } + } + + if best != nil { + r.lastUsed[best.ID] = time.Now() + } + return best +} + +// MarkFailure records a failure for a specific profile. +func (r *AuthRotator) MarkFailure(profileID string, reason FailoverReason) { + r.cooldown.MarkFailure(profileID, reason) + logger.WarnCF("auth_rotation", "Profile marked as failed", map[string]interface{}{ + "profile_id": profileID, + "reason": string(reason), + "remaining": r.cooldown.CooldownRemaining(profileID).Round(time.Second).String(), + }) +} + +// MarkSuccess resets counters for a specific profile. +func (r *AuthRotator) MarkSuccess(profileID string) { + r.cooldown.MarkSuccess(profileID) +} + +// AvailableCount returns the number of profiles not in cooldown. +func (r *AuthRotator) AvailableCount() int { + count := 0 + for _, p := range r.profiles { + if r.cooldown.IsAvailable(p.ID) { + count++ + } + } + return count +} + +// ProfileCount returns the total number of profiles. +func (r *AuthRotator) ProfileCount() int { + return len(r.profiles) +} + +// AuthRotatingProvider wraps multiple LLM providers (one per API key) +// and rotates between them using AuthRotator. +type AuthRotatingProvider struct { + providers map[string]LLMProvider // profileID -> provider + rotator *AuthRotator + model string // default model from first provider +} + +// NewAuthRotatingProvider creates a rotating provider. +// factory is called once per profile to create the underlying provider. +func NewAuthRotatingProvider( + profiles []AuthProfile, + cooldown *CooldownTracker, + factory func(apiKey string) LLMProvider, +) *AuthRotatingProvider { + providerMap := make(map[string]LLMProvider, len(profiles)) + var defaultModel string + for _, p := range profiles { + prov := factory(p.APIKey) + providerMap[p.ID] = prov + if defaultModel == "" { + defaultModel = prov.GetDefaultModel() + } + } + + rotator := NewAuthRotator(profiles, cooldown) + + logger.InfoCF("auth_rotation", "Auth rotation initialized", map[string]interface{}{ + "profiles": len(profiles), + }) + + return &AuthRotatingProvider{ + providers: providerMap, + rotator: rotator, + model: defaultModel, + } +} + +// Chat selects the best available profile and delegates to its provider. +// On failure, marks the profile and returns the error (FallbackChain handles retry). +func (p *AuthRotatingProvider) Chat( + ctx context.Context, + messages []Message, + tools []ToolDefinition, + model string, + opts map[string]interface{}, +) (*LLMResponse, error) { + profile := p.rotator.NextAvailable() + if profile == nil { + return nil, fmt.Errorf("all auth profiles in cooldown (%d total)", p.rotator.ProfileCount()) + } + + provider := p.providers[profile.ID] + resp, err := provider.Chat(ctx, messages, tools, model, opts) + + if err != nil { + // Classify and record failure against this specific profile. + if failErr := ClassifyError(err, profile.ID, model); failErr != nil && failErr.IsRetriable() { + p.rotator.MarkFailure(profile.ID, failErr.Reason) + } + return nil, err + } + + p.rotator.MarkSuccess(profile.ID) + return resp, nil +} + +// GetDefaultModel returns the default model from the underlying providers. +func (p *AuthRotatingProvider) GetDefaultModel() string { + return p.model +} + +// BuildAuthProfiles creates AuthProfile entries from a list of API keys. +// Profile IDs follow the pattern "provider:N" (e.g. "openrouter:0"). +func BuildAuthProfiles(providerName string, apiKeys []string) []AuthProfile { + profiles := make([]AuthProfile, len(apiKeys)) + for i, key := range apiKeys { + profiles[i] = AuthProfile{ + ID: fmt.Sprintf("%s:%d", providerName, i), + APIKey: key, + } + } + return profiles +} diff --git a/pkg/providers/auth_rotation_test.go b/pkg/providers/auth_rotation_test.go new file mode 100644 index 000000000..9c6ef74f4 --- /dev/null +++ b/pkg/providers/auth_rotation_test.go @@ -0,0 +1,343 @@ +package providers + +import ( + "context" + "fmt" + "sync" + "testing" + "time" +) + +func TestAuthRotator_NextAvailable_RoundRobin(t *testing.T) { + profiles := []AuthProfile{ + {ID: "p:0", APIKey: "key-0"}, + {ID: "p:1", APIKey: "key-1"}, + {ID: "p:2", APIKey: "key-2"}, + } + + cooldown := NewCooldownTracker() + rotator := NewAuthRotator(profiles, cooldown) + + // First call should return earliest (all have same lastUsed = zero) + p1 := rotator.NextAvailable() + if p1 == nil { + t.Fatal("expected a profile, got nil") + } + + // Second call should return a different profile (p1 now has newest lastUsed) + p2 := rotator.NextAvailable() + if p2 == nil { + t.Fatal("expected a profile, got nil") + } + if p2.ID == p1.ID { + t.Errorf("round-robin should select different profile, got same: %s", p2.ID) + } + + // Third call should return the remaining profile + p3 := rotator.NextAvailable() + if p3 == nil { + t.Fatal("expected a profile, got nil") + } + if p3.ID == p1.ID || p3.ID == p2.ID { + t.Errorf("expected third unique profile, got %s (p1=%s, p2=%s)", p3.ID, p1.ID, p2.ID) + } +} + +func TestAuthRotator_NextAvailable_SkipsCooldown(t *testing.T) { + profiles := []AuthProfile{ + {ID: "p:0", APIKey: "key-0"}, + {ID: "p:1", APIKey: "key-1"}, + } + + cooldown := NewCooldownTracker() + rotator := NewAuthRotator(profiles, cooldown) + + // Put p:0 in cooldown + cooldown.MarkFailure("p:0", FailoverRateLimit) + + // Should skip p:0 and return p:1 + p := rotator.NextAvailable() + if p == nil { + t.Fatal("expected a profile, got nil") + } + if p.ID != "p:1" { + t.Errorf("expected p:1 (p:0 in cooldown), got %s", p.ID) + } +} + +func TestAuthRotator_NextAvailable_AllInCooldown(t *testing.T) { + profiles := []AuthProfile{ + {ID: "p:0", APIKey: "key-0"}, + {ID: "p:1", APIKey: "key-1"}, + } + + cooldown := NewCooldownTracker() + rotator := NewAuthRotator(profiles, cooldown) + + // Put both in cooldown + cooldown.MarkFailure("p:0", FailoverRateLimit) + cooldown.MarkFailure("p:1", FailoverBilling) + + p := rotator.NextAvailable() + if p != nil { + t.Errorf("expected nil when all in cooldown, got %s", p.ID) + } +} + +func TestAuthRotator_MarkSuccess_ResetsCooldown(t *testing.T) { + profiles := []AuthProfile{ + {ID: "p:0", APIKey: "key-0"}, + } + + cooldown := NewCooldownTracker() + rotator := NewAuthRotator(profiles, cooldown) + + // Put in cooldown + cooldown.MarkFailure("p:0", FailoverRateLimit) + if cooldown.IsAvailable("p:0") { + t.Fatal("should be in cooldown after failure") + } + + // Mark success resets + rotator.MarkSuccess("p:0") + if !cooldown.IsAvailable("p:0") { + t.Fatal("should be available after success") + } +} + +func TestAuthRotator_AvailableCount(t *testing.T) { + profiles := []AuthProfile{ + {ID: "p:0", APIKey: "key-0"}, + {ID: "p:1", APIKey: "key-1"}, + {ID: "p:2", APIKey: "key-2"}, + } + + cooldown := NewCooldownTracker() + rotator := NewAuthRotator(profiles, cooldown) + + if rotator.AvailableCount() != 3 { + t.Errorf("expected 3 available, got %d", rotator.AvailableCount()) + } + + cooldown.MarkFailure("p:1", FailoverRateLimit) + if rotator.AvailableCount() != 2 { + t.Errorf("expected 2 available, got %d", rotator.AvailableCount()) + } +} + +func TestBuildAuthProfiles(t *testing.T) { + keys := []string{"sk-key1", "sk-key2", "sk-key3"} + profiles := BuildAuthProfiles("openrouter", keys) + + if len(profiles) != 3 { + t.Fatalf("expected 3 profiles, got %d", len(profiles)) + } + if profiles[0].ID != "openrouter:0" { + t.Errorf("profiles[0].ID = %q, want %q", profiles[0].ID, "openrouter:0") + } + if profiles[2].APIKey != "sk-key3" { + t.Errorf("profiles[2].APIKey = %q, want %q", profiles[2].APIKey, "sk-key3") + } +} + +// mockRotatingProvider tracks which provider was called. +type mockRotatingProvider struct { + apiKey string + callCount int + mu sync.Mutex + failErr error +} + +func (m *mockRotatingProvider) Chat(_ context.Context, _ []Message, _ []ToolDefinition, _ string, _ map[string]interface{}) (*LLMResponse, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.callCount++ + if m.failErr != nil { + return nil, m.failErr + } + return &LLMResponse{Content: "ok from " + m.apiKey}, nil +} + +func (m *mockRotatingProvider) GetDefaultModel() string { return "mock" } + +func TestAuthRotatingProvider_RotatesOnSuccess(t *testing.T) { + profiles := []AuthProfile{ + {ID: "p:0", APIKey: "key-0"}, + {ID: "p:1", APIKey: "key-1"}, + } + + providers := make(map[string]*mockRotatingProvider) + cooldown := NewCooldownTracker() + factory := func(apiKey string) LLMProvider { + p := &mockRotatingProvider{apiKey: apiKey} + providers[apiKey] = p + return p + } + + rp := NewAuthRotatingProvider(profiles, cooldown, factory) + + // First call + resp, err := rp.Chat(context.Background(), nil, nil, "model", nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp == nil { + t.Fatal("expected response, got nil") + } + + // Second call should use different provider (round-robin) + _, err = rp.Chat(context.Background(), nil, nil, "model", nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Both providers should have been called once + p0 := providers["key-0"] + p1 := providers["key-1"] + if p0.callCount+p1.callCount != 2 { + t.Errorf("expected 2 total calls, got %d + %d", p0.callCount, p1.callCount) + } + if p0.callCount == 0 || p1.callCount == 0 { + t.Errorf("expected both providers called, got p0=%d, p1=%d", p0.callCount, p1.callCount) + } +} + +func TestAuthRotatingProvider_MarksFailure(t *testing.T) { + profiles := []AuthProfile{ + {ID: "p:0", APIKey: "key-0"}, + {ID: "p:1", APIKey: "key-1"}, + } + + cooldown := NewCooldownTracker() + factory := func(apiKey string) LLMProvider { + return &mockRotatingProvider{ + apiKey: apiKey, + failErr: fmt.Errorf("429 too many requests"), + } + } + + rp := NewAuthRotatingProvider(profiles, cooldown, factory) + + // First call fails — should mark p:0 failure + _, err := rp.Chat(context.Background(), nil, nil, "model", nil) + if err == nil { + t.Fatal("expected error") + } + + // p:0 should now be in cooldown + if cooldown.IsAvailable("p:0") { + t.Error("p:0 should be in cooldown after rate limit failure") + } + + // p:1 should still be available + if !cooldown.IsAvailable("p:1") { + t.Error("p:1 should still be available") + } +} + +func TestAuthRotatingProvider_AllInCooldown(t *testing.T) { + profiles := []AuthProfile{ + {ID: "p:0", APIKey: "key-0"}, + } + + cooldown := NewCooldownTracker() + cooldown.MarkFailure("p:0", FailoverRateLimit) + + factory := func(apiKey string) LLMProvider { + return &mockRotatingProvider{apiKey: apiKey} + } + + rp := NewAuthRotatingProvider(profiles, cooldown, factory) + + _, err := rp.Chat(context.Background(), nil, nil, "model", nil) + if err == nil { + t.Fatal("expected error when all profiles in cooldown") + } + if err.Error() != "all auth profiles in cooldown (1 total)" { + t.Errorf("unexpected error: %v", err) + } +} + +func TestAuthRotatingProvider_SingleKey_NoCooldownRotation(t *testing.T) { + profiles := []AuthProfile{ + {ID: "p:0", APIKey: "key-only"}, + } + + cooldown := NewCooldownTracker() + factory := func(apiKey string) LLMProvider { + return &mockRotatingProvider{apiKey: apiKey} + } + + rp := NewAuthRotatingProvider(profiles, cooldown, factory) + + // Multiple calls should all succeed using the single key + for i := 0; i < 5; i++ { + _, err := rp.Chat(context.Background(), nil, nil, "model", nil) + if err != nil { + t.Fatalf("call %d: unexpected error: %v", i, err) + } + } +} + +func TestAuthRotator_ConcurrentAccess(t *testing.T) { + profiles := []AuthProfile{ + {ID: "p:0", APIKey: "key-0"}, + {ID: "p:1", APIKey: "key-1"}, + {ID: "p:2", APIKey: "key-2"}, + } + + cooldown := NewCooldownTracker() + rotator := NewAuthRotator(profiles, cooldown) + + var wg sync.WaitGroup + seen := sync.Map{} + + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + p := rotator.NextAvailable() + if p != nil { + seen.Store(p.ID, true) + } + }() + } + + wg.Wait() + + // All profiles should have been used + count := 0 + seen.Range(func(_, _ interface{}) bool { + count++ + return true + }) + if count != 3 { + t.Errorf("expected all 3 profiles used concurrently, got %d", count) + } +} + +func TestAuthRotator_BillingCooldown_LongerDuration(t *testing.T) { + profiles := []AuthProfile{ + {ID: "p:0", APIKey: "key-0"}, + {ID: "p:1", APIKey: "key-1"}, + } + + cooldown := NewCooldownTracker() + rotator := NewAuthRotator(profiles, cooldown) + + // Mark billing failure (should have 5h cooldown) + rotator.MarkFailure("p:0", FailoverBilling) + + remaining := cooldown.CooldownRemaining("p:0") + // Billing cooldown should be >= 4.5 hours (5h minus some time elapsed) + if remaining < 4*time.Hour { + t.Errorf("billing cooldown should be ~5h, got %v", remaining) + } + + // Standard failure should have much shorter cooldown + rotator.MarkFailure("p:1", FailoverRateLimit) + remaining2 := cooldown.CooldownRemaining("p:1") + if remaining2 > 2*time.Minute { + t.Errorf("standard cooldown should be ~1min, got %v", remaining2) + } +} diff --git a/pkg/providers/factory.go b/pkg/providers/factory.go index e39cfe32b..c4a9f61f5 100644 --- a/pkg/providers/factory.go +++ b/pkg/providers/factory.go @@ -26,7 +26,9 @@ const ( type providerSelection struct { providerType providerType + providerName string // resolved provider name (e.g. "openrouter", "anthropic") apiKey string + apiKeys []string // multiple keys for auth rotation (nil = single key) apiBase string proxy string model string @@ -120,7 +122,9 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { } } case "openrouter": - if cfg.Providers.OpenRouter.APIKey != "" { + if cfg.Providers.OpenRouter.APIKey != "" || len(cfg.Providers.OpenRouter.APIKeys) > 0 { + sel.providerName = "openrouter" + sel.apiKeys = cfg.Providers.OpenRouter.ResolveAPIKeys() sel.apiKey = cfg.Providers.OpenRouter.APIKey sel.proxy = cfg.Providers.OpenRouter.Proxy if cfg.Providers.OpenRouter.APIBase != "" { @@ -227,6 +231,8 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"): + sel.providerName = "openrouter" + sel.apiKeys = cfg.Providers.OpenRouter.ResolveAPIKeys() sel.apiKey = cfg.Providers.OpenRouter.APIKey sel.proxy = cfg.Providers.OpenRouter.Proxy if cfg.Providers.OpenRouter.APIBase != "" { @@ -307,7 +313,9 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { sel.apiBase = cfg.Providers.VLLM.APIBase sel.proxy = cfg.Providers.VLLM.Proxy default: - if cfg.Providers.OpenRouter.APIKey != "" { + if cfg.Providers.OpenRouter.APIKey != "" || len(cfg.Providers.OpenRouter.APIKeys) > 0 { + sel.providerName = "openrouter" + sel.apiKeys = cfg.Providers.OpenRouter.ResolveAPIKeys() sel.apiKey = cfg.Providers.OpenRouter.APIKey sel.proxy = cfg.Providers.OpenRouter.Proxy if cfg.Providers.OpenRouter.APIBase != "" { @@ -322,7 +330,7 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { } if sel.providerType == providerTypeHTTPCompat { - if sel.apiKey == "" && !strings.HasPrefix(model, "bedrock/") { + if sel.apiKey == "" && len(sel.apiKeys) == 0 && !strings.HasPrefix(model, "bedrock/") { return providerSelection{}, fmt.Errorf("no API key configured for provider (model: %s)", model) } if sel.apiBase == "" { @@ -355,6 +363,15 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { case providerTypeGitHubCopilot: return NewGitHubCopilotProvider(sel.apiBase, sel.connectMode, sel.model) default: + // Auth rotation: wrap with AuthRotatingProvider if multiple keys configured. + if len(sel.apiKeys) > 1 { + profiles := BuildAuthProfiles(sel.providerName, sel.apiKeys) + cooldown := NewCooldownTracker() + factory := func(apiKey string) LLMProvider { + return NewHTTPProvider(apiKey, sel.apiBase, sel.proxy) + } + return NewAuthRotatingProvider(profiles, cooldown, factory), nil + } return NewHTTPProvider(sel.apiKey, sel.apiBase, sel.proxy), nil } } diff --git a/pkg/routing/agent_id.go b/pkg/routing/agent_id.go index bcf2f0dc0..a6e6e881e 100644 --- a/pkg/routing/agent_id.go +++ b/pkg/routing/agent_id.go @@ -5,6 +5,7 @@ import ( "strings" ) +// Agent ID defaults and constraints. const ( DefaultAgentID = "main" DefaultMainKey = "main" diff --git a/pkg/routing/route.go b/pkg/routing/route.go index 9eb060c53..9b69bf366 100644 --- a/pkg/routing/route.go +++ b/pkg/routing/route.go @@ -21,8 +21,8 @@ type ResolvedRoute struct { AgentID string Channel string AccountID string - SessionKey string - MainSessionKey string + SessionKey string `json:"session_key"` //nolint:gosec // G117: not a secret, this is a session identifier + MainSessionKey string `json:"main_session_key"` //nolint:gosec // G117: not a secret, this is a session identifier MatchedBy string // "binding.peer", "binding.peer.parent", "binding.guild", "binding.team", "binding.account", "binding.channel", "default" } @@ -141,7 +141,7 @@ func matchesAccountID(matchAccountID, actual string) bool { if trimmed == "*" { return true } - return strings.ToLower(trimmed) == strings.ToLower(actual) + return strings.EqualFold(trimmed, actual) } func (r *RouteResolver) findPeerMatch(bindings []config.AgentBinding, peer *RoutePeer) *config.AgentBinding { diff --git a/pkg/routing/session_key.go b/pkg/routing/session_key.go index e12f0d1d8..cb2469a10 100644 --- a/pkg/routing/session_key.go +++ b/pkg/routing/session_key.go @@ -8,6 +8,7 @@ import ( // DMScope controls DM session isolation granularity. type DMScope string +// DM scope constants control session isolation granularity. const ( DMScopeMain DMScope = "main" DMScopePerPeer DMScope = "per-peer" @@ -86,6 +87,8 @@ func BuildAgentPeerSessionKey(params SessionKeyParams) string { if peerID != "" { return fmt.Sprintf("agent:%s:direct:%s", agentID, peerID) } + default: + // DMScopeMain or unrecognized: fall through to main session key } return BuildAgentMainSessionKey(agentID) } diff --git a/pkg/tools/groups.go b/pkg/tools/groups.go new file mode 100644 index 000000000..247cc7850 --- /dev/null +++ b/pkg/tools/groups.go @@ -0,0 +1,39 @@ +package tools + +// DefaultToolGroups maps group references (e.g. "group:fs") to tool names. +// Groups provide a convenient shorthand for tool policies so operators +// can allow/deny entire categories without listing every tool name. +var DefaultToolGroups = map[string][]string{ + "group:fs": {"read_file", "write_file", "edit_file", "append_file", "list_dir"}, + "group:web": {"web_search", "web_fetch"}, + "group:exec": {"exec"}, + "group:hw": {"i2c", "spi"}, + "group:comms": {"message", "spawn"}, + "group:agents": {"blackboard", "handoff", "list_agents"}, +} + +// ResolveToolNames expands group refs (e.g. "group:fs") and individual tool +// names into a deduplicated list of concrete tool names. +// Unknown group refs are treated as individual tool names (pass-through). +func ResolveToolNames(refs []string) []string { + seen := make(map[string]struct{}, len(refs)) + result := make([]string, 0, len(refs)) + + for _, ref := range refs { + if tools, ok := DefaultToolGroups[ref]; ok { + for _, name := range tools { + if _, dup := seen[name]; !dup { + seen[name] = struct{}{} + result = append(result, name) + } + } + } else { + if _, dup := seen[ref]; !dup { + seen[ref] = struct{}{} + result = append(result, ref) + } + } + } + + return result +} diff --git a/pkg/tools/groups_test.go b/pkg/tools/groups_test.go new file mode 100644 index 000000000..90fcda8b8 --- /dev/null +++ b/pkg/tools/groups_test.go @@ -0,0 +1,68 @@ +package tools + +import ( + "sort" + "testing" +) + +func TestResolveToolNames_GroupExpansion(t *testing.T) { + result := ResolveToolNames([]string{"group:fs"}) + expected := []string{"read_file", "write_file", "edit_file", "append_file", "list_dir"} + + if len(result) != len(expected) { + t.Fatalf("len = %d, want %d: %v", len(result), len(expected), result) + } + sort.Strings(result) + sort.Strings(expected) + for i := range expected { + if result[i] != expected[i] { + t.Errorf("result[%d] = %q, want %q", i, result[i], expected[i]) + } + } +} + +func TestResolveToolNames_IndividualTool(t *testing.T) { + result := ResolveToolNames([]string{"exec"}) + if len(result) != 1 || result[0] != "exec" { + t.Errorf("result = %v, want [exec]", result) + } +} + +func TestResolveToolNames_Mixed(t *testing.T) { + result := ResolveToolNames([]string{"group:web", "exec"}) + expected := map[string]bool{"web_search": true, "web_fetch": true, "exec": true} + if len(result) != len(expected) { + t.Fatalf("len = %d, want %d: %v", len(result), len(expected), result) + } + for _, name := range result { + if !expected[name] { + t.Errorf("unexpected tool: %q", name) + } + } +} + +func TestResolveToolNames_Dedup(t *testing.T) { + result := ResolveToolNames([]string{"group:exec", "exec"}) + if len(result) != 1 { + t.Errorf("expected 1 (deduped), got %d: %v", len(result), result) + } +} + +func TestResolveToolNames_UnknownGroup(t *testing.T) { + result := ResolveToolNames([]string{"group:nonexistent"}) + // Unknown group ref treated as a literal tool name + if len(result) != 1 || result[0] != "group:nonexistent" { + t.Errorf("result = %v, want [group:nonexistent]", result) + } +} + +func TestResolveToolNames_Empty(t *testing.T) { + result := ResolveToolNames(nil) + if len(result) != 0 { + t.Errorf("result = %v, want empty", result) + } + result = ResolveToolNames([]string{}) + if len(result) != 0 { + t.Errorf("result = %v, want empty", result) + } +} diff --git a/pkg/tools/hooks.go b/pkg/tools/hooks.go new file mode 100644 index 000000000..87e32a609 --- /dev/null +++ b/pkg/tools/hooks.go @@ -0,0 +1,17 @@ +package tools + +import "context" + +// ToolHook allows intercepting tool execution for policy enforcement, +// loop detection, logging, or other cross-cutting concerns. +// +// Hooks are called by the ToolRegistry around tool execution: +// - BeforeExecute: called before the tool runs. Return non-nil error to block execution. +// - AfterExecute: called after the tool completes (even on error). Cannot block. +// +// Multiple hooks are executed in registration order. If any BeforeExecute returns +// an error, subsequent hooks and the tool itself are skipped. +type ToolHook interface { + BeforeExecute(ctx context.Context, toolName string, args map[string]interface{}) error + AfterExecute(ctx context.Context, toolName string, args map[string]interface{}, result *ToolResult) +} diff --git a/pkg/tools/hooks_test.go b/pkg/tools/hooks_test.go new file mode 100644 index 000000000..bde4b4548 --- /dev/null +++ b/pkg/tools/hooks_test.go @@ -0,0 +1,184 @@ +package tools + +import ( + "context" + "errors" + "testing" +) + +// testHook records calls and optionally blocks execution. +type testHook struct { + beforeCalls []string + afterCalls []string + blockTool string // if non-empty, block this tool name +} + +func (h *testHook) BeforeExecute(_ context.Context, toolName string, _ map[string]interface{}) error { + h.beforeCalls = append(h.beforeCalls, toolName) + if h.blockTool != "" && toolName == h.blockTool { + return errors.New("blocked by test hook") + } + return nil +} + +func (h *testHook) AfterExecute(_ context.Context, toolName string, _ map[string]interface{}, _ *ToolResult) { + h.afterCalls = append(h.afterCalls, toolName) +} + +// dummyTool is a minimal tool for hook testing. +type dummyTool struct { + name string +} + +func (d *dummyTool) Name() string { return d.name } +func (d *dummyTool) Description() string { return "test tool" } +func (d *dummyTool) Parameters() map[string]interface{} { return nil } +func (d *dummyTool) Execute(_ context.Context, _ map[string]interface{}) *ToolResult { + return NewToolResult("ok") +} + +func TestToolHook_BeforeAndAfterCalled(t *testing.T) { + reg := NewToolRegistry() + reg.Register(&dummyTool{name: "test_tool"}) + + hook := &testHook{} + reg.AddHook(hook) + + reg.Execute(context.Background(), "test_tool", nil) + + if len(hook.beforeCalls) != 1 || hook.beforeCalls[0] != "test_tool" { + t.Errorf("beforeCalls = %v, want [test_tool]", hook.beforeCalls) + } + if len(hook.afterCalls) != 1 || hook.afterCalls[0] != "test_tool" { + t.Errorf("afterCalls = %v, want [test_tool]", hook.afterCalls) + } +} + +func TestToolHook_BlocksExecution(t *testing.T) { + reg := NewToolRegistry() + reg.Register(&dummyTool{name: "blocked_tool"}) + + hook := &testHook{blockTool: "blocked_tool"} + reg.AddHook(hook) + + result := reg.Execute(context.Background(), "blocked_tool", nil) + + if !result.IsError { + t.Error("expected error result when hook blocks") + } + if len(hook.beforeCalls) != 1 { + t.Errorf("beforeCalls count = %d, want 1", len(hook.beforeCalls)) + } + // AfterExecute should still be called (for observability) + if len(hook.afterCalls) != 1 { + t.Errorf("afterCalls count = %d, want 1 (observability)", len(hook.afterCalls)) + } +} + +func TestToolHook_MultipleHooks(t *testing.T) { + reg := NewToolRegistry() + reg.Register(&dummyTool{name: "multi"}) + + hook1 := &testHook{} + hook2 := &testHook{} + reg.AddHook(hook1) + reg.AddHook(hook2) + + reg.Execute(context.Background(), "multi", nil) + + if len(hook1.beforeCalls) != 1 || len(hook2.beforeCalls) != 1 { + t.Error("expected both hooks to be called") + } +} + +func TestToolHook_FirstBlockStopsChain(t *testing.T) { + reg := NewToolRegistry() + reg.Register(&dummyTool{name: "chain_test"}) + + hook1 := &testHook{blockTool: "chain_test"} + hook2 := &testHook{} + reg.AddHook(hook1) + reg.AddHook(hook2) + + result := reg.Execute(context.Background(), "chain_test", nil) + + if !result.IsError { + t.Error("expected error when first hook blocks") + } + // hook1 should have been called, hook2's Before should NOT + if len(hook1.beforeCalls) != 1 { + t.Error("hook1 before should have been called") + } + if len(hook2.beforeCalls) != 0 { + t.Error("hook2 before should NOT have been called (chain stopped)") + } +} + +// TestToolHook_AfterExecuteRunsForAllHooksOnBlock verifies that when a BeforeExecute +// hook blocks execution, AfterExecute is still invoked on ALL registered hooks +// (not just the blocking one) for observability purposes. +func TestToolHook_AfterExecuteRunsForAllHooksOnBlock(t *testing.T) { + reg := NewToolRegistry() + reg.Register(&dummyTool{name: "observed_tool"}) + + hook1 := &testHook{blockTool: "observed_tool"} + hook2 := &testHook{} // does not block, but should still get AfterExecute + reg.AddHook(hook1) + reg.AddHook(hook2) + + result := reg.Execute(context.Background(), "observed_tool", nil) + + if !result.IsError { + t.Error("expected error result when hook1 blocks") + } + // BeforeExecute: hook1 called, hook2 NOT called (chain stopped) + if len(hook1.beforeCalls) != 1 { + t.Errorf("hook1.beforeCalls = %d, want 1", len(hook1.beforeCalls)) + } + if len(hook2.beforeCalls) != 0 { + t.Errorf("hook2.beforeCalls = %d, want 0 (chain stopped)", len(hook2.beforeCalls)) + } + // AfterExecute: BOTH hooks called (inner loop over all hooks for observability) + if len(hook1.afterCalls) != 1 { + t.Errorf("hook1.afterCalls = %d, want 1", len(hook1.afterCalls)) + } + if len(hook2.afterCalls) != 1 { + t.Errorf("hook2.afterCalls = %d, want 1 (AfterExecute runs for all hooks even on block)", len(hook2.afterCalls)) + } +} + +// TestToolHook_NotFoundToolSkipsHooks verifies that hooks are not called when +// a tool does not exist in the registry. +func TestToolHook_NotFoundToolSkipsHooks(t *testing.T) { + reg := NewToolRegistry() + // Do NOT register the tool + + hook := &testHook{} + reg.AddHook(hook) + + result := reg.Execute(context.Background(), "ghost_tool", nil) + + if !result.IsError { + t.Error("expected error for unknown tool") + } + // Hooks should not be called when the tool doesn't exist (early return before hook loop) + if len(hook.beforeCalls) != 0 { + t.Errorf("hook.beforeCalls = %d, want 0 for missing tool", len(hook.beforeCalls)) + } + if len(hook.afterCalls) != 0 { + t.Errorf("hook.afterCalls = %d, want 0 for missing tool", len(hook.afterCalls)) + } +} + +// TestToolHook_NoHooksSucceeds verifies that a tool executes normally with no hooks registered. +func TestToolHook_NoHooksSucceeds(t *testing.T) { + reg := NewToolRegistry() + reg.Register(&dummyTool{name: "plain_tool"}) + // No hooks added + + result := reg.Execute(context.Background(), "plain_tool", nil) + + if result.IsError { + t.Errorf("expected success with no hooks, got error: %s", result.ForLLM) + } +} diff --git a/pkg/tools/loop_detector.go b/pkg/tools/loop_detector.go new file mode 100644 index 000000000..c67b5040e --- /dev/null +++ b/pkg/tools/loop_detector.go @@ -0,0 +1,377 @@ +package tools + +import ( + "context" + "crypto/sha256" + "encoding/json" + "fmt" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/logger" +) + +// Session context key for per-session loop detection isolation. + +type loopDetectorContextKey struct{} + +// WithSessionKey returns a context carrying the given session key for loop detection. +func WithSessionKey(ctx context.Context, key string) context.Context { + return context.WithValue(ctx, loopDetectorContextKey{}, key) +} + +func sessionKeyFromContext(ctx context.Context) string { + if key, ok := ctx.Value(loopDetectorContextKey{}).(string); ok { + return key + } + return "_default" +} + +// Default thresholds matching OpenClaw's production values. +const ( + DefaultHistorySize = 30 + DefaultWarningThreshold = 10 + DefaultCriticalThreshold = 20 + DefaultCircuitBreakerThreshold = 30 +) + +// LoopDetectorConfig configures loop detection thresholds. +type LoopDetectorConfig struct { + HistorySize int // sliding window size (default 30) + WarningThreshold int // generic repeat warn level (default 10) + CriticalThreshold int // block execution level (default 20) + CircuitBreakerThreshold int // global emergency stop (default 30) + EnableGenericRepeat bool // detect any repeated tool+args + EnablePingPong bool // detect alternating A,B,A,B patterns +} + +// DefaultLoopDetectorConfig returns production-ready defaults. +func DefaultLoopDetectorConfig() LoopDetectorConfig { + return LoopDetectorConfig{ + HistorySize: DefaultHistorySize, + WarningThreshold: DefaultWarningThreshold, + CriticalThreshold: DefaultCriticalThreshold, + CircuitBreakerThreshold: DefaultCircuitBreakerThreshold, + EnableGenericRepeat: true, + EnablePingPong: true, + } +} + +// LoopVerdict is the result of loop analysis. +type LoopVerdict struct { + Blocked bool + Warning bool + Reason string + Count int +} + +// toolCallRecord represents a single tool call in the sliding window. +type toolCallRecord struct { + ToolName string + ArgsHash string + ResultHash string // filled after execution via AfterExecute + Timestamp time.Time +} + +// sessionState holds per-session detection state. +type sessionState struct { + history []toolCallRecord + mu sync.Mutex +} + +// LoopDetector implements ToolHook for detecting repetitive tool call patterns. +// It uses per-session state keyed by the session key in context.Context. +type LoopDetector struct { + config LoopDetectorConfig + sessions sync.Map // sessionKey -> *sessionState +} + +// NewLoopDetector creates a loop detector with the given configuration. +// Zero or negative thresholds are replaced with production defaults. +func NewLoopDetector(config LoopDetectorConfig) *LoopDetector { + if config.HistorySize <= 0 { + config.HistorySize = DefaultHistorySize + } + if config.WarningThreshold <= 0 { + config.WarningThreshold = DefaultWarningThreshold + } + if config.CriticalThreshold <= 0 { + config.CriticalThreshold = DefaultCriticalThreshold + } + if config.CircuitBreakerThreshold <= 0 { + config.CircuitBreakerThreshold = DefaultCircuitBreakerThreshold + } + return &LoopDetector{config: config} +} + +func (d *LoopDetector) getSession(key string) *sessionState { + if v, ok := d.sessions.Load(key); ok { + return v.(*sessionState) + } + s := &sessionState{} + actual, _ := d.sessions.LoadOrStore(key, s) + return actual.(*sessionState) +} + +// BeforeExecute checks for loops before tool execution. +// Returns an error to block execution if a critical loop is detected. +func (d *LoopDetector) BeforeExecute(ctx context.Context, toolName string, args map[string]interface{}) error { + sessionKey := sessionKeyFromContext(ctx) + state := d.getSession(sessionKey) + argsHash := hashArgs(args) + + state.mu.Lock() + defer state.mu.Unlock() + + // Check for loops before recording this call + verdict := d.detect(state, toolName, argsHash) + + // Record this call + state.history = append(state.history, toolCallRecord{ + ToolName: toolName, + ArgsHash: argsHash, + Timestamp: time.Now(), + }) + // Trim sliding window + if len(state.history) > d.config.HistorySize { + state.history = state.history[len(state.history)-d.config.HistorySize:] + } + + if verdict.Blocked { + logger.WarnCF("loop_detector", "Loop blocked", + map[string]interface{}{ + "tool": toolName, + "reason": verdict.Reason, + "count": verdict.Count, + "session": sessionKey, + }) + return fmt.Errorf("loop detected: %s (count: %d)", verdict.Reason, verdict.Count) + } + + if verdict.Warning { + logger.WarnCF("loop_detector", "Loop warning", + map[string]interface{}{ + "tool": toolName, + "reason": verdict.Reason, + "count": verdict.Count, + "session": sessionKey, + }) + } + + return nil +} + +// AfterExecute records the tool result hash for no-progress detection. +func (d *LoopDetector) AfterExecute(ctx context.Context, toolName string, args map[string]interface{}, result *ToolResult) { + sessionKey := sessionKeyFromContext(ctx) + state := d.getSession(sessionKey) + argsHash := hashArgs(args) + resultHash := hashResult(result) + + state.mu.Lock() + defer state.mu.Unlock() + + // Find the most recent matching record without a result hash + for i := len(state.history) - 1; i >= 0; i-- { + rec := &state.history[i] + if rec.ToolName == toolName && rec.ArgsHash == argsHash && rec.ResultHash == "" { + rec.ResultHash = resultHash + break + } + } +} + +// ResetSession clears the detection state for a session. +func (d *LoopDetector) ResetSession(sessionKey string) { + d.sessions.Delete(sessionKey) +} + +// detect runs all detection engines and returns the highest-severity verdict. +func (d *LoopDetector) detect(state *sessionState, toolName, argsHash string) LoopVerdict { + // 1. Global circuit breaker: no-progress on same tool+args + noProgressStreak := d.getNoProgressStreak(state, toolName, argsHash) + if noProgressStreak >= d.config.CircuitBreakerThreshold { + return LoopVerdict{ + Blocked: true, + Reason: "circuit breaker: repeated calls with identical results", + Count: noProgressStreak, + } + } + + // 2. Generic repeat detection (any tool+args combination) + if d.config.EnableGenericRepeat { + count := d.countRepeats(state, toolName, argsHash) + if count >= d.config.CriticalThreshold { + return LoopVerdict{ + Blocked: true, + Reason: "tool call repeated too many times", + Count: count, + } + } + if count >= d.config.WarningThreshold { + return LoopVerdict{ + Warning: true, + Reason: "possible loop: tool call repeated", + Count: count, + } + } + } + + // 3. Ping-pong detection (alternating A,B,A,B pattern) + if d.config.EnablePingPong { + streak := d.getPingPongStreak(state, toolName, argsHash) + if streak >= d.config.CriticalThreshold && d.hasPingPongNoProgress(state, streak) { + return LoopVerdict{ + Blocked: true, + Reason: "ping-pong loop with no progress", + Count: streak, + } + } + if streak >= d.config.WarningThreshold { + return LoopVerdict{ + Warning: true, + Reason: "possible ping-pong pattern", + Count: streak, + } + } + } + + return LoopVerdict{} +} + +// countRepeats counts how many times this tool+args combination appears in history. +func (d *LoopDetector) countRepeats(state *sessionState, toolName, argsHash string) int { + count := 0 + for _, rec := range state.history { + if rec.ToolName == toolName && rec.ArgsHash == argsHash { + count++ + } + } + return count +} + +// getPingPongStreak detects alternating A,B,A,B patterns. +// Returns the alternation streak length (number of entries in the alternating tail). +func (d *LoopDetector) getPingPongStreak(state *sessionState, toolName, argsHash string) int { + h := state.history + if len(h) < 2 { + return 0 + } + + currentSig := toolName + ":" + argsHash + lastSig := h[len(h)-1].ToolName + ":" + h[len(h)-1].ArgsHash + + // Same signature as last call — not alternation + if currentSig == lastSig { + return 0 + } + + // Count alternating tail backwards from the end. + // History ends with ...lastSig. Current call would be currentSig. + // Pattern: ...currentSig, lastSig, currentSig, lastSig + streak := 1 // count the last entry + for i := len(h) - 2; i >= 0; i-- { + sig := h[i].ToolName + ":" + h[i].ArgsHash + // Positions from end: even indices should match lastSig, odd should match currentSig + distFromEnd := len(h) - 1 - i + var expected string + if distFromEnd%2 == 0 { + expected = lastSig + } else { + expected = currentSig + } + if sig != expected { + break + } + streak++ + } + + return streak +} + +// getNoProgressStreak counts consecutive tail calls with same tool+args AND same result. +func (d *LoopDetector) getNoProgressStreak(state *sessionState, toolName, argsHash string) int { + h := state.history + if len(h) == 0 { + return 0 + } + + // Find the most recent matching entry with a result hash + var referenceHash string + for i := len(h) - 1; i >= 0; i-- { + if h[i].ToolName == toolName && h[i].ArgsHash == argsHash && h[i].ResultHash != "" { + referenceHash = h[i].ResultHash + break + } + } + if referenceHash == "" { + return 0 // no completed calls to compare + } + + // Count consecutive matching entries from the tail + streak := 0 + for i := len(h) - 1; i >= 0; i-- { + rec := h[i] + if rec.ToolName != toolName || rec.ArgsHash != argsHash { + break + } + if rec.ResultHash == "" { + // Call recorded but not yet completed — include in streak (conservative) + streak++ + continue + } + if rec.ResultHash != referenceHash { + break // result changed = progress + } + streak++ + } + + return streak +} + +// hasPingPongNoProgress checks that both sides of a ping-pong have stable (unchanged) results. +func (d *LoopDetector) hasPingPongNoProgress(state *sessionState, streak int) bool { + h := state.history + if len(h) < 4 || streak < 4 { + return false + } + + // Check the last 4 entries: [A, B, A, B] + // Side A = indices -4, -2; Side B = indices -3, -1 + tail := h[len(h)-4:] + if tail[0].ResultHash == "" || tail[1].ResultHash == "" || + tail[2].ResultHash == "" || tail[3].ResultHash == "" { + return false // need result hashes on all 4 + } + + sideAStable := tail[0].ResultHash == tail[2].ResultHash + sideBStable := tail[1].ResultHash == tail[3].ResultHash + return sideAStable && sideBStable +} + +// hashArgs produces a deterministic hash of tool arguments. +// Go's json.Marshal sorts map keys alphabetically, ensuring stability. +func hashArgs(args map[string]interface{}) string { + if len(args) == 0 { + return "empty" + } + data, err := json.Marshal(args) + if err != nil { + return "error" + } + h := sha256.Sum256(data) + return fmt.Sprintf("%x", h[:8]) // 16 hex chars — sufficient for dedup +} + +// hashResult produces a hash of the tool result for no-progress detection. +func hashResult(result *ToolResult) string { + if result == nil { + return "nil" + } + content := result.ForLLM + if len(content) > 1024 { + content = content[:1024] // cap for hashing performance + } + h := sha256.Sum256([]byte(content)) + return fmt.Sprintf("%x", h[:8]) +} diff --git a/pkg/tools/loop_detector_test.go b/pkg/tools/loop_detector_test.go new file mode 100644 index 000000000..b25dbc903 --- /dev/null +++ b/pkg/tools/loop_detector_test.go @@ -0,0 +1,526 @@ +package tools + +import ( + "context" + "fmt" + "testing" +) + +// loopTestTool is a minimal tool that returns configurable results. +type loopTestTool struct { + name string + result string +} + +func (t *loopTestTool) Name() string { return t.name } +func (t *loopTestTool) Description() string { return "loop test tool" } +func (t *loopTestTool) Parameters() map[string]interface{} { return nil } +func (t *loopTestTool) Execute(_ context.Context, _ map[string]interface{}) *ToolResult { + return NewToolResult(t.result) +} + +// --- Context key tests --- + +func TestWithSessionKey(t *testing.T) { + ctx := context.Background() + if got := sessionKeyFromContext(ctx); got != "_default" { + t.Errorf("expected _default, got %q", got) + } + + ctx = WithSessionKey(ctx, "session-123") + if got := sessionKeyFromContext(ctx); got != "session-123" { + t.Errorf("expected session-123, got %q", got) + } +} + +// --- Hash tests --- + +func TestHashArgs_Deterministic(t *testing.T) { + args := map[string]interface{}{"path": "/tmp/file.txt", "content": "hello"} + h1 := hashArgs(args) + h2 := hashArgs(args) + if h1 != h2 { + t.Errorf("hashArgs not deterministic: %s != %s", h1, h2) + } +} + +func TestHashArgs_Empty(t *testing.T) { + if got := hashArgs(nil); got != "empty" { + t.Errorf("expected 'empty' for nil args, got %q", got) + } + if got := hashArgs(map[string]interface{}{}); got != "empty" { + t.Errorf("expected 'empty' for empty args, got %q", got) + } +} + +func TestHashArgs_DifferentArgs(t *testing.T) { + h1 := hashArgs(map[string]interface{}{"a": "1"}) + h2 := hashArgs(map[string]interface{}{"a": "2"}) + if h1 == h2 { + t.Error("expected different hashes for different args") + } +} + +func TestHashResult_NilResult(t *testing.T) { + if got := hashResult(nil); got != "nil" { + t.Errorf("expected 'nil', got %q", got) + } +} + +func TestHashResult_Deterministic(t *testing.T) { + r := &ToolResult{ForLLM: "output data"} + h1 := hashResult(r) + h2 := hashResult(r) + if h1 != h2 { + t.Errorf("hashResult not deterministic: %s != %s", h1, h2) + } +} + +// --- Config validation tests --- + +func TestNewLoopDetector_DefaultConfig(t *testing.T) { + d := NewLoopDetector(DefaultLoopDetectorConfig()) + if d.config.HistorySize != 30 { + t.Errorf("HistorySize = %d, want 30", d.config.HistorySize) + } + if d.config.WarningThreshold != 10 { + t.Errorf("WarningThreshold = %d, want 10", d.config.WarningThreshold) + } + if d.config.CriticalThreshold != 20 { + t.Errorf("CriticalThreshold = %d, want 20", d.config.CriticalThreshold) + } + if d.config.CircuitBreakerThreshold != 30 { + t.Errorf("CircuitBreakerThreshold = %d, want 30", d.config.CircuitBreakerThreshold) + } +} + +func TestNewLoopDetector_FixesZeroThresholds(t *testing.T) { + d := NewLoopDetector(LoopDetectorConfig{ + WarningThreshold: 0, // zero → default 10 + CriticalThreshold: -1, // negative → default 20 + CircuitBreakerThreshold: 0, // zero → default 30 + }) + if d.config.WarningThreshold != DefaultWarningThreshold { + t.Errorf("WarningThreshold = %d, want %d", d.config.WarningThreshold, DefaultWarningThreshold) + } + if d.config.CriticalThreshold != DefaultCriticalThreshold { + t.Errorf("CriticalThreshold = %d, want %d", d.config.CriticalThreshold, DefaultCriticalThreshold) + } + if d.config.CircuitBreakerThreshold != DefaultCircuitBreakerThreshold { + t.Errorf("CircuitBreakerThreshold = %d, want %d", d.config.CircuitBreakerThreshold, DefaultCircuitBreakerThreshold) + } +} + +func TestNewLoopDetector_RespectsPositiveThresholds(t *testing.T) { + d := NewLoopDetector(LoopDetectorConfig{ + WarningThreshold: 100, + CriticalThreshold: 3, + CircuitBreakerThreshold: 5, + }) + // All positive values should be kept as-is + if d.config.WarningThreshold != 100 { + t.Errorf("WarningThreshold = %d, want 100", d.config.WarningThreshold) + } + if d.config.CriticalThreshold != 3 { + t.Errorf("CriticalThreshold = %d, want 3", d.config.CriticalThreshold) + } + if d.config.CircuitBreakerThreshold != 5 { + t.Errorf("CircuitBreakerThreshold = %d, want 5", d.config.CircuitBreakerThreshold) + } +} + +// --- Generic repeat detection --- + +func TestLoopDetector_BelowWarning_NoBlock(t *testing.T) { + d := NewLoopDetector(LoopDetectorConfig{ + WarningThreshold: 5, + CriticalThreshold: 10, + CircuitBreakerThreshold: 20, + EnableGenericRepeat: true, + }) + ctx := WithSessionKey(context.Background(), "test") + args := map[string]interface{}{"key": "val"} + + // Call 4 times (below warning=5): all should pass + for i := 0; i < 4; i++ { + if err := d.BeforeExecute(ctx, "read_file", args); err != nil { + t.Fatalf("call %d: unexpected block: %v", i, err) + } + } +} + +func TestLoopDetector_AtWarning_NoBlock(t *testing.T) { + d := NewLoopDetector(LoopDetectorConfig{ + WarningThreshold: 3, + CriticalThreshold: 6, + CircuitBreakerThreshold: 12, + EnableGenericRepeat: true, + }) + ctx := WithSessionKey(context.Background(), "test") + args := map[string]interface{}{"key": "val"} + + // Warning is informational — should NOT block + for i := 0; i < 5; i++ { + if err := d.BeforeExecute(ctx, "read_file", args); err != nil { + t.Fatalf("call %d: unexpected block at warning level: %v", i, err) + } + } +} + +func TestLoopDetector_AtCritical_Blocks(t *testing.T) { + d := NewLoopDetector(LoopDetectorConfig{ + WarningThreshold: 3, + CriticalThreshold: 6, + CircuitBreakerThreshold: 15, + EnableGenericRepeat: true, + }) + ctx := WithSessionKey(context.Background(), "test") + args := map[string]interface{}{"key": "val"} + + // First 6 calls should pass (history counts 0..5 before each check) + for i := 0; i < 6; i++ { + if err := d.BeforeExecute(ctx, "read_file", args); err != nil { + t.Fatalf("call %d: unexpected block: %v", i, err) + } + } + + // 7th call: history has 6 entries, check sees count=6 >= critical=6 → block + if err := d.BeforeExecute(ctx, "read_file", args); err == nil { + t.Fatal("expected block at critical threshold, got nil") + } +} + +func TestLoopDetector_DifferentTools_NoConflict(t *testing.T) { + d := NewLoopDetector(LoopDetectorConfig{ + WarningThreshold: 3, + CriticalThreshold: 6, + CircuitBreakerThreshold: 15, + EnableGenericRepeat: true, + }) + ctx := WithSessionKey(context.Background(), "test") + + // Alternate between two tools: neither should hit threshold + for i := 0; i < 10; i++ { + tool := "read_file" + if i%2 == 1 { + tool = "write_file" + } + if err := d.BeforeExecute(ctx, tool, nil); err != nil { + t.Fatalf("call %d (%s): unexpected block: %v", i, tool, err) + } + } +} + +func TestLoopDetector_GenericRepeatDisabled(t *testing.T) { + d := NewLoopDetector(LoopDetectorConfig{ + WarningThreshold: 3, + CriticalThreshold: 6, + CircuitBreakerThreshold: 100, // high so circuit breaker doesn't fire + EnableGenericRepeat: false, + }) + ctx := WithSessionKey(context.Background(), "test") + + // Should never block from generic repeat when disabled + for i := 0; i < 20; i++ { + if err := d.BeforeExecute(ctx, "read_file", nil); err != nil { + t.Fatalf("call %d: block with generic repeat disabled: %v", i, err) + } + } +} + +// --- Ping-pong detection --- + +func TestLoopDetector_PingPong_Detected(t *testing.T) { + d := NewLoopDetector(LoopDetectorConfig{ + WarningThreshold: 4, + CriticalThreshold: 8, + CircuitBreakerThreshold: 100, + EnableGenericRepeat: false, // isolate ping-pong + EnablePingPong: true, + }) + ctx := WithSessionKey(context.Background(), "test") + argsA := map[string]interface{}{"file": "a.txt"} + argsB := map[string]interface{}{"file": "b.txt"} + + // Build alternation pattern: A, B, A, B, ... + // With result tracking for no-progress evidence + for i := 0; i < 20; i++ { + var tool string + var args map[string]interface{} + if i%2 == 0 { + tool = "read_file" + args = argsA + } else { + tool = "write_file" + args = argsB + } + err := d.BeforeExecute(ctx, tool, args) + // Record identical result to establish no-progress + d.AfterExecute(ctx, tool, args, &ToolResult{ForLLM: fmt.Sprintf("result_%s", tool)}) + + if err != nil { + // Should eventually block + if i < 8 { + t.Fatalf("blocked too early at call %d: %v", i, err) + } + return // success: blocked at or after critical threshold + } + } + t.Fatal("ping-pong was never blocked after 20 alternating calls") +} + +func TestLoopDetector_PingPong_WithProgress_NoBlock(t *testing.T) { + d := NewLoopDetector(LoopDetectorConfig{ + WarningThreshold: 4, + CriticalThreshold: 8, + CircuitBreakerThreshold: 100, + EnableGenericRepeat: false, + EnablePingPong: true, + }) + ctx := WithSessionKey(context.Background(), "test") + argsA := map[string]interface{}{"file": "a.txt"} + argsB := map[string]interface{}{"file": "b.txt"} + + // Build alternation but with CHANGING results (progress) + for i := 0; i < 20; i++ { + var tool string + var args map[string]interface{} + if i%2 == 0 { + tool = "read_file" + args = argsA + } else { + tool = "write_file" + args = argsB + } + if err := d.BeforeExecute(ctx, tool, args); err != nil { + t.Fatalf("blocked at call %d despite progress: %v", i, err) + } + // Different result each time = progress + d.AfterExecute(ctx, tool, args, &ToolResult{ForLLM: fmt.Sprintf("result_%d", i)}) + } +} + +func TestLoopDetector_PingPongDisabled(t *testing.T) { + d := NewLoopDetector(LoopDetectorConfig{ + WarningThreshold: 3, + CriticalThreshold: 6, + CircuitBreakerThreshold: 100, + EnableGenericRepeat: false, + EnablePingPong: false, + }) + ctx := WithSessionKey(context.Background(), "test") + + for i := 0; i < 20; i++ { + tool := "read_file" + if i%2 == 1 { + tool = "write_file" + } + if err := d.BeforeExecute(ctx, tool, nil); err != nil { + t.Fatalf("call %d: block with ping-pong disabled: %v", i, err) + } + d.AfterExecute(ctx, tool, nil, &ToolResult{ForLLM: "same"}) + } +} + +// --- No-progress / circuit breaker --- + +func TestLoopDetector_CircuitBreaker_NoProgress(t *testing.T) { + threshold := 8 + d := NewLoopDetector(LoopDetectorConfig{ + WarningThreshold: 100, // high so generic repeat doesn't fire + CriticalThreshold: 100, + CircuitBreakerThreshold: threshold, + EnableGenericRepeat: false, + EnablePingPong: false, + }) + ctx := WithSessionKey(context.Background(), "test") + args := map[string]interface{}{"file": "/tmp/stuck"} + + for i := 0; i < threshold+5; i++ { + err := d.BeforeExecute(ctx, "read_file", args) + // Record identical result each time + d.AfterExecute(ctx, "read_file", args, &ToolResult{ForLLM: "same output"}) + + if err != nil { + if i < threshold { + t.Fatalf("circuit breaker fired too early at call %d", i) + } + return // success + } + } + t.Fatal("circuit breaker never fired") +} + +func TestLoopDetector_CircuitBreaker_WithProgress_NoBlock(t *testing.T) { + d := NewLoopDetector(LoopDetectorConfig{ + WarningThreshold: 100, + CriticalThreshold: 100, + CircuitBreakerThreshold: 5, + EnableGenericRepeat: false, + }) + ctx := WithSessionKey(context.Background(), "test") + + for i := 0; i < 20; i++ { + if err := d.BeforeExecute(ctx, "exec", nil); err != nil { + t.Fatalf("call %d: blocked despite progress: %v", i, err) + } + // Different result each time + d.AfterExecute(ctx, "exec", nil, &ToolResult{ForLLM: fmt.Sprintf("output_%d", i)}) + } +} + +// --- Session isolation --- + +func TestLoopDetector_SessionIsolation(t *testing.T) { + d := NewLoopDetector(LoopDetectorConfig{ + WarningThreshold: 3, + CriticalThreshold: 5, + CircuitBreakerThreshold: 15, + EnableGenericRepeat: true, + }) + + ctxA := WithSessionKey(context.Background(), "session-A") + ctxB := WithSessionKey(context.Background(), "session-B") + + // Fill session A to near-critical + for i := 0; i < 4; i++ { + if err := d.BeforeExecute(ctxA, "read_file", nil); err != nil { + t.Fatalf("session A call %d: unexpected block: %v", i, err) + } + } + + // Session B should be unaffected + for i := 0; i < 4; i++ { + if err := d.BeforeExecute(ctxB, "read_file", nil); err != nil { + t.Fatalf("session B call %d: blocked by session A state: %v", i, err) + } + } +} + +// --- ResetSession --- + +func TestLoopDetector_ResetSession(t *testing.T) { + d := NewLoopDetector(LoopDetectorConfig{ + WarningThreshold: 3, + CriticalThreshold: 5, + CircuitBreakerThreshold: 15, + EnableGenericRepeat: true, + }) + ctx := WithSessionKey(context.Background(), "reset-test") + + // Fill to near-critical + for i := 0; i < 4; i++ { + d.BeforeExecute(ctx, "read_file", nil) + } + + // Reset session + d.ResetSession("reset-test") + + // Should be able to call again without hitting threshold + for i := 0; i < 4; i++ { + if err := d.BeforeExecute(ctx, "read_file", nil); err != nil { + t.Fatalf("call %d after reset: unexpected block: %v", i, err) + } + } +} + +// --- Sliding window --- + +func TestLoopDetector_SlidingWindow_EvictsOld(t *testing.T) { + d := NewLoopDetector(LoopDetectorConfig{ + HistorySize: 5, // tiny window + WarningThreshold: 3, + CriticalThreshold: 5, + CircuitBreakerThreshold: 10, + EnableGenericRepeat: true, + }) + ctx := WithSessionKey(context.Background(), "window-test") + + // Fill 5 entries with read_file + for i := 0; i < 5; i++ { + d.BeforeExecute(ctx, "read_file", nil) + } + + // Now call different tools to push old entries out of window + for i := 0; i < 5; i++ { + d.BeforeExecute(ctx, fmt.Sprintf("tool_%d", i), nil) + } + + // read_file should no longer be in window — calling it should not trigger anything + if err := d.BeforeExecute(ctx, "read_file", nil); err != nil { + t.Fatalf("expected no block after window eviction, got: %v", err) + } +} + +// --- Integration with ToolRegistry --- + +func TestLoopDetector_IntegrationWithRegistry(t *testing.T) { + reg := NewToolRegistry() + reg.Register(&loopTestTool{name: "stuck_tool", result: "same"}) + + d := NewLoopDetector(LoopDetectorConfig{ + WarningThreshold: 3, + CriticalThreshold: 5, + CircuitBreakerThreshold: 15, + EnableGenericRepeat: true, + }) + reg.AddHook(d) + + ctx := WithSessionKey(context.Background(), "integration") + + // Should succeed initially + for i := 0; i < 5; i++ { + result := reg.Execute(ctx, "stuck_tool", nil) + if result.IsError { + t.Fatalf("call %d: unexpected error: %s", i, result.ForLLM) + } + } + + // 6th call should be blocked + result := reg.Execute(ctx, "stuck_tool", nil) + if !result.IsError { + t.Fatal("expected block at critical threshold via registry integration") + } +} + +// --- AfterExecute records result --- + +func TestLoopDetector_AfterExecute_RecordsResult(t *testing.T) { + d := NewLoopDetector(DefaultLoopDetectorConfig()) + ctx := WithSessionKey(context.Background(), "after-test") + args := map[string]interface{}{"x": "1"} + + d.BeforeExecute(ctx, "test_tool", args) + d.AfterExecute(ctx, "test_tool", args, &ToolResult{ForLLM: "result"}) + + // Verify result was recorded + state := d.getSession("after-test") + state.mu.Lock() + defer state.mu.Unlock() + + if len(state.history) != 1 { + t.Fatalf("history len = %d, want 1", len(state.history)) + } + if state.history[0].ResultHash == "" { + t.Error("result hash not recorded by AfterExecute") + } +} + +// --- Default session key --- + +func TestLoopDetector_DefaultSessionKey(t *testing.T) { + d := NewLoopDetector(LoopDetectorConfig{ + WarningThreshold: 3, + CriticalThreshold: 5, + CircuitBreakerThreshold: 15, + EnableGenericRepeat: true, + }) + + // No session key in context — should use "_default" + ctx := context.Background() + for i := 0; i < 4; i++ { + if err := d.BeforeExecute(ctx, "test", nil); err != nil { + t.Fatalf("call %d: unexpected block: %v", i, err) + } + } +} diff --git a/pkg/tools/policy.go b/pkg/tools/policy.go new file mode 100644 index 000000000..5d4cac945 --- /dev/null +++ b/pkg/tools/policy.go @@ -0,0 +1,44 @@ +package tools + +// ToolPolicy represents one layer of allow/deny filtering. +type ToolPolicy struct { + Allow []string + Deny []string +} + +// ApplyPolicy filters a registry in-place. +// Allow (if non-empty): only listed tools survive. +// Deny: listed tools removed from whatever remains. +func ApplyPolicy(reg *ToolRegistry, policy ToolPolicy) { + allowNames := ResolveToolNames(policy.Allow) + denyNames := ResolveToolNames(policy.Deny) + + // Allow-list: if non-empty, remove everything not in the allow set + if len(allowNames) > 0 { + allowSet := make(map[string]struct{}, len(allowNames)) + for _, name := range allowNames { + allowSet[name] = struct{}{} + } + for _, name := range reg.List() { + if _, ok := allowSet[name]; !ok { + reg.Remove(name) + } + } + } + + // Deny-list: remove listed tools + for _, name := range denyNames { + reg.Remove(name) + } +} + +// DepthDenyList returns tools to deny at a given depth. +// depth 0: nil (main agent, full access) +// depth >= maxDepth: spawn/handoff/list_agents denied (leaf, no further chaining) +// between 0 and maxDepth: nil (mid-chain, full access) +func DepthDenyList(depth, maxDepth int) []string { + if depth >= maxDepth { + return []string{"spawn", "handoff", "list_agents"} + } + return nil +} diff --git a/pkg/tools/policy_test.go b/pkg/tools/policy_test.go new file mode 100644 index 000000000..d5b05578d --- /dev/null +++ b/pkg/tools/policy_test.go @@ -0,0 +1,179 @@ +package tools + +import ( + "sort" + "testing" +) + +func setupTestRegistry(names ...string) *ToolRegistry { + reg := NewToolRegistry() + for _, name := range names { + reg.Register(&dummyTool{name: name}) + } + return reg +} + +func registryNames(reg *ToolRegistry) []string { + names := reg.List() + sort.Strings(names) + return names +} + +func TestApplyPolicy_AllowOnly(t *testing.T) { + reg := setupTestRegistry("read_file", "write_file", "exec", "web_search") + ApplyPolicy(reg, ToolPolicy{Allow: []string{"read_file", "exec"}}) + + names := registryNames(reg) + if len(names) != 2 { + t.Fatalf("count = %d, want 2: %v", len(names), names) + } + if names[0] != "exec" || names[1] != "read_file" { + t.Errorf("names = %v, want [exec, read_file]", names) + } +} + +func TestApplyPolicy_DenyOnly(t *testing.T) { + reg := setupTestRegistry("read_file", "write_file", "exec", "web_search") + ApplyPolicy(reg, ToolPolicy{Deny: []string{"exec", "web_search"}}) + + names := registryNames(reg) + if len(names) != 2 { + t.Fatalf("count = %d, want 2: %v", len(names), names) + } + if names[0] != "read_file" || names[1] != "write_file" { + t.Errorf("names = %v, want [read_file, write_file]", names) + } +} + +func TestApplyPolicy_AllowAndDeny(t *testing.T) { + reg := setupTestRegistry("read_file", "write_file", "exec", "web_search") + ApplyPolicy(reg, ToolPolicy{ + Allow: []string{"read_file", "write_file", "exec"}, + Deny: []string{"exec"}, + }) + + names := registryNames(reg) + if len(names) != 2 { + t.Fatalf("count = %d, want 2: %v", len(names), names) + } + if names[0] != "read_file" || names[1] != "write_file" { + t.Errorf("names = %v, want [read_file, write_file]", names) + } +} + +func TestApplyPolicy_EmptyPolicy(t *testing.T) { + reg := setupTestRegistry("read_file", "write_file", "exec") + ApplyPolicy(reg, ToolPolicy{}) + + if reg.Count() != 3 { + t.Errorf("count = %d, want 3 (no-op)", reg.Count()) + } +} + +func TestApplyPolicy_GroupRefs(t *testing.T) { + reg := setupTestRegistry("read_file", "write_file", "edit_file", "append_file", "list_dir", "web_search", "web_fetch", "exec") + ApplyPolicy(reg, ToolPolicy{Deny: []string{"group:web"}}) + + names := registryNames(reg) + for _, name := range names { + if name == "web_search" || name == "web_fetch" { + t.Errorf("web tool %q should have been denied", name) + } + } + if reg.Count() != 6 { + t.Errorf("count = %d, want 6", reg.Count()) + } +} + +func TestDepthDenyList_Zero(t *testing.T) { + result := DepthDenyList(0, 3) + if result != nil { + t.Errorf("depth 0 should return nil, got %v", result) + } +} + +func TestDepthDenyList_AtMax(t *testing.T) { + result := DepthDenyList(3, 3) + expected := []string{"spawn", "handoff", "list_agents"} + if len(result) != len(expected) { + t.Fatalf("len = %d, want %d", len(result), len(expected)) + } + for i, name := range expected { + if result[i] != name { + t.Errorf("result[%d] = %q, want %q", i, result[i], name) + } + } +} + +func TestDepthDenyList_BelowMax(t *testing.T) { + result := DepthDenyList(1, 3) + if result != nil { + t.Errorf("mid-chain should return nil, got %v", result) + } +} + +func TestRegistryClone(t *testing.T) { + reg := setupTestRegistry("tool_a", "tool_b", "tool_c") + cloned := reg.Clone() + + // Same tools + if cloned.Count() != 3 { + t.Fatalf("cloned count = %d, want 3", cloned.Count()) + } + + // Independent: removing from clone doesn't affect original + cloned.Remove("tool_b") + if cloned.Count() != 2 { + t.Errorf("cloned count after remove = %d, want 2", cloned.Count()) + } + if reg.Count() != 3 { + t.Errorf("original count after clone remove = %d, want 3", reg.Count()) + } +} + +func TestRegistryRemove(t *testing.T) { + reg := setupTestRegistry("tool_a", "tool_b") + reg.Remove("tool_a") + + if reg.Count() != 1 { + t.Fatalf("count = %d, want 1", reg.Count()) + } + if _, ok := reg.Get("tool_a"); ok { + t.Error("tool_a should have been removed") + } + if _, ok := reg.Get("tool_b"); !ok { + t.Error("tool_b should still exist") + } + + // Remove nonexistent tool — no panic + reg.Remove("nonexistent") + if reg.Count() != 1 { + t.Errorf("count = %d, want 1 after removing nonexistent", reg.Count()) + } +} + +func TestPolicyPipeline_Compose(t *testing.T) { + // Simulate: global allow → per-agent deny → depth deny + reg := setupTestRegistry( + "read_file", "write_file", "exec", + "web_search", "spawn", "handoff", "list_agents", + ) + + // Layer 1: per-agent policy (deny web) + ApplyPolicy(reg, ToolPolicy{Deny: []string{"web_search"}}) + + // Layer 2: depth policy (leaf: deny spawn/handoff/list_agents) + denyList := DepthDenyList(3, 3) // at max depth + ApplyPolicy(reg, ToolPolicy{Deny: denyList}) + + names := registryNames(reg) + expected := map[string]bool{"read_file": true, "write_file": true, "exec": true} + if len(names) != len(expected) { + t.Fatalf("count = %d, want %d: %v", len(names), len(expected), names) + } + for _, name := range names { + if !expected[name] { + t.Errorf("unexpected tool: %q", name) + } + } +} diff --git a/pkg/tools/process_scope.go b/pkg/tools/process_scope.go new file mode 100644 index 000000000..d454416ea --- /dev/null +++ b/pkg/tools/process_scope.go @@ -0,0 +1,141 @@ +package tools + +import ( + "os" + "sync" + "syscall" + + "github.com/sipeed/picoclaw/pkg/logger" +) + +// ProcessScope tracks PIDs per session key, providing namespace-like isolation +// for exec tool processes. Inspired by Google Kubernetes pod process isolation +// and Linux cgroups — agents can only see and kill their own processes. +type ProcessScope struct { + pids sync.Map // sessionKey -> *pidSet +} + +type pidSet struct { + mu sync.Mutex + pids map[int]bool +} + +// NewProcessScope creates a new process scope tracker. +func NewProcessScope() *ProcessScope { + return &ProcessScope{} +} + +// Register adds a PID to a session's scope. +func (ps *ProcessScope) Register(sessionKey string, pid int) { + set := ps.getOrCreate(sessionKey) + set.mu.Lock() + defer set.mu.Unlock() + set.pids[pid] = true +} + +// Deregister removes a PID from a session's scope (process exited normally). +func (ps *ProcessScope) Deregister(sessionKey string, pid int) { + v, ok := ps.pids.Load(sessionKey) + if !ok { + return + } + set := v.(*pidSet) + set.mu.Lock() + defer set.mu.Unlock() + delete(set.pids, pid) +} + +// Owns returns true if the given PID belongs to the session's scope. +func (ps *ProcessScope) Owns(sessionKey string, pid int) bool { + v, ok := ps.pids.Load(sessionKey) + if !ok { + return false + } + set := v.(*pidSet) + set.mu.Lock() + defer set.mu.Unlock() + return set.pids[pid] +} + +// ListPIDs returns all live PIDs for a session (filters out already-exited processes). +func (ps *ProcessScope) ListPIDs(sessionKey string) []int { + v, ok := ps.pids.Load(sessionKey) + if !ok { + return nil + } + set := v.(*pidSet) + set.mu.Lock() + defer set.mu.Unlock() + + var live []int + for pid := range set.pids { + // Check if process is still running (Unix signal 0). + if isProcessAlive(pid) { + live = append(live, pid) + } else { + delete(set.pids, pid) + } + } + return live +} + +// KillAll kills all processes owned by a session. Returns number killed. +// Used during cascade stop to clean up spawned processes. +func (ps *ProcessScope) KillAll(sessionKey string) int { + v, ok := ps.pids.Load(sessionKey) + if !ok { + return 0 + } + set := v.(*pidSet) + set.mu.Lock() + defer set.mu.Unlock() + + killed := 0 + for pid := range set.pids { + if err := killProcess(pid); err == nil { + killed++ + } + delete(set.pids, pid) + } + + if killed > 0 { + logger.InfoCF("process_scope", "Killed session processes", map[string]interface{}{ + "session_key": sessionKey, + "killed": killed, + }) + } + return killed +} + +// Cleanup removes all tracking for a session. +func (ps *ProcessScope) Cleanup(sessionKey string) { + ps.pids.Delete(sessionKey) +} + +func (ps *ProcessScope) getOrCreate(sessionKey string) *pidSet { + if v, ok := ps.pids.Load(sessionKey); ok { + return v.(*pidSet) + } + set := &pidSet{pids: make(map[int]bool)} + actual, _ := ps.pids.LoadOrStore(sessionKey, set) + return actual.(*pidSet) +} + +// isProcessAlive checks if a process is still running via signal 0 (Unix). +func isProcessAlive(pid int) bool { + proc, err := os.FindProcess(pid) + if err != nil { + return false + } + err = proc.Signal(syscall.Signal(0)) + return err == nil +} + +// killProcess sends SIGTERM to a process. Falls back to SIGKILL if needed. +func killProcess(pid int) error { + proc, err := os.FindProcess(pid) + if err != nil { + return err + } + return proc.Signal(syscall.SIGTERM) +} diff --git a/pkg/tools/process_scope_test.go b/pkg/tools/process_scope_test.go new file mode 100644 index 000000000..ac74419ba --- /dev/null +++ b/pkg/tools/process_scope_test.go @@ -0,0 +1,136 @@ +package tools + +import ( + "os" + "os/exec" + "testing" +) + +func TestProcessScope_RegisterAndOwns(t *testing.T) { + ps := NewProcessScope() + + ps.Register("session-1", 12345) + ps.Register("session-1", 12346) + ps.Register("session-2", 99999) + + if !ps.Owns("session-1", 12345) { + t.Error("session-1 should own PID 12345") + } + if !ps.Owns("session-1", 12346) { + t.Error("session-1 should own PID 12346") + } + if ps.Owns("session-1", 99999) { + t.Error("session-1 should NOT own PID 99999") + } + if !ps.Owns("session-2", 99999) { + t.Error("session-2 should own PID 99999") + } +} + +func TestProcessScope_Deregister(t *testing.T) { + ps := NewProcessScope() + + ps.Register("session-1", 12345) + ps.Deregister("session-1", 12345) + + if ps.Owns("session-1", 12345) { + t.Error("should not own after deregister") + } +} + +func TestProcessScope_CrossSessionIsolation(t *testing.T) { + ps := NewProcessScope() + + ps.Register("session-a", 100) + ps.Register("session-b", 200) + + if ps.Owns("session-a", 200) { + t.Error("session-a should not see session-b's processes") + } + if ps.Owns("session-b", 100) { + t.Error("session-b should not see session-a's processes") + } +} + +func TestProcessScope_ListPIDs_FiltersDeadProcesses(t *testing.T) { + ps := NewProcessScope() + + // Register current PID (alive) and a fake PID (dead) + ps.Register("session-1", os.Getpid()) + ps.Register("session-1", 999999999) // almost certainly not a real PID + + live := ps.ListPIDs("session-1") + + // Current process should be in the list + found := false + for _, pid := range live { + if pid == os.Getpid() { + found = true + } + if pid == 999999999 { + t.Error("dead PID should have been filtered out") + } + } + if !found { + t.Error("current process PID should be in live list") + } +} + +func TestProcessScope_KillAll(t *testing.T) { + ps := NewProcessScope() + + // Start a real process we can kill + cmd := exec.Command("sleep", "60") + if err := cmd.Start(); err != nil { + t.Skipf("cannot start test process: %v", err) + } + pid := cmd.Process.Pid + + ps.Register("session-1", pid) + + killed := ps.KillAll("session-1") + if killed != 1 { + t.Errorf("expected 1 killed, got %d", killed) + } + + // Reap the child process to prevent zombie (zombie still responds to signal 0). + // cmd.Wait() blocks until the process exits and is reaped by the OS. + err := cmd.Wait() + if err == nil { + t.Error("expected wait to return non-nil error after SIGTERM") + } + + // After reaping, process should no longer be in the process table + if isProcessAlive(pid) { + t.Error("process should have been killed") + } +} + +func TestProcessScope_Cleanup(t *testing.T) { + ps := NewProcessScope() + + ps.Register("session-1", os.Getpid()) + ps.Cleanup("session-1") + + if ps.Owns("session-1", os.Getpid()) { + t.Error("should not own after cleanup") + } +} + +func TestProcessScope_EmptySession(t *testing.T) { + ps := NewProcessScope() + + if ps.Owns("nonexistent", 12345) { + t.Error("nonexistent session should not own anything") + } + + pids := ps.ListPIDs("nonexistent") + if len(pids) != 0 { + t.Error("nonexistent session should have no PIDs") + } + + killed := ps.KillAll("nonexistent") + if killed != 0 { + t.Error("killing nonexistent session should kill 0") + } +} diff --git a/pkg/tools/registry.go b/pkg/tools/registry.go index c8cf92863..1e47ee267 100644 --- a/pkg/tools/registry.go +++ b/pkg/tools/registry.go @@ -12,6 +12,7 @@ import ( type ToolRegistry struct { tools map[string]Tool + hooks []ToolHook mu sync.RWMutex } @@ -21,6 +22,14 @@ func NewToolRegistry() *ToolRegistry { } } +// AddHook registers a hook that will be called around tool executions. +// Hooks are called in registration order. +func (r *ToolRegistry) AddHook(hook ToolHook) { + r.mu.Lock() + defer r.mu.Unlock() + r.hooks = append(r.hooks, hook) +} + func (r *ToolRegistry) Register(tool Tool) { r.mu.Lock() defer r.mu.Unlock() @@ -71,6 +80,27 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args }) } + // Run BeforeExecute hooks — any hook returning error blocks execution. + r.mu.RLock() + hooks := r.hooks + r.mu.RUnlock() + + for _, hook := range hooks { + if err := hook.BeforeExecute(ctx, name, args); err != nil { + logger.WarnCF("tool", "Hook blocked tool execution", + map[string]interface{}{ + "tool": name, + "error": err.Error(), + }) + blocked := ErrorResult(fmt.Sprintf("tool %q blocked by hook: %s", name, err.Error())) + // Still run AfterExecute hooks for observability + for _, h := range hooks { + h.AfterExecute(ctx, name, args, blocked) + } + return blocked + } + } + start := time.Now() result := tool.Execute(ctx, args) duration := time.Since(start) @@ -98,6 +128,11 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args }) } + // Run AfterExecute hooks for observability + for _, hook := range hooks { + hook.AfterExecute(ctx, name, args, result) + } + return result } @@ -156,6 +191,31 @@ func (r *ToolRegistry) List() []string { return names } +// Clone creates a shallow copy of the registry. +// Tool instances are shared (not deep-copied), only the map and hooks slice are copied. +// This is used for depth-based policy: clone → remove denied tools → pass to RunToolLoop. +func (r *ToolRegistry) Clone() *ToolRegistry { + r.mu.RLock() + defer r.mu.RUnlock() + + cloned := &ToolRegistry{ + tools: make(map[string]Tool, len(r.tools)), + hooks: make([]ToolHook, len(r.hooks)), + } + for name, tool := range r.tools { + cloned.tools[name] = tool + } + copy(cloned.hooks, r.hooks) + return cloned +} + +// Remove unregisters a tool by name. +func (r *ToolRegistry) Remove(name string) { + r.mu.Lock() + defer r.mu.Unlock() + delete(r.tools, name) +} + // Count returns the number of registered tools. func (r *ToolRegistry) Count() int { r.mu.RLock()