diff --git a/.forge/skills/test-mcp-permissions/SKILL.md b/.forge/skills/test-mcp-permissions/SKILL.md new file mode 100644 index 0000000000..745ff71beb --- /dev/null +++ b/.forge/skills/test-mcp-permissions/SKILL.md @@ -0,0 +1,168 @@ +--- +name: test-mcp-permissions +description: Test the MCP server permission policy feature end-to-end. Use when asked to test MCP permissions, verify that local MCP servers are gated by policy, or validate the allow/deny/prompt behavior for MCP connections introduced in PR #3324. +--- + +# Test MCP Permissions + +This skill validates the MCP server permission policy feature (PR #3324). The feature gates **local-scope** MCP servers (`.mcp.json` in the project directory) through a permission prompt at startup, while **user-scope** servers (`~/.forge/.mcp.json`) are trusted unconditionally. + +## How the Feature Works + +1. **Startup flow**: `UI::request_local_mcp_permissions()` reads the local `.mcp.json` and calls `McpApp::request_mcp_permissions(cfg)` **before** the REPL starts, so the prompt never races with user input. +2. **Permission check**: For each enabled local server, a `PermissionOperation::Mcp` is evaluated against `~/.forge/permissions.yaml` by `PolicyEngine`. +3. **Policy result**: + - `Allow` → server connects silently + - `Deny` → server is filtered out silently + - `Confirm` (no matching rule) → user is prompted +4. **Prompt**: A two-choice `ConfirmPermission` (Accept / Reject) is shown with the server's command/url as a header. **Both** choices are persisted — the user is never asked again for the same server+cwd combination. +5. **Import shortcut**: `/mcp import` auto-persists `Allow` via `allow_mcp_servers()` — importing itself counts as consent, no prompt shown. + +## Permissions File + +`~/.forge/permissions.yaml` — written decisions look like: + +```yaml +# stdio server (Allow) +policies: + - permission: allow + rule: + mcp: + command: npx + args: ["-y", "@github/mcp"] + dir: /path/to/project + +# HTTP server (Deny) + - permission: deny + rule: + mcp: + url: "https://untrusted.example.com/sse" + dir: /path/to/project +``` + +Glob patterns work in all fields (`command: "np*"`, `url: "https://trusted.com/*"`). + +## Test Scenarios + +### Scenario 1 — No permissions.yaml: prompt fires + +```bash +rm -f ~/.forge/permissions.yaml +# Add a local MCP server to .mcp.json in the project dir: +echo '{"mcpServers":{"test-server":{"command":"npx","args":["-y","@github/mcp"]}}}' > .mcp.json +forge +``` + +**Expected:** Prompt appears — `Allow MCP server "test-server" to connect?` with `command: npx` shown as a header line. Choose **Accept**. + +**Verify:** +```bash +cat ~/.forge/permissions.yaml +# → contains: permission: allow, mcp: {command: npx, args: ["-y", "@github/mcp"], dir: } +``` + +--- + +### Scenario 2 — Accept persisted: no prompt on second run + +After Scenario 1 (accepted), restart forge in the same directory. + +**Expected:** No prompt. Server connects silently. + +--- + +### Scenario 3 — Reject persisted: server silently blocked + +Run Scenario 1 again (`rm ~/.forge/permissions.yaml`, restart forge), choose **Reject**. + +**Expected:** Forge starts without the server's tools available. + +**Verify:** +```bash +cat ~/.forge/permissions.yaml +# → contains: permission: deny +``` + +Ask forge to use a tool from that server — it should report it as unavailable. + +--- + +### Scenario 4 — User-scope server: never prompted + +Add a server to `~/.forge/.mcp.json` (user scope, not `.mcp.json` in cwd). + +```bash +rm -f ~/.forge/permissions.yaml +forge +``` + +**Expected:** No prompt. User-scope servers bypass the permission gate and always connect. + +--- + +### Scenario 5 — `mcp import` auto-approves + +```bash +rm -f ~/.forge/permissions.yaml +forge +# Inside forge REPL: +/mcp import +``` + +**Expected:** No permission prompt during import. After import, `~/.forge/permissions.yaml` contains `allow` rules for each imported server. + +--- + +### Scenario 6 — Glob rule pre-set in permissions.yaml + +```bash +cat > ~/.forge/permissions.yaml << 'EOF' +policies: + - permission: allow + rule: + mcp: + command: "np*" +EOF +forge +``` + +**Expected:** No prompt for any stdio server whose command starts with `np` (e.g. `npx`). The glob match skips the prompt entirely. + +--- + +### Scenario 7 — HTTP MCP server + +```bash +rm -f ~/.forge/permissions.yaml +echo '{"mcpServers":{"http-server":{"url":"https://mcp.example.com/sse"}}}' > .mcp.json +forge +``` + +**Expected:** Prompt shows `url: https://mcp.example.com/sse` as header. Accepting writes: +```yaml +- permission: allow + rule: + mcp: + url: "https://mcp.example.com/sse" + dir: +``` + +--- + +## Quick Reset Between Tests + +```bash +rm -f ~/.forge/permissions.yaml +rm -f .mcp.json +``` + +## Key Code Locations + +| What | File | +|---|---| +| Startup permission gate | `crates/forge_main/src/ui.rs:445-457` | +| McpApp orchestration | `crates/forge_app/src/mcp_app.rs` | +| Policy prompt logic | `crates/forge_services/src/policy.rs:218-244` | +| MCP rule matching | `crates/forge_domain/src/policies/rule.rs:111-116` | +| MCP filter (glob match) | `crates/forge_domain/src/policies/rule.rs:159-181` | +| Default permissions | `crates/forge_services/src/permissions.default.yaml` | diff --git a/.forge/skills/test-mcp-permissions/scripts/test.py b/.forge/skills/test-mcp-permissions/scripts/test.py new file mode 100644 index 0000000000..7ab8729e27 --- /dev/null +++ b/.forge/skills/test-mcp-permissions/scripts/test.py @@ -0,0 +1,235 @@ +#!/usr/bin/env python3 +"""End-to-end tests for MCP server permission policy (PR #3324).""" + +import contextlib +import json +import os +import shutil +import subprocess +import sys +import tempfile +import time + +try: + import pexpect +except ImportError: + subprocess.check_call([sys.executable, "-m", "pip", "install", "pexpect"]) + import pexpect + +try: + import yaml +except ImportError: + subprocess.check_call([sys.executable, "-m", "pip", "install", "pyyaml"]) + import yaml + +FORGE_BIN = os.path.abspath(os.environ.get("FORGE_BIN", "forge")) +PASS = "\033[32mPASS\033[0m" +FAIL = "\033[31mFAIL\033[0m" +results: list = [] + + +@contextlib.contextmanager +def scenario_dirs(): + cwd = tempfile.mkdtemp(prefix="forge_mcp_cwd_") + cfg = tempfile.mkdtemp(prefix="forge_mcp_cfg_") + for base in [os.path.expanduser("~/forge"), os.path.expanduser("~/.forge")]: + if os.path.isdir(base): + for name in [".forge.toml", ".config.json", ".credentials.json"]: + src = os.path.join(base, name) + if os.path.exists(src): + shutil.copy2(src, os.path.join(cfg, name)) + break + try: + yield cwd, cfg + finally: + shutil.rmtree(cwd, ignore_errors=True) + shutil.rmtree(cfg, ignore_errors=True) + + +def read_perms(cfg: str) -> dict: + p = os.path.join(cfg, "permissions.yaml") + return yaml.safe_load(open(p)) or {} if os.path.exists(p) else {} + + +def write_perms(cfg: str, data: dict): + with open(os.path.join(cfg, "permissions.yaml"), "w") as f: + yaml.dump(data, f, default_flow_style=False) + + +def write_mcp(path: str, command: str, args=None, key="test-server"): + server: dict = {"command": command} + if args: + server["args"] = args + with open(os.path.join(path, ".mcp.json"), "w") as f: + json.dump({"mcpServers": {key: server}}, f) + + +def spawn(cwd: str, cfg: str) -> pexpect.spawn: + env = {**os.environ, "TERM": "xterm-256color", "COLUMNS": "120", "LINES": "40", "FORGE_CONFIG": cfg} + return pexpect.spawn( + "/bin/sh", args=["-c", f"exec {FORGE_BIN} -p hello 2>&1"], + cwd=cwd, timeout=30, encoding="utf-8", codec_errors="replace", env=env, + ) + + +def accept(child: pexpect.spawn): + child.expect("Allow MCP server", timeout=30) + child.send("\r") + time.sleep(4) + child.close(force=True) + + +def reject(child: pexpect.spawn): + child.expect("Allow MCP server", timeout=30) + for ch in ("\x1b", "[", "B"): # arrow-down as separate bytes for raw-mode TUI + child.send(ch) + time.sleep(0.1) + time.sleep(0.4) + child.send("\r") + time.sleep(4) + child.close(force=True) + + +def no_prompt(child: pexpect.spawn) -> bool: + idx = child.expect(["Allow MCP server", pexpect.TIMEOUT, pexpect.EOF], timeout=15) + child.close(force=True) + return idx != 0 + + +def mcp_rules(perms: dict, permission: str) -> list: + return [ + p for p in perms.get("policies", []) + if isinstance(p, dict) and p.get("permission") == permission + and isinstance(p.get("rule"), dict) and "mcp" in p["rule"] + ] + + +def show_perms(before: dict, after: dict): + def dump(d): + if not d: + print(" (empty)") + return + for line in yaml.dump(d, default_flow_style=False, sort_keys=False).splitlines(): + print(f" {line}") + print(" ┌─ before ────────────────────────") + dump(before) + print(" ├─ after ────────────────────────") + dump(after) + print(" └─────────────────────────────────") + + +def run(name: str, fn): + print(f"\n{'─'*50}\n{name}\n{'─'*50}") + try: + fn() + print(f" {PASS}") + results.append((name, True, None)) + except Exception as e: + print(f" {FAIL} — {e}") + results.append((name, False, str(e))) + + +# --------------------------------------------------------------------------- +# Scenarios +# --------------------------------------------------------------------------- + +def test_accept_writes_allow(): + with scenario_dirs() as (cwd, cfg): + write_mcp(cwd, "echo", ["hello"]) + before = read_perms(cfg) + accept(spawn(cwd, cfg)) + after = read_perms(cfg) + show_perms(before, after) + assert mcp_rules(after, "allow"), "Expected an MCP allow rule" + + +def test_reject_writes_deny(): + with scenario_dirs() as (cwd, cfg): + write_mcp(cwd, "echo", ["hello"]) + before = read_perms(cfg) + reject(spawn(cwd, cfg)) + after = read_perms(cfg) + show_perms(before, after) + assert mcp_rules(after, "deny"), "Expected an MCP deny rule" + + +def test_existing_allow_skips_prompt(): + with scenario_dirs() as (cwd, cfg): + write_perms(cfg, {"policies": [{"permission": "allow", "rule": {"mcp": {"command": "echo"}}}]}) + write_mcp(cwd, "echo", ["hello"]) + before = read_perms(cfg) + assert no_prompt(spawn(cwd, cfg)), "Prompt appeared even though allow rule was pre-written" + show_perms(before, read_perms(cfg)) + + +def test_second_run_skips_prompt(): + with scenario_dirs() as (cwd, cfg): + write_mcp(cwd, "echo", ["hello"]) + + before1 = read_perms(cfg) + accept(spawn(cwd, cfg)) + print(" [run 1]"); show_perms(before1, read_perms(cfg)) + + before2 = read_perms(cfg) + assert no_prompt(spawn(cwd, cfg)), "Prompt appeared on second run — decision was not persisted" + print(" [run 2]"); show_perms(before2, read_perms(cfg)) + + +def test_npx_accept(): + with scenario_dirs() as (cwd, cfg): + write_mcp(cwd, "npx", ["-y", "@modelcontextprotocol/server-filesystem", cwd], key="filesystem") + before = read_perms(cfg) + accept(spawn(cwd, cfg)) + after = read_perms(cfg) + show_perms(before, after) + rules = mcp_rules(after, "allow") + assert rules and rules[0]["rule"]["mcp"].get("command") == "npx", "Expected npx allow rule" + + +def test_npx_reject(): + with scenario_dirs() as (cwd, cfg): + write_mcp(cwd, "npx", ["-y", "@modelcontextprotocol/server-filesystem", cwd], key="filesystem") + before = read_perms(cfg) + reject(spawn(cwd, cfg)) + after = read_perms(cfg) + show_perms(before, after) + rules = mcp_rules(after, "deny") + assert rules and rules[0]["rule"]["mcp"].get("command") == "npx", "Expected npx deny rule" + + +def test_user_scope_never_prompts(): + with scenario_dirs() as (cwd, cfg): + write_mcp(cfg, "echo", ["user-scope"], key="user-server") # inside cfg = user scope + before = read_perms(cfg) + assert no_prompt(spawn(cwd, cfg)), "Prompt appeared for user-scope server" + show_perms(before, read_perms(cfg)) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main(): + print(f"{'='*50}\nMCP Permission Policy — E2E Tests\n{'='*50}") + print(f" binary: {FORGE_BIN}\n") + + run("Accept → allow rule written", test_accept_writes_allow) + run("Reject → deny rule written", test_reject_writes_deny) + run("Pre-existing allow → no prompt", test_existing_allow_skips_prompt) + run("Second run after accept → no prompt", test_second_run_skips_prompt) + run("npx server Accept → allow rule", test_npx_accept) + run("npx server Reject → deny rule", test_npx_reject) + run("User-scope server → never prompts", test_user_scope_never_prompts) + + passed = sum(1 for _, ok, _ in results if ok) + print(f"\n{'='*50}\nSUMMARY\n{'='*50}") + for name, ok, err in results: + print(f" {PASS if ok else FAIL} {name}") + if err: + print(f" {err}") + print(f"\n{passed}/{len(results)} passed") + sys.exit(0 if passed == len(results) else 1) + + +if __name__ == "__main__": + main() diff --git a/crates/forge_api/src/api.rs b/crates/forge_api/src/api.rs index aafb112d49..a55f3496f8 100644 --- a/crates/forge_api/src/api.rs +++ b/crates/forge_api/src/api.rs @@ -124,6 +124,15 @@ pub trait API: Sync + Send { /// project directory async fn write_mcp_config(&self, scope: &Scope, config: &McpConfig) -> Result<()>; + /// Prompts for missing permissions for each enabled server in `cfg`. + /// Idempotent — servers with existing decisions are skipped. + /// Call this synchronously at startup before the REPL takes over stdin. + async fn request_mcp_permissions(&self, cfg: McpConfig) -> Result<()>; + + /// Persist `Allow` decisions for the named servers without prompting. + /// Used by `mcp import` to record consent on the user's behalf. + async fn allow_mcp_servers(&self, names: &[ServerName]) -> Result<()>; + /// Retrieves the provider configuration for the specified agent async fn get_agent_provider(&self, agent_id: AgentId) -> anyhow::Result>; diff --git a/crates/forge_api/src/forge_api.rs b/crates/forge_api/src/forge_api.rs index b56d485bfd..b210dd820f 100644 --- a/crates/forge_api/src/forge_api.rs +++ b/crates/forge_api/src/forge_api.rs @@ -7,7 +7,7 @@ use forge_app::dto::ToolsOverview; use forge_app::{ AgentProviderResolver, AgentRegistry, AppConfigService, AuthService, CommandInfra, CommandLoaderService, ConversationService, DataGenerationApp, EnvironmentInfra, - FileDiscoveryService, ForgeApp, GitApp, GrpcInfra, McpConfigManager, McpService, + FileDiscoveryService, ForgeApp, GitApp, GrpcInfra, McpApp, McpConfigManager, McpService, ProviderAuthService, ProviderService, Services, User, UserUsage, Walker, WorkspaceService, }; use forge_config::ForgeConfig; @@ -39,6 +39,16 @@ impl ForgeAPI { { ForgeApp::new(self.services.clone()) } + + /// Creates an McpApp instance for MCP permission and connection + /// orchestration. + fn mcp_app(&self) -> McpApp + where + A: Services + EnvironmentInfra, + F: EnvironmentInfra, + { + McpApp::new(self.services.clone()) + } } impl ForgeAPI>, ForgeRepo> { @@ -227,6 +237,14 @@ impl< .map_err(|e| anyhow::anyhow!(e)) } + async fn allow_mcp_servers(&self, names: &[ServerName]) -> Result<()> { + self.mcp_app().allow_mcp_servers(names).await + } + + async fn request_mcp_permissions(&self, cfg: McpConfig) -> Result<()> { + self.mcp_app().request_mcp_permissions(cfg).await + } + async fn execute_shell_command_raw( &self, command: &str, diff --git a/crates/forge_app/src/infra.rs b/crates/forge_app/src/infra.rs index 63d65c83a3..ff884464ae 100644 --- a/crates/forge_app/src/infra.rs +++ b/crates/forge_app/src/infra.rs @@ -164,6 +164,44 @@ pub trait CommandInfra: Send + Sync { ) -> anyhow::Result; } +/// A prompt shown to the user in a selection widget. +/// +/// `message` is the question displayed on the prompt line. +/// `header` contains zero or more lines shown as non-selectable context rows +/// above the list of options. +#[derive(Debug, Clone, Default)] +pub struct SelectPrompt { + /// The question shown on the prompt line. + pub message: String, + /// Optional context lines rendered above the selectable options. + pub header: Vec, +} + +impl SelectPrompt { + /// Creates a prompt with a message and no header lines. + pub fn new(message: impl Into) -> Self { + Self { message: message.into(), header: Vec::new() } + } + + /// Sets the header lines of the prompt, replacing any previously set lines. + pub fn with_header(mut self, lines: impl IntoIterator>) -> Self { + self.header = lines.into_iter().map(|l| l.into()).collect(); + self + } +} + +impl From<&str> for SelectPrompt { + fn from(s: &str) -> Self { + Self::new(s) + } +} + +impl From for SelectPrompt { + fn from(s: String) -> Self { + Self::new(s) + } +} + #[async_trait::async_trait] pub trait UserInfra: Send + Sync { /// Prompts the user with question @@ -174,19 +212,22 @@ pub trait UserInfra: Send + Sync { /// Returns None if the user interrupts the selection async fn select_one( &self, - message: &str, + prompt: impl Into + Send, options: Vec, ) -> anyhow::Result>; /// Prompts the user to select a single option from an enum that implements /// IntoEnumIterator Returns None if the user interrupts the selection - async fn select_one_enum(&self, message: &str) -> anyhow::Result> + async fn select_one_enum( + &self, + prompt: impl Into + Send, + ) -> anyhow::Result> where T: Clone + std::fmt::Display + Send + 'static + strum::IntoEnumIterator + std::str::FromStr, ::Err: std::fmt::Debug, { let options: Vec = T::iter().collect(); - let selected = self.select_one(message, options).await?; + let selected = self.select_one(prompt, options).await?; Ok(selected) } diff --git a/crates/forge_app/src/lib.rs b/crates/forge_app/src/lib.rs index 66de3e618d..644ace81aa 100644 --- a/crates/forge_app/src/lib.rs +++ b/crates/forge_app/src/lib.rs @@ -15,6 +15,7 @@ mod git_app; mod hooks; mod infra; mod init_conversation_metrics; +mod mcp_app; mod mcp_executor; mod operation; mod orch; @@ -47,6 +48,7 @@ pub use data_gen::*; pub use error::*; pub use git_app::*; pub use infra::*; +pub use mcp_app::*; pub use services::*; pub use template_engine::*; pub use terminal_context::*; diff --git a/crates/forge_app/src/mcp_app.rs b/crates/forge_app/src/mcp_app.rs new file mode 100644 index 0000000000..a361b0b8f8 --- /dev/null +++ b/crates/forge_app/src/mcp_app.rs @@ -0,0 +1,104 @@ +use std::sync::Arc; + +use anyhow::Result; +use forge_domain::*; +use merge::Merge; + +use crate::services::{McpConfigManager, McpService, PolicyService}; +use crate::{EnvironmentInfra, Services}; + +/// McpApp handles MCP permission reconciliation and policy-filtered +/// connections, keeping `McpService` free of any policy awareness. +pub struct McpApp { + services: Arc, +} + +impl McpApp { + /// Creates a new McpApp instance with the provided services. + pub fn new(services: Arc) -> Self { + Self { services } + } +} + +impl> McpApp { + /// Prompts for missing permissions for each enabled server in `cfg`. + /// Idempotent — servers that already have a recorded decision are skipped + /// silently. + /// + /// This is the only place a permission prompt can fire for MCP. Call this + /// synchronously at startup (before the REPL takes over stdin) so prompts + /// don't race with user input. + pub async fn request_mcp_permissions(&self, cfg: McpConfig) -> Result<()> { + let cwd = self.services.get_environment().cwd; + for (name, server) in cfg + .mcp_servers + .into_iter() + .filter(|(_, s)| !s.is_disabled()) + { + let op = PermissionOperation::Mcp { + config: server, + cwd: cwd.clone(), + message: format!("Allow MCP server \"{name}\" to connect?"), + }; + // check_operation_permission handles the prompt + persist. + // The return value is intentionally discarded here; the caller + // just needs all decisions to be recorded before connections start. + let _ = self.services.check_operation_permission(&op).await?; + } + Ok(()) + } + + /// Returns a merged MCP config where user-scope servers are trusted + /// unconditionally and local-scope servers are filtered to those with an + /// explicit `Allow` policy. Never prompts — call + /// [`Self::request_mcp_permissions`] first to ensure decisions exist. + pub async fn permitted_mcp_config(&self) -> Result { + let mut user = self.services.read_mcp_config(Some(&Scope::User)).await?; + let local = self.services.read_mcp_config(Some(&Scope::Local)).await?; + let cwd = self.services.get_environment().cwd; + + let mut filtered_local = McpConfig::default(); + for (name, server) in local.mcp_servers { + if server.is_disabled() { + continue; + } + let op = PermissionOperation::Mcp { + config: server.clone(), + cwd: cwd.clone(), + message: String::new(), + }; + if self.services.is_operation_permitted(&op).await? { + filtered_local.mcp_servers.insert(name, server); + } + } + user.merge(filtered_local); + Ok(user) + } + + /// Lists MCP tools, connecting only servers that have an explicit `Allow` + /// policy for local-scope entries (user-scope are trusted unconditionally). + /// Never prompts. + pub async fn get_mcp_servers(&self) -> Result { + let cfg = self.permitted_mcp_config().await?; + self.services.get_mcp_servers(cfg).await + } + + /// Persist `Allow` decisions for the named servers without prompting. + /// Used by `mcp import` to record consent on the user's behalf — importing + /// is itself an explicit opt-in. + pub async fn allow_mcp_servers(&self, names: &[ServerName]) -> Result<()> { + let cfg = self.services.read_mcp_config(None).await?; + let cwd = self.services.get_environment().cwd; + for name in names { + if let Some(server) = cfg.mcp_servers.get(name) { + let op = PermissionOperation::Mcp { + config: server.clone(), + cwd: cwd.clone(), + message: format!("Connect to MCP server: {name}"), + }; + self.services.allow_operation(&op).await?; + } + } + Ok(()) + } +} diff --git a/crates/forge_app/src/mcp_executor.rs b/crates/forge_app/src/mcp_executor.rs index 21e3d024ba..b7f82612ef 100644 --- a/crates/forge_app/src/mcp_executor.rs +++ b/crates/forge_app/src/mcp_executor.rs @@ -2,15 +2,19 @@ use std::sync::Arc; use forge_domain::{TitleFormat, ToolCallContext, ToolCallFull, ToolName, ToolOutput}; -use crate::McpService; +use crate::{EnvironmentInfra, McpApp, McpService, Services}; pub struct McpExecutor { services: Arc, + /// Shared `McpApp` instance so `permitted_mcp_config` is computed at most + /// once per executor lifetime rather than on every tool call. + mcp_app: McpApp, } -impl McpExecutor { +impl> McpExecutor { pub fn new(services: Arc) -> Self { - Self { services } + let mcp_app = McpApp::new(services.clone()); + Self { services, mcp_app } } pub async fn execute( @@ -22,11 +26,12 @@ impl McpExecutor { .send_tool_input(TitleFormat::info("MCP").sub_title(input.name.as_str())) .await?; - self.services.execute_mcp(input).await + let cfg = self.mcp_app.permitted_mcp_config().await?; + self.services.execute_mcp(input, cfg).await } pub async fn contains_tool(&self, tool_name: &ToolName) -> anyhow::Result { - let mcp_servers = self.services.get_mcp_servers().await?; + let mcp_servers = self.mcp_app.get_mcp_servers().await?; // Convert Claude Code format (mcp__{server}__{tool}) to the internal legacy // format (mcp_{server}_tool_{tool}) before checking, so both name styles match. let legacy = tool_name.to_legacy_mcp_name(); diff --git a/crates/forge_app/src/services.rs b/crates/forge_app/src/services.rs index 78ab0ca533..3a3dae5965 100644 --- a/crates/forge_app/src/services.rs +++ b/crates/forge_app/src/services.rs @@ -218,8 +218,15 @@ pub trait McpConfigManager: Send + Sync { #[async_trait::async_trait] pub trait McpService: Send + Sync { - async fn get_mcp_servers(&self) -> anyhow::Result; - async fn execute_mcp(&self, call: ToolCallFull) -> anyhow::Result; + /// Connect to and list tools from the given MCP servers. + /// The caller is responsible for filtering `cfg` through any policy + /// gating before calling; this method connects every enabled server + /// in `cfg` without re-checking permissions. + async fn get_mcp_servers(&self, cfg: McpConfig) -> anyhow::Result; + /// Execute a tool call against an already-connected server. The caller is + /// responsible for supplying a pre-filtered `cfg` (same one used for + /// `get_mcp_servers`) so denied servers are never reconnected here. + async fn execute_mcp(&self, call: ToolCallFull, cfg: McpConfig) -> anyhow::Result; /// Refresh the MCP cache by fetching fresh data async fn reload_mcp(&self) -> anyhow::Result<()>; } @@ -485,6 +492,24 @@ pub trait PolicyService: Send + Sync { &self, operation: &forge_domain::PermissionOperation, ) -> anyhow::Result; + + /// Check whether an operation is explicitly permitted by the current + /// policy without prompting the user. Returns `true` only when the policy + /// engine resolves to `Allow`; `Confirm` and `Deny` both return `false`. + /// Use this instead of `check_operation_permission` when interactive + /// prompting must be avoided (e.g. MCP connection authorisation). + async fn is_operation_permitted( + &self, + operation: &forge_domain::PermissionOperation, + ) -> anyhow::Result; + + /// Unconditionally persist an allow policy for the given operation. + /// Used when the user has explicitly opted in (e.g. via `mcp import`) so + /// no interactive confirmation is needed. + async fn allow_operation( + &self, + operation: &forge_domain::PermissionOperation, + ) -> anyhow::Result<()>; } /// Skill fetch service @@ -684,12 +709,12 @@ impl McpConfigManager for I { #[async_trait::async_trait] impl McpService for I { - async fn get_mcp_servers(&self) -> anyhow::Result { - self.mcp_service().get_mcp_servers().await + async fn get_mcp_servers(&self, cfg: McpConfig) -> anyhow::Result { + self.mcp_service().get_mcp_servers(cfg).await } - async fn execute_mcp(&self, call: ToolCallFull) -> anyhow::Result { - self.mcp_service().execute_mcp(call).await + async fn execute_mcp(&self, call: ToolCallFull, cfg: McpConfig) -> anyhow::Result { + self.mcp_service().execute_mcp(call, cfg).await } async fn reload_mcp(&self) -> anyhow::Result<()> { @@ -942,6 +967,22 @@ impl PolicyService for I { .check_operation_permission(operation) .await } + + async fn is_operation_permitted( + &self, + operation: &forge_domain::PermissionOperation, + ) -> anyhow::Result { + self.policy_service() + .is_operation_permitted(operation) + .await + } + + async fn allow_operation( + &self, + operation: &forge_domain::PermissionOperation, + ) -> anyhow::Result<()> { + self.policy_service().allow_operation(operation).await + } } #[async_trait::async_trait] diff --git a/crates/forge_app/src/tool_registry.rs b/crates/forge_app/src/tool_registry.rs index dbfff3da06..4d46bc4e48 100644 --- a/crates/forge_app/src/tool_registry.rs +++ b/crates/forge_app/src/tool_registry.rs @@ -21,7 +21,7 @@ use crate::fmt::content::FormatContent; use crate::mcp_executor::McpExecutor; use crate::tool_executor::ToolExecutor; use crate::{ - AgentRegistry, EnvironmentInfra, McpService, PolicyService, ProviderService, Services, + AgentRegistry, EnvironmentInfra, McpApp, PolicyService, ProviderService, Services, ToolResolver, WorkspaceService, }; @@ -241,7 +241,7 @@ impl> ToolReg } pub async fn tools_overview(&self) -> anyhow::Result { - let mcp_tools = self.services.get_mcp_servers().await?; + let mcp_tools = McpApp::new(self.services.clone()).get_mcp_servers().await?; let agent_tools = self.agent_executor.agent_definitions().await?; // Get agents for template rendering in Task tool description diff --git a/crates/forge_domain/src/mcp.rs b/crates/forge_domain/src/mcp.rs index 53afe53725..809bc073a3 100644 --- a/crates/forge_domain/src/mcp.rs +++ b/crates/forge_domain/src/mcp.rs @@ -15,7 +15,7 @@ pub enum Scope { User, } -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Hash)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] #[serde(untagged)] pub enum McpServerConfig { Stdio(McpStdioServer), @@ -65,7 +65,7 @@ impl McpServerConfig { } } -#[derive(Default, Debug, Clone, Serialize, Deserialize, Setters, PartialEq, Hash)] +#[derive(Default, Debug, Clone, Serialize, Deserialize, Setters, PartialEq, Eq, Hash)] #[setters(strip_option, into)] pub struct McpStdioServer { /// Command to execute for starting this MCP server @@ -91,7 +91,7 @@ pub struct McpStdioServer { pub disable: bool, } -#[derive(Default, Debug, Clone, Serialize, Deserialize, PartialEq, Hash)] +#[derive(Default, Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] pub struct McpHttpServer { /// Url of the MCP server (auto-detects HTTP vs SSE transport) #[serde(skip_serializing_if = "String::is_empty", alias = "serverUrl")] @@ -144,7 +144,7 @@ impl McpHttpServer { /// Represents the OAuth setting for an MCP server. /// Supports three states: auto-detect (default), explicitly disabled, or /// explicitly configured. -#[derive(Debug, Clone, PartialEq, Hash, Default)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)] pub enum McpOAuthSetting { /// No explicit OAuth config - auto-detect via server 401 response #[default] @@ -227,7 +227,7 @@ impl McpOAuthSetting { /// Supports automatic OAuth configuration discovery from server metadata. /// When auth_url/token_url are not provided, Forge will automatically /// discover them using RFC 8414 OAuth 2.0 Authorization Server Metadata. -#[derive(Default, Debug, Clone, Serialize, Deserialize, Setters, PartialEq, Hash)] +#[derive(Default, Debug, Clone, Serialize, Deserialize, Setters, PartialEq, Eq, Hash)] #[setters(strip_option, into)] #[serde(rename_all = "camelCase")] pub struct McpOAuthConfig { diff --git a/crates/forge_domain/src/mcp_servers.rs b/crates/forge_domain/src/mcp_servers.rs index 68d2b1ae8b..7dbd7ad821 100644 --- a/crates/forge_domain/src/mcp_servers.rs +++ b/crates/forge_domain/src/mcp_servers.rs @@ -9,9 +9,8 @@ use crate::{ServerName, ToolDefinition}; /// Simplified cache structure that stores only the essential data. /// Validation and TTL checking are handled by the infrastructure layer /// using cacache's built-in metadata capabilities. -#[derive(Default, Clone, Serialize, Deserialize, Debug, PartialEq, derive_setters::Setters)] +#[derive(Default, Clone, Serialize, Deserialize, Debug, PartialEq)] #[serde(rename_all = "camelCase")] -#[setters(strip_option, into)] pub struct McpServers { /// Successfully loaded MCP servers with their tools servers: HashMap>, diff --git a/crates/forge_domain/src/policies/engine.rs b/crates/forge_domain/src/policies/engine.rs index b89747a906..da7e342982 100644 --- a/crates/forge_domain/src/policies/engine.rs +++ b/crates/forge_domain/src/policies/engine.rs @@ -88,10 +88,16 @@ impl<'a> PolicyEngine<'a> { #[cfg(test)] mod tests { + use std::path::PathBuf; + use pretty_assertions::assert_eq; use super::*; - use crate::{ExecuteRule, Fetch, Permission, Policy, PolicyConfig, ReadRule, Rule, WriteRule}; + use crate::mcp::McpServerConfig; + use crate::{ + ExecuteRule, Fetch, McpFilter, McpRule, Permission, Policy, PolicyConfig, ReadRule, Rule, + WriteRule, + }; fn fixture_workflow_with_read_policy() -> PolicyConfig { PolicyConfig::new().add_policy(Policy::Simple { @@ -201,4 +207,66 @@ mod tests { assert_eq!(actual, Permission::Allow); } + + #[test] + fn test_policy_engine_mcp_unmatched_command_defaults_to_confirm() { + // Rule targets "node" but operation uses "npx" — should not match. + let fixture_workflow = PolicyConfig::new().add_policy(Policy::Simple { + permission: Permission::Allow, + rule: Rule::Mcp(McpRule { + mcp: McpFilter { command: Some("node".to_string()), ..McpFilter::default() }, + }), + }); + let fixture = PolicyEngine::new(&fixture_workflow); + let operation = PermissionOperation::Mcp { + config: McpServerConfig::new_stdio("npx", vec![], None), + cwd: PathBuf::from("/home/user/project"), + message: "Connect to MCP server: github".to_string(), + }; + + let actual = fixture.can_perform(&operation); + + assert_eq!(actual, Permission::Confirm); + } + + #[test] + fn test_policy_engine_mcp_matching_command_glob_allows() { + let fixture_workflow = PolicyConfig::new().add_policy(Policy::Simple { + permission: Permission::Allow, + rule: Rule::Mcp(McpRule { + mcp: McpFilter { command: Some("np*".to_string()), ..McpFilter::default() }, + }), + }); + let fixture = PolicyEngine::new(&fixture_workflow); + let operation = PermissionOperation::Mcp { + config: McpServerConfig::new_stdio("npx", vec![], None), + cwd: PathBuf::from("/home/user/project"), + message: "Connect to MCP server: github".to_string(), + }; + + let actual = fixture.can_perform(&operation); + + assert_eq!(actual, Permission::Allow); + } + + #[test] + fn test_policy_engine_mcp_url_rule_does_not_match_stdio() { + // A url-only rule must not match a stdio server. + let fixture_workflow = PolicyConfig::new().add_policy(Policy::Simple { + permission: Permission::Allow, + rule: Rule::Mcp(McpRule { + mcp: McpFilter { url: Some("*".to_string()), ..McpFilter::default() }, + }), + }); + let fixture = PolicyEngine::new(&fixture_workflow); + let operation = PermissionOperation::Mcp { + config: McpServerConfig::new_stdio("npx", vec![], None), + cwd: PathBuf::from("/home/user/project"), + message: "Connect to MCP server: github".to_string(), + }; + + let actual = fixture.can_perform(&operation); + + assert_eq!(actual, Permission::Confirm); + } } diff --git a/crates/forge_domain/src/policies/operation.rs b/crates/forge_domain/src/policies/operation.rs index 3a99e383dc..46a89d0a50 100644 --- a/crates/forge_domain/src/policies/operation.rs +++ b/crates/forge_domain/src/policies/operation.rs @@ -1,5 +1,7 @@ use std::path::PathBuf; +use crate::mcp::McpServerConfig; + /// Operations that can be performed and need policy checking #[derive(Debug, Clone, PartialEq, Eq)] pub enum PermissionOperation { @@ -23,4 +25,16 @@ pub enum PermissionOperation { cwd: PathBuf, message: String, }, + /// MCP server connection authorization. Evaluated once per server when the + /// MCP service brings up connections; the decision then gates every tool + /// call routed through that server. The `config` field carries either a + /// stdio server (command + args) or an HTTP server (url) — never both. + Mcp { + /// The server configuration — either `Stdio` (command + args) or `Http` + /// (url). + config: McpServerConfig, + /// The current working directory at the time of the operation. + cwd: PathBuf, + message: String, + }, } diff --git a/crates/forge_domain/src/policies/rule.rs b/crates/forge_domain/src/policies/rule.rs index 652dbab8a9..44cf35136b 100644 --- a/crates/forge_domain/src/policies/rule.rs +++ b/crates/forge_domain/src/policies/rule.rs @@ -6,6 +6,7 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use super::operation::PermissionOperation; +use crate::mcp::McpServerConfig; /// Rule for write operations with a glob pattern #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, JsonSchema)] @@ -39,6 +40,42 @@ pub struct Fetch { pub dir: Option, } +/// Filter criteria nested inside an [`McpRule`]. All fields are optional; +/// omitting a field means "match any value" for that dimension. Multiple +/// fields are combined with logical AND. +#[derive( + Debug, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, JsonSchema, +)] +pub struct McpFilter { + /// Optional glob over the command used to launch a stdio MCP server. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub command: Option, + /// Optional glob patterns over the server's argument list (all must match). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub args: Option>, + /// Optional glob over the URL of an HTTP/SSE MCP server. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub url: Option, + /// Optional working directory glob pattern. `None` matches any directory. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub dir: Option, +} + +/// Rule for MCP server connection authorization. The required `mcp` key +/// identifies this as an MCP rule (analogous to `write:`, `read:`, etc.) and +/// disambiguates it from other rule types in the untagged `Rule` enum. +/// +/// The value is an [`McpFilter`] object whose fields are all optional: +/// an empty object `{}` matches any MCP server; populating fields narrows the +/// match. Stdio servers are matched via `command`/`args`; HTTP servers via +/// `url`. Specifying both `command` and `url` will never match (a server is +/// either stdio or HTTP, not both). +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, JsonSchema)] +pub struct McpRule { + /// Filter criteria for the MCP server. Use `{}` to match any server. + pub mcp: McpFilter, +} + /// Rules that define what operations are covered by a policy #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, JsonSchema)] #[serde(untagged)] @@ -51,62 +88,96 @@ pub enum Rule { Execute(ExecuteRule), /// Rule for network fetch operations with a URL pattern Fetch(Fetch), + /// Rule for MCP tool invocations with a tool-name glob pattern + Mcp(McpRule), } impl Rule { /// Check if this rule matches the given operation pub fn matches(&self, operation: &PermissionOperation) -> bool { match (self, operation) { - (Rule::Write(rule), PermissionOperation::Write { path, cwd, message: _ }) => { - let pattern_matches = match_pattern(&rule.write, path); - let dir = match &rule.dir { - Some(wd_pattern) => match_pattern(wd_pattern, cwd), - None => true, /* If no working directory pattern is specified, it matches any - * directory */ - }; - pattern_matches && dir + (Rule::Write(rule), PermissionOperation::Write { path, cwd, .. }) => { + match_pattern(&rule.write, path) && match_dir(&rule.dir, cwd) } - (Rule::Read(rule), PermissionOperation::Read { path, cwd, message: _ }) => { - let pattern_matches = match_pattern(&rule.read, path); - let dir_matches = match &rule.dir { - Some(wd_pattern) => match_pattern(wd_pattern, cwd), - None => true, /* If no working directory pattern is specified, it matches any - * directory */ - }; - pattern_matches && dir_matches + (Rule::Read(rule), PermissionOperation::Read { path, cwd, .. }) => { + match_pattern(&rule.read, path) && match_dir(&rule.dir, cwd) } - (Rule::Execute(rule), PermissionOperation::Execute { command: cmd, cwd }) => { - let command_matches = match_pattern(&rule.command, cmd); - let dir_matches = match &rule.dir { - Some(wd_pattern) => match_pattern(wd_pattern, cwd), - None => true, /* If no working directory pattern is specified, it matches any - * directory */ - }; - command_matches && dir_matches + match_pattern(&rule.command, cmd) && match_dir(&rule.dir, cwd) + } + (Rule::Fetch(rule), PermissionOperation::Fetch { url, cwd, .. }) => { + match_pattern(&rule.url, url) && match_dir(&rule.dir, cwd) } - (Rule::Fetch(rule), PermissionOperation::Fetch { url, cwd, message: _ }) => { - let url_matches = match_pattern(&rule.url, url); - let dir_matches = match &rule.dir { - Some(wd_pattern) => match_pattern(wd_pattern, cwd), - None => true, /* If no working directory pattern is specified, it matches any - * directory */ - }; - url_matches && dir_matches + (Rule::Mcp(rule), PermissionOperation::Mcp { config, cwd, .. }) => { + rule.mcp.matches_config(config) && match_dir(&rule.mcp.dir, cwd) } _ => false, } } } -/// Helper function to match a glob pattern against a path or string +/// Returns true when `opt_pattern` is absent (wildcard) or matches `target`. +fn match_dir>(opt_pattern: &Option, target: P) -> bool { + opt_pattern + .as_deref() + .is_none_or(|pat| match_pattern(pat, target)) +} + +/// Returns true when `pattern` glob-matches `target`. fn match_pattern>(pattern: &str, target: P) -> bool { - match Pattern::new(pattern) { - Ok(glob_pattern) => { - let target_str = target.as_ref().to_string_lossy(); - glob_pattern.matches(&target_str) + Pattern::new(pattern).is_ok_and(|p| p.matches(&target.as_ref().to_string_lossy())) +} + +impl McpFilter { + /// Build a filter that exactly pins `config` — stdio servers are matched by + /// `command` + `args`; HTTP servers by `url`. The `dir` is always set to + /// `cwd` so the rule is scoped to the working directory. + pub fn from_config(config: &McpServerConfig, cwd: &std::path::Path) -> Self { + let dir = Some(cwd.to_string_lossy().into_owned()); + match config { + McpServerConfig::Stdio(s) => Self { + command: Some(s.command.clone()), + args: if s.args.is_empty() { + None + } else { + Some(s.args.clone()) + }, + url: None, + dir, + }, + McpServerConfig::Http(h) => { + Self { command: None, args: None, url: Some(h.url.clone()), dir } + } + } + } + + /// Returns true when this filter is compatible with `config`. + /// + /// A stdio filter (has `command`/`args`, no `url`) only matches stdio + /// servers; an HTTP filter (has `url`, no `command`/`args`) only + /// matches HTTP servers; an empty filter matches both. + fn matches_config(&self, config: &McpServerConfig) -> bool { + match config { + McpServerConfig::Stdio(s) => { + // A url-only rule must not match a stdio server + self.url.is_none() + && self + .command + .as_deref() + .is_none_or(|p| match_pattern(p, &s.command)) + && self.args.as_deref().is_none_or(|patterns| { + patterns + .iter() + .all(|p| s.args.iter().any(|a| match_pattern(p, a))) + }) + } + McpServerConfig::Http(h) => { + // A command/args-only rule must not match an HTTP server + self.command.is_none() + && self.args.is_none() + && self.url.as_deref().is_none_or(|p| match_pattern(p, &h.url)) + } } - Err(_) => false, // Invalid pattern doesn't match anything } } @@ -150,6 +221,32 @@ impl Display for Fetch { } } +impl Display for McpRule { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let filter = &self.mcp; + let mut parts: Vec = Vec::new(); + if let Some(cmd) = &filter.command { + parts.push(format!("command '{cmd}'")); + } + if let Some(args) = &filter.args { + parts.push(format!("args [{}]", args.join(", "))); + } + if let Some(url) = &filter.url { + parts.push(format!("url '{url}'")); + } + let base = if parts.is_empty() { + "mcp server (any)".to_string() + } else { + format!("mcp server with {}", parts.join(", ")) + }; + if let Some(wd) = &filter.dir { + write!(f, "{} in '{wd}'", base) + } else { + write!(f, "{}", base) + } + } +} + impl Display for Rule { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { @@ -157,6 +254,7 @@ impl Display for Rule { Rule::Read(rule) => write!(f, "{rule}"), Rule::Execute(rule) => write!(f, "{rule}"), Rule::Fetch(rule) => write!(f, "{rule}"), + Rule::Mcp(rule) => write!(f, "{rule}"), } } } @@ -168,6 +266,7 @@ mod tests { use pretty_assertions::assert_eq; use super::*; + use crate::mcp::McpServerConfig; fn fixture_write_operation() -> PermissionOperation { PermissionOperation::Write { @@ -208,6 +307,30 @@ mod tests { } } + fn fixture_mcp_stdio_operation() -> PermissionOperation { + PermissionOperation::Mcp { + config: McpServerConfig::new_stdio( + "npx", + vec!["-y".to_string(), "@github/mcp".to_string()], + None, + ), + cwd: PathBuf::from("/home/user/project"), + message: "Connect to MCP server: github".to_string(), + } + } + + fn fixture_mcp_http_operation() -> PermissionOperation { + PermissionOperation::Mcp { + config: McpServerConfig::new_http("https://mcp.example.com/sse"), + cwd: PathBuf::from("/home/user/project"), + message: "Connect to MCP server: example".to_string(), + } + } + + fn fixture_mcp_rule(filter: McpFilter) -> Rule { + Rule::Mcp(McpRule { mcp: filter }) + } + #[test] fn test_rule_matches_write_operation() { let fixture = Rule::Write(WriteRule { write: "src/**/*.rs".to_string(), dir: None }); @@ -325,4 +448,237 @@ mod tests { assert_eq!(actual, true); } + + // ── MCP stdio tests ────────────────────────────────────────────────────── + + #[test] + fn test_mcp_stdio_empty_filter_matches_any_stdio() { + let fixture = fixture_mcp_rule(McpFilter::default()); + let operation = fixture_mcp_stdio_operation(); + + let actual = fixture.matches(&operation); + + assert_eq!(actual, true); + } + + #[test] + fn test_mcp_stdio_command_exact_match() { + let fixture = fixture_mcp_rule(McpFilter { + command: Some("npx".to_string()), + ..McpFilter::default() + }); + let operation = fixture_mcp_stdio_operation(); + + let actual = fixture.matches(&operation); + + assert_eq!(actual, true); + } + + #[test] + fn test_mcp_stdio_command_glob_match() { + let fixture = fixture_mcp_rule(McpFilter { + command: Some("np*".to_string()), + ..McpFilter::default() + }); + let operation = fixture_mcp_stdio_operation(); + + let actual = fixture.matches(&operation); + + assert_eq!(actual, true); + } + + #[test] + fn test_mcp_stdio_command_no_match() { + let fixture = fixture_mcp_rule(McpFilter { + command: Some("node".to_string()), + ..McpFilter::default() + }); + let operation = fixture_mcp_stdio_operation(); + + let actual = fixture.matches(&operation); + + assert_eq!(actual, false); + } + + #[test] + fn test_mcp_stdio_args_match() { + let fixture = fixture_mcp_rule(McpFilter { + args: Some(vec!["@github/mcp".to_string()]), + ..McpFilter::default() + }); + let operation = fixture_mcp_stdio_operation(); + + let actual = fixture.matches(&operation); + + assert_eq!(actual, true); + } + + #[test] + fn test_mcp_stdio_args_glob_match() { + let fixture = fixture_mcp_rule(McpFilter { + args: Some(vec!["@github/*".to_string()]), + ..McpFilter::default() + }); + let operation = fixture_mcp_stdio_operation(); + + let actual = fixture.matches(&operation); + + assert_eq!(actual, true); + } + + #[test] + fn test_mcp_stdio_args_no_match() { + let fixture = fixture_mcp_rule(McpFilter { + args: Some(vec!["@slack/mcp".to_string()]), + ..McpFilter::default() + }); + let operation = fixture_mcp_stdio_operation(); + + let actual = fixture.matches(&operation); + + assert_eq!(actual, false); + } + + #[test] + fn test_mcp_stdio_url_rule_does_not_match_stdio_server() { + // A url-only rule must not match a stdio server + let fixture = + fixture_mcp_rule(McpFilter { url: Some("*".to_string()), ..McpFilter::default() }); + let operation = fixture_mcp_stdio_operation(); + + let actual = fixture.matches(&operation); + + assert_eq!(actual, false); + } + + // ── MCP HTTP tests ─────────────────────────────────────────────────────── + + #[test] + fn test_mcp_http_empty_filter_matches_any_http() { + let fixture = fixture_mcp_rule(McpFilter::default()); + let operation = fixture_mcp_http_operation(); + + let actual = fixture.matches(&operation); + + assert_eq!(actual, true); + } + + #[test] + fn test_mcp_http_url_exact_match() { + let fixture = fixture_mcp_rule(McpFilter { + url: Some("https://mcp.example.com/sse".to_string()), + ..McpFilter::default() + }); + let operation = fixture_mcp_http_operation(); + + let actual = fixture.matches(&operation); + + assert_eq!(actual, true); + } + + #[test] + fn test_mcp_http_url_glob_match() { + let fixture = fixture_mcp_rule(McpFilter { + url: Some("https://mcp.example.com/*".to_string()), + ..McpFilter::default() + }); + let operation = fixture_mcp_http_operation(); + + let actual = fixture.matches(&operation); + + assert_eq!(actual, true); + } + + #[test] + fn test_mcp_http_url_no_match() { + let fixture = fixture_mcp_rule(McpFilter { + url: Some("https://other.example.com/*".to_string()), + ..McpFilter::default() + }); + let operation = fixture_mcp_http_operation(); + + let actual = fixture.matches(&operation); + + assert_eq!(actual, false); + } + + #[test] + fn test_mcp_http_command_rule_does_not_match_http_server() { + // A command-only rule must not match an HTTP server + let fixture = + fixture_mcp_rule(McpFilter { command: Some("*".to_string()), ..McpFilter::default() }); + let operation = fixture_mcp_http_operation(); + + let actual = fixture.matches(&operation); + + assert_eq!(actual, false); + } + + // ── Cross-type and dir tests ───────────────────────────────────────────── + + #[test] + fn test_mcp_rule_does_not_match_non_mcp_operation() { + let fixture = fixture_mcp_rule(McpFilter::default()); + let operation = fixture_execute_operation(); + + let actual = fixture.matches(&operation); + + assert_eq!(actual, false); + } + + #[test] + fn test_mcp_dir_pattern_matches_stdio() { + let fixture = fixture_mcp_rule(McpFilter { + dir: Some("/home/user/*".to_string()), + ..McpFilter::default() + }); + let operation = fixture_mcp_stdio_operation(); + + let actual = fixture.matches(&operation); + + assert_eq!(actual, true); + } + + #[test] + fn test_mcp_dir_pattern_no_match_stdio() { + let fixture = fixture_mcp_rule(McpFilter { + dir: Some("/different/path/*".to_string()), + ..McpFilter::default() + }); + let operation = fixture_mcp_stdio_operation(); + + let actual = fixture.matches(&operation); + + assert_eq!(actual, false); + } + + #[test] + fn test_mcp_combined_command_and_dir_match() { + let fixture = fixture_mcp_rule(McpFilter { + command: Some("npx".to_string()), + args: None, + url: None, + dir: Some("/home/user/*".to_string()), + }); + let operation = fixture_mcp_stdio_operation(); + + let actual = fixture.matches(&operation); + + assert_eq!(actual, true); + } + + #[test] + fn test_mcp_combined_command_and_dir_dir_mismatch() { + let fixture = fixture_mcp_rule(McpFilter { + command: Some("npx".to_string()), + args: None, + url: None, + dir: Some("/different/*".to_string()), + }); + let operation = fixture_mcp_stdio_operation(); + + let actual = fixture.matches(&operation); + + assert_eq!(actual, false); + } } diff --git a/crates/forge_infra/src/forge_infra.rs b/crates/forge_infra/src/forge_infra.rs index 31f5cb63e5..75dd3fa490 100644 --- a/crates/forge_infra/src/forge_infra.rs +++ b/crates/forge_infra/src/forge_infra.rs @@ -260,10 +260,10 @@ impl UserInfra for ForgeInfra { async fn select_one( &self, - message: &str, + prompt: impl Into + Send, options: Vec, ) -> anyhow::Result> { - self.inquire_service.select_one(message, options).await + self.inquire_service.select_one(prompt, options).await } async fn select_many( diff --git a/crates/forge_infra/src/inquire.rs b/crates/forge_infra/src/inquire.rs index f31b8f2a2f..d0a71d9796 100644 --- a/crates/forge_infra/src/inquire.rs +++ b/crates/forge_infra/src/inquire.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use forge_app::UserInfra; +use forge_app::{SelectPrompt, UserInfra}; use forge_select::ForgeWidget; pub struct ForgeInquire; @@ -34,16 +34,23 @@ impl UserInfra for ForgeInquire { async fn select_one( &self, - message: &str, + prompt: impl Into + Send, options: Vec, ) -> Result> { if options.is_empty() { return Ok(None); } - let message = message.to_string(); - self.prompt(move || ForgeWidget::select(&message, options).prompt()) - .await + let SelectPrompt { message, header } = prompt.into(); + self.prompt(move || { + let builder = ForgeWidget::select(&message, options); + if header.is_empty() { + builder.prompt() + } else { + builder.with_help_message(header).prompt() + } + }) + .await } async fn select_many( diff --git a/crates/forge_main/src/ui.rs b/crates/forge_main/src/ui.rs index 9ecd50fc41..f416d382ca 100644 --- a/crates/forge_main/src/ui.rs +++ b/crates/forge_main/src/ui.rs @@ -11,7 +11,7 @@ use convert_case::{Case, Casing}; use forge_api::{ API, AgentId, AnyProvider, ApiKeyRequest, AuthContextRequest, AuthContextResponse, ChatRequest, ChatResponse, CodeRequest, ConfigOperation, Conversation, ConversationId, DeviceCodeRequest, - Event, InterruptionReason, ModelId, Provider, ProviderId, TextMessage, UserPrompt, + Event, InterruptionReason, ModelId, Provider, ProviderId, Scope, TextMessage, UserPrompt, }; use forge_app::utils::{format_display_path, truncate_key}; use forge_app::{CommitResult, ToolResolver}; @@ -223,6 +223,7 @@ impl A + Send + Sync> UI self.display_banner()?; self.trace_user(); self.hydrate_caches(); + self.request_local_mcp_permissions().await?; Ok(()) } @@ -367,6 +368,7 @@ impl A + Send + Sync> UI self.trace_user(); self.hydrate_caches(); self.init_conversation().await?; + self.request_local_mcp_permissions().await?; // Check for dispatch flag first if let Some(dispatch_json) = self.cli.event.clone() { @@ -443,12 +445,23 @@ impl A + Send + Sync> UI } } + /// Reads the local-scope MCP config and asks the user for permission for + /// each server that does not yet have a recorded decision. Call this + /// synchronously before the REPL takes over stdin so prompts don't race + /// with user input. + async fn request_local_mcp_permissions(&self) -> Result<()> { + let local_cfg = self.api.read_mcp_config(Some(&Scope::Local)).await?; + self.api.request_mcp_permissions(local_cfg).await + } + // Improve startup time by hydrating caches fn hydrate_caches(&self) { let api = self.api.clone(); tokio::spawn(async move { api.get_models().await }); let api = self.api.clone(); - tokio::spawn(async move { api.get_tools().await }); + tokio::spawn(async move { + let _ = api.get_tools().await; + }); let api = self.api.clone(); tokio::spawn(async move { api.get_agent_infos().await }); let api = self.api.clone(); @@ -569,6 +582,10 @@ impl A + Send + Sync> UI // Write back to the specific scope only self.api.write_mcp_config(&scope, &scope_config).await?; + // Importing is an explicit opt-in — persist Allow decisions so + // the user is not prompted on first use. + self.api.allow_mcp_servers(&added_servers).await?; + // Log each added server after successful write for server_name in added_servers { self.writeln_title(TitleFormat::info(format!( @@ -3322,7 +3339,7 @@ impl A + Send + Sync> UI .collect(); match ForgeWidget::select("Select authentication method:", method_names.clone()) - .with_help_message("Use arrow keys to navigate and Enter to select") + .with_help_message(vec!["Use arrow keys to navigate and Enter to select"]) .prompt()? { Some(selected_name) => { diff --git a/crates/forge_repo/src/forge_repo.rs b/crates/forge_repo/src/forge_repo.rs index 555758c7b5..6e2f7d55bf 100644 --- a/crates/forge_repo/src/forge_repo.rs +++ b/crates/forge_repo/src/forge_repo.rs @@ -418,18 +418,21 @@ where async fn select_one( &self, - message: &str, + prompt: impl Into + Send, options: Vec, ) -> anyhow::Result> { - self.infra.select_one(message, options).await + self.infra.select_one(prompt, options).await } - async fn select_one_enum(&self, message: &str) -> anyhow::Result> + async fn select_one_enum( + &self, + prompt: impl Into + Send, + ) -> anyhow::Result> where T: Clone + std::fmt::Display + Send + 'static + strum::IntoEnumIterator + std::str::FromStr, ::Err: std::fmt::Debug, { - self.infra.select_one_enum(message).await + self.infra.select_one_enum(prompt).await } async fn select_many( diff --git a/crates/forge_select/src/select.rs b/crates/forge_select/src/select.rs index ed59b9d4ea..0cbf1099df 100644 --- a/crates/forge_select/src/select.rs +++ b/crates/forge_select/src/select.rs @@ -11,7 +11,7 @@ pub struct SelectBuilder { pub(crate) options: Vec, pub(crate) starting_cursor: Option, pub(crate) default: Option, - pub(crate) help_message: Option<&'static str>, + pub(crate) help_message: Vec, pub(crate) initial_text: Option, pub(crate) header_lines: usize, pub(crate) preview: Option, @@ -43,9 +43,10 @@ impl SelectBuilder { self } - /// Set help message displayed as a header above the list. - pub fn with_help_message(mut self, message: &'static str) -> Self { - self.help_message = Some(message); + /// Set one or more header lines displayed above the list. + /// Each entry becomes a separate non-selectable header row. + pub fn with_help_message(mut self, lines: impl IntoIterator>) -> Self { + self.help_message = lines.into_iter().map(|l| l.into()).collect(); self } @@ -124,9 +125,12 @@ impl SelectBuilder { selector = selector.initial_raw(Some(cursor.to_string())); } - if let Some(help) = self.help_message { - selector.rows.insert(0, SelectRow::header(help)); - selector.header_lines = selector.header_lines.saturating_add(1); + if !self.help_message.is_empty() { + let count = self.help_message.len(); + for line in self.help_message.into_iter().rev() { + selector.rows.insert(0, SelectRow::header(line)); + } + selector.header_lines = selector.header_lines.saturating_add(count); } let selected = selector.prompt()?; diff --git a/crates/forge_select/src/widget.rs b/crates/forge_select/src/widget.rs index ac73b5cd57..075c4b56e9 100644 --- a/crates/forge_select/src/widget.rs +++ b/crates/forge_select/src/widget.rs @@ -18,7 +18,7 @@ impl ForgeWidget { options, starting_cursor: None, default: None, - help_message: None, + help_message: Vec::new(), initial_text: None, header_lines: 0, preview: None, diff --git a/crates/forge_services/src/forge_services.rs b/crates/forge_services/src/forge_services.rs index cd2f775899..d2ba2b0522 100644 --- a/crates/forge_services/src/forge_services.rs +++ b/crates/forge_services/src/forge_services.rs @@ -30,7 +30,7 @@ use crate::tool_services::{ ForgeFsUndo, ForgeFsWrite, ForgeImageRead, ForgePlanCreate, ForgeShell, ForgeSkillFetch, }; -type McpService = ForgeMcpService, F, ::Client>; +type McpService = ForgeMcpService::Client>; type AuthService = ForgeAuthService; /// ForgeApp is the main application container that implements the App trait. @@ -109,8 +109,9 @@ impl< > ForgeServices { pub fn new(infra: Arc) -> Self { + let mcp_service = Arc::new(ForgeMcpService::new(infra.clone())); let mcp_manager = Arc::new(ForgeMcpManager::new(infra.clone())); - let mcp_service = Arc::new(ForgeMcpService::new(mcp_manager.clone(), infra.clone())); + let policy_service = ForgePolicyService::new(infra.clone()); let template_service = Arc::new(ForgeTemplateService::new(infra.clone())); let attachment_service = Arc::new(ForgeChatRequest::new(infra.clone())); let suggestion_service = Arc::new(ForgeDiscoveryService::new(infra.clone())); @@ -133,7 +134,6 @@ impl< Arc::new(ForgeCustomInstructionsService::new(infra.clone())); let agent_registry_service = Arc::new(ForgeAgentRegistryService::new(infra.clone())); let command_loader_service = Arc::new(ForgeCommandLoaderService::new(infra.clone())); - let policy_service = ForgePolicyService::new(infra.clone()); let provider_auth_service = ForgeProviderAuthService::new(infra.clone()); let discovery = Arc::new(FdDefault::new(infra.clone())); let workspace_service = Arc::new(crate::context_engine::ForgeWorkspaceService::new( diff --git a/crates/forge_services/src/mcp/service.rs b/crates/forge_services/src/mcp/service.rs index a96692de62..5c03cfa7c6 100644 --- a/crates/forge_services/src/mcp/service.rs +++ b/crates/forge_services/src/mcp/service.rs @@ -6,9 +6,7 @@ use forge_app::domain::{ McpConfig, McpServerConfig, McpServers, ServerName, ToolCallFull, ToolDefinition, ToolName, ToolOutput, }; -use forge_app::{ - EnvironmentInfra, KVStore, McpClientInfra, McpConfigManager, McpServerInfra, McpService, -}; +use forge_app::{EnvironmentInfra, KVStore, McpClientInfra, McpServerInfra, McpService}; use tokio::sync::{Mutex, RwLock}; use crate::mcp::tool::McpExecutor; @@ -23,12 +21,11 @@ fn generate_mcp_tool_name(server_name: &ServerName, tool_name: &ToolName) -> Too } #[derive(Clone)] -pub struct ForgeMcpService { +pub struct ForgeMcpService { tools: Arc>>>>, failed_servers: Arc>>, previous_config_hash: Arc>, init_lock: Arc>, - manager: Arc, infra: Arc, } @@ -39,20 +36,18 @@ struct ToolHolder { server_name: String, } -impl ForgeMcpService +impl ForgeMcpService where - M: McpConfigManager, - I: McpServerInfra + KVStore + EnvironmentInfra, + I: McpServerInfra + KVStore + EnvironmentInfra + 'static, C: McpClientInfra + Clone, C: From<::Client>, { - pub fn new(manager: Arc, infra: Arc) -> Self { + pub fn new(infra: Arc) -> Self { Self { tools: Default::default(), failed_servers: Default::default(), previous_config_hash: Arc::new(Mutex::new(Default::default())), init_lock: Arc::new(Mutex::new(())), - manager, infra, } } @@ -100,12 +95,10 @@ where Ok(()) } - async fn init_mcp(&self) -> anyhow::Result<()> { - let mcp = self.manager.read_mcp_config(None).await?; - + async fn init_mcp(&self, cfg: McpConfig) -> anyhow::Result<()> { // Fast path: if config is unchanged, skip reinitialization without acquiring // the lock - if !self.is_config_modified(&mcp).await { + if !self.is_config_modified(&cfg).await { return Ok(()); } @@ -114,33 +107,30 @@ where let _guard = self.init_lock.lock().await; // Double-check under the lock: a concurrent caller may have already updated - if !self.is_config_modified(&mcp).await { + if !self.is_config_modified(&cfg).await { return Ok(()); } - self.update_mcp(mcp).await + self.update_mcp(cfg).await } - async fn update_mcp(&self, mcp: McpConfig) -> Result<(), anyhow::Error> { - // Compute the hash early before mcp is consumed, but write it only after - // all connections are established so waiters on init_lock see a consistent - // state. + async fn update_mcp(&self, mcp: McpConfig) -> anyhow::Result<()> { + // Compute the hash early before `mcp` is consumed, but write it only + // after all connections are established so waiters on init_lock see a + // consistent state. let new_hash = mcp.cache_key(); self.clear_tools().await; - - // Clear failed servers map before attempting new connections self.failed_servers.write().await.clear(); let connections: Vec<_> = mcp .mcp_servers .into_iter() - .filter(|v| !v.1.is_disabled()) + .filter(|(_, server)| !server.is_disabled()) .map(|(name, server)| async move { let conn = self .connect(&name, server) .await .context(format!("Failed to initiate MCP server: {name}")); - (name, conn) }) .collect(); @@ -148,17 +138,11 @@ where let results = futures::future::join_all(connections).await; for (server_name, result) in results { - match result { - Ok(_) => {} - Err(error) => { - // Format error with full chain for detailed diagnostics - // Using Debug formatting with alternate flag shows the full error chain - let error_string = format!("{error:?}"); - self.failed_servers - .write() - .await - .insert(server_name.clone(), error_string.clone()); - } + if let Err(error) = result { + self.failed_servers + .write() + .await + .insert(server_name, format!("{error:?}")); } } @@ -170,8 +154,8 @@ where Ok(()) } - async fn list(&self) -> anyhow::Result { - self.init_mcp().await?; + async fn list(&self, cfg: McpConfig) -> anyhow::Result { + self.init_mcp(cfg).await?; let tools = self.tools.read().await; let mut grouped_tools = std::collections::HashMap::new(); @@ -187,13 +171,15 @@ where Ok(McpServers::new(grouped_tools, failures)) } + async fn clear_tools(&self) { self.tools.write().await.clear() } - async fn call(&self, call: ToolCallFull) -> anyhow::Result { - // Ensure MCP connections are initialized before calling tools - self.init_mcp().await?; + async fn call(&self, call: ToolCallFull, cfg: McpConfig) -> anyhow::Result { + // Use the caller-supplied pre-filtered config so only permitted servers + // are (re)connected here. + self.init_mcp(cfg).await?; let tools = self.tools.read().await; @@ -226,32 +212,27 @@ where } #[async_trait::async_trait] -impl McpService - for ForgeMcpService +impl McpService for ForgeMcpService where + I: McpServerInfra + KVStore + EnvironmentInfra + 'static, C: McpClientInfra + Clone, C: From<::Client>, { - async fn get_mcp_servers(&self) -> anyhow::Result { - // Read current configs to compute merged hash - let mcp_config = self.manager.read_mcp_config(None).await?; + async fn get_mcp_servers(&self, cfg: McpConfig) -> anyhow::Result { + let config_hash = cfg.cache_key(); - // Compute unified hash from merged config - let config_hash = mcp_config.cache_key(); - - // Check if cache is valid (exists and not expired) - // Cache is valid, retrieve it if let Some(cache) = self.infra.cache_get::<_, McpServers>(&config_hash).await? { - return Ok(cache.clone()); + return Ok(cache); } - let servers = self.list().await?; + let servers = self.list(cfg).await?; self.infra.cache_set(&config_hash, &servers).await?; + Ok(servers) } - async fn execute_mcp(&self, call: ToolCallFull) -> anyhow::Result { - self.call(call).await + async fn execute_mcp(&self, call: ToolCallFull, cfg: McpConfig) -> anyhow::Result { + self.call(call, cfg).await } async fn reload_mcp(&self) -> anyhow::Result<()> { @@ -266,12 +247,10 @@ mod tests { use fake::{Fake, Faker}; use forge_app::domain::{ - ConfigOperation, Environment, McpConfig, McpServerConfig, Scope, ServerName, ToolCallFull, + ConfigOperation, Environment, McpConfig, McpServerConfig, ServerName, ToolCallFull, ToolDefinition, ToolName, ToolOutput, }; - use forge_app::{ - EnvironmentInfra, KVStore, McpClientInfra, McpConfigManager, McpServerInfra, McpService, - }; + use forge_app::{EnvironmentInfra, KVStore, McpClientInfra, McpServerInfra, McpService}; use forge_config::ForgeConfig; use pretty_assertions::assert_eq; use serde::de::DeserializeOwned; @@ -298,30 +277,6 @@ mod tests { } } - // ── Mock config manager ────────────────────────────────────────────────── - - struct MockMcpManager; - - #[async_trait::async_trait] - impl McpConfigManager for MockMcpManager { - async fn read_mcp_config(&self, _scope: Option<&Scope>) -> anyhow::Result { - let mut servers = BTreeMap::new(); - servers.insert( - ServerName::from("test-server".to_string()), - McpServerConfig::new_stdio("echo", vec![], None), - ); - Ok(McpConfig { mcp_servers: servers }) - } - - async fn write_mcp_config( - &self, - _config: &McpConfig, - _scope: &Scope, - ) -> anyhow::Result<()> { - Ok(()) - } - } - // ── Mock infrastructure ────────────────────────────────────────────────── #[derive(Clone)] @@ -390,8 +345,17 @@ mod tests { // ── Fixture ────────────────────────────────────────────────────────────── - fn fixture() -> ForgeMcpService { - ForgeMcpService::new(Arc::new(MockMcpManager), Arc::new(MockInfra)) + fn fixture() -> ForgeMcpService { + ForgeMcpService::new(Arc::new(MockInfra)) + } + + fn fixture_cfg() -> McpConfig { + let mut servers = BTreeMap::new(); + servers.insert( + ServerName::from("test-server".to_string()), + McpServerConfig::new_stdio("echo", vec![], None), + ); + McpConfig { mcp_servers: servers } } #[test] @@ -445,14 +409,17 @@ mod tests { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn test_concurrent_init_does_not_race() { let service = Arc::new(fixture()); + let cfg = fixture_cfg(); let s1 = service.clone(); let s2 = service.clone(); - let (r1, r2) = tokio::join!(s1.get_mcp_servers(), s2.get_mcp_servers()); + let c1 = cfg.clone(); + let c2 = cfg.clone(); + let (r1, r2) = tokio::join!(s1.get_mcp_servers(c1), s2.get_mcp_servers(c2)); r1.unwrap(); r2.unwrap(); - let servers = service.get_mcp_servers().await.unwrap(); + let servers = service.get_mcp_servers(cfg).await.unwrap(); let tool_name = servers .get_servers() .values() @@ -463,7 +430,7 @@ mod tests { .clone(); let call = ToolCallFull::new(tool_name); - let actual = service.execute_mcp(call).await.unwrap(); + let actual = service.execute_mcp(call, fixture_cfg()).await.unwrap(); let expected = ToolOutput::text("mock result"); assert_eq!(actual, expected); } diff --git a/crates/forge_services/src/policy.rs b/crates/forge_services/src/policy.rs index 97621bc04b..4c0751084f 100644 --- a/crates/forge_services/src/policy.rs +++ b/crates/forge_services/src/policy.rs @@ -4,12 +4,12 @@ use std::sync::{Arc, LazyLock}; use anyhow::Context; use bytes::Bytes; use forge_app::domain::{ - ExecuteRule, Fetch, Permission, PermissionOperation, Policy, PolicyConfig, PolicyEngine, - ReadRule, Rule, WriteRule, + ExecuteRule, Fetch, McpFilter, McpRule, Permission, PermissionOperation, Policy, PolicyConfig, + PolicyEngine, ReadRule, Rule, WriteRule, }; use forge_app::{ DirectoryReaderInfra, EnvironmentInfra, FileInfoInfra, FileReaderInfra, FileWriterInfra, - PolicyDecision, PolicyService, UserInfra, + PolicyDecision, PolicyService, SelectPrompt, UserInfra, }; use strum_macros::{Display, EnumIter}; @@ -27,10 +27,24 @@ pub enum PolicyPermission { AcceptAndRemember, } +/// Two-choice prompt for operations where both Accept and Reject are +/// persisted so the user is never asked again. Use this instead of +/// [`PolicyPermission`] when there is no meaningful "one-off allow" path. +#[derive(Debug, Clone, PartialEq, Eq, Display, EnumIter, strum_macros::EnumString)] +enum ConfirmPermission { + /// Allow the operation and remember this choice + #[strum(to_string = "Accept")] + Accept, + /// Deny the operation and remember this choice + #[strum(to_string = "Reject")] + Reject, +} + #[derive(Clone)] pub struct ForgePolicyService { infra: Arc, } + /// Default policies loaded once at startup from the embedded YAML file static DEFAULT_POLICIES: LazyLock = LazyLock::new(|| { let yaml_content = include_str!("./permissions.default.yaml"); @@ -156,6 +170,23 @@ where + DirectoryReaderInfra + UserInfra, { + /// Unconditionally persist an allow policy for the given operation. + async fn allow_operation(&self, operation: &PermissionOperation) -> anyhow::Result<()> { + self.add_policy_for_operation(operation).await.map(|_| ()) + } + + /// Check whether an operation is explicitly permitted by the current + /// policy without prompting the user. `Confirm` is treated as not + /// permitted so callers can handle it themselves (e.g. show a warning). + async fn is_operation_permitted( + &self, + operation: &PermissionOperation, + ) -> anyhow::Result { + let (policies, _) = self.get_or_create_policies().await?; + let engine = PolicyEngine::new(&policies); + Ok(matches!(engine.can_perform(operation), Permission::Allow)) + } + /// Check if an operation is allowed based on policies and handle user /// confirmation async fn check_operation_permission( @@ -172,24 +203,51 @@ where Permission::Allow => Ok(PolicyDecision { allowed: true, path }), Permission::Confirm => { // Request user confirmation using UserInfra - let confirmation_msg = match operation { + let prompt = match operation { PermissionOperation::Read { message, .. } => { - format!("{message}. How would you like to proceed?") + SelectPrompt::new(format!("{message}. How would you like to proceed?")) } PermissionOperation::Write { message, .. } => { - format!("{message}. How would you like to proceed?") + SelectPrompt::new(format!("{message}. How would you like to proceed?")) } PermissionOperation::Execute { .. } => { - "How would you like to proceed?".to_string() + SelectPrompt::new("How would you like to proceed?") } PermissionOperation::Fetch { message, .. } => { - format!("{message}. How would you like to proceed?") + SelectPrompt::new(format!("{message}. How would you like to proceed?")) + } + PermissionOperation::Mcp { message, config, cwd } => { + let header = mcp_config_header(config); + let prompt = SelectPrompt::new(message.clone()).with_header(header); + return match self + .infra + .select_one_enum::(prompt) + .await? + { + Some(ConfirmPermission::Accept) => { + let update_path = self.add_policy_for_operation(operation).await?; + Ok(PolicyDecision { allowed: true, path: update_path.or(path) }) + } + Some(ConfirmPermission::Reject) | None => { + let deny_policy = Policy::Simple { + permission: Permission::Deny, + rule: Rule::Mcp(McpRule { + mcp: McpFilter::from_config(config, cwd), + }), + }; + self.modify_policy(deny_policy).await?; + Ok(PolicyDecision { + allowed: false, + path: Some(self.permissions_path()), + }) + } + }; } }; match self .infra - .select_one_enum::(&confirmation_msg) + .select_one_enum::(prompt) .await? { Some(PolicyPermission::Accept) => Ok(PolicyDecision { allowed: true, path }), @@ -206,6 +264,21 @@ where } } +/// Builds the header lines describing an MCP server's configuration. +fn mcp_config_header(config: &forge_app::domain::McpServerConfig) -> Vec { + use forge_app::domain::McpServerConfig; + match config { + McpServerConfig::Stdio(s) => { + let mut lines = vec![format!("command: {}", s.command)]; + if !s.args.is_empty() { + lines.push(format!("args: {}", s.args.join(" "))); + } + lines + } + McpServerConfig::Http(h) => vec![format!("url: {}", h.url)], + } +} + /// Create a policy for an operation based on its type fn create_policy_for_operation( operation: &PermissionOperation, @@ -262,6 +335,10 @@ fn create_policy_for_operation( }), } } + PermissionOperation::Mcp { config, cwd, .. } => Some(Policy::Simple { + permission: Permission::Allow, + rule: Rule::Mcp(McpRule { mcp: McpFilter::from_config(config, cwd) }), + }), } } @@ -443,4 +520,57 @@ mod tests { assert_eq!(actual, expected); } + + #[test] + fn test_create_policy_for_mcp_stdio_operation() { + let operation = PermissionOperation::Mcp { + config: forge_app::domain::McpServerConfig::new_stdio( + "npx", + vec!["-y".to_string(), "@github/mcp".to_string()], + None, + ), + cwd: PathBuf::from("/home/user/project"), + message: "Connect to MCP server: github".to_string(), + }; + + let actual = create_policy_for_operation(&operation, None); + + let expected = Some(Policy::Simple { + permission: Permission::Allow, + rule: Rule::Mcp(McpRule { + mcp: McpFilter { + command: Some("npx".to_string()), + args: Some(vec!["-y".to_string(), "@github/mcp".to_string()]), + url: None, + dir: Some("/home/user/project".to_string()), + }, + }), + }); + + assert_eq!(actual, expected); + } + + #[test] + fn test_create_policy_for_mcp_http_operation() { + let operation = PermissionOperation::Mcp { + config: forge_app::domain::McpServerConfig::new_http("https://mcp.example.com/sse"), + cwd: PathBuf::from("/home/user/project"), + message: "Connect to MCP server: example".to_string(), + }; + + let actual = create_policy_for_operation(&operation, None); + + let expected = Some(Policy::Simple { + permission: Permission::Allow, + rule: Rule::Mcp(McpRule { + mcp: McpFilter { + url: Some("https://mcp.example.com/sse".to_string()), + dir: Some("/home/user/project".to_string()), + ..McpFilter::default() + }, + }), + }); + + assert_eq!(actual, expected); + } } diff --git a/crates/forge_services/src/tool_services/followup.rs b/crates/forge_services/src/tool_services/followup.rs index 7af2b5fc2a..b791cfb94f 100644 --- a/crates/forge_services/src/tool_services/followup.rs +++ b/crates/forge_services/src/tool_services/followup.rs @@ -39,7 +39,7 @@ impl FollowUpService for ForgeFollowup { ) }), (false, false) => inquire - .select_one(&question, options) + .select_one(question.as_str(), options) .await? .map(|selected| format!("User selected: {selected}")), };