|
1 | 1 | """Tests for the main SupportAgent.""" |
2 | 2 |
|
| 3 | +import asyncio |
| 4 | +from typing import cast |
3 | 5 | from unittest.mock import AsyncMock, MagicMock, patch |
4 | 6 |
|
5 | 7 | import pytest |
6 | 8 |
|
| 9 | +from openai import APIConnectionError |
| 10 | +from openai.types.responses import EasyInputMessageParam |
| 11 | + |
7 | 12 |
|
8 | 13 | class MockOutputItem: |
9 | 14 | """Mock for response output items.""" |
@@ -93,11 +98,10 @@ async def test_handles_function_calls(self): |
93 | 98 | "function_call", |
94 | 99 | call_id="call_123", |
95 | 100 | name="search_policies", |
96 | | - arguments='{"question": "return policy"}' |
| 101 | + arguments='{"question": "return policy"}', |
97 | 102 | ) |
98 | 103 | first_response = MockResponse( |
99 | | - output_text="", |
100 | | - output=[function_call_output] |
| 104 | + output_text="", output=[function_call_output] |
101 | 105 | ) |
102 | 106 |
|
103 | 107 | # Second response is final message |
@@ -135,9 +139,11 @@ def mock_callback(name, args): |
135 | 139 | "function_call", |
136 | 140 | call_id="call_123", |
137 | 141 | name="search_policies", |
138 | | - arguments='{"question": "test"}' |
| 142 | + arguments='{"question": "test"}', |
| 143 | + ) |
| 144 | + first_response = MockResponse( |
| 145 | + output_text="", output=[function_call_output] |
139 | 146 | ) |
140 | | - first_response = MockResponse(output_text="", output=[function_call_output]) |
141 | 147 | final_response = MockResponse(output_text="Done") |
142 | 148 |
|
143 | 149 | mock_client.responses.create = AsyncMock( |
@@ -167,13 +173,13 @@ async def test_multiple_function_calls(self): |
167 | 173 | "function_call", |
168 | 174 | call_id="call_1", |
169 | 175 | name="search_policies", |
170 | | - arguments='{"question": "returns"}' |
| 176 | + arguments='{"question": "returns"}', |
171 | 177 | ) |
172 | 178 | call2 = MockOutputItem( |
173 | 179 | "function_call", |
174 | 180 | call_id="call_2", |
175 | 181 | name="query_orders_database", |
176 | | - arguments='{"query": "SELECT 1"}' |
| 182 | + arguments='{"query": "SELECT 1"}', |
177 | 183 | ) |
178 | 184 | first_response = MockResponse(output_text="", output=[call1, call2]) |
179 | 185 | final_response = MockResponse(output_text="Here's the info") |
@@ -242,3 +248,101 @@ def test_has_function_calls_mixed(self): |
242 | 248 |
|
243 | 249 | assert agent._has_function_calls(output) is True |
244 | 250 |
|
| 251 | + def test_truncate_conversation_when_over_limit(self): |
| 252 | + """Should truncate conversation when over MAX_CONVERSATION_ITEMS.""" |
| 253 | + with patch("agent.core.AsyncOpenAI"): |
| 254 | + from agent.core import SupportAgent, MAX_CONVERSATION_ITEMS |
| 255 | + |
| 256 | + agent = SupportAgent() |
| 257 | + # Add more items than the limit |
| 258 | + for i in range(MAX_CONVERSATION_ITEMS + 10): |
| 259 | + msg: EasyInputMessageParam = {"role": "user", "content": f"msg {i}"} |
| 260 | + agent.conversation.append(msg) |
| 261 | + |
| 262 | + agent._truncate_conversation() |
| 263 | + |
| 264 | + assert len(agent.conversation) == MAX_CONVERSATION_ITEMS |
| 265 | + |
| 266 | + def test_truncate_conversation_keeps_recent(self): |
| 267 | + """Should keep most recent messages when truncating.""" |
| 268 | + with patch("agent.core.AsyncOpenAI"): |
| 269 | + from agent.core import SupportAgent, MAX_CONVERSATION_ITEMS |
| 270 | + |
| 271 | + agent = SupportAgent() |
| 272 | + # Add numbered messages |
| 273 | + for i in range(MAX_CONVERSATION_ITEMS + 5): |
| 274 | + msg: EasyInputMessageParam = {"role": "user", "content": f"msg {i}"} |
| 275 | + agent.conversation.append(msg) |
| 276 | + |
| 277 | + agent._truncate_conversation() |
| 278 | + |
| 279 | + # Should have kept the last MAX_CONVERSATION_ITEMS messages |
| 280 | + first_item = cast(EasyInputMessageParam, agent.conversation[0]) |
| 281 | + last_item = cast(EasyInputMessageParam, agent.conversation[-1]) |
| 282 | + assert first_item.get("content") == "msg 5" |
| 283 | + assert last_item.get("content") == f"msg {MAX_CONVERSATION_ITEMS + 4}" |
| 284 | + |
| 285 | + def test_truncate_conversation_no_op_when_under_limit(self): |
| 286 | + """Should not truncate when under limit.""" |
| 287 | + with patch("agent.core.AsyncOpenAI"): |
| 288 | + from agent.core import SupportAgent |
| 289 | + |
| 290 | + agent = SupportAgent() |
| 291 | + msg1: EasyInputMessageParam = {"role": "user", "content": "msg 1"} |
| 292 | + msg2: EasyInputMessageParam = {"role": "user", "content": "msg 2"} |
| 293 | + agent.conversation.append(msg1) |
| 294 | + agent.conversation.append(msg2) |
| 295 | + |
| 296 | + agent._truncate_conversation() |
| 297 | + |
| 298 | + assert len(agent.conversation) == 2 |
| 299 | + |
| 300 | + @pytest.mark.asyncio |
| 301 | + async def test_tool_timeout_returns_error(self): |
| 302 | + """Should return error message when tool times out.""" |
| 303 | + with patch("agent.core.AsyncOpenAI") as mock_openai: |
| 304 | + with patch("agent.core.handle_tool_call") as mock_handle: |
| 305 | + with patch("agent.core.TOOL_TIMEOUT_SECONDS", 0.01): |
| 306 | + mock_client = MagicMock() |
| 307 | + mock_openai.return_value = mock_client |
| 308 | + |
| 309 | + # Make handle_tool_call hang |
| 310 | + async def slow_handler(*args, **kwargs): |
| 311 | + await asyncio.sleep(1) |
| 312 | + return "result" |
| 313 | + |
| 314 | + mock_handle.side_effect = slow_handler |
| 315 | + |
| 316 | + from agent.core import SupportAgent |
| 317 | + |
| 318 | + agent = SupportAgent() |
| 319 | + result = await agent._execute_tool( |
| 320 | + "search_policies", '{"question": "test"}' |
| 321 | + ) |
| 322 | + |
| 323 | + assert "timed out" in result |
| 324 | + |
| 325 | + @pytest.mark.asyncio |
| 326 | + async def test_api_retry_on_transient_error(self): |
| 327 | + """Should retry API calls on transient errors.""" |
| 328 | + with patch("agent.core.AsyncOpenAI") as mock_openai: |
| 329 | + mock_client = MagicMock() |
| 330 | + mock_openai.return_value = mock_client |
| 331 | + |
| 332 | + # Fail twice, then succeed |
| 333 | + mock_response = MockResponse(output_text="Success after retry") |
| 334 | + mock_client.responses.create = AsyncMock( |
| 335 | + side_effect=[ |
| 336 | + APIConnectionError(request=MagicMock()), |
| 337 | + APIConnectionError(request=MagicMock()), |
| 338 | + mock_response, |
| 339 | + ] |
| 340 | + ) |
| 341 | + |
| 342 | + from agent.core import SupportAgent |
| 343 | + |
| 344 | + agent = SupportAgent() |
| 345 | + result = await agent.chat("Hello") |
| 346 | + |
| 347 | + assert result == "Success after retry" |
| 348 | + assert mock_client.responses.create.call_count == 3 |
0 commit comments