Skip to content

Commit a07a34e

Browse files
committed
add forward_headers passthrough to remote::model-context-protocol
1 parent 6c2b51c commit a07a34e

File tree

4 files changed

+432
-24
lines changed

4 files changed

+432
-24
lines changed

docs/docs/providers/tool_runtime/remote_model-context-protocol.mdx

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@ title: remote::model-context-protocol
1010

1111
Model Context Protocol (MCP) tool for standardized tool calling and context management.
1212

13+
## Configuration
14+
15+
| Field | Type | Required | Default | Description |
16+
|-------|------|----------|---------|-------------|
17+
| `forward_headers` | `dict[str, str] \| None` | No | | Mapping of X-LlamaStack-Provider-Data keys to outbound HTTP header names. Only listed keys are forwarded — all others are ignored (default-deny). When targeting 'Authorization', the provider-data value must be a bare Bearer token (e.g. 'my-jwt-token', not 'Bearer my-jwt-token') — the 'Bearer ' prefix is added automatically by the MCP client. Header name values should use canonical HTTP casing (e.g. 'Authorization', 'X-Tenant-ID'). Keys with a __ prefix and core security-sensitive headers (for example Host, Content-Type, Transfer-Encoding, Cookie) are rejected at config parse time. Example: {"maas_api_token": "Authorization", "tenant_id": "X-Tenant-ID"} |
18+
| `extra_blocked_headers` | `list[str]` | No | [] | Additional outbound header names to block in forward_headers. Names are matched case-insensitively and added to the core blocked list. This can tighten policy but cannot unblock core security-sensitive headers. |
19+
1320
## Sample Configuration
1421

