Skip to content

Commit bab1270

Browse files
authored
hooks - before tool call event - cancel tool (#964)
1 parent 08dc4ae commit bab1270

File tree

7 files changed

+140
-4
lines changed

7 files changed

+140
-4
lines changed

src/strands/hooks/events.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,18 @@ class BeforeToolCallEvent(HookEvent):
9797
to change which tool gets executed. This may be None if tool lookup failed.
9898
tool_use: The tool parameters that will be passed to selected_tool.
9999
invocation_state: Keyword arguments that will be passed to the tool.
100+
cancel_tool: A user defined message that when set, will cancel the tool call.
101+
The message will be placed into a tool result with an error status. If set to `True`, Strands will cancel
102+
the tool call and use a default cancel message.
100103
"""
101104

102105
selected_tool: Optional[AgentTool]
103106
tool_use: ToolUse
104107
invocation_state: dict[str, Any]
108+
cancel_tool: bool | str = False
105109

106110
def _can_write(self, name: str) -> bool:
107-
return name in ["selected_tool", "tool_use"]
111+
return name in ["cancel_tool", "selected_tool", "tool_use"]
108112

109113

110114
@dataclass
@@ -124,13 +128,15 @@ class AfterToolCallEvent(HookEvent):
124128
invocation_state: Keyword arguments that were passed to the tool
125129
result: The result of the tool invocation. Either a ToolResult on success
126130
or an Exception if the tool execution failed.
131+
cancel_message: The cancellation message if the user cancelled the tool call.
127132
"""
128133

129134
selected_tool: Optional[AgentTool]
130135
tool_use: ToolUse
131136
invocation_state: dict[str, Any]
132137
result: ToolResult
133138
exception: Optional[Exception] = None
139+
cancel_message: str | None = None
134140

135141
def _can_write(self, name: str) -> bool:
136142
return name == "result"

src/strands/tools/executors/_executor.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from ...hooks import AfterToolCallEvent, BeforeToolCallEvent
1515
from ...telemetry.metrics import Trace
1616
from ...telemetry.tracer import get_tracer
17-
from ...types._events import ToolResultEvent, ToolStreamEvent, TypedEvent
17+
from ...types._events import ToolCancelEvent, ToolResultEvent, ToolStreamEvent, TypedEvent
1818
from ...types.content import Message
1919
from ...types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolResult, ToolUse
2020

@@ -81,6 +81,31 @@ async def _stream(
8181
)
8282
)
8383

