Skip to content

Commit fca208b

Browse files
authored
feat: pass A2A request context metadata as invocation state (#1854)
1 parent 2da3f7c commit fca208b

File tree

2 files changed

+140
-2
lines changed

2 files changed

+140
-2
lines changed

src/strands/multiagent/a2a/executor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,12 @@ async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater
128128
self._current_artifact_id = str(uuid.uuid4())
129129
self._is_first_chunk = True
130130

131+
# Pass the A2A RequestContext through invocation state so downstream
132+
# tools and hooks can access request metadata, task info, configuration, etc.
133+
invocation_state: dict[str, Any] = {"a2a_request_context": context}
134+
131135
try:
132-
async for event in self.agent.stream_async(content_blocks):
136+
async for event in self.agent.stream_async(content_blocks, invocation_state=invocation_state):
133137
await self._handle_streaming_event(event, updater)
134138
except Exception:
135139
logger.exception("Error in streaming execution")

tests/strands/multiagent/a2a/test_executor.py

Lines changed: 135 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Tests for the StrandsA2AExecutor class."""
22

33
import base64
4+
from typing import Any
45
from unittest.mock import AsyncMock, MagicMock, patch
56

67
import pytest
@@ -1196,4 +1197,137 @@ async def test_a2a_compliant_handle_result_not_first_chunk(mock_strands_agent):
11961197
assert mock_updater.add_artifact.call_args[1]["artifact_id"] == "artifact-abc"
11971198
assert mock_updater.add_artifact.call_args[1]["append"] is True
11981199
assert mock_updater.add_artifact.call_args[1]["last_chunk"] is True
1199-
mock_updater.complete.assert_called_once()
1200+
1201+
1202+
# Tests for invocation state propagation from A2A request context
1203+
1204+
1205+
def _setup_streaming_context(
1206+
mock_strands_agent: MagicMock,
1207+
mock_request_context: MagicMock,
1208+
) -> None:
1209+
"""Set up common mocks for invocation state streaming tests.
1210+
1211+
Args:
1212+
mock_strands_agent: The mock Strands Agent.
1213+
mock_request_context: The mock RequestContext.
1214+
"""
1215+
1216+
async def mock_stream(content_blocks: list, **kwargs: Any) -> Any:
1217+
yield {"result": MagicMock(spec=SAAgentResult)}
1218+
1219+
mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream)
1220+
1221+
# Set up message with a text part
1222+
mock_text_part = MagicMock(spec=TextPart)
1223+
mock_text_part.text = "test input"
1224+
mock_part = MagicMock()
1225+
mock_part.root = mock_text_part
1226+
mock_message = MagicMock()
1227+
mock_message.parts = [mock_part]
1228+
mock_request_context.message = mock_message
1229+
1230+
1231+
@pytest.mark.asyncio
1232+
async def test_invocation_state_contains_request_context(mock_strands_agent, mock_request_context, mock_event_queue):
1233+
"""Test that the full RequestContext is passed as a2a_request_context in invocation state."""
1234+
mock_task = MagicMock()
1235+
mock_task.id = "task-42"
1236+
mock_task.context_id = "ctx-99"
1237+
mock_request_context.current_task = mock_task
1238+
mock_request_context.metadata = {"caller": "test-client"}
1239+
1240+
_setup_streaming_context(mock_strands_agent, mock_request_context)
1241+
1242+
executor = StrandsA2AExecutor(mock_strands_agent)
1243+
await executor.execute(mock_request_context, mock_event_queue)
1244+
1245+
mock_strands_agent.stream_async.assert_called_once()
1246+
call_kwargs = mock_strands_agent.stream_async.call_args[1]
1247+
invocation_state = call_kwargs["invocation_state"]
1248+
1249+
assert invocation_state is not None
1250+
assert invocation_state["a2a_request_context"] is mock_request_context
1251+
1252+
1253+
@pytest.mark.asyncio
1254+
async def test_invocation_state_context_exposes_metadata(mock_strands_agent, mock_request_context, mock_event_queue):
1255+
"""Test that metadata is accessible through the RequestContext in invocation state."""
1256+
test_metadata = {"caller": "test-client", "session": "abc-123"}
1257+
mock_request_context.metadata = test_metadata
1258+
mock_task = MagicMock()
1259+
mock_task.id = "task-1"
1260+
mock_task.context_id = "ctx-1"
1261+
mock_request_context.current_task = mock_task
1262+
1263+
_setup_streaming_context(mock_strands_agent, mock_request_context)
1264+
1265+
executor = StrandsA2AExecutor(mock_strands_agent)
1266+
await executor.execute(mock_request_context, mock_event_queue)
1267+
1268+
call_kwargs = mock_strands_agent.stream_async.call_args[1]
1269+
context = call_kwargs["invocation_state"]["a2a_request_context"]
1270+
1271+
assert context.metadata == test_metadata
1272+
1273+
1274+
@pytest.mark.asyncio
1275+
async def test_invocation_state_context_exposes_task_info(mock_strands_agent, mock_request_context, mock_event_queue):
1276+
"""Test that task info is accessible through the RequestContext in invocation state."""
1277+
mock_task = MagicMock()
1278+
mock_task.id = "task-100"
1279+
mock_task.context_id = "ctx-200"
1280+
mock_request_context.current_task = mock_task
1281+
1282+
_setup_streaming_context(mock_strands_agent, mock_request_context)
1283+
1284+
executor = StrandsA2AExecutor(mock_strands_agent)
1285+
await executor.execute(mock_request_context, mock_event_queue)
1286+
1287+
call_kwargs = mock_strands_agent.stream_async.call_args[1]
1288+
context = call_kwargs["invocation_state"]["a2a_request_context"]
1289+
1290+
assert context.current_task.id == "task-100"
1291+
assert context.current_task.context_id == "ctx-200"
1292+
1293+
1294+
@pytest.mark.asyncio
1295+
async def test_invocation_state_context_when_no_task(mock_strands_agent, mock_request_context, mock_event_queue):
1296+
"""Test that RequestContext is passed even when there is no current task."""
1297+
mock_request_context.current_task = None
1298+
mock_request_context.metadata = {}
1299+
1300+
_setup_streaming_context(mock_strands_agent, mock_request_context)
1301+
1302+
executor = StrandsA2AExecutor(mock_strands_agent)
1303+
1304+
with patch("strands.multiagent.a2a.executor.new_task") as mock_new_task:
1305+
mock_new_task.return_value = MagicMock(id="generated-id", context_id="generated-ctx")
1306+
await executor.execute(mock_request_context, mock_event_queue)
1307+
1308+
call_kwargs = mock_strands_agent.stream_async.call_args[1]
1309+
invocation_state = call_kwargs["invocation_state"]
1310+
1311+
assert invocation_state["a2a_request_context"] is mock_request_context
1312+
1313+
1314+
@pytest.mark.asyncio
1315+
async def test_invocation_state_with_a2a_compliant_streaming(
1316+
mock_strands_agent, mock_request_context, mock_event_queue
1317+
):
1318+
"""Test that invocation state is passed correctly in A2A-compliant streaming mode."""
1319+
mock_task = MagicMock()
1320+
mock_task.id = "task-compliant"
1321+
mock_task.context_id = "ctx-compliant"
1322+
mock_request_context.current_task = mock_task
1323+
1324+
_setup_streaming_context(mock_strands_agent, mock_request_context)
1325+
1326+
executor = StrandsA2AExecutor(mock_strands_agent, enable_a2a_compliant_streaming=True)
1327+
await executor.execute(mock_request_context, mock_event_queue)
1328+
1329+
call_kwargs = mock_strands_agent.stream_async.call_args[1]
1330+
invocation_state = call_kwargs["invocation_state"]
1331+
1332+
assert invocation_state is not None
1333+
assert invocation_state["a2a_request_context"] is mock_request_context

0 commit comments

Comments
 (0)