Skip to content

Commit 72abe20

Browse files
committed
feat(socket): implement transport_stats for pub
1 parent 9016991 commit 72abe20

4 files changed

Lines changed: 43 additions & 17 deletions

File tree

msg-socket/src/pub/driver.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use super::{
1414
PubError, PubMessage, PubOptions, SocketState, session::SubscriberSession, trie::PrefixTrie,
1515
};
1616
use crate::{ConnectionHookErased, hooks};
17-
use msg_transport::{Address, PeerAddress, Transport};
17+
use msg_transport::{Address, MeteredIo, PeerAddress, Transport};
1818
use msg_wire::pubsub;
1919

2020
/// The driver for the publisher socket. This is responsible for accepting incoming connections,
@@ -28,7 +28,7 @@ pub(crate) struct PubDriver<T: Transport<A>, A: Address> {
2828
/// The publisher options (shared with the socket)
2929
pub(super) options: Arc<PubOptions>,
3030
/// The publisher socket state, shared with the socket front-end.
31-
pub(crate) state: Arc<SocketState>,
31+
pub(crate) state: Arc<SocketState<T::Stats>>,
3232
/// Optional connection hook.
3333
pub(super) hook: Option<Arc<dyn ConnectionHookErased<T::Io>>>,
3434
/// A set of pending incoming connections, represented by [`Transport::Accept`].
@@ -58,13 +58,15 @@ where
5858
Ok((stream, _addr)) => {
5959
info!("connection hook passed");
6060

61-
let framed = Framed::new(stream, pubsub::Codec::new());
61+
let metered =
62+
MeteredIo::new(stream, Arc::clone(&this.state.transport_stats));
63+
let framed = Framed::new(metered, pubsub::Codec::new());
6264

6365
let session = SubscriberSession {
6466
seq: 0,
6567
session_id: this.id_counter,
6668
from_socket_bcast: this.from_socket_bcast.resubscribe().into(),
67-
state: Arc::clone(&this.state),
69+
stats: this.state.stats.clone(),
6870
pending_egress: None,
6971
conn: framed,
7072
topic_filter: PrefixTrie::new(),
@@ -158,13 +160,14 @@ where
158160

159161
self.hook_tasks.spawn(fut.with_span(span));
160162
} else {
161-
let framed = Framed::new(io, pubsub::Codec::new());
163+
let metered = MeteredIo::new(io, Arc::clone(&self.state.transport_stats));
164+
let framed = Framed::new(metered, pubsub::Codec::new());
162165

163166
let session = SubscriberSession {
164167
seq: 0,
165168
session_id: self.id_counter,
166169
from_socket_bcast: self.from_socket_bcast.resubscribe().into(),
167-
state: Arc::clone(&self.state),
170+
stats: self.state.stats.clone(),
168171
pending_egress: None,
169172
conn: framed,
170173
topic_filter: PrefixTrie::new(),

msg-socket/src/pub/mod.rs

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::{io, time::Duration};
1+
use std::{io, sync::Arc, time::Duration};
22

33
use bytes::Bytes;
44
use msg_common::constants::KiB;
@@ -13,6 +13,7 @@ pub use socket::*;
1313

1414
mod stats;
1515
use crate::{Profile, stats::SocketStats};
16+
use arc_swap::ArcSwap;
1617
use stats::PubStats;
1718

1819
mod trie;
@@ -208,9 +209,22 @@ impl PubMessage {
208209
}
209210

210211
/// The publisher socket state, shared between the backend task and the socket.
211-
#[derive(Debug, Default)]
212-
pub(crate) struct SocketState {
213-
pub(crate) stats: SocketStats<PubStats>,
212+
/// Generic over the transport-level stats type.
213+
#[derive(Debug)]
214+
pub(crate) struct SocketState<S: Default> {
215+
pub(crate) stats: Arc<SocketStats<PubStats>>,
216+
/// The transport-level stats. We wrap the inner stats in an `Arc`
217+
/// for cheap clone on read.
218+
pub(crate) transport_stats: Arc<ArcSwap<S>>,
219+
}
220+
221+
impl<S: Default> Default for SocketState<S> {
222+
fn default() -> Self {
223+
Self {
224+
stats: Arc::new(SocketStats::default()),
225+
transport_stats: Arc::new(ArcSwap::from_pointee(S::default())),
226+
}
227+
}
214228
}
215229

216230
#[cfg(test)]

msg-socket/src/pub/session.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,12 @@ use tokio_stream::wrappers::BroadcastStream;
1111
use tokio_util::codec::Framed;
1212
use tracing::{debug, error, trace, warn};
1313

14-
use super::{PubMessage, SocketState, trie::PrefixTrie};
14+
use super::{PubMessage, trie::PrefixTrie};
1515
use msg_wire::pubsub;
1616

17+
use super::stats::PubStats;
18+
use crate::stats::SocketStats;
19+
1720
/// A subscriber session. This struct represents a single subscriber session, which is a
1821
/// connection to a subscriber. This struct is responsible for handling incoming and outgoing
1922
/// messages, as well as managing the connection state.
@@ -26,8 +29,8 @@ pub(super) struct SubscriberSession<Io> {
2629
pub(super) from_socket_bcast: BroadcastStream<PubMessage>,
2730
/// Messages queued to be sent on the connection
2831
pub(super) pending_egress: Option<pubsub::Message>,
29-
/// The socket state, shared between the backend task and the socket.
30-
pub(super) state: Arc<SocketState>,
32+
/// The socket stats.
33+
pub(super) stats: Arc<SocketStats<PubStats>>,
3134
/// The framed connection.
3235
pub(super) conn: Framed<Io, pubsub::Codec>,
3336
/// The topic filter (a prefix trie that works with strings)
@@ -76,7 +79,7 @@ impl<Io: AsyncRead + AsyncWrite + Unpin> SubscriberSession<Io> {
7679

7780
impl<Io> Drop for SubscriberSession<Io> {
7881
fn drop(&mut self) {
79-
self.state.stats.specific.decrement_active_clients();
82+
self.stats.specific.decrement_active_clients();
8083
}
8184
}
8285

@@ -130,7 +133,7 @@ impl<Io: AsyncRead + AsyncWrite + Unpin> Future for SubscriberSession<Io> {
130133

131134
match this.conn.start_send_unpin(msg) {
132135
Ok(_) => {
133-
this.state.stats.specific.increment_tx(msg_len);
136+
this.stats.specific.increment_tx(msg_len);
134137
}
135138
Err(e) => {
136139
error!(err = ?e, "Failed to send message to socket");

msg-socket/src/pub/socket.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::{net::SocketAddr, path::PathBuf, sync::Arc};
22

3+
use arc_swap::Guard;
34
use bytes::Bytes;
45
use futures::stream::FuturesUnordered;
56
use tokio::{
@@ -28,7 +29,7 @@ pub struct PubSocket<T: Transport<A>, A: Address> {
2829
/// The reply socket options, shared with the driver.
2930
options: Arc<PubOptions>,
3031
/// The reply socket state, shared with the driver.
31-
state: Arc<SocketState>,
32+
state: Arc<SocketState<T::Stats>>,
3233
/// The transport used by this socket. This value is temporary and will be moved
3334
/// to the driver task once the socket is bound.
3435
transport: Option<T>,
@@ -89,7 +90,7 @@ where
8990
to_sessions_bcast: None,
9091
options: Arc::new(options),
9192
transport: Some(transport),
92-
state: Arc::new(SocketState::default()),
93+
state: Arc::new(SocketState::<T::Stats>::default()),
9394
hook: None,
9495
compressor: None,
9596
}
@@ -212,6 +213,11 @@ where
212213
&self.state.stats.specific
213214
}
214215

216+
/// Get the latest transport-level stats snapshot.
217+
pub fn transport_stats(&self) -> Guard<Arc<T::Stats>> {
218+
self.state.transport_stats.load()
219+
}
220+
215221
/// Returns the local address this socket is bound to. `None` if the socket is not bound.
216222
pub fn local_addr(&self) -> Option<&A> {
217223
self.local_addr.as_ref()

0 commit comments

Comments
 (0)