|
9 | 9 |
|
10 | 10 | from llama_stack.core.request_headers import NeedsRequestProviderData |
11 | 11 | from llama_stack.log import get_logger |
| 12 | +from llama_stack.providers.utils.forward_headers import build_forwarded_headers |
12 | 13 | from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool, list_mcp_tools |
13 | 14 | from llama_stack_api import ( |
14 | 15 | URL, |
@@ -44,51 +45,86 @@ async def list_runtime_tools( |
44 | 45 | mcp_endpoint: URL | None = None, |
45 | 46 | authorization: str | None = None, |
46 | 47 | ) -> ListToolDefsResponse: |
47 | | - # this endpoint should be retrieved by getting the tool group right? |
48 | 48 | if mcp_endpoint is None: |
49 | 49 | raise ValueError("mcp_endpoint is required") |
50 | 50 |
|
51 | | - # Get other headers from provider data (but NOT authorization) |
52 | | - provider_headers = await self.get_headers_from_request(mcp_endpoint.uri) |
| 51 | + forwarded_headers, forwarded_auth = self._get_forwarded_headers_and_auth() |
| 52 | + # legacy mcp_headers URI-keyed path (backward compat) |
| 53 | + legacy_headers = await self.get_headers_from_request(mcp_endpoint.uri) |
| 54 | + merged_headers = {**forwarded_headers, **legacy_headers} |
| 55 | + # explicit authorization= param from caller wins over forwarded |
| 56 | + effective_auth = authorization or forwarded_auth |
53 | 57 |
|
54 | | - return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=provider_headers, authorization=authorization) |
| 58 | + return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=merged_headers, authorization=effective_auth) |
55 | 59 |
|
56 | 60 | async def invoke_tool( |
57 | 61 | self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None |
58 | 62 | ) -> ToolInvocationResult: |
59 | 63 | tool = await self.tool_store.get_tool(tool_name) |
60 | 64 | if tool.metadata is None or tool.metadata.get("endpoint") is None: |
61 | 65 | raise ValueError(f"Tool {tool_name} does not have metadata") |
62 | | - endpoint = tool.metadata.get("endpoint") |
| 66 | + endpoint: str = tool.metadata["endpoint"] |
63 | 67 | if urlparse(endpoint).scheme not in ("http", "https"): |
64 | 68 | raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL") |
65 | 69 |
|
66 | | - # Get other headers from provider data (but NOT authorization) |
67 | | - provider_headers = await self.get_headers_from_request(endpoint) |
| 70 | + forwarded_headers, forwarded_auth = self._get_forwarded_headers_and_auth() |
| 71 | + # legacy mcp_headers URI-keyed path (backward compat) |
| 72 | + legacy_headers = await self.get_headers_from_request(endpoint) |
| 73 | + merged_headers = {**forwarded_headers, **legacy_headers} |
| 74 | + # explicit authorization= param from caller wins over forwarded |
| 75 | + effective_auth = authorization or forwarded_auth |
68 | 76 |
|
69 | 77 | return await invoke_mcp_tool( |
70 | 78 | endpoint=endpoint, |
71 | 79 | tool_name=tool_name, |
72 | 80 | kwargs=kwargs, |
73 | | - headers=provider_headers, |
74 | | - authorization=authorization, |
| 81 | + headers=merged_headers, |
| 82 | + authorization=effective_auth, |
75 | 83 | ) |
76 | 84 |
|
77 | | - async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]: |
78 | | - """ |
79 | | - Extract headers from request provider data, excluding authorization. |
| 85 | + def _get_forwarded_headers_and_auth(self) -> tuple[dict[str, str], str | None]: |
| 86 | + """Extract forwarded headers from provider data per the admin-configured allowlist. |
80 | 87 |
|
81 | | - Authorization must be provided via the dedicated authorization parameter. |
82 | | - If Authorization is found in mcp_headers, raise an error to guide users to the correct approach. |
83 | | -
|
84 | | - Args: |
85 | | - mcp_endpoint_uri: The MCP endpoint URI to match against provider data |
| 88 | + Splits the output of build_forwarded_headers() into non-Authorization headers |
| 89 | + and an auth token. Authorization-mapped values must be bare tokens (no 'Bearer ' |
| 90 | + prefix) per the forward_headers field description — prepare_mcp_headers() adds |
| 91 | + the prefix when passing via the authorization= param. |
86 | 92 |
|
87 | 93 | Returns: |
88 | | - dict[str, str]: Headers dictionary (without Authorization) |
| 94 | + (non_auth_headers, auth_token) where auth_token is None if not configured. |
| 95 | + """ |
| 96 | + provider_data = self.get_request_provider_data() |
| 97 | + all_headers = build_forwarded_headers(provider_data, self.config.forward_headers) |
| 98 | + |
| 99 | + if not all_headers: |
| 100 | + if self.config.forward_headers and provider_data is not None: |
| 101 | + logger.warning( |
| 102 | + "forward_headers is configured but no matching keys found in provider data — " |
| 103 | + "outbound request may be unauthenticated" |
| 104 | + ) |
| 105 | + return {}, None |
| 106 | + |
| 107 | + # Pull out Authorization (case-insensitive) so it goes via the authorization= |
| 108 | + # param — prepare_mcp_headers() rejects Authorization in the headers= dict. |
| 109 | + auth_token: str | None = None |
| 110 | + non_auth: dict[str, str] = {} |
| 111 | + for name, value in all_headers.items(): |
| 112 | + if name.lower() == "authorization": |
| 113 | + auth_token = value |
| 114 | + else: |
| 115 | + non_auth[name] = value |
| 116 | + |
| 117 | + return non_auth, auth_token |
| 118 | + |
| 119 | + async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]: |
| 120 | + """Extract headers from the legacy mcp_headers URI-keyed provider data. |
| 121 | +
|
| 122 | + Kept for backward compatibility. New deployments should use forward_headers |
| 123 | + in the provider config instead. |
89 | 124 |
|
90 | 125 | Raises: |
91 | | - ValueError: If Authorization header is found in mcp_headers |
| 126 | + ValueError: If Authorization header is found in mcp_headers (must use |
| 127 | + the dedicated authorization parameter instead). |
92 | 128 | """ |
93 | 129 |
|
94 | 130 | def canonicalize_uri(uri: str) -> str: |
|
0 commit comments