diff --git a/crates/rmcp-macros/src/lib.rs b/crates/rmcp-macros/src/lib.rs index 6bb06827..4175aff4 100644 --- a/crates/rmcp-macros/src/lib.rs +++ b/crates/rmcp-macros/src/lib.rs @@ -5,6 +5,7 @@ mod common; mod prompt; mod prompt_handler; mod prompt_router; +mod task_handler; mod tool; mod tool_handler; mod tool_router; @@ -263,3 +264,17 @@ pub fn prompt_handler(attr: TokenStream, input: TokenStream) -> TokenStream { .unwrap_or_else(|err| err.to_compile_error()) .into() } + +/// # task_handler +/// +/// Generates basic task-handling methods (`enqueue_task` and `list_tasks`) for a server handler +/// using a shared [`OperationProcessor`]. The default processor expression assumes a +/// `self.processor` field holding an `Arc>`, but it can be customized +/// via `#[task_handler(processor = ...)]`. Because the macro captures `self` inside spawned +/// futures, the handler type must implement [`Clone`]. +#[proc_macro_attribute] +pub fn task_handler(attr: TokenStream, input: TokenStream) -> TokenStream { + task_handler::task_handler(attr.into(), input.into()) + .unwrap_or_else(|err| err.to_compile_error()) + .into() +} diff --git a/crates/rmcp-macros/src/task_handler.rs b/crates/rmcp-macros/src/task_handler.rs new file mode 100644 index 00000000..f94cf130 --- /dev/null +++ b/crates/rmcp-macros/src/task_handler.rs @@ -0,0 +1,278 @@ +use darling::{FromMeta, ast::NestedMeta}; +use proc_macro2::TokenStream; +use quote::{ToTokens, quote}; +use syn::{Expr, ImplItem, ItemImpl}; + +#[derive(FromMeta)] +#[darling(default)] +struct TaskHandlerAttribute { + processor: Expr, +} + +impl Default for TaskHandlerAttribute { + fn default() -> Self { + Self { + processor: syn::parse2(quote! { self.processor }).expect("default processor expr"), + } + } +} + +pub fn task_handler(attr: TokenStream, input: TokenStream) -> syn::Result { + let attr_args = NestedMeta::parse_meta_list(attr)?; + let TaskHandlerAttribute { processor } = TaskHandlerAttribute::from_list(&attr_args)?; + let mut item_impl = syn::parse2::(input.clone())?; + + let has_method = |name: &str, item_impl: &ItemImpl| -> bool { + item_impl.items.iter().any(|item| match item { + ImplItem::Fn(func) => func.sig.ident == name, + _ => false, + }) + }; + + if !has_method("list_tasks", &item_impl) { + let list_fn = quote! { + async fn list_tasks( + &self, + _request: Option, + _: rmcp::service::RequestContext, + ) -> Result { + let running_ids = (#processor).lock().await.list_running(); + let total = running_ids.len() as u64; + let tasks = running_ids + .into_iter() + .map(|task_id| { + let timestamp = rmcp::task_manager::current_timestamp(); + rmcp::model::Task { + task_id, + status: rmcp::model::TaskStatus::Working, + status_message: None, + created_at: timestamp.clone(), + last_updated_at: Some(timestamp), + ttl: None, + poll_interval: None, + } + }) + .collect::>(); + + Ok(rmcp::model::ListTasksResult { + tasks, + next_cursor: None, + total: Some(total), + }) + } + }; + item_impl.items.push(syn::parse2::(list_fn)?); + } + + if !has_method("enqueue_task", &item_impl) { + let enqueue_fn = quote! { + async fn enqueue_task( + &self, + request: rmcp::model::CallToolRequestParam, + context: rmcp::service::RequestContext, + ) -> Result { + use rmcp::task_manager::{ + current_timestamp, OperationDescriptor, OperationMessage, OperationResultTransport, + ToolCallTaskResult, + }; + let task_id = context.id.to_string(); + let operation_name = request.name.to_string(); + let future_request = request.clone(); + let future_context = context.clone(); + let server = self.clone(); + + let descriptor = OperationDescriptor::new(task_id.clone(), operation_name) + .with_context(context) + .with_client_request(rmcp::model::ClientRequest::CallToolRequest( + rmcp::model::Request::new(request), + )); + + let task_result_id = task_id.clone(); + let future = Box::pin(async move { + let result = server.call_tool(future_request, future_context).await; + Ok( + Box::new(ToolCallTaskResult::new(task_result_id, result)) + as Box, + ) + }); + + (#processor) + .lock() + .await + .submit_operation(OperationMessage::new(descriptor, future)) + .map_err(|err| rmcp::ErrorData::internal_error( + format!("failed to enqueue task: {err}"), + None, + ))?; + + let timestamp = current_timestamp(); + let task = rmcp::model::Task { + task_id, + status: rmcp::model::TaskStatus::Working, + status_message: Some("Task accepted".to_string()), + created_at: timestamp.clone(), + last_updated_at: Some(timestamp), + ttl: None, + poll_interval: None, + }; + + Ok(rmcp::model::CreateTaskResult { task }) + } + }; + item_impl.items.push(syn::parse2::(enqueue_fn)?); + } + + if !has_method("get_task_info", &item_impl) { + let get_info_fn = quote! { + async fn get_task_info( + &self, + request: rmcp::model::GetTaskInfoParam, + _context: rmcp::service::RequestContext, + ) -> Result { + use rmcp::task_manager::current_timestamp; + let task_id = request.task_id.clone(); + let mut processor = (#processor).lock().await; + processor.collect_completed_results(); + + // Check completed results first + let completed = processor.peek_completed().iter().rev().find(|r| r.descriptor.operation_id == task_id); + if let Some(completed_result) = completed { + // Determine Finished vs Failed + let status = match &completed_result.result { + Ok(boxed) => { + if let Some(tool) = boxed.as_any().downcast_ref::() { + match &tool.result { + Ok(_) => rmcp::model::TaskStatus::Completed, + Err(_) => rmcp::model::TaskStatus::Failed, + } + } else { + rmcp::model::TaskStatus::Completed + } + } + Err(_) => rmcp::model::TaskStatus::Failed, + }; + let timestamp = current_timestamp(); + let task = rmcp::model::Task { + task_id, + status, + status_message: None, + created_at: timestamp.clone(), + last_updated_at: Some(timestamp), + ttl: completed_result.descriptor.ttl, + poll_interval: None, + }; + return Ok(rmcp::model::GetTaskInfoResult { task: Some(task) }); + } + + // If not completed, check running + let running = processor.list_running(); + if running.into_iter().any(|id| id == task_id) { + let timestamp = current_timestamp(); + let task = rmcp::model::Task { + task_id, + status: rmcp::model::TaskStatus::Working, + status_message: None, + created_at: timestamp.clone(), + last_updated_at: Some(timestamp), + ttl: None, + poll_interval: None, + }; + return Ok(rmcp::model::GetTaskInfoResult { task: Some(task) }); + } + + Ok(rmcp::model::GetTaskInfoResult { task: None }) + } + }; + item_impl.items.push(syn::parse2::(get_info_fn)?); + } + + if !has_method("get_task_result", &item_impl) { + let get_result_fn = quote! { + async fn get_task_result( + &self, + request: rmcp::model::GetTaskResultParam, + _context: rmcp::service::RequestContext, + ) -> Result { + use std::time::Duration; + let task_id = request.task_id.clone(); + + loop { + // Scope the lock so we can await outside if needed + { + let mut processor = (#processor).lock().await; + processor.collect_completed_results(); + + if let Some(task_result) = processor.take_completed_result(&task_id) { + match task_result.result { + Ok(boxed) => { + if let Some(tool) = boxed.as_any().downcast_ref::() { + match &tool.result { + Ok(call_tool) => { + let value = ::serde_json::to_value(call_tool).unwrap_or(::serde_json::Value::Null); + return Ok(rmcp::model::TaskResult { + content_type: "application/json".to_string(), + value, + summary: None, + }); + } + Err(err) => return Err(McpError::internal_error( + format!("task failed: {}", err), + None, + )), + } + } else { + return Err(McpError::internal_error("unsupported task result transport", None)); + } + } + Err(err) => return Err(McpError::internal_error( + format!("task execution error: {}", err), + None, + )), + } + } + + // Not completed yet: if not running, return not found + let running = processor.list_running(); + if !running.iter().any(|id| id == &task_id) { + return Err(McpError::resource_not_found(format!("task not found: {}", task_id), None)); + } + } + + tokio::time::sleep(Duration::from_millis(100)).await; + } + } + }; + item_impl + .items + .push(syn::parse2::(get_result_fn)?); + } + + if !has_method("cancel_task", &item_impl) { + let cancel_fn = quote! { + async fn cancel_task( + &self, + request: rmcp::model::CancelTaskParam, + _context: rmcp::service::RequestContext, + ) -> Result<(), McpError> { + let task_id = request.task_id; + let mut processor = (#processor).lock().await; + processor.collect_completed_results(); + + if processor.cancel_task(&task_id) { + return Ok(()); + } + + // If already completed, signal it's not cancellable + let exists_completed = processor.peek_completed().iter().any(|r| r.descriptor.operation_id == task_id); + if exists_completed { + return Err(McpError::invalid_request(format!("task already completed: {}", task_id), None)); + } + + Err(McpError::resource_not_found(format!("task not found: {}", task_id), None)) + } + }; + item_impl.items.push(syn::parse2::(cancel_fn)?); + } + + Ok(item_impl.into_token_stream()) +} diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 098287da..d908f6c3 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -197,3 +197,8 @@ path = "tests/test_progress_subscriber.rs" name = "test_elicitation" required-features = ["elicitation", "client", "server"] path = "tests/test_elicitation.rs" + +[[test]] +name = "test_task" +required-features = ["server", "client", "macros"] +path = "tests/test_task.rs" \ No newline at end of file diff --git a/crates/rmcp/src/error.rs b/crates/rmcp/src/error.rs index e0da2b3d..f51a7158 100644 --- a/crates/rmcp/src/error.rs +++ b/crates/rmcp/src/error.rs @@ -41,6 +41,10 @@ pub enum RmcpError { error: Box, }, // and cancellation shouldn't be an error? + + // TODO: add more error variants as needed + #[error("Task error: {0}")] + TaskError(String), } impl RmcpError { diff --git a/crates/rmcp/src/handler/server.rs b/crates/rmcp/src/handler/server.rs index fd062dbd..16466b07 100644 --- a/crates/rmcp/src/handler/server.rs +++ b/crates/rmcp/src/handler/server.rs @@ -11,6 +11,7 @@ pub mod router; pub mod tool; pub mod tool_name_validation; pub mod wrapper; + impl Service for H { async fn handle_request( &self, @@ -61,14 +62,38 @@ impl Service for H { .unsubscribe(request.params, context) .await .map(ServerResult::empty), - ClientRequest::CallToolRequest(request) => self - .call_tool(request.params, context) - .await - .map(ServerResult::CallToolResult), + ClientRequest::CallToolRequest(request) => { + if request.params.task.is_some() { + tracing::info!("Enqueueing task for tool call: {}", request.params.name); + self.enqueue_task(request.params, context.clone()) + .await + .map(ServerResult::CreateTaskResult) + } else { + self.call_tool(request.params, context) + .await + .map(ServerResult::CallToolResult) + } + } ClientRequest::ListToolsRequest(request) => self .list_tools(request.params, context) .await .map(ServerResult::ListToolsResult), + ClientRequest::ListTasksRequest(request) => self + .list_tasks(request.params, context) + .await + .map(ServerResult::ListTasksResult), + ClientRequest::GetTaskInfoRequest(request) => self + .get_task_info(request.params, context) + .await + .map(ServerResult::GetTaskInfoResult), + ClientRequest::GetTaskResultRequest(request) => self + .get_task_result(request.params, context) + .await + .map(ServerResult::TaskResult), + ClientRequest::CancelTaskRequest(request) => self + .cancel_task(request.params, context) + .await + .map(ServerResult::empty), } } @@ -104,6 +129,16 @@ impl Service for H { #[allow(unused_variables)] pub trait ServerHandler: Sized + Send + Sync + 'static { + fn enqueue_task( + &self, + _request: CallToolRequestParam, + _context: RequestContext, + ) -> impl Future> + Send + '_ { + std::future::ready(Err(McpError::internal_error( + "Task processing not implemented".to_string(), + None, + ))) + } fn ping( &self, context: RequestContext, @@ -240,4 +275,38 @@ pub trait ServerHandler: Sized + Send + Sync + 'static { fn get_info(&self) -> ServerInfo { ServerInfo::default() } + + fn list_tasks( + &self, + request: Option, + context: RequestContext, + ) -> impl Future> + Send + '_ { + std::future::ready(Err(McpError::method_not_found::())) + } + + fn get_task_info( + &self, + request: GetTaskInfoParam, + context: RequestContext, + ) -> impl Future> + Send + '_ { + std::future::ready(Err(McpError::method_not_found::())) + } + + fn get_task_result( + &self, + request: GetTaskResultParam, + context: RequestContext, + ) -> impl Future> + Send + '_ { + let _ = (request, context); + std::future::ready(Err(McpError::method_not_found::())) + } + + fn cancel_task( + &self, + request: CancelTaskParam, + context: RequestContext, + ) -> impl Future> + Send + '_ { + let _ = (request, context); + std::future::ready(Err(McpError::method_not_found::())) + } } diff --git a/crates/rmcp/src/handler/server/tool.rs b/crates/rmcp/src/handler/server/tool.rs index 21e7e1a2..16435e42 100644 --- a/crates/rmcp/src/handler/server/tool.rs +++ b/crates/rmcp/src/handler/server/tool.rs @@ -33,12 +33,17 @@ pub struct ToolCallContext<'s, S> { pub service: &'s S, pub name: Cow<'static, str>, pub arguments: Option, + pub task: Option, } impl<'s, S> ToolCallContext<'s, S> { pub fn new( service: &'s S, - CallToolRequestParam { name, arguments }: CallToolRequestParam, + CallToolRequestParam { + name, + arguments, + task, + }: CallToolRequestParam, request_context: RequestContext, ) -> Self { Self { @@ -46,6 +51,7 @@ impl<'s, S> ToolCallContext<'s, S> { service, name, arguments, + task, } } pub fn name(&self) -> &str { diff --git a/crates/rmcp/src/lib.rs b/crates/rmcp/src/lib.rs index 9f81eabe..cba4eeba 100644 --- a/crates/rmcp/src/lib.rs +++ b/crates/rmcp/src/lib.rs @@ -162,6 +162,7 @@ pub use service::{RoleClient, serve_client}; pub use service::{RoleServer, serve_server}; pub mod handler; +pub mod task_manager; pub mod transport; // re-export diff --git a/crates/rmcp/src/model.rs b/crates/rmcp/src/model.rs index 33b507da..ec12b3cd 100644 --- a/crates/rmcp/src/model.rs +++ b/crates/rmcp/src/model.rs @@ -8,6 +8,7 @@ mod meta; mod prompt; mod resource; mod serde_impl; +mod task; mod tool; pub use annotated::*; pub use capabilities::*; @@ -19,6 +20,7 @@ pub use prompt::*; pub use resource::*; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use serde_json::Value; +pub use task::*; pub use tool::*; /// A JSON object type alias for convenient handling of JSON data. @@ -1653,6 +1655,8 @@ pub struct CallToolRequestParam { /// Arguments to pass to the tool (must match the tool's input schema) #[serde(skip_serializing_if = "Option::is_none")] pub arguments: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub task: Option, } /// Request to call a specific tool @@ -1691,6 +1695,61 @@ pub struct GetPromptResult { pub messages: Vec, } +// ============================================================================= +// TASK MANAGEMENT +// ============================================================================= + +const_string!(GetTaskInfoMethod = "tasks/get"); +pub type GetTaskInfoRequest = Request; + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct GetTaskInfoParam { + pub task_id: String, +} + +const_string!(ListTasksMethod = "tasks/list"); +pub type ListTasksRequest = RequestOptionalParam; + +const_string!(GetTaskResultMethod = "tasks/result"); +pub type GetTaskResultRequest = Request; + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct GetTaskResultParam { + pub task_id: String, +} + +const_string!(CancelTaskMethod = "tasks/cancel"); +pub type CancelTaskRequest = Request; + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct CancelTaskParam { + pub task_id: String, +} +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct GetTaskInfoResult { + #[serde(skip_serializing_if = "Option::is_none")] + pub task: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct ListTasksResult { + pub tasks: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub next_cursor: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub total: Option, +} + // ============================================================================= // MESSAGE TYPE UNIONS // ============================================================================= @@ -1757,7 +1816,11 @@ ts_union!( | SubscribeRequest | UnsubscribeRequest | CallToolRequest - | ListToolsRequest; + | ListToolsRequest + | GetTaskInfoRequest + | ListTasksRequest + | GetTaskResultRequest + | CancelTaskRequest; ); impl ClientRequest { @@ -1776,6 +1839,10 @@ impl ClientRequest { ClientRequest::UnsubscribeRequest(r) => r.method.as_str(), ClientRequest::CallToolRequest(r) => r.method.as_str(), ClientRequest::ListToolsRequest(r) => r.method.as_str(), + ClientRequest::GetTaskInfoRequest(r) => r.method.as_str(), + ClientRequest::ListTasksRequest(r) => r.method.as_str(), + ClientRequest::GetTaskResultRequest(r) => r.method.as_str(), + ClientRequest::CancelTaskRequest(r) => r.method.as_str(), } } } @@ -1833,6 +1900,10 @@ ts_union!( | ListToolsResult | CreateElicitationResult | EmptyResult + | CreateTaskResult + | ListTasksResult + | GetTaskInfoResult + | TaskResult ; ); diff --git a/crates/rmcp/src/model/capabilities.rs b/crates/rmcp/src/model/capabilities.rs index cbe1a6ea..1740b3ee 100644 --- a/crates/rmcp/src/model/capabilities.rs +++ b/crates/rmcp/src/model/capabilities.rs @@ -40,6 +40,25 @@ pub struct RootsCapabilities { pub list_changed: Option, } +/// Task capability negotiation for SEP-1686. +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct TasksCapability { + /// Map of request category (e.g. "tools.call") to a boolean indicating support. + #[serde(skip_serializing_if = "Option::is_none")] + pub requests: Option, + /// Whether the receiver supports `tasks/list`. + #[serde(skip_serializing_if = "Option::is_none")] + pub list: Option, + /// Whether the receiver supports `tasks/cancel`. + #[serde(skip_serializing_if = "Option::is_none")] + pub cancel: Option, +} + +/// A convenience alias for describing per-request task support. +pub type TaskRequestMap = BTreeMap; + /// Capability for handling elicitation requests from servers. /// /// Elicitation allows servers to request interactive input from users during tool execution. @@ -78,6 +97,8 @@ pub struct ClientCapabilities { /// Capability to handle elicitation requests from servers for interactive user input #[serde(skip_serializing_if = "Option::is_none")] pub elicitation: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tasks: Option, } /// @@ -109,6 +130,8 @@ pub struct ServerCapabilities { pub resources: Option, #[serde(skip_serializing_if = "Option::is_none")] pub tools: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tasks: Option, } macro_rules! builder { @@ -223,12 +246,13 @@ builder! { completions: JsonObject, prompts: PromptsCapability, resources: ResourcesCapability, - tools: ToolsCapability + tools: ToolsCapability, + tasks: TasksCapability } } -impl - ServerCapabilitiesBuilder> +impl + ServerCapabilitiesBuilder> { pub fn enable_tool_list_changed(mut self) -> Self { if let Some(c) = self.tools.as_mut() { @@ -238,8 +262,8 @@ impl } } -impl - ServerCapabilitiesBuilder> +impl + ServerCapabilitiesBuilder> { pub fn enable_prompts_list_changed(mut self) -> Self { if let Some(c) = self.prompts.as_mut() { @@ -249,8 +273,8 @@ impl } } -impl - ServerCapabilitiesBuilder> +impl + ServerCapabilitiesBuilder> { pub fn enable_resources_list_changed(mut self) -> Self { if let Some(c) = self.resources.as_mut() { @@ -273,11 +297,12 @@ builder! { roots: RootsCapabilities, sampling: JsonObject, elicitation: ElicitationCapability, + tasks: TasksCapability, } } -impl - ClientCapabilitiesBuilder> +impl + ClientCapabilitiesBuilder> { pub fn enable_roots_list_changed(mut self) -> Self { if let Some(c) = self.roots.as_mut() { @@ -288,8 +313,8 @@ impl } #[cfg(feature = "elicitation")] -impl - ClientCapabilitiesBuilder> +impl + ClientCapabilitiesBuilder> { /// Enable JSON Schema validation for elicitation responses. /// When enabled, the client will validate user input against the requested_schema diff --git a/crates/rmcp/src/model/meta.rs b/crates/rmcp/src/model/meta.rs index a03fc056..46e40779 100644 --- a/crates/rmcp/src/model/meta.rs +++ b/crates/rmcp/src/model/meta.rs @@ -86,6 +86,10 @@ variant_extension! { UnsubscribeRequest CallToolRequest ListToolsRequest + GetTaskInfoRequest + ListTasksRequest + GetTaskResultRequest + CancelTaskRequest } } diff --git a/crates/rmcp/src/model/task.rs b/crates/rmcp/src/model/task.rs new file mode 100644 index 00000000..8cb0ee58 --- /dev/null +++ b/crates/rmcp/src/model/task.rs @@ -0,0 +1,79 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +/// Canonical task lifecycle status as defined by SEP-1686. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "snake_case")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub enum TaskStatus { + /// The receiver accepted the request and is currently working on it. + #[default] + Working, + /// The receiver requires additional input before work can continue. + InputRequired, + /// The underlying operation completed successfully and the result is ready. + Completed, + /// The underlying operation failed and will not continue. + Failed, + /// The task was cancelled and will not continue processing. + Cancelled, +} + +/// Final result for a succeeded task (returned from `tasks/result`). +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct TaskResult { + /// MIME type or custom content-type identifier. + pub content_type: String, + /// The actual result payload, matching the underlying request's schema. + pub value: Value, + /// Optional short summary for UI surfaces. + #[serde(skip_serializing_if = "Option::is_none")] + pub summary: Option, +} + +/// Primary Task object that surfaces metadata during the task lifecycle. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct Task { + /// Unique task identifier generated by the receiver. + pub task_id: String, + /// Current lifecycle status (see [`TaskStatus`]). + pub status: TaskStatus, + /// Optional human-readable status message for UI surfaces. + #[serde(skip_serializing_if = "Option::is_none")] + pub status_message: Option, + /// ISO-8601 creation timestamp. + pub created_at: String, + /// ISO-8601 timestamp for the most recent status change. + #[serde(skip_serializing_if = "Option::is_none")] + pub last_updated_at: Option, + /// Retention window in milliseconds that the receiver agreed to honor. + #[serde(skip_serializing_if = "Option::is_none")] + pub ttl: Option, + /// Suggested polling interval (milliseconds). + #[serde(skip_serializing_if = "Option::is_none")] + pub poll_interval: Option, +} + +/// Wrapper returned by task-augmented requests (CreateTaskResult in SEP-1686). +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct CreateTaskResult { + pub task: Task, +} + +/// Paginated list of tasks +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct TaskList { + pub tasks: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub next_cursor: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub total: Option, +} diff --git a/crates/rmcp/src/task_manager.rs b/crates/rmcp/src/task_manager.rs new file mode 100644 index 00000000..d8768902 --- /dev/null +++ b/crates/rmcp/src/task_manager.rs @@ -0,0 +1,294 @@ +use std::{any::Any, collections::HashMap, pin::Pin}; + +use futures::Future; +use tokio::{ + sync::mpsc, + time::{Duration, timeout}, +}; + +use crate::{ + RoleServer, + error::{ErrorData as McpError, RmcpError as Error}, + model::{CallToolResult, ClientRequest}, + service::RequestContext, +}; + +/// Boxed future that represents an asynchronous operation managed by the processor. +pub type OperationFuture = + Pin, Error>> + Send>>; + +/// Describes metadata associated with an enqueued task. +#[derive(Debug, Clone)] +pub struct OperationDescriptor { + pub operation_id: String, + pub name: String, + pub client_request: Option, + pub context: Option>, + pub ttl: Option, +} + +impl OperationDescriptor { + pub fn new(operation_id: impl Into, name: impl Into) -> Self { + Self { + operation_id: operation_id.into(), + name: name.into(), + client_request: None, + context: None, + ttl: None, + } + } + + pub fn with_client_request(mut self, request: ClientRequest) -> Self { + self.client_request = Some(request); + self + } + + pub fn with_context(mut self, context: RequestContext) -> Self { + self.context = Some(context); + self + } + + pub fn with_ttl(mut self, ttl: u64) -> Self { + self.ttl = Some(ttl); + self + } +} + +/// Operation message describing a unit of asynchronous work. +pub struct OperationMessage { + pub descriptor: OperationDescriptor, + pub future: OperationFuture, +} + +impl OperationMessage { + pub fn new(descriptor: OperationDescriptor, future: OperationFuture) -> Self { + Self { descriptor, future } + } +} + +/// Trait for operation result transport +pub trait OperationResultTransport: Send + Sync + 'static { + fn operation_id(&self) -> &String; + fn as_any(&self) -> &dyn std::any::Any; +} + +// ===== Operation Processor ===== +pub const DEFAULT_TASK_TIMEOUT_SECS: u64 = 300; // 5 minutes +/// Operation processor that coordinates extractors and handlers +pub struct OperationProcessor { + /// Currently running tasks keyed by id + running_tasks: HashMap, + /// Completed results waiting to be collected + completed_results: Vec, + task_result_receiver: Option>, + task_result_sender: mpsc::UnboundedSender, +} + +struct RunningTask { + task_handle: tokio::task::JoinHandle<()>, + started_at: std::time::Instant, + timeout: Option, + descriptor: OperationDescriptor, +} + +pub struct TaskResult { + pub descriptor: OperationDescriptor, + pub result: Result, Error>, +} + +/// Helper to generate an ISO 8601 timestamp for task metadata. +pub fn current_timestamp() -> String { + chrono::Utc::now().to_rfc3339() +} + +/// Result transport for tool calls executed as tasks. +pub struct ToolCallTaskResult { + id: String, + pub result: Result, +} + +impl ToolCallTaskResult { + pub fn new(id: impl Into, result: Result) -> Self { + Self { + id: id.into(), + result, + } + } +} + +impl OperationResultTransport for ToolCallTaskResult { + fn operation_id(&self) -> &String { + &self.id + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +impl Default for OperationProcessor { + fn default() -> Self { + Self::new() + } +} + +impl OperationProcessor { + pub fn new() -> Self { + let (task_result_sender, task_result_receiver) = mpsc::unbounded_channel(); + Self { + running_tasks: HashMap::new(), + completed_results: Vec::new(), + task_result_receiver: Some(task_result_receiver), + task_result_sender, + } + } + + /// Submit an operation for asynchronous execution. + #[allow(clippy::result_large_err)] + pub fn submit_operation(&mut self, message: OperationMessage) -> Result<(), Error> { + if self + .running_tasks + .contains_key(&message.descriptor.operation_id) + { + return Err(Error::TaskError(format!( + "Operation with id {} is already running", + message.descriptor.operation_id + ))); + } + self.spawn_async_task(message); + Ok(()) + } + + fn spawn_async_task(&mut self, message: OperationMessage) { + let OperationMessage { descriptor, future } = message; + let task_id = descriptor.operation_id.clone(); + let timeout_secs = descriptor.ttl.or(Some(DEFAULT_TASK_TIMEOUT_SECS)); + let sender = self.task_result_sender.clone(); + let descriptor_for_result = descriptor.clone(); + + let timed_future = async move { + if let Some(secs) = timeout_secs { + match timeout(Duration::from_secs(secs), future).await { + Ok(result) => result, + Err(_) => Err(Error::TaskError("Operation timed out".to_string())), + } + } else { + future.await + } + }; + + let handle = tokio::spawn(async move { + let result = timed_future.await; + let task_result = TaskResult { + descriptor: descriptor_for_result, + result, + }; + let _ = sender.send(task_result); + }); + let running_task = RunningTask { + task_handle: handle, + started_at: std::time::Instant::now(), + timeout: timeout_secs, + descriptor, + }; + self.running_tasks.insert(task_id, running_task); + } + + /// Collect completed results from running tasks and remove them from the running tasks map. + pub fn collect_completed_results(&mut self) -> Vec { + if let Some(receiver) = &mut self.task_result_receiver { + while let Ok(result) = receiver.try_recv() { + self.running_tasks.remove(&result.descriptor.operation_id); + self.completed_results.push(result); + } + } + std::mem::take(&mut self.completed_results) + } + + /// Check for tasks that have exceeded their timeout and handle them appropriately. + pub fn check_timeouts(&mut self) { + let now = std::time::Instant::now(); + let mut timed_out_tasks = Vec::new(); + + for (task_id, task) in &self.running_tasks { + if let Some(timeout_duration) = task.timeout { + if now.duration_since(task.started_at).as_secs() > timeout_duration { + task.task_handle.abort(); + timed_out_tasks.push(task_id.clone()); + } + } + } + + for task_id in timed_out_tasks { + if let Some(task) = self.running_tasks.remove(&task_id) { + let timeout_result = TaskResult { + descriptor: task.descriptor, + result: Err(Error::TaskError("Operation timed out".to_string())), + }; + self.completed_results.push(timeout_result); + } + } + } + + /// Get the number of running tasks. + pub fn running_task_count(&self) -> usize { + self.running_tasks.len() + } + + /// Cancel all running tasks. + pub fn cancel_all_tasks(&mut self) { + for (_, task) in self.running_tasks.drain() { + task.task_handle.abort(); + } + self.completed_results.clear(); + } + /// List running task ids. + pub fn list_running(&self) -> Vec { + self.running_tasks.keys().cloned().collect() + } + + /// Note: collectors should call collect_completed_results; this provides a snapshot of queued results. + pub fn peek_completed(&self) -> &[TaskResult] { + &self.completed_results + } + + /// Fetch the metadata for a running or recently completed task. + pub fn task_descriptor(&self, task_id: &str) -> Option<&OperationDescriptor> { + if let Some(task) = self.running_tasks.get(task_id) { + return Some(&task.descriptor); + } + self.completed_results + .iter() + .rev() + .find(|result| result.descriptor.operation_id == task_id) + .map(|result| &result.descriptor) + } + + /// Attempt to cancel a running task. + pub fn cancel_task(&mut self, task_id: &str) -> bool { + if let Some(task) = self.running_tasks.remove(task_id) { + task.task_handle.abort(); + // Insert a cancelled result so callers can observe the terminal state. + let cancel_result = TaskResult { + descriptor: task.descriptor, + result: Err(Error::TaskError("Operation cancelled".to_string())), + }; + self.completed_results.push(cancel_result); + return true; + } + false + } + + /// Retrieve a completed task result if available. + pub fn take_completed_result(&mut self, task_id: &str) -> Option { + if let Some(position) = self + .completed_results + .iter() + .position(|result| result.descriptor.operation_id == task_id) + { + Some(self.completed_results.remove(position)) + } else { + None + } + } +} diff --git a/crates/rmcp/src/transport/streamable_http_client.rs b/crates/rmcp/src/transport/streamable_http_client.rs index 4db461a4..61f5074b 100644 --- a/crates/rmcp/src/transport/streamable_http_client.rs +++ b/crates/rmcp/src/transport/streamable_http_client.rs @@ -384,6 +384,7 @@ impl Worker for StreamableHttpClientWorker { "process initialized notification response", ))?; let _ = initialized_notification.responder.send(Ok(())); + #[allow(clippy::large_enum_variant)] enum Event { ClientMessage(WorkerSendRequest), ServerMessage(ServerJsonRpcMessage), diff --git a/crates/rmcp/tests/test_progress_subscriber.rs b/crates/rmcp/tests/test_progress_subscriber.rs index 531b1692..b5d185ab 100644 --- a/crates/rmcp/tests/test_progress_subscriber.rs +++ b/crates/rmcp/tests/test_progress_subscriber.rs @@ -110,6 +110,7 @@ async fn test_progress_subscriber() -> anyhow::Result<()> { ClientRequest::CallToolRequest(Request::new(CallToolRequestParam { name: "some_progress".into(), arguments: None, + task: None, })), PeerRequestOptions::no_options(), ) diff --git a/crates/rmcp/tests/test_task.rs b/crates/rmcp/tests/test_task.rs new file mode 100644 index 00000000..31fc9a9b --- /dev/null +++ b/crates/rmcp/tests/test_task.rs @@ -0,0 +1,77 @@ +use std::{any::Any, time::Duration}; + +use rmcp::task_manager::{ + OperationDescriptor, OperationMessage, OperationProcessor, OperationResultTransport, +}; + +struct DummyTransport { + id: String, + value: u32, +} + +impl OperationResultTransport for DummyTransport { + fn operation_id(&self) -> &String { + &self.id + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +#[tokio::test] +async fn executes_enqueued_future() { + let mut processor = OperationProcessor::new(); + let descriptor = OperationDescriptor::new("op1", "dummy"); + let future = Box::pin(async { + tokio::time::sleep(Duration::from_millis(10)).await; + Ok(Box::new(DummyTransport { + id: "op1".to_string(), + value: 42, + }) as Box) + }); + + processor + .submit_operation(OperationMessage::new(descriptor, future)) + .expect("submit operation"); + + tokio::time::sleep(Duration::from_millis(30)).await; + let results = processor.collect_completed_results(); + assert_eq!(results.len(), 1); + let payload = results[0] + .result + .as_ref() + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(payload.value, 42); +} + +#[tokio::test] +async fn rejects_duplicate_operation_ids() { + let mut processor = OperationProcessor::new(); + let descriptor = OperationDescriptor::new("dup", "dummy"); + let future = Box::pin(async { + Ok(Box::new(DummyTransport { + id: "dup".to_string(), + value: 1, + }) as Box) + }); + processor + .submit_operation(OperationMessage::new(descriptor, future)) + .expect("first submit"); + + let descriptor_dup = OperationDescriptor::new("dup", "dummy"); + let future_dup = Box::pin(async { + Ok(Box::new(DummyTransport { + id: "dup".to_string(), + value: 2, + }) as Box) + }); + + let err = processor + .submit_operation(OperationMessage::new(descriptor_dup, future_dup)) + .expect_err("duplicate should fail"); + assert!(format!("{err}").contains("already running")); +} diff --git a/crates/rmcp/tests/test_tool_macros.rs b/crates/rmcp/tests/test_tool_macros.rs index db5242b3..763c4f43 100644 --- a/crates/rmcp/tests/test_tool_macros.rs +++ b/crates/rmcp/tests/test_tool_macros.rs @@ -320,6 +320,7 @@ async fn test_optional_i64_field_with_null_input() -> anyhow::Result<()> { .unwrap() .clone(), ), + task: None, }) .await?; @@ -348,6 +349,7 @@ async fn test_optional_i64_field_with_null_input() -> anyhow::Result<()> { .unwrap() .clone(), ), + task: None, }) .await?; diff --git a/examples/clients/src/collection.rs b/examples/clients/src/collection.rs index 67969ae4..c714da54 100644 --- a/examples/clients/src/collection.rs +++ b/examples/clients/src/collection.rs @@ -49,6 +49,7 @@ async fn main() -> Result<()> { .call_tool(CallToolRequestParam { name: "git_status".into(), arguments: serde_json::json!({ "repo_path": "." }).as_object().cloned(), + task: None, }) .await?; } diff --git a/examples/clients/src/everything_stdio.rs b/examples/clients/src/everything_stdio.rs index 107adc07..f1cbcae5 100644 --- a/examples/clients/src/everything_stdio.rs +++ b/examples/clients/src/everything_stdio.rs @@ -40,6 +40,7 @@ async fn main() -> Result<()> { .call_tool(CallToolRequestParam { name: "echo".into(), arguments: Some(object!({ "message": "hi from rmcp" })), + task: None, }) .await?; tracing::info!("Tool result for echo: {tool_result:#?}"); @@ -49,6 +50,7 @@ async fn main() -> Result<()> { .call_tool(CallToolRequestParam { name: "longRunningOperation".into(), arguments: Some(object!({ "duration": 3, "steps": 1 })), + task: None, }) .await?; tracing::info!("Tool result for longRunningOperation: {tool_result:#?}"); diff --git a/examples/clients/src/git_stdio.rs b/examples/clients/src/git_stdio.rs index d1298b36..7b516f38 100644 --- a/examples/clients/src/git_stdio.rs +++ b/examples/clients/src/git_stdio.rs @@ -42,6 +42,7 @@ async fn main() -> Result<(), RmcpError> { .call_tool(CallToolRequestParam { name: "git_status".into(), arguments: serde_json::json!({ "repo_path": "." }).as_object().cloned(), + task: None, }) .await?; tracing::info!("Tool result: {tool_result:#?}"); diff --git a/examples/clients/src/progress_client.rs b/examples/clients/src/progress_client.rs index ddf18b2f..c795ce22 100644 --- a/examples/clients/src/progress_client.rs +++ b/examples/clients/src/progress_client.rs @@ -184,6 +184,7 @@ async fn test_stdio_transport(records: u32) -> Result<()> { .call_tool(CallToolRequestParam { name: "stream_processor".into(), arguments: None, + task: None, }) .await?; @@ -238,6 +239,7 @@ async fn test_http_transport(http_url: &str, records: u32) -> Result<()> { .call_tool(CallToolRequestParam { name: "stream_processor".into(), arguments: None, + task: None, }) .await?; diff --git a/examples/clients/src/sampling_stdio.rs b/examples/clients/src/sampling_stdio.rs index 8f5aba22..b30a3c26 100644 --- a/examples/clients/src/sampling_stdio.rs +++ b/examples/clients/src/sampling_stdio.rs @@ -106,6 +106,7 @@ async fn main() -> Result<()> { arguments: Some(object!({ "question": "Hello world" })), + task: None, }) .await { diff --git a/examples/clients/src/streamable_http.rs b/examples/clients/src/streamable_http.rs index 2f1f1598..cd4b73c4 100644 --- a/examples/clients/src/streamable_http.rs +++ b/examples/clients/src/streamable_http.rs @@ -44,6 +44,7 @@ async fn main() -> Result<()> { .call_tool(CallToolRequestParam { name: "increment".into(), arguments: serde_json::json!({}).as_object().cloned(), + task: None, }) .await?; tracing::info!("Tool result: {tool_result:#?}"); diff --git a/examples/rig-integration/src/mcp_adaptor.rs b/examples/rig-integration/src/mcp_adaptor.rs index 483c6e02..286e58d5 100644 --- a/examples/rig-integration/src/mcp_adaptor.rs +++ b/examples/rig-integration/src/mcp_adaptor.rs @@ -47,6 +47,7 @@ impl RigTool for McpToolAdaptor { name: self.tool.name.clone(), arguments: serde_json::from_str(&args) .map_err(rig::tool::ToolError::JsonError)?, + task: None, }) .await .inspect(|result| tracing::info!(?result)) diff --git a/examples/servers/src/common/counter.rs b/examples/servers/src/common/counter.rs index dc2472bb..ac271cba 100644 --- a/examples/servers/src/common/counter.rs +++ b/examples/servers/src/common/counter.rs @@ -1,6 +1,7 @@ #![allow(dead_code)] -use std::sync::Arc; +use std::{any::Any, sync::Arc}; +use chrono::Utc; use rmcp::{ ErrorData as McpError, RoleServer, ServerHandler, handler::server::{ @@ -10,10 +11,30 @@ use rmcp::{ model::*, prompt, prompt_handler, prompt_router, schemars, service::RequestContext, + task_handler, + task_manager::{ + OperationDescriptor, OperationMessage, OperationProcessor, OperationResultTransport, + }, tool, tool_handler, tool_router, }; use serde_json::json; use tokio::sync::Mutex; +use tracing::info; + +struct ToolCallOperationResult { + id: String, + result: Result, +} + +impl OperationResultTransport for ToolCallOperationResult { + fn operation_id(&self) -> &String { + &self.id + } + + fn as_any(&self) -> &dyn Any { + self + } +} #[derive(Debug, serde::Deserialize, schemars::JsonSchema)] pub struct StructRequest { @@ -41,6 +62,7 @@ pub struct Counter { counter: Arc>, tool_router: ToolRouter, prompt_router: PromptRouter, + processor: Arc>, } #[tool_router] @@ -51,6 +73,7 @@ impl Counter { counter: Arc::new(Mutex::new(0)), tool_router: Self::tool_router(), prompt_router: Self::prompt_router(), + processor: Arc::new(Mutex::new(OperationProcessor::new())), } } @@ -84,6 +107,14 @@ impl Counter { )])) } + #[tool(description = "Long running task example")] + async fn long_task(&self) -> Result { + tokio::time::sleep(std::time::Duration::from_secs(10)).await; + Ok(CallToolResult::success(vec![Content::text( + "Long task completed", + )])) + } + #[tool(description = "Say hello to the client")] fn say_hello(&self) -> Result { Ok(CallToolResult::success(vec![Content::text("hello")])) @@ -166,6 +197,7 @@ impl Counter { #[tool_handler(meta = Meta(rmcp::object!({"tool_meta_key": "tool_meta_value"})))] #[prompt_handler(meta = Meta(rmcp::object!({"router_meta_key": "router_meta_value"})))] +#[task_handler] impl ServerHandler for Counter { fn get_info(&self) -> ServerInfo { ServerInfo { @@ -250,8 +282,16 @@ impl ServerHandler for Counter { #[cfg(test)] mod tests { + use rmcp::{ClientHandler, ServiceExt}; + use tokio::time::Duration; + use super::*; + #[derive(Default, Clone)] + struct TestClient; + + impl ClientHandler for TestClient {} + #[tokio::test] async fn test_prompt_attributes_generated() { // Verify that the prompt macros generate the expected attributes @@ -289,34 +329,56 @@ mod tests { } #[tokio::test] - async fn test_example_prompt_execution() { + async fn test_client_enqueues_long_task() -> anyhow::Result<()> { let counter = Counter::new(); - let context = rmcp::handler::server::prompt::PromptContext::new( - &counter, - "example_prompt".to_string(), - Some({ - let mut map = serde_json::Map::new(); - map.insert( - "message".to_string(), - serde_json::Value::String("Test message".to_string()), - ); - map - }), - RequestContext { - meta: Default::default(), - ct: tokio_util::sync::CancellationToken::new(), - id: rmcp::model::NumberOrString::String("test-1".to_string()), - peer: Default::default(), - extensions: Default::default(), - }, + let processor = counter.processor.clone(); + let client = TestClient::default(); + + let (server_transport, client_transport) = tokio::io::duplex(4096); + let server_handle = tokio::spawn(async move { + let service = counter.serve(server_transport).await?; + service.waiting().await?; + anyhow::Ok(()) + }); + + let client_service = client.serve(client_transport).await?; + let mut task_meta = serde_json::Map::new(); + task_meta.insert( + "source".into(), + serde_json::Value::String("integration-test".into()), ); - - let router = Counter::prompt_router(); - let result = router.get_prompt(context).await; - assert!(result.is_ok()); - - let prompt_result = result.unwrap(); - assert_eq!(prompt_result.messages.len(), 1); - assert_eq!(prompt_result.messages[0].role, PromptMessageRole::User); + let params = CallToolRequestParam { + name: "long_task".into(), + arguments: None, + task: Some(task_meta), + }; + let response = client_service + .send_request(ClientRequest::CallToolRequest(Request::new(params.clone()))) + .await?; + + let ServerResult::CreateTaskResult(info) = response else { + panic!("expected task creation result, got {response:?}"); + }; + let task = info.task; + + assert_eq!(task.status, TaskStatus::Working); + // task list should show the task + let tasks = client_service + .send_request(ClientRequest::ListTasksRequest( + RequestOptionalParam::default(), + )) + .await + .unwrap(); + let ServerResult::ListTasksResult(listed) = tasks else { + panic!("expected list tasks result, got {tasks:?}"); + }; + assert_eq!(listed.tasks[0].task_id, task.task_id); + tokio::time::sleep(Duration::from_millis(50)).await; + let running = processor.lock().await.running_task_count(); + assert_eq!(running, 1); + + client_service.cancel().await?; + let _ = server_handle.await; + Ok(()) } } diff --git a/examples/simple-chat-client/src/tool.rs b/examples/simple-chat-client/src/tool.rs index 771b4e9e..174f4274 100644 --- a/examples/simple-chat-client/src/tool.rs +++ b/examples/simple-chat-client/src/tool.rs @@ -62,6 +62,7 @@ impl Tool for McpToolAdapter { .call_tool(CallToolRequestParam { name: self.tool.name.clone(), arguments, + task: None, }) .await?; diff --git a/examples/transport/src/unix_socket.rs b/examples/transport/src/unix_socket.rs index feeb2b87..0d91dfee 100644 --- a/examples/transport/src/unix_socket.rs +++ b/examples/transport/src/unix_socket.rs @@ -52,6 +52,7 @@ async fn main() -> anyhow::Result<()> { "a": 10, "b": 20 })), + task: None, }) .await?;