Skip to content

Commit 36ca4f1

Browse files
xuanyang15copybara-github
authored andcommitted
feat: Add on_tool_error_callback in LlmAgent
PiperOrigin-RevId: 824598502
1 parent 496f8cd commit 36ca4f1

File tree

3 files changed

+248
-14
lines changed

3 files changed

+248
-14
lines changed

src/google/adk/agents/llm_agent.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,16 @@
105105
list[_SingleAfterToolCallback],
106106
]
107107

108+
_SingleOnToolErrorCallback: TypeAlias = Callable[
109+
[BaseTool, dict[str, Any], ToolContext, Exception],
110+
Union[Awaitable[Optional[dict]], Optional[dict]],
111+
]
112+
113+
OnToolErrorCallback: TypeAlias = Union[
114+
_SingleOnToolErrorCallback,
115+
list[_SingleOnToolErrorCallback],
116+
]
117+
108118
InstructionProvider: TypeAlias = Callable[
109119
[ReadonlyContext], Union[str, Awaitable[str]]
110120
]
@@ -381,6 +391,21 @@ class LlmAgent(BaseAgent):
381391
tool_context: ToolContext,
382392
tool_response: The response from the tool.
383393
394+
Returns:
395+
When present, the returned dict will be used as tool result.
396+
"""
397+
on_tool_error_callback: Optional[OnToolErrorCallback] = None
398+
"""Callback or list of callbacks to be called when a tool call encounters an error.
399+
400+
When a list of callbacks is provided, the callbacks will be called in the
401+
order they are listed until a callback does not return None.
402+
403+
Args:
404+
tool: The tool to be called.
405+
args: The arguments to the tool.
406+
tool_context: ToolContext,
407+
error: The error from the tool call.
408+
384409
Returns:
385410
When present, the returned dict will be used as tool result.
386411
"""
@@ -582,6 +607,20 @@ def canonical_after_tool_callbacks(
582607
return self.after_tool_callback
583608
return [self.after_tool_callback]
584609

610+
@property
611+
def canonical_on_tool_error_callbacks(
612+
self,
613+
) -> list[OnToolErrorCallback]:
614+
"""The resolved self.on_tool_error_callback field as a list of OnToolErrorCallback.
615+
616+
This method is only for use by Agent Development Kit.
617+
"""
618+
if not self.on_tool_error_callback:
619+
return []
620+
if isinstance(self.on_tool_error_callback, list):
621+
return self.on_tool_error_callback
622+
return [self.on_tool_error_callback]
623+
585624
@property
586625
def _llm_flow(self) -> BaseLlmFlow:
587626
if (

src/google/adk/flows/llm_flows/functions.py

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,40 @@ async def _execute_single_function_call_async(
275275
tool_confirmation: Optional[ToolConfirmation] = None,
276276
) -> Optional[Event]:
277277
"""Execute a single function call with thread safety for state modifications."""
278+
279+
async def _run_on_tool_error_callbacks(
280+
*,
281+
tool: BaseTool,
282+
tool_args: dict[str, Any],
283+
tool_context: ToolContext,
284+
error: Exception,
285+
) -> Optional[dict[str, Any]]:
286+
"""Runs the on_tool_error_callbacks for the given tool."""
287+
error_response = (
288+
await invocation_context.plugin_manager.run_on_tool_error_callback(
289+
tool=tool,
290+
tool_args=tool_args,
291+
tool_context=tool_context,
292+
error=error,
293+
)
294+
)
295+
if error_response is not None:
296+
return error_response
297+
298+
for callback in agent.canonical_on_tool_error_callbacks:
299+
error_response = callback(
300+
tool=tool,
301+
args=tool_args,
302+
tool_context=tool_context,
303+
error=error,
304+
)
305+
if inspect.isawaitable(error_response):
306+
error_response = await error_response
307+
if error_response is not None:
308+
return error_response
309+
310+
return None
311+
278312
# Do not use "args" as the variable name, because it is a reserved keyword
279313
# in python debugger.
280314
# Make a deep copy to avoid being modified.
@@ -290,13 +324,11 @@ async def _execute_single_function_call_async(
290324
tool = _get_tool(function_call, tools_dict)
291325
except ValueError as tool_error:
292326
tool = BaseTool(name=function_call.name, description='Tool not found')
293-
error_response = (
294-
await invocation_context.plugin_manager.run_on_tool_error_callback(
295-
tool=tool,
296-
tool_args=function_args,
297-
tool_context=tool_context,
298-
error=tool_error,
299-
)
327+
error_response = await _run_on_tool_error_callbacks(
328+
tool=tool,
329+
tool_args=function_args,
330+
tool_context=tool_context,
331+
error=tool_error,
300332
)
301333
if error_response is not None:
302334
return __build_response_event(
@@ -335,13 +367,11 @@ async def _run_with_trace():
335367
tool, args=function_args, tool_context=tool_context
336368
)
337369
except Exception as tool_error:
338-
error_response = (
339-
await invocation_context.plugin_manager.run_on_tool_error_callback(
340-
tool=tool,
341-
tool_args=function_args,
342-
tool_context=tool_context,
343-
error=tool_error,
344-
)
370+
error_response = await _run_on_tool_error_callbacks(
371+
tool=tool,
372+
tool_args=function_args,
373+
tool_context=tool_context,
374+
error=tool_error,
345375
)
346376
if error_response is not None:
347377
function_response = error_response

tests/unittests/flows/llm_flows/test_tool_callbacks.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from google.genai import types
2121
from google.genai.types import Part
2222
from pydantic import BaseModel
23+
import pytest
2324

2425
from ... import testing_utils
2526

@@ -28,7 +29,13 @@ def simple_function(input_str: str) -> str:
2829
return {'result': input_str}
2930

3031

32+
def simple_function_with_error() -> str:
33+
raise SystemError('simple_function_with_error')
34+
35+
3136
class MockBeforeToolCallback(BaseModel):
37+
"""Mock before tool callback."""
38+
3239
mock_response: dict[str, object]
3340
modify_tool_request: bool = False
3441

@@ -45,6 +52,8 @@ def __call__(
4552

4653

4754
class MockAfterToolCallback(BaseModel):
55+
"""Mock after tool callback."""
56+
4857
mock_response: dict[str, object]
4958
modify_tool_request: bool = False
5059
modify_tool_response: bool = False
@@ -65,13 +74,32 @@ def __call__(
6574
return self.mock_response
6675

6776

77+
class MockOnToolErrorCallback(BaseModel):
78+
"""Mock on tool error callback."""
79+
80+
mock_response: dict[str, object]
81+
modify_tool_response: bool = False
82+
83+
def __call__(
84+
self,
85+
tool: BaseTool,
86+
args: dict[str, Any],
87+
tool_context: ToolContext,
88+
error: Exception,
89+
) -> dict[str, object]:
90+
if self.modify_tool_response:
91+
return self.mock_response
92+
return None
93+
94+
6895
def noop_callback(
6996
**kwargs,
7097
) -> dict[str, object]:
7198
pass
7299

73100

74101
def test_before_tool_callback():
102+
"""Test that the before_tool_callback is called before the tool is called."""
75103
responses = [
76104
types.Part.from_function_call(name='simple_function', args={}),
77105
'response1',
@@ -100,6 +128,7 @@ def test_before_tool_callback():
100128

101129

102130
def test_before_tool_callback_noop():
131+
"""Test that the before_tool_callback is a no-op when not overridden."""
103132
responses = [
104133
types.Part.from_function_call(
105134
name='simple_function', args={'input_str': 'simple_function_call'}
@@ -134,6 +163,7 @@ def test_before_tool_callback_noop():
134163

135164

136165
def test_before_tool_callback_modify_tool_request():
166+
"""Test that the before_tool_callback modifies the tool request."""
137167
responses = [
138168
types.Part.from_function_call(name='simple_function', args={}),
139169
'response1',
@@ -164,6 +194,7 @@ def test_before_tool_callback_modify_tool_request():
164194

165195

166196
def test_after_tool_callback():
197+
"""Test that the after_tool_callback is called after the tool is called."""
167198
responses = [
168199
types.Part.from_function_call(
169200
name='simple_function', args={'input_str': 'simple_function_call'}
@@ -199,6 +230,7 @@ def test_after_tool_callback():
199230

200231

201232
def test_after_tool_callback_noop():
233+
"""Test that the after_tool_callback is a no-op when not overridden."""
202234
responses = [
203235
types.Part.from_function_call(
204236
name='simple_function', args={'input_str': 'simple_function_call'}
@@ -233,6 +265,7 @@ def test_after_tool_callback_noop():
233265

234266

235267
def test_after_tool_callback_modify_tool_response():
268+
"""Test that the after_tool_callback modifies the tool response."""
236269
responses = [
237270
types.Part.from_function_call(
238271
name='simple_function', args={'input_str': 'simple_function_call'}
@@ -267,3 +300,135 @@ def test_after_tool_callback_modify_tool_response():
267300
),
268301
('root_agent', 'response1'),
269302
]
303+
304+
305+
async def test_on_tool_error_callback_tool_not_found_noop():
306+
"""Test that the on_tool_error_callback is a no-op when the tool is not found."""
307+
responses = [
308+
types.Part.from_function_call(
309+
name='nonexistent_function',
310+
args={'input_str': 'simple_function_call'},
311+
),
312+
'response1',
313+
]
314+
mock_model = testing_utils.MockModel.create(responses=responses)
315+
agent = Agent(
316+
name='root_agent',
317+
model=mock_model,
318+
on_tool_error_callback=noop_callback,
319+
tools=[simple_function],
320+
)
321+
322+
runner = testing_utils.InMemoryRunner(agent)
323+
with pytest.raises(ValueError):
324+
await runner.run_async('test')
325+
326+
327+
def test_on_tool_error_callback_tool_not_found_modify_tool_response():
328+
"""Test that the on_tool_error_callback modifies the tool response when the tool is not found."""
329+
responses = [
330+
types.Part.from_function_call(
331+
name='nonexistent_function',
332+
args={'input_str': 'simple_function_call'},
333+
),
334+
'response1',
335+
]
336+
mock_model = testing_utils.MockModel.create(responses=responses)
337+
agent = Agent(
338+
name='root_agent',
339+
model=mock_model,
340+
on_tool_error_callback=MockOnToolErrorCallback(
341+
mock_response={'result': 'on_tool_error_callback_response'},
342+
modify_tool_response=True,
343+
),
344+
tools=[simple_function],
345+
)
346+
347+
runner = testing_utils.InMemoryRunner(agent)
348+
assert testing_utils.simplify_events(runner.run('test')) == [
349+
(
350+
'root_agent',
351+
Part.from_function_call(
352+
name='nonexistent_function',
353+
args={'input_str': 'simple_function_call'},
354+
),
355+
),
356+
(
357+
'root_agent',
358+
Part.from_function_response(
359+
name='nonexistent_function',
360+
response={'result': 'on_tool_error_callback_response'},
361+
),
362+
),
363+
('root_agent', 'response1'),
364+
]
365+
366+
367+
async def test_on_tool_error_callback_tool_error_noop():
368+
"""Test that the on_tool_error_callback is a no-op when the tool returns an error."""
369+
responses = [
370+
types.Part.from_function_call(
371+
name='simple_function_with_error',
372+
args={},
373+
),
374+
'response1',
375+
]
376+
mock_model = testing_utils.MockModel.create(responses=responses)
377+
agent = Agent(
378+
name='root_agent',
379+
model=mock_model,
380+
on_tool_error_callback=noop_callback,
381+
tools=[simple_function_with_error],
382+
)
383+
384+
runner = testing_utils.InMemoryRunner(agent)
385+
with pytest.raises(SystemError):
386+
await runner.run_async('test')
387+
388+
389+
def test_on_tool_error_callback_tool_error_modify_tool_response():
390+
"""Test that the on_tool_error_callback modifies the tool response when the tool returns an error."""
391+
392+
async def async_on_tool_error_callback(
393+
tool: BaseTool,
394+
args: dict[str, Any],
395+
tool_context: ToolContext,
396+
error: Exception,
397+
) -> dict[str, object]:
398+
if tool.name == 'simple_function_with_error':
399+
return {'result': 'async_on_tool_error_callback_response'}
400+
return None
401+
402+
responses = [
403+
types.Part.from_function_call(
404+
name='simple_function_with_error',
405+
args={},
406+
),
407+
'response1',
408+
]
409+
mock_model = testing_utils.MockModel.create(responses=responses)
410+
agent = Agent(
411+
name='root_agent',
412+
model=mock_model,
413+
on_tool_error_callback=async_on_tool_error_callback,
414+
tools=[simple_function_with_error],
415+
)
416+
417+
runner = testing_utils.InMemoryRunner(agent)
418+
assert testing_utils.simplify_events(runner.run('test')) == [
419+
(
420+
'root_agent',
421+
Part.from_function_call(
422+
name='simple_function_with_error',
423+
args={},
424+
),
425+
),
426+
(
427+
'root_agent',
428+
Part.from_function_response(
429+
name='simple_function_with_error',
430+
response={'result': 'async_on_tool_error_callback_response'},
431+
),
432+
),
433+
('root_agent', 'response1'),
434+
]

0 commit comments

Comments
 (0)