Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions crates/rmcp-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod common;
mod prompt;
mod prompt_handler;
mod prompt_router;
mod task_handler;
mod tool;
mod tool_handler;
mod tool_router;
Expand Down Expand Up @@ -263,3 +264,17 @@ pub fn prompt_handler(attr: TokenStream, input: TokenStream) -> TokenStream {
.unwrap_or_else(|err| err.to_compile_error())
.into()
}

/// # task_handler
///
/// Generates basic task-handling methods (`enqueue_task` and `list_tasks`) for a server handler
/// using a shared [`OperationProcessor`]. The default processor expression assumes a
/// `self.processor` field holding an `Arc<Mutex<OperationProcessor>>`, but it can be customized
/// via `#[task_handler(processor = ...)]`. Because the macro captures `self` inside spawned
/// futures, the handler type must implement [`Clone`].
Comment on lines +273 to +274
Copy link

Copilot AI Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The macro requires the handler type to implement Clone (line 82) to spawn the task, but this requirement is not documented in the macro's documentation comment. Users will encounter confusing compiler errors if their handler doesn't implement Clone. The documentation at lines 268-274 should explicitly state that the handler must implement Clone.

Suggested change
/// via `#[task_handler(processor = ...)]`. Because the macro captures `self` inside spawned
/// futures, the handler type must implement [`Clone`].
/// via `#[task_handler(processor = ...)]`.
///
/// **Requirements:** The handler type must implement [`Clone`], as the macro captures `self` inside spawned futures.

Copilot uses AI. Check for mistakes.
#[proc_macro_attribute]
pub fn task_handler(attr: TokenStream, input: TokenStream) -> TokenStream {
task_handler::task_handler(attr.into(), input.into())
.unwrap_or_else(|err| err.to_compile_error())
.into()
}
278 changes: 278 additions & 0 deletions crates/rmcp-macros/src/task_handler.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
use darling::{FromMeta, ast::NestedMeta};
use proc_macro2::TokenStream;
use quote::{ToTokens, quote};
use syn::{Expr, ImplItem, ItemImpl};

#[derive(FromMeta)]
#[darling(default)]
struct TaskHandlerAttribute {
processor: Expr,
}

impl Default for TaskHandlerAttribute {
fn default() -> Self {
Self {
processor: syn::parse2(quote! { self.processor }).expect("default processor expr"),
}
}
}

