Skip to content

Refactor agent into AgentBase+Agent #1044

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from openai import AsyncOpenAI

from . import _config
from .agent import Agent, ToolsToFinalOutputFunction, ToolsToFinalOutputResult
from .agent import Agent, AgentBase, ToolsToFinalOutputFunction, ToolsToFinalOutputResult
from .agent_output import AgentOutputSchema, AgentOutputSchemaBase
from .computer import AsyncComputer, Button, Computer, Environment
from .exceptions import (
Expand Down Expand Up @@ -160,6 +160,7 @@ def enable_verbose_stdout_logging():

__all__ = [
"Agent",
"AgentBase",
"ToolsToFinalOutputFunction",
"ToolsToFinalOutputResult",
"Runner",
Expand Down
61 changes: 33 additions & 28 deletions src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,36 @@ class MCPConfig(TypedDict):


@dataclass
class Agent(Generic[TContext]):
class AgentBase:
"""Base class for `Agent` and `RealtimeAgent`."""

name: str
"""The name of the agent."""

handoff_description: str | None = None
"""A description of the agent. This is used when the agent is used as a handoff, so that an
LLM knows what it does and when to invoke it.
"""

tools: list[Tool] = field(default_factory=list)
"""A list of tools that the agent can use."""

mcp_servers: list[MCPServer] = field(default_factory=list)
"""A list of [Model Context Protocol](https://modelcontextprotocol.io/) servers that
the agent can use. Every time the agent runs, it will include tools from these servers in the
list of available tools.

NOTE: You are expected to manage the lifecycle of these servers. Specifically, you must call
`server.connect()` before passing it to the agent, and `server.cleanup()` when the server is no
longer needed.
"""

mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig())
"""Configuration for MCP servers."""


@dataclass
class Agent(AgentBase, Generic[TContext]):
"""An agent is an AI model configured with instructions, tools, guardrails, handoffs and more.

We strongly recommend passing `instructions`, which is the "system prompt" for the agent. In
Expand All @@ -76,10 +105,9 @@ class Agent(Generic[TContext]):

Agents are generic on the context type. The context is a (mutable) object you create. It is
passed to tool functions, handoffs, guardrails, etc.
"""

name: str
"""The name of the agent."""
See `AgentBase` for base parameters that are shared with `RealtimeAgent`s.
"""

instructions: (
str
Expand All @@ -103,11 +131,6 @@ class Agent(Generic[TContext]):
usable with OpenAI models, using the Responses API.
"""

handoff_description: str | None = None
"""A description of the agent. This is used when the agent is used as a handoff, so that an
LLM knows what it does and when to invoke it.
"""

handoffs: list[Agent[Any] | Handoff[TContext]] = field(default_factory=list)
"""Handoffs are sub-agents that the agent can delegate to. You can provide a list of handoffs,
and the agent can choose to delegate to them if relevant. Allows for separation of concerns and
Expand All @@ -125,22 +148,6 @@ class Agent(Generic[TContext]):
"""Configures model-specific tuning parameters (e.g. temperature, top_p).
"""

tools: list[Tool] = field(default_factory=list)
"""A list of tools that the agent can use."""

mcp_servers: list[MCPServer] = field(default_factory=list)
"""A list of [Model Context Protocol](https://modelcontextprotocol.io/) servers that
the agent can use. Every time the agent runs, it will include tools from these servers in the
list of available tools.

NOTE: You are expected to manage the lifecycle of these servers. Specifically, you must call
`server.connect()` before passing it to the agent, and `server.cleanup()` when the server is no
longer needed.
"""

mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig())
"""Configuration for MCP servers."""

input_guardrails: list[InputGuardrail[TContext]] = field(default_factory=list)
"""A list of checks that run in parallel to the agent's execution, before generating a
response. Runs only if the agent is the first agent in the chain.
Expand Down Expand Up @@ -256,9 +263,7 @@ async def get_prompt(
"""Get the prompt for the agent."""
return await PromptUtil.to_model_input(self.prompt, run_context, self)

async def get_mcp_tools(
self, run_context: RunContextWrapper[TContext]
) -> list[Tool]:
async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
"""Fetches the available tools from the MCP servers."""
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
return await MCPUtil.get_all_function_tools(
Expand Down
43 changes: 26 additions & 17 deletions src/agents/lifecycle.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
from typing import Any, Generic

from .agent import Agent
from typing_extensions import TypeVar

from .agent import Agent, AgentBase
from .run_context import RunContextWrapper, TContext
from .tool import Tool

TAgent = TypeVar("TAgent", bound=AgentBase, default=AgentBase)


class RunHooks(Generic[TContext]):
class RunHooksBase(Generic[TContext, TAgent]):
"""A class that receives callbacks on various lifecycle events in an agent run. Subclass and
override the methods you need.
"""

async def on_agent_start(
self, context: RunContextWrapper[TContext], agent: Agent[TContext]
) -> None:
async def on_agent_start(self, context: RunContextWrapper[TContext], agent: TAgent) -> None:
"""Called before the agent is invoked. Called each time the current agent changes."""
pass

async def on_agent_end(
self,
context: RunContextWrapper[TContext],
agent: Agent[TContext],
agent: TAgent,
output: Any,
) -> None:
"""Called when the agent produces a final output."""
Expand All @@ -28,16 +30,16 @@ async def on_agent_end(
async def on_handoff(
self,
context: RunContextWrapper[TContext],
from_agent: Agent[TContext],
to_agent: Agent[TContext],
from_agent: TAgent,
to_agent: TAgent,
) -> None:
"""Called when a handoff occurs."""
pass

async def on_tool_start(
self,
context: RunContextWrapper[TContext],
agent: Agent[TContext],
agent: TAgent,
tool: Tool,
) -> None:
"""Called before a tool is invoked."""
Expand All @@ -46,30 +48,30 @@ async def on_tool_start(
async def on_tool_end(
self,
context: RunContextWrapper[TContext],
agent: Agent[TContext],
agent: TAgent,
tool: Tool,
result: str,
) -> None:
"""Called after a tool is invoked."""
pass


class AgentHooks(Generic[TContext]):
class AgentHooksBase(Generic[TContext, TAgent]):
"""A class that receives callbacks on various lifecycle events for a specific agent. You can
set this on `agent.hooks` to receive events for that specific agent.

Subclass and override the methods you need.
"""

async def on_start(self, context: RunContextWrapper[TContext], agent: Agent[TContext]) -> None:
async def on_start(self, context: RunContextWrapper[TContext], agent: TAgent) -> None:
"""Called before the agent is invoked. Called each time the running agent is changed to this
agent."""
pass

async def on_end(
self,
context: RunContextWrapper[TContext],
agent: Agent[TContext],
agent: TAgent,
output: Any,
) -> None:
"""Called when the agent produces a final output."""
Expand All @@ -78,8 +80,8 @@ async def on_end(
async def on_handoff(
self,
context: RunContextWrapper[TContext],
agent: Agent[TContext],
source: Agent[TContext],
agent: TAgent,
source: TAgent,
) -> None:
"""Called when the agent is being handed off to. The `source` is the agent that is handing
off to this agent."""
Expand All @@ -88,7 +90,7 @@ async def on_handoff(
async def on_tool_start(
self,
context: RunContextWrapper[TContext],
agent: Agent[TContext],
agent: TAgent,
tool: Tool,
) -> None:
"""Called before a tool is invoked."""
Expand All @@ -97,9 +99,16 @@ async def on_tool_start(
async def on_tool_end(
self,
context: RunContextWrapper[TContext],
agent: Agent[TContext],
agent: TAgent,
tool: Tool,
result: str,
) -> None:
"""Called after a tool is invoked."""
pass


RunHooks = RunHooksBase[TContext, Agent]
"""Run hooks when using `Agent`."""

AgentHooks = AgentHooksBase[TContext, Agent]
"""Agent hooks for `Agent`s."""
10 changes: 5 additions & 5 deletions src/agents/mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .util import ToolFilter, ToolFilterCallable, ToolFilterContext, ToolFilterStatic

if TYPE_CHECKING:
from ..agent import Agent
from ..agent import AgentBase


class MCPServer(abc.ABC):
Expand Down Expand Up @@ -53,7 +53,7 @@ async def cleanup(self):
async def list_tools(
self,
run_context: RunContextWrapper[Any] | None = None,
agent: Agent[Any] | None = None,
agent: AgentBase | None = None,
) -> list[MCPTool]:
"""List the tools available on the server."""
pass
Expand Down Expand Up @@ -117,7 +117,7 @@ async def _apply_tool_filter(
self,
tools: list[MCPTool],
run_context: RunContextWrapper[Any],
agent: Agent[Any],
agent: AgentBase,
) -> list[MCPTool]:
"""Apply the tool filter to the list of tools."""
if self.tool_filter is None:
Expand Down Expand Up @@ -153,7 +153,7 @@ async def _apply_dynamic_tool_filter(
self,
tools: list[MCPTool],
run_context: RunContextWrapper[Any],
agent: Agent[Any],
agent: AgentBase,
) -> list[MCPTool]:
"""Apply dynamic tool filtering using a callable filter function."""

Expand Down Expand Up @@ -244,7 +244,7 @@ async def connect(self):
async def list_tools(
self,
run_context: RunContextWrapper[Any] | None = None,
agent: Agent[Any] | None = None,
agent: AgentBase | None = None,
) -> list[MCPTool]:
"""List the tools available on the server."""
if not self.session:
Expand Down
11 changes: 5 additions & 6 deletions src/agents/mcp/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,19 @@

from typing_extensions import NotRequired, TypedDict

from agents.strict_schema import ensure_strict_json_schema

from .. import _debug
from ..exceptions import AgentsException, ModelBehaviorError, UserError
from ..logger import logger
from ..run_context import RunContextWrapper
from ..strict_schema import ensure_strict_json_schema
from ..tool import FunctionTool, Tool
from ..tracing import FunctionSpanData, get_current_span, mcp_tools_span
from ..util._types import MaybeAwaitable

if TYPE_CHECKING:
from mcp.types import Tool as MCPTool

from ..agent import Agent
from ..agent import AgentBase
from .server import MCPServer


Expand All @@ -29,7 +28,7 @@ class ToolFilterContext:
run_context: RunContextWrapper[Any]
"""The current run context."""

agent: "Agent[Any]"
agent: "AgentBase"
"""The agent that is requesting the tool list."""

server_name: str
Expand Down Expand Up @@ -100,7 +99,7 @@ async def get_all_function_tools(
servers: list["MCPServer"],
convert_schemas_to_strict: bool,
run_context: RunContextWrapper[Any],
agent: "Agent[Any]",
agent: "AgentBase",
) -> list[Tool]:
"""Get all function tools from a list of MCP servers."""
tools = []
Expand All @@ -126,7 +125,7 @@ async def get_function_tools(
server: "MCPServer",
convert_schemas_to_strict: bool,
run_context: RunContextWrapper[Any],
agent: "Agent[Any]",
agent: "AgentBase",
) -> list[Tool]:
"""Get all function tools from a single MCP server."""

Expand Down
11 changes: 5 additions & 6 deletions src/agents/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@
from .util._types import MaybeAwaitable

if TYPE_CHECKING:

from .agent import Agent
from .agent import Agent, AgentBase

ToolParams = ParamSpec("ToolParams")

Expand Down Expand Up @@ -88,7 +87,7 @@ class FunctionTool:
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
as it increases the likelihood of correct JSON input."""

is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True
is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True
"""Whether the tool is enabled. Either a bool or a Callable that takes the run context and agent
and returns whether the tool is enabled. You can use this to dynamically enable/disable a tool
based on your context/state."""
Expand Down Expand Up @@ -301,7 +300,7 @@ def function_tool(
use_docstring_info: bool = True,
failure_error_function: ToolErrorFunction | None = None,
strict_mode: bool = True,
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True,
) -> FunctionTool:
"""Overload for usage as @function_tool (no parentheses)."""
...
Expand All @@ -316,7 +315,7 @@ def function_tool(
use_docstring_info: bool = True,
failure_error_function: ToolErrorFunction | None = None,
strict_mode: bool = True,
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True,
) -> Callable[[ToolFunction[...]], FunctionTool]:
"""Overload for usage as @function_tool(...)."""
...
Expand All @@ -331,7 +330,7 @@ def function_tool(
use_docstring_info: bool = True,
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
strict_mode: bool = True,
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True,
) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]:
"""
Decorator to create a FunctionTool from a function. By default, we will:
Expand Down
Loading