Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,237 changes: 1,209 additions & 28 deletions Cargo.lock

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
[package]
name = "systemd-udp-proxy"
version = "0.1.3"
version = "0.2.0"
edition = "2024"

[dependencies]
clap = { version = "4.5", features = ["derive"] }
env_logger = "0.11"
listenfd = "1.0"
log = "0.4"
opentelemetry = "0.31"
opentelemetry_sdk = { version = "0.31", features = ["rt-tokio"] }
opentelemetry-otlp = { version = "0.31", features = ["metrics", "grpc-tonic"] }
opentelemetry-semantic-conventions = { version = "0.31", features = ["semconv_experimental"] }
tokio = { version = "1.48", features = ["io-util", "macros", "net", "rt-multi-thread", "sync", "time"] }
44 changes: 37 additions & 7 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,18 @@ use clap::Parser;
use listenfd::ListenFd;
#[cfg(not(debug_assertions))]
use log::warn;
use primary_tasks::{rx_task, tx_task};
use primary_tasks::{SessionCache, rx_task, tx_task};
use session::SessionReply;
use tokio::{net::UdpSocket, sync::mpsc};
use tokio::{
net::UdpSocket,
sync::{RwLock, mpsc},
};

mod error_util;
mod log_config;
mod primary_tasks;
mod session;
mod telemetry;

#[derive(Parser, Debug)]
struct ProxyConfig {
Expand All @@ -34,9 +38,20 @@ struct ProxyConfig {
/// How many seconds sessions should be cached before expiring
#[arg(short = 't', long, default_value_t = 60)]
session_timeout: u64,
}
/// Maximum UDP packet size to receive in bytes (packets larger will be truncated)
#[arg(short = 'm', long, default_value_t = 1500)]
max_packet_size: usize,

const MAX_UDP_PACKET_SIZE: u16 = u16::MAX;
/// The OTel collector endpoint
#[arg(long, default_value = "http://localhost:4317")]
otel_endpoint: String,
/// The service name for OTel tagging
#[arg(long, default_value = "systemd-udp-proxy")]
service_name: String,
/// The deployment environment for OTel tagging
#[arg(long, default_value = "prod")]
environment: String,
}

