Skip to content

[Server] Explicitly sequence I/O operations on the underlying stream. #179

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@ categories = [
"parser-implementations",
"web-programming",
"web-programming::http-client",
"web-programming::http-server"
"web-programming::http-server",
]
authors = ["Yoshua Wuyts <[email protected]>"]
readme = "README.md"
@@ -28,6 +28,10 @@ log = "0.4.11"
pin-project = "1.0.2"
async-channel = "1.5.1"
async-dup = "1.2.2"
futures-channel = "0.3.12"
futures-io = "0.3.12"
futures-lite = "1.11.3"
futures-util = "0.3.12"

[dev-dependencies]
pretty_assertions = "0.6.1"
658 changes: 233 additions & 425 deletions src/chunked/decoder.rs

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/client/decode.rs
Original file line number Diff line number Diff line change
@@ -80,7 +80,7 @@ where
if let Some(encoding) = transfer_encoding {
if encoding.last().as_str() == "chunked" {
let trailers_sender = res.send_trailers();
let reader = BufReader::new(ChunkedDecoder::new(reader, trailers_sender));
let reader = ChunkedDecoder::new(reader, trailers_sender);
res.set_body(Body::from_reader(reader, None));

// Return the response.
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -108,13 +108,16 @@ mod body_encoder;
mod chunked;
mod date;
mod read_notifier;
mod sequenced;
mod unite;

pub mod client;
pub mod server;

use async_std::io::Cursor;
use body_encoder::BodyEncoder;
pub use client::connect;
pub use sequenced::Sequenced;
pub use server::{accept, accept_with_opts, ServerOptions};

#[derive(Debug)]
232 changes: 232 additions & 0 deletions src/sequenced.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
use core::future::Future;
use core::mem;
use core::pin::Pin;
use core::task::{Context, Poll};

use futures_channel::oneshot;
use futures_core::ready;
use futures_io::{AsyncBufRead, AsyncRead, AsyncWrite};
use futures_lite::future::poll_fn;

#[derive(Debug)]
enum SequencedState<T> {
Active {
value: T,
poisoned: bool,
},
Waiting {
receiver: oneshot::Receiver<Self>,
poisoned: Option<bool>,
},
}

/// Allows multiple asynchronous tasks to access the same reader or writer concurrently
/// without conflicting.
/// The `split_seq` and `split_seq_rev` methods produce a new instance of the type such that
/// all I/O operations on one instance are sequenced before all I/O operations on the other.
///
/// When one task has finished with the reader/writer it should call `release`, which will
/// unblock operations on the task with the other instance. If dropped without calling
/// `release`, the inner reader/writer will become poisoned before being returned. The
/// caller can explicitly remove the poisoned status.
///
/// The `Sequenced<T>` can be split as many times as necessary, and it is valid to call
/// `release()` at any time, although no further operations can be performed via a released
/// instance. If this type is dropped without calling `release()`, then the reader/writer will
/// become poisoned.
///
/// As only one task has access to the reader/writer at once, no additional synchronization
/// is necessary, and so this wrapper adds very little overhead. What synchronization does
/// occur only needs to happen when an instance is released, in order to send its state to
/// the next instance in the sequence.
///
/// Merging can be achieved by simply releasing one of the two instances, and then using the
/// other one as normal. It does not matter Which one is released.
#[derive(Debug)]
pub struct Sequenced<T> {
parent: Option<oneshot::Sender<SequencedState<T>>>,
state: Option<SequencedState<T>>,
}

impl<T> Sequenced<T> {
/// Constructs a new sequenced reader/writer
pub fn new(value: T) -> Self {
Self {
parent: None,
state: Some(SequencedState::Active {
value,
poisoned: false,
}),
}
}
/// Splits this reader/writer into two such that the returned instance is sequenced before this one.
pub fn split_seq(&mut self) -> Self {
let (sender, receiver) = oneshot::channel();
let state = mem::replace(
&mut self.state,
Some(SequencedState::Waiting {
receiver,
poisoned: None,
}),
);
Self {
parent: Some(sender),
state,
}
}
/// Splits this reader/writer into two such that the returned instance is sequenced after this one.
pub fn split_seq_rev(&mut self) -> Self {
let other = self.split_seq();
mem::replace(self, other)
}

/// Release this reader/writer immediately, allowing instances sequenced after this one to proceed.
pub fn release(&mut self) {
if let (Some(state), Some(parent)) = (self.state.take(), self.parent.take()) {
let _ = parent.send(state);
}
}
fn set_poisoned(&mut self, value: bool) {
match &mut self.state {
Some(SequencedState::Active { poisoned, .. }) => *poisoned = value,
Some(SequencedState::Waiting { poisoned, .. }) => *poisoned = Some(value),
None => {}
}
}
/// Removes the poison status if set
pub(crate) fn cure(&mut self) {
self.set_poisoned(false)
}
fn resolve(&mut self, cx: &mut Context<'_>) -> Poll<Option<&mut T>> {
while let Some(SequencedState::Waiting { receiver, poisoned }) = &mut self.state {
if let Some(sender) = &self.parent {
// Check if we're waiting on ourselves.
if sender.is_connected_to(receiver) {
return Poll::Ready(None);
}
}
let poisoned = *poisoned;
self.state = ready!(Pin::new(receiver).poll(cx)).ok();
if let Some(value) = poisoned {
self.set_poisoned(value)
}
}
Poll::Ready(match &mut self.state {
Some(SequencedState::Active {
poisoned: false,
value,
}) => Some(value),
Some(SequencedState::Active { poisoned: true, .. }) => None,
Some(SequencedState::Waiting { .. }) => unreachable!(),
None => None,
})
}
/// Attempt to take the inner reader/writer. This will require waiting until prior instances
/// have been released, and will fail with `None` if any were dropped without being released,
/// or were themselves taken.
/// Instances sequenced after this one will see the reader/writer be closed.
pub fn poll_take_inner(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
ready!(self.as_mut().resolve(cx));
if let Some(SequencedState::Active {
value,
poisoned: false,
}) = self.as_mut().state.take()
{
Poll::Ready(Some(value))
} else {
Poll::Ready(None)
}
}
/// Attempt to take the inner reader/writer. This will require waiting until prior instances
/// have been released, and will fail with `None` if any were dropped without being released,
/// or were themselves taken.
/// Instances sequenced after this one will see the reader/writer be closed.
pub async fn take_inner(&mut self) -> Option<T> {
poll_fn(|cx| Pin::new(&mut *self).poll_take_inner(cx)).await
}

/// Swap the two reader/writers at this sequence point.
pub fn swap(&mut self, other: &mut Self) {
mem::swap(&mut self.state, &mut other.state);
}
}

impl<T> Drop for Sequenced<T> {
// Poison and release the inner reader/writer. Has no effect if the reader/writer
// was already released.
fn drop(&mut self) {
self.set_poisoned(true);
self.release();
}
}

impl<T> Unpin for Sequenced<T> {}

impl<T: Unpin + AsyncRead> AsyncRead for Sequenced<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<futures_io::Result<usize>> {
if let Some(inner) = ready!(self.get_mut().resolve(cx)) {
Pin::new(inner).poll_read(cx, buf)
} else {
Poll::Ready(Ok(0))
}
}
}

impl<T: Unpin + AsyncBufRead> AsyncBufRead for Sequenced<T> {
fn poll_fill_buf(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<futures_io::Result<&[u8]>> {
if let Some(inner) = ready!(self.get_mut().resolve(cx)) {
Pin::new(inner).poll_fill_buf(cx)
} else {
Poll::Ready(Ok(&[]))
}
}

fn consume(self: Pin<&mut Self>, amt: usize) {
if let Some(SequencedState::Active {
value,
poisoned: false,
}) = &mut self.get_mut().state
{
Pin::new(value).consume(amt);
} else if amt > 0 {
panic!("Called `consume()` without having filled the buffer")
}
}
}

impl<T: Unpin + AsyncWrite> AsyncWrite for Sequenced<T> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<futures_io::Result<usize>> {
if let Some(inner) = ready!(self.get_mut().resolve(cx)) {
Pin::new(inner).poll_write(cx, buf)
} else {
Poll::Ready(Ok(0))
}
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
if let Some(inner) = ready!(self.get_mut().resolve(cx)) {
Pin::new(inner).poll_flush(cx)
} else {
Poll::Ready(Ok(()))
}
}

fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
if let Some(inner) = ready!(self.get_mut().resolve(cx)) {
Pin::new(inner).poll_close(cx)
} else {
Poll::Ready(Ok(()))
}
}
}
35 changes: 0 additions & 35 deletions src/server/body_reader.rs

