Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -349,3 +349,45 @@ fastmcp_tool = to_fastmcp(add)
mcp = FastMCP("Math", tools=[fastmcp_tool])
mcp.run(transport="stdio")
```

## Passing InjectedToolArg to an MCP Tool

By using the LangChain MCP Adapter on both the server and client sides, you can use `InjectedToolArg` to hide certain parameters from the LLM.

```python
# server.py
from langchain_core.tools import tool

data = {
'user_0': 'Spike'
}

@tool
async def get_user_pet_name(user_id: Annotated[str, InjectedToolArg]) -> str:
"""Returns the user's pet name"""

return data[user_id]

fastmcp_tool = to_fastmcp(add)
mcp = FastMCP("Math", tools=[fastmcp_tool])
mcp.run(transport="stdio")
```

And the user ID can be passed as part of the input, without the LLM knowledge:

```python
# client.py

client = MultiServerMCPClient(
...
)

tools = await client.get_tools()
agent = create_react_agent("openai:gpt-4.1", tools)
response = await agent.ainvoke(
{
"messages": "What is my dog's name?",
"user_id": "user_0"
}
)
```
117 changes: 94 additions & 23 deletions langchain_mcp_adapters/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from langchain_core.tools.base import get_all_basemodel_annotations
from mcp import ClientSession
from mcp.server.fastmcp.tools import Tool as FastMCPTool
from mcp.server.fastmcp.server import Context
from mcp.server.fastmcp.utilities.func_metadata import ArgModelBase, FuncMetadata
from mcp.types import (
AudioContent,
Expand All @@ -25,7 +26,7 @@
TextContent,
)
from mcp.types import Tool as MCPTool
from pydantic import BaseModel, create_model
from pydantic import BaseModel, create_model, TypeAdapter

from langchain_mcp_adapters.callbacks import CallbackContext, Callbacks, _MCPCallbacks
from langchain_mcp_adapters.hooks import CallToolRequestSpec, Hooks, ToolHookContext
Expand All @@ -44,6 +45,8 @@ def get_runtime() -> None:
"""no-op runtime getter."""
return

META_KEY_INJECT_ARGS_VALUE = 'langchain/injectedArgsValue'
META_KEY_INJECT_ARGS_SCHEMA = 'langchain/injectedArgsSchema'

NonTextContent = ImageContent | AudioContent | ResourceLink | EmbeddedResource
MAX_ITERATIONS = 1000
Expand Down Expand Up @@ -153,9 +156,20 @@ def convert_mcp_tool_to_langchain_tool(
async def call_tool(
**arguments: dict[str, Any],
) -> tuple[str | list[str], list[NonTextContent] | None]:
meta = _build_tool_calling_meta(
tool,
**arguments
)

arguments = _clear_injected_tool_args_from_arguments(
meta,
**arguments
)

mcp_callbacks = (
callbacks.to_mcp_format(
context=CallbackContext(server_name=server_name, tool_name=tool.name)
context=CallbackContext(
server_name=server_name, tool_name=tool.name)
)
if callbacks is not None
else _MCPCallbacks()
Expand Down Expand Up @@ -223,12 +237,14 @@ async def call_tool(
tool_name,
tool_args,
progress_callback=mcp_callbacks.progress_callback,
meta=meta
)
else:
call_tool_result = await session.call_tool(
tool_name,
tool_args,
progress_callback=mcp_callbacks.progress_callback,
meta=meta
)

if call_tool_result is None:
Expand Down Expand Up @@ -291,7 +307,8 @@ async def load_mcp_tools(
raise ValueError(msg)

mcp_callbacks = (
callbacks.to_mcp_format(context=CallbackContext(server_name=server_name))
callbacks.to_mcp_format(
context=CallbackContext(server_name=server_name))
if callbacks is not None
else _MCPCallbacks()
)
Expand Down Expand Up @@ -322,15 +339,8 @@ async def load_mcp_tools(
]


def _get_injected_args(tool: BaseTool) -> list[str]:
"""Get the list of injected argument names from a LangChain tool.

Args:
tool: The LangChain tool to inspect.

Returns:
A list of injected argument names.
"""
def _get_injected_args_schema(tool: BaseTool) -> dict[str, dict[str, str]]:
schemas: dict[str, dict[str, str]] = {}

def _is_injected_arg_type(type_: type) -> bool:
return any(
Expand All @@ -339,11 +349,15 @@ def _is_injected_arg_type(type_: type) -> bool:
for arg in get_args(type_)[1:]
)

return [
field
for field, field_info in get_all_basemodel_annotations(tool.args_schema).items()
if _is_injected_arg_type(field_info)
]
for field_name, field_info in get_all_basemodel_annotations(tool.args_schema).items():
if not _is_injected_arg_type(field_info):
continue

field_type = get_args(field_info)[0]

schemas[field_name] = TypeAdapter(field_type).json_schema()

return schemas


def to_fastmcp(tool: BaseTool) -> FastMCPTool:
Expand All @@ -366,25 +380,31 @@ def to_fastmcp(tool: BaseTool) -> FastMCPTool:
)
raise TypeError(msg)

meta = _build_fastmcp_tool_meta_from_injected_args_schema(tool)

parameters = tool.tool_call_schema.model_json_schema()
field_definitions = {
field: (field_info.annotation, field_info)
for field, field_info in tool.tool_call_schema.model_fields.items()
}

arg_model = create_model(
f"{tool.name}Arguments", **field_definitions, __base__=ArgModelBase
)
fn_metadata = FuncMetadata(arg_model=arg_model)

# We'll use an Any type for the function return type.
# We're providing the parameters separately
async def fn(**arguments: dict[str, Any]) -> Any: # noqa: ANN401
return await tool.ainvoke(arguments)

injected_args = _get_injected_args(tool)
if len(injected_args) > 0:
msg = "LangChain tools with injected arguments are not supported"
raise NotImplementedError(msg)
async def fn(context: Context, **arguments: dict[str, Any]) -> Any: # noqa: ANN401
if context is not None and context.request_context.meta is not None:
injected_args_value = context.request_context.meta.model_dump().get(
META_KEY_INJECT_ARGS_VALUE, {})

for arg_name, arg_value in injected_args_value.items():
arguments[arg_name] = arg_value

return await tool.ainvoke(arguments)

return FastMCPTool(
fn=fn,
Expand All @@ -393,4 +413,55 @@ async def fn(**arguments: dict[str, Any]) -> Any: # noqa: ANN401
parameters=parameters,
fn_metadata=fn_metadata,
is_async=True,
context_kwarg='context',
meta=meta
)


def _build_fastmcp_tool_meta_from_injected_args_schema(tool: BaseTool) -> dict[str, Any] | None:
injected_args_schema = _get_injected_args_schema(tool)

if injected_args_schema:
return {
META_KEY_INJECT_ARGS_SCHEMA: injected_args_schema
}


def _build_tool_calling_meta(tool: MCPTool, **arguments: dict[str, Any]) -> dict[str, Any] | None:
"""
Discovers InjectedToolArg from the remote MCP Tool and get their value from **arguments

Such values will be sent to the remote tool through MCP's _meta[META_KEY_INJECT_ARGS_SCHEMA]
"""
injected_args_schema: dict[str, Any] | None = None
meta: dict[str, Any] | None = None

if tool.meta is not None and META_KEY_INJECT_ARGS_SCHEMA in tool.meta:
injected_args_schema = tool.meta.get(META_KEY_INJECT_ARGS_SCHEMA)

if injected_args_schema:
meta = {
META_KEY_INJECT_ARGS_VALUE: {}
}

for arg_name, arg_value in arguments.items():
if arg_name not in injected_args_schema:
continue

meta[META_KEY_INJECT_ARGS_VALUE][arg_name] = arg_value

return meta


def _clear_injected_tool_args_from_arguments(meta: dict[str, Any] | None, **arguments: dict[str, Any]) -> dict[str, Any]:
"""
Clean **arguments from args that will be sent through MCP's _meta[META_KEY_INJECT_ARGS_SCHEMA]
"""
if meta is None or META_KEY_INJECT_ARGS_VALUE not in meta:
return arguments

return {
arg_name: arg_value
for arg_name, arg_value in arguments.items()
if arg_name not in meta[META_KEY_INJECT_ARGS_VALUE]
}
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"langchain-core>=0.3.36,<2.0.0",
"mcp>=1.9.2",
"mcp>=1.19.0",
"typing-extensions>=4.14.0",
]

Expand Down
Loading