|
9 | 9 |
|
10 | 10 | from llama_stack.core.request_headers import NeedsRequestProviderData |
11 | 11 | from llama_stack.log import get_logger |
| 12 | +from llama_stack.providers.utils.forward_headers import build_forwarded_headers |
12 | 13 | from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool, list_mcp_tools |
13 | 14 | from llama_stack_api import ( |
14 | 15 | URL, |
@@ -46,51 +47,86 @@ async def list_runtime_tools( |
46 | 47 | mcp_endpoint: URL | None = None, |
47 | 48 | authorization: str | None = None, |
48 | 49 | ) -> ListToolDefsResponse: |
49 | | - # this endpoint should be retrieved by getting the tool group right? |
50 | 50 | if mcp_endpoint is None: |
51 | 51 | raise ValueError("mcp_endpoint is required") |
52 | 52 |
|
53 | | - # Get other headers from provider data (but NOT authorization) |
54 | | - provider_headers = await self.get_headers_from_request(mcp_endpoint.uri) |
| 53 | + forwarded_headers, forwarded_auth = self._get_forwarded_headers_and_auth() |
| 54 | + # legacy mcp_headers URI-keyed path (backward compat) |
| 55 | + legacy_headers = await self.get_headers_from_request(mcp_endpoint.uri) |
| 56 | + merged_headers = {**forwarded_headers, **legacy_headers} |
| 57 | + # explicit authorization= param from caller wins over forwarded |
| 58 | + effective_auth = authorization or forwarded_auth |
55 | 59 |
|
56 | | - return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=provider_headers, authorization=authorization) |
| 60 | + return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=merged_headers, authorization=effective_auth) |
57 | 61 |
|
58 | 62 | async def invoke_tool( |
59 | 63 | self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None |
60 | 64 | ) -> ToolInvocationResult: |
61 | 65 | tool = await self.tool_store.get_tool(tool_name) |
62 | 66 | if tool.metadata is None or tool.metadata.get("endpoint") is None: |
63 | 67 | raise ValueError(f"Tool {tool_name} does not have metadata") |
64 | | - endpoint = tool.metadata.get("endpoint") |
| 68 | + endpoint: str = tool.metadata["endpoint"] |
65 | 69 | if urlparse(endpoint).scheme not in ("http", "https"): |
66 | 70 | raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL") |
67 | 71 |
|
68 | | - # Get other headers from provider data (but NOT authorization) |
69 | | - provider_headers = await self.get_headers_from_request(endpoint) |
| 72 | + forwarded_headers, forwarded_auth = self._get_forwarded_headers_and_auth() |
| 73 | + # legacy mcp_headers URI-keyed path (backward compat) |
| 74 | + legacy_headers = await self.get_headers_from_request(endpoint) |
| 75 | + merged_headers = {**forwarded_headers, **legacy_headers} |
| 76 | + # explicit authorization= param from caller wins over forwarded |
| 77 | + effective_auth = authorization or forwarded_auth |
70 | 78 |
|
71 | 79 | return await invoke_mcp_tool( |
72 | 80 | endpoint=endpoint, |
73 | 81 | tool_name=tool_name, |
74 | 82 | kwargs=kwargs, |
75 | | - headers=provider_headers, |
76 | | - authorization=authorization, |
| 83 | + headers=merged_headers, |
| 84 | + authorization=effective_auth, |
77 | 85 | ) |
78 | 86 |
|
79 | | - async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]: |
80 | | - """ |
81 | | - Extract headers from request provider data, excluding authorization. |
| 87 | + def _get_forwarded_headers_and_auth(self) -> tuple[dict[str, str], str | None]: |
| 88 | + """Extract forwarded headers from provider data per the admin-configured allowlist. |
82 | 89 |
|
83 | | - Authorization must be provided via the dedicated authorization parameter. |
84 | | - If Authorization is found in mcp_headers, raise an error to guide users to the correct approach. |
85 | | -
|
86 | | - Args: |
87 | | - mcp_endpoint_uri: The MCP endpoint URI to match against provider data |
| 90 | + Splits the output of build_forwarded_headers() into non-Authorization headers |
| 91 | + and an auth token. Authorization-mapped values must be bare tokens (no 'Bearer ' |
| 92 | + prefix) per the forward_headers field description — prepare_mcp_headers() adds |
| 93 | + the prefix when passing via the authorization= param. |
88 | 94 |
|
89 | 95 | Returns: |
90 | | - dict[str, str]: Headers dictionary (without Authorization) |
| 96 | + (non_auth_headers, auth_token) where auth_token is None if not configured. |
| 97 | + """ |
| 98 | + provider_data = self.get_request_provider_data() |
| 99 | + all_headers = build_forwarded_headers(provider_data, self.config.forward_headers) |
| 100 | + |
| 101 | + if not all_headers: |
| 102 | + if self.config.forward_headers: |
| 103 | + logger.warning( |
| 104 | + "forward_headers is configured but no headers were forwarded — " |
| 105 | + "outbound request may be unauthenticated" |
| 106 | + ) |
| 107 | + return {}, None |
| 108 | + |
| 109 | + # Pull out Authorization (case-insensitive) so it goes via the authorization= |
| 110 | + # param — prepare_mcp_headers() rejects Authorization in the headers= dict. |
| 111 | + auth_token: str | None = None |
| 112 | + non_auth: dict[str, str] = {} |
| 113 | + for name, value in all_headers.items(): |
| 114 | + if name.lower() == "authorization": |
| 115 | + auth_token = value |
| 116 | + else: |
| 117 | + non_auth[name] = value |
| 118 | + |
| 119 | + return non_auth, auth_token |
| 120 | + |
| 121 | + async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]: |
| 122 | + """Extract headers from the legacy mcp_headers URI-keyed provider data. |
| 123 | +
|
| 124 | + Kept for backward compatibility. New deployments should use forward_headers |
| 125 | + in the provider config instead. |
91 | 126 |
|
92 | 127 | Raises: |
93 | | - ValueError: If Authorization header is found in mcp_headers |
| 128 | + ValueError: If Authorization header is found in mcp_headers (must use |
| 129 | + the dedicated authorization parameter instead). |
94 | 130 | """ |
95 | 131 |
|
96 | 132 | def canonicalize_uri(uri: str) -> str: |
|
0 commit comments