This file was deleted.

145 changes: 125 additions & 20 deletions src/server/decode.rs
Original file line number Diff line number Diff line change
@@ -2,17 +2,18 @@
use std::str::FromStr;

use async_dup::{Arc, Mutex};
use async_std::io::{BufReader, Read, Write};
use async_std::io::{self, BufRead, BufReader, Read, Write};
use async_std::{prelude::*, task};
use futures_channel::oneshot;
use futures_util::{select_biased, FutureExt};
use http_types::content::ContentLength;
use http_types::headers::{EXPECT, TRANSFER_ENCODING};
use http_types::{ensure, ensure_eq, format_err};
use http_types::{Body, Method, Request, Url};

use super::body_reader::BodyReader;
use crate::chunked::ChunkedDecoder;
use crate::read_notifier::ReadNotifier;
use crate::sequenced::Sequenced;
use crate::{MAX_HEADERS, MAX_HEAD_LENGTH};

const LF: u8 = b'\n';
@@ -24,15 +25,64 @@ const CONTINUE_HEADER_VALUE: &str = "100-continue";
const CONTINUE_RESPONSE: &[u8] = b"HTTP/1.1 100 Continue\r\n\r\n";

/// Decode an HTTP request on the server.
pub async fn decode<IO>(mut io: IO) -> http_types::Result<Option<(Request, BodyReader<IO>)>>
pub async fn decode<IO>(io: IO) -> http_types::Result<Option<(Request, impl Future)>>
where
IO: Read + Write + Clone + Send + Sync + Unpin + 'static,
{
let mut reader = BufReader::new(io.clone());
let mut reader = Sequenced::new(BufReader::new(io.clone()));
let mut writer = Sequenced::new(io);
let res = decode_rw(reader.split_seq(), writer.split_seq()).await?;
Ok(res.map(|(r, _)| {
(r, async move {
reader.take_inner().await;
writer.take_inner().await;
})
}))
}

