|
1 | 1 | """Tests for the StrandsA2AExecutor class.""" |
2 | 2 |
|
3 | 3 | import base64 |
| 4 | +from typing import Any |
4 | 5 | from unittest.mock import AsyncMock, MagicMock, patch |
5 | 6 |
|
6 | 7 | import pytest |
@@ -1196,4 +1197,137 @@ async def test_a2a_compliant_handle_result_not_first_chunk(mock_strands_agent): |
1196 | 1197 | assert mock_updater.add_artifact.call_args[1]["artifact_id"] == "artifact-abc" |
1197 | 1198 | assert mock_updater.add_artifact.call_args[1]["append"] is True |
1198 | 1199 | assert mock_updater.add_artifact.call_args[1]["last_chunk"] is True |
1199 | | - mock_updater.complete.assert_called_once() |
| 1200 | + |
| 1201 | + |
| 1202 | +# Tests for invocation state propagation from A2A request context |
| 1203 | + |
| 1204 | + |
| 1205 | +def _setup_streaming_context( |
| 1206 | + mock_strands_agent: MagicMock, |
| 1207 | + mock_request_context: MagicMock, |
| 1208 | +) -> None: |
| 1209 | + """Set up common mocks for invocation state streaming tests. |
| 1210 | +
|
| 1211 | + Args: |
| 1212 | + mock_strands_agent: The mock Strands Agent. |
| 1213 | + mock_request_context: The mock RequestContext. |
| 1214 | + """ |
| 1215 | + |
| 1216 | + async def mock_stream(content_blocks: list, **kwargs: Any) -> Any: |
| 1217 | + yield {"result": MagicMock(spec=SAAgentResult)} |
| 1218 | + |
| 1219 | + mock_strands_agent.stream_async = MagicMock(side_effect=mock_stream) |
| 1220 | + |
| 1221 | + # Set up message with a text part |
| 1222 | + mock_text_part = MagicMock(spec=TextPart) |
| 1223 | + mock_text_part.text = "test input" |
| 1224 | + mock_part = MagicMock() |
| 1225 | + mock_part.root = mock_text_part |
| 1226 | + mock_message = MagicMock() |
| 1227 | + mock_message.parts = [mock_part] |
| 1228 | + mock_request_context.message = mock_message |
| 1229 | + |
| 1230 | + |
| 1231 | +@pytest.mark.asyncio |
| 1232 | +async def test_invocation_state_contains_request_context(mock_strands_agent, mock_request_context, mock_event_queue): |
| 1233 | + """Test that the full RequestContext is passed as a2a_request_context in invocation state.""" |
| 1234 | + mock_task = MagicMock() |
| 1235 | + mock_task.id = "task-42" |
| 1236 | + mock_task.context_id = "ctx-99" |
| 1237 | + mock_request_context.current_task = mock_task |
| 1238 | + mock_request_context.metadata = {"caller": "test-client"} |
| 1239 | + |
| 1240 | + _setup_streaming_context(mock_strands_agent, mock_request_context) |
| 1241 | + |
| 1242 | + executor = StrandsA2AExecutor(mock_strands_agent) |
| 1243 | + await executor.execute(mock_request_context, mock_event_queue) |
| 1244 | + |
| 1245 | + mock_strands_agent.stream_async.assert_called_once() |
| 1246 | + call_kwargs = mock_strands_agent.stream_async.call_args[1] |
| 1247 | + invocation_state = call_kwargs["invocation_state"] |
| 1248 | + |
| 1249 | + assert invocation_state is not None |
| 1250 | + assert invocation_state["a2a_request_context"] is mock_request_context |
| 1251 | + |
| 1252 | + |
| 1253 | +@pytest.mark.asyncio |
| 1254 | +async def test_invocation_state_context_exposes_metadata(mock_strands_agent, mock_request_context, mock_event_queue): |
| 1255 | + """Test that metadata is accessible through the RequestContext in invocation state.""" |
| 1256 | + test_metadata = {"caller": "test-client", "session": "abc-123"} |
| 1257 | + mock_request_context.metadata = test_metadata |
| 1258 | + mock_task = MagicMock() |
| 1259 | + mock_task.id = "task-1" |
| 1260 | + mock_task.context_id = "ctx-1" |
| 1261 | + mock_request_context.current_task = mock_task |
| 1262 | + |
| 1263 | + _setup_streaming_context(mock_strands_agent, mock_request_context) |
| 1264 | + |
| 1265 | + executor = StrandsA2AExecutor(mock_strands_agent) |
| 1266 | + await executor.execute(mock_request_context, mock_event_queue) |
| 1267 | + |
| 1268 | + call_kwargs = mock_strands_agent.stream_async.call_args[1] |
| 1269 | + context = call_kwargs["invocation_state"]["a2a_request_context"] |
| 1270 | + |
| 1271 | + assert context.metadata == test_metadata |
| 1272 | + |
| 1273 | + |
| 1274 | +@pytest.mark.asyncio |
| 1275 | +async def test_invocation_state_context_exposes_task_info(mock_strands_agent, mock_request_context, mock_event_queue): |
| 1276 | + """Test that task info is accessible through the RequestContext in invocation state.""" |
| 1277 | + mock_task = MagicMock() |
| 1278 | + mock_task.id = "task-100" |
| 1279 | + mock_task.context_id = "ctx-200" |
| 1280 | + mock_request_context.current_task = mock_task |
| 1281 | + |
| 1282 | + _setup_streaming_context(mock_strands_agent, mock_request_context) |
| 1283 | + |
| 1284 | + executor = StrandsA2AExecutor(mock_strands_agent) |
| 1285 | + await executor.execute(mock_request_context, mock_event_queue) |
| 1286 | + |
| 1287 | + call_kwargs = mock_strands_agent.stream_async.call_args[1] |
| 1288 | + context = call_kwargs["invocation_state"]["a2a_request_context"] |
| 1289 | + |
| 1290 | + assert context.current_task.id == "task-100" |
| 1291 | + assert context.current_task.context_id == "ctx-200" |
| 1292 | + |
| 1293 | + |
| 1294 | +@pytest.mark.asyncio |
| 1295 | +async def test_invocation_state_context_when_no_task(mock_strands_agent, mock_request_context, mock_event_queue): |
| 1296 | + """Test that RequestContext is passed even when there is no current task.""" |
| 1297 | + mock_request_context.current_task = None |
| 1298 | + mock_request_context.metadata = {} |
| 1299 | + |
| 1300 | + _setup_streaming_context(mock_strands_agent, mock_request_context) |
| 1301 | + |
| 1302 | + executor = StrandsA2AExecutor(mock_strands_agent) |
| 1303 | + |
| 1304 | + with patch("strands.multiagent.a2a.executor.new_task") as mock_new_task: |
| 1305 | + mock_new_task.return_value = MagicMock(id="generated-id", context_id="generated-ctx") |
| 1306 | + await executor.execute(mock_request_context, mock_event_queue) |
| 1307 | + |
| 1308 | + call_kwargs = mock_strands_agent.stream_async.call_args[1] |
| 1309 | + invocation_state = call_kwargs["invocation_state"] |
| 1310 | + |
| 1311 | + assert invocation_state["a2a_request_context"] is mock_request_context |
| 1312 | + |
| 1313 | + |
| 1314 | +@pytest.mark.asyncio |
| 1315 | +async def test_invocation_state_with_a2a_compliant_streaming( |
| 1316 | + mock_strands_agent, mock_request_context, mock_event_queue |
| 1317 | +): |
| 1318 | + """Test that invocation state is passed correctly in A2A-compliant streaming mode.""" |
| 1319 | + mock_task = MagicMock() |
| 1320 | + mock_task.id = "task-compliant" |
| 1321 | + mock_task.context_id = "ctx-compliant" |
| 1322 | + mock_request_context.current_task = mock_task |
| 1323 | + |
| 1324 | + _setup_streaming_context(mock_strands_agent, mock_request_context) |
| 1325 | + |
| 1326 | + executor = StrandsA2AExecutor(mock_strands_agent, enable_a2a_compliant_streaming=True) |
| 1327 | + await executor.execute(mock_request_context, mock_event_queue) |
| 1328 | + |
| 1329 | + call_kwargs = mock_strands_agent.stream_async.call_args[1] |
| 1330 | + invocation_state = call_kwargs["invocation_state"] |
| 1331 | + |
| 1332 | + assert invocation_state is not None |
| 1333 | + assert invocation_state["a2a_request_context"] is mock_request_context |
0 commit comments