Skip to content

Commit 62f986b

Browse files
authored
Merge branch 'main' into resume-from-tool-use
2 parents 7d7f1f2 + 999e654 commit 62f986b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+4255
-129
lines changed
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
name: PR Size Labeler
2+
3+
on:
4+
pull_request_target:
5+
branches: main
6+
7+
jobs:
8+
label-size:
9+
runs-on: ubuntu-latest
10+
permissions:
11+
pull-requests: write
12+
issues: write
13+
steps:
14+
- name: Calculate PR size and apply label
15+
uses: actions/github-script@v8
16+
with:
17+
script: |
18+
const pr = context.payload.pull_request;
19+
const totalChanges = pr.additions + pr.deletions;
20+
21+
// Remove existing size labels
22+
const labels = await github.rest.issues.listLabelsOnIssue({
23+
owner: context.repo.owner,
24+
repo: context.repo.repo,
25+
issue_number: context.payload.pull_request.number
26+
});
27+
28+
for (const label of labels.data) {
29+
if (label.name.startsWith('size/')) {
30+
await github.rest.issues.removeLabel({
31+
owner: context.repo.owner,
32+
repo: context.repo.repo,
33+
issue_number: context.payload.pull_request.number,
34+
name: label.name
35+
});
36+
}
37+
}
38+
39+
// Determine and apply new size label
40+
let sizeLabel;
41+
if (totalChanges <= 20) sizeLabel = 'size/xs';
42+
else if (totalChanges <= 100) sizeLabel = 'size/s';
43+
else if (totalChanges <= 500) sizeLabel = 'size/m';
44+
else if (totalChanges <= 1000) sizeLabel = 'size/l';
45+
else {
46+
sizeLabel = 'size/xl';
47+
}
48+
49+
await github.rest.issues.addLabels({
50+
owner: context.repo.owner,
51+
repo: context.repo.repo,
52+
issue_number: context.payload.pull_request.number,
53+
labels: [sizeLabel]
54+
});
55+
56+
if (sizeLabel === 'size/xl') {
57+
core.setFailed(`PR is too large (${totalChanges} lines). Please split into smaller PRs.`);
58+
}

src/strands/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,12 @@
55
from .tools.decorator import tool
66
from .types.tools import ToolContext
77

8-
__all__ = ["Agent", "agent", "models", "tool", "types", "telemetry", "ToolContext"]
8+
__all__ = [
9+
"Agent",
10+
"agent",
11+
"models",
12+
"tool",
13+
"ToolContext",
14+
"types",
15+
"telemetry",
16+
]

src/strands/agent/agent.py

