Skip to content

Commit 50969a4

Browse files
feat(steering): add experimental steering for modular prompting (#1280)
--------- Co-authored-by: John Tristan <[email protected]>
1 parent 9fa818e commit 50969a4

File tree

27 files changed

+1695
-100
lines changed

27 files changed

+1695
-100
lines changed

src/strands/agent/state.py

Lines changed: 3 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,97 +1,6 @@
11
"""Agent state management."""
22

3-
import copy
4-
import json
5-
from typing import Any, Dict, Optional
3+
from ..types.json_dict import JSONSerializableDict
64

7-
8-
class AgentState:
9-
"""Represents an Agent's stateful information outside of context provided to a model.
10-
11-
Provides a key-value store for agent state with JSON serialization validation and persistence support.
12-
Key features:
13-
- JSON serialization validation on assignment
14-
- Get/set/delete operations
15-
"""
16-
17-
def __init__(self, initial_state: Optional[Dict[str, Any]] = None):
18-
"""Initialize AgentState."""
19-
self._state: Dict[str, Dict[str, Any]]
20-
if initial_state:
21-
self._validate_json_serializable(initial_state)
22-
self._state = copy.deepcopy(initial_state)
23-
else:
24-
self._state = {}
25-
26-
def set(self, key: str, value: Any) -> None:
27-
"""Set a value in the state.
28-
29-
Args:
30-
key: The key to store the value under
31-
value: The value to store (must be JSON serializable)
32-
33-
Raises:
34-
ValueError: If key is invalid, or if value is not JSON serializable
35-
"""
36-
self._validate_key(key)
37-
self._validate_json_serializable(value)
38-
39-
self._state[key] = copy.deepcopy(value)
40-
41-
def get(self, key: Optional[str] = None) -> Any:
42-
"""Get a value or entire state.
43-
44-
Args:
45-
key: The key to retrieve (if None, returns entire state object)
46-
47-
Returns:
48-
The stored value, entire state dict, or None if not found
49-
"""
50-
if key is None:
51-
return copy.deepcopy(self._state)
52-
else:
53-
# Return specific key
54-
return copy.deepcopy(self._state.get(key))
55-
56-
def delete(self, key: str) -> None:
57-
"""Delete a specific key from the state.
58-
59-
Args:
60-
key: The key to delete
61-
"""
62-
self._validate_key(key)
63-
64-
self._state.pop(key, None)
65-
66-
def _validate_key(self, key: str) -> None:
67-
"""Validate that a key is valid.
68-
69-
Args:
70-
key: The key to validate
71-
72-
Raises:
73-
ValueError: If key is invalid
74-
"""
75-
if key is None:
76-
raise ValueError("Key cannot be None")
77-
if not isinstance(key, str):
78-
raise ValueError("Key must be a string")
79-
if not key.strip():
80-
raise ValueError("Key cannot be empty")
81-
82-
def _validate_json_serializable(self, value: Any) -> None:
83-
"""Validate that a value is JSON serializable.
84-
85-
Args:
86-
value: The value to validate
87-
88-
Raises:
89-
ValueError: If value is not JSON serializable
90-
"""
91-
try:
92-
json.dumps(value)
93-
except (TypeError, ValueError) as e:
94-
raise ValueError(
95-
f"Value is not JSON serializable: {type(value).__name__}. "
96-
f"Only JSON-compatible types (str, int, float, bool, list, dict, None) are allowed."
97-
) from e
5+
# Type alias for agent state
6+
AgentState = JSONSerializableDict

src/strands/experimental/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
This module implements experimental features that are subject to change in future revisions without notice.
44
"""
55

6-
from . import tools
6+
from . import steering, tools
77
from .agent_config import config_to_agent
88

9-
__all__ = ["config_to_agent", "tools"]
9+
__all__ = ["config_to_agent", "tools", "steering"]
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""Steering system for Strands agents.
2+
3+
Provides contextual guidance for agents through modular prompting with progressive disclosure.
4+
Instead of front-loading all instructions, steering handlers provide just-in-time feedback
5+
based on local context data populated by context callbacks.
6+
7+
Core components:
8+
9+
- SteeringHandler: Base class for guidance logic with local context
10+
- SteeringContextCallback: Protocol for context update functions
11+
- SteeringContextProvider: Protocol for multi-event context providers
12+
- SteeringAction: Proceed/Guide/Interrupt decisions
13+
14+
Usage:
15+
handler = LLMSteeringHandler(system_prompt="...")
16+
agent = Agent(tools=[...], hooks=[handler])
17+
"""
18+
19+
# Core primitives
20+
# Context providers
21+
from .context_providers.ledger_provider import (
22+
LedgerAfterToolCall,
23+
LedgerBeforeToolCall,
24+
LedgerProvider,
25+
)
26+
from .core.action import Guide, Interrupt, Proceed, SteeringAction
27+
from .core.context import SteeringContextCallback, SteeringContextProvider
28+
from .core.handler import SteeringHandler
29+
30+
# Handler implementations
31+
from .handlers.llm import LLMPromptMapper, LLMSteeringHandler
32+
33+
__all__ = [
34+
"SteeringAction",
35+
"Proceed",
36+
"Guide",
37+
"Interrupt",
38+
"SteeringHandler",
39+
"SteeringContextCallback",
40+
"SteeringContextProvider",
41+
"LedgerBeforeToolCall",
42+
"LedgerAfterToolCall",
43+
"LedgerProvider",
44+
"LLMSteeringHandler",
45+
"LLMPromptMapper",
46+
]
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""Context providers for steering evaluation."""
2+
3+
from .ledger_provider import (
4+
LedgerAfterToolCall,
5+
LedgerBeforeToolCall,
6+
LedgerProvider,
7+
)
8+
9+
__all__ = [
10+
"LedgerAfterToolCall",
11+
"LedgerBeforeToolCall",
12+
"LedgerProvider",
13+
]
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
"""Ledger context provider for comprehensive agent activity tracking.
2+
3+
Tracks complete agent activity ledger including tool calls, conversation history,
4+
and timing information. This comprehensive audit trail enables steering handlers
5+
to make informed guidance decisions based on agent behavior patterns and history.
6+
7+
Data captured:
8+
9+
- Tool call history with inputs, outputs, timing, success/failure
10+
- Conversation messages and agent responses
11+
- Session metadata and timing information
12+
- Error patterns and recovery attempts
13+
14+
Usage:
15+
Use as context provider functions or mix into steering handlers.
16+
"""
17+
18+
import logging
19+
from datetime import datetime
20+
from typing import Any
21+
22+
from ....hooks.events import AfterToolCallEvent, BeforeToolCallEvent
23+
from ..core.context import SteeringContext, SteeringContextCallback, SteeringContextProvider
24+
25+
logger = logging.getLogger(__name__)
26+
27+
28+
class LedgerBeforeToolCall(SteeringContextCallback[BeforeToolCallEvent]):
29+
"""Context provider for ledger tracking before tool calls."""
30+
31+
def __init__(self) -> None:
32+
"""Initialize the ledger provider."""
33+
self.session_start = datetime.now().isoformat()
34+
35+
def __call__(self, event: BeforeToolCallEvent, steering_context: SteeringContext, **kwargs: Any) -> None:
36+
"""Update ledger before tool call."""
37+
ledger = steering_context.data.get("ledger") or {}
38+
39+
if not ledger:
40+
ledger = {
41+
"session_start": self.session_start,
42+
"tool_calls": [],
43+
"conversation_history": [],
44+
"session_metadata": {},
45+
}
46+
47+
tool_call_entry = {
48+
"timestamp": datetime.now().isoformat(),
49+
"tool_name": event.tool_use.get("name"),
50+
"tool_args": event.tool_use.get("arguments", {}),
51+
"status": "pending",
52+
}
53+
ledger["tool_calls"].append(tool_call_entry)
54+
steering_context.data.set("ledger", ledger)
55+
56+
57+
class LedgerAfterToolCall(SteeringContextCallback[AfterToolCallEvent]):
58+
"""Context provider for ledger tracking after tool calls."""
59+
60+
def __call__(self, event: AfterToolCallEvent, steering_context: SteeringContext, **kwargs: Any) -> None:
61+
"""Update ledger after tool call."""
62+
ledger = steering_context.data.get("ledger") or {}
63+
64+
if ledger.get("tool_calls"):
65+
last_call = ledger["tool_calls"][-1]
66+
last_call.update(
67+
{
68+
"completion_timestamp": datetime.now().isoformat(),
69+
"status": event.result["status"],
70+
"result": event.result["content"],
71+
"error": str(event.exception) if event.exception else None,
72+
}
73+
)
74+
steering_context.data.set("ledger", ledger)
75+
76+
77+
class LedgerProvider(SteeringContextProvider):
78+
"""Combined ledger context provider for both before and after tool calls."""
79+
80+
def context_providers(self, **kwargs: Any) -> list[SteeringContextCallback]:
81+
"""Return ledger context providers with shared state."""
82+
return [
83+
LedgerBeforeToolCall(),
84+
LedgerAfterToolCall(),
85+
]
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
"""Core steering system interfaces and base classes."""
2+
3+
from .action import Guide, Interrupt, Proceed, SteeringAction
4+
from .handler import SteeringHandler
5+
6+
__all__ = ["SteeringAction", "Proceed", "Guide", "Interrupt", "SteeringHandler"]
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""SteeringAction types for steering evaluation results.
2+
3+
Defines structured outcomes from steering handlers that determine how tool calls
4+
should be handled. SteeringActions enable modular prompting by providing just-in-time
5+
feedback rather than front-loading all instructions in monolithic prompts.
6+
7+
Flow:
8+
SteeringHandler.steer() → SteeringAction → BeforeToolCallEvent handling
9+
↓ ↓ ↓
10+
Evaluate context Action type Tool execution modified
11+
12+
SteeringAction types:
13+
Proceed: Tool executes immediately (no intervention needed)
14+
Guide: Tool cancelled, agent receives contextual feedback to explore alternatives
15+
Interrupt: Tool execution paused for human input via interrupt system
16+
17+
Extensibility:
18+
New action types can be added to the union. Always handle the default
19+
case in pattern matching to maintain backward compatibility.
20+
"""
21+
22+
from typing import Annotated, Literal
23+
24+
from pydantic import BaseModel, Field
25+
26+
27+
class Proceed(BaseModel):
28+
"""Allow tool to execute immediately without intervention.
29+
30+
The tool call proceeds as planned. The reason provides context
31+
for logging and debugging purposes.
32+
"""
33+
34+
type: Literal["proceed"] = "proceed"
35+
reason: str
36+
37+
38+
class Guide(BaseModel):
39+
"""Cancel tool and provide contextual feedback for agent to explore alternatives.
40+
41+
The tool call is cancelled and the agent receives the reason as contextual
42+
feedback to help them consider alternative approaches while maintaining
43+
adaptive reasoning capabilities.
44+
"""
45+
46+
type: Literal["guide"] = "guide"
47+
reason: str
48+
49+
50+
class Interrupt(BaseModel):
51+
"""Pause tool execution for human input via interrupt system.
52+
53+
The tool call is paused and human input is requested through Strands'
54+
interrupt system. The human can approve or deny the operation, and their
55+
decision determines whether the tool executes or is cancelled.
56+
"""
57+
58+
type: Literal["interrupt"] = "interrupt"
59+
reason: str
60+
61+
62+
# SteeringAction union - extensible for future action types
63+
# IMPORTANT: Always handle the default case when pattern matching
64+
# to maintain backward compatibility as new action types are added
65+
SteeringAction = Annotated[Proceed | Guide | Interrupt, Field(discriminator="type")]

0 commit comments

Comments
 (0)