pub fn task_handler(attr: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
let attr_args = NestedMeta::parse_meta_list(attr)?;
let TaskHandlerAttribute { processor } = TaskHandlerAttribute::from_list(&attr_args)?;
let mut item_impl = syn::parse2::<ItemImpl>(input.clone())?;

let has_method = |name: &str, item_impl: &ItemImpl| -> bool {
item_impl.items.iter().any(|item| match item {
ImplItem::Fn(func) => func.sig.ident == name,
_ => false,
})
};

if !has_method("list_tasks", &item_impl) {
let list_fn = quote! {
async fn list_tasks(
&self,
_request: Option<rmcp::model::PaginatedRequestParam>,
_: rmcp::service::RequestContext<rmcp::RoleServer>,
) -> Result<rmcp::model::ListTasksResult, McpError> {
let running_ids = (#processor).lock().await.list_running();
let total = running_ids.len() as u64;
let tasks = running_ids
.into_iter()
.map(|task_id| {
let timestamp = rmcp::task_manager::current_timestamp();
rmcp::model::Task {
task_id,
status: rmcp::model::TaskStatus::Working,
status_message: None,
created_at: timestamp.clone(),
last_updated_at: Some(timestamp),
ttl: None,
poll_interval: None,
}
})
.collect::<Vec<_>>();

Ok(rmcp::model::ListTasksResult {
tasks,
next_cursor: None,
total: Some(total),
})
}
Comment on lines +34 to +62
Copy link

Copilot AI Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The generated list_tasks method assumes all running tasks have status Working, but it doesn't check completed results that might not have been collected yet. This means tasks that have just completed but haven't been polled yet won't appear in the list, which could confuse clients. The method should call collect_completed_results first and include recently completed tasks in the listing.

Copilot uses AI. Check for mistakes.
};
item_impl.items.push(syn::parse2::<ImplItem>(list_fn)?);
}

if !has_method("enqueue_task", &item_impl) {
let enqueue_fn = quote! {
async fn enqueue_task(
&self,
request: rmcp::model::CallToolRequestParam,
context: rmcp::service::RequestContext<rmcp::RoleServer>,
) -> Result<rmcp::model::CreateTaskResult, McpError> {
use rmcp::task_manager::{
current_timestamp, OperationDescriptor, OperationMessage, OperationResultTransport,
ToolCallTaskResult,
};
let task_id = context.id.to_string();
let operation_name = request.name.to_string();
let future_request = request.clone();
let future_context = context.clone();
let server = self.clone();

let descriptor = OperationDescriptor::new(task_id.clone(), operation_name)
.with_context(context)
.with_client_request(rmcp::model::ClientRequest::CallToolRequest(
rmcp::model::Request::new(request),
));

let task_result_id = task_id.clone();
let future = Box::pin(async move {
let result = server.call_tool(future_request, future_context).await;
Ok(
Box::new(ToolCallTaskResult::new(task_result_id, result))
as Box<dyn OperationResultTransport>,
)
});

(#processor)
.lock()
.await
.submit_operation(OperationMessage::new(descriptor, future))
.map_err(|err| rmcp::ErrorData::internal_error(
format!("failed to enqueue task: {err}"),
None,
))?;

let timestamp = current_timestamp();
let task = rmcp::model::Task {
task_id,
status: rmcp::model::TaskStatus::Working,
status_message: Some("Task accepted".to_string()),
created_at: timestamp.clone(),
last_updated_at: Some(timestamp),
ttl: None,
poll_interval: None,
};

Ok(rmcp::model::CreateTaskResult { task })
}
};
item_impl.items.push(syn::parse2::<ImplItem>(enqueue_fn)?);
}

if !has_method("get_task_info", &item_impl) {
let get_info_fn = quote! {
async fn get_task_info(
&self,
request: rmcp::model::GetTaskInfoParam,
_context: rmcp::service::RequestContext<rmcp::RoleServer>,
) -> Result<rmcp::model::GetTaskInfoResult, McpError> {
use rmcp::task_manager::current_timestamp;
let task_id = request.task_id.clone();
let mut processor = (#processor).lock().await;
processor.collect_completed_results();

// Check completed results first
let completed = processor.peek_completed().iter().rev().find(|r| r.descriptor.operation_id == task_id);
if let Some(completed_result) = completed {
// Determine Finished vs Failed
let status = match &completed_result.result {
Ok(boxed) => {
if let Some(tool) = boxed.as_any().downcast_ref::<rmcp::task_manager::ToolCallTaskResult>() {
match &tool.result {
Ok(_) => rmcp::model::TaskStatus::Completed,
Err(_) => rmcp::model::TaskStatus::Failed,
}
} else {
rmcp::model::TaskStatus::Completed
}
}
Err(_) => rmcp::model::TaskStatus::Failed,
};
let timestamp = current_timestamp();
let task = rmcp::model::Task {
task_id,
status,
status_message: None,
created_at: timestamp.clone(),
last_updated_at: Some(timestamp),
ttl: completed_result.descriptor.ttl,
poll_interval: None,
};
return Ok(rmcp::model::GetTaskInfoResult { task: Some(task) });
}

// If not completed, check running
let running = processor.list_running();
if running.into_iter().any(|id| id == task_id) {
let timestamp = current_timestamp();
let task = rmcp::model::Task {
task_id,
status: rmcp::model::TaskStatus::Working,
status_message: None,
created_at: timestamp.clone(),
last_updated_at: Some(timestamp),
ttl: None,
poll_interval: None,
};
return Ok(rmcp::model::GetTaskInfoResult { task: Some(task) });
Comment on lines +169 to +180
Copy link

Copilot AI Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The generated created_at and last_updated_at timestamps use current_timestamp() which calls chrono::Utc::now() at the time of query, not when the task was actually created. This means the timestamps don't reflect the true task creation or update times, but rather when the status was queried. These timestamps should be stored in the RunningTask structure and retrieved from there for accuracy.

Copilot uses AI. Check for mistakes.
}

Ok(rmcp::model::GetTaskInfoResult { task: None })
}
};
item_impl.items.push(syn::parse2::<ImplItem>(get_info_fn)?);
}

if !has_method("get_task_result", &item_impl) {
let get_result_fn = quote! {
async fn get_task_result(
&self,
request: rmcp::model::GetTaskResultParam,
_context: rmcp::service::RequestContext<rmcp::RoleServer>,
) -> Result<rmcp::model::TaskResult, McpError> {
use std::time::Duration;
let task_id = request.task_id.clone();

loop {
// Scope the lock so we can await outside if needed
{
let mut processor = (#processor).lock().await;
processor.collect_completed_results();

if let Some(task_result) = processor.take_completed_result(&task_id) {
match task_result.result {
Ok(boxed) => {
if let Some(tool) = boxed.as_any().downcast_ref::<rmcp::task_manager::ToolCallTaskResult>() {
match &tool.result {
Ok(call_tool) => {
let value = ::serde_json::to_value(call_tool).unwrap_or(::serde_json::Value::Null);
return Ok(rmcp::model::TaskResult {
content_type: "application/json".to_string(),
value,
summary: None,
});
}
Err(err) => return Err(McpError::internal_error(
format!("task failed: {}", err),
None,
)),
}
} else {
return Err(McpError::internal_error("unsupported task result transport", None));
}
}
Err(err) => return Err(McpError::internal_error(
format!("task execution error: {}", err),
None,
)),
}
}

// Not completed yet: if not running, return not found
let running = processor.list_running();
if !running.iter().any(|id| id == &task_id) {
return Err(McpError::resource_not_found(format!("task not found: {}", task_id), None));
}
}

tokio::time::sleep(Duration::from_millis(100)).await;
}
}
Comment on lines +199 to +243
Copy link

Copilot AI Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The get_task_result method uses an infinite polling loop with a 100ms sleep between iterations. This can lead to resource exhaustion if many clients are simultaneously polling for task results. Consider implementing a more efficient notification mechanism (e.g., using tokio::sync::watch or tokio::sync::Notify) to wake up waiting clients when results are available, or at minimum add a maximum retry count or timeout to prevent infinite polling.

Copilot uses AI. Check for mistakes.
Comment on lines +190 to +243
Copy link

Copilot AI Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The get_task_result method implementation lacks test coverage. While there's a basic integration test that verifies task enqueueing and listing, there's no test that validates the actual result retrieval mechanism via GetTaskResultRequest. This is a critical path that involves complex polling logic and should be tested to ensure it correctly waits for and returns task results.

Copilot uses AI. Check for mistakes.
};
item_impl
.items
.push(syn::parse2::<ImplItem>(get_result_fn)?);
}

if !has_method("cancel_task", &item_impl) {
let cancel_fn = quote! {
async fn cancel_task(
&self,
request: rmcp::model::CancelTaskParam,
_context: rmcp::service::RequestContext<rmcp::RoleServer>,
) -> Result<(), McpError> {
let task_id = request.task_id;
let mut processor = (#processor).lock().await;
processor.collect_completed_results();

if processor.cancel_task(&task_id) {
return Ok(());
}

// If already completed, signal it's not cancellable
let exists_completed = processor.peek_completed().iter().any(|r| r.descriptor.operation_id == task_id);
if exists_completed {
return Err(McpError::invalid_request(format!("task already completed: {}", task_id), None));
}

Err(McpError::resource_not_found(format!("task not found: {}", task_id), None))
}
};
Comment on lines +251 to +273
Copy link

Copilot AI Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cancel_task method implementation lacks test coverage. There's no test verifying that task cancellation works correctly, that cancelled tasks return the appropriate error, or that attempting to cancel already-completed tasks produces the expected error message. Given the complexity of the cancellation logic, this should have explicit test coverage.

Copilot uses AI. Check for mistakes.
item_impl.items.push(syn::parse2::<ImplItem>(cancel_fn)?);
}

Ok(item_impl.into_token_stream())
}
5 changes: 5 additions & 0 deletions crates/rmcp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,8 @@ path = "tests/test_progress_subscriber.rs"
name = "test_elicitation"
required-features = ["elicitation", "client", "server"]
path = "tests/test_elicitation.rs"

[[test]]
name = "test_task"
required-features = ["server", "client", "macros"]
path = "tests/test_task.rs"
4 changes: 4 additions & 0 deletions crates/rmcp/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ pub enum RmcpError {
error: Box<dyn std::error::Error + Send + Sync>,
},
// and cancellation shouldn't be an error?

// TODO: add more error variants as needed
#[error("Task error: {0}")]
TaskError(String),
Comment on lines +46 to +47
Copy link

Copilot AI Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TaskError variant only contains a String message, which loses the source error information. This makes debugging difficult when tasks fail due to underlying errors (like I/O errors, network errors, etc.). Consider changing this to store a boxed error like other variants, or adding a separate variant for errors with sources: TaskError { message: String, source: Option<Box<dyn std::error::Error + Send + Sync>> }.

Copilot uses AI. Check for mistakes.
}

impl RmcpError {
Expand Down
Loading
Loading