Skip to content

Commit cc84e25

Browse files
committed
feat: add robustness features and update tests
- Add retry logic with exponential backoff (tenacity) - Add conversation truncation to prevent token overflow - Add tool execution timeout (30s) - Update tests for new robustness features - Update README with robustness docs and parallel execution
1 parent e5c7afd commit cc84e25

5 files changed

Lines changed: 185 additions & 25 deletions

File tree

README.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ A CLI-based customer support agent powered by OpenAI with SQL database and RAG t
1111
- **Policy Search**: RAG-powered search across company policies (returns, shipping, warranty)
1212
- **Conversation History**: Maintains context across multiple exchanges
1313
- **Multi-Agent SQL**: Uses generator and reviewer agents for accurate SQL queries
14+
- **Robustness**: Retry logic, conversation truncation, and tool timeouts
15+
- **CI/CD**: GitHub Actions with Python 3.11/3.12 matrix testing
1416

1517
## Architecture
1618

@@ -87,6 +89,18 @@ Think step by step:
8789

8890
This approach leverages the model's chain-of-thought capabilities for more accurate and explainable responses.
8991

92+
### Parallel Tool Execution
93+
94+
When the LLM returns multiple tool calls in a single response, they are executed **concurrently** using `asyncio.gather()` for better performance:
95+
96+
```python
97+
# Multiple tools run in parallel
98+
results = await asyncio.gather(*[
99+
self._execute_tool(item.name, item.arguments)
100+
for item in output if item.type == "function_call"
101+
])
102+
```
103+
90104
### Component Overview
91105

92106
| Component | Description |
@@ -230,6 +244,16 @@ uv run pytest --cov=. --cov-report=html
230244
2. Add the OpenAI function schema in `tools/definitions.py`
231245
3. Register the handler in `tools/router.py`
232246

247+
### Robustness Features
248+
249+
The agent includes several reliability improvements:
250+
251+
| Feature | Description |
252+
|---------|-------------|
253+
| **Retry Logic** | 3 attempts with exponential backoff for transient API failures |
254+
| **Conversation Truncation** | Keeps last 40 items (~20 turns) to prevent token overflow |
255+
| **Tool Timeout** | 30-second timeout prevents hanging on slow tool execution |
256+
233257
## Sample Data
234258

235259
The agent comes pre-seeded with:

agent/core.py

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,14 @@
33
import json
44
from typing import Any, Callable, cast
55

6-
from openai import AsyncOpenAI
6+
from openai import AsyncOpenAI, APIError, APIConnectionError, RateLimitError
77
from openai.types.responses import ResponseInputItemParam, EasyInputMessageParam
8+
from tenacity import (
9+
retry,
10+
stop_after_attempt,
11+
wait_exponential,
12+
retry_if_exception_type,
13+
)
814

915
from tools import TOOLS, handle_tool_call
1016

@@ -14,6 +20,10 @@
1420
# Callback type for agent activity notifications
1521
AgentCallback = Callable[[str, str, dict], None]
1622

23+
# Configuration
24+
MAX_CONVERSATION_ITEMS = 40 # ~20 turns (user + assistant)
25+
TOOL_TIMEOUT_SECONDS = 30.0
26+
1727

