diff --git a/code-rs/rmcp-client/src/rmcp_client.rs b/code-rs/rmcp-client/src/rmcp_client.rs index b8896ae5859..c82ef1184b3 100644 --- a/code-rs/rmcp-client/src/rmcp_client.rs +++ b/code-rs/rmcp-client/src/rmcp_client.rs @@ -14,6 +14,7 @@ use mcp_types::InitializeRequestParams; use mcp_types::InitializeResult; use mcp_types::ListToolsRequestParams; use mcp_types::ListToolsResult; +use mcp_types::MCP_SCHEMA_VERSION; use rmcp::model::CallToolRequestParam; use rmcp::model::InitializeRequestParam; use rmcp::model::PaginatedRequestParam; @@ -149,20 +150,30 @@ impl RmcpClient { }; let service = match timeout { - Some(duration) => time::timeout(duration, service_future) - .await - .map_err(|_| anyhow!("timed out handshaking with MCP server after {duration:?}"))? - .map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?, - None => service_future - .await - .map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?, + Some(duration) => match time::timeout(duration, service_future).await { + Ok(Ok(service)) => service, + Ok(Err(err)) => return Err(handshake_failed_error(err)), + Err(_) => return Err(handshake_timeout_error(duration)), + }, + None => match service_future.await { + Ok(service) => service, + Err(err) => return Err(handshake_failed_error(err)), + }, }; let initialize_result_rmcp = service .peer() .peer_info() .ok_or_else(|| anyhow!("handshake succeeded but server info was missing"))?; - let initialize_result = convert_to_mcp(initialize_result_rmcp)?; + let initialize_result: InitializeResult = convert_to_mcp(initialize_result_rmcp)?; + + if initialize_result.protocol_version != MCP_SCHEMA_VERSION { + let reported_version = initialize_result.protocol_version.clone(); + return Err(anyhow!( + "MCP server reported protocol version {reported_version}, but this client expects {}. Update either side so both speak the same schema.", + MCP_SCHEMA_VERSION + )); + } { let mut guard = self.state.lock().await; @@ -217,3 +228,16 @@ impl RmcpClient { } } } + +fn handshake_failed_error(err: impl Into) -> anyhow::Error { + let err = err.into(); + anyhow!( + "handshaking with MCP server failed: {err} (this client supports MCP schema version {MCP_SCHEMA_VERSION})" + ) +} + +fn handshake_timeout_error(duration: Duration) -> anyhow::Error { + anyhow!( + "timed out handshaking with MCP server after {duration:?} (expected MCP schema version {MCP_SCHEMA_VERSION})" + ) +}