diff --git a/iroh/src/disco.rs b/iroh/src/disco.rs index f1beaedf77..eb6087b18c 100644 --- a/iroh/src/disco.rs +++ b/iroh/src/disco.rs @@ -24,6 +24,7 @@ use std::{ }; use anyhow::{anyhow, bail, ensure, Context, Result}; +use bytes::Bytes; use data_encoding::HEXLOWER; use iroh_base::{PublicKey, RelayUrl}; use serde::{Deserialize, Serialize}; @@ -102,6 +103,19 @@ pub fn source_and_box(p: &[u8]) -> Option<(PublicKey, &[u8])> { Some((sender, sealed_box)) } +/// If `p` looks like a disco message it returns the slice of `p` that represents the disco public key source, +/// and the part that is the box. +pub fn source_and_box_bytes(p: &Bytes) -> Option<(PublicKey, Bytes)> { + if !looks_like_disco_wrapper(p) { + return None; + } + + let source = &p[MAGIC_LEN..MAGIC_LEN + KEY_LEN]; + let sender = PublicKey::try_from(source).ok()?; + let sealed_box = p.slice(MAGIC_LEN + KEY_LEN..); + Some((sender, sealed_box)) +} + /// A discovery message. #[derive(Debug, Clone, PartialEq, Eq)] pub enum Message { diff --git a/iroh/src/magicsock.rs b/iroh/src/magicsock.rs index 9cda164c82..880277fe9b 100644 --- a/iroh/src/magicsock.rs +++ b/iroh/src/magicsock.rs @@ -1094,11 +1094,6 @@ impl MagicSock { return None; } - if self.handle_relay_disco_message(&dm.buf, &dm.url, dm.src) { - // DISCO messages are handled internally in the MagicSock, do not pass to Quinn. - return None; - } - let quic_mapped_addr = self.node_map.receive_relay(&dm.url, dm.src); // Normalize local_ip @@ -1119,32 +1114,6 @@ impl MagicSock { Some((dm.src, meta, dm.buf)) } - fn handle_relay_disco_message( - &self, - msg: &[u8], - url: &RelayUrl, - relay_node_src: PublicKey, - ) -> bool { - match disco::source_and_box(msg) { - Some((source, sealed_box)) => { - if relay_node_src != source { - // TODO: return here? - warn!("Received relay disco message from connection for {}, but with message from {}", relay_node_src.fmt_short(), source.fmt_short()); - } - self.handle_disco_message( - source, - sealed_box, - DiscoMessageSource::Relay { - url: url.clone(), - key: relay_node_src, - }, - ); - true - } - None => false, - } - } - /// Handles a discovery message. #[instrument("disco_in", skip_all, fields(node = %sender.fmt_short(), %src))] fn handle_disco_message(&self, sender: PublicKey, sealed_box: &[u8], src: DiscoMessageSource) { @@ -1827,7 +1796,13 @@ impl Handle { let mut actor_tasks = JoinSet::default(); - let relay_actor = RelayActor::new(msock.clone(), relay_datagram_recv_queue, relay_protocol); + let (relay_disco_recv_tx, mut relay_disco_recv_rx) = tokio::sync::mpsc::channel(1024); + let relay_actor = RelayActor::new( + msock.clone(), + relay_datagram_recv_queue, + relay_disco_recv_tx, + relay_protocol, + ); let relay_actor_cancel_token = relay_actor.cancel_token(); actor_tasks.spawn( async move { @@ -1837,6 +1812,23 @@ impl Handle { } .instrument(info_span!("relay-actor")), ); + actor_tasks.spawn({ + let msock = msock.clone(); + async move { + while let Some(message) = relay_disco_recv_rx.recv().await { + msock.handle_disco_message( + message.source, + &message.sealed_box, + DiscoMessageSource::Relay { + url: message.relay_url, + key: message.relay_remote_node_id, + }, + ); + } + debug!("relay-disco-recv actor closed"); + } + .instrument(info_span!("relay-disco-recv")) + }); #[cfg(not(wasm_browser))] let _ = actor_tasks.spawn({ @@ -2123,7 +2115,8 @@ impl RelayDatagramSendChannelReceiver { #[derive(Debug)] struct RelayDatagramRecvQueue { queue: ConcurrentQueue, - waker: AtomicWaker, + recv_waker: AtomicWaker, + send_wakers: ConcurrentQueue, } impl RelayDatagramRecvQueue { @@ -2131,7 +2124,8 @@ impl RelayDatagramRecvQueue { fn new() -> Self { Self { queue: ConcurrentQueue::bounded(512), - waker: AtomicWaker::new(), + recv_waker: AtomicWaker::new(), + send_wakers: ConcurrentQueue::unbounded(), } } @@ -2144,10 +2138,49 @@ impl RelayDatagramRecvQueue { item: RelayRecvDatagram, ) -> Result<(), concurrent_queue::PushError> { self.queue.push(item).inspect(|_| { - self.waker.wake(); + self.recv_waker.wake(); }) } + /// Polls for whether the queue has free slots for sending items. + /// + /// If the queue has free slots, this returns [`Poll::Ready`]. + /// If the queue is full, [`Poll::Pending`] is returned and the waker + /// is stored and woken once the queue has free slots. + /// + /// This can be called from multiple tasks concurrently. If a slot becomes + /// available, all stored wakers will be woken simultaneously. + /// This also means that even if [`Poll::Ready`] is returned, it is not + /// guaranteed that [`Self::try_send`] will return `Ok` on the next call, + /// because another send task could have used the slot already. + fn poll_send_ready(&self, cx: &mut Context<'_>) -> Poll> { + if self.queue.is_closed() { + Poll::Ready(Err(anyhow!("Queue closed"))) + } else if !self.queue.is_full() { + Poll::Ready(Ok(())) + } else { + match self.send_wakers.push(cx.waker().clone()) { + Ok(()) => Poll::Pending, + Err(concurrent_queue::PushError::Full(_)) => { + unreachable!("Send waker queue is unbounded") + } + Err(concurrent_queue::PushError::Closed(_)) => { + Poll::Ready(Err(anyhow!("Queue closed"))) + } + } + } + } + + async fn send_ready(&self) -> Result<()> { + std::future::poll_fn(|cx| self.poll_send_ready(cx)).await + } + + fn wake_senders(&self) { + while let Ok(waker) = self.send_wakers.pop() { + waker.wake(); + } + } + /// Polls for new items in the queue. /// /// Although this method is available from `&self`, it must not be @@ -2162,23 +2195,31 @@ impl RelayDatagramRecvQueue { /// to be able to poll from `&self`. fn poll_recv(&self, cx: &mut Context) -> Poll> { match self.queue.pop() { - Ok(value) => Poll::Ready(Ok(value)), + Ok(value) => { + self.wake_senders(); + Poll::Ready(Ok(value)) + } Err(concurrent_queue::PopError::Empty) => { - self.waker.register(cx.waker()); + self.recv_waker.register(cx.waker()); match self.queue.pop() { Ok(value) => { - self.waker.take(); + self.recv_waker.take(); + self.wake_senders(); Poll::Ready(Ok(value)) } Err(concurrent_queue::PopError::Empty) => Poll::Pending, Err(concurrent_queue::PopError::Closed) => { - self.waker.take(); + self.recv_waker.take(); + self.wake_senders(); Poll::Ready(Err(anyhow!("Queue closed"))) } } } - Err(concurrent_queue::PopError::Closed) => Poll::Ready(Err(anyhow!("Queue closed"))), + Err(concurrent_queue::PopError::Closed) => { + self.wake_senders(); + Poll::Ready(Err(anyhow!("Queue closed"))) + } } } } diff --git a/iroh/src/magicsock/relay_actor.rs b/iroh/src/magicsock/relay_actor.rs index d603724fe1..bfa91c327e 100644 --- a/iroh/src/magicsock/relay_actor.rs +++ b/iroh/src/magicsock/relay_actor.rs @@ -51,7 +51,10 @@ use n0_future::{ time::{self, Duration, Instant, MissedTickBehavior}, FuturesUnorderedBounded, SinkExt, StreamExt, }; -use tokio::sync::{mpsc, oneshot}; +use tokio::sync::{ + mpsc::{self, OwnedPermit}, + oneshot, +}; use tokio_util::sync::CancellationToken; use tracing::{debug, error, event, info_span, instrument, trace, warn, Instrument, Level}; use url::Url; @@ -133,8 +136,6 @@ struct ActiveRelayActor { prio_inbox: mpsc::Receiver, /// Inbox for messages which involve sending to the relay server. inbox: mpsc::Receiver, - /// Queue for received relay datagrams. - relay_datagrams_recv: Arc, /// Channel on which we queue packets to send to the relay. relay_datagrams_send: mpsc::Receiver, @@ -157,6 +158,15 @@ struct ActiveRelayActor { /// Token indicating the [`ActiveRelayActor`] should stop. stop_token: CancellationToken, metrics: Arc, + receive_queue: ReceiveQueue, +} + +#[derive(Debug)] +pub(super) struct RelayDiscoMessage { + pub(super) source: PublicKey, + pub(super) sealed_box: Bytes, + pub(super) relay_url: RelayUrl, + pub(super) relay_remote_node_id: PublicKey, } #[derive(Debug)] @@ -197,6 +207,7 @@ struct ActiveRelayActorOptions { inbox: mpsc::Receiver, relay_datagrams_send: mpsc::Receiver, relay_datagrams_recv: Arc, + relay_disco_recv: mpsc::Sender, connection_opts: RelayConnectionOptions, stop_token: CancellationToken, metrics: Arc, @@ -234,6 +245,7 @@ impl ActiveRelayActor { inbox, relay_datagrams_send, relay_datagrams_recv, + relay_disco_recv, connection_opts, stop_token, metrics, @@ -242,14 +254,19 @@ impl ActiveRelayActor { ActiveRelayActor { prio_inbox, inbox, - relay_datagrams_recv, relay_datagrams_send, - url, + url: url.clone(), relay_client_builder, is_home_relay: false, inactive_timeout: Box::pin(time::sleep(RELAY_INACTIVE_CLEANUP_TIME)), stop_token, metrics, + receive_queue: ReceiveQueue { + relay_url: url, + relay_datagrams_recv, + relay_disco_recv, + pending: None, + }, } } @@ -438,6 +455,7 @@ impl ActiveRelayActor { } } } + _ = self.receive_queue.forward_pending(), if self.receive_queue.is_pending() => {}, _ = &mut self.inactive_timeout, if !self.is_home_relay => { debug!(?RELAY_INACTIVE_CLEANUP_TIME, "Inactive, exiting."); break None; @@ -599,7 +617,12 @@ impl ActiveRelayActor { let fut = client_sink.send_all(&mut packet_stream); self.run_sending(fut, &mut state, &mut client_stream).await?; } - msg = client_stream.next() => { + res = self.receive_queue.forward_pending(), if self.receive_queue.is_pending() => { + if let Err(err) = res { + break Err(err); + } + } + msg = client_stream.next(), if !self.receive_queue.is_pending() => { let Some(msg) = msg else { break Err(anyhow!("Stream closed by server.")); }; @@ -642,19 +665,11 @@ impl ActiveRelayActor { .map(|p| *p != remote_node_id) .unwrap_or(true) { - // Avoid map lookup with high throughput single peer. + // Avoid map () = self.receive_queue.forward_pending(), if self.receive_queue.is_pending() => {} lookup with high throughput single peer. state.last_packet_src = Some(remote_node_id); state.nodes_present.insert(remote_node_id); } - for datagram in PacketSplitIter::new(self.url.clone(), remote_node_id, data) { - let Ok(datagram) = datagram else { - warn!("Invalid packet split"); - break; - }; - if let Err(err) = self.relay_datagrams_recv.try_send(datagram) { - warn!("Dropping received relay packet: {err:#}"); - } - } + self.receive_queue.queue_packets(remote_node_id, data); } ReceivedMessage::NodeGone(node_id) => { state.nodes_present.remove(&node_id); @@ -737,7 +752,12 @@ impl ActiveRelayActor { break Err(anyhow!("Ping timeout")); } // No need to read the inbox or datagrams to send. - msg = client_stream.next() => { + res = self.receive_queue.forward_pending(), if self.receive_queue.is_pending() => { + if let Err(err) = res { + break Err(err); + } + } + msg = client_stream.next(), if !self.receive_queue.is_pending() => { let Some(msg) = msg else { break Err(anyhow!("Stream closed by server.")); }; @@ -756,6 +776,144 @@ impl ActiveRelayActor { } } +#[derive(Debug)] +struct ReceiveQueue { + relay_url: RelayUrl, + /// Received relay packets that could not yet be forwarded to the magicsocket. + pending: Option, + /// Queue for received relay datagrams. + relay_datagrams_recv: Arc, + /// Queue for received relay disco packets. + relay_disco_recv: mpsc::Sender, +} + +#[derive(Debug)] +struct PendingRecv { + packets: PacketSplitIter, + blocked_on: RecvPath, +} + +#[derive(Debug)] +enum RecvPath { + Data, + Disco, +} + +impl ReceiveQueue { + fn is_pending(&self) -> bool { + self.pending.is_some() + } + + /// Send packets to their respective queues. + /// + /// If a queue is blocked, the packets that were not yet sent will be stored on [`Self`], + /// and [`Self::is_pending`] will return true. You then need to await [`Self::forward_pending`] + /// in a loop until [`Self::is_pending`] returns false again. Only then call [`Self::queue_packets`] + /// again. Otherwise this function will panic. + /// + /// ## Panics + /// + /// Panics if [`Self::is_pending`] returns `true`. + fn queue_packets(&mut self, remote_node_id: NodeId, data: Bytes) { + let packets = PacketSplitIter::new(self.relay_url.clone(), remote_node_id, data); + assert!( + !self.is_pending(), + "ReceiveQueue::queue_packets may not be called if is_pending() returns true" + ); + self.handle_packets(packets, None); + } + + /// Forward pending received packets to their queues. + /// + /// This will wait for the path the last received item is blocked on (via [`PendingRecv::blocked_on`]) + /// to become unblocked. It will then forward the pending items, until a queue is blocked again. + /// In that case, the remaining items will be stored and [`Self::is_pending`] returns true. + /// + /// Returns an error if the queue we're blocked on is closed. + /// + /// This function is cancellation-safe: If the future is dropped at any point, all items are guaranteed + /// to either be sent into their respective queues or preserved here. + async fn forward_pending(&mut self) -> Result<()> { + // We take a reference onto the inner value. + // we're not `take`ing it here, because this would make the function not cancellation safe. + let Some(ref pending) = self.pending else { + return Ok(()); + }; + let disco_permit = match pending.blocked_on { + RecvPath::Data => { + // The data receive queue does not have permits, so we can only wait for free slots. + self.relay_datagrams_recv.send_ready().await?; + None + } + RecvPath::Disco => { + // The disco receive channel has permits, so we can reserve a permit to use afterwards + // to send at least one item. + let permit = self.relay_disco_recv.clone().reserve_owned().await?; + Some(permit) + } + }; + // We checked above that `self.pending` is not `None` so this `expect` is safe. + let packets = self + .pending + .take() + .expect("checked to be not empty") + .packets; + self.handle_packets(packets, disco_permit); + Ok(()) + } + + fn handle_packets( + &mut self, + mut packets: PacketSplitIter, + mut disco_permit: Option>, + ) { + let remote_node_id = packets.remote_node_id(); + for datagram in &mut packets { + let Ok(datagram) = datagram else { + warn!("Invalid packet split"); + break; + }; + match crate::disco::source_and_box_bytes(&datagram.buf) { + Some((source, sealed_box)) => { + if remote_node_id != source { + // TODO: return here? + warn!("Received relay disco message from connection for {}, but with message from {}", remote_node_id.fmt_short(), source.fmt_short()); + } + let message = RelayDiscoMessage { + source, + sealed_box, + relay_url: datagram.url.clone(), + relay_remote_node_id: datagram.src, + }; + if let Some(permit) = disco_permit.take() { + permit.send(message); + } else if let Err(err) = self.relay_disco_recv.try_send(message) { + warn!("Relay disco receive queue blocked: {err}"); + packets.push_front(datagram); + self.pending = Some(PendingRecv { + packets, + blocked_on: RecvPath::Disco, + }); + return; + } + } + None => { + if let Err(err) = self.relay_datagrams_recv.try_send(datagram) { + warn!("Relay data receive queue blocked: {err}"); + packets.push_front(err.into_inner()); + self.pending = Some(PendingRecv { + packets, + blocked_on: RecvPath::Data, + }); + return; + } + } + } + } + self.pending = None; + } +} + /// Shared state when the [`ActiveRelayActor`] is connected to a relay server. /// /// Common state between [`ActiveRelayActor::run_connected`] and @@ -814,6 +972,7 @@ pub(super) struct RelayActor { /// /// [`AsyncUdpSocket::poll_recv`]: quinn::AsyncUdpSocket::poll_recv relay_datagram_recv_queue: Arc, + relay_disco_recv_tx: mpsc::Sender, /// The actors managing each currently used relay server. /// /// These actors will exit when they have any inactivity. Otherwise they will keep @@ -829,12 +988,14 @@ impl RelayActor { pub(super) fn new( msock: Arc, relay_datagram_recv_queue: Arc, + relay_disco_recv_tx: mpsc::Sender, protocol: iroh_relay::http::Protocol, ) -> Self { let cancel_token = CancellationToken::new(); Self { msock, relay_datagram_recv_queue, + relay_disco_recv_tx, active_relays: Default::default(), active_relay_tasks: JoinSet::new(), cancel_token, @@ -1056,6 +1217,7 @@ impl RelayActor { inbox: inbox_rx, relay_datagrams_send: send_datagram_rx, relay_datagrams_recv: self.relay_datagram_recv_queue.clone(), + relay_disco_recv: self.relay_disco_recv_tx.clone(), connection_opts, stop_token: self.cancel_token.child_token(), metrics: self.msock.metrics.magicsock.clone(), @@ -1234,12 +1396,22 @@ struct PacketSplitIter { url: RelayUrl, src: NodeId, bytes: Bytes, + next: Option, } impl PacketSplitIter { /// Create a new PacketSplitIter from a packet. fn new(url: RelayUrl, src: NodeId, bytes: Bytes) -> Self { - Self { url, src, bytes } + Self { + url, + src, + bytes, + next: None, + } + } + + fn remote_node_id(&self) -> NodeId { + self.src } fn fail(&mut self) -> Option> { @@ -1249,6 +1421,10 @@ impl PacketSplitIter { "", ))) } + + fn push_front(&mut self, item: RelayRecvDatagram) { + self.next = Some(item); + } } impl Iterator for PacketSplitIter { @@ -1256,6 +1432,9 @@ impl Iterator for PacketSplitIter { fn next(&mut self) -> Option { use bytes::Buf; + if let Some(item) = self.next.take() { + return Some(Ok(item)); + } if self.bytes.has_remaining() { if self.bytes.remaining() < 2 { return self.fail(); @@ -1331,6 +1510,7 @@ mod tests { inbox_rx: mpsc::Receiver, relay_datagrams_send: mpsc::Receiver, relay_datagrams_recv: Arc, + relay_disco_recv: mpsc::Sender, span: tracing::Span, ) -> AbortOnDropHandle> { let opts = ActiveRelayActorOptions { @@ -1339,6 +1519,7 @@ mod tests { inbox: inbox_rx, relay_datagrams_send, relay_datagrams_recv, + relay_disco_recv, connection_opts: RelayConnectionOptions { secret_key, dns_resolver: DnsResolver::new(), @@ -1363,6 +1544,7 @@ mod tests { let secret_key = SecretKey::from_bytes(&[8u8; 32]); let recv_datagram_queue = Arc::new(RelayDatagramRecvQueue::new()); let (send_datagram_tx, send_datagram_rx) = mpsc::channel(16); + let (relay_disco_recv_tx, _relay_disco_recv_rx) = mpsc::channel(16); let (prio_inbox_tx, prio_inbox_rx) = mpsc::channel(8); let (inbox_tx, inbox_rx) = mpsc::channel(16); let cancel_token = CancellationToken::new(); @@ -1374,6 +1556,7 @@ mod tests { inbox_rx, send_datagram_rx, recv_datagram_queue.clone(), + relay_disco_recv_tx, info_span!("echo-node"), ); let echo_task = tokio::spawn({ @@ -1455,6 +1638,7 @@ mod tests { let secret_key = SecretKey::from_bytes(&[1u8; 32]); let datagram_recv_queue = Arc::new(RelayDatagramRecvQueue::new()); + let (relay_disco_recv_tx, _relay_disco_recv_rx) = mpsc::channel(16); let (send_datagram_tx, send_datagram_rx) = mpsc::channel(16); let (_prio_inbox_tx, prio_inbox_rx) = mpsc::channel(8); let (inbox_tx, inbox_rx) = mpsc::channel(16); @@ -1467,6 +1651,7 @@ mod tests { inbox_rx, send_datagram_rx, datagram_recv_queue.clone(), + relay_disco_recv_tx, info_span!("actor-under-test"), ); @@ -1541,6 +1726,7 @@ mod tests { let secret_key = SecretKey::from_bytes(&[1u8; 32]); let datagram_recv_queue = Arc::new(RelayDatagramRecvQueue::new()); + let (relay_disco_recv_tx, _relay_disco_recv_rx) = mpsc::channel(16); let (_send_datagram_tx, send_datagram_rx) = mpsc::channel(16); let (_prio_inbox_tx, prio_inbox_rx) = mpsc::channel(8); let (inbox_tx, inbox_rx) = mpsc::channel(16); @@ -1553,6 +1739,7 @@ mod tests { inbox_rx, send_datagram_rx, datagram_recv_queue.clone(), + relay_disco_recv_tx, info_span!("actor-under-test"), );