Skip to content

Commit 72d0d75

Browse files
vrtnisseratch
andauthored
Retry: Add tool_name to ToolContext for generic tool handlers (#1110)
This is a follow-up to pr #1043 The original changes were reverted due to missing updates in RealtimeSession, which caused runtime test failures. This PR: - Reapplies the `tool_name` and `tool_call_id` additions to `ToolContext`. - Updates `RealtimeSession._handle_tool_call` to instantiate `ToolContext` with `tool_name=event.name` and `tool_call_id=event.call_id`. - Adjusts tests as needed so that all 533 tests (including old-version Python 3.9) pass cleanly. Closes #1030 --------- Co-authored-by: Kazuhiro Sera <[email protected]>
1 parent 2da6194 commit 72d0d75

File tree

7 files changed

+70
-23
lines changed

7 files changed

+70
-23
lines changed

docs/tools.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ Sometimes, you don't want to use a Python function as a tool. You can directly c
180180
- `name`
181181
- `description`
182182
- `params_json_schema`, which is the JSON schema for the arguments
183-
- `on_invoke_tool`, which is an async function that receives the context and the arguments as a JSON string, and must return the tool output as a string.
183+
- `on_invoke_tool`, which is an async function that receives a [`ToolContext`][agents.tool_context.ToolContext] and the arguments as a JSON string, and must return the tool output as a string.
184184

185185
```python
186186
from typing import Any

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ plugins:
9292
- ref/lifecycle.md
9393
- ref/items.md
9494
- ref/run_context.md
95+
- ref/tool_context.md
9596
- ref/usage.md
9697
- ref/exceptions.md
9798
- ref/guardrail.md

src/agents/_run_impl.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,11 @@ async def run_single_tool(
548548
func_tool: FunctionTool, tool_call: ResponseFunctionToolCall
549549
) -> Any:
550550
with function_span(func_tool.name) as span_fn:
551-
tool_context = ToolContext.from_agent_context(context_wrapper, tool_call.call_id)
551+
tool_context = ToolContext.from_agent_context(
552+
context_wrapper,
553+
tool_call.call_id,
554+
tool_call=tool_call,
555+
)
552556
if config.trace_include_sensitive_data:
553557
span_fn.span_data.input = tool_call.arguments
554558
try:

src/agents/realtime/session.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,12 @@ async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None:
269269
)
270270

271271
func_tool = function_map[event.name]
272-
tool_context = ToolContext.from_agent_context(self._context_wrapper, event.call_id)
272+
tool_context = ToolContext(
273+
context=self._context_wrapper.context,
274+
usage=self._context_wrapper.usage,
275+
tool_name=event.name,
276+
tool_call_id=event.call_id,
277+
)
273278
result = await func_tool.on_invoke_tool(tool_context, event.arguments)
274279

275280
await self._model.send_event(
@@ -288,7 +293,12 @@ async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None:
288293
)
289294
elif event.name in handoff_map:
290295
handoff = handoff_map[event.name]
291-
tool_context = ToolContext.from_agent_context(self._context_wrapper, event.call_id)
296+
tool_context = ToolContext(
297+
context=self._context_wrapper.context,
298+
usage=self._context_wrapper.usage,
299+
tool_name=event.name,
300+
tool_call_id=event.call_id,
301+
)
292302

293303
# Execute the handoff to get the new agent
294304
result = await handoff.on_invoke_handoff(self._context_wrapper, event.arguments)

src/agents/tool_context.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from dataclasses import dataclass, field, fields
2-
from typing import Any
2+
from typing import Any, Optional
3+
4+
from openai.types.responses import ResponseFunctionToolCall
35

46
from .run_context import RunContextWrapper, TContext
57

@@ -8,16 +10,26 @@ def _assert_must_pass_tool_call_id() -> str:
810
raise ValueError("tool_call_id must be passed to ToolContext")
911

1012

13+
def _assert_must_pass_tool_name() -> str:
14+
raise ValueError("tool_name must be passed to ToolContext")
15+
16+
1117
@dataclass
1218
class ToolContext(RunContextWrapper[TContext]):
1319
"""The context of a tool call."""
1420

21+
tool_name: str = field(default_factory=_assert_must_pass_tool_name)
22+
"""The name of the tool being invoked."""
23+
1524
tool_call_id: str = field(default_factory=_assert_must_pass_tool_call_id)
1625
"""The ID of the tool call."""
1726

1827
@classmethod
1928
def from_agent_context(
20-
cls, context: RunContextWrapper[TContext], tool_call_id: str
29+
cls,
30+
context: RunContextWrapper[TContext],
31+
tool_call_id: str,
32+
tool_call: Optional[ResponseFunctionToolCall] = None,
2133
) -> "ToolContext":
2234
"""
2335
Create a ToolContext from a RunContextWrapper.
@@ -26,4 +38,5 @@ def from_agent_context(
2638
base_values: dict[str, Any] = {
2739
f.name: getattr(context, f.name) for f in fields(RunContextWrapper) if f.init
2840
}
29-
return cls(tool_call_id=tool_call_id, **base_values)
41+
tool_name = tool_call.name if tool_call is not None else _assert_must_pass_tool_name()
42+
return cls(tool_name=tool_name, tool_call_id=tool_call_id, **base_values)

tests/test_function_tool.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ async def test_argless_function():
2626
tool = function_tool(argless_function)
2727
assert tool.name == "argless_function"
2828

29-
result = await tool.on_invoke_tool(ToolContext(context=None, tool_call_id="1"), "")
29+
result = await tool.on_invoke_tool(
30+
ToolContext(context=None, tool_name=tool.name, tool_call_id="1"), ""
31+
)
3032
assert result == "ok"
3133

3234

@@ -39,11 +41,13 @@ async def test_argless_with_context():
3941
tool = function_tool(argless_with_context)
4042
assert tool.name == "argless_with_context"
4143

42-
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), "")
44+
result = await tool.on_invoke_tool(ToolContext(None, tool_name=tool.name, tool_call_id="1"), "")
4345
assert result == "ok"
4446

4547
# Extra JSON should not raise an error
46-
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1}')
48+
result = await tool.on_invoke_tool(
49+
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1}'
50+
)
4751
assert result == "ok"
4852

4953

@@ -56,15 +60,19 @@ async def test_simple_function():
5660
tool = function_tool(simple_function, failure_error_function=None)
5761
assert tool.name == "simple_function"
5862

59-
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1}')
63+
result = await tool.on_invoke_tool(
64+
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1}'
65+
)
6066
assert result == 6
6167

62-
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1, "b": 2}')
68+
result = await tool.on_invoke_tool(
69+
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1, "b": 2}'
70+
)
6371
assert result == 3
6472

6573
# Missing required argument should raise an error
6674
with pytest.raises(ModelBehaviorError):
67-
await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), "")
75+
await tool.on_invoke_tool(ToolContext(None, tool_name=tool.name, tool_call_id="1"), "")
6876

6977

7078
class Foo(BaseModel):
@@ -92,7 +100,9 @@ async def test_complex_args_function():
92100
"bar": Bar(x="hello", y=10),
93101
}
94102
)
95-
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json)
103+
result = await tool.on_invoke_tool(
104+
ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json
105+
)
96106
assert result == "6 hello10 hello"
97107

98108
valid_json = json.dumps(
@@ -101,7 +111,9 @@ async def test_complex_args_function():
101111
"bar": Bar(x="hello", y=10),
102112
}
103113
)
104-
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json)
114+
result = await tool.on_invoke_tool(
115+
ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json
116+
)
105117
assert result == "3 hello10 hello"
106118

107119
valid_json = json.dumps(
@@ -111,12 +123,16 @@ async def test_complex_args_function():
111123
"baz": "world",
112124
}
113125
)
114-
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json)
126+
result = await tool.on_invoke_tool(
127+
ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json
128+
)
115129
assert result == "3 hello10 world"
116130

117131
# Missing required argument should raise an error
118132
with pytest.raises(ModelBehaviorError):
119-
await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"foo": {"a": 1}}')
133+
await tool.on_invoke_tool(
134+
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"foo": {"a": 1}}'
135+
)
120136

121137

122138
def test_function_config_overrides():
@@ -176,7 +192,9 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str:
176192
assert tool.params_json_schema[key] == value
177193
assert tool.strict_json_schema
178194

179-
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"data": "hello"}')
195+
result = await tool.on_invoke_tool(
196+
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"data": "hello"}'
197+
)
180198
assert result == "hello_done"
181199

182200
tool_not_strict = FunctionTool(
@@ -191,7 +209,8 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str:
191209
assert "additionalProperties" not in tool_not_strict.params_json_schema
192210

193211
result = await tool_not_strict.on_invoke_tool(
194-
ToolContext(None, tool_call_id="1"), '{"data": "hello", "bar": "baz"}'
212+
ToolContext(None, tool_name=tool_not_strict.name, tool_call_id="1"),
213+
'{"data": "hello", "bar": "baz"}',
195214
)
196215
assert result == "hello_done"
197216

@@ -202,7 +221,7 @@ def my_func(a: int, b: int = 5):
202221
raise ValueError("test")
203222

204223
tool = function_tool(my_func)
205-
ctx = ToolContext(None, tool_call_id="1")
224+
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1")
206225

207226
result = await tool.on_invoke_tool(ctx, "")
208227
assert "Invalid JSON" in str(result)
@@ -226,7 +245,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) ->
226245
return f"error_{error.__class__.__name__}"
227246

228247
tool = function_tool(my_func, failure_error_function=custom_sync_error_function)
229-
ctx = ToolContext(None, tool_call_id="1")
248+
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1")
230249

231250
result = await tool.on_invoke_tool(ctx, "")
232251
assert result == "error_ModelBehaviorError"
@@ -250,7 +269,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) ->
250269
return f"error_{error.__class__.__name__}"
251270

252271
tool = function_tool(my_func, failure_error_function=custom_sync_error_function)
253-
ctx = ToolContext(None, tool_call_id="1")
272+
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1")
254273

255274
result = await tool.on_invoke_tool(ctx, "")
256275
assert result == "error_ModelBehaviorError"

tests/test_function_tool_decorator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def __init__(self):
1616

1717

1818
def ctx_wrapper() -> ToolContext[DummyContext]:
19-
return ToolContext(context=DummyContext(), tool_call_id="1")
19+
return ToolContext(context=DummyContext(), tool_name="dummy", tool_call_id="1")
2020

2121

2222
@function_tool

0 commit comments

Comments
 (0)