Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions examples/basic/lifecycle_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ async def on_agent_end(self, context: RunContextWrapper, agent: Agent, output: A
async def on_tool_start(self, context: RunContextWrapper, agent: Agent, tool: Tool) -> None:
self.event_counter += 1
print(
f"### {self.event_counter}: Tool {tool.name} started. Usage: {self._usage_to_str(context.usage)}"
f"### {self.event_counter}: Tool {tool.name} started. name={context.tool_name}, call_id={context.tool_call_id}, args={context.tool_arguments}. Usage: {self._usage_to_str(context.usage)}" # type: ignore[attr-defined]
)

async def on_tool_end(
self, context: RunContextWrapper, agent: Agent, tool: Tool, result: str
) -> None:
self.event_counter += 1
print(
f"### {self.event_counter}: Tool {tool.name} ended with result {result}. Usage: {self._usage_to_str(context.usage)}"
f"### {self.event_counter}: Tool {tool.name} finished. result={result}, name={context.tool_name}, call_id={context.tool_call_id}, args={context.tool_arguments}. Usage: {self._usage_to_str(context.usage)}" # type: ignore[attr-defined]
)

async def on_handoff(
Expand Down Expand Up @@ -128,19 +128,19 @@ async def main() -> None:
### 1: Agent Start Agent started. Usage: 0 requests, 0 input tokens, 0 output tokens, 0 total tokens
### 2: LLM started. Usage: 0 requests, 0 input tokens, 0 output tokens, 0 total tokens
### 3: LLM ended. Usage: 1 requests, 143 input tokens, 15 output tokens, 158 total tokens
### 4: Tool random_number started. Usage: 1 requests, 143 input tokens, 15 output tokens, 158 total tokens
### 5: Tool random_number ended with result 69. Usage: 1 requests, 143 input tokens, 15 output tokens, 158 total tokens
### 4: Tool random_number started. name=random_number, call_id=call_IujmDZYiM800H0hy7v17VTS0, args={"max":250}. Usage: 1 requests, 143 input tokens, 15 output tokens, 158 total tokens
### 5: Tool random_number finished. result=107, name=random_number, call_id=call_IujmDZYiM800H0hy7v17VTS0, args={"max":250}. Usage: 1 requests, 143 input tokens, 15 output tokens, 158 total tokens
### 6: LLM started. Usage: 1 requests, 143 input tokens, 15 output tokens, 158 total tokens
### 7: LLM ended. Usage: 2 requests, 310 input tokens, 29 output tokens, 339 total tokens
### 8: Handoff from Start Agent to Multiply Agent. Usage: 2 requests, 310 input tokens, 29 output tokens, 339 total tokens
### 9: Agent Multiply Agent started. Usage: 2 requests, 310 input tokens, 29 output tokens, 339 total tokens
### 10: LLM started. Usage: 2 requests, 310 input tokens, 29 output tokens, 339 total tokens
### 11: LLM ended. Usage: 3 requests, 472 input tokens, 45 output tokens, 517 total tokens
### 12: Tool multiply_by_two started. Usage: 3 requests, 472 input tokens, 45 output tokens, 517 total tokens
### 13: Tool multiply_by_two ended with result 138. Usage: 3 requests, 472 input tokens, 45 output tokens, 517 total tokens
### 12: Tool multiply_by_two started. name=multiply_by_two, call_id=call_KhHvTfsgaosZsfi741QvzgYw, args={"x":107}. Usage: 3 requests, 472 input tokens, 45 output tokens, 517 total tokens
### 13: Tool multiply_by_two finished. result=214, name=multiply_by_two, call_id=call_KhHvTfsgaosZsfi741QvzgYw, args={"x":107}. Usage: 3 requests, 472 input tokens, 45 output tokens, 517 total tokens
### 14: LLM started. Usage: 3 requests, 472 input tokens, 45 output tokens, 517 total tokens
### 15: LLM ended. Usage: 4 requests, 660 input tokens, 56 output tokens, 716 total tokens
### 16: Agent Multiply Agent ended with output number=138. Usage: 4 requests, 660 input tokens, 56 output tokens, 716 total tokens
### 16: Agent Multiply Agent ended with output number=214. Usage: 4 requests, 660 input tokens, 56 output tokens, 716 total tokens
Done!

"""
2 changes: 2 additions & 0 deletions src/agents/realtime/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None:
usage=self._context_wrapper.usage,
tool_name=event.name,
tool_call_id=event.call_id,
tool_arguments=event.arguments,
)
result = await func_tool.on_invoke_tool(tool_context, event.arguments)

Expand All @@ -432,6 +433,7 @@ async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None:
usage=self._context_wrapper.usage,
tool_name=event.name,
tool_call_id=event.call_id,
tool_arguments=event.arguments,
)

# Execute the handoff to get the new agent
Expand Down
15 changes: 14 additions & 1 deletion src/agents/tool_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ def _assert_must_pass_tool_name() -> str:
raise ValueError("tool_name must be passed to ToolContext")


def _assert_must_pass_tool_arguments() -> str:
raise ValueError("tool_arguments must be passed to ToolContext")


@dataclass
class ToolContext(RunContextWrapper[TContext]):
"""The context of a tool call."""
Expand All @@ -24,6 +28,9 @@ class ToolContext(RunContextWrapper[TContext]):
tool_call_id: str = field(default_factory=_assert_must_pass_tool_call_id)
"""The ID of the tool call."""

tool_arguments: str = field(default_factory=_assert_must_pass_tool_arguments)
"""The raw arguments string of the tool call."""

@classmethod
def from_agent_context(
cls,
Expand All @@ -39,4 +46,10 @@ def from_agent_context(
f.name: getattr(context, f.name) for f in fields(RunContextWrapper) if f.init
}
tool_name = tool_call.name if tool_call is not None else _assert_must_pass_tool_name()
return cls(tool_name=tool_name, tool_call_id=tool_call_id, **base_values)
tool_args = (
tool_call.arguments if tool_call is not None else _assert_must_pass_tool_arguments()
)

return cls(
tool_name=tool_name, tool_call_id=tool_call_id, tool_arguments=tool_args, **base_values
)
14 changes: 12 additions & 2 deletions tests/test_agent_as_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,12 @@ async def fake_run(
)

assert isinstance(tool, FunctionTool)
tool_context = ToolContext(context=None, tool_name="story_tool", tool_call_id="call_1")
tool_context = ToolContext(
context=None,
tool_name="story_tool",
tool_call_id="call_1",
tool_arguments='{"input": "hello"}',
)
output = await tool.on_invoke_tool(tool_context, '{"input": "hello"}')

assert output == "Hello world"
Expand Down Expand Up @@ -374,7 +379,12 @@ async def extractor(result) -> str:
)

assert isinstance(tool, FunctionTool)
tool_context = ToolContext(context=None, tool_name="summary_tool", tool_call_id="call_2")
tool_context = ToolContext(
context=None,
tool_name="summary_tool",
tool_call_id="call_2",
tool_arguments='{"input": "summarize this"}',
)
output = await tool.on_invoke_tool(tool_context, '{"input": "summarize this"}')

assert output == "custom output"
51 changes: 36 additions & 15 deletions tests/test_function_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ async def test_argless_function():
assert tool.name == "argless_function"

result = await tool.on_invoke_tool(
ToolContext(context=None, tool_name=tool.name, tool_call_id="1"), ""
ToolContext(context=None, tool_name=tool.name, tool_call_id="1", tool_arguments=""), ""
)
assert result == "ok"

Expand All @@ -41,12 +41,15 @@ async def test_argless_with_context():
tool = function_tool(argless_with_context)
assert tool.name == "argless_with_context"

result = await tool.on_invoke_tool(ToolContext(None, tool_name=tool.name, tool_call_id="1"), "")
result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=""), ""
)
assert result == "ok"

# Extra JSON should not raise an error
result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1}'
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"a": 1}'),
'{"a": 1}',
)
assert result == "ok"

Expand All @@ -61,18 +64,22 @@ async def test_simple_function():
assert tool.name == "simple_function"

result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1}'
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"a": 1}'),
'{"a": 1}',
)
assert result == 6

result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1, "b": 2}'
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"a": 1, "b": 2}'),
'{"a": 1, "b": 2}',
)
assert result == 3

# Missing required argument should raise an error
with pytest.raises(ModelBehaviorError):
await tool.on_invoke_tool(ToolContext(None, tool_name=tool.name, tool_call_id="1"), "")
await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=""), ""
)


class Foo(BaseModel):
Expand Down Expand Up @@ -101,7 +108,8 @@ async def test_complex_args_function():
}
)
result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=valid_json),
valid_json,
)
assert result == "6 hello10 hello"

Expand All @@ -112,7 +120,8 @@ async def test_complex_args_function():
}
)
result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=valid_json),
valid_json,
)
assert result == "3 hello10 hello"

Expand All @@ -124,14 +133,18 @@ async def test_complex_args_function():
}
)
result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=valid_json),
valid_json,
)
assert result == "3 hello10 world"

# Missing required argument should raise an error
with pytest.raises(ModelBehaviorError):
await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"foo": {"a": 1}}'
ToolContext(
None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"foo": {"a": 1}}'
),
'{"foo": {"a": 1}}',
)


Expand Down Expand Up @@ -193,7 +206,10 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str:
assert tool.strict_json_schema

result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"data": "hello"}'
ToolContext(
None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"data": "hello"}'
),
'{"data": "hello"}',
)
assert result == "hello_done"

Expand All @@ -209,7 +225,12 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str:
assert "additionalProperties" not in tool_not_strict.params_json_schema

result = await tool_not_strict.on_invoke_tool(
ToolContext(None, tool_name=tool_not_strict.name, tool_call_id="1"),
ToolContext(
None,
tool_name=tool_not_strict.name,
tool_call_id="1",
tool_arguments='{"data": "hello", "bar": "baz"}',
),
'{"data": "hello", "bar": "baz"}',
)
assert result == "hello_done"
Expand All @@ -221,7 +242,7 @@ def my_func(a: int, b: int = 5):
raise ValueError("test")

tool = function_tool(my_func)
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1")
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments="")

result = await tool.on_invoke_tool(ctx, "")
assert "Invalid JSON" in str(result)
Expand All @@ -245,7 +266,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) ->
return f"error_{error.__class__.__name__}"

tool = function_tool(my_func, failure_error_function=custom_sync_error_function)
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1")
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments="")

result = await tool.on_invoke_tool(ctx, "")
assert result == "error_ModelBehaviorError"
Expand All @@ -269,7 +290,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) ->
return f"error_{error.__class__.__name__}"

tool = function_tool(my_func, failure_error_function=custom_sync_error_function)
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1")
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments="")

result = await tool.on_invoke_tool(ctx, "")
assert result == "error_ModelBehaviorError"
Expand Down
4 changes: 3 additions & 1 deletion tests/test_function_tool_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ def __init__(self):


def ctx_wrapper() -> ToolContext[DummyContext]:
return ToolContext(context=DummyContext(), tool_name="dummy", tool_call_id="1")
return ToolContext(
context=DummyContext(), tool_name="dummy", tool_call_id="1", tool_arguments=""
)


@function_tool
Expand Down