Skip to content

proto: Replace write_source with WriteGuard<'_> #2242

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 54 additions & 14 deletions quinn-proto/src/connection/streams/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@ use recv::Recv;
pub use recv::{Chunks, ReadError, ReadableError};

mod send;
pub(crate) use send::{ByteSlice, BytesArray};
use send::{BytesSource, Send, SendState};
pub use send::{FinishError, WriteError, Written};
use send::{Send, SendState};

mod state;
#[allow(unreachable_pub)] // fuzzing only
Expand Down Expand Up @@ -221,7 +220,10 @@ impl<'a> SendStream<'a> {
///
/// Returns the number of bytes successfully written.
pub fn write(&mut self, data: &[u8]) -> Result<usize, WriteError> {
Ok(self.write_source(&mut ByteSlice::from_slice(data))?.bytes)
let mut guard = self.write_guard()?;
let written = data.len().min(guard.limit);
guard.write(data[..written].to_vec().into());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
guard.write(data[..written].to_vec().into());
guard.write(Bytes::copy_from_slice(&data[..written]));

Ok(written)
}

/// Send data on the given stream
Expand All @@ -231,10 +233,25 @@ impl<'a> SendStream<'a> {
/// [`Written::chunks`] will not count this chunk as fully written. However
/// the chunk will be advanced and contain only non-written data after the call.
pub fn write_chunks(&mut self, data: &mut [Bytes]) -> Result<Written, WriteError> {
self.write_source(&mut BytesArray::from_chunks(data))
let mut guard = self.write_guard()?;
let mut written = Written::default();
for chunk in data {
let prefix = chunk.split_to(chunk.len().min(guard.limit));
written.bytes += prefix.len();
guard.write(prefix);

if chunk.is_empty() {
written.chunks += 1;
}

if guard.limit == 0 {
break;
}
}
Ok(written)
}

fn write_source<B: BytesSource>(&mut self, source: &mut B) -> Result<Written, WriteError> {
fn write_guard(&mut self) -> Result<WriteGuard, WriteError> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bikeshed: build_write, or maybe begin_write.

if self.conn_state.is_closed() {
trace!(%self.id, "write blocked; connection draining");
return Err(WriteError::Blocked);
Expand Down Expand Up @@ -263,15 +280,16 @@ impl<'a> SendStream<'a> {
return Err(WriteError::Blocked);
}

let was_pending = stream.is_pending();
let written = stream.write(source, limit)?;
self.state.data_sent += written.bytes as u64;
self.state.unacked_data += written.bytes as u64;
trace!(stream = %self.id, "wrote {} bytes", written.bytes);
if !was_pending {
self.state.pending.push_pending(self.id, stream.priority);
}
Ok(written)
let limit = stream.write_limit(limit)?;

Ok(WriteGuard {
limit,
stream,
id: self.id,
data_sent: &mut self.state.data_sent,
unacked_data: &mut self.state.unacked_data,
pending: &mut self.state.pending,
})
}

/// Check if this stream was stopped, get the reason if it was
Expand Down Expand Up @@ -367,6 +385,28 @@ impl<'a> SendStream<'a> {
}
}

struct WriteGuard<'a> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bikeshed: WriteBuilder? "Guard" usually means "cleans something up on Drop".

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe StreamWriter? Builder invokes builder pattern for construction to me.

limit: usize,
id: StreamId,
stream: &'a mut Send,
data_sent: &'a mut u64,
unacked_data: &'a mut u64,
pending: &'a mut PendingStreamsQueue,
Comment on lines +391 to +394
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these be a single reference to a helper struct? That might also be easier to adapt to the quinn layer.

}

impl<'a> WriteGuard<'a> {
fn write(&mut self, bytes: Bytes) {
self.limit -= bytes.len();
*self.data_sent += bytes.len() as u64;
*self.unacked_data += bytes.len() as u64;
let was_pending = self.stream.is_pending();
self.stream.pending.write(bytes);
if !was_pending {
self.pending.push_pending(self.id, self.stream.priority);
}
}
}

/// A queue of streams with pending outgoing data, sorted by priority
struct PendingStreamsQueue {
streams: BinaryHeap<PendingStream>,
Expand Down
222 changes: 2 additions & 220 deletions quinn-proto/src/connection/streams/send.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use bytes::Bytes;
use thiserror::Error;

use crate::{VarInt, connection::send_buffer::SendBuffer, frame};
Expand Down Expand Up @@ -49,11 +48,7 @@ impl Send {
}
}

pub(super) fn write<S: BytesSource>(
&mut self,
source: &mut S,
limit: u64,
) -> Result<Written, WriteError> {
pub(super) fn write_limit(&self, limit: u64) -> Result<usize, WriteError> {
if !self.is_writable() {
return Err(WriteError::ClosedStream);
}
Expand All @@ -64,23 +59,7 @@ impl Send {
if budget == 0 {
return Err(WriteError::Blocked);
}
let mut limit = limit.min(budget) as usize;

let mut result = Written::default();
loop {
let (chunk, chunks_consumed) = source.pop_chunk(limit);
result.chunks += chunks_consumed;
result.bytes += chunk.len();

if chunk.is_empty() {
break;
}

limit -= chunk.len();
self.pending.write(chunk);
}

Ok(result)
Ok(limit.min(budget) as usize)
}

/// Update stream state due to a reset sent by the local application
Expand Down Expand Up @@ -143,106 +122,6 @@ impl Send {
}
}

/// A [`BytesSource`] implementation for `&'a mut [Bytes]`
///
/// The type allows to dequeue [`Bytes`] chunks from an array of chunks, up to
/// a configured limit.
pub(crate) struct BytesArray<'a> {
/// The wrapped slice of `Bytes`
chunks: &'a mut [Bytes],
/// The amount of chunks consumed from this source
consumed: usize,
}

impl<'a> BytesArray<'a> {
pub(crate) fn from_chunks(chunks: &'a mut [Bytes]) -> Self {
Self {
chunks,
consumed: 0,
}
}
}

impl BytesSource for BytesArray<'_> {
fn pop_chunk(&mut self, limit: usize) -> (Bytes, usize) {
// The loop exists to skip empty chunks while still marking them as
// consumed
let mut chunks_consumed = 0;

while self.consumed < self.chunks.len() {
let chunk = &mut self.chunks[self.consumed];

if chunk.len() <= limit {
let chunk = std::mem::take(chunk);
self.consumed += 1;
chunks_consumed += 1;
if chunk.is_empty() {
continue;
}
return (chunk, chunks_consumed);
} else if limit > 0 {
let chunk = chunk.split_to(limit);
return (chunk, chunks_consumed);
} else {
break;
}
}

(Bytes::new(), chunks_consumed)
}
}

