diff --git a/tests/compression/src/lib.rs b/tests/compression/src/lib.rs index 02f729b60..be1c77b8f 100644 --- a/tests/compression/src/lib.rs +++ b/tests/compression/src/lib.rs @@ -66,10 +66,15 @@ impl test_server::Test for Svc { &self, _req: Request<()>, ) -> Result, Status> { - let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(); - let stream = tokio_stream::iter(std::iter::repeat(SomeData { data })) - .take(2) - .map(Ok::<_, Status>); + // Messages smaller than 1024 don't get compressed and we want to + // test that the first message doesn't get compressed + let small = vec![0u8; UNCOMPRESSED_MIN_BODY_SIZE - 100]; + let big = vec![0u8; UNCOMPRESSED_MIN_BODY_SIZE]; + + let stream = tokio_stream::iter([ + Ok::<_, Status>(SomeData { data: small }), + Ok::<_, Status>(SomeData { data: big }), + ]); Ok(self.prepare_response(Response::new(Box::pin(stream)))) } diff --git a/tests/compression/src/server_stream.rs b/tests/compression/src/server_stream.rs index 7a6e1dffe..2d468eb8f 100644 --- a/tests/compression/src/server_stream.rs +++ b/tests/compression/src/server_stream.rs @@ -58,7 +58,8 @@ async fn client_enabled_server_enabled(encoding: CompressionEncoding) { .await .expect("stream empty") .expect("item was error"); - assert!(response_bytes_counter.load(SeqCst) < UNCOMPRESSED_MIN_BODY_SIZE); + // The first message shouldn't get compressed because it's below the threshold + assert!(response_bytes_counter.load(SeqCst) > UNCOMPRESSED_MIN_BODY_SIZE - 100); stream .next() diff --git a/tonic/src/codec/compression.rs b/tonic/src/codec/compression.rs index 3c2f3420a..3dae032b2 100644 --- a/tonic/src/codec/compression.rs +++ b/tonic/src/codec/compression.rs @@ -4,13 +4,34 @@ use bytes::{Buf, BufMut, BytesMut}; use flate2::read::{GzDecoder, GzEncoder}; #[cfg(feature = "deflate")] use flate2::read::{ZlibDecoder, ZlibEncoder}; -use std::{borrow::Cow, fmt}; +use std::{borrow::Cow, fmt, sync::OnceLock}; #[cfg(feature = "zstd")] use zstd::stream::read::{Decoder, Encoder}; pub(crate) const ENCODING_HEADER: &str = "grpc-encoding"; pub(crate) const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding"; +/// Get the compression threshold from environment variable or default (1024 bytes) +fn get_compression_threshold() -> usize { + static THRESHOLD: OnceLock = OnceLock::new(); + *THRESHOLD.get_or_init(|| { + std::env::var("TONIC_COMPRESSION_THRESHOLD") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(1024) + }) +} + +/// Get the spawn_blocking threshold from environment variable (disabled by default) +fn get_spawn_blocking_threshold() -> Option { + static THRESHOLD: OnceLock> = OnceLock::new(); + *THRESHOLD.get_or_init(|| { + std::env::var("TONIC_SPAWN_BLOCKING_THRESHOLD") + .ok() + .and_then(|v| v.parse().ok()) + }) +} + /// Struct used to configure which encodings are enabled on a server or channel. /// /// Represents an ordered list of compression encodings that are enabled. @@ -77,6 +98,26 @@ pub(crate) struct CompressionSettings { /// buffer_growth_interval controls memory growth for internal buffers to balance resizing cost against memory waste. /// The default buffer growth interval is 8 kilobytes. pub(crate) buffer_growth_interval: usize, + /// Minimum message size (in bytes) to compress. Messages smaller than this are sent uncompressed. + /// Can be configured via TONIC_COMPRESSION_THRESHOLD environment variable. Default: 1024 bytes. + pub(crate) compression_threshold: usize, + /// Minimum message size (in bytes) to use spawn_blocking for compression. + /// If set, messages larger than this threshold will be compressed on a blocking thread pool. + /// Can be configured via TONIC_SPAWN_BLOCKING_THRESHOLD environment variable. Default: None (disabled). + pub(crate) spawn_blocking_threshold: Option, +} + +impl CompressionSettings { + /// Create new CompressionSettings with thresholds loaded from environment variables + #[inline] + pub(crate) fn new(encoding: CompressionEncoding, buffer_growth_interval: usize) -> Self { + Self { + encoding, + buffer_growth_interval, + compression_threshold: get_compression_threshold(), + spawn_blocking_threshold: get_spawn_blocking_threshold(), + } + } } /// The compression encodings Tonic supports. @@ -252,6 +293,7 @@ pub(crate) fn compress( } /// Decompress `len` bytes from `compressed_buf` into `out_buf`. +#[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))] #[allow(unused_variables, unreachable_code)] pub(crate) fn decompress( settings: CompressionSettings, diff --git a/tonic/src/codec/decode.rs b/tonic/src/codec/decode.rs index a221a5c93..270ac87f7 100644 --- a/tonic/src/codec/decode.rs +++ b/tonic/src/codec/decode.rs @@ -1,4 +1,6 @@ -use super::compression::{decompress, CompressionEncoding, CompressionSettings}; +use super::compression::CompressionEncoding; +#[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))] +use super::compression::{decompress, CompressionSettings}; use super::{BufferSettings, DecodeBuf, Decoder, DEFAULT_MAX_RECV_MESSAGE_SIZE, HEADER_SIZE}; use crate::{body::Body, metadata::MetadataMap, Code, Status}; use bytes::{Buf, BufMut, BytesMut}; @@ -30,6 +32,7 @@ struct StreamingInner { direction: Direction, buf: BytesMut, trailers: Option, + #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))] decompress_buf: BytesMut, encoding: Option, max_message_size: Option, @@ -136,6 +139,7 @@ impl Streaming { direction, buf: BytesMut::with_capacity(buffer_size), trailers: None, + #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))] decompress_buf: BytesMut::new(), encoding, max_message_size, @@ -147,6 +151,10 @@ impl Streaming { impl StreamingInner { fn decode_chunk( &mut self, + #[cfg_attr( + not(any(feature = "gzip", feature = "deflate", feature = "zstd")), + allow(unused_variables) + )] buffer_settings: BufferSettings, ) -> Result>, Status> { if let State::ReadHeader = self.state { @@ -209,29 +217,36 @@ impl StreamingInner { return Ok(None); } - let decode_buf = if let Some(encoding) = compression { - self.decompress_buf.clear(); - - if let Err(err) = decompress( - CompressionSettings { - encoding, - buffer_growth_interval: buffer_settings.buffer_size, - }, - &mut self.buf, - &mut self.decompress_buf, - len, - ) { - let message = if let Direction::Response(status) = self.direction { - format!( - "Error decompressing: {err}, while receiving response with status: {status}" - ) - } else { - format!("Error decompressing: {err}, while sending request") - }; - return Err(Status::internal(message)); + let decode_buf = if let Some(_encoding) = compression { + #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))] + { + let encoding = _encoding; + self.decompress_buf.clear(); + + if let Err(err) = decompress( + CompressionSettings::new(encoding, buffer_settings.buffer_size), + &mut self.buf, + &mut self.decompress_buf, + len, + ) { + let message = if let Direction::Response(status) = self.direction { + format!( + "Error decompressing: {err}, while receiving response with status: {status}" + ) + } else { + format!("Error decompressing: {err}, while sending request") + }; + return Err(Status::internal(message)); + } + let decompressed_len = self.decompress_buf.len(); + DecodeBuf::new(&mut self.decompress_buf, decompressed_len) + } + #[cfg(not(any(feature = "gzip", feature = "deflate", feature = "zstd")))] + { + // This branch is unreachable when no compression features are enabled + // because CompressionEncoding has no variants + unreachable!("Compression encoding without compression features") } - let decompressed_len = self.decompress_buf.len(); - DecodeBuf::new(&mut self.decompress_buf, decompressed_len) } else { DecodeBuf::new(&mut self.buf, len) }; diff --git a/tonic/src/codec/encode.rs b/tonic/src/codec/encode.rs index 7568f0515..c7858c3ef 100644 --- a/tonic/src/codec/encode.rs +++ b/tonic/src/codec/encode.rs @@ -7,12 +7,24 @@ use bytes::{BufMut, Bytes, BytesMut}; use http::HeaderMap; use http_body::{Body, Frame}; use pin_project::pin_project; +#[cfg(any(feature = "transport", feature = "channel", feature = "server"))] +use std::future::Future; use std::{ pin::Pin, task::{ready, Context, Poll}, }; +#[cfg(any(feature = "transport", feature = "channel", feature = "server"))] +use tokio::task::JoinHandle; use tokio_stream::{adapters::Fuse, Stream, StreamExt}; +#[cfg(any(feature = "transport", feature = "channel", feature = "server"))] +#[derive(Debug)] +struct CompressionResult { + compressed_data: BytesMut, + was_compressed: bool, + encoding: Option, +} + /// Combinator for efficient encoding of messages into reasonably sized buffers. /// EncodedBytes encodes ready messages from its delegate stream into a BytesMut, /// splitting off and yielding a buffer when either: @@ -29,6 +41,9 @@ struct EncodedBytes { buf: BytesMut, uncompression_buf: BytesMut, error: Option, + #[cfg(any(feature = "transport", feature = "channel", feature = "server"))] + #[pin] + compression_task: Option>>, } impl EncodedBytes { @@ -63,6 +78,145 @@ impl EncodedBytes { buf, uncompression_buf, error: None, + #[cfg(any(feature = "transport", feature = "channel", feature = "server"))] + compression_task: None, + } + } +} + +impl EncodedBytes +where + T: Encoder, + U: Stream>, +{ + fn encode_item_uncompressed( + encoder: &mut T, + item: T::Item, + buf: &mut BytesMut, + max_message_size: Option, + ) -> Result<(), Status> { + let offset = buf.len(); + buf.reserve(HEADER_SIZE); + unsafe { + buf.advance_mut(HEADER_SIZE); + } + + if let Err(err) = encoder.encode(item, &mut EncodeBuf::new(buf)) { + return Err(Status::internal(format!("Error encoding: {err}"))); + } + + finish_encoding(None, max_message_size, &mut buf[offset..]) + } + + /// Process the next item from the stream + /// Returns true if we should spawn a blocking task (sets up compression_task) + /// Returns false if item was processed inline + fn process_next_item( + encoder: &mut T, + item: T::Item, + buf: &mut BytesMut, + uncompression_buf: &mut BytesMut, + compression_encoding: Option, + max_message_size: Option, + #[cfg(any(feature = "transport", feature = "channel", feature = "server"))] + compression_task: &mut Pin< + &mut Option>>, + >, + buffer_settings: &BufferSettings, + ) -> Result { + let compression_settings = compression_encoding + .map(|encoding| CompressionSettings::new(encoding, buffer_settings.buffer_size)); + + if let Some(settings) = compression_settings { + uncompression_buf.clear(); + if let Err(err) = encoder.encode(item, &mut EncodeBuf::new(uncompression_buf)) { + return Err(Status::internal(format!("Error encoding: {err}"))); + } + + let uncompressed_len = uncompression_buf.len(); + + // Check if we should use spawn_blocking (only when tokio is available) + #[cfg(any(feature = "transport", feature = "channel", feature = "server"))] + if let Some(spawn_threshold) = settings.spawn_blocking_threshold { + if uncompressed_len >= spawn_threshold + && uncompressed_len >= settings.compression_threshold + { + let data_to_compress = uncompression_buf.split().freeze(); + + let task = tokio::task::spawn_blocking(move || { + compress_blocking(data_to_compress, settings) + }); + + compression_task.set(Some(task)); + return Ok(true); + } + } + + compress_and_encode_item( + buf, + uncompression_buf, + settings, + max_message_size, + uncompressed_len, + )?; + } else { + Self::encode_item_uncompressed(encoder, item, buf, max_message_size)?; + } + + Ok(false) + } + + #[cfg(any(feature = "transport", feature = "channel", feature = "server"))] + fn poll_compression_task( + compression_task: &mut Pin<&mut Option>>>, + buf: &mut BytesMut, + max_message_size: Option, + buffer_settings: &BufferSettings, + cx: &mut Context<'_>, + ) -> Poll>> { + if let Some(task) = compression_task.as_mut().as_pin_mut() { + match Future::poll(task, cx) { + Poll::Ready(Ok(Ok(result))) => { + compression_task.set(None); + + buf.reserve(HEADER_SIZE + result.compressed_data.len()); + let offset = buf.len(); + + unsafe { + buf.advance_mut(HEADER_SIZE); + } + + buf.extend_from_slice(&result.compressed_data); + + let final_compression = if result.was_compressed { + result.encoding + } else { + None + }; + + if let Err(status) = + finish_encoding(final_compression, max_message_size, &mut buf[offset..]) + { + return Poll::Ready(Some(Err(status))); + } + + if buf.len() >= buffer_settings.yield_threshold { + return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze()))); + } + Poll::Ready(None) + } + Poll::Ready(Ok(Err(status))) => { + compression_task.set(None); + Poll::Ready(Some(Err(status))) + } + Poll::Ready(Err(_)) => { + compression_task.set(None); + Poll::Ready(Some(Err(Status::internal("compression task panicked")))) + } + Poll::Pending => Poll::Pending, + } + } else { + Poll::Ready(None) } } } @@ -83,6 +237,8 @@ where buf, uncompression_buf, error, + #[cfg(any(feature = "transport", feature = "channel", feature = "server"))] + mut compression_task, } = self.project(); let buffer_settings = encoder.buffer_settings(); @@ -90,6 +246,24 @@ where return Poll::Ready(Some(Err(status))); } + // Check if we have an in-flight compression task + #[cfg(any(feature = "transport", feature = "channel", feature = "server"))] + { + match Self::poll_compression_task( + &mut compression_task, + buf, + *max_message_size, + &buffer_settings, + cx, + ) { + Poll::Ready(Some(result)) => return Poll::Ready(Some(result)), + Poll::Pending => return Poll::Pending, + Poll::Ready(None) => { + // Task completed, continue processing + } + } + } + loop { match source.as_mut().poll_next(cx) { Poll::Pending if buf.is_empty() => { @@ -102,20 +276,70 @@ where return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze()))); } Poll::Ready(Some(Ok(item))) => { - if let Err(status) = encode_item( + match Self::process_next_item( encoder, + item, buf, uncompression_buf, *compression_encoding, *max_message_size, - buffer_settings, - item, + #[cfg(any( + feature = "transport", + feature = "channel", + feature = "server" + ))] + &mut compression_task, + &buffer_settings, ) { - return Poll::Ready(Some(Err(status))); - } - - if buf.len() >= buffer_settings.yield_threshold { - return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze()))); + Ok(true) => { + #[cfg(any( + feature = "transport", + feature = "channel", + feature = "server" + ))] + { + // We just spawned/armed the blocking compression task. + // Poll it once right away so it can capture our waker. + match Self::poll_compression_task( + &mut compression_task, + buf, + *max_message_size, + &buffer_settings, + cx, + ) { + Poll::Ready(Some(result)) => { + return Poll::Ready(Some(result)); + } + Poll::Ready(None) => { + if buf.len() >= buffer_settings.yield_threshold { + return Poll::Ready(Some(Ok(buf + .split_to(buf.len()) + .freeze()))); + } + } + Poll::Pending => { + return Poll::Pending; + } + } + } + #[cfg(not(any( + feature = "transport", + feature = "channel", + feature = "server" + )))] + { + // This shouldn't happen when tokio is not available + unreachable!("spawn_blocking returned true without tokio") + } + } + Ok(false) => { + if buf.len() >= buffer_settings.yield_threshold { + return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze()))); + } + } + Err(status) => { + return Poll::Ready(Some(Err(status))); + } } } Poll::Ready(Some(Err(status))) => { @@ -130,18 +354,39 @@ where } } -fn encode_item( - encoder: &mut T, +/// Compress data in a blocking task (called via spawn_blocking) +#[cfg(any(feature = "transport", feature = "channel", feature = "server"))] +fn compress_blocking( + data: Bytes, + settings: CompressionSettings, +) -> Result { + let uncompressed_len = data.len(); + let mut uncompression_buf = BytesMut::from(data.as_ref()); + let mut compressed_buf = BytesMut::new(); + + compress( + settings, + &mut uncompression_buf, + &mut compressed_buf, + uncompressed_len, + ) + .map_err(|err| Status::internal(format!("Error compressing: {err}")))?; + + Ok(CompressionResult { + compressed_data: compressed_buf, + was_compressed: true, + encoding: Some(settings.encoding), + }) +} + +/// Compress and encode an already-serialized item inline (without spawn_blocking) +fn compress_and_encode_item( buf: &mut BytesMut, uncompression_buf: &mut BytesMut, - compression_encoding: Option, + settings: CompressionSettings, max_message_size: Option, - buffer_settings: BufferSettings, - item: T::Item, -) -> Result<(), Status> -where - T: Encoder, -{ + uncompressed_len: usize, +) -> Result<(), Status> { let offset = buf.len(); buf.reserve(HEADER_SIZE); @@ -149,33 +394,24 @@ where buf.advance_mut(HEADER_SIZE); } - if let Some(encoding) = compression_encoding { - uncompression_buf.clear(); - - encoder - .encode(item, &mut EncodeBuf::new(uncompression_buf)) - .map_err(|err| Status::internal(format!("Error encoding: {err}")))?; + let mut was_compressed = false; - let uncompressed_len = uncompression_buf.len(); - - compress( - CompressionSettings { - encoding, - buffer_growth_interval: buffer_settings.buffer_size, - }, - uncompression_buf, - buf, - uncompressed_len, - ) - .map_err(|err| Status::internal(format!("Error compressing: {err}")))?; + if uncompressed_len >= settings.compression_threshold { + compress(settings, uncompression_buf, buf, uncompressed_len) + .map_err(|err| Status::internal(format!("Error compressing: {err}")))?; + was_compressed = true; } else { - encoder - .encode(item, &mut EncodeBuf::new(buf)) - .map_err(|err| Status::internal(format!("Error encoding: {err}")))?; + buf.reserve(uncompressed_len); + buf.extend_from_slice(&uncompression_buf[..]); } // now that we know length, we can write the header - finish_encoding(compression_encoding, max_message_size, &mut buf[offset..]) + let final_compression = if was_compressed { + Some(settings.encoding) + } else { + None + }; + finish_encoding(final_compression, max_message_size, &mut buf[offset..]) } fn finish_encoding(