diff --git a/src/elizacp/tests/mcp_tool_invocation.rs b/src/elizacp/tests/mcp_tool_invocation.rs index ca4f971..1357e78 100644 --- a/src/elizacp/tests/mcp_tool_invocation.rs +++ b/src/elizacp/tests/mcp_tool_invocation.rs @@ -16,7 +16,7 @@ async fn recv( response: sacp::JrResponse, ) -> Result { let (tx, rx) = tokio::sync::oneshot::channel(); - response.await_when_result_received(async move |result| { + response.on_receiving_result(async move |result| { tx.send(result).map_err(|_| sacp::Error::internal_error()) })?; rx.await.map_err(|_| sacp::Error::internal_error())? diff --git a/src/sacp-conductor/src/conductor.rs b/src/sacp-conductor/src/conductor.rs index 699be9d..a8c8b09 100644 --- a/src/sacp-conductor/src/conductor.rs +++ b/src/sacp-conductor/src/conductor.rs @@ -612,7 +612,7 @@ impl ConductorResponder { meta: None, }, ) - .await_when_result_received({ + .on_receiving_result({ let mut conductor_tx = conductor_tx.clone(); async move |result| { match result { @@ -979,7 +979,7 @@ impl ConductorResponder { .as_ref() .expect("we have an agent component") .send_request(initialize_req) - .await_when_result_received(async move |response| { + .on_receiving_result(async move |response| { tracing::debug!(?response, "got initialize response from agent"); request_cx .respond_with_result_via(conductor_tx, response) @@ -999,7 +999,7 @@ impl ConductorResponder { // Forward initialize request to our successor connection_cx .send_request_to(Agent, initialize_req) - .await_when_result_received(async move |result| { + .on_receiving_result(async move |result| { tracing::trace!( ?result, "received response to initialize_proxy from empty conductor" @@ -1016,7 +1016,7 @@ impl ConductorResponder { let proxy_req = InitializeProxyRequest::from(initialize_req); self.proxies[target_component_index] .send_request(proxy_req) - .await_when_result_received(async move |result| { + .on_receiving_result(async move |result| { tracing::debug!(?result, "got initialize_proxy response from proxy"); // Convert InitializeProxyResponse back to InitializeResponse request_cx @@ -1550,7 +1550,7 @@ impl JrResponseExt for JrResponse { request_cx: JrRequestCx, ) -> Result<(), sacp::Error> { let conductor_tx = conductor_tx.clone(); - self.await_when_result_received(async move |result| { + self.on_receiving_result(async move |result| { request_cx .respond_with_result_via(conductor_tx, result) .await diff --git a/src/sacp-conductor/tests/initialization_sequence.rs b/src/sacp-conductor/tests/initialization_sequence.rs index bc21fd7..a1a4f23 100644 --- a/src/sacp-conductor/tests/initialization_sequence.rs +++ b/src/sacp-conductor/tests/initialization_sequence.rs @@ -21,7 +21,7 @@ async fn recv( response: sacp::JrResponse, ) -> Result { let (tx, rx) = tokio::sync::oneshot::channel(); - response.await_when_result_received(async move |result| { + response.on_receiving_result(async move |result| { tx.send(result).map_err(|_| sacp::Error::internal_error()) })?; rx.await.map_err(|_| sacp::Error::internal_error())? @@ -83,7 +83,7 @@ impl Component for InitComponent { // Forward to successor and respond cx.send_request_to(sacp::Agent, request) - .await_when_result_received(async move |response| { + .on_receiving_result(async move |response| { let response: InitializeResponse = response?; request_cx.respond(response) }) diff --git a/src/sacp-conductor/tests/mcp-integration.rs b/src/sacp-conductor/tests/mcp-integration.rs index 05ac376..a28d938 100644 --- a/src/sacp-conductor/tests/mcp-integration.rs +++ b/src/sacp-conductor/tests/mcp-integration.rs @@ -25,7 +25,7 @@ async fn recv( response: sacp::JrResponse, ) -> Result { let (tx, rx) = tokio::sync::oneshot::channel(); - response.await_when_result_received(async move |result| { + response.on_receiving_result(async move |result| { tx.send(result).map_err(|_| sacp::Error::internal_error()) })?; rx.await.map_err(|_| sacp::Error::internal_error())? diff --git a/src/sacp-conductor/tests/mcp_server_handler_chain.rs b/src/sacp-conductor/tests/mcp_server_handler_chain.rs index f999422..2d543f7 100644 --- a/src/sacp-conductor/tests/mcp_server_handler_chain.rs +++ b/src/sacp-conductor/tests/mcp_server_handler_chain.rs @@ -38,7 +38,7 @@ async fn recv( response: sacp::JrResponse, ) -> Result { let (tx, rx) = tokio::sync::oneshot::channel(); - response.await_when_result_received(async move |result| { + response.on_receiving_result(async move |result| { tx.send(result).map_err(|_| sacp::Error::internal_error()) })?; rx.await.map_err(|_| sacp::Error::internal_error())? @@ -101,7 +101,7 @@ impl Component for ProxyWithMcpAndHandler { // Forward to agent and relay response cx.send_request_to(Agent, request) - .await_when_result_received(async move |result| { + .on_receiving_result(async move |result| { let response: NewSessionResponse = result?; request_cx.respond(response) }) diff --git a/src/sacp-conductor/tests/test_tool_fn.rs b/src/sacp-conductor/tests/test_tool_fn.rs new file mode 100644 index 0000000..3b4cb5c --- /dev/null +++ b/src/sacp-conductor/tests/test_tool_fn.rs @@ -0,0 +1,101 @@ +//! Integration test for `tool_fn` - stateless concurrent tools +//! +//! This test verifies that `tool_fn` works correctly for stateless tools +//! that don't need mutable state. + +use sacp::Component; +use sacp::ProxyToConductor; +use sacp::mcp_server::McpServer; +use sacp_conductor::Conductor; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use tokio::io::duplex; +use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; + +/// Input for the greet tool +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +struct GreetInput { + name: String, +} + +/// Create a proxy that provides an MCP server with a stateless greet tool +fn create_greet_proxy() -> Result { + // Create MCP server with a stateless greet tool using tool_fn + let mcp_server = McpServer::builder("greet_server".to_string()) + .instructions("Test MCP server with stateless greet tool") + .tool_fn( + "greet", + "Greet someone by name", + async |input: GreetInput, _context| Ok(format!("Hello, {}!", input.name)), + sacp::tool_fn!(), + ) + .build(); + + // Create proxy component + Ok(sacp::DynComponent::new(ProxyWithGreetServer { mcp_server })) +} + +struct ProxyWithGreetServer> { + mcp_server: McpServer, +} + +impl + 'static + Send> Component + for ProxyWithGreetServer +{ + async fn serve(self, client: impl Component) -> Result<(), sacp::Error> { + ProxyToConductor::builder() + .name("greet-proxy") + .with_mcp_server(self.mcp_server) + .serve(client) + .await + } +} + +/// Elizacp agent component wrapper for testing +struct ElizacpAgentComponent; + +impl Component for ElizacpAgentComponent { + async fn serve(self, client: impl Component) -> Result<(), sacp::Error> { + // Create duplex channels for bidirectional communication + let (elizacp_write, client_read) = duplex(8192); + let (client_write, elizacp_read) = duplex(8192); + + let elizacp_transport = + sacp::ByteStreams::new(elizacp_write.compat_write(), elizacp_read.compat()); + + let client_transport = + sacp::ByteStreams::new(client_write.compat_write(), client_read.compat()); + + // Spawn elizacp in a background task + tokio::spawn(async move { + if let Err(e) = elizacp::ElizaAgent::new().serve(elizacp_transport).await { + tracing::error!("Elizacp error: {}", e); + } + }); + + // Serve the client with the transport connected to elizacp + client_transport.serve(client).await + } +} + +#[tokio::test] +async fn test_tool_fn_greet() -> Result<(), sacp::Error> { + let result = yopo::prompt( + Conductor::new( + "test-conductor".to_string(), + vec![ + create_greet_proxy()?, + sacp::DynComponent::new(ElizacpAgentComponent), + ], + Default::default(), + ), + r#"Use tool greet_server::greet with {"name": "World"}"#, + ) + .await?; + + expect_test::expect![[r#" + "OK: CallToolResult { content: [Annotated { raw: Text(RawTextContent { text: \"\\\"Hello, World!\\\"\", meta: None }), annotations: None }], structured_content: Some(String(\"Hello, World!\")), is_error: Some(false), meta: None }" + "#]].assert_debug_eq(&result); + + Ok(()) +} diff --git a/src/sacp-tokio/tests/debug_logging.rs b/src/sacp-tokio/tests/debug_logging.rs index 499ce45..847d378 100644 --- a/src/sacp-tokio/tests/debug_logging.rs +++ b/src/sacp-tokio/tests/debug_logging.rs @@ -12,7 +12,7 @@ async fn recv( response: sacp::JrResponse, ) -> Result { let (tx, rx) = tokio::sync::oneshot::channel(); - response.await_when_result_received(async move |result| { + response.on_receiving_result(async move |result| { tx.send(result).map_err(|_| sacp::Error::internal_error()) })?; rx.await.map_err(|_| sacp::Error::internal_error())? diff --git a/src/sacp/src/cookbook.rs b/src/sacp/src/cookbook.rs new file mode 100644 index 0000000..b38f19d --- /dev/null +++ b/src/sacp/src/cookbook.rs @@ -0,0 +1,370 @@ +//! Cookbook of common patterns for building ACP components. +//! +//! This module contains documented examples of patterns that come up +//! frequently when building agents, proxies, and other ACP components. +//! +//! # Roles and Endpoints +//! +//! ACP connections are typed by their *role*, which captures both "who I am" +//! and "who I'm talking to". Roles implement [`JrRole`] and determine what +//! operations are valid on a connection. +//! +//! ## Endpoints +//! +//! *Endpoints* ([`JrEndpoint`]) are logical destinations for messages: +//! +//! - [`Client`] - The client endpoint (IDE, CLI, etc.) +//! - [`Agent`] - The agent endpoint (AI-powered component) +//! - [`Conductor`] - The conductor endpoint (orchestrates proxy chains) +//! +//! Most roles have a single implicit endpoint, but proxies can send to +//! multiple endpoints. Use [`send_request_to`] and [`send_notification_to`] +//! to specify the destination explicitly. +//! +//! ## Role Types +//! +//! The built-in role types are: +//! +//! | Role | Description | Can send to | +//! |------|-------------|-------------| +//! | [`ClientToAgent`] | Client's connection to an agent | `Agent` | +//! | [`AgentToClient`] | Agent's connection to a client | `Client` | +//! | [`ProxyToConductor`] | Proxy's connection to the conductor | `Client`, `Agent` | +//! | [`ConductorToClient`] | Conductor's connection to a client | `Client`, `Agent` | +//! | [`ConductorToProxy`] | Conductor's connection to a proxy | `Agent` | +//! | [`ConductorToAgent`] | Conductor's connection to the final agent | `Agent` | +//! | [`UntypedRole`] | Generic role for testing/dynamic scenarios | any | +//! +//! ## Proxies and Multiple Endpoints +//! +//! A proxy sits between client and agent, so it needs to send messages in +//! both directions. [`ProxyToConductor`] implements `HasEndpoint` and +//! `HasEndpoint`, allowing it to forward messages appropriately: +//! +//! ```ignore +//! // Forward a request toward the agent +//! cx.send_request_to(Agent, request).forward_to_request_cx(request_cx)?; +//! +//! // Send a notification toward the client +//! cx.send_notification_to(Client, notification)?; +//! ``` +//! +//! When sending to `Agent` from a proxy, messages are automatically wrapped +//! in [`SuccessorMessage`] envelopes. When receiving from `Agent`, they're +//! automatically unwrapped. +//! +//! [`JrRole`]: crate::role::JrRole +//! [`JrEndpoint`]: crate::role::JrEndpoint +//! [`Client`]: crate::Client +//! [`Agent`]: crate::Agent +//! [`Conductor`]: crate::Conductor +//! [`ClientToAgent`]: crate::ClientToAgent +//! [`AgentToClient`]: crate::AgentToClient +//! [`ProxyToConductor`]: crate::ProxyToConductor +//! [`ConductorToClient`]: crate::role::ConductorToClient +//! [`ConductorToProxy`]: crate::role::ConductorToProxy +//! [`ConductorToAgent`]: crate::role::ConductorToAgent +//! [`UntypedRole`]: crate::role::UntypedRole +//! [`SuccessorMessage`]: crate::schema::SuccessorMessage +//! [`send_request_to`]: crate::JrConnectionCx::send_request_to +//! [`send_notification_to`]: crate::JrConnectionCx::send_notification_to +//! +//! # Patterns +//! +//! - [`reusable_components`] - Defining agents/proxies with [`Component`] +//! - [`custom_message_handlers`] - Implementing [`JrMessageHandler`] +//! - [`connecting_as_client`] - Using `with_client` to send requests +//! - [`global_mcp_server`] - Adding a shared MCP server to a handler chain +//! - [`per_session_mcp_server`] - Creating per-session MCP servers +//! +//! [`Component`]: crate::Component +//! [`JrMessageHandler`]: crate::JrMessageHandler +//! [`reusable_components`]: crate::cookbook::reusable_components +//! [`custom_message_handlers`]: crate::cookbook::custom_message_handlers +//! [`connecting_as_client`]: crate::cookbook::connecting_as_client +//! [`global_mcp_server`]: crate::cookbook::global_mcp_server +//! [`per_session_mcp_server`]: crate::cookbook::per_session_mcp_server + +pub mod reusable_components { + //! Pattern: Defining reusable components. + //! + //! When building agents or proxies, define a struct that implements [`Component`]. + //! Internally, use the role's `builder()` method to set up handlers. + //! + //! # Example + //! + //! ``` + //! use sacp::{Component, AgentToClient}; + //! use sacp::schema::{ + //! InitializeRequest, InitializeResponse, AgentCapabilities, + //! }; + //! + //! struct MyAgent { + //! name: String, + //! } + //! + //! impl Component for MyAgent { + //! async fn serve(self, client: impl Component) -> Result<(), sacp::Error> { + //! AgentToClient::builder() + //! .name(&self.name) + //! .on_receive_request(async move |req: InitializeRequest, request_cx, _cx| { + //! request_cx.respond(InitializeResponse { + //! protocol_version: req.protocol_version, + //! agent_capabilities: AgentCapabilities::default(), + //! auth_methods: vec![], + //! agent_info: None, + //! meta: None, + //! }) + //! }, sacp::on_receive_request!()) + //! .serve(client) + //! .await + //! } + //! } + //! + //! let agent = MyAgent { name: "my-agent".into() }; + //! ``` + //! + //! # Important: Don't block the event loop + //! + //! Message handlers run on the event loop. Blocking in a handler prevents the + //! connection from processing new messages. For expensive work: + //! + //! - Use [`JrConnectionCx::spawn`] to offload work to a background task + //! - Use [`on_receiving_result`] to schedule work when a response arrives + //! + //! [`Component`]: crate::Component + //! [`JrConnectionCx::spawn`]: crate::JrConnectionCx::spawn + //! [`on_receiving_result`]: crate::JrResponse::on_receiving_result +} + +pub mod custom_message_handlers { + //! Pattern: Custom message handlers. + //! + //! For reusable message handling logic, implement [`JrMessageHandler`] and use + //! [`MatchMessage`] or [`MatchMessageFrom`] for type-safe dispatching. + //! + //! # Example + //! + //! ``` + //! use sacp::{JrMessageHandler, MessageCx, Handled, JrConnectionCx}; + //! use sacp::schema::{InitializeRequest, InitializeResponse, AgentCapabilities}; + //! use sacp::util::MatchMessage; + //! + //! struct MyHandler; + //! + //! impl JrMessageHandler for MyHandler { + //! type Role = sacp::role::UntypedRole; + //! + //! async fn handle_message( + //! &mut self, + //! message: MessageCx, + //! _cx: JrConnectionCx, + //! ) -> Result, sacp::Error> { + //! MatchMessage::new(message) + //! .if_request(async |req: InitializeRequest, request_cx| { + //! request_cx.respond(InitializeResponse { + //! protocol_version: req.protocol_version, + //! agent_capabilities: AgentCapabilities::default(), + //! auth_methods: vec![], + //! agent_info: None, + //! meta: None, + //! }) + //! }) + //! .await + //! .done() + //! } + //! + //! fn describe_chain(&self) -> impl std::fmt::Debug { + //! "MyHandler" + //! } + //! } + //! ``` + //! + //! # When to use `MatchMessage` vs `MatchMessageFrom` + //! + //! - [`MatchMessage`] - Use when you don't need endpoint-aware handling + //! - [`MatchMessageFrom`] - Use in proxies where messages come from different + //! endpoints (`Client` vs `Agent`) and may need different handling + //! + //! [`JrMessageHandler`]: crate::JrMessageHandler + //! [`MatchMessage`]: crate::util::MatchMessage + //! [`MatchMessageFrom`]: crate::util::MatchMessageFrom +} + +pub mod connecting_as_client { + //! Pattern: Connecting as a client. + //! + //! To connect to a JSON-RPC server and send requests, use [`with_client`]. + //! This gives you a connection context for sending requests while the + //! connection handles incoming messages in the background. + //! + //! # Example + //! + //! ``` + //! use sacp::{ClientToAgent, Component}; + //! use sacp::schema::{InitializeRequest, NewSessionRequest, SessionNotification}; + //! + //! async fn connect_to_agent(transport: impl Component) -> Result<(), sacp::Error> { + //! ClientToAgent::builder() + //! .name("my-client") + //! .on_receive_notification(async |notif: SessionNotification, _cx| { + //! // Handle notifications from the agent + //! println!("Session updated: {:?}", notif); + //! Ok(()) + //! }, sacp::on_receive_notification!()) + //! .with_client(transport, async |cx| { + //! // Initialize the connection + //! let _init_response = cx.send_request(InitializeRequest { + //! protocol_version: Default::default(), + //! client_capabilities: Default::default(), + //! client_info: None, + //! meta: None, + //! }) + //! .block_task() + //! .await?; + //! + //! // Create a session + //! let session = cx.send_request(NewSessionRequest { + //! cwd: ".".into(), + //! mcp_servers: vec![], + //! meta: None, + //! }) + //! .block_task() + //! .await?; + //! + //! println!("Session created: {:?}", session.session_id); + //! Ok(()) + //! }) + //! .await + //! } + //! ``` + //! + //! # Note on `block_task` + //! + //! Using [`block_task`] is safe inside `with_client` because the closure runs + //! as a spawned task, not on the event loop. The event loop continues processing + //! messages (including the response you're waiting for) while your task blocks. + //! + //! [`with_client`]: crate::JrConnectionBuilder::with_client + //! [`block_task`]: crate::JrResponse::block_task +} + +pub mod global_mcp_server { + //! Pattern: Global MCP server in handler chain. + //! + //! Use this pattern when you want a single MCP server that handles tool calls + //! for all sessions. The server is added to the connection's handler chain and + //! automatically injects itself into every `NewSessionRequest` that passes through. + //! + //! # When to use + //! + //! - The MCP server provides stateless tools (no per-session state needed) + //! - You want the simplest setup with minimal boilerplate + //! - Tools don't need access to session-specific context + //! + //! # Example + //! + //! ``` + //! use sacp::mcp_server::McpServer; + //! use sacp::{Component, JrResponder, ProxyToConductor}; + //! use schemars::JsonSchema; + //! use serde::{Deserialize, Serialize}; + //! + //! #[derive(Debug, Deserialize, JsonSchema)] + //! struct EchoParams { message: String } + //! + //! #[derive(Debug, Serialize, JsonSchema)] + //! struct EchoOutput { echoed: String } + //! + //! // Build the MCP server with tools + //! let mcp_server = McpServer::builder("my-tools") + //! .tool_fn("echo", "Echoes the input", + //! async |params: EchoParams, _cx| { + //! Ok(EchoOutput { echoed: params.message }) + //! }, + //! sacp::tool_fn!()) + //! .build(); + //! + //! // The proxy component is generic over the MCP server's responder type + //! struct MyProxy { + //! mcp_server: McpServer, + //! } + //! + //! impl + Send + 'static> Component for MyProxy { + //! async fn serve(self, client: impl Component) -> Result<(), sacp::Error> { + //! ProxyToConductor::builder() + //! .with_mcp_server(self.mcp_server) + //! .serve(client) + //! .await + //! } + //! } + //! + //! let proxy = MyProxy { mcp_server }; + //! ``` + //! + //! # How it works + //! + //! When you call [`with_mcp_server`], the MCP server is added as a message + //! handler. It: + //! + //! 1. Intercepts `NewSessionRequest` messages and adds its `acp:UUID` URL to the + //! request's `mcp_servers` list + //! 2. Passes the modified request through to the next handler + //! 3. Handles incoming MCP protocol messages (tool calls, etc.) for its URL + //! + //! [`with_mcp_server`]: crate::JrConnectionBuilder::with_mcp_server +} + +pub mod per_session_mcp_server { + //! Pattern: Per-session MCP server. + //! + //! Use this pattern when each session needs its own MCP server instance, + //! typically because tools need access to session-specific state or context. + //! + //! # When to use + //! + //! - Tools need access to the session ID or session-specific state + //! - You want to customize the MCP server based on session parameters + //! - Tools need to send notifications back to a specific session + //! + //! # Example + //! + //! ``` + //! use sacp::mcp_server::McpServer; + //! use sacp::schema::NewSessionRequest; + //! use sacp::{Agent, Client, Component, JrResponder, ProxyToConductor}; + //! + //! async fn run_proxy(transport: impl Component) -> Result<(), sacp::Error> { + //! ProxyToConductor::builder() + //! .on_receive_request_from(Client, async |request: NewSessionRequest, request_cx, cx| { + //! // Create an MCP server for this session + //! let cwd = request.cwd.clone(); + //! let mcp_server = McpServer::builder("session-tools") + //! .tool_fn("get_cwd", "Returns session working directory", + //! async move |_params: (), _cx| { + //! Ok(cwd.display().to_string()) + //! }, sacp::tool_fn!()) + //! .build(); + //! + //! // Build the session with the MCP server attached and proxy it + //! cx.build_session_from(request) + //! .with_mcp_server(mcp_server)? + //! .proxy_session(request_cx, JrResponder::run) + //! .await + //! }, sacp::on_receive_request!()) + //! .serve(transport) + //! .await + //! } + //! ``` + //! + //! # How it works + //! + //! When you call [`SessionBuilder::with_mcp_server`]: + //! + //! 1. The MCP server is converted into a dynamic handler via `into_dynamic_handler()` + //! 2. The handler is registered for the session's message routing + //! 3. The MCP server's URL is added to the `NewSessionRequest` + //! 4. The handler lives as long as the session (dropped when `run_session` completes) + //! + //! [`SessionBuilder::with_mcp_server`]: crate::SessionBuilder::with_mcp_server +} diff --git a/src/sacp/src/handler.rs b/src/sacp/src/handler.rs index 406bd63..8923e10 100644 --- a/src/sacp/src/handler.rs +++ b/src/sacp/src/handler.rs @@ -3,10 +3,5 @@ //! This module contains the handler types used by [`JrConnection`](crate::JrConnection) //! to process incoming messages. Most users won't need to use these types directly, //! as the builder methods on `JrConnection` handle the construction automatically. -//! -//! However, these types can be useful for: -//! - Building reusable handler components -//! - Composing handlers programmatically -//! - Understanding the handler infrastructure -pub use crate::jsonrpc::handlers::*; +pub use crate::jsonrpc::{JrMessageHandler, handlers::NullHandler}; diff --git a/src/sacp/src/jsonrpc.rs b/src/sacp/src/jsonrpc.rs index fae3877..d0377b9 100644 --- a/src/sacp/src/jsonrpc.rs +++ b/src/sacp/src/jsonrpc.rs @@ -25,9 +25,9 @@ pub(crate) mod responder; mod task_actor; mod transport_actor; -use crate::handler::{ChainedHandler, NamedHandler}; use crate::jsonrpc::dynamic_handler::DynamicHandlerMessage; pub use crate::jsonrpc::handlers::NullHandler; +use crate::jsonrpc::handlers::{ChainedHandler, NamedHandler}; use crate::jsonrpc::handlers::{MessageHandler, NotificationHandler, RequestHandler}; use crate::jsonrpc::outgoing_actor::{OutgoingMessageTx, send_raw_message}; use crate::jsonrpc::responder::SpawnedResponder; @@ -202,6 +202,22 @@ pub trait JrMessageHandler: Send { fn describe_chain(&self) -> impl std::fmt::Debug; } +impl JrMessageHandler for &mut H { + type Role = H::Role; + + fn handle_message( + &mut self, + message: MessageCx, + cx: JrConnectionCx, + ) -> impl Future, crate::Error>> + Send { + H::handle_message(self, message, cx) + } + + fn describe_chain(&self) -> impl std::fmt::Debug { + H::describe_chain(self) + } +} + /// A JSON-RPC connection that can act as either a server, client, or both. /// /// `JrConnection` provides a builder-style API for creating JSON-RPC servers and clients. @@ -403,7 +419,7 @@ pub trait JrMessageHandler: Send { /// * [`spawn`](JrConnectionCx::spawn) - Run tasks concurrently without blocking the event loop /// /// The [`JrResponse`] returned by `send_request` provides methods like -/// [`await_when_result_received`](JrResponse::await_when_result_received) that help you +/// [`on_receiving_result`](JrResponse::on_receiving_result) that help you /// avoid accidentally blocking the event loop while waiting for responses. /// /// # Driving the Connection @@ -1564,7 +1580,7 @@ impl JrConnectionCx { /// The returned [`JrResponse`] provides methods for receiving the response without /// blocking the event loop: /// - /// * [`await_when_result_received`](JrResponse::await_when_result_received) - Schedule + /// * [`on_receiving_result`](JrResponse::on_receiving_result) - Schedule /// a callback to run when the response arrives (doesn't block the event loop) /// * [`block_task`](JrResponse::block_task) - Block the current task until the response /// arrives (only safe in spawned tasks, not in handlers) @@ -1588,7 +1604,7 @@ impl JrConnectionCx { /// # async fn example(cx: sacp::JrConnectionCx) -> Result<(), sacp::Error> { /// // ✅ Option 1: Schedule callback (safe in handlers) /// cx.send_request(MyRequest {}) - /// .await_when_result_received(async |result| { + /// .on_receiving_result(async |result| { /// // Handle the response /// Ok(()) /// })?; @@ -2396,14 +2412,14 @@ impl JrNotification for UntypedMessage {} /// /// ## Option 1: Schedule a Callback (Safe in Handlers) /// -/// Use [`await_when_result_received`](Self::await_when_result_received) to schedule a task +/// Use [`on_receiving_result`](Self::on_receiving_result) to schedule a task /// that runs when the response arrives. This doesn't block the event loop: /// /// ```no_run /// # use sacp_test::*; /// # async fn example(cx: sacp::JrConnectionCx) -> Result<(), sacp::Error> { /// cx.send_request(MyRequest {}) -/// .await_when_result_received(async |result| { +/// .on_receiving_result(async |result| { /// match result { /// Ok(response) => { /// // Handle successful response @@ -2554,13 +2570,13 @@ impl JrResponse { /// - You want to forward responses without processing them /// - The response types match between the outgoing request and incoming request /// - /// This is equivalent to calling `await_when_result_received` and manually forwarding + /// This is equivalent to calling `on_receiving_result` and manually forwarding /// the result, but more concise. pub fn forward_to_request_cx(self, request_cx: JrRequestCx) -> Result<(), crate::Error> where T: Send, { - self.await_when_result_received(async move |result| request_cx.respond_with_result(result)) + self.on_receiving_result(async move |result| request_cx.respond_with_result(result)) } /// Block the current task until the response is received. @@ -2624,7 +2640,7 @@ impl JrResponse { /// - You need the response value to proceed with your logic /// - Linear control flow is more natural than callbacks /// - /// For handler callbacks, use [`await_when_result_received`](Self::await_when_result_received) instead. + /// For handler callbacks, use [`on_receiving_result`](Self::on_receiving_result) instead. pub async fn block_task(self) -> Result where T: Send, @@ -2644,7 +2660,7 @@ impl JrResponse { /// Schedule an async task to run when a successful response is received. /// - /// This is a convenience wrapper around [`await_when_result_received`](Self::await_when_result_received) + /// This is a convenience wrapper around [`on_receiving_result`](Self::on_receiving_result) /// for the common pattern of forwarding errors to a request context while only processing /// successful responses. /// @@ -2663,7 +2679,7 @@ impl JrResponse { /// connection.on_receive_request(async |req: ValidateRequest, request_cx, cx| { /// // Send initial request /// cx.send_request(ValidateRequest { data: req.data.clone() }) - /// .await_when_ok_response_received(request_cx, async |validation, request_cx| { + /// .on_receiving_ok_result(request_cx, async |validation, request_cx| { /// // Only runs if validation succeeded /// if validation.is_valid { /// // Respond to original request @@ -2687,9 +2703,9 @@ impl JrResponse { /// - You want errors to automatically propagate to the request context /// - You only care about the success case /// - /// For more control over error handling, use [`await_when_result_received`](Self::await_when_result_received). + /// For more control over error handling, use [`on_receiving_result`](Self::on_receiving_result). #[track_caller] - pub fn await_when_ok_response_received( + pub fn on_receiving_ok_result( self, request_cx: JrRequestCx, task: impl FnOnce(T, JrRequestCx) -> F + 'static + Send, @@ -2698,7 +2714,7 @@ impl JrResponse { F: Future> + 'static + Send, T: Send, { - self.await_when_result_received(async move |result| match result { + self.on_receiving_result(async move |result| match result { Ok(value) => task(value, request_cx).await, Err(err) => request_cx.respond_with_error(err), }) @@ -2718,7 +2734,7 @@ impl JrResponse { /// connection.on_receive_request(async |req: MyRequest, request_cx, cx| { /// // Send a request and schedule a callback for the response /// cx.send_request(QueryRequest { id: 22 }) - /// .await_when_result_received({ + /// .on_receiving_result({ /// let connection_cx = cx.clone(); /// async move |result| { /// match result { @@ -2764,7 +2780,7 @@ impl JrResponse { /// /// For spawned tasks where you need linear control flow, consider [`block_task`](Self::block_task). #[track_caller] - pub fn await_when_result_received( + pub fn on_receiving_result( self, task: impl FnOnce(Result) -> F + 'static + Send, ) -> Result<(), crate::Error> diff --git a/src/sacp/src/jsonrpc/handlers.rs b/src/sacp/src/jsonrpc/handlers.rs index a2d2dce..e28fd11 100644 --- a/src/sacp/src/jsonrpc/handlers.rs +++ b/src/sacp/src/jsonrpc/handlers.rs @@ -52,27 +52,20 @@ pub struct RequestHandler< > { handler: F, to_future_hack: ToFut, - role: Role, - phantom: PhantomData, + phantom: PhantomData, } impl RequestHandler { /// Creates a new request handler - pub fn new(_endpoint: End, role: Role, handler: F, to_future_hack: ToFut) -> Self { + pub fn new(_endpoint: End, _role: Role, handler: F, to_future_hack: ToFut) -> Self { Self { handler, to_future_hack, - role, phantom: PhantomData, } } - - /// Returns the role. - pub fn role(&self) -> Role { - self.role - } } impl JrMessageHandler @@ -186,27 +179,20 @@ pub struct NotificationHandler< > { handler: F, to_future_hack: ToFut, - role: Role, - phantom: PhantomData, + phantom: PhantomData, } impl NotificationHandler { /// Creates a new notification handler - pub fn new(_endpoint: End, role: Role, handler: F, to_future_hack: ToFut) -> Self { + pub fn new(_endpoint: End, _role: Role, handler: F, to_future_hack: ToFut) -> Self { Self { handler, to_future_hack, - role, phantom: PhantomData, } } - - /// Returns the role. - pub fn role(&self) -> Role { - self.role - } } impl JrMessageHandler @@ -312,27 +298,20 @@ pub struct MessageHandler< > { handler: F, to_future_hack: ToFut, - role: Role, - phantom: PhantomData, + phantom: PhantomData, } impl MessageHandler { /// Creates a new message handler - pub fn new(_endpoint: End, role: Role, handler: F, to_future_hack: ToFut) -> Self { + pub fn new(_endpoint: End, _role: Role, handler: F, to_future_hack: ToFut) -> Self { Self { handler, to_future_hack, - role, phantom: PhantomData, } } - - /// Returns the role. - pub fn role(&self) -> Role { - self.role - } } impl diff --git a/src/sacp/src/jsonrpc/task_actor.rs b/src/sacp/src/jsonrpc/task_actor.rs index e30dce0..f61f6fe 100644 --- a/src/sacp/src/jsonrpc/task_actor.rs +++ b/src/sacp/src/jsonrpc/task_actor.rs @@ -1,14 +1,10 @@ use std::panic::Location; -use futures::{ - FutureExt, StreamExt, - channel::mpsc, - future::BoxFuture, - stream::{FusedStream, FuturesUnordered}, -}; +use futures::{FutureExt, channel::mpsc, future::BoxFuture}; use crate::JrConnectionCx; use crate::role::JrRole; +use crate::util::process_stream_concurrently; pub type TaskTx = mpsc::UnboundedSender; @@ -37,8 +33,9 @@ impl Task { } })) } - } - ).boxed() + }, + ) + .boxed() } } @@ -52,40 +49,13 @@ impl Task { /// The "task actor" manages dynamically spawned tasks. pub(super) async fn task_actor( - mut task_rx: mpsc::UnboundedReceiver, + task_rx: mpsc::UnboundedReceiver, _cx: &JrConnectionCx, ) -> Result<(), crate::Error> { - let mut futures = FuturesUnordered::new(); - - loop { - // If we have no futures to run, wait until we do. - if futures.is_empty() { - match task_rx.next().await { - Some(task) => futures.push(task.future), - None => return Ok(()), - } - continue; - } - - // If there are no more tasks coming in, just drain our queue and return. - if task_rx.is_terminated() { - while let Some(result) = futures.next().await { - result?; - } - return Ok(()); - } - - // Otherwise, run futures until we get a request for a new task. - futures::select! { - result = futures.next() => if let Some(result) = result { - result?; - }, - - task = task_rx.next() => { - if let Some(task) = task { - futures.push(task.future); - } - } - } - } + process_stream_concurrently( + task_rx, + async |task| task.future.await, + |a, b| Box::pin(a(b)), + ) + .await } diff --git a/src/sacp/src/lib.rs b/src/sacp/src/lib.rs index 0197965..8305250 100644 --- a/src/sacp/src/lib.rs +++ b/src/sacp/src/lib.rs @@ -44,94 +44,16 @@ //! # } //! ``` //! -//! ## Common Patterns +//! ## Cookbook //! -//! ### Pattern 1: Defining Reusable Components +//! The [`cookbook`] module contains documented patterns for common tasks: //! -//! When building agents or proxies, define a struct that implements [`Component`]. Internally, use the role's `builder()` method to set up handlers: -//! -//! ```rust,ignore -//! use sacp::Component; -//! -//! struct MyAgent { -//! config: AgentConfig, -//! } -//! -//! impl Component for MyAgent { -//! async fn serve(self, client: impl Component) -> Result<(), sacp::Error> { -//! UntypedRole::builder() -//! .name("my-agent") -//! .on_receive_request(async move |req: PromptRequest, cx| { -//! // Don't block the message loop! Use await_when_* for async work -//! cx.respond(self.process_prompt(req)) -//! .await_when_result_received(async move |response| { -//! // This runs after the response is received -//! log_response(&response); -//! cx.respond(response) -//! }) -//! }) -//! .serve(client) -//! .await -//! } -//! } -//! ``` -//! -//! **Important:** Message handlers run on the event loop. Blocking in a handler will prevent the connection from processing new messages. -//! Use [`JrConnectionCx::spawn`] to offload expensive work, or use the `await_when_*` methods to avoid blocking. -//! -//! ### Pattern 2: Custom Message Handlers -//! -//! For reusable message handling logic, implement [`JrMessageHandler`] and use [`MatchMessage`](crate::util::MatchMessage) for dispatching: -//! -//! ```rust,ignore -//! use sacp::{JrMessageHandler, MessageAndCx, Handled}; -//! use sacp::util::MatchMessage; -//! -//! struct MyHandler { -//! state: Arc>, -//! } -//! -//! impl JrMessageHandler for MyHandler { -//! async fn handle_message(&mut self, message: MessageAndCx) -//! -> Result, sacp::Error> -//! { -//! MatchMessage::new(message) -//! .if_request(async |req: MyRequest, cx| { -//! // Handle using self.state -//! cx.respond(MyResponse { /* ... */ }) -//! }) -//! .await -//! .done() -//! } -//! -//! fn describe_chain(&self) -> impl std::fmt::Debug { "MyHandler" } -//! } -//! ``` -//! -//! ### Pattern 3: Connecting as a Client -//! -//! To connect to a JSON-RPC server and send requests, use `with_client`. Note the use of `async` (not `async move`) -//! to share access to local variables: -//! -//! ```rust,ignore -//! UntypedRole::builder() -//! .on_receive_notification(async |notif: SessionUpdate, cx| { -//! // Handle notifications from the server -//! Ok(()) -//! }) -//! .with_client(sacp::ByteStreams::new(stdout, stdin), async |cx| { -//! // Send requests using the connection context -//! let response = cx.send_request(MyRequest { /* ... */ }) -//! .block_task() -//! .await?; -//! -//! // Can access local variables here -//! process_response(response); -//! -//! Ok(()) -//! }) -//! .await -//! ``` +//! - [Roles and endpoints](cookbook#roles-and-endpoints) - Understanding `JrRole`, `JrEndpoint`, and how proxies work +//! - [Reusable components](cookbook::reusable_components) - Defining agents/proxies with [`Component`] +//! - [Custom message handlers](cookbook::custom_message_handlers) - Implementing [`JrMessageHandler`] +//! - [Connecting as a client](cookbook::connecting_as_client) - Using `with_client` to send requests +//! - [Global MCP server](cookbook::global_mcp_server) - Adding MCP servers to a handler chain +//! - [Per-session MCP server](cookbook::per_session_mcp_server) - Creating MCP servers per session //! //! ## Using the Request Context //! @@ -169,6 +91,8 @@ mod capabilities; /// Component abstraction for agents and proxies pub mod component; +/// Cookbook of common patterns for building ACP components +pub mod cookbook; /// JSON-RPC handler types for building custom message handlers pub mod handler; /// JSON-RPC connection and handler infrastructure @@ -245,6 +169,16 @@ macro_rules! tool_fn_mut { }; } +/// This is a hack that must be given as the final argument of +/// [`McpServerBuilder::tool_fn`] when defining stateless concurrent tools. +/// See [`tool_fn_mut!`] for the gory details. +#[macro_export] +macro_rules! tool_fn { + () => { + |func, params, context| Box::pin(func(params, context)) + }; +} + /// This macro is used for the value of the `to_future_hack` parameter of /// [`JrConnectionBuilder::on_receive_request`] and [`JrConnectionBuilder::on_receive_request_from`]. /// diff --git a/src/sacp/src/mcp_server/active_session.rs b/src/sacp/src/mcp_server/active_session.rs index aeeafb5..afa3e1c 100644 --- a/src/sacp/src/mcp_server/active_session.rs +++ b/src/sacp/src/mcp_server/active_session.rs @@ -18,10 +18,7 @@ use std::sync::Arc; /// This is added as a 'dynamic' handler to the connection context /// (see [`JrConnectionCx::add_dynamic_handler`]) and handles MCP-over-ACP messages /// with the appropriate ACP url. -pub(super) struct McpActiveSession -where - Role: HasEndpoint, -{ +pub(super) struct McpActiveSession { /// The role of the server #[expect(dead_code)] role: Role, @@ -228,17 +225,6 @@ where message: MessageCx, connection_cx: JrConnectionCx, ) -> Result, crate::Error> { - // Hmm, this is a bit wacky: - // - // * In a proxy, we expect to receive MCP over ACP notifications wrapped as a "FromSuccessorNotification" - // and we don't expect to receive them unwrapped (that would be the client sending it to us, not our agent, - // and that's weird); - // * But in a *client*, we expect to receive incoming messages unwrapped (i.e., from our successor), - // and not wrapped (we don't expect *anything* wrapped). - // - // So we just accept them in either direction for now. The whole thing feels a bit inelegant, - // but I guess it works. - MatchMessageFrom::new(message, &connection_cx) // MCP connect requests come from the Agent direction (wrapped in SuccessorMessage) .if_request_from(Agent, async |request, request_cx| { diff --git a/src/sacp/src/mcp_server/builder.rs b/src/sacp/src/mcp_server/builder.rs index 46fe955..042b667 100644 --- a/src/sacp/src/mcp_server/builder.rs +++ b/src/sacp/src/mcp_server/builder.rs @@ -23,7 +23,7 @@ use crate::{ jsonrpc::responder::{ChainResponder, JrResponder, NullResponder}, mcp_server::{ McpServer, McpServerConnect, - responder::{ToolCall, ToolFnResponder}, + responder::{ToolCall, ToolFnMutResponder, ToolFnResponder}, }, }; @@ -100,7 +100,7 @@ where self, tool: impl McpTool + 'static, tool_responder: R, - ) -> McpServerBuilder> { + ) -> McpServerBuilder> { let this = self.tool(tool); McpServerBuilder { role: this.role, @@ -144,56 +144,66 @@ where ) -> BoxFuture<'a, Result> + Send + 'static, - ) -> McpServerBuilder>> + ) -> McpServerBuilder> where P: JsonSchema + DeserializeOwned + 'static + Send, R: JsonSchema + Serialize + 'static + Send, F: AsyncFnMut(P, McpContext) -> Result + Send, { - struct ToolFnTool { - name: String, - description: String, - call_tx: mpsc::Sender>, - } - - impl McpTool for ToolFnTool - where - Role: JrRole, - P: JsonSchema + DeserializeOwned + 'static + Send, - R: JsonSchema + Serialize + 'static + Send, - { - type Input = P; - type Output = R; - - fn name(&self) -> String { - self.name.clone() - } - - fn description(&self) -> String { - self.description.clone() - } - - async fn call_tool( - &self, - params: P, - mcp_cx: McpContext, - ) -> Result { - let (result_tx, result_rx) = oneshot::channel(); - - self.call_tx - .clone() - .send(ToolCall { - params, - mcp_cx, - result_tx, - }) - .await - .map_err(crate::util::internal_error)?; - - result_rx.await.map_err(crate::util::internal_error)? - } - } + let (call_tx, call_rx) = mpsc::channel(128); + self.tool_with_responder( + ToolFnTool { + name: name.to_string(), + description: description.to_string(), + call_tx, + }, + ToolFnMutResponder { + func, + call_rx, + tool_future_fn: Box::new(tool_future_hack), + }, + ) + } + /// Convenience wrapper for defining a stateless tool that can run concurrently. + /// Unlike [`tool_fn_mut`](Self::tool_fn_mut), multiple invocations of this tool can run + /// at the same time since the function is `Fn` rather than `FnMut`. + /// + /// # Parameters + /// + /// * `name`: The name of the tool. + /// * `description`: The description of the tool. + /// * `func`: The function that implements the tool. Use an async closure like `async |args, cx| { .. }`. + /// + /// # Examples + /// + /// ```rust,ignore + /// McpServer::builder("my-server") + /// .tool_fn( + /// "greet", + /// "Greet someone by name", + /// async |input: GreetInput, _cx| Ok(format!("Hello, {}!", input.name)), + /// ) + /// ``` + pub fn tool_fn( + self, + name: impl ToString, + description: impl ToString, + func: F, + tool_future_hack: impl for<'a> Fn( + &'a F, + P, + McpContext, + ) -> BoxFuture<'a, Result> + + Send + + Sync + + 'static, + ) -> McpServerBuilder> + where + P: JsonSchema + DeserializeOwned + 'static + Send, + R: JsonSchema + Serialize + 'static + Send, + F: AsyncFn(P, McpContext) -> Result + Send + Sync + 'static, + { let (call_tx, call_rx) = mpsc::channel(128); self.tool_with_responder( ToolFnTool { @@ -202,7 +212,7 @@ where call_tx, }, ToolFnResponder { - func, + func: func, call_rx, tool_future_fn: Box::new(tool_future_hack), }, @@ -417,3 +427,45 @@ fn to_rmcp_error(error: crate::Error) -> rmcp::ErrorData { data: error.data, } } + +/// MCP tool used for `tool_fn` and `tooL_fn_mut`. +/// Each time it is invoked, it sends a `ToolCall` message to `call_tx`. +struct ToolFnTool { + name: String, + description: String, + call_tx: mpsc::Sender>, +} + +impl McpTool for ToolFnTool +where + Role: JrRole, + P: JsonSchema + DeserializeOwned + 'static + Send, + R: JsonSchema + Serialize + 'static + Send, +{ + type Input = P; + type Output = R; + + fn name(&self) -> String { + self.name.clone() + } + + fn description(&self) -> String { + self.description.clone() + } + + async fn call_tool(&self, params: P, mcp_cx: McpContext) -> Result { + let (result_tx, result_rx) = oneshot::channel(); + + self.call_tx + .clone() + .send(ToolCall { + params, + mcp_cx, + result_tx, + }) + .await + .map_err(crate::util::internal_error)?; + + result_rx.await.map_err(crate::util::internal_error)? + } +} diff --git a/src/sacp/src/mcp_server/mod.rs b/src/sacp/src/mcp_server/mod.rs index 0adfd60..74d0a42 100644 --- a/src/sacp/src/mcp_server/mod.rs +++ b/src/sacp/src/mcp_server/mod.rs @@ -56,6 +56,5 @@ mod tool; pub use builder::McpServerBuilder; pub use connect::McpServerConnect; pub use context::McpContext; -pub use responder::{ToolCall, ToolFnResponder}; -pub use server::{McpMessageHandler, McpServer}; +pub use server::McpServer; pub use tool::McpTool; diff --git a/src/sacp/src/mcp_server/responder.rs b/src/sacp/src/mcp_server/responder.rs index 4a41922..90f40e9 100644 --- a/src/sacp/src/mcp_server/responder.rs +++ b/src/sacp/src/mcp_server/responder.rs @@ -1,11 +1,15 @@ //! MCP-specific responder types. -use futures::{StreamExt, channel::mpsc, future::BoxFuture}; +use futures::{ + StreamExt, + channel::{mpsc, oneshot}, + future::BoxFuture, +}; use crate::{JrConnectionCx, JrRole, jsonrpc::responder::JrResponder, mcp_server::McpContext}; /// A tool call request sent through the channel. -pub struct ToolCall { +pub(super) struct ToolCall { pub(crate) params: P, pub(crate) mcp_cx: McpContext, pub(crate) result_tx: futures::channel::oneshot::Sender>, @@ -13,13 +17,16 @@ pub struct ToolCall { /// Responder for a `tool_fn` closure that receives tool calls through a channel /// and invokes the user's async function. -pub struct ToolFnResponder { +pub(super) struct ToolFnMutResponder { pub(crate) func: F, pub(crate) call_rx: mpsc::Receiver>, - pub(crate) tool_future_fn: Box Fn(&'a mut F, P, McpContext) -> BoxFuture<'a, Result> + Send>, + pub(crate) tool_future_fn: Box< + dyn for<'a> Fn(&'a mut F, P, McpContext) -> BoxFuture<'a, Result> + + Send, + >, } -impl JrResponder for ToolFnResponder +impl JrResponder for ToolFnMutResponder where Role: JrRole, P: Send, @@ -27,7 +34,11 @@ where F: Send, { async fn run(self, _cx: JrConnectionCx) -> Result<(), crate::Error> { - let ToolFnResponder { mut func, mut call_rx, tool_future_fn } = self; + let ToolFnMutResponder { + mut func, + mut call_rx, + tool_future_fn, + } = self; while let Some(ToolCall { params, mcp_cx, @@ -42,3 +53,74 @@ where Ok(()) } } + +/// Responder for a `tool_fn` closure that receives tool calls through a channel +/// and invokes the user's async function concurrently. +pub(super) struct ToolFnResponder { + pub(crate) func: F, + pub(crate) call_rx: mpsc::Receiver>, + pub(crate) tool_future_fn: Box< + dyn for<'a> Fn(&'a F, P, McpContext) -> BoxFuture<'a, Result> + + Send + + Sync, + >, +} + +impl JrResponder for ToolFnResponder +where + Role: JrRole, + P: Send, + R: Send, + F: Send + Sync, +{ + async fn run(self, _cx: JrConnectionCx) -> Result<(), crate::Error> { + let ToolFnResponder { + func, + call_rx, + tool_future_fn, + } = self; + crate::util::process_stream_concurrently( + call_rx, + async |tool_call| { + fn hack<'a, F, P, R, Role>( + func: &'a F, + params: P, + mcp_cx: McpContext, + tool_future_fn: &'a ( + dyn Fn( + &'a F, + P, + McpContext, + ) -> BoxFuture<'a, Result> + + Send + + Sync + ), + result_tx: oneshot::Sender>, + ) -> BoxFuture<'a, ()> + where + Role: JrRole, + P: Send, + R: Send, + F: Send + Sync, + { + Box::pin(async move { + let result = tool_future_fn(func, params, mcp_cx).await; + // Ignore send errors - the receiver may have been dropped + let _ = result_tx.send(result); + }) + } + + let ToolCall { + params, + mcp_cx, + result_tx, + } = tool_call; + + hack(&func, params, mcp_cx, &*tool_future_fn, result_tx).await; + Ok(()) + }, + |a, b| Box::pin(a(b)), + ) + .await + } +} diff --git a/src/sacp/src/mcp_server/server.rs b/src/sacp/src/mcp_server/server.rs index f889db1..2c328f8 100644 --- a/src/sacp/src/mcp_server/server.rs +++ b/src/sacp/src/mcp_server/server.rs @@ -20,7 +20,7 @@ use crate::{ /// `McpServer` wraps an [`McpServerConnect`] implementation and can be used either: /// - As a message handler via [`JrConnectionBuilder::with_handler`], automatically /// attaching to new sessions -/// - Manually via [`Self::add_to_new_session`] for more control +/// - Manually via [`Self::into_dynamic_handler`] for more control /// /// # Creating an MCP Server /// @@ -40,7 +40,7 @@ use crate::{ /// ``` pub struct McpServer { /// The "message handler" handles incoming messages to the MCP server (speaks the MCP protocol). - message_handler: McpMessageHandler, + message_handler: McpNewSessionHandler, /// The "responder" is a task that should be run alongside the message handler. /// Some futures direct messages back through channels to this future which actually @@ -72,29 +72,42 @@ where /// See [`Self::builder`] to construct MCP servers from Rust code. pub fn new(c: impl McpServerConnect, responder: Responder) -> Self { McpServer { - message_handler: McpMessageHandler { - connect: Arc::new(c), - }, + message_handler: McpNewSessionHandler::new(c), responder, } } /// Split this MCP server into the message handler and a future that must be run while the handler is active. - pub(crate) fn into_handler_and_responder(self) -> (McpMessageHandler, Responder) { + pub(crate) fn into_handler_and_responder(self) -> (McpNewSessionHandler, Responder) { (self.message_handler, self.responder) } } /// Message handler created from a [`McpServer`]. -#[derive(Clone)] -pub struct McpMessageHandler { +pub(crate) struct McpNewSessionHandler { + acp_url: String, connect: Arc>, + active_session: McpActiveSession, } -impl McpMessageHandler +impl McpNewSessionHandler where Role: HasEndpoint, { + pub fn new(c: impl McpServerConnect) -> Self { + let acp_url = format!("acp:{}", Uuid::new_v4()); + let connect = Arc::new(c); + Self { + active_session: McpActiveSession::new( + Role::default(), + acp_url.clone(), + connect.clone(), + ), + acp_url, + connect, + } + } + /// Attach this server to the new session, spawning off a dynamic handler that will /// manage requests coming from this session. /// @@ -105,24 +118,26 @@ where /// will no longer be received, so you need to keep this value alive as long as the session /// is in use. You can also invoke [`DynamicHandlerRegistration::run_indefinitely`] /// if you want to keep the handler running indefinitely. - pub fn add_to_new_session( - &self, + pub fn into_dynamic_handler( + self, request: &mut NewSessionRequest, cx: &JrConnectionCx, ) -> Result, crate::Error> { - let acp_url = format!("acp:{}", Uuid::new_v4()); - let connection = - McpActiveSession::new(Role::default(), acp_url.clone(), self.connect.clone()); + self.modify_new_session_request(request); + cx.add_dynamic_handler(self.active_session) + } + + /// Modify the new session request to include this MCP server. + fn modify_new_session_request(&self, request: &mut NewSessionRequest) { request.mcp_servers.push(crate::schema::McpServer::Http { name: self.connect.name(), - url: acp_url, + url: self.acp_url.clone(), headers: Default::default(), }); - cx.add_dynamic_handler(connection) } } -impl JrMessageHandler for McpMessageHandler +impl JrMessageHandler for McpNewSessionHandler where Role: HasEndpoint + HasEndpoint, { @@ -137,8 +152,7 @@ where .if_request_from( Client, async |mut request: NewSessionRequest, request_cx| { - self.add_to_new_session(&mut request, &cx)? - .run_indefinitely(); + self.modify_new_session_request(&mut request); Ok(Handled::No { message: (request, request_cx), retry: false, @@ -146,7 +160,8 @@ where }, ) .await - .done() + .otherwise_delegate(&mut self.active_session) + .await } fn describe_chain(&self) -> impl std::fmt::Debug { diff --git a/src/sacp/src/role.rs b/src/sacp/src/role.rs index ce18b9e..41eec53 100644 --- a/src/sacp/src/role.rs +++ b/src/sacp/src/role.rs @@ -398,19 +398,16 @@ impl JrRole for ProxyToConductor { // and add a dynamic handler for that // session-id. .if_request_from(Client, async |request: NewSessionRequest, request_cx| { - cx.send_request_to(Agent, request) - .await_when_result_received({ - let cx = cx.clone(); - async move |result| { - if let Ok(NewSessionResponse { session_id, .. }) = &result { - cx.add_dynamic_handler(ProxySessionMessages { - session_id: session_id.clone(), - })? + cx.send_request_to(Agent, request).on_receiving_result({ + let cx = cx.clone(); + async move |result| { + if let Ok(NewSessionResponse { session_id, .. }) = &result { + cx.add_dynamic_handler(ProxySessionMessages::new(session_id.clone()))? .run_indefinitely(); - } - request_cx.respond_with_result(result) } - }) + request_cx.respond_with_result(result) + } + }) }) .await // Incoming notification from the agent -- forward to the client @@ -455,12 +452,30 @@ impl JrRole for ProxyToConductor { } } -struct ProxySessionMessages { +/// Dynamic handler that proxies session messages from Agent to Client. +/// +/// This is used internally to handle session message routing after a +/// `session.new` request has been forwarded. +pub(crate) struct ProxySessionMessages { session_id: SessionId, + _marker: std::marker::PhantomData, +} + +impl ProxySessionMessages { + /// Create a new proxy handler for the given session. + pub fn new(session_id: SessionId) -> Self { + Self { + session_id, + _marker: std::marker::PhantomData, + } + } } -impl JrMessageHandler for ProxySessionMessages { - type Role = ProxyToConductor; +impl JrMessageHandler for ProxySessionMessages +where + Role: HasEndpoint + HasEndpoint, +{ + type Role = Role; async fn handle_message( &mut self, diff --git a/src/sacp/src/session.rs b/src/sacp/src/session.rs index 326b8e5..faed433 100644 --- a/src/sacp/src/session.rs +++ b/src/sacp/src/session.rs @@ -7,12 +7,14 @@ use agent_client_protocol_schema::{ use futures::channel::mpsc; use crate::{ - Agent, Handled, HasEndpoint, JrConnectionCx, JrMessageHandler, JrRole, MessageCx, + Agent, Client, Handled, HasEndpoint, JrConnectionCx, JrMessageHandler, JrRequestCx, JrRole, + MessageCx, jsonrpc::{ DynamicHandlerRegistration, responder::{ChainResponder, JrResponder, NullResponder}, }, mcp_server::McpServer, + role::ProxySessionMessages, schema::SessionId, util::{MatchMessage, MatchMessageFrom, run_until}, }; @@ -33,6 +35,17 @@ where ) } + /// Session builder starting from an existing request. + /// + /// Use this when you've intercepted a `session.new` request and want to + /// modify it (e.g., inject MCP servers) before forwarding. + pub fn build_session_from( + &self, + request: NewSessionRequest, + ) -> SessionBuilder { + SessionBuilder::new(self, request) + } + /// Given a session response received from the agent, /// attach a handler to process messages related to this session /// and let you access them. @@ -114,7 +127,7 @@ where { let (handler, responder) = mcp_server.into_handler_and_responder(); self.dynamic_handler_registrations - .push(handler.add_to_new_session(&mut self.request, &self.connection)?); + .push(handler.into_dynamic_handler(&mut self.request, &self.connection)?); Ok(SessionBuilder { connection: self.connection, request: self.request, @@ -173,6 +186,54 @@ where self.connection .attach_session(response, self.dynamic_handler_registrations) } + + /// Forward the session request to the agent and proxy all messages. + /// + /// Use this when you want to inject MCP servers into a session but don't need + /// to actively interact with it. The session messages will be proxied between + /// client and agent automatically. + /// + /// # Parameters + /// + /// * `request_cx`: The request context from the intercepted `session.new` request, + /// used to send the response back to the client. + /// * `run_responder`: this is typically just `Responder::run`; + /// the need for this parameter is a workaround for Rust limitations. + pub async fn proxy_session( + self, + request_cx: JrRequestCx, + run_responder: impl FnOnce(Responder, JrConnectionCx) -> F, + ) -> Result<(), crate::Error> + where + Role: HasEndpoint, + F: Future> + Send + 'static, + { + let response = self + .connection + .send_request_to(Agent, self.request) + .block_task() + .await?; + + // Add dynamic handler to proxy session messages + let session_id = response.session_id.clone(); + self.connection + .add_dynamic_handler(ProxySessionMessages::new(session_id))? + .run_indefinitely(); + + // Keep MCP server handlers alive + for registration in self.dynamic_handler_registrations { + registration.run_indefinitely(); + } + + // Spawn the responder + let cx = self.connection.clone(); + self.connection.spawn(run_responder(self.responder, cx))?; + + // Send response back to client + request_cx.respond(response)?; + + Ok(()) + } } /// Active session struct that lets you send prompts and receive updates. @@ -242,7 +303,7 @@ where meta: None, }, ) - .await_when_result_received(async move |result| { + .on_receiving_result(async move |result| { let PromptResponse { stop_reason, meta: _, diff --git a/src/sacp/src/util.rs b/src/sacp/src/util.rs index 1e558a9..9826793 100644 --- a/src/sacp/src/util.rs +++ b/src/sacp/src/util.rs @@ -1,5 +1,10 @@ // Types re-exported from crate root +use futures::{ + future::BoxFuture, + stream::{Stream, StreamExt}, +}; + mod typed; pub use typed::{MatchMessage, MatchMessageFrom, TypeNotification}; @@ -98,3 +103,74 @@ pub async fn run_until( } } } + +/// Process items from a stream concurrently. +/// +/// For each item received from `stream`, calls `process_fn` to create a future, +/// then runs all futures concurrently. If any future returns an error, +/// stops processing and returns that error. +/// +/// This is useful for patterns where you receive work items from a channel +/// and want to process them concurrently while respecting backpressure. +pub async fn process_stream_concurrently( + stream: impl Stream, + process_fn: F, + process_fn_hack: impl for<'a> Fn(&'a F, T) -> BoxFuture<'a, Result<(), crate::Error>>, +) -> Result<(), crate::Error> +where + F: AsyncFn(T) -> Result<(), crate::Error>, +{ + use std::pin::pin; + + use futures::stream::{FusedStream, FuturesUnordered}; + use futures_concurrency::future::Race; + + let mut stream = pin!(stream.fuse()); + let mut futures: FuturesUnordered<_> = FuturesUnordered::new(); + + loop { + // If we have no futures to run, wait until we do. + if futures.is_empty() { + match stream.next().await { + Some(item) => futures.push(process_fn_hack(&process_fn, item)), + None => return Ok(()), + } + continue; + } + + // If there are no more items coming in, just drain our queue and return. + if stream.is_terminated() { + while let Some(result) = futures.next().await { + result?; + } + return Ok(()); + } + + // Otherwise, race between getting a new item and completing a future. + enum Event { + NewItem(Option), + FutureCompleted(Option>), + } + + let event = (async { Event::NewItem(stream.next().await) }, async { + Event::FutureCompleted(futures.next().await) + }) + .race() + .await; + + match event { + Event::NewItem(Some(item)) => { + futures.push(process_fn_hack(&process_fn, item)); + } + Event::NewItem(None) => { + // Stream closed, loop will catch is_terminated + } + Event::FutureCompleted(Some(result)) => { + result?; + } + Event::FutureCompleted(None) => { + // No futures were pending, shouldn't happen since we checked is_empty + } + } + } +} diff --git a/src/sacp/src/util/typed.rs b/src/sacp/src/util/typed.rs index 5b21ca8..adc99e5 100644 --- a/src/sacp/src/util/typed.rs +++ b/src/sacp/src/util/typed.rs @@ -20,8 +20,8 @@ use jsonrpcmsg::Params; use crate::{ - Handled, HasDefaultEndpoint, JrConnectionCx, JrNotification, JrRequest, JrRequestCx, MessageCx, - UntypedMessage, + Handled, HasDefaultEndpoint, JrConnectionCx, JrMessageHandler, JrNotification, JrRequest, + JrRequestCx, MessageCx, UntypedMessage, role::{HasEndpoint, JrEndpoint, JrRole}, util::json_cast, }; @@ -519,6 +519,33 @@ impl MatchMessageFrom { Err(err) => Err(err), } } + + /// Handle messages that didn't match any previous `handle_if` call. + /// + /// This is the fallback handler that receives the original untyped message if none + /// of the typed handlers matched. You must call this method to complete the pattern + /// matching chain and get the final result. + pub async fn otherwise_delegate( + self, + mut handler: impl JrMessageHandler, + ) -> Result, crate::Error> { + match self.state? { + Handled::Yes => Ok(Handled::Yes), + Handled::No { + message, + retry: outer_retry, + } => match handler.handle_message(message, self.cx).await? { + Handled::Yes => Ok(Handled::Yes), + Handled::No { + message, + retry: inner_retry, + } => Ok(Handled::No { + message, + retry: inner_retry | outer_retry, + }), + }, + } + } } /// Builder for pattern-matching on untyped JSON-RPC notifications. diff --git a/src/sacp/tests/jsonrpc_advanced.rs b/src/sacp/tests/jsonrpc_advanced.rs index 0f24c00..f027b55 100644 --- a/src/sacp/tests/jsonrpc_advanced.rs +++ b/src/sacp/tests/jsonrpc_advanced.rs @@ -14,7 +14,7 @@ use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; /// Test helper to block and wait for a JSON-RPC response. async fn recv(response: JrResponse) -> Result { let (tx, rx) = tokio::sync::oneshot::channel(); - response.await_when_result_received(async move |result| { + response.on_receiving_result(async move |result| { tx.send(result).map_err(|_| sacp::Error::internal_error()) })?; rx.await.map_err(|_| sacp::Error::internal_error())? diff --git a/src/sacp/tests/jsonrpc_connection_builder.rs b/src/sacp/tests/jsonrpc_connection_builder.rs index b469f68..a203b31 100644 --- a/src/sacp/tests/jsonrpc_connection_builder.rs +++ b/src/sacp/tests/jsonrpc_connection_builder.rs @@ -17,7 +17,7 @@ use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; /// Test helper to block and wait for a JSON-RPC response. async fn recv(response: JrResponse) -> Result { let (tx, rx) = tokio::sync::oneshot::channel(); - response.await_when_result_received(async move |result| { + response.on_receiving_result(async move |result| { tx.send(result).map_err(|_| sacp::Error::internal_error()) })?; rx.await.map_err(|_| sacp::Error::internal_error())? diff --git a/src/sacp/tests/jsonrpc_edge_cases.rs b/src/sacp/tests/jsonrpc_edge_cases.rs index 23df8c6..62a275b 100644 --- a/src/sacp/tests/jsonrpc_edge_cases.rs +++ b/src/sacp/tests/jsonrpc_edge_cases.rs @@ -15,7 +15,7 @@ use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; /// Test helper to block and wait for a JSON-RPC response. async fn recv(response: JrResponse) -> Result { let (tx, rx) = tokio::sync::oneshot::channel(); - response.await_when_result_received(async move |result| { + response.on_receiving_result(async move |result| { tx.send(result).map_err(|_| sacp::Error::internal_error()) })?; rx.await.map_err(|_| sacp::Error::internal_error())? diff --git a/src/sacp/tests/jsonrpc_error_handling.rs b/src/sacp/tests/jsonrpc_error_handling.rs index 08fed55..12bcde2 100644 --- a/src/sacp/tests/jsonrpc_error_handling.rs +++ b/src/sacp/tests/jsonrpc_error_handling.rs @@ -17,7 +17,7 @@ use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; /// Test helper to block and wait for a JSON-RPC response. async fn recv(response: JrResponse) -> Result { let (tx, rx) = tokio::sync::oneshot::channel(); - response.await_when_result_received(async move |result| { + response.on_receiving_result(async move |result| { tx.send(result).map_err(|_| sacp::Error::internal_error()) })?; rx.await.map_err(|_| sacp::Error::internal_error())? diff --git a/src/sacp/tests/jsonrpc_hello.rs b/src/sacp/tests/jsonrpc_hello.rs index dac5e65..36b60ae 100644 --- a/src/sacp/tests/jsonrpc_hello.rs +++ b/src/sacp/tests/jsonrpc_hello.rs @@ -17,7 +17,7 @@ use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; /// Test helper to block and wait for a JSON-RPC response. async fn recv(response: JrResponse) -> Result { let (tx, rx) = tokio::sync::oneshot::channel(); - response.await_when_result_received(async move |result| { + response.on_receiving_result(async move |result| { tx.send(result).map_err(|_| sacp::Error::internal_error()) })?; rx.await.map_err(|_| sacp::Error::internal_error())?