diff --git a/Cargo.toml b/Cargo.toml index 6d6a917..4a44787 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,7 +39,7 @@ tracing-appender = "0.2" tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] } # MCP SDK -rmcp = { version = "0.9", features = ["server", "transport-io", "schemars"] } +rmcp = { version = "0.12.0", features = ["server", "transport-io", "schemars"] } # CLI parsing clap = { version = "4.5", features = ["derive"] } diff --git a/src/sacp-conductor/tests/scoped_mcp_server.rs b/src/sacp-conductor/tests/scoped_mcp_server.rs index 95033cd..9a00d34 100644 --- a/src/sacp-conductor/tests/scoped_mcp_server.rs +++ b/src/sacp-conductor/tests/scoped_mcp_server.rs @@ -37,7 +37,7 @@ async fn test_scoped_mcp_server_through_proxy() -> Result<(), sacp::Error> { .await?; expect_test::expect![[r#" - "OK: CallToolResult { content: [Annotated { raw: Text(RawTextContent { text: \"2\", meta: None }), annotations: None }], structured_content: Some(Number(2)), is_error: Some(false), meta: None }" + "OK: CallToolResult { content: [Annotated { raw: Text(RawTextContent { text: \"2\", meta: None }), annotations: None }], structured_content: None, is_error: Some(false), meta: None }" "#]].assert_debug_eq(&result); Ok(()) @@ -76,7 +76,7 @@ async fn test_scoped_mcp_server_through_session() -> Result<(), sacp::Error> { .await?; expect_test::expect![[r#" - "OK: CallToolResult { content: [Annotated { raw: Text(RawTextContent { text: \"2\", meta: None }), annotations: None }], structured_content: Some(Number(2)), is_error: Some(false), meta: None }" + "OK: CallToolResult { content: [Annotated { raw: Text(RawTextContent { text: \"2\", meta: None }), annotations: None }], structured_content: None, is_error: Some(false), meta: None }" "#]].assert_debug_eq(&result); Ok(()) diff --git a/src/sacp-conductor/tests/test_mcp_tool_output_types.rs b/src/sacp-conductor/tests/test_mcp_tool_output_types.rs new file mode 100644 index 0000000..fe09695 --- /dev/null +++ b/src/sacp-conductor/tests/test_mcp_tool_output_types.rs @@ -0,0 +1,121 @@ +//! Test MCP tools with various output types (string, integer, object) +//! +//! MCP structured output requires JSON objects. This test verifies behavior +//! when tools return non-object types like bare strings or integers. + +use sacp::Component; +use sacp::ProxyToConductor; +use sacp::mcp_server::McpServer; +use sacp_conductor::Conductor; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use tokio::io::duplex; +use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; + +/// Empty input for test tools +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +struct EmptyInput {} + +/// Create a proxy with tools that return different types +fn create_test_proxy() -> Result { + let mcp_server = McpServer::builder("test_server".to_string()) + .instructions("Test MCP server with various output types") + .tool_fn_mut( + "return_string", + "Returns a bare string", + async |_input: EmptyInput, _context| Ok("hello world".to_string()), + sacp::tool_fn_mut!(), + ) + .tool_fn_mut( + "return_integer", + "Returns a bare integer", + async |_input: EmptyInput, _context| Ok(42i32), + sacp::tool_fn_mut!(), + ) + .build(); + + Ok(sacp::DynComponent::new(ProxyWithTestServer { mcp_server })) +} + +struct ProxyWithTestServer> { + mcp_server: McpServer, +} + +impl + 'static + Send> Component for ProxyWithTestServer { + async fn serve(self, client: impl Component) -> Result<(), sacp::Error> { + ProxyToConductor::builder() + .name("test-proxy") + .with_mcp_server(self.mcp_server) + .serve(client) + .await + } +} + +/// Elizacp agent component wrapper for testing +struct ElizacpAgentComponent; + +impl Component for ElizacpAgentComponent { + async fn serve(self, client: impl Component) -> Result<(), sacp::Error> { + let (elizacp_write, client_read) = duplex(8192); + let (client_write, elizacp_read) = duplex(8192); + + let elizacp_transport = + sacp::ByteStreams::new(elizacp_write.compat_write(), elizacp_read.compat()); + + let client_transport = + sacp::ByteStreams::new(client_write.compat_write(), client_read.compat()); + + tokio::spawn(async move { + if let Err(e) = elizacp::ElizaAgent::new().serve(elizacp_transport).await { + tracing::error!("Elizacp error: {}", e); + } + }); + + client_transport.serve(client).await + } +} + +#[tokio::test] +async fn test_tool_returning_string() -> Result<(), sacp::Error> { + let result = yopo::prompt( + Conductor::new( + "test-conductor".to_string(), + vec![ + create_test_proxy()?, + sacp::DynComponent::new(ElizacpAgentComponent), + ], + Default::default(), + ), + r#"Use tool test_server::return_string with {}"#, + ) + .await?; + + // The result should contain "hello world" somewhere + assert!( + result.contains("hello world"), + "expected 'hello world' in result: {result}" + ); + + Ok(()) +} + +#[tokio::test] +async fn test_tool_returning_integer() -> Result<(), sacp::Error> { + let result = yopo::prompt( + Conductor::new( + "test-conductor".to_string(), + vec![ + create_test_proxy()?, + sacp::DynComponent::new(ElizacpAgentComponent), + ], + Default::default(), + ), + r#"Use tool test_server::return_integer with {}"#, + ) + .await?; + + // The result should contain "42" somewhere + assert!(result.contains("42"), "expected '42' in result: {result}"); + + Ok(()) +} diff --git a/src/sacp-conductor/tests/test_tool_fn.rs b/src/sacp-conductor/tests/test_tool_fn.rs index 3b4cb5c..23e6c39 100644 --- a/src/sacp-conductor/tests/test_tool_fn.rs +++ b/src/sacp-conductor/tests/test_tool_fn.rs @@ -94,7 +94,7 @@ async fn test_tool_fn_greet() -> Result<(), sacp::Error> { .await?; expect_test::expect![[r#" - "OK: CallToolResult { content: [Annotated { raw: Text(RawTextContent { text: \"\\\"Hello, World!\\\"\", meta: None }), annotations: None }], structured_content: Some(String(\"Hello, World!\")), is_error: Some(false), meta: None }" + "OK: CallToolResult { content: [Annotated { raw: Text(RawTextContent { text: \"\\\"Hello, World!\\\"\", meta: None }), annotations: None }], structured_content: None, is_error: Some(false), meta: None }" "#]].assert_debug_eq(&result); Ok(()) diff --git a/src/sacp/src/mcp_server/builder.rs b/src/sacp/src/mcp_server/builder.rs index 042b667..750bae6 100644 --- a/src/sacp/src/mcp_server/builder.rs +++ b/src/sacp/src/mcp_server/builder.rs @@ -10,7 +10,7 @@ use futures::{ use fxhash::FxHashMap; use rmcp::{ ErrorData, ServerHandler, - handler::server::tool::cached_schema_for_type, + handler::server::tool::{schema_for_output, schema_for_type}, model::{CallToolResult, ListToolsResult, Tool}, }; use schemars::JsonSchema; @@ -46,18 +46,34 @@ use crate::{ /// ) /// .build(); /// ``` -pub struct McpServerBuilder { +pub struct McpServerBuilder { role: Role, name: String, data: McpServerData, responder: Responder, } -#[derive(Default)] -struct McpServerData { +struct McpServerData { instructions: Option, tool_models: Vec, - tools: FxHashMap>>, + tools: FxHashMap>, +} + +/// A registered tool with its metadata. +struct RegisteredTool { + tool: Arc>, + /// Whether this tool returns structured output (i.e., has an output_schema). + has_structured_output: bool, +} + +impl Default for McpServerData { + fn default() -> Self { + Self { + instructions: None, + tool_models: Vec::new(), + tools: FxHashMap::default(), + } + } } impl McpServerBuilder @@ -87,10 +103,15 @@ where /// Add a tool to the server. pub fn tool(mut self, tool: impl McpTool + 'static) -> Self { let tool_model = make_tool_model(&tool); + let has_structured_output = tool_model.output_schema.is_some(); self.data.tool_models.push(tool_model); - self.data - .tools - .insert(tool.name(), make_erased_mcp_tool(tool)); + self.data.tools.insert( + tool.name(), + RegisteredTool { + tool: make_erased_mcp_tool(tool), + has_structured_output, + }, + ); self } @@ -313,7 +334,7 @@ where context: rmcp::service::RequestContext, ) -> Result { // Lookup the tool definition, erroring if not found - let Some(tool) = self.data.tools.get(&request.name[..]) else { + let Some(registered) = self.data.tools.get(&request.name[..]) else { return Err(rmcp::model::ErrorData::invalid_params( format!("tool `{}` not found", request.name), None, @@ -324,15 +345,25 @@ where let serde_value = serde_json::to_value(request.arguments).expect("valid json"); // Execute the user's tool, unless cancellation occurs + let has_structured_output = registered.has_structured_output; match futures::future::select( - tool.call_tool(serde_value, self.mcp_cx.clone()), + registered.tool.call_tool(serde_value, self.mcp_cx.clone()), pin!(context.ct.cancelled()), ) .await { // If completed successfully Either::Left((m, _)) => match m { - Ok(result) => Ok(CallToolResult::structured(result)), + Ok(result) => { + // Use structured output only if the tool declared an output_schema + if has_structured_output { + Ok(CallToolResult::structured(result)) + } else { + Ok(CallToolResult::success(vec![rmcp::model::Content::text( + result.to_string(), + )])) + } + } Err(error) => Err(to_rmcp_error(error)), }, @@ -382,8 +413,11 @@ fn make_tool_model>(tool: &M) -> Tool { name: tool.name().into(), title: tool.title(), description: Some(tool.description().into()), - input_schema: cached_schema_for_type::(), - output_schema: Some(cached_schema_for_type::()), + input_schema: schema_for_type::(), + // schema_for_output returns Err for non-object types (strings, integers, etc.) + // since MCP structured output requires JSON objects. We use .ok() to set + // output_schema to None for these tools, signaling unstructured output. + output_schema: schema_for_output::().ok(), annotations: None, icons: None, meta: None,