Skip to content

Commit 59bfb6e

Browse files
authored
refactor(sse_server): separate router and server startup (modelcontextprotocol#52)
* feat: expose axum router * feat: add axum_router example with SSE server implementation * refactor: simplify SseServer configuration handling in server setup * docs: add warning to SseServer::new about potential post_path issues with embedded routers
1 parent b9e1922 commit 59bfb6e

File tree

3 files changed

+77
-11
lines changed

3 files changed

+77
-11
lines changed

crates/rmcp/src/transport/sse_server.rs

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -219,13 +219,9 @@ impl SseServer {
219219
.await
220220
}
221221
pub async fn serve_with_config(config: SseServerConfig) -> io::Result<Self> {
222-
let (app, transport_rx) = App::new(config.post_path.clone());
223-
let listener = tokio::net::TcpListener::bind(config.bind).await?;
224-
let service = Router::new()
225-
.route(&config.sse_path, get(sse_handler))
226-
.route(&config.post_path, post(post_event_handler))
227-
.with_state(app);
228-
let ct = config.ct.child_token();
222+
let (sse_server, service) = Self::new(config);
223+
let listener = tokio::net::TcpListener::bind(sse_server.config.bind).await?;
224+
let ct = sse_server.config.ct.child_token();
229225
let server = axum::serve(listener, service).with_graceful_shutdown(async move {
230226
ct.cancelled().await;
231227
tracing::info!("sse server cancelled");
@@ -236,13 +232,28 @@ impl SseServer {
236232
tracing::error!(error = %e, "sse server shutdown with error");
237233
}
238234
}
239-
.instrument(tracing::info_span!("sse-server", bind_address = %config.bind)),
235+
.instrument(tracing::info_span!("sse-server", bind_address = %sse_server.config.bind)),
240236
);
241-
Ok(Self {
237+
Ok(sse_server)
238+
}
239+
240+
/// Warning: This function creates a new SseServer instance with the provided configuration.
241+
/// `App.post_path` may be incorrect if using `Router` as an embedded router.
242+
pub fn new(config: SseServerConfig) -> (SseServer, Router) {
243+
let (app, transport_rx) = App::new(config.post_path.clone());
244+
let router = Router::new()
245+
.route(&config.sse_path, get(sse_handler))
246+
.route(&config.post_path, post(post_event_handler))
247+
.with_state(app);
248+
249+
let server = SseServer {
242250
transport_rx,
243251
config,
244-
})
252+
};
253+
254+
(server, router)
245255
}
256+
246257
pub fn with_service<S, F>(mut self, service_provider: F) -> CancellationToken
247258
where
248259
S: Service<RoleServer>,

examples/servers/Cargo.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,8 @@ path = "src/std_io.rs"
3333

3434
[[example]]
3535
name = "axum"
36-
path = "src/axum.rs"
36+
path = "src/axum.rs"
37+
38+
[[example]]
39+
name = "axum_router"
40+
path = "src/axum_router.rs"

examples/servers/src/axum_router.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
use rmcp::transport::sse_server::{SseServer, SseServerConfig};
2+
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
3+
4+
use tracing_subscriber::{self};
5+
mod common;
6+
use common::counter::Counter;
7+
8+
const BIND_ADDRESS: &str = "127.0.0.1:8000";
9+
10+
#[tokio::main]
11+
async fn main() -> anyhow::Result<()> {
12+
tracing_subscriber::registry()
13+
.with(
14+
tracing_subscriber::EnvFilter::try_from_default_env()
15+
.unwrap_or_else(|_| "debug".to_string().into()),
16+
)
17+
.with(tracing_subscriber::fmt::layer())
18+
.init();
19+
20+
let config = SseServerConfig {
21+
bind: BIND_ADDRESS.parse()?,
22+
sse_path: "/sse".to_string(),
23+
post_path: "/message".to_string(),
24+
ct: tokio_util::sync::CancellationToken::new(),
25+
};
26+
27+
let (sse_server, router) = SseServer::new(config);
28+
29+
// Do something with the router, e.g., add routes or middleware
30+
31+
let listener = tokio::net::TcpListener::bind(sse_server.config.bind).await?;
32+
33+
let ct = sse_server.config.ct.child_token();
34+
35+
let server = axum::serve(listener, router).with_graceful_shutdown(async move {
36+
ct.cancelled().await;
37+
tracing::info!("sse server cancelled");
38+
});
39+
40+
tokio::spawn(async move {
41+
if let Err(e) = server.await {
42+
tracing::error!(error = %e, "sse server shutdown with error");
43+
}
44+
});
45+
46+
let ct = sse_server.with_service(Counter::new);
47+
48+
tokio::signal::ctrl_c().await?;
49+
ct.cancel();
50+
Ok(())
51+
}

0 commit comments

Comments
 (0)