1828
INSTRUCTIONS = """You are a customer support agent for an e-commerce company.
1929
@@ -75,14 +85,11 @@ async def chat(self, user_message: str) -> str:
7585
user_msg: EasyInputMessageParam = {"role": "user", "content": user_message}
7686
self.conversation.append(user_msg)
7787

78-
# Call OpenAI Responses API
79-
response = await self.client.responses.create(
80-
model=self.model,
81-
instructions=INSTRUCTIONS,
82-
input=self.conversation,
83-
tools=TOOLS,
84-
reasoning={"effort": "medium"},
85-
)
88+
# Truncate conversation if too long to avoid token limits
89+
self._truncate_conversation()
90+
91+
# Call OpenAI Responses API with retry
92+
response = await self._call_api()
8693

8794
# Process output items - handle function calls
8895
while self._has_function_calls(response.output):
@@ -95,19 +102,35 @@ async def chat(self, user_message: str) -> str:
95102
)
96103
self.conversation.extend(tool_outputs)
97104

98-
response = await self.client.responses.create(
99-
model=self.model,
100-
instructions=INSTRUCTIONS,
101-
input=self.conversation,
102-
tools=TOOLS,
103-
reasoning={"effort": "medium"},
104-
)
105+
response = await self._call_api()
105106

106107
# Add final response to conversation history
107108
self.conversation.extend(cast(list[ResponseInputItemParam], response.output))
108109

109110
return response.output_text
110111

112+
@retry(
113+
stop=stop_after_attempt(3),
114+
wait=wait_exponential(multiplier=1, min=1, max=10),
115+
retry=retry_if_exception_type((APIError, APIConnectionError, RateLimitError)),
116+
reraise=True,
117+
)
118+
async def _call_api(self):
119+
"""Call OpenAI API with retry logic for transient failures."""
120+
return await self.client.responses.create(
121+
model=self.model,
122+
instructions=INSTRUCTIONS,
123+
input=self.conversation,
124+
tools=TOOLS,
125+
reasoning={"effort": "medium"},
126+
)
127+
128+
def _truncate_conversation(self) -> None:
129+
"""Truncate conversation to prevent token overflow."""
130+
if len(self.conversation) > MAX_CONVERSATION_ITEMS:
131+
# Keep the most recent items
132+
self.conversation = self.conversation[-MAX_CONVERSATION_ITEMS:]
133+
111134
def _has_function_calls(self, output: list[Any]) -> bool:
112135
"""Check if output contains any function calls."""
113136
return any(item.type == "function_call" for item in output)
@@ -139,13 +162,19 @@ async def _process_function_calls(
139162
)
140163

141164
async def _execute_tool(self, name: str, arguments: str) -> str:
142-
"""Execute a single tool call and return the result."""
165+
"""Execute a single tool call with timeout."""
143166
args = json.loads(arguments)
144167

145168
if self.on_tool_call:
146169
self.on_tool_call(name, args)
147170

148-
return await handle_tool_call(name, args, self.on_agent_activity)
171+
try:
172+
return await asyncio.wait_for(
173+
handle_tool_call(name, args, self.on_agent_activity),
174+
timeout=TOOL_TIMEOUT_SECONDS,
175+
)
176+
except asyncio.TimeoutError:
177+
return f"Error: Tool '{name}' timed out after {TOOL_TIMEOUT_SECONDS}s"
149178

150179
def clear_history(self) -> None:
151180
"""Clear conversation history."""

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ dependencies = [
1010
"openai>=2.14.0",
1111
"python-dotenv>=1.2.1",
1212
"rich>=14.2.0",
13+
"tenacity>=9.1.2",
1314
]
1415

1516
[dependency-groups]

tests/test_agent_core.py

Lines changed: 111 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
"""Tests for the main SupportAgent."""
22

3+
import asyncio
4+
from typing import cast
35
from unittest.mock import AsyncMock, MagicMock, patch
46

57
import pytest
68

9+
from openai import APIConnectionError
10+
from openai.types.responses import EasyInputMessageParam
11+
712

813
class MockOutputItem:
914
"""Mock for response output items."""
@@ -93,11 +98,10 @@ async def test_handles_function_calls(self):
9398
"function_call",
9499
call_id="call_123",
95100
name="search_policies",
96-
arguments='{"question": "return policy"}'
101+
arguments='{"question": "return policy"}',
97102
)
98103
first_response = MockResponse(
99-
output_text="",
100-
output=[function_call_output]
104+
output_text="", output=[function_call_output]
101105
)
102106

103107
# Second response is final message
@@ -135,9 +139,11 @@ def mock_callback(name, args):
135139
"function_call",
136140
call_id="call_123",
137141
name="search_policies",
138-
arguments='{"question": "test"}'
142+
arguments='{"question": "test"}',
143+
)
144+
first_response = MockResponse(
145+
output_text="", output=[function_call_output]
139146
)
140-
first_response = MockResponse(output_text="", output=[function_call_output])
141147
final_response = MockResponse(output_text="Done")
142148

143149
mock_client.responses.create = AsyncMock(
@@ -167,13 +173,13 @@ async def test_multiple_function_calls(self):
167173
"function_call",
168174
call_id="call_1",
169175
name="search_policies",
170-
arguments='{"question": "returns"}'
176+
arguments='{"question": "returns"}',
171177
)
172178
call2 = MockOutputItem(
173179
"function_call",
174180
call_id="call_2",
175181
name="query_orders_database",
176-
arguments='{"query": "SELECT 1"}'
182+
arguments='{"query": "SELECT 1"}',
177183
)
178184
first_response = MockResponse(output_text="", output=[call1, call2])
179185
final_response = MockResponse(output_text="Here's the info")
@@ -242,3 +248,101 @@ def test_has_function_calls_mixed(self):
242248

243249
assert agent._has_function_calls(output) is True
244250

251+
def test_truncate_conversation_when_over_limit(self):
252+
"""Should truncate conversation when over MAX_CONVERSATION_ITEMS."""
253+
with patch("agent.core.AsyncOpenAI"):
254+
from agent.core import SupportAgent, MAX_CONVERSATION_ITEMS
255+
256+
agent = SupportAgent()
257+
# Add more items than the limit
258+
for i in range(MAX_CONVERSATION_ITEMS + 10):
259+
msg: EasyInputMessageParam = {"role": "user", "content": f"msg {i}"}
260+
agent.conversation.append(msg)
261+
262+
agent._truncate_conversation()
263+
264+
assert len(agent.conversation) == MAX_CONVERSATION_ITEMS
265+
266+
def test_truncate_conversation_keeps_recent(self):
267+
"""Should keep most recent messages when truncating."""
268+
with patch("agent.core.AsyncOpenAI"):
269+
from agent.core import SupportAgent, MAX_CONVERSATION_ITEMS
270+
271+
agent = SupportAgent()
272+
# Add numbered messages
273+
for i in range(MAX_CONVERSATION_ITEMS + 5):
274+
msg: EasyInputMessageParam = {"role": "user", "content": f"msg {i}"}
275+
agent.conversation.append(msg)
276+
277+
agent._truncate_conversation()
278+
279+
# Should have kept the last MAX_CONVERSATION_ITEMS messages
280+
first_item = cast(EasyInputMessageParam, agent.conversation[0])
281+
last_item = cast(EasyInputMessageParam, agent.conversation[-1])
282+
assert first_item.get("content") == "msg 5"
283+
assert last_item.get("content") == f"msg {MAX_CONVERSATION_ITEMS + 4}"
284+
285+
def test_truncate_conversation_no_op_when_under_limit(self):
286+
"""Should not truncate when under limit."""
287+
with patch("agent.core.AsyncOpenAI"):
288+
from agent.core import SupportAgent
289+
290+
agent = SupportAgent()
291+
msg1: EasyInputMessageParam = {"role": "user", "content": "msg 1"}
292+
msg2: EasyInputMessageParam = {"role": "user", "content": "msg 2"}
293+
agent.conversation.append(msg1)
294+
agent.conversation.append(msg2)
295+
296+
agent._truncate_conversation()
297+
298+
assert len(agent.conversation) == 2
299+
300+
@pytest.mark.asyncio
301+
async def test_tool_timeout_returns_error(self):
302+
"""Should return error message when tool times out."""
303+
with patch("agent.core.AsyncOpenAI") as mock_openai:
304+
with patch("agent.core.handle_tool_call") as mock_handle:
305+
with patch("agent.core.TOOL_TIMEOUT_SECONDS", 0.01):
306+
mock_client = MagicMock()
307+
mock_openai.return_value = mock_client
308+
309+
# Make handle_tool_call hang
310+
async def slow_handler(*args, **kwargs):
311+
await asyncio.sleep(1)
312+
return "result"
313+
314+
mock_handle.side_effect = slow_handler
315+
316+
from agent.core import SupportAgent
317+
318+
agent = SupportAgent()
319+
result = await agent._execute_tool(
320+
"search_policies", '{"question": "test"}'
321+
)
322+
323+
assert "timed out" in result
324+
325+
@pytest.mark.asyncio
326+
async def test_api_retry_on_transient_error(self):
327+
"""Should retry API calls on transient errors."""
328+
with patch("agent.core.AsyncOpenAI") as mock_openai:
329+
mock_client = MagicMock()
330+
mock_openai.return_value = mock_client
331+
332+
# Fail twice, then succeed
333+
mock_response = MockResponse(output_text="Success after retry")
334+
mock_client.responses.create = AsyncMock(
335+
side_effect=[
336+
APIConnectionError(request=MagicMock()),
337+
APIConnectionError(request=MagicMock()),
338+
mock_response,
339+
]
340+
)
341+
342+
from agent.core import SupportAgent
343+
344+
agent = SupportAgent()
345+
result = await agent.chat("Hello")
346+
347+
assert result == "Success after retry"
348+
assert mock_client.responses.create.call_count == 3

uv.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)