Skip to content

Commit 78e74b5

Browse files
wukathcopybara-github
authored andcommitted
feat: Add require_confirmation param for MCP tool/toolset
This allows users to require human approval for using MCP tools. PiperOrigin-RevId: 819800747
1 parent d82c492 commit 78e74b5

File tree

4 files changed

+153
-0
lines changed

4 files changed

+153
-0
lines changed

contributing/samples/mcp_sse_agent/agent.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
'get_file_info',
5454
'list_allowed_directories',
5555
],
56+
require_confirmation=True,
5657
)
5758
],
5859
)

src/google/adk/tools/mcp_tool/mcp_tool.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,12 @@
1515
from __future__ import annotations
1616

1717
import base64
18+
import inspect
1819
import logging
20+
from typing import Any
21+
from typing import Callable
1922
from typing import Optional
23+
from typing import Union
2024
import warnings
2125

2226
from fastapi.openapi.models import APIKeyIn
@@ -70,6 +74,7 @@ def __init__(
7074
mcp_session_manager: MCPSessionManager,
7175
auth_scheme: Optional[AuthScheme] = None,
7276
auth_credential: Optional[AuthCredential] = None,
77+
require_confirmation: Union[bool, Callable[..., bool]] = False,
7378
):
7479
"""Initializes an MCPTool.
7580
@@ -81,6 +86,10 @@ def __init__(
8186
mcp_session_manager: The MCP session manager to use for communication.
8287
auth_scheme: The authentication scheme to use.
8388
auth_credential: The authentication credential to use.
89+
require_confirmation: Whether this tool requires confirmation. A boolean
90+
or a callable that takes the function's arguments and returns a
91+
boolean. If the callable returns True, the tool will require
92+
confirmation from the user.
8493
8594
Raises:
8695
ValueError: If mcp_tool or mcp_session_manager is None.
@@ -96,6 +105,7 @@ def __init__(
96105
)
97106
self._mcp_tool = mcp_tool
98107
self._mcp_session_manager = mcp_session_manager
108+
self._require_confirmation = require_confirmation
99109

100110
@override
101111
def _get_declaration(self) -> FunctionDeclaration:
@@ -116,6 +126,57 @@ def raw_mcp_tool(self) -> McpBaseTool:
116126
"""Returns the raw MCP tool."""
117127
return self._mcp_tool
118128

129+
async def _invoke_callable(
130+
self, target: Callable[..., Any], args_to_call: dict[str, Any]
131+
) -> Any:
132+
"""Invokes a callable, handling both sync and async cases."""
133+
134+
# Functions are callable objects, but not all callable objects are functions
135+
# checking coroutine function is not enough. We also need to check whether
136+
# Callable's __call__ function is a coroutine funciton
137+
is_async = inspect.iscoroutinefunction(target) or (
138+
hasattr(target, "__call__")
139+
and inspect.iscoroutinefunction(target.__call__)
140+
)
141+
if is_async:
142+
return await target(**args_to_call)
143+
else:
144+
return target(**args_to_call)
145+
146+
@override
147+
async def run_async(
148+
self, *, args: dict[str, Any], tool_context: ToolContext
149+
) -> Any:
150+
if isinstance(self._require_confirmation, Callable):
151+
require_confirmation = await self._invoke_callable(
152+
self._require_confirmation, args
153+
)
154+
else:
155+
require_confirmation = bool(self._require_confirmation)
156+
157+
if require_confirmation:
158+
if not tool_context.tool_confirmation:
159+
args_to_show = args.copy()
160+
if "tool_context" in args_to_show:
161+
args_to_show.pop("tool_context")
162+
163+
tool_context.request_confirmation(
164+
hint=(
165+
f"Please approve or reject the tool call {self.name}() by"
166+
" responding with a FunctionResponse with an expected"
167+
" ToolConfirmation payload."
168+
),
169+
)
170+
return {
171+
"error": (
172+
"This tool call requires confirmation, please approve or"
173+
" reject."
174+
)
175+
}
176+
elif not tool_context.tool_confirmation.confirmed:
177+
return {"error": "This tool call is rejected."}
178+
return await super().run_async(args=args, tool_context=tool_context)
179+
119180
@retry_on_closed_resource
120181
@override
121182
async def _run_async_impl(

src/google/adk/tools/mcp_tool/mcp_toolset.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import logging
1818
import sys
19+
from typing import Callable
20+
from typing import Dict
1921
from typing import List
2022
from typing import Optional
2123
from typing import TextIO
@@ -104,6 +106,7 @@ def __init__(
104106
errlog: TextIO = sys.stderr,
105107
auth_scheme: Optional[AuthScheme] = None,
106108
auth_credential: Optional[AuthCredential] = None,
109+
require_confirmation: Union[bool, Callable[..., bool]] = False,
107110
):
108111
"""Initializes the MCPToolset.
109112
@@ -124,6 +127,9 @@ def __init__(
124127
errlog: TextIO stream for error logging.
125128
auth_scheme: The auth scheme of the tool for tool calling
126129
auth_credential: The auth credential of the tool for tool calling
130+
require_confirmation: Whether tools in this toolset require
131+
confirmation. Can be a single boolean or a callable to apply to all
132+
tools.
127133
"""
128134
super().__init__(tool_filter=tool_filter, tool_name_prefix=tool_name_prefix)
129135

@@ -140,6 +146,7 @@ def __init__(
140146
)
141147
self._auth_scheme = auth_scheme
142148
self._auth_credential = auth_credential
149+
self._require_confirmation = require_confirmation
143150

144151
@retry_on_closed_resource
145152
async def get_tools(
@@ -169,6 +176,7 @@ async def get_tools(
169176
mcp_session_manager=self._mcp_session_manager,
170177
auth_scheme=self._auth_scheme,
171178
auth_credential=self._auth_credential,
179+
require_confirmation=self._require_confirmation,
172180
)
173181

174182
if self._is_tool_selected(mcp_tool, readonly_context):

tests/unittests/tools/mcp_tool/test_mcp_tool.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,89 @@ async def test_get_headers_api_key_error_logging(self):
549549
in logged_message
550550
)
551551

552+
@pytest.mark.asyncio
553+
async def test_run_async_require_confirmation_true_no_confirmation(self):
554+
"""Test require_confirmation=True with no confirmation in context."""
555+
tool = MCPTool(
556+
mcp_tool=self.mock_mcp_tool,
557+
mcp_session_manager=self.mock_session_manager,
558+
require_confirmation=True,
559+
)
560+
tool_context = Mock(spec=ToolContext)
561+
tool_context.tool_confirmation = None
562+
tool_context.request_confirmation = Mock()
563+
args = {"param1": "test_value"}
564+
565+
result = await tool.run_async(args=args, tool_context=tool_context)
566+
567+
assert result == {
568+
"error": (
569+
"This tool call requires confirmation, please approve or reject."
570+
)
571+
}
572+
tool_context.request_confirmation.assert_called_once()
573+
574+
@pytest.mark.asyncio
575+
async def test_run_async_require_confirmation_true_rejected(self):
576+
"""Test require_confirmation=True with rejection in context."""
577+
tool = MCPTool(
578+
mcp_tool=self.mock_mcp_tool,
579+
mcp_session_manager=self.mock_session_manager,
580+
require_confirmation=True,
581+
)
582+
tool_context = Mock(spec=ToolContext)
583+
tool_context.tool_confirmation = Mock(confirmed=False)
584+
args = {"param1": "test_value"}
585+
586+
result = await tool.run_async(args=args, tool_context=tool_context)
587+
588+
assert result == {"error": "This tool call is rejected."}
589+
590+
@pytest.mark.asyncio
591+
async def test_run_async_require_confirmation_true_confirmed(self):
592+
"""Test require_confirmation=True with confirmation in context."""
593+
tool = MCPTool(
594+
mcp_tool=self.mock_mcp_tool,
595+
mcp_session_manager=self.mock_session_manager,
596+
require_confirmation=True,
597+
)
598+
tool_context = Mock(spec=ToolContext)
599+
tool_context.tool_confirmation = Mock(confirmed=True)
600+
args = {"param1": "test_value"}
601+
602+
with patch(
603+
"google.adk.tools.base_authenticated_tool.BaseAuthenticatedTool.run_async",
604+
new_callable=AsyncMock,
605+
) as mock_super_run_async:
606+
await tool.run_async(args=args, tool_context=tool_context)
607+
mock_super_run_async.assert_called_once_with(
608+
args=args, tool_context=tool_context
609+
)
610+
611+
@pytest.mark.asyncio
612+
async def test_run_async_require_confirmation_callable_true_no_confirmation(
613+
self,
614+
):
615+
"""Test require_confirmation=callable with no confirmation in context."""
616+
tool = MCPTool(
617+
mcp_tool=self.mock_mcp_tool,
618+
mcp_session_manager=self.mock_session_manager,
619+
require_confirmation=lambda **kwargs: True,
620+
)
621+
tool_context = Mock(spec=ToolContext)
622+
tool_context.tool_confirmation = None
623+
tool_context.request_confirmation = Mock()
624+
args = {"param1": "test_value"}
625+
626+
result = await tool.run_async(args=args, tool_context=tool_context)
627+
628+
assert result == {
629+
"error": (
630+
"This tool call requires confirmation, please approve or reject."
631+
)
632+
}
633+
tool_context.request_confirmation.assert_called_once()
634+
552635
def test_init_validation(self):
553636
"""Test that initialization validates required parameters."""
554637
# This test ensures that the MCPTool properly handles its dependencies

0 commit comments

Comments
 (0)