Skip to content

Commit fec1283

Browse files
committed
feat: easy manual connection loop
1 parent 4aca16f commit fec1283

File tree

2 files changed

+313
-24
lines changed

2 files changed

+313
-24
lines changed

irpc-iroh/examples/auth.rs

Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
//! This example demonstrates a few things:
2+
//! * Using irpc with a cloneable server struct instead of with an actor loop
3+
//! * Manually implementing the connection loop
4+
//! * Authenticating peers
5+
6+
use anyhow::Result;
7+
use iroh::{protocol::Router, Endpoint};
8+
9+
use self::storage::{StorageClient, StorageServer};
10+
11+
#[tokio::main]
12+
async fn main() -> Result<()> {
13+
tracing_subscriber::fmt::init();
14+
println!("Remote use");
15+
remote().await?;
16+
Ok(())
17+
}
18+
19+
async fn remote() -> Result<()> {
20+
let (server_router, server_addr) = {
21+
let endpoint = Endpoint::builder().discovery_n0().bind().await?;
22+
let server = StorageServer::new("secret".to_string());
23+
let router = Router::builder(endpoint.clone())
24+
.accept(StorageServer::ALPN, server.clone())
25+
.spawn()
26+
.await?;
27+
let addr = endpoint.node_addr().await?;
28+
(router, addr)
29+
};
30+
31+
let client_endpoint = Endpoint::builder().bind().await?;
32+
let api = StorageClient::connect(client_endpoint, server_addr.clone());
33+
api.auth("secret").await?;
34+
api.set("hello".to_string(), "world".to_string()).await?;
35+
api.set("goodbye".to_string(), "world".to_string()).await?;
36+
let value = api.get("hello".to_string()).await?;
37+
println!("value = {:?}", value);
38+
let mut list = api.list().await?;
39+
while let Some(value) = list.recv().await? {
40+
println!("list value = {:?}", value);
41+
}
42+
43+
let client_endpoint = Endpoint::builder().bind().await?;
44+
let api = StorageClient::connect(client_endpoint, server_addr.clone());
45+
assert!(api.auth("bad").await.is_err());
46+
assert!(api.get("hello".to_string()).await.is_err());
47+
48+
let client_endpoint = Endpoint::builder().bind().await?;
49+
let api = StorageClient::connect(client_endpoint, server_addr);
50+
assert!(api.get("hello".to_string()).await.is_err());
51+
52+
drop(server_router);
53+
Ok(())
54+
}
55+
56+
mod storage {
57+
//! Implementation of our storage service.
58+
//!
59+
//! The only `pub` item is [`StorageApi`], everything else is private.
60+
61+
use std::{
62+
collections::BTreeMap,
63+
sync::{Arc, Mutex},
64+
};
65+
66+
use anyhow::Result;
67+
use iroh::{endpoint::Connection, protocol::ProtocolHandler, Endpoint};
68+
use irpc::{
69+
channel::{oneshot, spsc},
70+
Client, Service, WithChannels,
71+
};
72+
// Import the macro
73+
use irpc_derive::rpc_requests;
74+
use irpc_iroh::{read_request, IrohRemoteConnection};
75+
use serde::{Deserialize, Serialize};
76+
use tracing::info;
77+
78+
const ALPN: &[u8] = b"storage-api/0";
79+
80+
/// A simple storage service, just to try it out
81+
#[derive(Debug, Clone, Copy)]
82+
struct StorageService;
83+
84+
impl Service for StorageService {}
85+
86+
#[derive(Debug, Serialize, Deserialize)]
87+
struct Auth {
88+
token: String,
89+
}
90+
91+
#[derive(Debug, Serialize, Deserialize)]
92+
struct Get {
93+
key: String,
94+
}
95+
96+
#[derive(Debug, Serialize, Deserialize)]
97+
struct List;
98+
99+
#[derive(Debug, Serialize, Deserialize)]
100+
struct Set {
101+
key: String,
102+
value: String,
103+
}
104+
105+
#[derive(Debug, Serialize, Deserialize)]
106+
struct SetMany;
107+
108+
// Use the macro to generate both the StorageProtocol and StorageMessage enums
109+
// plus implement Channels for each type
110+
#[rpc_requests(StorageService, message = StorageMessage)]
111+
#[derive(Serialize, Deserialize)]
112+
enum StorageProtocol {
113+
#[rpc(tx=oneshot::Sender<Result<(), String>>)]
114+
Auth(Auth),
115+
#[rpc(tx=oneshot::Sender<Option<String>>)]
116+
Get(Get),
117+
#[rpc(tx=oneshot::Sender<()>)]
118+
Set(Set),
119+
#[rpc(tx=oneshot::Sender<u64>, rx=spsc::Receiver<(String, String)>)]
120+
SetMany(SetMany),
121+
#[rpc(tx=spsc::Sender<String>)]
122+
List(List),
123+
}
124+
125+
#[derive(Debug, Clone)]
126+
pub struct StorageServer {
127+
state: Arc<Mutex<BTreeMap<String, String>>>,
128+
auth_token: String,
129+
}
130+
131+
#[derive(Default)]
132+
struct PeerState {
133+
authed: bool,
134+
}
135+
136+
impl ProtocolHandler for StorageServer {
137+
fn accept(&self, conn: Connection) -> n0_future::future::Boxed<Result<()>> {
138+
let this = self.clone();
139+
Box::pin(async move {
140+
let mut peer_state = PeerState::default();
141+
while let Some((msg, rx, tx)) = read_request(&conn).await? {
142+
// Upcast the send/receive streams to the channel types each message needs.
143+
let msg: StorageMessage = match msg {
144+
StorageProtocol::Auth(msg) => WithChannels::from((msg, tx, rx)).into(),
145+
StorageProtocol::Get(msg) => WithChannels::from((msg, tx, rx)).into(),
146+
StorageProtocol::Set(msg) => WithChannels::from((msg, tx, rx)).into(),
147+
StorageProtocol::SetMany(msg) => WithChannels::from((msg, tx, rx)).into(),
148+
StorageProtocol::List(msg) => WithChannels::from((msg, tx, rx)).into(),
149+
};
150+
151+
// Handle the message
152+
if let Err(err) = this.handle(&mut peer_state, msg).await {
153+
match err {
154+
Error::Unauthorized => conn.close(401u32.into(), b"unauthorized"),
155+
Error::InvalidMessage => conn.close(400u32.into(), b"invalid message"),
156+
}
157+
break;
158+
}
159+
}
160+
conn.closed().await;
161+
Ok(())
162+
})
163+
}
164+
}
165+
166+
enum Error {
167+
Unauthorized,
168+
InvalidMessage,
169+
}
170+
171+
impl StorageServer {
172+
pub const ALPN: &[u8] = ALPN;
173+
174+
pub fn new(auth_token: String) -> Self {
175+
Self {
176+
state: Default::default(),
177+
auth_token,
178+
}
179+
}
180+
181+
async fn handle(
182+
&self,
183+
peer_state: &mut PeerState,
184+
msg: StorageMessage,
185+
) -> Result<(), Error> {
186+
if !peer_state.authed && !matches!(msg, StorageMessage::Auth(_)) {
187+
return Err(Error::InvalidMessage);
188+
}
189+
match msg {
190+
StorageMessage::Auth(auth) => {
191+
let WithChannels { tx, inner, .. } = auth;
192+
if peer_state.authed {
193+
return Err(Error::InvalidMessage);
194+
} else if inner.token != self.auth_token {
195+
return Err(Error::Unauthorized);
196+
} else {
197+
peer_state.authed = true;
198+
tx.send(Ok(())).await.ok();
199+
}
200+
}
201+
StorageMessage::Get(get) => {
202+
info!("get {:?}", get);
203+
let WithChannels { tx, inner, .. } = get;
204+
let res = self.state.lock().unwrap().get(&inner.key).cloned();
205+
tx.send(res).await.ok();
206+
}
207+
StorageMessage::Set(set) => {
208+
info!("set {:?}", set);
209+
let WithChannels { tx, inner, .. } = set;
210+
self.state.lock().unwrap().insert(inner.key, inner.value);
211+
tx.send(()).await.ok();
212+
}
213+
StorageMessage::SetMany(list) => {
214+
let WithChannels { tx, mut rx, .. } = list;
215+
let mut i = 0;
216+
while let Ok(Some((key, value))) = rx.recv().await {
217+
let mut state = self.state.lock().unwrap();
218+
state.insert(key, value);
219+
i += 1;
220+
}
221+
tx.send(i).await.ok();
222+
}
223+
StorageMessage::List(list) => {
224+
info!("list {:?}", list);
225+
let WithChannels { mut tx, .. } = list;
226+
let values = {
227+
let state = self.state.lock().unwrap();
228+
// TODO: use async lock to not clone here.
229+
let values: Vec<_> = state
230+
.iter()
231+
.map(|(key, value)| format!("{key}={value}"))
232+
.collect();
233+
values
234+
};
235+
for value in values {
236+
if tx.send(value).await.is_err() {
237+
break;
238+
}
239+
}
240+
}
241+
}
242+
Ok(())
243+
}
244+
}
245+
246+
pub struct StorageClient {
247+
inner: Client<StorageMessage, StorageProtocol, StorageService>,
248+
}
249+
250+
impl StorageClient {
251+
pub const ALPN: &[u8] = ALPN;
252+
253+
pub fn connect(endpoint: Endpoint, addr: impl Into<iroh::NodeAddr>) -> StorageClient {
254+
let conn = IrohRemoteConnection::new(endpoint, addr.into(), Self::ALPN.to_vec());
255+
StorageClient {
256+
inner: Client::boxed(conn),
257+
}
258+
}
259+
260+
pub async fn auth(&self, token: &str) -> Result<(), anyhow::Error> {
261+
self.inner
262+
.rpc(Auth {
263+
token: token.to_string(),
264+
})
265+
.await?
266+
.map_err(|err| anyhow::anyhow!(err))
267+
}
268+
269+
pub async fn get(&self, key: String) -> Result<Option<String>, irpc::Error> {
270+
self.inner.rpc(Get { key }).await
271+
}
272+
273+
pub async fn list(&self) -> Result<spsc::Receiver<String>, irpc::Error> {
274+
self.inner.server_streaming(List, 10).await
275+
}
276+
277+
pub async fn set(&self, key: String, value: String) -> Result<(), irpc::Error> {
278+
let msg = Set { key, value };
279+
self.inner.rpc(msg).await
280+
}
281+
}
282+
}

