diff --git a/langchain_mcp_adapters/client.py b/langchain_mcp_adapters/client.py index 06c27e6..8d34f5d 100644 --- a/langchain_mcp_adapters/client.py +++ b/langchain_mcp_adapters/client.py @@ -5,15 +5,16 @@ """ import asyncio -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Callable from contextlib import asynccontextmanager from types import TracebackType from typing import Any from langchain_core.documents.base import Blob from langchain_core.messages import AIMessage, HumanMessage -from langchain_core.tools import BaseTool +from langchain_core.tools import BaseTool, ToolException from mcp import ClientSession +from pydantic import ValidationError from langchain_mcp_adapters.callbacks import CallbackContext, Callbacks from langchain_mcp_adapters.interceptors import ToolCallInterceptor @@ -143,12 +144,22 @@ async def session( await session.initialize() yield session - async def get_tools(self, *, server_name: str | None = None) -> list[BaseTool]: + async def get_tools( + self, + *, + server_name: str | None = None, + handle_tool_error: bool | str | Callable[[ToolException], str] | None = False, + handle_validation_error: ( + bool | str | Callable[[ValidationError], str] | None + ) = False, + ) -> list[BaseTool]: """Get a list of all tools from all connected servers. Args: server_name: Optional name of the server to get tools from. If `None`, all tools from all servers will be returned. + handle_tool_error: Optional error handler for tool execution errors. + handle_validation_error: Optional error handler for validation errors. !!! note @@ -171,6 +182,8 @@ async def get_tools(self, *, server_name: str | None = None) -> list[BaseTool]: callbacks=self.callbacks, server_name=server_name, tool_interceptors=self.tool_interceptors, + handle_tool_error=handle_tool_error, + handle_validation_error=handle_validation_error, ) all_tools: list[BaseTool] = [] @@ -183,6 +196,8 @@ async def get_tools(self, *, server_name: str | None = None) -> list[BaseTool]: callbacks=self.callbacks, server_name=name, tool_interceptors=self.tool_interceptors, + handle_tool_error=handle_tool_error, + handle_validation_error=handle_validation_error, ) ) load_mcp_tool_tasks.append(load_mcp_tool_task) diff --git a/langchain_mcp_adapters/tools.py b/langchain_mcp_adapters/tools.py index 6141f09..3bb47dc 100644 --- a/langchain_mcp_adapters/tools.py +++ b/langchain_mcp_adapters/tools.py @@ -25,7 +25,7 @@ TextContent, ) from mcp.types import Tool as MCPTool -from pydantic import BaseModel, create_model +from pydantic import BaseModel, ValidationError, create_model from langchain_mcp_adapters.callbacks import CallbackContext, Callbacks, _MCPCallbacks from langchain_mcp_adapters.interceptors import ( @@ -161,6 +161,10 @@ def convert_mcp_tool_to_langchain_tool( callbacks: Callbacks | None = None, tool_interceptors: list[ToolCallInterceptor] | None = None, server_name: str | None = None, + handle_tool_error: bool | str | Callable[[ToolException], str] | None = False, + handle_validation_error: ( + bool | str | Callable[[ValidationError], str] | None + ) = False, ) -> BaseTool: """Convert an MCP tool to a LangChain tool. @@ -174,6 +178,8 @@ def convert_mcp_tool_to_langchain_tool( callbacks: Optional callbacks for handling notifications and events tool_interceptors: Optional list of interceptors for tool call processing server_name: Name of the server this tool belongs to + handle_tool_error: Optional error handler for tool execution errors. + handle_validation_error: Optional error handler for validation errors. Returns: a LangChain tool @@ -303,6 +309,8 @@ async def execute_tool(request: MCPToolCallRequest) -> MCPToolCallResult: coroutine=call_tool, response_format="content_and_artifact", metadata=metadata, + handle_tool_error=handle_tool_error, + handle_validation_error=handle_validation_error, ) @@ -313,6 +321,10 @@ async def load_mcp_tools( callbacks: Callbacks | None = None, tool_interceptors: list[ToolCallInterceptor] | None = None, server_name: str | None = None, + handle_tool_error: bool | str | Callable[[ToolException], str] | None = False, + handle_validation_error: ( + bool | str | Callable[[ValidationError], str] | None + ) = False, ) -> list[BaseTool]: """Load all available MCP tools and convert them to LangChain [tools](https://docs.langchain.com/oss/python/langchain/tools). @@ -322,6 +334,8 @@ async def load_mcp_tools( callbacks: Optional `Callbacks` for handling notifications and events. tool_interceptors: Optional list of interceptors for tool call processing. server_name: Name of the server these tools belong to. + handle_tool_error: Optional error handler for tool execution errors. + handle_validation_error: Optional error handler for validation errors. Returns: List of LangChain [tools](https://docs.langchain.com/oss/python/langchain/tools). @@ -361,6 +375,8 @@ async def load_mcp_tools( callbacks=callbacks, tool_interceptors=tool_interceptors, server_name=server_name, + handle_tool_error=handle_tool_error, + handle_validation_error=handle_validation_error, ) for tool in tools ]