Skip to content

Commit 106f7f3

Browse files
committed
refactor: apply backpressure without blocking the actor loop
1 parent 7c100fa commit 106f7f3

File tree

2 files changed

+171
-39
lines changed

2 files changed

+171
-39
lines changed

iroh/src/magicsock.rs

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2104,15 +2104,17 @@ impl RelayDatagramSendChannelReceiver {
21042104
#[derive(Debug)]
21052105
struct RelayDatagramRecvQueue {
21062106
queue: ConcurrentQueue<RelayRecvDatagram>,
2107-
waker: AtomicWaker,
2107+
recv_waker: AtomicWaker,
2108+
send_waker: AtomicWaker,
21082109
}
21092110

21102111
impl RelayDatagramRecvQueue {
21112112
/// Creates a new, empty queue with a fixed size bound of 512 items.
21122113
fn new() -> Self {
21132114
Self {
21142115
queue: ConcurrentQueue::bounded(512),
2115-
waker: AtomicWaker::new(),
2116+
recv_waker: AtomicWaker::new(),
2117+
send_waker: AtomicWaker::new(),
21162118
}
21172119
}
21182120

@@ -2125,10 +2127,21 @@ impl RelayDatagramRecvQueue {
21252127
item: RelayRecvDatagram,
21262128
) -> Result<(), concurrent_queue::PushError<RelayRecvDatagram>> {
21272129
self.queue.push(item).inspect(|_| {
2128-
self.waker.wake();
2130+
self.recv_waker.wake();
21292131
})
21302132
}
21312133

2134+
fn poll_send_ready(&self, cx: &mut Context<'_>) -> Poll<Result<()>> {
2135+
if self.queue.is_closed() {
2136+
Poll::Ready(Err(anyhow!("Queue closed")))
2137+
} else if !self.queue.is_full() {
2138+
Poll::Ready(Ok(()))
2139+
} else {
2140+
self.send_waker.register(cx.waker());
2141+
Poll::Pending
2142+
}
2143+
}
2144+
21322145
/// Polls for new items in the queue.
21332146
///
21342147
/// Although this method is available from `&self`, it must not be
@@ -2143,23 +2156,31 @@ impl RelayDatagramRecvQueue {
21432156
/// to be able to poll from `&self`.
21442157
fn poll_recv(&self, cx: &mut Context) -> Poll<Result<RelayRecvDatagram>> {
21452158
match self.queue.pop() {
2146-
Ok(value) => Poll::Ready(Ok(value)),
2159+
Ok(value) => {
2160+
self.send_waker.wake();
2161+
Poll::Ready(Ok(value))
2162+
}
21472163
Err(concurrent_queue::PopError::Empty) => {
2148-
self.waker.register(cx.waker());
2164+
self.recv_waker.register(cx.waker());
21492165

21502166
match self.queue.pop() {
21512167
Ok(value) => {
2152-
self.waker.take();
2168+
self.send_waker.wake();
2169+
self.recv_waker.take();
21532170
Poll::Ready(Ok(value))
21542171
}
21552172
Err(concurrent_queue::PopError::Empty) => Poll::Pending,
21562173
Err(concurrent_queue::PopError::Closed) => {
2157-
self.waker.take();
2174+
self.recv_waker.take();
2175+
self.send_waker.wake();
21582176
Poll::Ready(Err(anyhow!("Queue closed")))
21592177
}
21602178
}
21612179
}
2162-
Err(concurrent_queue::PopError::Closed) => Poll::Ready(Err(anyhow!("Queue closed"))),
2180+
Err(concurrent_queue::PopError::Closed) => {
2181+
self.send_waker.wake();
2182+
Poll::Ready(Err(anyhow!("Queue closed")))
2183+
}
21632184
}
21642185
}
21652186
}

iroh/src/magicsock/relay_actor.rs

Lines changed: 142 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ use n0_future::{
5151
time::{self, Duration, Instant, MissedTickBehavior},
5252
FuturesUnorderedBounded, SinkExt, StreamExt,
5353
};
54-
use tokio::sync::{mpsc, oneshot};
54+
use tokio::sync::{
55+
mpsc::{self, OwnedPermit},
56+
oneshot,
57+
};
5558
use tokio_util::sync::CancellationToken;
5659
use tracing::{debug, error, event, info_span, instrument, trace, warn, Instrument, Level};
5760
use url::Url;
@@ -159,6 +162,20 @@ struct ActiveRelayActor {
159162
/// Token indicating the [`ActiveRelayActor`] should stop.
160163
stop_token: CancellationToken,
161164
metrics: Arc<MagicsockMetrics>,
165+
/// Received relay packets that could not yet be forwarded to the magicsocket.
166+
pending_received: Option<PendingRecv>,
167+
}
168+
169+
#[derive(Debug)]
170+
struct PendingRecv {
171+
packet_iter: PacketSplitIter,
172+
blocked_on: RecvPath,
173+
}
174+
175+
#[derive(Debug)]
176+
enum RecvPath {
177+
Data,
178+
Disco,
162179
}
163180

164181
#[derive(Debug)]
@@ -263,6 +280,7 @@ impl ActiveRelayActor {
263280
inactive_timeout: Box::pin(time::sleep(RELAY_INACTIVE_CLEANUP_TIME)),
264281
stop_token,
265282
metrics,
283+
pending_received: None,
266284
}
267285
}
268286

@@ -612,7 +630,8 @@ impl ActiveRelayActor {
612630
let fut = client_sink.send_all(&mut packet_stream);
613631
self.run_sending(fut, &mut state, &mut client_stream).await?;
614632
}
615-
msg = client_stream.next() => {
633+
_ = forward_pending(&mut self.pending_received, &self.relay_datagrams_recv, &mut self.relay_disco_recv), if self.pending_received.is_some() => {}
634+
msg = client_stream.next(), if self.pending_received.is_none() => {
616635
let Some(msg) = msg else {
617636
break Err(anyhow!("Stream closed by server."));
618637
};
@@ -659,33 +678,14 @@ impl ActiveRelayActor {
659678
state.last_packet_src = Some(remote_node_id);
660679
state.nodes_present.insert(remote_node_id);
661680
}
662-
for datagram in PacketSplitIter::new(self.url.clone(), remote_node_id, data) {
663-
let Ok(datagram) = datagram else {
664-
warn!("Invalid packet split");
665-
break;
666-
};
667-
match crate::disco::source_and_box_bytes(&datagram.buf) {
668-
Some((source, sealed_box)) => {
669-
if remote_node_id != source {
670-
// TODO: return here?
671-
warn!("Received relay disco message from connection for {}, but with message from {}", remote_node_id.fmt_short(), source.fmt_short());
672-
}
673-
let message = RelayDiscoMessage {
674-
source,
675-
sealed_box,
676-
relay_url: datagram.url.clone(),
677-
relay_remote_node_id: datagram.src,
678-
};
679-
if let Err(err) = self.relay_disco_recv.try_send(message) {
680-
warn!("Dropping received relay disco packet: {err:#}");
681-
}
682-
}
683-
None => {
684-
if let Err(err) = self.relay_datagrams_recv.try_send(datagram) {
685-
warn!("Dropping received relay data packet: {err:#}");
686-
}
687-
}
688-
}
681+
let packet_iter = PacketSplitIter::new(self.url.clone(), remote_node_id, data);
682+
if let Some(pending) = handle_received_packet_iter(
683+
packet_iter,
684+
None,
685+
&self.relay_datagrams_recv,
686+
&mut self.relay_disco_recv,
687+
) {
688+
self.pending_received = Some(pending);
689689
}
690690
}
691691
ReceivedMessage::NodeGone(node_id) => {
@@ -769,7 +769,8 @@ impl ActiveRelayActor {
769769
break Err(anyhow!("Ping timeout"));
770770
}
771771
// No need to read the inbox or datagrams to send.
772-
msg = client_stream.next() => {
772+
_ = forward_pending(&mut self.pending_received, &self.relay_datagrams_recv, &mut self.relay_disco_recv), if self.pending_received.is_some() => {}
773+
msg = client_stream.next(), if self.pending_received.is_none() => {
773774
let Some(msg) = msg else {
774775
break Err(anyhow!("Stream closed by server."));
775776
};
@@ -788,6 +789,99 @@ impl ActiveRelayActor {
788789
}
789790
}
790791

792+
/// Forward pending received packets to their queues.
793+
///
794+
/// If `maybe_pending` is not empty, this will wait for the path the last received item
795+
/// is blocked on (via [`PendingRecv::blocked_on`]) to become unblocked. It will then forward
796+
/// the pending items, until a queue is blocked again. In that case, the remaining items will
797+
/// be put back into `maybe_pending`. If all items could be sent, `maybe_pending` will be set
798+
/// to `None`.
799+
///
800+
/// This function is cancellation-safe: If the future is dropped at any point, all items are guaranteed
801+
/// to either be sent into their respective queues, or are still in `maybe_pending`.
802+
async fn forward_pending(
803+
maybe_pending: &mut Option<PendingRecv>,
804+
relay_datagrams_recv: &RelayDatagramRecvQueue,
805+
relay_disco_recv: &mut mpsc::Sender<RelayDiscoMessage>,
806+
) {
807+
let Some(ref mut pending) = maybe_pending else {
808+
return;
809+
};
810+
let disco_permit = match pending.blocked_on {
811+
RecvPath::Data => {
812+
std::future::poll_fn(|cx| relay_datagrams_recv.poll_send_ready(cx))
813+
.await
814+
.ok();
815+
None
816+
}
817+
RecvPath::Disco => {
818+
let Ok(permit) = relay_disco_recv.clone().reserve_owned().await else {
819+
return;
820+
};
821+
Some(permit)
822+
}
823+
};
824+
let pending = maybe_pending.take().unwrap();
825+
if let Some(pending) = handle_received_packet_iter(
826+
pending.packet_iter,
827+
disco_permit,
828+
relay_datagrams_recv,
829+
relay_disco_recv,
830+
) {
831+
*maybe_pending = Some(pending);
832+
}
833+
}
834+
835+
fn handle_received_packet_iter(
836+
mut packet_iter: PacketSplitIter,
837+
mut disco_permit: Option<OwnedPermit<RelayDiscoMessage>>,
838+
relay_datagrams_recv: &RelayDatagramRecvQueue,
839+
relay_disco_recv: &mut mpsc::Sender<RelayDiscoMessage>,
840+
) -> Option<PendingRecv> {
841+
let remote_node_id = packet_iter.remote_node_id();
842+
for datagram in &mut packet_iter {
843+
let Ok(datagram) = datagram else {
844+
warn!("Invalid packet split");
845+
return None;
846+
};
847+
match crate::disco::source_and_box_bytes(&datagram.buf) {
848+
Some((source, sealed_box)) => {
849+
if remote_node_id != source {
850+
// TODO: return here?
851+
warn!("Received relay disco message from connection for {}, but with message from {}", remote_node_id.fmt_short(), source.fmt_short());
852+
}
853+
let message = RelayDiscoMessage {
854+
source,
855+
sealed_box,
856+
relay_url: datagram.url.clone(),
857+
relay_remote_node_id: datagram.src,
858+
};
859+
if let Some(permit) = disco_permit.take() {
860+
permit.send(message);
861+
} else if let Err(err) = relay_disco_recv.try_send(message) {
862+
warn!("Dropping received relay disco packet: {err:#}");
863+
packet_iter.push_front(datagram);
864+
return Some(PendingRecv {
865+
packet_iter,
866+
blocked_on: RecvPath::Disco,
867+
});
868+
}
869+
}
870+
None => {
871+
if let Err(err) = relay_datagrams_recv.try_send(datagram) {
872+
warn!("Dropping received relay data packet: {err:#}");
873+
packet_iter.push_front(err.into_inner());
874+
return Some(PendingRecv {
875+
packet_iter,
876+
blocked_on: RecvPath::Data,
877+
});
878+
}
879+
}
880+
}
881+
}
882+
None
883+
}
884+
791885
/// Shared state when the [`ActiveRelayActor`] is connected to a relay server.
792886
///
793887
/// Common state between [`ActiveRelayActor::run_connected`] and
@@ -1270,12 +1364,22 @@ struct PacketSplitIter {
12701364
url: RelayUrl,
12711365
src: NodeId,
12721366
bytes: Bytes,
1367+
next: Option<RelayRecvDatagram>,
12731368
}
12741369

12751370
impl PacketSplitIter {
12761371
/// Create a new PacketSplitIter from a packet.
12771372
fn new(url: RelayUrl, src: NodeId, bytes: Bytes) -> Self {
1278-
Self { url, src, bytes }
1373+
Self {
1374+
url,
1375+
src,
1376+
bytes,
1377+
next: None,
1378+
}
1379+
}
1380+
1381+
fn remote_node_id(&self) -> NodeId {
1382+
self.src
12791383
}
12801384

12811385
fn fail(&mut self) -> Option<std::io::Result<RelayRecvDatagram>> {
@@ -1285,13 +1389,20 @@ impl PacketSplitIter {
12851389
"",
12861390
)))
12871391
}
1392+
1393+
fn push_front(&mut self, item: RelayRecvDatagram) {
1394+
self.next = Some(item);
1395+
}
12881396
}
12891397

12901398
impl Iterator for PacketSplitIter {
12911399
type Item = std::io::Result<RelayRecvDatagram>;
12921400

12931401
fn next(&mut self) -> Option<Self::Item> {
12941402
use bytes::Buf;
1403+
if let Some(item) = self.next.take() {
1404+
return Some(Ok(item));
1405+
}
12951406
if self.bytes.has_remaining() {
12961407
if self.bytes.remaining() < 2 {
12971408
return self.fail();

0 commit comments

Comments
 (0)