Skip to content

Add a way to listen to multiple protocols on the same endpoint. #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 84 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1093,7 +1093,7 @@ pub mod rpc {
use std::{fmt::Debug, future::Future, io, marker::PhantomData, pin::Pin, sync::Arc};

use n0_future::{future::Boxed as BoxFuture, task::JoinSet};
use quinn::ConnectionError;
use quinn::{ConnectionError, Endpoint};
use serde::{de::DeserializeOwned, Serialize};
use smallvec::SmallVec;
use tracing::{trace, trace_span, warn, Instrument};
Expand Down Expand Up @@ -1470,6 +1470,89 @@ pub mod rpc {
request_id += 1;
}
}

type MultiHandler = Arc<
dyn Fn(
&[u8],
quinn::RecvStream,
quinn::SendStream,
) -> std::result::Result<
BoxFuture<std::result::Result<(), SendError>>,
(quinn::RecvStream, quinn::SendStream),
> + Send
+ Sync
+ 'static,
>;

pub struct Listener {
handlers: Vec<MultiHandler>,
}

impl Listener {
pub fn add_handler<R: DeserializeOwned + 'static>(mut self, handler: Handler<R>) -> Self {
self.handlers.push(Arc::new(
move |buf, recv, send| match postcard::from_bytes::<R>(buf) {
Err(_) => Err((recv, send)),
Ok(msg) => Ok(handler(msg, recv, send)),
},
));
self
}

pub async fn listen(self, endpoint: Endpoint) {
let mut request_id = 0u64;
let mut tasks = JoinSet::new();
while let Some(incoming) = endpoint.accept().await {
let handlers = self.handlers.clone();
let fut = async move {
let connection = match incoming.await {
Ok(connection) => connection,
Err(cause) => {
warn!("failed to accept connection {cause:?}");
return io::Result::Ok(());
}
};
loop {
let (mut send, mut recv) = match connection.accept_bi().await {
Ok((s, r)) => (s, r),
Err(ConnectionError::ApplicationClosed(cause))
if cause.error_code.into_inner() == 0 =>
{
trace!("remote side closed connection {cause:?}");
return Ok(());
}
Err(cause) => {
warn!("failed to accept bi stream {cause:?}");
return Err(cause.into());
}
};
let size = recv.read_varint_u64().await?.ok_or_else(|| {
io::Error::new(io::ErrorKind::UnexpectedEof, "failed to read size")
})?;
let mut buf = vec![0; size as usize];
recv.read_exact(&mut buf)
.await
.map_err(|e| io::Error::new(io::ErrorKind::UnexpectedEof, e))?;
for handler in &handlers {
match handler(&buf, recv, send) {
Ok(fut) => {
fut.await?;
break;
}
Err((recv_ret, send_ret)) => {
recv = recv_ret;
send = send_ret;
}
}
}
}
};
let span = trace_span!("rpc", id = request_id);
tasks.spawn(fut.instrument(span));
request_id += 1;
}
}
}
}

/// A request to a service. This can be either local or remote.
Expand Down