#[tokio::main]
async fn main() -> io::Result<()> {
Expand Down Expand Up @@ -69,10 +84,25 @@ async fn main() -> io::Result<()> {
let source_socket = Arc::new(UdpSocket::from_std(std_source_socket)?);
let (reply_channel_tx, reply_channel_rx) = mpsc::unbounded_channel::<SessionReply>();

let rx_task = tokio::spawn(rx_task(config, reply_channel_tx, source_socket.clone()));
let tx_task = tokio::spawn(tx_task(reply_channel_rx, source_socket.clone()));
let sessions = Arc::new(RwLock::new(SessionCache::new()));
let (metrics, meter) = telemetry::init_metrics(&config, sessions.clone()).map_err(|err| {
io::Error::other(format!("Failed to initialize OTel metrics exporter: {err}"))
})?;

let rx_task = tokio::spawn(rx_task(
config,
reply_channel_tx,
source_socket.clone(),
sessions,
metrics.clone(),
));
let tx_task = tokio::spawn(tx_task(
reply_channel_rx,
source_socket.clone(),
metrics.clone(),
));

rx_task.await??;
tx_task.await??;
Ok(())
meter.shutdown().map_err(io::Error::other)
}
1 change: 1 addition & 0 deletions src/primary_tasks/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod rx_task;
mod tx_task;

pub use rx_task::SessionCache;
pub use rx_task::rx_task;
pub use tx_task::tx_task;
42 changes: 30 additions & 12 deletions src/primary_tasks/rx_task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,38 @@ use tokio::{
};

use crate::{
MAX_UDP_PACKET_SIZE, ProxyConfig,
ProxyConfig,
error_util::{ErrorAction, handle_io_error},
session::{Session, SessionReply, SessionSource},
telemetry::{NetworkDirection, Peer, ProxyMetrics},
};

type SessionChannel = UnboundedSender<Vec<u8>>;
type SessionCache = HashMap<SessionSource, (SessionChannel, Arc<Session>)>;
pub type SessionCache = HashMap<SessionSource, (SessionChannel, Arc<Session>)>;

/// Loops infinitely over the `rx_socket` to recieve traffic from the original source of the proxy.
///
/// For each unique [`std::net::SocketAddr`] that sends traffic to `rx_socket`, a [`Session`] is created and
/// tx/rx loop tasks are spawned to proxy traffic for that session to and from the destination. If a [`Session`]
/// does not recieve traffic for [`ProxyConfig::session_timeout`] seconds, it will close its tasks and a new one will
/// be created if any traffic resumes from it.
///
/// If a packet arrives after a Session's channel has closed but before the session is removed
/// from the cache, that packet will be dropped and the session will be cleaned up. Subsequent
/// packets from the same source will trigger creation of a new session.
pub async fn rx_task(
config: ProxyConfig,
reply_channel_tx: UnboundedSender<SessionReply>,
rx_socket: Arc<UdpSocket>,
sessions: Arc<RwLock<SessionCache>>,
metrics: Arc<ProxyMetrics>,
) -> io::Result<()> {
let shared_reply_channel = Arc::new(reply_channel_tx);
let sessions = Arc::new(RwLock::new(SessionCache::new()));
let dir = NetworkDirection::Receive;
let peer = Peer::Client;

loop {
let mut buf = Vec::with_capacity(MAX_UDP_PACKET_SIZE.into());
let mut buf = Vec::with_capacity(config.max_packet_size);
match rx_socket.recv_buf_from(&mut buf).await {
Err(err) => match handle_io_error(err) {
ErrorAction::Terminate(err) => return Err(err),
Expand All @@ -48,13 +56,14 @@ pub async fn rx_task(
let session_channel_tx = match session_cache.entry(source.into()) {
Entry::Vacant(entry) => {
info!("Creating a new session for {source}");
let session = match Session::new(&config, source.into()).await {
Ok(created_session) => Arc::new(created_session),
Err(err) => {
error!("Failed to create a session for {}: {:?}", source, err);
continue;
}
};
let session =
match Session::new(&config, source.into(), metrics.clone()).await {
Ok(created_session) => Arc::new(created_session),
Err(err) => {
error!("Failed to create a session for {}: {:?}", source, err);
continue;
}
};

let (tx, rx) = mpsc::unbounded_channel();

Expand All @@ -72,7 +81,11 @@ pub async fn rx_task(
let rx_reply_channel = shared_reply_channel.clone();
tokio::spawn(async move {
if let Err(err) = rx_session
.rx_loop(rx_reply_channel, config.session_timeout)
.rx_loop(
rx_reply_channel,
config.session_timeout,
config.max_packet_size,
)
.await
{
error!("RX error for {}: {:?}", source, err);
Expand All @@ -89,12 +102,17 @@ pub async fn rx_task(
}
};

let bytes = buf.len() as u64;
if session_channel_tx.send(buf).is_err() {
error!(
"Dropped packet for {} because its proxy session is closed",
source
);
metrics.count_dropped_packet(&Peer::Client);
sessions.write().await.remove(&source.into());
} else {
metrics.count_packet(&dir, &peer);
metrics.count_bytes(&dir, &peer, bytes);
}
}
}
Expand Down
26 changes: 21 additions & 5 deletions src/primary_tasks/tx_task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use tokio::{net::UdpSocket, sync::mpsc::UnboundedReceiver};
use crate::{
error_util::{ErrorAction, handle_io_error},
session::SessionReply,
telemetry::{NetworkDirection, Peer, ProxyMetrics},
};

/// Loops infinitely over the `reply_channel_rx` to forward traffic from the destination of the proxy.
Expand All @@ -15,17 +16,32 @@ use crate::{
pub async fn tx_task(
mut reply_channel_rx: UnboundedReceiver<SessionReply>,
tx_socket: Arc<UdpSocket>,
metrics: Arc<ProxyMetrics>,
) -> io::Result<()> {
let dir = NetworkDirection::Transmit;
let peer = Peer::Client;

while let Some(reply) = reply_channel_rx.recv().await {
match tx_socket
.send_to(&reply.data, (reply.source.address, reply.source.port))
.await
{
Ok(_) => {}
Err(err) => match handle_io_error(err) {
ErrorAction::Terminate(err) => return Err::<(), io::Error>(err),
ErrorAction::Continue => {}
},
Ok(_) => {
metrics.count_packet(&dir, &peer);
metrics.count_bytes(&dir, &peer, reply.data.len() as u64);
}
Err(err) => {
metrics.count_dropped_packet(&peer);
match handle_io_error(err) {
ErrorAction::Terminate(err) => {
metrics.count_io_error(&dir, &peer, false);
return Err::<(), io::Error>(err);
}
ErrorAction::Continue => {
metrics.count_io_error(&dir, &peer, true);
}
}
}
}
}
Ok(())
Expand Down
80 changes: 62 additions & 18 deletions src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@ use std::{
time::Duration,
};

use log::info;
use log::{info, warn};
use tokio::{
net::UdpSocket,
sync::mpsc::{UnboundedReceiver, UnboundedSender},
time::timeout,
time::{Instant, timeout},
};

use crate::{
MAX_UDP_PACKET_SIZE, ProxyConfig,
ProxyConfig,
error_util::{ErrorAction, handle_io_error},
telemetry::{NetworkDirection, Peer, ProxyMetrics},
};

#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
Expand All @@ -27,15 +28,15 @@ pub struct SessionSource {
/// Wrapper around a [`UdpSocket`] that handles the boiler plate of establishing a connection to the appropriate
/// backend destination. It retains the original [`SessionSource`] of the traffic it will be proxying
/// so that replies from the backend can be properly routed back.
#[derive(Debug)]
pub struct Session {
/// The source that this session is receiving traffic from
source: SessionSource,
/// The socket that this session is using to communicate with the destination
destination_socket: Arc<UdpSocket>,
metrics: Arc<ProxyMetrics>,
start: Instant,
}

#[derive(Debug)]
pub struct SessionReply {
pub source: SessionSource,
pub data: Vec<u8>,
Expand All @@ -51,7 +52,11 @@ impl Session {
/// Establish a new session that binds to an [`ProxyConfig::source_address`] and establishes
/// a connection to [`ProxyConfig::destination_address`] on [`ProxyConfig::destination_port`].
/// Returns an [`io::Error`] if the connection fails to establish.
pub async fn new(config: &ProxyConfig, source: SessionSource) -> io::Result<Self> {
pub async fn new(
config: &ProxyConfig,
source: SessionSource,
metrics: Arc<ProxyMetrics>,
) -> io::Result<Self> {
// Let the OS assign us an available port
let destination_socket = Arc::new(UdpSocket::bind((config.source_address, 0)).await?);
// Connect to the destination
Expand All @@ -62,6 +67,8 @@ impl Session {
Ok(Session {
source,
destination_socket,
metrics,
start: Instant::now(),
})
}

Expand All @@ -74,17 +81,33 @@ impl Session {
session_timeout: u64,
) -> io::Result<()> {
let duration = Duration::from_secs(session_timeout);
let dir = NetworkDirection::Transmit;
let peer = Peer::Backend;

while let Ok(Some(data)) = timeout(duration, source_channel.recv()).await {
match self.destination_socket.send(&data).await {
Ok(_) => {}
Err(err) => match err.kind() {
// Destination service hasn't started yet
ErrorKind::ConnectionRefused => {}
_ => match handle_io_error(err) {
ErrorAction::Terminate(cause) => return Err(cause),
ErrorAction::Continue => {}
},
},
Ok(_) => {
self.metrics.count_packet(&dir, &peer);
self.metrics.count_bytes(&dir, &peer, data.len() as u64);
}
Err(err) => {
self.metrics.count_dropped_packet(&peer);
match err.kind() {
// Destination service hasn't started yet
ErrorKind::ConnectionRefused => {
warn!("Destination service refused connection");
}
_ => match handle_io_error(err) {
ErrorAction::Terminate(cause) => {
self.metrics.count_io_error(&dir, &peer, false);
return Err(cause);
}
ErrorAction::Continue => {
self.metrics.count_io_error(&dir, &peer, true);
}
},
}
}
}
}
info!("Closing tx session for {}", self.source);
Expand All @@ -97,17 +120,30 @@ impl Session {
&self,
reply_channel: Arc<UnboundedSender<SessionReply>>,
session_timeout: u64,
max_packet_size: usize,
) -> io::Result<()> {
let duration = Duration::from_secs(session_timeout);
let dir = NetworkDirection::Receive;
let peer = Peer::Backend;

loop {
let mut buf = Vec::with_capacity(MAX_UDP_PACKET_SIZE.into());
let mut buf = Vec::with_capacity(max_packet_size);
match timeout(duration, self.destination_socket.recv_buf(&mut buf)).await {
Ok(result) => {
if let Err(err) = result {
self.metrics.count_dropped_packet(&peer);
match handle_io_error(err) {
ErrorAction::Terminate(cause) => return Err(cause),
ErrorAction::Continue => {}
ErrorAction::Terminate(cause) => {
self.metrics.count_io_error(&dir, &peer, false);
return Err(cause);
}
ErrorAction::Continue => {
self.metrics.count_io_error(&dir, &peer, true);
}
}
} else {
self.metrics.count_packet(&dir, &peer);
self.metrics.count_bytes(&dir, &peer, buf.len() as u64);
}
}
Err(_timeout_exceeded) => {
Expand All @@ -120,6 +156,7 @@ impl Session {
.send(SessionReply::new(self.source, buf))
.is_err()
{
self.metrics.count_dropped_packet(&peer);
return Err(io::Error::new(
ErrorKind::ConnectionAborted,
"Primary tx task has stopped listening, dropping reply as the proxy will soon terminate",
Expand All @@ -129,6 +166,13 @@ impl Session {
}
}

impl Drop for Session {
fn drop(&mut self) {
self.metrics
.record_session_duration(Instant::now().duration_since(self.start).as_secs_f64());
}
}

impl From<SocketAddr> for SessionSource {
fn from(value: SocketAddr) -> Self {
SessionSource {
Expand Down
Loading