/// A [`BytesSource`] implementation for `&[u8]`
///
/// The type allows to dequeue a single [`Bytes`] chunk, which will be lazily
/// created from a reference. This allows to defer the allocation until it is
/// known how much data needs to be copied.
pub(crate) struct ByteSlice<'a> {
/// The wrapped byte slice
data: &'a [u8],
}

impl<'a> ByteSlice<'a> {
pub(crate) fn from_slice(data: &'a [u8]) -> Self {
Self { data }
}
}

impl BytesSource for ByteSlice<'_> {
fn pop_chunk(&mut self, limit: usize) -> (Bytes, usize) {
let limit = limit.min(self.data.len());
if limit == 0 {
return (Bytes::new(), 0);
}

let chunk = Bytes::from(self.data[..limit].to_owned());
self.data = &self.data[chunk.len()..];

let chunks_consumed = usize::from(self.data.is_empty());
(chunk, chunks_consumed)
}
}

/// A source of one or more buffers which can be converted into `Bytes` buffers on demand
///
/// The purpose of this data type is to defer conversion as long as possible,
/// so that no heap allocation is required in case no data is writable.
pub(super) trait BytesSource {
/// Returns the next chunk from the source of owned chunks.
///
/// This method will consume parts of the source.
/// Calling it will yield `Bytes` elements up to the configured `limit`.
///
/// The method returns a tuple:
/// - The first item is the yielded `Bytes` element. The element will be
/// empty if the limit is zero or no more data is available.
/// - The second item returns how many complete chunks inside the source had
/// had been consumed. This can be less than 1, if a chunk inside the
/// source had been truncated in order to adhere to the limit. It can also
/// be more than 1, if zero-length chunks had been skipped.
fn pop_chunk(&mut self, limit: usize) -> (Bytes, usize);
}

/// Indicates how many bytes and chunks had been transferred in a write operation
#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
pub struct Written {
Expand Down Expand Up @@ -303,100 +182,3 @@ pub enum FinishError {
#[error("closed stream")]
ClosedStream,
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn bytes_array() {
let full = b"Hello World 123456789 ABCDEFGHJIJKLMNOPQRSTUVWXYZ".to_owned();
for limit in 0..full.len() {
let mut chunks = [
Bytes::from_static(b""),
Bytes::from_static(b"Hello "),
Bytes::from_static(b"Wo"),
Bytes::from_static(b""),
Bytes::from_static(b"r"),
Bytes::from_static(b"ld"),
Bytes::from_static(b""),
Bytes::from_static(b" 12345678"),
Bytes::from_static(b"9 ABCDE"),
Bytes::from_static(b"F"),
Bytes::from_static(b"GHJIJKLMNOPQRSTUVWXYZ"),
];
let num_chunks = chunks.len();
let last_chunk_len = chunks[chunks.len() - 1].len();

let mut array = BytesArray::from_chunks(&mut chunks);

let mut buf = Vec::new();
let mut chunks_popped = 0;
let mut chunks_consumed = 0;
let mut remaining = limit;
loop {
let (chunk, consumed) = array.pop_chunk(remaining);
chunks_consumed += consumed;

if !chunk.is_empty() {
buf.extend_from_slice(&chunk);
remaining -= chunk.len();
chunks_popped += 1;
} else {
break;
}
}

assert_eq!(&buf[..], &full[..limit]);

if limit == full.len() {
// Full consumption of the last chunk
assert_eq!(chunks_consumed, num_chunks);
// Since there are empty chunks, we consume more than there are popped
assert_eq!(chunks_consumed, chunks_popped + 3);
} else if limit > full.len() - last_chunk_len {
// Partial consumption of the last chunk
assert_eq!(chunks_consumed, num_chunks - 1);
assert_eq!(chunks_consumed, chunks_popped + 2);
}
}
}

#[test]
fn byte_slice() {
let full = b"Hello World 123456789 ABCDEFGHJIJKLMNOPQRSTUVWXYZ".to_owned();
for limit in 0..full.len() {
let mut array = ByteSlice::from_slice(&full[..]);

let mut buf = Vec::new();
let mut chunks_popped = 0;
let mut chunks_consumed = 0;
let mut remaining = limit;
loop {
let (chunk, consumed) = array.pop_chunk(remaining);
chunks_consumed += consumed;

if !chunk.is_empty() {
buf.extend_from_slice(&chunk);
remaining -= chunk.len();
chunks_popped += 1;
} else {
break;
}
}

assert_eq!(&buf[..], &full[..limit]);
if limit != 0 {
assert_eq!(chunks_popped, 1);
} else {
assert_eq!(chunks_popped, 0);
}

if limit == full.len() {
assert_eq!(chunks_consumed, 1);
} else {
assert_eq!(chunks_consumed, 0);
}
}
}
}
Loading