Skip to content

Commit 37e99b5

Browse files
committed
Back out "enhancement: Add tool_name to ToolContext to support shared tool handlers (#1043)"
Original commit changeset: befe19d
1 parent e651a29 commit 37e99b5

File tree

7 files changed

+21
-61
lines changed

7 files changed

+21
-61
lines changed

docs/ref/tool_context.md

Lines changed: 0 additions & 3 deletions
This file was deleted.

docs/tools.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ Sometimes, you don't want to use a Python function as a tool. You can directly c
180180
- `name`
181181
- `description`
182182
- `params_json_schema`, which is the JSON schema for the arguments
183-
- `on_invoke_tool`, which is an async function that receives a [`ToolContext`][agents.tool_context.ToolContext] and the arguments as a JSON string, and must return the tool output as a string.
183+
- `on_invoke_tool`, which is an async function that receives the context and the arguments as a JSON string, and must return the tool output as a string.
184184

185185
```python
186186
from typing import Any

mkdocs.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ plugins:
9292
- ref/lifecycle.md
9393
- ref/items.md
9494
- ref/run_context.md
95-
- ref/tool_context.md
9695
- ref/usage.md
9796
- ref/exceptions.md
9897
- ref/guardrail.md

src/agents/_run_impl.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -548,11 +548,7 @@ async def run_single_tool(
548548
func_tool: FunctionTool, tool_call: ResponseFunctionToolCall
549549
) -> Any:
550550
with function_span(func_tool.name) as span_fn:
551-
tool_context = ToolContext.from_agent_context(
552-
context_wrapper,
553-
tool_call.call_id,
554-
tool_call=tool_call,
555-
)
551+
tool_context = ToolContext.from_agent_context(context_wrapper, tool_call.call_id)
556552
if config.trace_include_sensitive_data:
557553
span_fn.span_data.input = tool_call.arguments
558554
try:

src/agents/tool_context.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from dataclasses import dataclass, field, fields
2-
from typing import Any, Optional
3-
4-
from openai.types.responses import ResponseFunctionToolCall
2+
from typing import Any
53

64
from .run_context import RunContextWrapper, TContext
75

@@ -10,26 +8,16 @@ def _assert_must_pass_tool_call_id() -> str:
108
raise ValueError("tool_call_id must be passed to ToolContext")
119

1210

13-
def _assert_must_pass_tool_name() -> str:
14-
raise ValueError("tool_name must be passed to ToolContext")
15-
16-
1711
@dataclass
1812
class ToolContext(RunContextWrapper[TContext]):
1913
"""The context of a tool call."""
2014

21-
tool_name: str = field(default_factory=_assert_must_pass_tool_name)
22-
"""The name of the tool being invoked."""
23-
2415
tool_call_id: str = field(default_factory=_assert_must_pass_tool_call_id)
2516
"""The ID of the tool call."""
2617

2718
@classmethod
2819
def from_agent_context(
29-
cls,
30-
context: RunContextWrapper[TContext],
31-
tool_call_id: str,
32-
tool_call: Optional[ResponseFunctionToolCall] = None,
20+
cls, context: RunContextWrapper[TContext], tool_call_id: str
3321
) -> "ToolContext":
3422
"""
3523
Create a ToolContext from a RunContextWrapper.
@@ -38,5 +26,4 @@ def from_agent_context(
3826
base_values: dict[str, Any] = {
3927
f.name: getattr(context, f.name) for f in fields(RunContextWrapper) if f.init
4028
}
41-
tool_name = tool_call.name if tool_call is not None else _assert_must_pass_tool_name()
42-
return cls(tool_name=tool_name, tool_call_id=tool_call_id, **base_values)
29+
return cls(tool_call_id=tool_call_id, **base_values)

tests/test_function_tool.py

Lines changed: 15 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@ async def test_argless_function():
2626
tool = function_tool(argless_function)
2727
assert tool.name == "argless_function"
2828

29-
result = await tool.on_invoke_tool(
30-
ToolContext(context=None, tool_name=tool.name, tool_call_id="1"), ""
31-
)
29+
result = await tool.on_invoke_tool(ToolContext(context=None, tool_call_id="1"), "")
3230
assert result == "ok"
3331

3432

@@ -41,13 +39,11 @@ async def test_argless_with_context():
4139
tool = function_tool(argless_with_context)
4240
assert tool.name == "argless_with_context"
4341

44-
result = await tool.on_invoke_tool(ToolContext(None, tool_name=tool.name, tool_call_id="1"), "")
42+
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), "")
4543
assert result == "ok"
4644

4745
# Extra JSON should not raise an error
48-
result = await tool.on_invoke_tool(
49-
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1}'
50-
)
46+
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1}')
5147
assert result == "ok"
5248

5349

@@ -60,19 +56,15 @@ async def test_simple_function():
6056
tool = function_tool(simple_function, failure_error_function=None)
6157
assert tool.name == "simple_function"
6258

63-
result = await tool.on_invoke_tool(
64-
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1}'
65-
)
59+
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1}')
6660
assert result == 6
6761

68-
result = await tool.on_invoke_tool(
69-
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1, "b": 2}'
70-
)
62+
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1, "b": 2}')
7163
assert result == 3
7264

7365
# Missing required argument should raise an error
7466
with pytest.raises(ModelBehaviorError):
75-
await tool.on_invoke_tool(ToolContext(None, tool_name=tool.name, tool_call_id="1"), "")
67+
await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), "")
7668

7769

7870
class Foo(BaseModel):
@@ -100,9 +92,7 @@ async def test_complex_args_function():
10092
"bar": Bar(x="hello", y=10),
10193
}
10294
)
103-
result = await tool.on_invoke_tool(
104-
ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json
105-
)
95+
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json)
10696
assert result == "6 hello10 hello"
10797

10898
valid_json = json.dumps(
@@ -111,9 +101,7 @@ async def test_complex_args_function():
111101
"bar": Bar(x="hello", y=10),
112102
}
113103
)
114-
result = await tool.on_invoke_tool(
115-
ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json
116-
)
104+
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json)
117105
assert result == "3 hello10 hello"
118106

119107
valid_json = json.dumps(
@@ -123,16 +111,12 @@ async def test_complex_args_function():
123111
"baz": "world",
124112
}
125113
)
126-
result = await tool.on_invoke_tool(
127-
ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json
128-
)
114+
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json)
129115
assert result == "3 hello10 world"
130116

131117
# Missing required argument should raise an error
132118
with pytest.raises(ModelBehaviorError):
133-
await tool.on_invoke_tool(
134-
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"foo": {"a": 1}}'
135-
)
119+
await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"foo": {"a": 1}}')
136120

137121

138122
def test_function_config_overrides():
@@ -192,9 +176,7 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str:
192176
assert tool.params_json_schema[key] == value
193177
assert tool.strict_json_schema
194178

195-
result = await tool.on_invoke_tool(
196-
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"data": "hello"}'
197-
)
179+
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"data": "hello"}')
198180
assert result == "hello_done"
199181

200182
tool_not_strict = FunctionTool(
@@ -209,8 +191,7 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str:
209191
assert "additionalProperties" not in tool_not_strict.params_json_schema
210192

211193
result = await tool_not_strict.on_invoke_tool(
212-
ToolContext(None, tool_name=tool_not_strict.name, tool_call_id="1"),
213-
'{"data": "hello", "bar": "baz"}',
194+
ToolContext(None, tool_call_id="1"), '{"data": "hello", "bar": "baz"}'
214195
)
215196
assert result == "hello_done"
216197

@@ -221,7 +202,7 @@ def my_func(a: int, b: int = 5):
221202
raise ValueError("test")
222203

223204
tool = function_tool(my_func)
224-
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1")
205+
ctx = ToolContext(None, tool_call_id="1")
225206

226207
result = await tool.on_invoke_tool(ctx, "")
227208
assert "Invalid JSON" in str(result)
@@ -245,7 +226,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) ->
245226
return f"error_{error.__class__.__name__}"
246227

247228
tool = function_tool(my_func, failure_error_function=custom_sync_error_function)
248-
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1")
229+
ctx = ToolContext(None, tool_call_id="1")
249230

250231
result = await tool.on_invoke_tool(ctx, "")
251232
assert result == "error_ModelBehaviorError"
@@ -269,7 +250,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) ->
269250
return f"error_{error.__class__.__name__}"
270251

271252
tool = function_tool(my_func, failure_error_function=custom_sync_error_function)
272-
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1")
253+
ctx = ToolContext(None, tool_call_id="1")
273254

274255
result = await tool.on_invoke_tool(ctx, "")
275256
assert result == "error_ModelBehaviorError"

tests/test_function_tool_decorator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def __init__(self):
1616

1717

1818
def ctx_wrapper() -> ToolContext[DummyContext]:
19-
return ToolContext(context=DummyContext(), tool_name="dummy", tool_call_id="1")
19+
return ToolContext(context=DummyContext(), tool_call_id="1")
2020

2121

2222
@function_tool

0 commit comments

Comments
 (0)