84+
if before_event.cancel_tool:
85+
cancel_message = (
86+
before_event.cancel_tool if isinstance(before_event.cancel_tool, str) else "tool cancelled by user"
87+
)
88+
yield ToolCancelEvent(tool_use, cancel_message)
89+
90+
cancel_result: ToolResult = {
91+
"toolUseId": str(tool_use.get("toolUseId")),
92+
"status": "error",
93+
"content": [{"text": cancel_message}],
94+
}
95+
after_event = agent.hooks.invoke_callbacks(
96+
AfterToolCallEvent(
97+
agent=agent,
98+
tool_use=tool_use,
99+
invocation_state=invocation_state,
100+
selected_tool=None,
101+
result=cancel_result,
102+
cancel_message=cancel_message,
103+
)
104+
)
105+
yield ToolResultEvent(after_event.result)
106+
tool_results.append(after_event.result)
107+
return
108+
84109
try:
85110
selected_tool = before_event.selected_tool
86111
tool_use = before_event.tool_use
@@ -123,7 +148,7 @@ async def _stream(
123148
# so that we don't needlessly yield ToolStreamEvents for non-generator callbacks.
124149
# In which case, as soon as we get a ToolResultEvent we're done and for ToolStreamEvent
125150
# we yield it directly; all other cases (non-sdk AgentTools), we wrap events in
126-
# ToolStreamEvent and the last even is just the result
151+
# ToolStreamEvent and the last event is just the result.
127152

128153
if isinstance(event, ToolResultEvent):
129154
# below the last "event" must point to the tool_result

src/strands/types/_events.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,29 @@ def tool_use_id(self) -> str:
298298
return cast(str, cast(ToolUse, cast(dict, self.get("tool_stream_event")).get("tool_use")).get("toolUseId"))
299299

300300

301+
class ToolCancelEvent(TypedEvent):
302+
"""Event emitted when a user cancels a tool call from their BeforeToolCallEvent hook."""
303+
304+
def __init__(self, tool_use: ToolUse, message: str) -> None:
305+
"""Initialize with tool streaming data.
306+
307+
Args:
308+
tool_use: Information about the tool being cancelled
309+
message: The tool cancellation message
310+
"""
311+
super().__init__({"tool_cancel_event": {"tool_use": tool_use, "message": message}})
312+
313+
@property
314+
def tool_use_id(self) -> str:
315+
"""The id of the tool cancelled."""
316+
return cast(str, cast(ToolUse, cast(dict, self.get("tool_cancelled_event")).get("tool_use")).get("toolUseId"))
317+
318+
@property
319+
def message(self) -> str:
320+
"""The tool cancellation message."""
321+
return cast(str, self["message"])
322+
323+
301324
class ModelMessageEvent(TypedEvent):
302325
"""Event emitted when the model invocation has completed.
303326

tests/strands/tools/executors/test_executor.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from strands.hooks import AfterToolCallEvent, BeforeToolCallEvent
88
from strands.telemetry.metrics import Trace
99
from strands.tools.executors._executor import ToolExecutor
10-
from strands.types._events import ToolResultEvent, ToolStreamEvent
10+
from strands.types._events import ToolCancelEvent, ToolResultEvent, ToolStreamEvent
1111
from strands.types.tools import ToolUse
1212

1313

@@ -215,3 +215,38 @@ async def test_executor_stream_with_trace(
215215

216216
cycle_trace.add_child.assert_called_once()
217217
assert isinstance(cycle_trace.add_child.call_args[0][0], Trace)
218+
219+
220+
@pytest.mark.parametrize(
221+
("cancel_tool", "cancel_message"),
222+
[(True, "tool cancelled by user"), ("user cancel message", "user cancel message")],
223+
)
224+
@pytest.mark.asyncio
225+
async def test_executor_stream_cancel(
226+
cancel_tool, cancel_message, executor, agent, tool_results, invocation_state, alist
227+
):
228+
def cancel_callback(event):
229+
event.cancel_tool = cancel_tool
230+
return event
231+
232+
agent.hooks.add_callback(BeforeToolCallEvent, cancel_callback)
233+
tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}}
234+
235+
stream = executor._stream(agent, tool_use, tool_results, invocation_state)
236+
237+
tru_events = await alist(stream)
238+
exp_events = [
239+
ToolCancelEvent(tool_use, cancel_message),
240+
ToolResultEvent(
241+
{
242+
"toolUseId": "1",
243+
"status": "error",
244+
"content": [{"text": cancel_message}],
245+
},
246+
),
247+
]
248+
assert tru_events == exp_events
249+
250+
tru_results = tool_results
251+
exp_results = [exp_events[-1].tool_result]
252+
assert tru_results == exp_results
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import pytest
2+
3+
from strands.hooks import BeforeToolCallEvent, HookProvider
4+
5+
6+
@pytest.fixture
7+
def cancel_hook():
8+
class Hook(HookProvider):
9+
def register_hooks(self, registry):
10+
registry.add_callback(BeforeToolCallEvent, self.cancel)
11+
12+
def cancel(self, event):
13+
event.cancel_tool = "cancelled tool call"
14+
15+
return Hook()

tests_integ/tools/executors/test_concurrent.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import json
23

34
import pytest
45

@@ -59,3 +60,18 @@ async def test_agent_invoke_async_tool_executor(agent, tool_events):
5960
{"name": "time_tool", "event": "end"},
6061
]
6162
assert tru_events == exp_events
63+
64+
65+
@pytest.mark.asyncio
66+
async def test_agent_stream_async_tool_executor_cancelled(cancel_hook, tool_executor, time_tool, tool_events):
67+
agent = Agent(tools=[time_tool], tool_executor=tool_executor, hooks=[cancel_hook])
68+
69+
exp_message = "cancelled tool call"
70+
tru_message = ""
71+
async for event in agent.stream_async("What is the time in New York?"):
72+
if "tool_cancel_event" in event:
73+
tru_message = event["tool_cancel_event"]["message"]
74+
75+
assert tru_message == exp_message
76+
assert len(tool_events) == 0
77+
assert exp_message in json.dumps(agent.messages)

tests_integ/tools/executors/test_sequential.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import json
23

34
import pytest
45

@@ -59,3 +60,18 @@ async def test_agent_invoke_async_tool_executor(agent, tool_events):
5960
{"name": "weather_tool", "event": "end"},
6061
]
6162
assert tru_events == exp_events
63+
64+
65+
@pytest.mark.asyncio
66+
async def test_agent_stream_async_tool_executor_cancelled(cancel_hook, tool_executor, time_tool, tool_events):
67+
agent = Agent(tools=[time_tool], tool_executor=tool_executor, hooks=[cancel_hook])
68+
69+
exp_message = "cancelled tool call"
70+
tru_message = ""
71+
async for event in agent.stream_async("What is the time in New York?"):
72+
if "tool_cancel_event" in event:
73+
tru_message = event["tool_cancel_event"]["message"]
74+
75+
assert tru_message == exp_message
76+
assert len(tool_events) == 0
77+
assert exp_message in json.dumps(agent.messages)

0 commit comments

Comments
 (0)