irpc-iroh/src/lib.rs

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -128,35 +128,42 @@ pub async fn handle_connection<R: DeserializeOwned + 'static>(
128128
handler: Handler<R>,
129129
) -> io::Result<()> {
130130
loop {
131-
let (send, mut recv) = match connection.accept_bi().await {
132-
Ok((s, r)) => (s, r),
133-
Err(ConnectionError::ApplicationClosed(cause))
134-
if cause.error_code.into_inner() == 0 =>
135-
{
136-
trace!("remote side closed connection {cause:?}");
137-
return Ok(());
138-
}
139-
Err(cause) => {
140-
warn!("failed to accept bi stream {cause:?}");
141-
return Err(cause.into());
142-
}
131+
let Some((msg, rx, tx)) = read_request(&connection).await? else {
132+
return Ok(());
143133
};
144-
let size = recv
145-
.read_varint_u64()
146-
.await?
147-
.ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "failed to read size"))?;
148-
let mut buf = vec![0; size as usize];
149-
recv.read_exact(&mut buf)
150-
.await
151-
.map_err(|e| io::Error::new(io::ErrorKind::UnexpectedEof, e))?;
152-
let msg: R = postcard::from_bytes(&buf)
153-
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
154-
let rx = recv;
155-
let tx = send;
156134
handler(msg, rx, tx).await?;
157135
}
158136
}
159137

