1818from .util ._types import MaybeAwaitable
1919
2020if TYPE_CHECKING :
21- from .agent import Agent
21+ from .agent import Agent , AgentBase
2222
2323
2424# The handoff input type is the type of data passed when the agent is called via a handoff.
2525THandoffInput = TypeVar ("THandoffInput" , default = Any )
2626
27+ # The agent type that the handoff returns
28+ TAgent = TypeVar ("TAgent" , bound = "AgentBase[Any]" , default = "Agent[Any]" )
29+
2730OnHandoffWithInput = Callable [[RunContextWrapper [Any ], THandoffInput ], Any ]
2831OnHandoffWithoutInput = Callable [[RunContextWrapper [Any ]], Any ]
2932
@@ -52,7 +55,7 @@ class HandoffInputData:
5255
5356
5457@dataclass
55- class Handoff (Generic [TContext ]):
58+ class Handoff (Generic [TContext , TAgent ]):
5659 """A handoff is when an agent delegates a task to another agent.
5760 For example, in a customer support scenario you might have a "triage agent" that determines
5861 which agent should handle the user's request, and sub-agents that specialize in different
@@ -69,7 +72,7 @@ class Handoff(Generic[TContext]):
6972 """The JSON schema for the handoff input. Can be empty if the handoff does not take an input.
7073 """
7174
72- on_invoke_handoff : Callable [[RunContextWrapper [Any ], str ], Awaitable [Agent [ TContext ] ]]
75+ on_invoke_handoff : Callable [[RunContextWrapper [Any ], str ], Awaitable [TAgent ]]
7376 """The function that invokes the handoff. The parameters passed are:
7477 1. The handoff run context
7578 2. The arguments from the LLM, as a JSON string. Empty string if input_json_schema is empty.
@@ -100,20 +103,22 @@ class Handoff(Generic[TContext]):
100103 True, as it increases the likelihood of correct JSON input.
101104 """
102105
103- is_enabled : bool | Callable [[RunContextWrapper [Any ], Agent [Any ]], MaybeAwaitable [bool ]] = True
106+ is_enabled : bool | Callable [[RunContextWrapper [Any ], AgentBase [Any ]], MaybeAwaitable [bool ]] = (
107+ True
108+ )
104109 """Whether the handoff is enabled. Either a bool or a Callable that takes the run context and
105110 agent and returns whether the handoff is enabled. You can use this to dynamically enable/disable
106111 a handoff based on your context/state."""
107112
108- def get_transfer_message (self , agent : Agent [Any ]) -> str :
113+ def get_transfer_message (self , agent : AgentBase [Any ]) -> str :
109114 return json .dumps ({"assistant" : agent .name })
110115
111116 @classmethod
112- def default_tool_name (cls , agent : Agent [Any ]) -> str :
117+ def default_tool_name (cls , agent : AgentBase [Any ]) -> str :
113118 return _transforms .transform_string_function_style (f"transfer_to_{ agent .name } " )
114119
115120 @classmethod
116- def default_tool_description (cls , agent : Agent [Any ]) -> str :
121+ def default_tool_description (cls , agent : AgentBase [Any ]) -> str :
117122 return (
118123 f"Handoff to the { agent .name } agent to handle the request. "
119124 f"{ agent .handoff_description or '' } "
@@ -128,7 +133,7 @@ def handoff(
128133 tool_description_override : str | None = None ,
129134 input_filter : Callable [[HandoffInputData ], HandoffInputData ] | None = None ,
130135 is_enabled : bool | Callable [[RunContextWrapper [Any ], Agent [Any ]], MaybeAwaitable [bool ]] = True ,
131- ) -> Handoff [TContext ]: ...
136+ ) -> Handoff [TContext , Agent [ TContext ] ]: ...
132137
133138
134139@overload
@@ -141,7 +146,7 @@ def handoff(
141146 tool_name_override : str | None = None ,
142147 input_filter : Callable [[HandoffInputData ], HandoffInputData ] | None = None ,
143148 is_enabled : bool | Callable [[RunContextWrapper [Any ], Agent [Any ]], MaybeAwaitable [bool ]] = True ,
144- ) -> Handoff [TContext ]: ...
149+ ) -> Handoff [TContext , Agent [ TContext ] ]: ...
145150
146151
147152@overload
@@ -153,7 +158,7 @@ def handoff(
153158 tool_name_override : str | None = None ,
154159 input_filter : Callable [[HandoffInputData ], HandoffInputData ] | None = None ,
155160 is_enabled : bool | Callable [[RunContextWrapper [Any ], Agent [Any ]], MaybeAwaitable [bool ]] = True ,
156- ) -> Handoff [TContext ]: ...
161+ ) -> Handoff [TContext , Agent [ TContext ] ]: ...
157162
158163
159164def handoff (
@@ -163,8 +168,9 @@ def handoff(
163168 on_handoff : OnHandoffWithInput [THandoffInput ] | OnHandoffWithoutInput | None = None ,
164169 input_type : type [THandoffInput ] | None = None ,
165170 input_filter : Callable [[HandoffInputData ], HandoffInputData ] | None = None ,
166- is_enabled : bool | Callable [[RunContextWrapper [Any ], Agent [Any ]], MaybeAwaitable [bool ]] = True ,
167- ) -> Handoff [TContext ]:
171+ is_enabled : bool
172+ | Callable [[RunContextWrapper [Any ], Agent [TContext ]], MaybeAwaitable [bool ]] = True ,
173+ ) -> Handoff [TContext , Agent [TContext ]]:
168174 """Create a handoff from an agent.
169175
170176 Args:
@@ -202,7 +208,7 @@ def handoff(
202208
203209 async def _invoke_handoff (
204210 ctx : RunContextWrapper [Any ], input_json : str | None = None
205- ) -> Agent [Any ]:
211+ ) -> Agent [TContext ]:
206212 if input_type is not None and type_adapter is not None :
207213 if input_json is None :
208214 _error_tracing .attach_error_to_current_span (
@@ -239,12 +245,24 @@ async def _invoke_handoff(
239245 # If there is a need, we can make this configurable in the future
240246 input_json_schema = ensure_strict_json_schema (input_json_schema )
241247
248+ async def _is_enabled (ctx : RunContextWrapper [Any ], agent_base : AgentBase [Any ]) -> bool :
249+ from .agent import Agent
250+
251+ assert callable (is_enabled ), "is_enabled must be non-null here"
252+ assert isinstance (agent_base , Agent ), "Can't handoff to a non-Agent"
253+ result = is_enabled (ctx , agent_base )
254+
255+ if inspect .isawaitable (result ):
256+ return await result
257+
258+ return result
259+
242260 return Handoff (
243261 tool_name = tool_name ,
244262 tool_description = tool_description ,
245263 input_json_schema = input_json_schema ,
246264 on_invoke_handoff = _invoke_handoff ,
247265 input_filter = input_filter ,
248266 agent_name = agent .name ,
249- is_enabled = is_enabled ,
267+ is_enabled = _is_enabled if callable ( is_enabled ) else is_enabled ,
250268 )
0 commit comments