1522
```yaml

src/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,68 @@
66

77
from typing import Any
88

9-
from pydantic import BaseModel
9+
from pydantic import BaseModel, ConfigDict, Field, model_validator
10+
11+
from llama_stack.providers.utils.forward_headers import validate_forward_headers_config
1012

1113

1214
class MCPProviderDataValidator(BaseModel):
1315
"""
1416
Validator for MCP provider-specific data passed via request headers.
1517
16-
Phase 1: Support old header-based authentication for backward compatibility.
17-
In Phase 2, this will be deprecated in favor of the authorization parameter.
18+
extra="allow" so deployer-defined forward_headers keys (e.g. "maas_api_token")
19+
survive Pydantic parsing — they can't be declared as typed fields because the
20+
key names are operator-configured at deploy time.
21+
22+
The legacy mcp_headers URI-keyed path is kept for backward compatibility.
1823
"""
1924

25+
model_config = ConfigDict(extra="allow")
26+
2027
mcp_headers: dict[str, dict[str, str]] | None = None # Map of URI -> headers dict
2128

2229

2330
class MCPProviderConfig(BaseModel):
31+
model_config = ConfigDict(extra="forbid")
32+
33+
forward_headers: dict[str, str] | None = Field(
34+
default=None,
35+
description=(
36+
"Mapping of X-LlamaStack-Provider-Data keys to outbound HTTP header names. "
37+
"Only listed keys are forwarded — all others are ignored (default-deny). "
38+
"When targeting 'Authorization', the provider-data value must be a bare "
39+
"Bearer token (e.g. 'my-jwt-token', not 'Bearer my-jwt-token') — the "
40+
"'Bearer ' prefix is added automatically by the MCP client. "
41+
"Header name values should use canonical HTTP casing (e.g. 'Authorization', 'X-Tenant-ID'). "
42+
"Keys with a __ prefix and core security-sensitive headers (for example Host, "
43+
"Content-Type, Transfer-Encoding, Cookie) are rejected at config parse time. "
44+
'Example: {"maas_api_token": "Authorization", "tenant_id": "X-Tenant-ID"}'
45+
),
46+
)
47+
extra_blocked_headers: list[str] = Field(
48+
default_factory=list,
49+
description=(
50+
"Additional outbound header names to block in forward_headers. "
51+
"Names are matched case-insensitively and added to the core blocked list. "
52+
"This can tighten policy but cannot unblock core security-sensitive headers."
53+
),
54+
)
55+
56+
@model_validator(mode="after")
57+
def validate_forward_headers(self) -> "MCPProviderConfig":
58+
validate_forward_headers_config(self.forward_headers, self.extra_blocked_headers)
59+
return self
60+
2461
@classmethod
25-
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
26-
return {}
62+
def sample_run_config(
63+
cls,
64+
forward_headers: dict[str, str] | None = None,
65+
extra_blocked_headers: list[str] | None = None,
66+
**_kwargs: Any,
67+
) -> dict[str, Any]:
68+
config: dict[str, Any] = {}
69+
if forward_headers:
70+
config["forward_headers"] = forward_headers
71+
if extra_blocked_headers:
72+
config["extra_blocked_headers"] = extra_blocked_headers
73+
return config

src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py

Lines changed: 55 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from llama_stack.core.request_headers import NeedsRequestProviderData
1111
from llama_stack.log import get_logger
12+
from llama_stack.providers.utils.forward_headers import build_forwarded_headers
1213
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool, list_mcp_tools
1314
from llama_stack_api import (
1415
URL,
@@ -44,51 +45,86 @@ async def list_runtime_tools(
4445
mcp_endpoint: URL | None = None,
4546
authorization: str | None = None,
4647
) -> ListToolDefsResponse:
47-
# this endpoint should be retrieved by getting the tool group right?
4848
if mcp_endpoint is None:
4949
raise ValueError("mcp_endpoint is required")
5050

51-
# Get other headers from provider data (but NOT authorization)
52-
provider_headers = await self.get_headers_from_request(mcp_endpoint.uri)
51+
forwarded_headers, forwarded_auth = self._get_forwarded_headers_and_auth()
52+
# legacy mcp_headers URI-keyed path (backward compat)
53+
legacy_headers = await self.get_headers_from_request(mcp_endpoint.uri)
54+
merged_headers = {**forwarded_headers, **legacy_headers}
55+
# explicit authorization= param from caller wins over forwarded
56+
effective_auth = authorization or forwarded_auth
5357

54-
return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=provider_headers, authorization=authorization)
58+
return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=merged_headers, authorization=effective_auth)
5559

5660
async def invoke_tool(
5761
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
5862
) -> ToolInvocationResult:
5963
tool = await self.tool_store.get_tool(tool_name)
6064
if tool.metadata is None or tool.metadata.get("endpoint") is None:
6165
raise ValueError(f"Tool {tool_name} does not have metadata")
62-
endpoint = tool.metadata.get("endpoint")
66+
endpoint: str = tool.metadata["endpoint"]
6367
if urlparse(endpoint).scheme not in ("http", "https"):
6468
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
6569

66-
# Get other headers from provider data (but NOT authorization)
67-
provider_headers = await self.get_headers_from_request(endpoint)
70+
forwarded_headers, forwarded_auth = self._get_forwarded_headers_and_auth()
71+
# legacy mcp_headers URI-keyed path (backward compat)
72+
legacy_headers = await self.get_headers_from_request(endpoint)
73+
merged_headers = {**forwarded_headers, **legacy_headers}
74+
# explicit authorization= param from caller wins over forwarded
75+
effective_auth = authorization or forwarded_auth
6876

6977
return await invoke_mcp_tool(
7078
endpoint=endpoint,
7179
tool_name=tool_name,
7280
kwargs=kwargs,
73-
headers=provider_headers,
74-
authorization=authorization,
81+
headers=merged_headers,
82+
authorization=effective_auth,
7583
)
7684

77-
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]:
78-
"""
79-
Extract headers from request provider data, excluding authorization.
85+
def _get_forwarded_headers_and_auth(self) -> tuple[dict[str, str], str | None]:
86+
"""Extract forwarded headers from provider data per the admin-configured allowlist.
8087
81-
Authorization must be provided via the dedicated authorization parameter.
82-
If Authorization is found in mcp_headers, raise an error to guide users to the correct approach.
83-
84-
Args:
85-
mcp_endpoint_uri: The MCP endpoint URI to match against provider data
88+
Splits the output of build_forwarded_headers() into non-Authorization headers
89+
and an auth token. Authorization-mapped values must be bare tokens (no 'Bearer '
90+
prefix) per the forward_headers field description — prepare_mcp_headers() adds
91+
the prefix when passing via the authorization= param.
8692
8793
Returns:
88-
dict[str, str]: Headers dictionary (without Authorization)
94+
(non_auth_headers, auth_token) where auth_token is None if not configured.
95+
"""
96+
provider_data = self.get_request_provider_data()
97+
all_headers = build_forwarded_headers(provider_data, self.config.forward_headers)
98+
99+
if not all_headers:
100+
if self.config.forward_headers and provider_data is not None:
101+
logger.warning(
102+
"forward_headers is configured but no matching keys found in provider data — "
103+
"outbound request may be unauthenticated"
104+
)
105+
return {}, None
106+
107+
# Pull out Authorization (case-insensitive) so it goes via the authorization=
108+
# param — prepare_mcp_headers() rejects Authorization in the headers= dict.
109+
auth_token: str | None = None
110+
non_auth: dict[str, str] = {}
111+
for name, value in all_headers.items():
112+
if name.lower() == "authorization":
113+
auth_token = value
114+
else:
115+
non_auth[name] = value
116+
117+
return non_auth, auth_token
118+
119+
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]:
120+
"""Extract headers from the legacy mcp_headers URI-keyed provider data.
121+
122+
Kept for backward compatibility. New deployments should use forward_headers
123+
in the provider config instead.
89124
90125
Raises:
91-
ValueError: If Authorization header is found in mcp_headers
126+
ValueError: If Authorization header is found in mcp_headers (must use
127+
the dedicated authorization parameter instead).
92128
"""
93129

94130
def canonicalize_uri(uri: str) -> str:

0 commit comments

Comments
 (0)