Lines changed: 80 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from ..tools.executors import ConcurrentToolExecutor
5151
from ..tools.executors._executor import ToolExecutor
5252
from ..tools.registry import ToolRegistry
53+
from ..tools.structured_output._structured_output_context import StructuredOutputContext
5354
from ..tools.watcher import ToolWatcher
5455
from ..types._events import AgentResultEvent, InitEventLoopEvent, ModelStreamChunkEvent, TypedEvent
5556
from ..types.agent import AgentInput
@@ -216,6 +217,7 @@ def __init__(
216217
messages: Optional[Messages] = None,
217218
tools: Optional[list[Union[str, dict[str, str], Any]]] = None,
218219
system_prompt: Optional[str] = None,
220+
structured_output_model: Optional[Type[BaseModel]] = None,
219221
callback_handler: Optional[
220222
Union[Callable[..., Any], _DefaultCallbackHandlerSentinel]
221223
] = _DEFAULT_CALLBACK_HANDLER,
@@ -251,6 +253,10 @@ def __init__(
251253
If provided, only these tools will be available. If None, all tools will be available.
252254
system_prompt: System prompt to guide model behavior.
253255
If None, the model will behave according to its default settings.
256+
structured_output_model: Pydantic model type(s) for structured output.
257+
When specified, all agent calls will attempt to return structured output of this type.
258+
This can be overridden on the agent invocation.
259+
Defaults to None (no structured output).
254260
callback_handler: Callback for processing events as they happen during agent execution.
255261
If not provided (using the default), a new PrintingCallbackHandler instance is created.
256262
If explicitly set to None, null_callback_handler is used.
@@ -280,8 +286,8 @@ def __init__(
280286
"""
281287
self.model = BedrockModel() if not model else BedrockModel(model_id=model) if isinstance(model, str) else model
282288
self.messages = messages if messages is not None else []
283-
284289
self.system_prompt = system_prompt
290+
self._default_structured_output_model = structured_output_model
285291
self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT)
286292
self.name = name or _DEFAULT_AGENT_NAME
287293
self.description = description
@@ -383,7 +389,12 @@ def tool_names(self) -> list[str]:
383389
return list(all_tools.keys())
384390

385391
def __call__(
386-
self, prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, **kwargs: Any
392+
self,
393+
prompt: AgentInput = None,
394+
*,
395+
invocation_state: dict[str, Any] | None = None,
396+
structured_output_model: Type[BaseModel] | None = None,
397+
**kwargs: Any,
387398
) -> AgentResult:
388399
"""Process a natural language prompt through the agent's event loop.
389400
@@ -400,6 +411,7 @@ def __call__(
400411
- list[Message]: Complete messages with roles
401412
- None: Use existing conversation history
402413
invocation_state: Additional parameters to pass through the event loop.
414+
structured_output_model: Pydantic model type(s) for structured output (overrides agent default).
403415
**kwargs: Additional parameters to pass through the event loop.[Deprecating]
404416
405417
Returns:
@@ -409,17 +421,27 @@ def __call__(
409421
- message: The final message from the model
410422
- metrics: Performance metrics from the event loop
411423
- state: The final state of the event loop
424+
- structured_output: Parsed structured output when structured_output_model was specified
412425
"""
413426

414427
def execute() -> AgentResult:
415-
return asyncio.run(self.invoke_async(prompt, invocation_state=invocation_state, **kwargs))
428+
return asyncio.run(
429+
self.invoke_async(
430+
prompt, invocation_state=invocation_state, structured_output_model=structured_output_model, **kwargs
431+
)
432+
)
416433

417434
with ThreadPoolExecutor() as executor:
418435
future = executor.submit(execute)
419436
return future.result()
420437

421438
async def invoke_async(
422-
self, prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, **kwargs: Any
439+
self,
440+
prompt: AgentInput = None,
441+
*,
442+
invocation_state: dict[str, Any] | None = None,
443+
structured_output_model: Type[BaseModel] | None = None,
444+
**kwargs: Any,
423445
) -> AgentResult:
424446
"""Process a natural language prompt through the agent's event loop.
425447
@@ -436,6 +458,7 @@ async def invoke_async(
436458
- list[Message]: Complete messages with roles
437459
- None: Use existing conversation history
438460
invocation_state: Additional parameters to pass through the event loop.
461+
structured_output_model: Pydantic model type(s) for structured output (overrides agent default).
439462
**kwargs: Additional parameters to pass through the event loop.[Deprecating]
440463
441464
Returns:
@@ -446,7 +469,9 @@ async def invoke_async(
446469
- metrics: Performance metrics from the event loop
447470
- state: The final state of the event loop
448471
"""
449-
events = self.stream_async(prompt, invocation_state=invocation_state, **kwargs)
472+
events = self.stream_async(
473+
prompt, invocation_state=invocation_state, structured_output_model=structured_output_model, **kwargs
474+
)
450475
async for event in events:
451476
_ = event
452477

@@ -473,6 +498,13 @@ def structured_output(self, output_model: Type[T], prompt: AgentInput = None) ->
473498
Raises:
474499
ValueError: If no conversation history or prompt is provided.
475500
"""
501+
warnings.warn(
502+
"Agent.structured_output method is deprecated."
503+
" You should pass in `structured_output_model` directly into the agent invocation."
504+
" see: https://strandsagents.com/latest/documentation/docs/user-guide/concepts/agents/structured-output/",
505+
category=DeprecationWarning,
506+
stacklevel=2,
507+
)
476508

477509
def execute() -> T:
478510
return asyncio.run(self.structured_output_async(output_model, prompt))
@@ -501,6 +533,13 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu
501533
if self._interrupt_state.activated:
502534
raise RuntimeError("cannot call structured output during interrupt")
503535

536+
warnings.warn(
537+
"Agent.structured_output_async method is deprecated."
538+
" You should pass in `structured_output_model` directly into the agent invocation."
539+
" see: https://strandsagents.com/latest/documentation/docs/user-guide/concepts/agents/structured-output/",
540+
category=DeprecationWarning,
541+
stacklevel=2,
542+
)
504543
self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self))
505544
with self.tracer.tracer.start_as_current_span(
506545
"execute_structured_output", kind=trace_api.SpanKind.CLIENT
@@ -545,7 +584,12 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu
545584
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
546585

547586
async def stream_async(
548-
self, prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, **kwargs: Any
587+
self,
588+
prompt: AgentInput = None,
589+
*,
590+
invocation_state: dict[str, Any] | None = None,
591+
structured_output_model: Type[BaseModel] | None = None,
592+
**kwargs: Any,
549593
) -> AsyncIterator[Any]:
550594
"""Process a natural language prompt and yield events as an async iterator.
551595
@@ -562,6 +606,7 @@ async def stream_async(
562606
- list[Message]: Complete messages with roles
563607
- None: Use existing conversation history
564608
invocation_state: Additional parameters to pass through the event loop.
609+
structured_output_model: Pydantic model type(s) for structured output (overrides agent default).
565610
**kwargs: Additional parameters to pass to the event loop.[Deprecating]
566611
567612
Yields:
@@ -606,7 +651,7 @@ async def stream_async(
606651

607652
with trace_api.use_span(self.trace_span):
608653
try:
609-
events = self._run_loop(messages, invocation_state=merged_state)
654+
events = self._run_loop(messages, merged_state, structured_output_model)
610655

611656
async for event in events:
612657
event.prepare(invocation_state=merged_state)
@@ -658,12 +703,18 @@ def _resume_interrupt(self, prompt: AgentInput) -> None:
658703

659704
self._interrupt_state.interrupts[interrupt_id].response = interrupt_response
660705

661-
async def _run_loop(self, messages: Messages, invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]:
706+
async def _run_loop(
707+
self,
708+
messages: Messages,
709+
invocation_state: dict[str, Any],
710+
structured_output_model: Type[BaseModel] | None = None,
711+
) -> AsyncGenerator[TypedEvent, None]:
662712
"""Execute the agent's event loop with the given message and parameters.
663713
664714
Args:
665715
messages: The input messages to add to the conversation.
666716
invocation_state: Additional parameters to pass to the event loop.
717+
structured_output_model: Optional Pydantic model type for structured output.
667718
668719
Yields:
669720
Events from the event loop cycle.
@@ -676,8 +727,12 @@ async def _run_loop(self, messages: Messages, invocation_state: dict[str, Any])
676727
for message in messages:
677728
self._append_message(message)
678729

730+
structured_output_context = StructuredOutputContext(
731+
structured_output_model or self._default_structured_output_model
732+
)
733+
679734
# Execute the event loop cycle with retry logic for context limits
680-
events = self._execute_event_loop_cycle(invocation_state)
735+
events = self._execute_event_loop_cycle(invocation_state, structured_output_context)
681736
async for event in events:
682737
# Signal from the model provider that the message sent by the user should be redacted,
683738
# likely due to a guardrail.
@@ -698,24 +753,33 @@ async def _run_loop(self, messages: Messages, invocation_state: dict[str, Any])
698753
self.conversation_manager.apply_management(self)
699754
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
700755

701-
async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]:
756+
async def _execute_event_loop_cycle(
757+
self, invocation_state: dict[str, Any], structured_output_context: StructuredOutputContext | None = None
758+
) -> AsyncGenerator[TypedEvent, None]:
702759
"""Execute the event loop cycle with retry logic for context window limits.
703760
704761
This internal method handles the execution of the event loop cycle and implements
705762
retry logic for handling context window overflow exceptions by reducing the
706763
conversation context and retrying.
707764
765+
Args:
766+
invocation_state: Additional parameters to pass to the event loop.
767+
structured_output_context: Optional structured output context for this invocation.
768+
708769
Yields:
709770
Events of the loop cycle.
710771
"""
711772
# Add `Agent` to invocation_state to keep backwards-compatibility
712773
invocation_state["agent"] = self
713774

775+
if structured_output_context:
776+
structured_output_context.register_tool(self.tool_registry)
777+
714778
try:
715-
# Execute the main event loop cycle
716779
events = event_loop_cycle(
717780
agent=self,
718781
invocation_state=invocation_state,
782+
structured_output_context=structured_output_context,
719783
)
720784
async for event in events:
721785
yield event
@@ -728,10 +792,14 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A
728792
if self._session_manager:
729793
self._session_manager.sync_agent(self)
730794

731-
events = self._execute_event_loop_cycle(invocation_state)
795+
events = self._execute_event_loop_cycle(invocation_state, structured_output_context)
732796
async for event in events:
733797
yield event
734798

799+
finally:
800+
if structured_output_context:
801+
structured_output_context.cleanup(self.tool_registry)
802+
735803
def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
736804
if self._interrupt_state.activated:
737805
return []

src/strands/agent/agent_result.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
"""
55

66
from dataclasses import dataclass
7-
from typing import Any, Sequence
7+
from typing import Any, Sequence, cast
8+
9+
from pydantic import BaseModel
810

911
from ..interrupt import Interrupt
1012
from ..telemetry.metrics import EventLoopMetrics
@@ -22,13 +24,15 @@ class AgentResult:
2224
metrics: Performance metrics collected during processing.
2325
state: Additional state information from the event loop.
2426
interrupts: List of interrupts if raised by user.
27+
structured_output: Parsed structured output when structured_output_model was specified.
2528
"""
2629

2730
stop_reason: StopReason
2831
message: Message
2932
metrics: EventLoopMetrics
3033
state: Any
3134
interrupts: Sequence[Interrupt] | None = None
35+
structured_output: BaseModel | None = None
3236

3337
def __str__(self) -> str:
3438
"""Get the agent's last message as a string.
@@ -46,3 +50,34 @@ def __str__(self) -> str:
4650
if isinstance(item, dict) and "text" in item:
4751
result += item.get("text", "") + "\n"
4852
return result
53+
54+
@classmethod
55+
def from_dict(cls, data: dict[str, Any]) -> "AgentResult":
56+
"""Rehydrate an AgentResult from persisted JSON.
57+
58+
Args:
59+
data: Dictionary containing the serialized AgentResult data
60+
Returns:
61+
AgentResult instance
62+
Raises:
63+
TypeError: If the data format is invalid@
64+
"""
65+
if data.get("type") != "agent_result":
66+
raise TypeError(f"AgentResult.from_dict: unexpected type {data.get('type')!r}")
67+
68+
message = cast(Message, data.get("message"))
69+
stop_reason = cast(StopReason, data.get("stop_reason"))
70+
71+
return cls(message=message, stop_reason=stop_reason, metrics=EventLoopMetrics(), state={})
72+
73+
def to_dict(self) -> dict[str, Any]:
74+
"""Convert this AgentResult to JSON-serializable dictionary.
75+
76+
Returns:
77+
Dictionary containing serialized AgentResult data
78+
"""
79+
return {
80+
"type": "agent_result",
81+
"message": self.message,
82+
"stop_reason": self.stop_reason,
83+
}

0 commit comments

Comments
 (0)