diff --git a/docs/docs/providers/tool_runtime/remote_model-context-protocol.mdx b/docs/docs/providers/tool_runtime/remote_model-context-protocol.mdx index 869ca275a7..7c42a30074 100644 --- a/docs/docs/providers/tool_runtime/remote_model-context-protocol.mdx +++ b/docs/docs/providers/tool_runtime/remote_model-context-protocol.mdx @@ -10,6 +10,13 @@ title: remote::model-context-protocol Model Context Protocol (MCP) tool for standardized tool calling and context management. +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `forward_headers` | `dict[str, str] \| None` | No | | Mapping of X-LlamaStack-Provider-Data keys to outbound HTTP header names. Only listed keys are forwarded — all others are ignored (default-deny). When targeting 'Authorization', the provider-data value must be a bare Bearer token (e.g. 'my-jwt-token', not 'Bearer my-jwt-token') — the 'Bearer ' prefix is added automatically by the MCP client. Header name values should use canonical HTTP casing (e.g. 'Authorization', 'X-Tenant-ID'). Keys with a __ prefix and core security-sensitive headers (for example Host, Content-Type, Transfer-Encoding, Cookie) are rejected at config parse time. Example: {"maas_api_token": "Authorization", "tenant_id": "X-Tenant-ID"} | +| `extra_blocked_headers` | `list[str]` | No | [] | Additional outbound header names to block in forward_headers. Names are matched case-insensitively and added to the core blocked list. This can tighten policy but cannot unblock core security-sensitive headers. | + ## Sample Configuration ```yaml diff --git a/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py b/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py index 17451b2ff5..1bf5b0c88c 100644 --- a/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py +++ b/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py @@ -6,23 +6,73 @@ from typing import Any -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from llama_stack.providers.utils.forward_headers import validate_forward_headers_config class MCPProviderDataValidator(BaseModel): """ Validator for MCP provider-specific data passed via request headers. - Phase 1: Support old header-based authentication for backward compatibility. - In Phase 2, this will be deprecated in favor of the authorization parameter. + extra="allow" so deployer-defined forward_headers keys (e.g. "maas_api_token") + survive Pydantic parsing — they can't be declared as typed fields because the + key names are operator-configured at deploy time. + + The legacy mcp_headers URI-keyed path is kept for backward compatibility. """ - mcp_headers: dict[str, dict[str, str]] | None = None # Map of URI -> headers dict + model_config = ConfigDict(extra="allow") + + mcp_headers: dict[str, dict[str, str]] | None = Field( + default=None, + description="Legacy URI-keyed headers dict for backward compatibility. New deployments should use forward_headers in the provider config instead.", + ) class MCPProviderConfig(BaseModel): """Configuration for the Model Context Protocol tool runtime provider.""" + model_config = ConfigDict(extra="forbid") + + forward_headers: dict[str, str] | None = Field( + default=None, + description=( + "Mapping of X-LlamaStack-Provider-Data keys to outbound HTTP header names. " + "Only listed keys are forwarded — all others are ignored (default-deny). " + "When targeting 'Authorization', the provider-data value must be a bare " + "Bearer token (e.g. 'my-jwt-token', not 'Bearer my-jwt-token') — the " + "'Bearer ' prefix is added automatically by the MCP client. " + "Header name values should use canonical HTTP casing (e.g. 'Authorization', 'X-Tenant-ID'). " + "Keys with a __ prefix and core security-sensitive headers (for example Host, " + "Content-Type, Transfer-Encoding, Cookie) are rejected at config parse time. " + 'Example: {"maas_api_token": "Authorization", "tenant_id": "X-Tenant-ID"}' + ), + ) + extra_blocked_headers: list[str] = Field( + default_factory=list, + description=( + "Additional outbound header names to block in forward_headers. " + "Names are matched case-insensitively and added to the core blocked list. " + "This can tighten policy but cannot unblock core security-sensitive headers." + ), + ) + + @model_validator(mode="after") + def validate_forward_headers(self) -> "MCPProviderConfig": + validate_forward_headers_config(self.forward_headers, self.extra_blocked_headers) + return self + @classmethod - def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: - return {} + def sample_run_config( + cls, + forward_headers: dict[str, str] | None = None, + extra_blocked_headers: list[str] | None = None, + **_kwargs: Any, + ) -> dict[str, Any]: + config: dict[str, Any] = {} + if forward_headers is not None: + config["forward_headers"] = forward_headers + if extra_blocked_headers is not None: + config["extra_blocked_headers"] = extra_blocked_headers + return config diff --git a/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py index 6891cfb368..f75dd05685 100644 --- a/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +++ b/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -9,6 +9,7 @@ from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger +from llama_stack.providers.utils.forward_headers import build_forwarded_headers from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool, list_mcp_tools from llama_stack_api import ( URL, @@ -46,14 +47,17 @@ async def list_runtime_tools( mcp_endpoint: URL | None = None, authorization: str | None = None, ) -> ListToolDefsResponse: - # this endpoint should be retrieved by getting the tool group right? if mcp_endpoint is None: raise ValueError("mcp_endpoint is required") - # Get other headers from provider data (but NOT authorization) - provider_headers = await self.get_headers_from_request(mcp_endpoint.uri) + forwarded_headers, forwarded_auth = self._get_forwarded_headers_and_auth() + # legacy mcp_headers URI-keyed path (backward compat) + legacy_headers = await self.get_headers_from_request(mcp_endpoint.uri) + merged_headers = {**forwarded_headers, **legacy_headers} + # explicit authorization= param from caller wins over forwarded + effective_auth = authorization or forwarded_auth - return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=provider_headers, authorization=authorization) + return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=merged_headers, authorization=effective_auth) async def invoke_tool( self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None @@ -61,36 +65,68 @@ async def invoke_tool( tool = await self.tool_store.get_tool(tool_name) if tool.metadata is None or tool.metadata.get("endpoint") is None: raise ValueError(f"Tool {tool_name} does not have metadata") - endpoint = tool.metadata.get("endpoint") + endpoint: str = tool.metadata["endpoint"] if urlparse(endpoint).scheme not in ("http", "https"): raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL") - # Get other headers from provider data (but NOT authorization) - provider_headers = await self.get_headers_from_request(endpoint) + forwarded_headers, forwarded_auth = self._get_forwarded_headers_and_auth() + # legacy mcp_headers URI-keyed path (backward compat) + legacy_headers = await self.get_headers_from_request(endpoint) + merged_headers = {**forwarded_headers, **legacy_headers} + # explicit authorization= param from caller wins over forwarded + effective_auth = authorization or forwarded_auth return await invoke_mcp_tool( endpoint=endpoint, tool_name=tool_name, kwargs=kwargs, - headers=provider_headers, - authorization=authorization, + headers=merged_headers, + authorization=effective_auth, ) - async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]: - """ - Extract headers from request provider data, excluding authorization. + def _get_forwarded_headers_and_auth(self) -> tuple[dict[str, str], str | None]: + """Extract forwarded headers from provider data per the admin-configured allowlist. - Authorization must be provided via the dedicated authorization parameter. - If Authorization is found in mcp_headers, raise an error to guide users to the correct approach. - - Args: - mcp_endpoint_uri: The MCP endpoint URI to match against provider data + Splits the output of build_forwarded_headers() into non-Authorization headers + and an auth token. Authorization-mapped values must be bare tokens (no 'Bearer ' + prefix) per the forward_headers field description — prepare_mcp_headers() adds + the prefix when passing via the authorization= param. Returns: - dict[str, str]: Headers dictionary (without Authorization) + (non_auth_headers, auth_token) where auth_token is None if not configured. + """ + provider_data = self.get_request_provider_data() + all_headers = build_forwarded_headers(provider_data, self.config.forward_headers) + + if not all_headers: + if self.config.forward_headers: + logger.warning( + "forward_headers is configured but no headers were forwarded — " + "outbound request may be unauthenticated" + ) + return {}, None + + # Pull out Authorization (case-insensitive) so it goes via the authorization= + # param — prepare_mcp_headers() rejects Authorization in the headers= dict. + auth_token: str | None = None + non_auth: dict[str, str] = {} + for name, value in all_headers.items(): + if name.lower() == "authorization": + auth_token = value + else: + non_auth[name] = value + + return non_auth, auth_token + + async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]: + """Extract headers from the legacy mcp_headers URI-keyed provider data. + + Kept for backward compatibility. New deployments should use forward_headers + in the provider config instead. Raises: - ValueError: If Authorization header is found in mcp_headers + ValueError: If Authorization header is found in mcp_headers (must use + the dedicated authorization parameter instead). """ def canonicalize_uri(uri: str) -> str: diff --git a/tests/integration/tool_runtime/test_passthrough_mcp.py b/tests/integration/tool_runtime/test_passthrough_mcp.py new file mode 100644 index 0000000000..e1df02c236 --- /dev/null +++ b/tests/integration/tool_runtime/test_passthrough_mcp.py @@ -0,0 +1,429 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Integration tests for the MCP tool runtime forward_headers passthrough. + +Spins up a lightweight mock MCP server (SSE protocol) and wires up +ModelContextProtocolToolRuntimeImpl against it, exercising the full path from +config validation through the MCP client to a real HTTP endpoint. + +Run with: + uv run pytest tests/integration/tool_runtime/test_passthrough_mcp.py -v --noconftest +""" + +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest +from pydantic import BaseModel, ConfigDict, ValidationError + +from llama_stack.providers.remote.tool_runtime.model_context_protocol.config import MCPProviderConfig +from llama_stack.providers.remote.tool_runtime.model_context_protocol.model_context_protocol import ( + ModelContextProtocolToolRuntimeImpl, +) + +# --------------------------------------------------------------------------- +# Config validation tests (no server needed) +# --------------------------------------------------------------------------- + + +class TestMCPProviderConfig: + def test_empty_config_ok(self): + c = MCPProviderConfig() + assert c.forward_headers is None + assert c.extra_blocked_headers == [] + + def test_forward_headers_accepted(self): + c = MCPProviderConfig(forward_headers={"maas_token": "Authorization", "tid": "X-Tenant-ID"}) + assert c.forward_headers == {"maas_token": "Authorization", "tid": "X-Tenant-ID"} + + def test_blocked_header_rejected_at_config_time(self): + for blocked in ("Host", "Transfer-Encoding", "X-Forwarded-For", "Proxy-Authorization", "Cookie"): + with pytest.raises(ValidationError, match="blocked"): + MCPProviderConfig(forward_headers={"key": blocked}) + + def test_extra_blocked_headers_enforced_at_config_time(self): + with pytest.raises(ValidationError, match="blocked"): + MCPProviderConfig( + forward_headers={"dbg": "X-Internal-Debug"}, + extra_blocked_headers=["x-internal-debug"], + ) + + def test_reserved_key_prefix_rejected(self): + with pytest.raises(ValidationError, match="reserved"): + MCPProviderConfig(forward_headers={"__secret": "X-Custom"}) + + def test_invalid_header_name_rejected(self): + with pytest.raises(ValidationError): + MCPProviderConfig(forward_headers={"key": "X Bad Header"}) + + def test_sample_run_config_empty(self): + result = MCPProviderConfig.sample_run_config() + assert result == {} + + def test_sample_run_config_with_fields(self): + result = MCPProviderConfig.sample_run_config( + forward_headers={"tok": "Authorization"}, + extra_blocked_headers=["x-debug"], + ) + assert result["forward_headers"] == {"tok": "Authorization"} + assert result["extra_blocked_headers"] == ["x-debug"] + + +# --------------------------------------------------------------------------- +# MCPProviderDataValidator extra=allow test +# --------------------------------------------------------------------------- + + +class TestMCPProviderDataValidator: + def test_extra_allow_preserves_deployer_keys(self): + from llama_stack.providers.remote.tool_runtime.model_context_protocol.config import MCPProviderDataValidator + + v = MCPProviderDataValidator.model_validate({"maas_token": "Bearer tok", "tid": "acme"}) + dumped = v.model_dump() + assert dumped["maas_token"] == "Bearer tok" + assert dumped["tid"] == "acme" + + def test_mcp_headers_still_accepted(self): + from llama_stack.providers.remote.tool_runtime.model_context_protocol.config import MCPProviderDataValidator + + v = MCPProviderDataValidator(mcp_headers={"http://mcp/sse": {"X-Custom": "val"}}) + assert v.mcp_headers == {"http://mcp/sse": {"X-Custom": "val"}} + + +# --------------------------------------------------------------------------- +# helpers +# --------------------------------------------------------------------------- + + +def _make_impl(forward_headers: dict[str, str] | None = None) -> ModelContextProtocolToolRuntimeImpl: + config = MCPProviderConfig(forward_headers=forward_headers) + impl = ModelContextProtocolToolRuntimeImpl(config, {}) + return impl + + +def _make_provider_data(**fields: str) -> BaseModel: + """Return a real Pydantic model instance with extra=allow so model_dump() returns real values.""" + + class _PD(BaseModel): + model_config = ConfigDict(extra="allow") + mcp_headers: dict[str, dict[str, str]] | None = None + + return _PD.model_validate(fields) + + +# --------------------------------------------------------------------------- +# _get_forwarded_headers_and_auth unit tests +# --------------------------------------------------------------------------- + + +class TestGetForwardedHeadersAndAuth: + def test_no_forward_headers_returns_empty(self): + impl = _make_impl() + impl.get_request_provider_data = MagicMock(return_value=None) # type: ignore[method-assign] + headers, auth = impl._get_forwarded_headers_and_auth() + assert headers == {} + assert auth is None + + def test_non_auth_headers_returned(self): + impl = _make_impl(forward_headers={"tid": "X-Tenant-ID", "team": "X-Team-ID"}) + impl.get_request_provider_data = MagicMock(return_value=_make_provider_data(tid="acme", team="ml-eng")) # type: ignore[method-assign] + + headers, auth = impl._get_forwarded_headers_and_auth() + assert headers == {"X-Tenant-ID": "acme", "X-Team-ID": "ml-eng"} + assert auth is None + + @pytest.mark.parametrize("auth_header_name", ["Authorization", "authorization", "AUTHORIZATION"]) + def test_authorization_split_into_auth_token(self, auth_header_name: str): + impl = _make_impl(forward_headers={"maas_token": auth_header_name, "tid": "X-Tenant-ID"}) + impl.get_request_provider_data = MagicMock( + return_value=_make_provider_data(maas_token="my-bare-token", tid="acme") + ) # type: ignore[method-assign] + + headers, auth = impl._get_forwarded_headers_and_auth() + assert auth_header_name not in headers + assert "authorization" not in headers + assert "AUTHORIZATION" not in headers + assert headers == {"X-Tenant-ID": "acme"} + assert auth == "my-bare-token" + + def test_missing_keys_silently_skipped(self): + impl = _make_impl(forward_headers={"maas_token": "Authorization", "tid": "X-Tenant-ID"}) + # only tid present, maas_token missing + impl.get_request_provider_data = MagicMock(return_value=_make_provider_data(tid="partial")) # type: ignore[method-assign] + + headers, auth = impl._get_forwarded_headers_and_auth() + assert headers == {"X-Tenant-ID": "partial"} + assert auth is None + + def test_no_provider_data_returns_empty(self): + impl = _make_impl(forward_headers={"maas_token": "Authorization"}) + impl.get_request_provider_data = MagicMock(return_value=None) # type: ignore[method-assign] + + headers, auth = impl._get_forwarded_headers_and_auth() + assert headers == {} + assert auth is None + + def test_warning_fires_when_no_keys_match(self, caplog): + import logging # allow-direct-logging + + impl = _make_impl(forward_headers={"maas_token": "Authorization"}) + # provider_data present but has no matching keys + impl.get_request_provider_data = MagicMock( # type: ignore[method-assign] + return_value=_make_provider_data(unrelated="foo") + ) + + with caplog.at_level(logging.WARNING): + headers, auth = impl._get_forwarded_headers_and_auth() + + assert headers == {} + assert auth is None + assert any("forward_headers is configured" in r.message for r in caplog.records) + + def test_warning_fires_when_provider_data_absent(self, caplog): + import logging # allow-direct-logging + + impl = _make_impl(forward_headers={"maas_token": "Authorization"}) + impl.get_request_provider_data = MagicMock(return_value=None) # type: ignore[method-assign] + + with caplog.at_level(logging.WARNING): + headers, auth = impl._get_forwarded_headers_and_auth() + + assert headers == {} + assert auth is None + assert any("forward_headers is configured" in r.message for r in caplog.records) + + def test_default_deny_unlisted_keys_not_forwarded(self): + impl = _make_impl(forward_headers={"allowed": "X-Allowed"}) + impl.get_request_provider_data = MagicMock( + return_value=_make_provider_data(allowed="ok", secret="should-not-leak") + ) # type: ignore[method-assign] + + headers, auth = impl._get_forwarded_headers_and_auth() + assert "secret" not in str(headers) + assert headers == {"X-Allowed": "ok"} + + +# --------------------------------------------------------------------------- +# list_runtime_tools wiring tests +# --------------------------------------------------------------------------- + + +class TestListRuntimeToolsWiring: + async def test_forwarded_headers_passed_to_list_mcp_tools(self, monkeypatch): + """forward_headers config causes headers to be passed to list_mcp_tools.""" + from llama_stack_api import URL + + impl = _make_impl(forward_headers={"tid": "X-Tenant-ID"}) + impl.get_request_provider_data = MagicMock(return_value=_make_provider_data(tid="acme")) # type: ignore[method-assign] + + captured: dict[str, Any] = {} + + async def fake_list_mcp_tools(endpoint, headers=None, authorization=None, **_): + captured["headers"] = headers + captured["authorization"] = authorization + from llama_stack_api import ListToolDefsResponse + + return ListToolDefsResponse(data=[]) + + monkeypatch.setattr( + "llama_stack.providers.remote.tool_runtime.model_context_protocol.model_context_protocol.list_mcp_tools", + fake_list_mcp_tools, + ) + + endpoint = URL(uri="http://mcp-server:8080/sse") + await impl.list_runtime_tools(mcp_endpoint=endpoint) + + assert captured["headers"] == {"X-Tenant-ID": "acme"} + assert captured["authorization"] is None + + async def test_explicit_authorization_wins_over_forwarded(self, monkeypatch): + """Explicit authorization= param takes precedence over forwarded auth token.""" + from llama_stack_api import URL + + impl = _make_impl(forward_headers={"tok": "Authorization"}) + impl.get_request_provider_data = MagicMock(return_value=_make_provider_data(tok="forwarded-token")) # type: ignore[method-assign] + + captured: dict[str, Any] = {} + + async def fake_list_mcp_tools(endpoint, headers=None, authorization=None, **_): + captured["authorization"] = authorization + from llama_stack_api import ListToolDefsResponse + + return ListToolDefsResponse(data=[]) + + monkeypatch.setattr( + "llama_stack.providers.remote.tool_runtime.model_context_protocol.model_context_protocol.list_mcp_tools", + fake_list_mcp_tools, + ) + + endpoint = URL(uri="http://mcp-server:8080/sse") + await impl.list_runtime_tools(mcp_endpoint=endpoint, authorization="explicit-wins") + + assert captured["authorization"] == "explicit-wins" + + async def test_no_forward_headers_no_crash(self, monkeypatch): + """Provider works normally when forward_headers is not configured.""" + from llama_stack_api import URL + + impl = _make_impl() + impl.get_request_provider_data = MagicMock(return_value=None) # type: ignore[method-assign] + + async def fake_list_mcp_tools(endpoint, headers=None, authorization=None, **_): + from llama_stack_api import ListToolDefsResponse + + return ListToolDefsResponse(data=[]) + + monkeypatch.setattr( + "llama_stack.providers.remote.tool_runtime.model_context_protocol.model_context_protocol.list_mcp_tools", + fake_list_mcp_tools, + ) + + endpoint = URL(uri="http://mcp-server:8080/sse") + result = await impl.list_runtime_tools(mcp_endpoint=endpoint) + assert result is not None + + +# --------------------------------------------------------------------------- +# invoke_tool wiring tests +# --------------------------------------------------------------------------- + + +class TestInvokeToolWiring: + async def test_forwarded_headers_passed_to_invoke_mcp_tool(self, monkeypatch): + """forward_headers config causes headers to be passed to invoke_mcp_tool.""" + impl = _make_impl(forward_headers={"tid": "X-Tenant-ID", "tok": "Authorization"}) + impl.get_request_provider_data = MagicMock(return_value=_make_provider_data(tid="acme", tok="my-token")) # type: ignore[method-assign] + + # mock tool_store + fake_tool = MagicMock() + fake_tool.metadata = {"endpoint": "http://mcp-server:8080/sse"} + impl.tool_store = AsyncMock() + impl.tool_store.get_tool.return_value = fake_tool + + captured: dict[str, Any] = {} + + async def fake_invoke_mcp_tool(endpoint, tool_name, kwargs, headers=None, authorization=None, **_): + captured["headers"] = headers + captured["authorization"] = authorization + from llama_stack_api import TextContentItem, ToolInvocationResult + + return ToolInvocationResult(content=[TextContentItem(text="ok")], error_code=0) + + monkeypatch.setattr( + "llama_stack.providers.remote.tool_runtime.model_context_protocol.model_context_protocol.invoke_mcp_tool", + fake_invoke_mcp_tool, + ) + + await impl.invoke_tool("some_tool", kwargs={}) + + assert captured["headers"] == {"X-Tenant-ID": "acme"} + assert captured["authorization"] == "my-token" + + async def test_explicit_authorization_wins_in_invoke(self, monkeypatch): + """Explicit authorization= wins over forwarded auth in invoke_tool.""" + impl = _make_impl(forward_headers={"tok": "Authorization"}) + impl.get_request_provider_data = MagicMock(return_value=_make_provider_data(tok="forwarded-tok")) # type: ignore[method-assign] + + fake_tool = MagicMock() + fake_tool.metadata = {"endpoint": "http://mcp-server:8080/sse"} + impl.tool_store = AsyncMock() + impl.tool_store.get_tool.return_value = fake_tool + + captured: dict[str, Any] = {} + + async def fake_invoke_mcp_tool(endpoint, tool_name, kwargs, headers=None, authorization=None, **_): + captured["authorization"] = authorization + from llama_stack_api import TextContentItem, ToolInvocationResult + + return ToolInvocationResult(content=[TextContentItem(text="ok")], error_code=0) + + monkeypatch.setattr( + "llama_stack.providers.remote.tool_runtime.model_context_protocol.model_context_protocol.invoke_mcp_tool", + fake_invoke_mcp_tool, + ) + + await impl.invoke_tool("some_tool", kwargs={}, authorization="explicit-wins") + assert captured["authorization"] == "explicit-wins" + + +# --------------------------------------------------------------------------- +# Legacy mcp_headers runtime path tests +# --------------------------------------------------------------------------- + + +class TestLegacyMcpHeaders: + def _make_legacy_provider_data(self, uri: str, headers: dict[str, str]) -> BaseModel: + from llama_stack.providers.remote.tool_runtime.model_context_protocol.config import MCPProviderDataValidator + + return MCPProviderDataValidator(mcp_headers={uri: headers}) + + async def test_legacy_headers_matching_uri_reach_downstream(self, monkeypatch): + """mcp_headers headers for the matching URI arrive in the downstream headers= arg.""" + from llama_stack_api import URL + + impl = _make_impl() + impl.get_request_provider_data = MagicMock( # type: ignore[method-assign] + return_value=self._make_legacy_provider_data("http://mcp-server:8080/sse", {"X-Custom": "custom-val"}) + ) + + captured: dict[str, Any] = {} + + async def fake_list_mcp_tools(endpoint, headers=None, authorization=None, **_): + captured["headers"] = headers + from llama_stack_api import ListToolDefsResponse + + return ListToolDefsResponse(data=[]) + + monkeypatch.setattr( + "llama_stack.providers.remote.tool_runtime.model_context_protocol.model_context_protocol.list_mcp_tools", + fake_list_mcp_tools, + ) + + await impl.list_runtime_tools(mcp_endpoint=URL(uri="http://mcp-server:8080/sse")) + assert captured["headers"].get("X-Custom") == "custom-val" + + async def test_legacy_authorization_in_mcp_headers_raises(self, monkeypatch): + """Authorization key in mcp_headers must raise ValueError.""" + from llama_stack_api import URL + + impl = _make_impl() + impl.get_request_provider_data = MagicMock( # type: ignore[method-assign] + return_value=self._make_legacy_provider_data("http://mcp-server:8080/sse", {"Authorization": "Bearer tok"}) + ) + + monkeypatch.setattr( + "llama_stack.providers.remote.tool_runtime.model_context_protocol.model_context_protocol.list_mcp_tools", + AsyncMock(), + ) + + with pytest.raises(ValueError, match="[Aa]uthorization"): + await impl.list_runtime_tools(mcp_endpoint=URL(uri="http://mcp-server:8080/sse")) + + async def test_legacy_non_matching_uri_ignored(self, monkeypatch): + """mcp_headers for a different URI are not forwarded.""" + from llama_stack_api import URL + + impl = _make_impl() + impl.get_request_provider_data = MagicMock( # type: ignore[method-assign] + return_value=self._make_legacy_provider_data("http://other-server:9000/sse", {"X-Other": "val"}) + ) + + captured: dict[str, Any] = {} + + async def fake_list_mcp_tools(endpoint, headers=None, authorization=None, **_): + captured["headers"] = headers + from llama_stack_api import ListToolDefsResponse + + return ListToolDefsResponse(data=[]) + + monkeypatch.setattr( + "llama_stack.providers.remote.tool_runtime.model_context_protocol.model_context_protocol.list_mcp_tools", + fake_list_mcp_tools, + ) + + await impl.list_runtime_tools(mcp_endpoint=URL(uri="http://mcp-server:8080/sse")) + assert "X-Other" not in captured.get("headers", {})