138+
pub async fn read_request<R: DeserializeOwned + 'static>(
139+
connection: &Connection,
140+
) -> std::io::Result<Option<(R, RecvStream, SendStream)>> {
141+
let (send, mut recv) = match connection.accept_bi().await {
142+
Ok((s, r)) => (s, r),
143+
Err(ConnectionError::ApplicationClosed(cause)) if cause.error_code.into_inner() == 0 => {
144+
trace!("remote side closed connection {cause:?}");
145+
return Ok(None);
146+
}
147+
Err(cause) => {
148+
warn!("failed to accept bi stream {cause:?}");
149+
return Err(cause.into());
150+
}
151+
};
152+
let size = recv
153+
.read_varint_u64()
154+
.await?
155+
.ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "failed to read size"))?;
156+
let mut buf = vec![0; size as usize];
157+
recv.read_exact(&mut buf)
158+
.await
159+
.map_err(|e| io::Error::new(io::ErrorKind::UnexpectedEof, e))?;
160+
let msg: R =
161+
postcard::from_bytes(&buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
162+
let rx = recv;
163+
let tx = send;
164+
Ok(Some((msg, rx, tx)))
165+
}
166+
160167
/// Utility function to listen for incoming connections and handle them with the provided handler
161168
pub async fn listen<R: DeserializeOwned + 'static>(endpoint: iroh::Endpoint, handler: Handler<R>) {
162169
let mut request_id = 0u64;

0 commit comments

Comments
 (0)