Skip to content

Commit 65a68a5

Browse files
committed
add forward_headers passthrough to remote::model-context-protocol
1 parent da09c44 commit 65a68a5

4 files changed

Lines changed: 547 additions & 25 deletions

File tree

docs/docs/providers/tool_runtime/remote_model-context-protocol.mdx

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@ title: remote::model-context-protocol
1010

1111
Model Context Protocol (MCP) tool for standardized tool calling and context management.
1212

13+
## Configuration
14+
15+
| Field | Type | Required | Default | Description |
16+
|-------|------|----------|---------|-------------|
17+
| `forward_headers` | `dict[str, str] \| None` | No | | Mapping of X-LlamaStack-Provider-Data keys to outbound HTTP header names. Only listed keys are forwarded — all others are ignored (default-deny). When targeting 'Authorization', the provider-data value must be a bare Bearer token (e.g. 'my-jwt-token', not 'Bearer my-jwt-token') — the 'Bearer ' prefix is added automatically by the MCP client. Header name values should use canonical HTTP casing (e.g. 'Authorization', 'X-Tenant-ID'). Keys with a __ prefix and core security-sensitive headers (for example Host, Content-Type, Transfer-Encoding, Cookie) are rejected at config parse time. Example: {"maas_api_token": "Authorization", "tenant_id": "X-Tenant-ID"} |
18+
| `extra_blocked_headers` | `list[str]` | No | [] | Additional outbound header names to block in forward_headers. Names are matched case-insensitively and added to the core blocked list. This can tighten policy but cannot unblock core security-sensitive headers. |
19+
1320
## Sample Configuration
1421

