diff --git a/README.md b/README.md index c5009f33..938ce999 100644 --- a/README.md +++ b/README.md @@ -349,3 +349,45 @@ fastmcp_tool = to_fastmcp(add) mcp = FastMCP("Math", tools=[fastmcp_tool]) mcp.run(transport="stdio") ``` + +## Passing InjectedToolArg to an MCP Tool + +By using the LangChain MCP Adapter on both the server and client sides, you can use `InjectedToolArg` to hide certain parameters from the LLM. + +```python +# server.py +from langchain_core.tools import tool + +data = { + 'user_0': 'Spike' +} + +@tool +async def get_user_pet_name(user_id: Annotated[str, InjectedToolArg]) -> str: + """Returns the user's pet name""" + + return data[user_id] + +fastmcp_tool = to_fastmcp(add) +mcp = FastMCP("Math", tools=[fastmcp_tool]) +mcp.run(transport="stdio") +``` + +And the user ID can be passed as part of the input, without the LLM knowledge: + +```python +# client.py + +client = MultiServerMCPClient( + ... +) + +tools = await client.get_tools() +agent = create_react_agent("openai:gpt-4.1", tools) +response = await agent.ainvoke( + { + "messages": "What is my dog's name?", + "user_id": "user_0" + } +) +``` \ No newline at end of file diff --git a/langchain_mcp_adapters/tools.py b/langchain_mcp_adapters/tools.py index a8bf0ebe..c942a96f 100644 --- a/langchain_mcp_adapters/tools.py +++ b/langchain_mcp_adapters/tools.py @@ -15,6 +15,7 @@ from langchain_core.tools.base import get_all_basemodel_annotations from mcp import ClientSession from mcp.server.fastmcp.tools import Tool as FastMCPTool +from mcp.server.fastmcp.server import Context from mcp.server.fastmcp.utilities.func_metadata import ArgModelBase, FuncMetadata from mcp.types import ( AudioContent, @@ -25,7 +26,7 @@ TextContent, ) from mcp.types import Tool as MCPTool -from pydantic import BaseModel, create_model +from pydantic import BaseModel, create_model, TypeAdapter from langchain_mcp_adapters.callbacks import CallbackContext, Callbacks, _MCPCallbacks from langchain_mcp_adapters.hooks import CallToolRequestSpec, Hooks, ToolHookContext @@ -44,6 +45,8 @@ def get_runtime() -> None: """no-op runtime getter.""" return +META_KEY_INJECT_ARGS_VALUE = 'langchain/injectedArgsValue' +META_KEY_INJECT_ARGS_SCHEMA = 'langchain/injectedArgsSchema' NonTextContent = ImageContent | AudioContent | ResourceLink | EmbeddedResource MAX_ITERATIONS = 1000 @@ -153,9 +156,20 @@ def convert_mcp_tool_to_langchain_tool( async def call_tool( **arguments: dict[str, Any], ) -> tuple[str | list[str], list[NonTextContent] | None]: + meta = _build_tool_calling_meta( + tool, + **arguments + ) + + arguments = _clear_injected_tool_args_from_arguments( + meta, + **arguments + ) + mcp_callbacks = ( callbacks.to_mcp_format( - context=CallbackContext(server_name=server_name, tool_name=tool.name) + context=CallbackContext( + server_name=server_name, tool_name=tool.name) ) if callbacks is not None else _MCPCallbacks() @@ -223,12 +237,14 @@ async def call_tool( tool_name, tool_args, progress_callback=mcp_callbacks.progress_callback, + meta=meta ) else: call_tool_result = await session.call_tool( tool_name, tool_args, progress_callback=mcp_callbacks.progress_callback, + meta=meta ) if call_tool_result is None: @@ -291,7 +307,8 @@ async def load_mcp_tools( raise ValueError(msg) mcp_callbacks = ( - callbacks.to_mcp_format(context=CallbackContext(server_name=server_name)) + callbacks.to_mcp_format( + context=CallbackContext(server_name=server_name)) if callbacks is not None else _MCPCallbacks() ) @@ -322,15 +339,8 @@ async def load_mcp_tools( ] -def _get_injected_args(tool: BaseTool) -> list[str]: - """Get the list of injected argument names from a LangChain tool. - - Args: - tool: The LangChain tool to inspect. - - Returns: - A list of injected argument names. - """ +def _get_injected_args_schema(tool: BaseTool) -> dict[str, dict[str, str]]: + schemas: dict[str, dict[str, str]] = {} def _is_injected_arg_type(type_: type) -> bool: return any( @@ -339,11 +349,15 @@ def _is_injected_arg_type(type_: type) -> bool: for arg in get_args(type_)[1:] ) - return [ - field - for field, field_info in get_all_basemodel_annotations(tool.args_schema).items() - if _is_injected_arg_type(field_info) - ] + for field_name, field_info in get_all_basemodel_annotations(tool.args_schema).items(): + if not _is_injected_arg_type(field_info): + continue + + field_type = get_args(field_info)[0] + + schemas[field_name] = TypeAdapter(field_type).json_schema() + + return schemas def to_fastmcp(tool: BaseTool) -> FastMCPTool: @@ -366,11 +380,14 @@ def to_fastmcp(tool: BaseTool) -> FastMCPTool: ) raise TypeError(msg) + meta = _build_fastmcp_tool_meta_from_injected_args_schema(tool) + parameters = tool.tool_call_schema.model_json_schema() field_definitions = { field: (field_info.annotation, field_info) for field, field_info in tool.tool_call_schema.model_fields.items() } + arg_model = create_model( f"{tool.name}Arguments", **field_definitions, __base__=ArgModelBase ) @@ -378,13 +395,16 @@ def to_fastmcp(tool: BaseTool) -> FastMCPTool: # We'll use an Any type for the function return type. # We're providing the parameters separately - async def fn(**arguments: dict[str, Any]) -> Any: # noqa: ANN401 - return await tool.ainvoke(arguments) - injected_args = _get_injected_args(tool) - if len(injected_args) > 0: - msg = "LangChain tools with injected arguments are not supported" - raise NotImplementedError(msg) + async def fn(context: Context, **arguments: dict[str, Any]) -> Any: # noqa: ANN401 + if context is not None and context.request_context.meta is not None: + injected_args_value = context.request_context.meta.model_dump().get( + META_KEY_INJECT_ARGS_VALUE, {}) + + for arg_name, arg_value in injected_args_value.items(): + arguments[arg_name] = arg_value + + return await tool.ainvoke(arguments) return FastMCPTool( fn=fn, @@ -393,4 +413,55 @@ async def fn(**arguments: dict[str, Any]) -> Any: # noqa: ANN401 parameters=parameters, fn_metadata=fn_metadata, is_async=True, + context_kwarg='context', + meta=meta ) + + +def _build_fastmcp_tool_meta_from_injected_args_schema(tool: BaseTool) -> dict[str, Any] | None: + injected_args_schema = _get_injected_args_schema(tool) + + if injected_args_schema: + return { + META_KEY_INJECT_ARGS_SCHEMA: injected_args_schema + } + + +def _build_tool_calling_meta(tool: MCPTool, **arguments: dict[str, Any]) -> dict[str, Any] | None: + """ + Discovers InjectedToolArg from the remote MCP Tool and get their value from **arguments + + Such values will be sent to the remote tool through MCP's _meta[META_KEY_INJECT_ARGS_SCHEMA] + """ + injected_args_schema: dict[str, Any] | None = None + meta: dict[str, Any] | None = None + + if tool.meta is not None and META_KEY_INJECT_ARGS_SCHEMA in tool.meta: + injected_args_schema = tool.meta.get(META_KEY_INJECT_ARGS_SCHEMA) + + if injected_args_schema: + meta = { + META_KEY_INJECT_ARGS_VALUE: {} + } + + for arg_name, arg_value in arguments.items(): + if arg_name not in injected_args_schema: + continue + + meta[META_KEY_INJECT_ARGS_VALUE][arg_name] = arg_value + + return meta + + +def _clear_injected_tool_args_from_arguments(meta: dict[str, Any] | None, **arguments: dict[str, Any]) -> dict[str, Any]: + """ + Clean **arguments from args that will be sent through MCP's _meta[META_KEY_INJECT_ARGS_SCHEMA] + """ + if meta is None or META_KEY_INJECT_ARGS_VALUE not in meta: + return arguments + + return { + arg_name: arg_value + for arg_name, arg_value in arguments.items() + if arg_name not in meta[META_KEY_INJECT_ARGS_VALUE] + } diff --git a/pyproject.toml b/pyproject.toml index 593fa971..258a4e74 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ readme = "README.md" requires-python = ">=3.10" dependencies = [ "langchain-core>=0.3.36,<2.0.0", - "mcp>=1.9.2", + "mcp>=1.19.0", "typing-extensions>=4.14.0", ] diff --git a/tests/test_tools.py b/tests/test_tools.py index 279da867..24aa2588 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -14,6 +14,7 @@ TextContent, TextResourceContents, ToolAnnotations, + RequestParams ) from mcp.types import Tool as MCPTool from pydantic import BaseModel @@ -24,6 +25,8 @@ convert_mcp_tool_to_langchain_tool, load_mcp_tools, to_fastmcp, + META_KEY_INJECT_ARGS_SCHEMA, + META_KEY_INJECT_ARGS_VALUE ) from tests.utils import run_streamable_http @@ -68,7 +71,8 @@ def test_convert_multiple_text_contents(): def test_convert_with_non_text_content(): # Test with non-text content - image_content = ImageContent(type="image", mimeType="image/png", data="base64data") + image_content = ImageContent( + type="image", mimeType="image/png", data="base64data") resource_content = EmbeddedResource( type="resource", resource=TextResourceContents( @@ -141,7 +145,7 @@ async def test_convert_mcp_tool_to_langchain_tool(): # Verify session.call_tool was called with correct arguments session.call_tool.assert_called_once_with( - "test_tool", {"param1": "test", "param2": 42}, progress_callback=None + "test_tool", {"param1": "test", "param2": 42}, progress_callback=None, meta=None ) # Verify result @@ -174,19 +178,22 @@ async def test_load_mcp_tools(): inputSchema=tool_input_schema, ), ] - session.list_tools.return_value = MagicMock(tools=mcp_tools, nextCursor=None) + session.list_tools.return_value = MagicMock( + tools=mcp_tools, nextCursor=None) # Mock call_tool to return different results for different tools - async def mock_call_tool(tool_name, arguments, progress_callback=None): + async def mock_call_tool(tool_name, arguments, progress_callback=None, meta=None): if tool_name == "tool1": return CallToolResult( content=[ - TextContent(type="text", text=f"tool1 result with {arguments}") + TextContent( + type="text", text=f"tool1 result with {arguments}") ], isError=False, ) return CallToolResult( - content=[TextContent(type="text", text=f"tool2 result with {arguments}")], + content=[TextContent( + type="text", text=f"tool2 result with {arguments}")], isError=False, ) @@ -346,11 +353,6 @@ async def test_convert_langchain_tool_to_fastmcp_tool(tool_instance): assert await fastmcp_tool.run(arguments=arguments) == 3 -def test_convert_langchain_tool_to_fastmcp_tool_with_injection(): - with pytest.raises(NotImplementedError): - to_fastmcp(add_with_injection) - - def _create_status_server(): server = FastMCP(port=8182) @@ -380,7 +382,8 @@ def custom_httpx_client_factory( timeout=timeout or httpx.Timeout(30.0), auth=auth, # Custom configuration - limits=httpx.Limits(max_keepalive_connections=5, max_connections=10), + limits=httpx.Limits(max_keepalive_connections=5, + max_connections=10), ) with run_streamable_http(_create_status_server, 8182): @@ -433,7 +436,8 @@ def custom_httpx_client_factory( timeout=timeout or httpx.Timeout(30.0), auth=auth, # Custom configuration for SSE - limits=httpx.Limits(max_keepalive_connections=3, max_connections=5), + limits=httpx.Limits( + max_keepalive_connections=3, max_connections=5), ) with run_streamable_http(_create_info_server, 8183): @@ -507,7 +511,8 @@ async def test_convert_mcp_tool_metadata_variants(): _meta={"source": "unit-test", "version": 1}, ) lc_tool_meta = convert_mcp_tool_to_langchain_tool(session, mcp_tool_meta) - assert lc_tool_meta.metadata == {"_meta": {"source": "unit-test", "version": 1}} + assert lc_tool_meta.metadata == { + "_meta": {"source": "unit-test", "version": 1}} mcp_tool_both = MCPTool( name="t_both", @@ -526,3 +531,102 @@ async def test_convert_mcp_tool_metadata_variants(): "openWorldHint": None, "_meta": {"flag": True}, } + + +async def test_injected_args_schema_extraction(): + """ + Tests that InjectedToolArg schema will be sent to the MCP server through _meta + """ + mcp_tool = to_fastmcp(add_with_injection) + + assert META_KEY_INJECT_ARGS_SCHEMA in mcp_tool.meta + assert mcp_tool.meta[META_KEY_INJECT_ARGS_SCHEMA] == { + 'injected_arg': {'type': 'string'}} + + +async def test_injected_args_value_are_sent(): + """ + Test that LangChain Tools send InjectedToolArg through MCP _meta + """ + session = AsyncMock() + session.call_tool.return_value = CallToolResult( + content=[TextContent(type="text", text="tool result")], + isError=False, + ) + + tool_input_schema = { + "properties": { + "param1": {"title": "Param1", "type": "string"}, + "param2": {"title": "Param2", "type": "integer"}, + }, + "required": ["param1", "param2"], + "title": "ToolSchema", + "type": "object", + } + + mcp_tool = MCPTool( + name="test_tool", + description="", + inputSchema=tool_input_schema, + _meta={ + META_KEY_INJECT_ARGS_SCHEMA: { + 'injected_arg': { + 'type': 'string' + } + } + } + ) + + lc_tool = convert_mcp_tool_to_langchain_tool( + session=session, + tool=mcp_tool + ) + + await lc_tool.ainvoke( + { + 'param1': 'test', + 'param2': 42, + 'injected_arg': 'bar' + } + ) + + session.call_tool.assert_called_once_with( + "test_tool", + { + "param1": "test", + "param2": 42 + }, + progress_callback=None, + meta={ + META_KEY_INJECT_ARGS_VALUE: { + 'injected_arg': 'bar' + } + } + ) + + +async def test_injected_args_are_passed_to_lctool(): + """ + Asserts that when a MCPTool receiveis LangChain InjectedToolArg through _meta it is correctly passed to the underlying LC Tool + """ + mcp_tool = to_fastmcp(add_with_injection) + + mocked_context = AsyncMock() + mocked_context.request_context.meta = RequestParams.Meta( + **{ + META_KEY_INJECT_ARGS_VALUE: { + 'injected_arg': 'bar' + } + }, + progressToken=None + ) + + result = await mcp_tool.run( + arguments={ + 'a': 1, + 'b': 1 + }, + context=mocked_context + ) + + assert result == 2