async fn discard_unread_body<R1: Read + Unpin, R2>(
mut body_reader: Sequenced<R1>,
mut reader: Sequenced<R2>,
) -> io::Result<()> {
// Unpoison the body reader, as we don't require it to be in any particular state
body_reader.cure();

// Consume the remainder of the request body
let body_bytes_discarded = io::copy(&mut body_reader, &mut io::sink()).await?;

log::trace!(
"discarded {} unread request body bytes",
body_bytes_discarded
);

// Unpoison the reader, as it's easier than trying to reach into the body reader to
// release the inner `Sequenced<T>`
reader.cure();
reader.release();

Ok(())
}

#[derive(Debug)]
pub struct NotifyWrite {
sender: Option<oneshot::Sender<()>>,
}

/// Decode an HTTP request on the server.
pub async fn decode_rw<R, W>(
mut reader: Sequenced<R>,
mut writer: Sequenced<W>,
) -> http_types::Result<Option<(Request, NotifyWrite)>>
where
R: BufRead + Send + Sync + Unpin + 'static,
W: Write + Send + Sync + Unpin + 'static,
{
let mut buf = Vec::new();
let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS];
let mut httparse_req = httparse::Request::new(&mut headers);

let mut notify_write = NotifyWrite { sender: None };

// Keep reading bytes from the stream until we hit the end of the stream.
loop {
let bytes_read = reader.read_until(LF, &mut buf).await?;
@@ -103,12 +153,47 @@ where
let (body_read_sender, body_read_receiver) = async_channel::bounded(1);

if Some(CONTINUE_HEADER_VALUE) == req.header(EXPECT).map(|h| h.as_str()) {
// Prevent the response being written until we've decided whether to send
// the continue message or not.
let mut continue_writer = writer.split_seq();

// We can swap these later to effectively deactivate the body reader, in the event
// that we don't ask the client to send a body.
let mut continue_reader = reader.split_seq();
let mut after_reader = reader.split_seq_rev();

let (notify_tx, notify_rx) = oneshot::channel();
notify_write.sender = Some(notify_tx);

// If the client expects a 100-continue header, spawn a
// task to wait for the first read attempt on the body.
task::spawn(async move {
// If the client expects a 100-continue header, spawn a
// task to wait for the first read attempt on the body.
if let Ok(()) = body_read_receiver.recv().await {
io.write_all(CONTINUE_RESPONSE).await.ok();
// It's important that we fuse this future, or else the `select` won't
// wake up properly if the sender is dropped.
let mut notify_rx = notify_rx.fuse();

let should_continue = select_biased! {
x = body_read_receiver.recv().fuse() => x.is_ok(),
_ = notify_rx => true,
};

if should_continue {
if continue_writer.write_all(CONTINUE_RESPONSE).await.is_err() {
return;
}
} else {
// We never asked for the body, so just allow the next
// request to continue from our current point in the stream.
continue_reader.swap(&mut after_reader);
}
// Allow the rest of the response to be written
continue_writer.release();

// Allow the body to be read
continue_reader.release();

// Allow the next request to be read (after the body, if requested, has been read)
after_reader.release();
// Since the sender is moved into the Body, this task will
// finish when the client disconnects, whether or not
// 100-continue was sent.
@@ -121,23 +206,43 @@ where
.unwrap_or(false)
{
let trailer_sender = req.send_trailers();
let reader = ChunkedDecoder::new(reader, trailer_sender);
let reader = Arc::new(Mutex::new(reader));
let reader_clone = reader.clone();
let reader = ReadNotifier::new(reader, body_read_sender);
let reader = BufReader::new(reader);
req.set_body(Body::from_reader(reader, None));
return Ok(Some((req, BodyReader::Chunked(reader_clone))));
let mut body_reader =
Sequenced::new(ChunkedDecoder::new(reader.split_seq(), trailer_sender));
req.set_body(Body::from_reader(
ReadNotifier::new(body_reader.split_seq(), body_read_sender),
None,
));
let reader_to_cure = reader.split_seq();

// Spawn a task to consume any part of the body which is unread
task::spawn(async move {
let _ = discard_unread_body(body_reader, reader_to_cure).await;
});

reader.release();
writer.release();
return Ok(Some((req, notify_write)));
} else if let Some(len) = content_length {
let len = len.len();
let reader = Arc::new(Mutex::new(reader.take(len)));
let mut body_reader = Sequenced::new(reader.split_seq().take(len));
req.set_body(Body::from_reader(
BufReader::new(ReadNotifier::new(reader.clone(), body_read_sender)),
ReadNotifier::new(body_reader.split_seq(), body_read_sender),
Some(len as usize),
));
Ok(Some((req, BodyReader::Fixed(reader))))
let reader_to_cure = reader.split_seq();

// Spawn a task to consume any part of the body which is unread
task::spawn(async move {
let _ = discard_unread_body(body_reader, reader_to_cure).await;
});

reader.release();
writer.release();
Ok(Some((req, notify_write)))
} else {
Ok(Some((req, BodyReader::None)))
reader.release();
writer.release();
Ok(Some((req, notify_write)))
}
}

70 changes: 46 additions & 24 deletions src/server/mod.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
//! Process HTTP connections on the server.
use async_std::future::{timeout, Future, TimeoutError};
use async_std::io::{self, Read, Write};
use async_std::io::{self, BufRead, BufReader, Read, Write};
use http_types::headers::{CONNECTION, UPGRADE};
use http_types::upgrade::Connection;
use http_types::{Request, Response, StatusCode};
use std::{marker::PhantomData, time::Duration};
mod body_reader;

mod decode;
mod encode;

pub use decode::decode;
pub use decode::{decode, decode_rw};
pub use encode::Encoder;

use crate::sequenced::Sequenced;
use crate::unite::Unite;

/// Configure the server.
#[derive(Debug, Clone)]
pub struct ServerOptions {
@@ -23,7 +26,7 @@ pub struct ServerOptions {
impl Default for ServerOptions {
fn default() -> Self {
Self {
headers_timeout: Some(Duration::from_secs(60)),
headers_timeout: Some(Duration::from_secs(30)),
}
}
}
@@ -58,8 +61,9 @@ where

/// struct for server
#[derive(Debug)]
pub struct Server<RW, F, Fut> {
io: RW,
pub struct Server<R, W, F, Fut> {
reader: Sequenced<R>,
writer: Sequenced<W>,
endpoint: F,
opts: ServerOptions,
_phantom: PhantomData<Fut>,
@@ -75,16 +79,34 @@ pub enum ConnectionStatus {
KeepAlive,
}

impl<RW, F, Fut> Server<RW, F, Fut>
impl<RW, F, Fut> Server<BufReader<RW>, RW, F, Fut>
where
RW: Read + Write + Clone + Send + Sync + Unpin + 'static,
RW: Read + Write + Send + Sync + Clone + Unpin + 'static,
F: Fn(Request) -> Fut,
Fut: Future<Output = http_types::Result<Response>>,
{
/// builds a new server
pub fn new(io: RW, endpoint: F) -> Self {
Self::new_rw(
Sequenced::new(BufReader::new(io.clone())),
Sequenced::new(io),
endpoint,
)
}
}

impl<R, W, F, Fut> Server<R, W, F, Fut>
where
R: BufRead + Send + Sync + Unpin + 'static,
W: Write + Send + Sync + Unpin + 'static,
F: Fn(Request) -> Fut,
Fut: Future<Output = http_types::Result<Response>>,
{
/// builds a new server
pub fn new_rw(reader: Sequenced<R>, writer: Sequenced<W>, endpoint: F) -> Self {
Self {
io,
reader,
writer,
endpoint,
opts: Default::default(),
_phantom: PhantomData,
@@ -104,16 +126,11 @@ where
}

/// accept one request
pub async fn accept_one(&mut self) -> http_types::Result<ConnectionStatus>
where
RW: Read + Write + Clone + Send + Sync + Unpin + 'static,
F: Fn(Request) -> Fut,
Fut: Future<Output = http_types::Result<Response>>,
{
pub async fn accept_one(&mut self) -> http_types::Result<ConnectionStatus> {
// Decode a new request, timing out if this takes longer than the timeout duration.
let fut = decode(self.io.clone());
let fut = decode_rw(self.reader.split_seq(), self.writer.split_seq());

let (req, mut body) = if let Some(timeout_duration) = self.opts.headers_timeout {
let (req, notify_write) = if let Some(timeout_duration) = self.opts.headers_timeout {
match timeout(timeout_duration, fut).await {
Ok(Ok(Some(r))) => r,
Ok(Ok(None)) | Err(TimeoutError { .. }) => return Ok(ConnectionStatus::Close), /* EOF or timeout */
@@ -159,17 +176,22 @@ where

let mut encoder = Encoder::new(res, method);

let bytes_written = io::copy(&mut encoder, &mut self.io).await?;
// This should be dropped before we begin writing the response.
drop(notify_write);

let bytes_written = io::copy(&mut encoder, &mut self.writer).await?;
log::trace!("wrote {} response bytes", bytes_written);

let body_bytes_discarded = io::copy(&mut body, &mut io::sink()).await?;
log::trace!(
"discarded {} unread request body bytes",
body_bytes_discarded
);
async_std::task::sleep(Duration::from_millis(1)).await;

if let Some(upgrade_sender) = upgrade_sender {
upgrade_sender.send(Connection::new(self.io.clone())).await;
let reader = self.reader.take_inner().await;
let writer = self.writer.take_inner().await;
if let (Some(reader), Some(writer)) = (reader, writer) {
upgrade_sender
.send(Connection::new(Unite::new(reader, writer)))
.await;
}
return Ok(ConnectionStatus::Close);
} else if close_connection {
Ok(ConnectionStatus::Close)
60 changes: 60 additions & 0 deletions src/unite.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
use core::pin::Pin;
use core::task::{Context, Poll};

use futures_io::{AsyncBufRead, AsyncRead, AsyncWrite};
use pin_project::pin_project;

#[pin_project]
pub(crate) struct Unite<R, W> {
#[pin]
reader: R,
#[pin]
writer: W,
}

impl<R, W> Unite<R, W> {
pub(crate) fn new(reader: R, writer: W) -> Self {
Self { reader, writer }
}
}

impl<R: AsyncRead, W> AsyncRead for Unite<R, W> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<futures_io::Result<usize>> {
self.project().reader.poll_read(cx, buf)
}
}

impl<R: AsyncBufRead, W> AsyncBufRead for Unite<R, W> {
fn poll_fill_buf(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<futures_io::Result<&[u8]>> {
self.project().reader.poll_fill_buf(cx)
}

fn consume(self: Pin<&mut Self>, amt: usize) {
self.project().reader.consume(amt)
}
}

impl<R, W: AsyncWrite> AsyncWrite for Unite<R, W> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<futures_io::Result<usize>> {
self.project().writer.poll_write(cx, buf)
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
self.project().writer.poll_flush(cx)
}

fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
self.project().writer.poll_close(cx)
}
}
63 changes: 60 additions & 3 deletions tests/accept.rs
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@ mod accept {
let content_length = 10;

let request_str = format!(
"POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: {}\r\n\r\n{}\r\n\r\n",
"POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: {}\r\n\r\n{}",
content_length,
std::str::from_utf8(&vec![b'|'; content_length]).unwrap()
);
@@ -33,6 +33,36 @@ mod accept {
Ok(())
}

#[async_std::test]
async fn pipelined() -> Result<()> {
let mut server = TestServer::new(|req| async {
let mut response = Response::new(200);
let len = req.len();
response.set_body(Body::from_reader(req, len));
Ok(response)
});

let content_length = 10;

let request_str = format!(
"POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: {}\r\n\r\n{}",
content_length,
std::str::from_utf8(&vec![b'|'; content_length]).unwrap()
);

server.write_all(request_str.as_bytes()).await?;
server.write_all(request_str.as_bytes()).await?;
assert_eq!(server.accept_one().await?, ConnectionStatus::KeepAlive);
assert_eq!(server.accept_one().await?, ConnectionStatus::KeepAlive);

server.close();
assert_eq!(server.accept_one().await?, ConnectionStatus::Close);

assert!(server.all_read());

Ok(())
}

#[async_std::test]
async fn request_close() -> Result<()> {
let mut server = TestServer::new(|_| async { Ok(Response::new(200)) });
@@ -74,7 +104,7 @@ mod accept {
let content_length = 10;

let request_str = format!(
"POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: {}\r\n\r\n{}\r\n\r\n",
"POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: {}\r\n\r\n{}",
content_length,
std::str::from_utf8(&vec![b'|'; content_length]).unwrap()
);
@@ -130,7 +160,7 @@ mod accept {
let content_length = 10000;

let request_str = format!(
"POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: {}\r\n\r\n{}\r\n\r\n",
"POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: {}\r\n\r\n{}",
content_length,
std::str::from_utf8(&vec![b'|'; content_length]).unwrap()
);
@@ -169,6 +199,33 @@ mod accept {
"GET / HTTP/1.1\r\nHost: example.com\r\nContent-Length: 0\r\n\r\n"
))
.await?;
server.close();
assert_eq!(server.accept_one().await?, ConnectionStatus::KeepAlive);

assert_eq!(server.accept_one().await?, ConnectionStatus::Close);

assert!(server.all_read());

Ok(())
}

#[async_std::test]
async fn echo_server() -> Result<()> {
let mut server = TestServer::new(|mut req| async move {
let mut resp = Response::new(200);
resp.set_body(req.take_body());
Ok(resp)
});

let content_length = 10;

let request_str = format!(
"POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: {}\r\n\r\n{}",
content_length,
std::str::from_utf8(&vec![b'|'; content_length]).unwrap()
);

server.write_all(request_str.as_bytes()).await?;
assert_eq!(server.accept_one().await?, ConnectionStatus::KeepAlive);

server.close();
208 changes: 204 additions & 4 deletions tests/continue.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
mod test_utils;

use async_h1::server::ConnectionStatus;
use async_h1::Sequenced;
use async_std::future::timeout;
use async_std::io::BufReader;
use async_std::{io, prelude::*, task};
use http_types::Result;
use http_types::{Response, Result};
use std::time::Duration;
use test_utils::TestIO;
use test_utils::{TestIO, TestServer};

const REQUEST_WITH_EXPECT: &[u8] = b"POST / HTTP/1.1\r\n\
Host: example.com\r\n\
@@ -16,7 +20,12 @@ async fn test_with_expect_when_reading_body() -> Result<()> {
let (mut client, server) = TestIO::new();
client.write_all(REQUEST_WITH_EXPECT).await?;

let (mut request, _) = async_h1::server::decode(server).await?.unwrap();
let (mut request, _notify_write) = async_h1::server::decode_rw(
Sequenced::new(BufReader::new(server.clone())),
Sequenced::new(server.clone()),
)
.await?
.unwrap();

task::sleep(SLEEP_DURATION).await; //prove we're not just testing before we've written

@@ -44,11 +53,202 @@ async fn test_without_expect_when_not_reading_body() -> Result<()> {
let (mut client, server) = TestIO::new();
client.write_all(REQUEST_WITH_EXPECT).await?;

let (_, _) = async_h1::server::decode(server).await?.unwrap();
let _ = async_h1::server::decode_rw(
Sequenced::new(BufReader::new(server.clone())),
Sequenced::new(server.clone()),
)
.await?
.unwrap();

task::sleep(SLEEP_DURATION).await; // just long enough to wait for the channel

assert_eq!("", &client.read.to_string()); // we haven't written 100-continue

client.write_all(REQUEST_WITH_EXPECT).await?;

// Make sure the server doesn't try to read the body before processing the next request
task::sleep(SLEEP_DURATION).await;
let (_, _) = async_h1::server::decode(server).await?.unwrap();

Ok(())
}

#[async_std::test]
async fn test_accept_unread_body() -> Result<()> {
let mut server = TestServer::new(|_| async { Ok(Response::new(200)) });

server.write_all(REQUEST_WITH_EXPECT).await?;
assert_eq!(
timeout(Duration::from_secs(1), server.accept_one()).await??,
ConnectionStatus::KeepAlive
);

server.write_all(REQUEST_WITH_EXPECT).await?;
assert_eq!(
timeout(Duration::from_secs(1), server.accept_one()).await??,
ConnectionStatus::KeepAlive
);

server.close();
assert_eq!(server.accept_one().await?, ConnectionStatus::Close);

assert!(server.all_read());

Ok(())
}

#[async_std::test]
async fn test_echo_server() -> Result<()> {
let mut server = TestServer::new(|mut req| async move {
let mut resp = Response::new(200);
resp.set_body(req.take_body());
Ok(resp)
});

server.write_all(REQUEST_WITH_EXPECT).await?;
server.write_all(b"0123456789").await?;
assert_eq!(server.accept_one().await?, ConnectionStatus::KeepAlive);

task::sleep(SLEEP_DURATION).await; // wait for "continue" to be sent

server.close();

assert!(server
.client
.read
.to_string()
.starts_with("HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\n"));

assert_eq!(server.accept_one().await?, ConnectionStatus::Close);

assert!(server.all_read());

Ok(())
}

#[async_std::test]
async fn test_delayed_read() -> Result<()> {
let mut server = TestServer::new(|mut req| async move {
let mut body = req.take_body();
task::spawn(async move {
let mut buf = Vec::new();
body.read_to_end(&mut buf).await.unwrap();
});
Ok(Response::new(200))
});

server.write_all(REQUEST_WITH_EXPECT).await?;
assert_eq!(
timeout(Duration::from_secs(1), server.accept_one()).await??,
ConnectionStatus::KeepAlive
);
server.write_all(b"0123456789").await?;

server.write_all(REQUEST_WITH_EXPECT).await?;
assert_eq!(
timeout(Duration::from_secs(1), server.accept_one()).await??,
ConnectionStatus::KeepAlive
);
server.write_all(b"0123456789").await?;

server.close();
assert_eq!(server.accept_one().await?, ConnectionStatus::Close);

assert!(server.all_read());

Ok(())
}

#[async_std::test]
async fn test_accept_fast_unread_sequential_requests() -> Result<()> {
let mut server = TestServer::new(|_| async move { Ok(Response::new(200)) });
let mut client = server.client.clone();

task::spawn(async move {
let mut reader = BufReader::new(client.clone());
for _ in 0..10 {
let mut buf = String::new();
client.write_all(REQUEST_WITH_EXPECT).await.unwrap();

while !buf.ends_with("\r\n\r\n") {
reader.read_line(&mut buf).await.unwrap();
}

assert!(buf.starts_with("HTTP/1.1 200 OK\r\n"));
}
client.close();
});

for _ in 0..10 {
assert_eq!(
timeout(Duration::from_secs(1), server.accept_one()).await??,
ConnectionStatus::KeepAlive
);
}

assert_eq!(server.accept_one().await?, ConnectionStatus::Close);

assert!(server.all_read());

Ok(())
}

#[async_std::test]
async fn test_accept_partial_read_sequential_requests() -> Result<()> {
const LARGE_REQUEST_WITH_EXPECT: &[u8] = b"POST / HTTP/1.1\r\n\
Host: example.com\r\n\
Content-Length: 1000\r\n\
Expect: 100-continue\r\n\r\n";

let mut server = TestServer::new(|mut req| async move {
let mut body = req.take_body();
let mut buf = [0];
body.read(&mut buf).await.unwrap();
Ok(Response::new(200))
});
let mut client = server.client.clone();

task::spawn(async move {
let mut reader = BufReader::new(client.clone());
for _ in 0..10 {
let mut buf = String::new();
client.write_all(LARGE_REQUEST_WITH_EXPECT).await.unwrap();

// Wait for body to be requested
while !buf.ends_with("\r\n\r\n") {
reader.read_line(&mut buf).await.unwrap();
}
assert!(buf.starts_with("HTTP/1.1 100 Continue\r\n"));

// Write body
for _ in 0..100 {
client.write_all(b"0123456789").await.unwrap();
}

// Wait for response
buf.clear();
while !buf.ends_with("\r\n\r\n") {
reader.read_line(&mut buf).await.unwrap();
}

assert!(buf.starts_with("HTTP/1.1 200 OK\r\n"));
}
client.close();
});

for _ in 0..10 {
assert_eq!(
timeout(Duration::from_secs(1), server.accept_one()).await??,
ConnectionStatus::KeepAlive
);
}

assert_eq!(
timeout(Duration::from_secs(1), server.accept_one()).await??,
ConnectionStatus::Close
);

assert!(server.all_read());

Ok(())
}
2 changes: 2 additions & 0 deletions tests/server_decode.rs
Original file line number Diff line number Diff line change
@@ -67,6 +67,7 @@ mod server_decode {
"llo",
"0",
"",
"",
])
.await?
.unwrap();
@@ -93,6 +94,7 @@ mod server_decode {
"0",
"x-invalid: å",
"",
"",
])
.await?
.unwrap();
84 changes: 47 additions & 37 deletions tests/test_utils.rs
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@ use async_h1::{
client::Encoder,
server::{ConnectionStatus, Server},
};
use async_std::io::{Read, Write};
use async_std::io::{BufReader, Read, Write};
use http_types::{Request, Response, Result};
use std::{
fmt::{Debug, Display},
@@ -17,9 +17,9 @@ use async_dup::Arc;

#[pin_project::pin_project]
pub struct TestServer<F, Fut> {
server: Server<TestIO, F, Fut>,
server: Server<BufReader<TestIO>, TestIO, F, Fut>,
#[pin]
client: TestIO,
pub(crate) client: TestIO,
}

impl<F, Fut> TestServer<F, Fut>
@@ -102,35 +102,47 @@ pub struct TestIO {
}

#[derive(Default)]
pub struct CloseableCursor {
data: RwLock<Vec<u8>>,
cursor: RwLock<usize>,
waker: RwLock<Option<Waker>>,
closed: RwLock<bool>,
struct CloseableCursorInner {
data: Vec<u8>,
cursor: usize,
waker: Option<Waker>,
closed: bool,
}

#[derive(Default)]
pub struct CloseableCursor(RwLock<CloseableCursorInner>);

impl CloseableCursor {
fn len(&self) -> usize {
self.data.read().unwrap().len()
pub fn len(&self) -> usize {
self.0.read().unwrap().data.len()
}

pub fn cursor(&self) -> usize {
self.0.read().unwrap().cursor
}

fn cursor(&self) -> usize {
*self.cursor.read().unwrap()
pub fn is_empty(&self) -> bool {
self.len() == 0
}

fn current(&self) -> bool {
self.len() == self.cursor()
pub fn current(&self) -> bool {
let inner = self.0.read().unwrap();
inner.data.len() == inner.cursor
}

fn close(&self) {
*self.closed.write().unwrap() = true;
pub fn close(&self) {
let mut inner = self.0.write().unwrap();
inner.closed = true;
if let Some(waker) = inner.waker.take() {
waker.wake();
}
}
}

impl Display for CloseableCursor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let data = &*self.data.read().unwrap();
let s = std::str::from_utf8(data).unwrap_or("not utf8");
let inner = self.0.read().unwrap();
let s = std::str::from_utf8(&inner.data).unwrap_or("not utf8");
write!(f, "{}", s)
}
}
@@ -163,13 +175,14 @@ impl TestIO {

impl Debug for CloseableCursor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let inner = self.0.read().unwrap();
f.debug_struct("CloseableCursor")
.field(
"data",
&std::str::from_utf8(&self.data.read().unwrap()).unwrap_or("not utf8"),
&std::str::from_utf8(&inner.data).unwrap_or("not utf8"),
)
.field("closed", &*self.closed.read().unwrap())
.field("cursor", &*self.cursor.read().unwrap())
.field("closed", &inner.closed)
.field("cursor", &inner.cursor)
.finish()
}
}
@@ -180,18 +193,17 @@ impl Read for &CloseableCursor {
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let len = self.len();
let cursor = self.cursor();
if cursor < len {
let data = &*self.data.read().unwrap();
let bytes_to_copy = buf.len().min(len - cursor);
buf[..bytes_to_copy].copy_from_slice(&data[cursor..cursor + bytes_to_copy]);
*self.cursor.write().unwrap() += bytes_to_copy;
let mut inner = self.0.write().unwrap();
if inner.cursor < inner.data.len() {
let bytes_to_copy = buf.len().min(inner.data.len() - inner.cursor);
buf[..bytes_to_copy]
.copy_from_slice(&inner.data[inner.cursor..inner.cursor + bytes_to_copy]);
inner.cursor += bytes_to_copy;
Poll::Ready(Ok(bytes_to_copy))
} else if *self.closed.read().unwrap() {
} else if inner.closed {
Poll::Ready(Ok(0))
} else {
*self.waker.write().unwrap() = Some(cx.waker().clone());
inner.waker = Some(cx.waker().clone());
Poll::Pending
}
}
@@ -203,11 +215,12 @@ impl Write for &CloseableCursor {
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
if *self.closed.read().unwrap() {
let mut inner = self.0.write().unwrap();
if inner.closed {
Poll::Ready(Ok(0))
} else {
self.data.write().unwrap().extend_from_slice(buf);
if let Some(waker) = self.waker.write().unwrap().take() {
inner.data.extend_from_slice(buf);
if let Some(waker) = inner.waker.take() {
waker.wake();
}
Poll::Ready(Ok(buf.len()))
@@ -219,10 +232,7 @@ impl Write for &CloseableCursor {
}

fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
if let Some(waker) = self.waker.write().unwrap().take() {
waker.wake();
}
*self.closed.write().unwrap() = true;
self.close();
Poll::Ready(Ok(()))
}
}