1522
```yaml

src/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,73 @@
66

77
from typing import Any
88

9-
from pydantic import BaseModel
9+
from pydantic import BaseModel, ConfigDict, Field, model_validator
10+
11+
from llama_stack.providers.utils.forward_headers import validate_forward_headers_config
1012

1113

1214
class MCPProviderDataValidator(BaseModel):
1315
"""
1416
Validator for MCP provider-specific data passed via request headers.
1517
16-
Phase 1: Support old header-based authentication for backward compatibility.
17-
In Phase 2, this will be deprecated in favor of the authorization parameter.
18+
extra="allow" so deployer-defined forward_headers keys (e.g. "maas_api_token")
19+
survive Pydantic parsing — they can't be declared as typed fields because the
20+
key names are operator-configured at deploy time.
21+
22+
The legacy mcp_headers URI-keyed path is kept for backward compatibility.
1823
"""
1924

20-
mcp_headers: dict[str, dict[str, str]] | None = None # Map of URI -> headers dict
25+
model_config = ConfigDict(extra="allow")
26+
27+
mcp_headers: dict[str, dict[str, str]] | None = Field(
28+
default=None,
29+
description="Legacy URI-keyed headers dict for backward compatibility. New deployments should use forward_headers in the provider config instead.",
30+
)
2131

2232

2333
class MCPProviderConfig(BaseModel):
2434
"""Configuration for the Model Context Protocol tool runtime provider."""
2535

36+
model_config = ConfigDict(extra="forbid")
37+
38+
forward_headers: dict[str, str] | None = Field(
39+
default=None,
40+
description=(
41+
"Mapping of X-LlamaStack-Provider-Data keys to outbound HTTP header names. "
42+
"Only listed keys are forwarded — all others are ignored (default-deny). "
43+
"When targeting 'Authorization', the provider-data value must be a bare "
44+
"Bearer token (e.g. 'my-jwt-token', not 'Bearer my-jwt-token') — the "
45+
"'Bearer ' prefix is added automatically by the MCP client. "
46+
"Header name values should use canonical HTTP casing (e.g. 'Authorization', 'X-Tenant-ID'). "
47+
"Keys with a __ prefix and core security-sensitive headers (for example Host, "
48+
"Content-Type, Transfer-Encoding, Cookie) are rejected at config parse time. "
49+
'Example: {"maas_api_token": "Authorization", "tenant_id": "X-Tenant-ID"}'
50+
),
51+
)
52+
extra_blocked_headers: list[str] = Field(
53+
default_factory=list,
54+
description=(
55+
"Additional outbound header names to block in forward_headers. "
56+
"Names are matched case-insensitively and added to the core blocked list. "
57+
"This can tighten policy but cannot unblock core security-sensitive headers."
58+
),
59+
)
60+
61+
@model_validator(mode="after")
62+
def validate_forward_headers(self) -> "MCPProviderConfig":
63+
validate_forward_headers_config(self.forward_headers, self.extra_blocked_headers)
64+
return self
65+
2666
@classmethod
27-
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
28-
return {}
67+
def sample_run_config(
68+
cls,
69+
forward_headers: dict[str, str] | None = None,
70+
extra_blocked_headers: list[str] | None = None,
71+
**_kwargs: Any,
72+
) -> dict[str, Any]:
73+
config: dict[str, Any] = {}
74+
if forward_headers is not None:
75+
config["forward_headers"] = forward_headers
76+
if extra_blocked_headers is not None:
77+
config["extra_blocked_headers"] = extra_blocked_headers
78+
return config

src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py

Lines changed: 55 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from llama_stack.core.request_headers import NeedsRequestProviderData
1111
from llama_stack.log import get_logger
12+
from llama_stack.providers.utils.forward_headers import build_forwarded_headers
1213
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool, list_mcp_tools
1314
from llama_stack_api import (
1415
URL,
@@ -46,51 +47,86 @@ async def list_runtime_tools(
4647
mcp_endpoint: URL | None = None,
4748
authorization: str | None = None,
4849
) -> ListToolDefsResponse:
49-
# this endpoint should be retrieved by getting the tool group right?
5050
if mcp_endpoint is None:
5151
raise ValueError("mcp_endpoint is required")
5252

53-
# Get other headers from provider data (but NOT authorization)
54-
provider_headers = await self.get_headers_from_request(mcp_endpoint.uri)
53+
forwarded_headers, forwarded_auth = self._get_forwarded_headers_and_auth()
54+
# legacy mcp_headers URI-keyed path (backward compat)
55+
legacy_headers = await self.get_headers_from_request(mcp_endpoint.uri)
56+
merged_headers = {**forwarded_headers, **legacy_headers}
57+
# explicit authorization= param from caller wins over forwarded
58+
effective_auth = authorization or forwarded_auth
5559

56-
return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=provider_headers, authorization=authorization)
60+
return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=merged_headers, authorization=effective_auth)
5761

5862
async def invoke_tool(
5963
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
6064
) -> ToolInvocationResult:
6165
tool = await self.tool_store.get_tool(tool_name)
6266
if tool.metadata is None or tool.metadata.get("endpoint") is None:
6367
raise ValueError(f"Tool {tool_name} does not have metadata")
64-
endpoint = tool.metadata.get("endpoint")
68+
endpoint: str = tool.metadata["endpoint"]
6569
if urlparse(endpoint).scheme not in ("http", "https"):
6670
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
6771

68-
# Get other headers from provider data (but NOT authorization)
69-
provider_headers = await self.get_headers_from_request(endpoint)
72+
forwarded_headers, forwarded_auth = self._get_forwarded_headers_and_auth()
73+
# legacy mcp_headers URI-keyed path (backward compat)
74+
legacy_headers = await self.get_headers_from_request(endpoint)
75+
merged_headers = {**forwarded_headers, **legacy_headers}
76+
# explicit authorization= param from caller wins over forwarded
77+
effective_auth = authorization or forwarded_auth
7078

7179
return await invoke_mcp_tool(
7280
endpoint=endpoint,
7381
tool_name=tool_name,
7482
kwargs=kwargs,
75-
headers=provider_headers,
76-
authorization=authorization,
83+
headers=merged_headers,
84+
authorization=effective_auth,
7785
)
7886

79-
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]:
80-
"""
81-
Extract headers from request provider data, excluding authorization.
87+
def _get_forwarded_headers_and_auth(self) -> tuple[dict[str, str], str | None]:
88+
"""Extract forwarded headers from provider data per the admin-configured allowlist.
8289
83-
Authorization must be provided via the dedicated authorization parameter.
84-
If Authorization is found in mcp_headers, raise an error to guide users to the correct approach.
85-
86-
Args:
87-
mcp_endpoint_uri: The MCP endpoint URI to match against provider data
90+
Splits the output of build_forwarded_headers() into non-Authorization headers
91+
and an auth token. Authorization-mapped values must be bare tokens (no 'Bearer '
92+
prefix) per the forward_headers field description — prepare_mcp_headers() adds
93+
the prefix when passing via the authorization= param.
8894
8995
Returns:
90-
dict[str, str]: Headers dictionary (without Authorization)
96+
(non_auth_headers, auth_token) where auth_token is None if not configured.
97+
"""
98+
provider_data = self.get_request_provider_data()
99+
all_headers = build_forwarded_headers(provider_data, self.config.forward_headers)
100+
101+
if not all_headers:
102+
if self.config.forward_headers:
103+
logger.warning(
104+
"forward_headers is configured but no headers were forwarded — "
105+
"outbound request may be unauthenticated"
106+
)
107+
return {}, None
108+
109+
# Pull out Authorization (case-insensitive) so it goes via the authorization=
110+
# param — prepare_mcp_headers() rejects Authorization in the headers= dict.
111+
auth_token: str | None = None
112+
non_auth: dict[str, str] = {}
113+
for name, value in all_headers.items():
114+
if name.lower() == "authorization":
115+
auth_token = value
116+
else:
117+
non_auth[name] = value
118+
119+
return non_auth, auth_token
120+
121+
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]:
122+
"""Extract headers from the legacy mcp_headers URI-keyed provider data.
123+
124+
Kept for backward compatibility. New deployments should use forward_headers
125+
in the provider config instead.
91126
92127
Raises:
93-
ValueError: If Authorization header is found in mcp_headers
128+
ValueError: If Authorization header is found in mcp_headers (must use
129+
the dedicated authorization parameter instead).
94130
"""
95131

96132
def canonicalize_uri(uri: str) -> str:

0 commit comments

Comments
 (0)