2020from google .genai import types
2121from google .genai .types import Part
2222from pydantic import BaseModel
23+ import pytest
2324
2425from ... import testing_utils
2526
@@ -28,7 +29,13 @@ def simple_function(input_str: str) -> str:
2829 return {'result' : input_str }
2930
3031
32+ def simple_function_with_error () -> str :
33+ raise SystemError ('simple_function_with_error' )
34+
35+
3136class MockBeforeToolCallback (BaseModel ):
37+ """Mock before tool callback."""
38+
3239 mock_response : dict [str , object ]
3340 modify_tool_request : bool = False
3441
@@ -45,6 +52,8 @@ def __call__(
4552
4653
4754class MockAfterToolCallback (BaseModel ):
55+ """Mock after tool callback."""
56+
4857 mock_response : dict [str , object ]
4958 modify_tool_request : bool = False
5059 modify_tool_response : bool = False
@@ -65,13 +74,32 @@ def __call__(
6574 return self .mock_response
6675
6776
77+ class MockOnToolErrorCallback (BaseModel ):
78+ """Mock on tool error callback."""
79+
80+ mock_response : dict [str , object ]
81+ modify_tool_response : bool = False
82+
83+ def __call__ (
84+ self ,
85+ tool : BaseTool ,
86+ args : dict [str , Any ],
87+ tool_context : ToolContext ,
88+ error : Exception ,
89+ ) -> dict [str , object ]:
90+ if self .modify_tool_response :
91+ return self .mock_response
92+ return None
93+
94+
6895def noop_callback (
6996 ** kwargs ,
7097) -> dict [str , object ]:
7198 pass
7299
73100
74101def test_before_tool_callback ():
102+ """Test that the before_tool_callback is called before the tool is called."""
75103 responses = [
76104 types .Part .from_function_call (name = 'simple_function' , args = {}),
77105 'response1' ,
@@ -100,6 +128,7 @@ def test_before_tool_callback():
100128
101129
102130def test_before_tool_callback_noop ():
131+ """Test that the before_tool_callback is a no-op when not overridden."""
103132 responses = [
104133 types .Part .from_function_call (
105134 name = 'simple_function' , args = {'input_str' : 'simple_function_call' }
@@ -134,6 +163,7 @@ def test_before_tool_callback_noop():
134163
135164
136165def test_before_tool_callback_modify_tool_request ():
166+ """Test that the before_tool_callback modifies the tool request."""
137167 responses = [
138168 types .Part .from_function_call (name = 'simple_function' , args = {}),
139169 'response1' ,
@@ -164,6 +194,7 @@ def test_before_tool_callback_modify_tool_request():
164194
165195
166196def test_after_tool_callback ():
197+ """Test that the after_tool_callback is called after the tool is called."""
167198 responses = [
168199 types .Part .from_function_call (
169200 name = 'simple_function' , args = {'input_str' : 'simple_function_call' }
@@ -199,6 +230,7 @@ def test_after_tool_callback():
199230
200231
201232def test_after_tool_callback_noop ():
233+ """Test that the after_tool_callback is a no-op when not overridden."""
202234 responses = [
203235 types .Part .from_function_call (
204236 name = 'simple_function' , args = {'input_str' : 'simple_function_call' }
@@ -233,6 +265,7 @@ def test_after_tool_callback_noop():
233265
234266
235267def test_after_tool_callback_modify_tool_response ():
268+ """Test that the after_tool_callback modifies the tool response."""
236269 responses = [
237270 types .Part .from_function_call (
238271 name = 'simple_function' , args = {'input_str' : 'simple_function_call' }
@@ -267,3 +300,135 @@ def test_after_tool_callback_modify_tool_response():
267300 ),
268301 ('root_agent' , 'response1' ),
269302 ]
303+
304+
305+ async def test_on_tool_error_callback_tool_not_found_noop ():
306+ """Test that the on_tool_error_callback is a no-op when the tool is not found."""
307+ responses = [
308+ types .Part .from_function_call (
309+ name = 'nonexistent_function' ,
310+ args = {'input_str' : 'simple_function_call' },
311+ ),
312+ 'response1' ,
313+ ]
314+ mock_model = testing_utils .MockModel .create (responses = responses )
315+ agent = Agent (
316+ name = 'root_agent' ,
317+ model = mock_model ,
318+ on_tool_error_callback = noop_callback ,
319+ tools = [simple_function ],
320+ )
321+
322+ runner = testing_utils .InMemoryRunner (agent )
323+ with pytest .raises (ValueError ):
324+ await runner .run_async ('test' )
325+
326+
327+ def test_on_tool_error_callback_tool_not_found_modify_tool_response ():
328+ """Test that the on_tool_error_callback modifies the tool response when the tool is not found."""
329+ responses = [
330+ types .Part .from_function_call (
331+ name = 'nonexistent_function' ,
332+ args = {'input_str' : 'simple_function_call' },
333+ ),
334+ 'response1' ,
335+ ]
336+ mock_model = testing_utils .MockModel .create (responses = responses )
337+ agent = Agent (
338+ name = 'root_agent' ,
339+ model = mock_model ,
340+ on_tool_error_callback = MockOnToolErrorCallback (
341+ mock_response = {'result' : 'on_tool_error_callback_response' },
342+ modify_tool_response = True ,
343+ ),
344+ tools = [simple_function ],
345+ )
346+
347+ runner = testing_utils .InMemoryRunner (agent )
348+ assert testing_utils .simplify_events (runner .run ('test' )) == [
349+ (
350+ 'root_agent' ,
351+ Part .from_function_call (
352+ name = 'nonexistent_function' ,
353+ args = {'input_str' : 'simple_function_call' },
354+ ),
355+ ),
356+ (
357+ 'root_agent' ,
358+ Part .from_function_response (
359+ name = 'nonexistent_function' ,
360+ response = {'result' : 'on_tool_error_callback_response' },
361+ ),
362+ ),
363+ ('root_agent' , 'response1' ),
364+ ]
365+
366+
367+ async def test_on_tool_error_callback_tool_error_noop ():
368+ """Test that the on_tool_error_callback is a no-op when the tool returns an error."""
369+ responses = [
370+ types .Part .from_function_call (
371+ name = 'simple_function_with_error' ,
372+ args = {},
373+ ),
374+ 'response1' ,
375+ ]
376+ mock_model = testing_utils .MockModel .create (responses = responses )
377+ agent = Agent (
378+ name = 'root_agent' ,
379+ model = mock_model ,
380+ on_tool_error_callback = noop_callback ,
381+ tools = [simple_function_with_error ],
382+ )
383+
384+ runner = testing_utils .InMemoryRunner (agent )
385+ with pytest .raises (SystemError ):
386+ await runner .run_async ('test' )
387+
388+
389+ def test_on_tool_error_callback_tool_error_modify_tool_response ():
390+ """Test that the on_tool_error_callback modifies the tool response when the tool returns an error."""
391+
392+ async def async_on_tool_error_callback (
393+ tool : BaseTool ,
394+ args : dict [str , Any ],
395+ tool_context : ToolContext ,
396+ error : Exception ,
397+ ) -> dict [str , object ]:
398+ if tool .name == 'simple_function_with_error' :
399+ return {'result' : 'async_on_tool_error_callback_response' }
400+ return None
401+
402+ responses = [
403+ types .Part .from_function_call (
404+ name = 'simple_function_with_error' ,
405+ args = {},
406+ ),
407+ 'response1' ,
408+ ]
409+ mock_model = testing_utils .MockModel .create (responses = responses )
410+ agent = Agent (
411+ name = 'root_agent' ,
412+ model = mock_model ,
413+ on_tool_error_callback = async_on_tool_error_callback ,
414+ tools = [simple_function_with_error ],
415+ )
416+
417+ runner = testing_utils .InMemoryRunner (agent )
418+ assert testing_utils .simplify_events (runner .run ('test' )) == [
419+ (
420+ 'root_agent' ,
421+ Part .from_function_call (
422+ name = 'simple_function_with_error' ,
423+ args = {},
424+ ),
425+ ),
426+ (
427+ 'root_agent' ,
428+ Part .from_function_response (
429+ name = 'simple_function_with_error' ,
430+ response = {'result' : 'async_on_tool_error_callback_response' },
431+ ),
432+ ),
433+ ('root_agent' , 'response1' ),
434+ ]
0 commit comments