Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ title: remote::model-context-protocol

Model Context Protocol (MCP) tool for standardized tool calling and context management.

## Configuration

| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `forward_headers` | `dict[str, str] \| None` | No | | Mapping of X-LlamaStack-Provider-Data keys to outbound HTTP header names. Only listed keys are forwarded — all others are ignored (default-deny). When targeting 'Authorization', the provider-data value must be a bare Bearer token (e.g. 'my-jwt-token', not 'Bearer my-jwt-token') — the 'Bearer ' prefix is added automatically by the MCP client. Header name values should use canonical HTTP casing (e.g. 'Authorization', 'X-Tenant-ID'). Keys with a __ prefix and core security-sensitive headers (for example Host, Content-Type, Transfer-Encoding, Cookie) are rejected at config parse time. Example: {"maas_api_token": "Authorization", "tenant_id": "X-Tenant-ID"} |
| `extra_blocked_headers` | `list[str]` | No | [] | Additional outbound header names to block in forward_headers. Names are matched case-insensitively and added to the core blocked list. This can tighten policy but cannot unblock core security-sensitive headers. |

## Sample Configuration

```yaml
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,73 @@

from typing import Any

from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict, Field, model_validator

from llama_stack.providers.utils.forward_headers import validate_forward_headers_config


class MCPProviderDataValidator(BaseModel):
"""
Validator for MCP provider-specific data passed via request headers.

Phase 1: Support old header-based authentication for backward compatibility.
In Phase 2, this will be deprecated in favor of the authorization parameter.
extra="allow" so deployer-defined forward_headers keys (e.g. "maas_api_token")
survive Pydantic parsing — they can't be declared as typed fields because the
key names are operator-configured at deploy time.

The legacy mcp_headers URI-keyed path is kept for backward compatibility.
"""

mcp_headers: dict[str, dict[str, str]] | None = None # Map of URI -> headers dict
model_config = ConfigDict(extra="allow")

mcp_headers: dict[str, dict[str, str]] | None = Field(
default=None,
description="Legacy URI-keyed headers dict for backward compatibility. New deployments should use forward_headers in the provider config instead.",
)


class MCPProviderConfig(BaseModel):
"""Configuration for the Model Context Protocol tool runtime provider."""

model_config = ConfigDict(extra="forbid")

forward_headers: dict[str, str] | None = Field(
default=None,
description=(
"Mapping of X-LlamaStack-Provider-Data keys to outbound HTTP header names. "
"Only listed keys are forwarded — all others are ignored (default-deny). "
"When targeting 'Authorization', the provider-data value must be a bare "
"Bearer token (e.g. 'my-jwt-token', not 'Bearer my-jwt-token') — the "
"'Bearer ' prefix is added automatically by the MCP client. "
"Header name values should use canonical HTTP casing (e.g. 'Authorization', 'X-Tenant-ID'). "
"Keys with a __ prefix and core security-sensitive headers (for example Host, "
"Content-Type, Transfer-Encoding, Cookie) are rejected at config parse time. "
'Example: {"maas_api_token": "Authorization", "tenant_id": "X-Tenant-ID"}'
),
)
extra_blocked_headers: list[str] = Field(
default_factory=list,
description=(
"Additional outbound header names to block in forward_headers. "
"Names are matched case-insensitively and added to the core blocked list. "
"This can tighten policy but cannot unblock core security-sensitive headers."
),
)

@model_validator(mode="after")
def validate_forward_headers(self) -> "MCPProviderConfig":
validate_forward_headers_config(self.forward_headers, self.extra_blocked_headers)
return self

@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {}
def sample_run_config(
cls,
forward_headers: dict[str, str] | None = None,
extra_blocked_headers: list[str] | None = None,
**_kwargs: Any,
) -> dict[str, Any]:
config: dict[str, Any] = {}
if forward_headers is not None:
config["forward_headers"] = forward_headers
if extra_blocked_headers is not None:
config["extra_blocked_headers"] = extra_blocked_headers
return config
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.utils.forward_headers import build_forwarded_headers
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool, list_mcp_tools
from llama_stack_api import (
URL,
Expand Down Expand Up @@ -46,51 +47,86 @@ async def list_runtime_tools(
mcp_endpoint: URL | None = None,
authorization: str | None = None,
) -> ListToolDefsResponse:
# this endpoint should be retrieved by getting the tool group right?
if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required")

# Get other headers from provider data (but NOT authorization)
provider_headers = await self.get_headers_from_request(mcp_endpoint.uri)
forwarded_headers, forwarded_auth = self._get_forwarded_headers_and_auth()
# legacy mcp_headers URI-keyed path (backward compat)
legacy_headers = await self.get_headers_from_request(mcp_endpoint.uri)
merged_headers = {**forwarded_headers, **legacy_headers}
# explicit authorization= param from caller wins over forwarded
effective_auth = authorization or forwarded_auth

return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=provider_headers, authorization=authorization)
return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=merged_headers, authorization=effective_auth)

async def invoke_tool(
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name)
if tool.metadata is None or tool.metadata.get("endpoint") is None:
raise ValueError(f"Tool {tool_name} does not have metadata")
endpoint = tool.metadata.get("endpoint")
endpoint: str = tool.metadata["endpoint"]
if urlparse(endpoint).scheme not in ("http", "https"):
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")

# Get other headers from provider data (but NOT authorization)
provider_headers = await self.get_headers_from_request(endpoint)
forwarded_headers, forwarded_auth = self._get_forwarded_headers_and_auth()
# legacy mcp_headers URI-keyed path (backward compat)
legacy_headers = await self.get_headers_from_request(endpoint)
merged_headers = {**forwarded_headers, **legacy_headers}
# explicit authorization= param from caller wins over forwarded
effective_auth = authorization or forwarded_auth

return await invoke_mcp_tool(
endpoint=endpoint,
tool_name=tool_name,
kwargs=kwargs,
headers=provider_headers,
authorization=authorization,
headers=merged_headers,
authorization=effective_auth,
)

async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]:
"""
Extract headers from request provider data, excluding authorization.
def _get_forwarded_headers_and_auth(self) -> tuple[dict[str, str], str | None]:
"""Extract forwarded headers from provider data per the admin-configured allowlist.

Authorization must be provided via the dedicated authorization parameter.
If Authorization is found in mcp_headers, raise an error to guide users to the correct approach.

Args:
mcp_endpoint_uri: The MCP endpoint URI to match against provider data
Splits the output of build_forwarded_headers() into non-Authorization headers
and an auth token. Authorization-mapped values must be bare tokens (no 'Bearer '
prefix) per the forward_headers field description — prepare_mcp_headers() adds
the prefix when passing via the authorization= param.

Returns:
dict[str, str]: Headers dictionary (without Authorization)
(non_auth_headers, auth_token) where auth_token is None if not configured.
"""
provider_data = self.get_request_provider_data()
all_headers = build_forwarded_headers(provider_data, self.config.forward_headers)

if not all_headers:
if self.config.forward_headers:
logger.warning(
"forward_headers is configured but no headers were forwarded — "
"outbound request may be unauthenticated"
)
return {}, None

# Pull out Authorization (case-insensitive) so it goes via the authorization=
# param — prepare_mcp_headers() rejects Authorization in the headers= dict.
auth_token: str | None = None
non_auth: dict[str, str] = {}
for name, value in all_headers.items():
if name.lower() == "authorization":
auth_token = value
else:
non_auth[name] = value

return non_auth, auth_token

async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]:
"""Extract headers from the legacy mcp_headers URI-keyed provider data.

Kept for backward compatibility. New deployments should use forward_headers
in the provider config instead.

Raises:
ValueError: If Authorization header is found in mcp_headers
ValueError: If Authorization header is found in mcp_headers (must use
the dedicated authorization parameter instead).
"""

def canonicalize_uri(uri: str) -> str:
Expand Down
Loading
Loading