@@ -26,9 +26,7 @@ async def test_argless_function():
26
26
tool = function_tool (argless_function )
27
27
assert tool .name == "argless_function"
28
28
29
- result = await tool .on_invoke_tool (
30
- ToolContext (context = None , tool_name = tool .name , tool_call_id = "1" ), ""
31
- )
29
+ result = await tool .on_invoke_tool (ToolContext (context = None , tool_call_id = "1" ), "" )
32
30
assert result == "ok"
33
31
34
32
@@ -41,13 +39,11 @@ async def test_argless_with_context():
41
39
tool = function_tool (argless_with_context )
42
40
assert tool .name == "argless_with_context"
43
41
44
- result = await tool .on_invoke_tool (ToolContext (None , tool_name = tool . name , tool_call_id = "1" ), "" )
42
+ result = await tool .on_invoke_tool (ToolContext (None , tool_call_id = "1" ), "" )
45
43
assert result == "ok"
46
44
47
45
# Extra JSON should not raise an error
48
- result = await tool .on_invoke_tool (
49
- ToolContext (None , tool_name = tool .name , tool_call_id = "1" ), '{"a": 1}'
50
- )
46
+ result = await tool .on_invoke_tool (ToolContext (None , tool_call_id = "1" ), '{"a": 1}' )
51
47
assert result == "ok"
52
48
53
49
@@ -60,19 +56,15 @@ async def test_simple_function():
60
56
tool = function_tool (simple_function , failure_error_function = None )
61
57
assert tool .name == "simple_function"
62
58
63
- result = await tool .on_invoke_tool (
64
- ToolContext (None , tool_name = tool .name , tool_call_id = "1" ), '{"a": 1}'
65
- )
59
+ result = await tool .on_invoke_tool (ToolContext (None , tool_call_id = "1" ), '{"a": 1}' )
66
60
assert result == 6
67
61
68
- result = await tool .on_invoke_tool (
69
- ToolContext (None , tool_name = tool .name , tool_call_id = "1" ), '{"a": 1, "b": 2}'
70
- )
62
+ result = await tool .on_invoke_tool (ToolContext (None , tool_call_id = "1" ), '{"a": 1, "b": 2}' )
71
63
assert result == 3
72
64
73
65
# Missing required argument should raise an error
74
66
with pytest .raises (ModelBehaviorError ):
75
- await tool .on_invoke_tool (ToolContext (None , tool_name = tool . name , tool_call_id = "1" ), "" )
67
+ await tool .on_invoke_tool (ToolContext (None , tool_call_id = "1" ), "" )
76
68
77
69
78
70
class Foo (BaseModel ):
@@ -100,9 +92,7 @@ async def test_complex_args_function():
100
92
"bar" : Bar (x = "hello" , y = 10 ),
101
93
}
102
94
)
103
- result = await tool .on_invoke_tool (
104
- ToolContext (None , tool_name = tool .name , tool_call_id = "1" ), valid_json
105
- )
95
+ result = await tool .on_invoke_tool (ToolContext (None , tool_call_id = "1" ), valid_json )
106
96
assert result == "6 hello10 hello"
107
97
108
98
valid_json = json .dumps (
@@ -111,9 +101,7 @@ async def test_complex_args_function():
111
101
"bar" : Bar (x = "hello" , y = 10 ),
112
102
}
113
103
)
114
- result = await tool .on_invoke_tool (
115
- ToolContext (None , tool_name = tool .name , tool_call_id = "1" ), valid_json
116
- )
104
+ result = await tool .on_invoke_tool (ToolContext (None , tool_call_id = "1" ), valid_json )
117
105
assert result == "3 hello10 hello"
118
106
119
107
valid_json = json .dumps (
@@ -123,16 +111,12 @@ async def test_complex_args_function():
123
111
"baz" : "world" ,
124
112
}
125
113
)
126
- result = await tool .on_invoke_tool (
127
- ToolContext (None , tool_name = tool .name , tool_call_id = "1" ), valid_json
128
- )
114
+ result = await tool .on_invoke_tool (ToolContext (None , tool_call_id = "1" ), valid_json )
129
115
assert result == "3 hello10 world"
130
116
131
117
# Missing required argument should raise an error
132
118
with pytest .raises (ModelBehaviorError ):
133
- await tool .on_invoke_tool (
134
- ToolContext (None , tool_name = tool .name , tool_call_id = "1" ), '{"foo": {"a": 1}}'
135
- )
119
+ await tool .on_invoke_tool (ToolContext (None , tool_call_id = "1" ), '{"foo": {"a": 1}}' )
136
120
137
121
138
122
def test_function_config_overrides ():
@@ -192,9 +176,7 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str:
192
176
assert tool .params_json_schema [key ] == value
193
177
assert tool .strict_json_schema
194
178
195
- result = await tool .on_invoke_tool (
196
- ToolContext (None , tool_name = tool .name , tool_call_id = "1" ), '{"data": "hello"}'
197
- )
179
+ result = await tool .on_invoke_tool (ToolContext (None , tool_call_id = "1" ), '{"data": "hello"}' )
198
180
assert result == "hello_done"
199
181
200
182
tool_not_strict = FunctionTool (
@@ -209,8 +191,7 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str:
209
191
assert "additionalProperties" not in tool_not_strict .params_json_schema
210
192
211
193
result = await tool_not_strict .on_invoke_tool (
212
- ToolContext (None , tool_name = tool_not_strict .name , tool_call_id = "1" ),
213
- '{"data": "hello", "bar": "baz"}' ,
194
+ ToolContext (None , tool_call_id = "1" ), '{"data": "hello", "bar": "baz"}'
214
195
)
215
196
assert result == "hello_done"
216
197
@@ -221,7 +202,7 @@ def my_func(a: int, b: int = 5):
221
202
raise ValueError ("test" )
222
203
223
204
tool = function_tool (my_func )
224
- ctx = ToolContext (None , tool_name = tool . name , tool_call_id = "1" )
205
+ ctx = ToolContext (None , tool_call_id = "1" )
225
206
226
207
result = await tool .on_invoke_tool (ctx , "" )
227
208
assert "Invalid JSON" in str (result )
@@ -245,7 +226,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) ->
245
226
return f"error_{ error .__class__ .__name__ } "
246
227
247
228
tool = function_tool (my_func , failure_error_function = custom_sync_error_function )
248
- ctx = ToolContext (None , tool_name = tool . name , tool_call_id = "1" )
229
+ ctx = ToolContext (None , tool_call_id = "1" )
249
230
250
231
result = await tool .on_invoke_tool (ctx , "" )
251
232
assert result == "error_ModelBehaviorError"
@@ -269,7 +250,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) ->
269
250
return f"error_{ error .__class__ .__name__ } "
270
251
271
252
tool = function_tool (my_func , failure_error_function = custom_sync_error_function )
272
- ctx = ToolContext (None , tool_name = tool . name , tool_call_id = "1" )
253
+ ctx = ToolContext (None , tool_call_id = "1" )
273
254
274
255
result = await tool .on_invoke_tool (ctx , "" )
275
256
assert result == "error_ModelBehaviorError"
0 commit comments