diff --git a/Cargo.toml b/Cargo.toml index 16e9ea7..35eb79d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,7 @@ categories = [ "parser-implementations", "web-programming", "web-programming::http-client", - "web-programming::http-server" + "web-programming::http-server", ] authors = ["Yoshua Wuyts "] readme = "README.md" @@ -28,6 +28,10 @@ log = "0.4.11" pin-project = "1.0.2" async-channel = "1.5.1" async-dup = "1.2.2" +futures-channel = "0.3.12" +futures-io = "0.3.12" +futures-lite = "1.11.3" +futures-util = "0.3.12" [dev-dependencies] pretty_assertions = "0.6.1" diff --git a/src/chunked/decoder.rs b/src/chunked/decoder.rs index 625e674..7531387 100644 --- a/src/chunked/decoder.rs +++ b/src/chunked/decoder.rs @@ -1,506 +1,314 @@ use std::fmt; +use std::fmt::Display; use std::future::Future; -use std::ops::Range; use std::pin::Pin; use std::task::{Context, Poll}; -use async_std::io::{self, Read}; -use async_std::sync::Arc; -use byte_pool::{Block, BytePool}; +use async_std::io::{self, BufRead, Read}; +use futures_core::ready; use http_types::trailers::{Sender, Trailers}; - -const INITIAL_CAPACITY: usize = 1024 * 4; -const MAX_CAPACITY: usize = 512 * 1024 * 1024; // 512 MiB - -lazy_static::lazy_static! { - /// The global buffer pool we use for storing incoming data. - pub(crate) static ref POOL: Arc = Arc::new(BytePool::new()); -} +use pin_project::pin_project; /// Decodes a chunked body according to /// https://tools.ietf.org/html/rfc7230#section-4.1 +#[pin_project] #[derive(Debug)] -pub struct ChunkedDecoder { +pub(crate) struct ChunkedDecoder { /// The underlying stream + #[pin] inner: R, - /// Buffer for the already read, but not yet parsed data. - buffer: Block<'static>, - /// Range of valid read data into buffer. - current: Range, - /// Whether we should attempt to decode whatever is currently inside the buffer. - /// False indicates that we know for certain that the buffer is incomplete. - initial_decode: bool, /// Current state. state: State, /// Trailer channel sender. trailer_sender: Option, } -impl ChunkedDecoder { +impl ChunkedDecoder { pub(crate) fn new(inner: R, trailer_sender: Sender) -> Self { ChunkedDecoder { inner, - buffer: POOL.alloc(INITIAL_CAPACITY), - current: Range { start: 0, end: 0 }, - initial_decode: false, // buffer is empty initially, nothing to decode} - state: State::Init, + state: State::Read(ReadState::BeforeChunk { + size: 0, + inner: ChunkSizeState::ChunkSize, + }), trailer_sender: Some(trailer_sender), } } } -impl ChunkedDecoder { - fn poll_read_chunk( - &mut self, - cx: &mut Context<'_>, - buffer: Block<'static>, - pos: &Range, - buf: &mut [u8], - current: u64, - len: u64, - ) -> io::Result { - let mut new_pos = pos.clone(); - let remaining = (len - current) as usize; - let to_read = std::cmp::min(remaining, buf.len()); - - let mut new_current = current; - - // position into buf - let mut read = 0; - - // first drain the buffer - if new_pos.len() > 0 { - let to_read_buf = std::cmp::min(to_read, pos.len()); - buf[..to_read_buf].copy_from_slice(&buffer[new_pos.start..new_pos.start + to_read_buf]); - - if new_pos.start + to_read_buf == new_pos.end { - new_pos = 0..0 - } else { - new_pos.start += to_read_buf; - } - new_current += to_read_buf as u64; - read += to_read_buf; - - let new_state = if new_current == len { - State::ChunkEnd - } else { - State::Chunk(new_current, len) - }; - - return Ok(DecodeResult::Some { - read, - new_state: Some(new_state), - new_pos, - buffer, - pending: false, - }); - } - - // attempt to fill the buffer - match Pin::new(&mut self.inner).poll_read(cx, &mut buf[read..read + to_read]) { - Poll::Ready(val) => { - let n = val?; - new_current += n as u64; - read += n; - let new_state = if new_current == len { - State::ChunkEnd - } else if n == 0 { - // Unexpected end - // TODO: do something? - State::Done - } else { - State::Chunk(new_current, len) - }; - - Ok(DecodeResult::Some { - read, - new_state: Some(new_state), - new_pos, - buffer, - pending: false, - }) - } - Poll::Pending => Ok(DecodeResult::Some { - read: 0, - new_state: Some(State::Chunk(new_current, len)), - new_pos, - buffer, - pending: true, - }), +const MAX_CHUNK_SIZE: u64 = 0x0FFF_FFFF_FFFF_FFFF; + +fn read_chunk_size( + buf: &[u8], + size: &mut u64, + state: &mut ChunkSizeState, +) -> io::Result<(usize, bool)> { + for (offset, c) in buf.iter().copied().enumerate() { + match *state { + ChunkSizeState::ChunkSize => match c { + b'0'..=b'9' => *size = (*size << 4) + (c - b'0') as u64, + b'a'..=b'f' => *size = (*size << 4) + (c + 10 - b'a') as u64, + b'A'..=b'F' => *size = (*size << 4) + (c + 10 - b'A') as u64, + b';' => *state = ChunkSizeState::Extension, + b'\r' => *state = ChunkSizeState::NewLine, + _ => return Err(other_err(httparse::InvalidChunkSize)), + }, + ChunkSizeState::Extension => match c { + b'\r' => *state = ChunkSizeState::NewLine, + _ => return Err(other_err(httparse::InvalidChunkSize)), + }, + ChunkSizeState::NewLine => match c { + b'\n' => return Ok((offset + 1, true)), + _ => return Err(other_err(httparse::InvalidChunkSize)), + }, } - } - - fn poll_read_inner( - &mut self, - cx: &mut Context<'_>, - buffer: Block<'static>, - pos: &Range, - buf: &mut [u8], - ) -> io::Result { - match self.state { - State::Init => { - // Initial read - decode_init(buffer, pos) - } - State::Chunk(current, len) => { - // reading a chunk - self.poll_read_chunk(cx, buffer, pos, buf, current, len) - } - State::ChunkEnd => decode_chunk_end(buffer, pos), - State::Trailer => { - // reading the trailer headers - decode_trailer(buffer, pos) - } - State::TrailerDone(ref mut headers) => { - let headers = std::mem::replace(headers, Trailers::new()); - let sender = self.trailer_sender.take(); - let sender = - sender.expect("invalid chunked state, tried sending multiple trailers"); - - let fut = Box::pin(sender.send(headers)); - Ok(DecodeResult::Some { - read: 0, - new_state: Some(State::TrailerSending(fut)), - new_pos: pos.clone(), - buffer, - pending: false, - }) - } - State::TrailerSending(ref mut fut) => { - match Pin::new(fut).poll(cx) { - Poll::Ready(_) => {} - Poll::Pending => { - return Ok(DecodeResult::Some { - read: 0, - new_state: None, - new_pos: pos.clone(), - buffer, - pending: true, - }); - } - } - - Ok(DecodeResult::Some { - read: 0, - new_state: Some(State::Done), - new_pos: pos.clone(), - buffer, - pending: false, - }) - } - State::Done => Ok(DecodeResult::Some { - read: 0, - new_state: Some(State::Done), - new_pos: pos.clone(), - buffer, - pending: false, - }), + if *size > MAX_CHUNK_SIZE { + return Err(other_err(httparse::InvalidChunkSize)); } } + Ok((buf.len(), false)) } -impl Read for ChunkedDecoder { +impl Read for ChunkedDecoder { #[allow(missing_doc_code_examples)] fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { - let this = &mut *self; + let inner_buf = ready!(self.as_mut().poll_fill_buf(cx))?; + let amt = buf.len().min(inner_buf.len()); + buf[0..amt].copy_from_slice(&inner_buf[0..amt]); + self.consume(amt); - if let State::Done = this.state { - return Poll::Ready(Ok(0)); - } - - let mut n = std::mem::replace(&mut this.current, 0..0); - let buffer = std::mem::replace(&mut this.buffer, POOL.alloc(INITIAL_CAPACITY)); - let mut needs_read = !matches!(this.state, State::Chunk(_, _)); - - let mut buffer = if n.len() > 0 && this.initial_decode { - // initial buffer filling, if needed - match this.poll_read_inner(cx, buffer, &n, buf)? { - DecodeResult::Some { - read, - buffer, - new_pos, - new_state, - pending, - } => { - this.current = new_pos.clone(); - if let Some(state) = new_state { - this.state = state; - } + Poll::Ready(Ok(amt)) + } +} - if pending { - // initial_decode is still true - this.buffer = buffer; - return Poll::Pending; - } +impl BufRead for ChunkedDecoder { + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); - if let State::Done = this.state { - // initial_decode is still true - this.buffer = buffer; - return Poll::Ready(Ok(read)); - } - - if read > 0 { - // initial_decode is still true - this.buffer = buffer; - return Poll::Ready(Ok(read)); + let pass_through_state = loop { + match this.state { + State::PassThrough(pass_through_state) => { + if pass_through_state.offset < pass_through_state.size { + break pass_through_state; + } else { + *this.state = State::Read(ReadState::AfterChunk { new_line: false }); } - - n = new_pos; - needs_read = false; - buffer } - DecodeResult::None(buffer) => buffer, - } - } else { - buffer - }; - - loop { - if n.len() >= buffer.capacity() { - if buffer.capacity() + 1024 <= MAX_CAPACITY { - buffer.realloc(buffer.capacity() + 1024); - } else { - this.buffer = buffer; - this.current = n; - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::Other, - "incoming data too large", - ))); + State::Poll(poll_state) => { + *this.state = ready!(poll_state.poll(cx, this.trailer_sender))?; } - } + State::Read(read_state) => { + let inner_buf = ready!(this.inner.as_mut().poll_fill_buf(cx))?; - if needs_read { - let bytes_read = match Pin::new(&mut this.inner).poll_read(cx, &mut buffer[n.end..]) - { - Poll::Ready(result) => result?, - Poll::Pending => { - // if we're here, it means that we need more data but there is none yet, - // so no decoding attempts are necessary until we get more data - this.initial_decode = false; - this.buffer = buffer; - this.current = n; - return Poll::Pending; + if inner_buf.is_empty() { + return Poll::Ready(Err(unexpected_eof())); } - }; - match (bytes_read, &this.state) { - (0, State::Done) => {} - (0, _) => { - // Unexpected end - // TODO: do something? - this.state = State::Done; + + let mut read = 0; + while read < inner_buf.len() { + let (nread, next_state) = read_state.advance(&inner_buf[read..])?; + read += nread; + if let Some(next_state) = next_state { + *this.state = next_state; + break; + } } - _ => {} + this.inner.as_mut().consume(read); } - n.end += bytes_read; + State::Done => return Poll::Ready(Ok(&[])), } - match this.poll_read_inner(cx, buffer, &n, buf)? { - DecodeResult::Some { - read, - buffer: new_buffer, - new_pos, - new_state, - pending, - } => { - // current buffer might now contain more data inside, so we need to attempt - // to decode it next time - this.initial_decode = true; - if let Some(state) = new_state { - this.state = state; - } - this.current = new_pos.clone(); - n = new_pos; - - if let State::Done = this.state { - this.buffer = new_buffer; - return Poll::Ready(Ok(read)); - } - - if read > 0 { - this.buffer = new_buffer; - return Poll::Ready(Ok(read)); - } + }; - if pending { - this.buffer = new_buffer; - return Poll::Pending; - } + // Unfortunately due to lifetime limitations, this can't be part of the main loop + let inner_buf = ready!(this.inner.poll_fill_buf(cx))?; - buffer = new_buffer; - needs_read = false; - continue; - } - DecodeResult::None(buf) => { - buffer = buf; + // Work out how much of the buffer we can pass through + let max_read = pass_through_state.size - pass_through_state.offset; + let amt = max_read.min(inner_buf.len() as u64) as usize; - if this.buffer.is_empty() || n.start == 0 && n.end == 0 { - // "logical buffer" is empty, there is nothing to decode on the next step - this.initial_decode = false; - this.buffer = buffer; - this.current = n; + Poll::Ready(if amt == 0 { + Err(unexpected_eof()) + } else { + Ok(&inner_buf[0..amt]) + }) + } - return Poll::Ready(Ok(0)); - } else { - needs_read = true; - } - } + fn consume(self: Pin<&mut Self>, amt: usize) { + let this = self.project(); + if amt > 0 { + if let State::PassThrough(pass_through_state) = this.state { + pass_through_state.offset += amt as u64; + assert!(pass_through_state.offset <= pass_through_state.size); + this.inner.consume(amt); + } else { + panic!("Called consume without first filling buffer"); } } } } -/// Possible return values from calling `decode` methods. -enum DecodeResult { - /// Something was decoded successfully. - Some { - /// How much data was read. - read: usize, - /// The passed in block returned. - buffer: Block<'static>, - /// The new range of valid data in `buffer`. - new_pos: Range, - /// The new state. - new_state: Option, - /// Should poll return `Pending`. - pending: bool, - }, - /// Nothing was decoded. - None(Block<'static>), +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +enum ChunkSizeState { + ChunkSize, + Extension, + NewLine, } -/// Decoder state. +// Decoder state +#[derive(Debug)] enum State { - /// Initial state. - Init, - /// Decoding a chunk, first value is the current position, second value is the length of the chunk. - Chunk(u64, u64), - /// Decoding the end part of a chunk. - ChunkEnd, - /// Decoding trailers. - Trailer, - /// Trailers were decoded, are now set to the decoded trailers. - TrailerDone(Trailers), - TrailerSending(Pin + 'static + Send + Sync>>), - /// All is said and done. + // We're inside a chunk + PassThrough(PassThroughState), + // We're reading the framing around a chunk + Read(ReadState), + // We're driving an internal future + Poll(PollState), + // We're done Done, } -impl fmt::Debug for State { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - use State::*; - match self { - Init => write!(f, "State::Init"), - Chunk(a, b) => write!(f, "State::Chunk({}, {})", a, b), - ChunkEnd => write!(f, "State::ChunkEnd"), - Trailer => write!(f, "State::Trailer"), - TrailerDone(trailers) => write!(f, "State::TrailerDone({:?})", &trailers), - TrailerSending(_) => write!(f, "State::TrailerSending"), - Done => write!(f, "State::Done"), - } - } + +#[derive(Debug)] +struct PassThroughState { + // Where we are within the chunk + offset: u64, + // How big the chunk is + size: u64, } -impl fmt::Debug for DecodeResult { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { +#[derive(Debug)] +enum ReadState { + // Reading the framing before a chunk + BeforeChunk { size: u64, inner: ChunkSizeState }, + // Just finished reading the chunk data + AfterChunk { new_line: bool }, + // Just read CRLF after chunk data + MaybeTrailer { new_line: bool }, + // Accumulating trailers into a buffer + Trailer { buffer: Vec }, +} + +impl ReadState { + fn advance(&mut self, buf: &[u8]) -> io::Result<(usize, Option)> { match self { - DecodeResult::Some { - read, - buffer, - new_pos, - new_state, - pending, - } => f - .debug_struct("DecodeResult::Some") - .field("read", read) - .field("block", &buffer.len()) - .field("new_pos", new_pos) - .field("new_state", new_state) - .field("pending", pending) - .finish(), - DecodeResult::None(block) => write!(f, "DecodeResult::None({})", block.len()), + ReadState::BeforeChunk { size, inner } => { + let (amt, done) = read_chunk_size(buf, size, inner)?; + if done { + Ok(( + amt, + if *size > 0 { + Some(State::PassThrough(PassThroughState { + offset: 0, + size: *size, + })) + } else { + *self = ReadState::MaybeTrailer { new_line: false }; + None + }, + )) + } else { + Ok((amt, None)) + } + } + ReadState::AfterChunk { new_line } => match (*new_line, buf[0]) { + (false, b'\r') => { + *new_line = true; + Ok((1, None)) + } + (true, b'\n') => { + *self = ReadState::BeforeChunk { + size: 0, + inner: ChunkSizeState::ChunkSize, + }; + Ok((1, None)) + } + _ => Err(invalid_data_err()), + }, + ReadState::MaybeTrailer { new_line } => match (*new_line, buf[0]) { + (false, b'\r') => { + *new_line = true; + Ok((1, None)) + } + (true, b'\n') => Ok(( + 1, + Some(State::Poll(PollState::TrailerDone(Trailers::new()))), + )), + (false, _) => { + *self = ReadState::Trailer { buffer: Vec::new() }; + Ok((0, None)) + } + (true, _) => Err(invalid_data_err()), + }, + ReadState::Trailer { buffer } => { + buffer.extend_from_slice(buf); + let mut headers = [httparse::EMPTY_HEADER; 16]; + match httparse::parse_headers(&buffer, &mut headers) { + Ok(httparse::Status::Complete((amt, headers))) => { + let mut trailers = Trailers::new(); + for header in headers { + trailers.insert( + header.name, + String::from_utf8_lossy(header.value).as_ref(), + ); + } + Ok((amt, Some(State::Poll(PollState::TrailerDone(trailers))))) + } + Ok(httparse::Status::Partial) => Ok((buf.len(), None)), + Err(err) => Err(other_err(err)), + } + } } } } -fn decode_init(buffer: Block<'static>, pos: &Range) -> io::Result { - use httparse::Status; - match httparse::parse_chunk_size(&buffer[pos.start..pos.end]) { - Ok(Status::Complete((used, chunk_len))) => { - let new_pos = Range { - start: pos.start + used, - end: pos.end, - }; - - let new_state = if chunk_len == 0 { - State::Trailer - } else { - State::Chunk(0, chunk_len) - }; - - Ok(DecodeResult::Some { - read: 0, - buffer, - new_pos, - new_state: Some(new_state), - pending: false, - }) - } - Ok(Status::Partial) => Ok(DecodeResult::None(buffer)), - Err(err) => Err(io::Error::new(io::ErrorKind::Other, err.to_string())), - } +enum PollState { + /// Trailers were decoded, are now set to the decoded trailers. + TrailerDone(Trailers), + TrailerSending(Pin + 'static + Send + Sync>>), } -fn decode_chunk_end(buffer: Block<'static>, pos: &Range) -> io::Result { - if pos.len() < 2 { - return Ok(DecodeResult::None(buffer)); +impl PollState { + fn poll( + &mut self, + cx: &mut Context<'_>, + trailer_sender: &mut Option, + ) -> Poll> { + Poll::Ready(match self { + PollState::TrailerDone(trailers) => { + let trailers = std::mem::replace(trailers, Trailers::new()); + let sender = trailer_sender + .take() + .expect("invalid chunked state, tried sending multiple trailers"); + let fut = Box::pin(sender.send(trailers)); + Ok(State::Poll(PollState::TrailerSending(fut))) + } + PollState::TrailerSending(fut) => { + ready!(fut.as_mut().poll(cx)); + Ok(State::Done) + } + }) } +} - if &buffer[pos.start..pos.start + 2] == b"\r\n" { - // valid chunk end move on to a new header - return Ok(DecodeResult::Some { - read: 0, - buffer, - new_pos: Range { - start: pos.start + 2, - end: pos.end, - }, - new_state: Some(State::Init), - pending: false, - }); +impl fmt::Debug for PollState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "PollState {{ .. }}") } - - Err(io::Error::from(io::ErrorKind::InvalidData)) } -fn decode_trailer(buffer: Block<'static>, pos: &Range) -> io::Result { - use httparse::Status; - - // read headers - let mut headers = [httparse::EMPTY_HEADER; 16]; +fn other_err(err: E) -> io::Error { + io::Error::new(io::ErrorKind::Other, err.to_string()) +} - match httparse::parse_headers(&buffer[pos.start..pos.end], &mut headers) { - Ok(Status::Complete((used, headers))) => { - let mut trailers = Trailers::new(); - for header in headers { - trailers.insert(header.name, String::from_utf8_lossy(header.value).as_ref()); - } +fn invalid_data_err() -> io::Error { + io::Error::from(io::ErrorKind::InvalidData) +} - Ok(DecodeResult::Some { - read: 0, - buffer, - new_state: Some(State::TrailerDone(trailers)), - new_pos: Range { - start: pos.start + used, - end: pos.end, - }, - pending: false, - }) - } - Ok(Status::Partial) => Ok(DecodeResult::None(buffer)), - Err(err) => Err(io::Error::new(io::ErrorKind::Other, err.to_string())), - } +fn unexpected_eof() -> io::Error { + io::Error::from(io::ErrorKind::UnexpectedEof) } #[cfg(test)] diff --git a/src/client/decode.rs b/src/client/decode.rs index 9ef6317..ca7dfa9 100644 --- a/src/client/decode.rs +++ b/src/client/decode.rs @@ -80,7 +80,7 @@ where if let Some(encoding) = transfer_encoding { if encoding.last().as_str() == "chunked" { let trailers_sender = res.send_trailers(); - let reader = BufReader::new(ChunkedDecoder::new(reader, trailers_sender)); + let reader = ChunkedDecoder::new(reader, trailers_sender); res.set_body(Body::from_reader(reader, None)); // Return the response. diff --git a/src/lib.rs b/src/lib.rs index 55d569b..730f15b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -108,6 +108,8 @@ mod body_encoder; mod chunked; mod date; mod read_notifier; +mod sequenced; +mod unite; pub mod client; pub mod server; @@ -115,6 +117,7 @@ pub mod server; use async_std::io::Cursor; use body_encoder::BodyEncoder; pub use client::connect; +pub use sequenced::Sequenced; pub use server::{accept, accept_with_opts, ServerOptions}; #[derive(Debug)] diff --git a/src/sequenced.rs b/src/sequenced.rs new file mode 100644 index 0000000..502c766 --- /dev/null +++ b/src/sequenced.rs @@ -0,0 +1,232 @@ +use core::future::Future; +use core::mem; +use core::pin::Pin; +use core::task::{Context, Poll}; + +use futures_channel::oneshot; +use futures_core::ready; +use futures_io::{AsyncBufRead, AsyncRead, AsyncWrite}; +use futures_lite::future::poll_fn; + +#[derive(Debug)] +enum SequencedState { + Active { + value: T, + poisoned: bool, + }, + Waiting { + receiver: oneshot::Receiver, + poisoned: Option, + }, +} + +/// Allows multiple asynchronous tasks to access the same reader or writer concurrently +/// without conflicting. +/// The `split_seq` and `split_seq_rev` methods produce a new instance of the type such that +/// all I/O operations on one instance are sequenced before all I/O operations on the other. +/// +/// When one task has finished with the reader/writer it should call `release`, which will +/// unblock operations on the task with the other instance. If dropped without calling +/// `release`, the inner reader/writer will become poisoned before being returned. The +/// caller can explicitly remove the poisoned status. +/// +/// The `Sequenced` can be split as many times as necessary, and it is valid to call +/// `release()` at any time, although no further operations can be performed via a released +/// instance. If this type is dropped without calling `release()`, then the reader/writer will +/// become poisoned. +/// +/// As only one task has access to the reader/writer at once, no additional synchronization +/// is necessary, and so this wrapper adds very little overhead. What synchronization does +/// occur only needs to happen when an instance is released, in order to send its state to +/// the next instance in the sequence. +/// +/// Merging can be achieved by simply releasing one of the two instances, and then using the +/// other one as normal. It does not matter Which one is released. +#[derive(Debug)] +pub struct Sequenced { + parent: Option>>, + state: Option>, +} + +impl Sequenced { + /// Constructs a new sequenced reader/writer + pub fn new(value: T) -> Self { + Self { + parent: None, + state: Some(SequencedState::Active { + value, + poisoned: false, + }), + } + } + /// Splits this reader/writer into two such that the returned instance is sequenced before this one. + pub fn split_seq(&mut self) -> Self { + let (sender, receiver) = oneshot::channel(); + let state = mem::replace( + &mut self.state, + Some(SequencedState::Waiting { + receiver, + poisoned: None, + }), + ); + Self { + parent: Some(sender), + state, + } + } + /// Splits this reader/writer into two such that the returned instance is sequenced after this one. + pub fn split_seq_rev(&mut self) -> Self { + let other = self.split_seq(); + mem::replace(self, other) + } + + /// Release this reader/writer immediately, allowing instances sequenced after this one to proceed. + pub fn release(&mut self) { + if let (Some(state), Some(parent)) = (self.state.take(), self.parent.take()) { + let _ = parent.send(state); + } + } + fn set_poisoned(&mut self, value: bool) { + match &mut self.state { + Some(SequencedState::Active { poisoned, .. }) => *poisoned = value, + Some(SequencedState::Waiting { poisoned, .. }) => *poisoned = Some(value), + None => {} + } + } + /// Removes the poison status if set + pub(crate) fn cure(&mut self) { + self.set_poisoned(false) + } + fn resolve(&mut self, cx: &mut Context<'_>) -> Poll> { + while let Some(SequencedState::Waiting { receiver, poisoned }) = &mut self.state { + if let Some(sender) = &self.parent { + // Check if we're waiting on ourselves. + if sender.is_connected_to(receiver) { + return Poll::Ready(None); + } + } + let poisoned = *poisoned; + self.state = ready!(Pin::new(receiver).poll(cx)).ok(); + if let Some(value) = poisoned { + self.set_poisoned(value) + } + } + Poll::Ready(match &mut self.state { + Some(SequencedState::Active { + poisoned: false, + value, + }) => Some(value), + Some(SequencedState::Active { poisoned: true, .. }) => None, + Some(SequencedState::Waiting { .. }) => unreachable!(), + None => None, + }) + } + /// Attempt to take the inner reader/writer. This will require waiting until prior instances + /// have been released, and will fail with `None` if any were dropped without being released, + /// or were themselves taken. + /// Instances sequenced after this one will see the reader/writer be closed. + pub fn poll_take_inner(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.as_mut().resolve(cx)); + if let Some(SequencedState::Active { + value, + poisoned: false, + }) = self.as_mut().state.take() + { + Poll::Ready(Some(value)) + } else { + Poll::Ready(None) + } + } + /// Attempt to take the inner reader/writer. This will require waiting until prior instances + /// have been released, and will fail with `None` if any were dropped without being released, + /// or were themselves taken. + /// Instances sequenced after this one will see the reader/writer be closed. + pub async fn take_inner(&mut self) -> Option { + poll_fn(|cx| Pin::new(&mut *self).poll_take_inner(cx)).await + } + + /// Swap the two reader/writers at this sequence point. + pub fn swap(&mut self, other: &mut Self) { + mem::swap(&mut self.state, &mut other.state); + } +} + +impl Drop for Sequenced { + // Poison and release the inner reader/writer. Has no effect if the reader/writer + // was already released. + fn drop(&mut self) { + self.set_poisoned(true); + self.release(); + } +} + +impl Unpin for Sequenced {} + +impl AsyncRead for Sequenced { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + if let Some(inner) = ready!(self.get_mut().resolve(cx)) { + Pin::new(inner).poll_read(cx, buf) + } else { + Poll::Ready(Ok(0)) + } + } +} + +impl AsyncBufRead for Sequenced { + fn poll_fill_buf( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + if let Some(inner) = ready!(self.get_mut().resolve(cx)) { + Pin::new(inner).poll_fill_buf(cx) + } else { + Poll::Ready(Ok(&[])) + } + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + if let Some(SequencedState::Active { + value, + poisoned: false, + }) = &mut self.get_mut().state + { + Pin::new(value).consume(amt); + } else if amt > 0 { + panic!("Called `consume()` without having filled the buffer") + } + } +} + +impl AsyncWrite for Sequenced { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if let Some(inner) = ready!(self.get_mut().resolve(cx)) { + Pin::new(inner).poll_write(cx, buf) + } else { + Poll::Ready(Ok(0)) + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if let Some(inner) = ready!(self.get_mut().resolve(cx)) { + Pin::new(inner).poll_flush(cx) + } else { + Poll::Ready(Ok(())) + } + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if let Some(inner) = ready!(self.get_mut().resolve(cx)) { + Pin::new(inner).poll_close(cx) + } else { + Poll::Ready(Ok(())) + } + } +} diff --git a/src/server/body_reader.rs b/src/server/body_reader.rs deleted file mode 100644 index 7586e31..0000000 --- a/src/server/body_reader.rs +++ /dev/null @@ -1,35 +0,0 @@ -use crate::chunked::ChunkedDecoder; -use async_dup::{Arc, Mutex}; -use async_std::io::{BufReader, Read, Take}; -use async_std::task::{Context, Poll}; -use std::{fmt::Debug, io, pin::Pin}; - -pub enum BodyReader { - Chunked(Arc>>>), - Fixed(Arc>>>), - None, -} - -impl Debug for BodyReader { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - BodyReader::Chunked(_) => f.write_str("BodyReader::Chunked"), - BodyReader::Fixed(_) => f.write_str("BodyReader::Fixed"), - BodyReader::None => f.write_str("BodyReader::None"), - } - } -} - -impl Read for BodyReader { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - match &*self { - BodyReader::Chunked(r) => Pin::new(&mut *r.lock()).poll_read(cx, buf), - BodyReader::Fixed(r) => Pin::new(&mut *r.lock()).poll_read(cx, buf), - BodyReader::None => Poll::Ready(Ok(0)), - } - } -} diff --git a/src/server/decode.rs b/src/server/decode.rs index 133b58e..24605ec 100644 --- a/src/server/decode.rs +++ b/src/server/decode.rs @@ -2,17 +2,18 @@ use std::str::FromStr; -use async_dup::{Arc, Mutex}; -use async_std::io::{BufReader, Read, Write}; +use async_std::io::{self, BufRead, BufReader, Read, Write}; use async_std::{prelude::*, task}; +use futures_channel::oneshot; +use futures_util::{select_biased, FutureExt}; use http_types::content::ContentLength; use http_types::headers::{EXPECT, TRANSFER_ENCODING}; use http_types::{ensure, ensure_eq, format_err}; use http_types::{Body, Method, Request, Url}; -use super::body_reader::BodyReader; use crate::chunked::ChunkedDecoder; use crate::read_notifier::ReadNotifier; +use crate::sequenced::Sequenced; use crate::{MAX_HEADERS, MAX_HEAD_LENGTH}; const LF: u8 = b'\n'; @@ -24,15 +25,64 @@ const CONTINUE_HEADER_VALUE: &str = "100-continue"; const CONTINUE_RESPONSE: &[u8] = b"HTTP/1.1 100 Continue\r\n\r\n"; /// Decode an HTTP request on the server. -pub async fn decode(mut io: IO) -> http_types::Result)>> +pub async fn decode(io: IO) -> http_types::Result> where IO: Read + Write + Clone + Send + Sync + Unpin + 'static, { - let mut reader = BufReader::new(io.clone()); + let mut reader = Sequenced::new(BufReader::new(io.clone())); + let mut writer = Sequenced::new(io); + let res = decode_rw(reader.split_seq(), writer.split_seq()).await?; + Ok(res.map(|(r, _)| { + (r, async move { + reader.take_inner().await; + writer.take_inner().await; + }) + })) +} + +async fn discard_unread_body( + mut body_reader: Sequenced, + mut reader: Sequenced, +) -> io::Result<()> { + // Unpoison the body reader, as we don't require it to be in any particular state + body_reader.cure(); + + // Consume the remainder of the request body + let body_bytes_discarded = io::copy(&mut body_reader, &mut io::sink()).await?; + + log::trace!( + "discarded {} unread request body bytes", + body_bytes_discarded + ); + + // Unpoison the reader, as it's easier than trying to reach into the body reader to + // release the inner `Sequenced` + reader.cure(); + reader.release(); + + Ok(()) +} + +#[derive(Debug)] +pub struct NotifyWrite { + sender: Option>, +} + +/// Decode an HTTP request on the server. +pub async fn decode_rw( + mut reader: Sequenced, + mut writer: Sequenced, +) -> http_types::Result> +where + R: BufRead + Send + Sync + Unpin + 'static, + W: Write + Send + Sync + Unpin + 'static, +{ let mut buf = Vec::new(); let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS]; let mut httparse_req = httparse::Request::new(&mut headers); + let mut notify_write = NotifyWrite { sender: None }; + // Keep reading bytes from the stream until we hit the end of the stream. loop { let bytes_read = reader.read_until(LF, &mut buf).await?; @@ -103,12 +153,47 @@ where let (body_read_sender, body_read_receiver) = async_channel::bounded(1); if Some(CONTINUE_HEADER_VALUE) == req.header(EXPECT).map(|h| h.as_str()) { + // Prevent the response being written until we've decided whether to send + // the continue message or not. + let mut continue_writer = writer.split_seq(); + + // We can swap these later to effectively deactivate the body reader, in the event + // that we don't ask the client to send a body. + let mut continue_reader = reader.split_seq(); + let mut after_reader = reader.split_seq_rev(); + + let (notify_tx, notify_rx) = oneshot::channel(); + notify_write.sender = Some(notify_tx); + + // If the client expects a 100-continue header, spawn a + // task to wait for the first read attempt on the body. task::spawn(async move { - // If the client expects a 100-continue header, spawn a - // task to wait for the first read attempt on the body. - if let Ok(()) = body_read_receiver.recv().await { - io.write_all(CONTINUE_RESPONSE).await.ok(); + // It's important that we fuse this future, or else the `select` won't + // wake up properly if the sender is dropped. + let mut notify_rx = notify_rx.fuse(); + + let should_continue = select_biased! { + x = body_read_receiver.recv().fuse() => x.is_ok(), + _ = notify_rx => true, }; + + if should_continue { + if continue_writer.write_all(CONTINUE_RESPONSE).await.is_err() { + return; + } + } else { + // We never asked for the body, so just allow the next + // request to continue from our current point in the stream. + continue_reader.swap(&mut after_reader); + } + // Allow the rest of the response to be written + continue_writer.release(); + + // Allow the body to be read + continue_reader.release(); + + // Allow the next request to be read (after the body, if requested, has been read) + after_reader.release(); // Since the sender is moved into the Body, this task will // finish when the client disconnects, whether or not // 100-continue was sent. @@ -121,23 +206,43 @@ where .unwrap_or(false) { let trailer_sender = req.send_trailers(); - let reader = ChunkedDecoder::new(reader, trailer_sender); - let reader = Arc::new(Mutex::new(reader)); - let reader_clone = reader.clone(); - let reader = ReadNotifier::new(reader, body_read_sender); - let reader = BufReader::new(reader); - req.set_body(Body::from_reader(reader, None)); - return Ok(Some((req, BodyReader::Chunked(reader_clone)))); + let mut body_reader = + Sequenced::new(ChunkedDecoder::new(reader.split_seq(), trailer_sender)); + req.set_body(Body::from_reader( + ReadNotifier::new(body_reader.split_seq(), body_read_sender), + None, + )); + let reader_to_cure = reader.split_seq(); + + // Spawn a task to consume any part of the body which is unread + task::spawn(async move { + let _ = discard_unread_body(body_reader, reader_to_cure).await; + }); + + reader.release(); + writer.release(); + return Ok(Some((req, notify_write))); } else if let Some(len) = content_length { let len = len.len(); - let reader = Arc::new(Mutex::new(reader.take(len))); + let mut body_reader = Sequenced::new(reader.split_seq().take(len)); req.set_body(Body::from_reader( - BufReader::new(ReadNotifier::new(reader.clone(), body_read_sender)), + ReadNotifier::new(body_reader.split_seq(), body_read_sender), Some(len as usize), )); - Ok(Some((req, BodyReader::Fixed(reader)))) + let reader_to_cure = reader.split_seq(); + + // Spawn a task to consume any part of the body which is unread + task::spawn(async move { + let _ = discard_unread_body(body_reader, reader_to_cure).await; + }); + + reader.release(); + writer.release(); + Ok(Some((req, notify_write))) } else { - Ok(Some((req, BodyReader::None))) + reader.release(); + writer.release(); + Ok(Some((req, notify_write))) } } diff --git a/src/server/mod.rs b/src/server/mod.rs index 1cfa4e9..1d644c5 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -1,18 +1,21 @@ //! Process HTTP connections on the server. use async_std::future::{timeout, Future, TimeoutError}; -use async_std::io::{self, Read, Write}; +use async_std::io::{self, BufRead, BufReader, Read, Write}; use http_types::headers::{CONNECTION, UPGRADE}; use http_types::upgrade::Connection; use http_types::{Request, Response, StatusCode}; use std::{marker::PhantomData, time::Duration}; -mod body_reader; + mod decode; mod encode; -pub use decode::decode; +pub use decode::{decode, decode_rw}; pub use encode::Encoder; +use crate::sequenced::Sequenced; +use crate::unite::Unite; + /// Configure the server. #[derive(Debug, Clone)] pub struct ServerOptions { @@ -23,7 +26,7 @@ pub struct ServerOptions { impl Default for ServerOptions { fn default() -> Self { Self { - headers_timeout: Some(Duration::from_secs(60)), + headers_timeout: Some(Duration::from_secs(30)), } } } @@ -58,8 +61,9 @@ where /// struct for server #[derive(Debug)] -pub struct Server { - io: RW, +pub struct Server { + reader: Sequenced, + writer: Sequenced, endpoint: F, opts: ServerOptions, _phantom: PhantomData, @@ -75,16 +79,34 @@ pub enum ConnectionStatus { KeepAlive, } -impl Server +impl Server, RW, F, Fut> where - RW: Read + Write + Clone + Send + Sync + Unpin + 'static, + RW: Read + Write + Send + Sync + Clone + Unpin + 'static, F: Fn(Request) -> Fut, Fut: Future>, { /// builds a new server pub fn new(io: RW, endpoint: F) -> Self { + Self::new_rw( + Sequenced::new(BufReader::new(io.clone())), + Sequenced::new(io), + endpoint, + ) + } +} + +impl Server +where + R: BufRead + Send + Sync + Unpin + 'static, + W: Write + Send + Sync + Unpin + 'static, + F: Fn(Request) -> Fut, + Fut: Future>, +{ + /// builds a new server + pub fn new_rw(reader: Sequenced, writer: Sequenced, endpoint: F) -> Self { Self { - io, + reader, + writer, endpoint, opts: Default::default(), _phantom: PhantomData, @@ -104,16 +126,11 @@ where } /// accept one request - pub async fn accept_one(&mut self) -> http_types::Result - where - RW: Read + Write + Clone + Send + Sync + Unpin + 'static, - F: Fn(Request) -> Fut, - Fut: Future>, - { + pub async fn accept_one(&mut self) -> http_types::Result { // Decode a new request, timing out if this takes longer than the timeout duration. - let fut = decode(self.io.clone()); + let fut = decode_rw(self.reader.split_seq(), self.writer.split_seq()); - let (req, mut body) = if let Some(timeout_duration) = self.opts.headers_timeout { + let (req, notify_write) = if let Some(timeout_duration) = self.opts.headers_timeout { match timeout(timeout_duration, fut).await { Ok(Ok(Some(r))) => r, Ok(Ok(None)) | Err(TimeoutError { .. }) => return Ok(ConnectionStatus::Close), /* EOF or timeout */ @@ -159,17 +176,22 @@ where let mut encoder = Encoder::new(res, method); - let bytes_written = io::copy(&mut encoder, &mut self.io).await?; + // This should be dropped before we begin writing the response. + drop(notify_write); + + let bytes_written = io::copy(&mut encoder, &mut self.writer).await?; log::trace!("wrote {} response bytes", bytes_written); - let body_bytes_discarded = io::copy(&mut body, &mut io::sink()).await?; - log::trace!( - "discarded {} unread request body bytes", - body_bytes_discarded - ); + async_std::task::sleep(Duration::from_millis(1)).await; if let Some(upgrade_sender) = upgrade_sender { - upgrade_sender.send(Connection::new(self.io.clone())).await; + let reader = self.reader.take_inner().await; + let writer = self.writer.take_inner().await; + if let (Some(reader), Some(writer)) = (reader, writer) { + upgrade_sender + .send(Connection::new(Unite::new(reader, writer))) + .await; + } return Ok(ConnectionStatus::Close); } else if close_connection { Ok(ConnectionStatus::Close) diff --git a/src/unite.rs b/src/unite.rs new file mode 100644 index 0000000..ad23785 --- /dev/null +++ b/src/unite.rs @@ -0,0 +1,60 @@ +use core::pin::Pin; +use core::task::{Context, Poll}; + +use futures_io::{AsyncBufRead, AsyncRead, AsyncWrite}; +use pin_project::pin_project; + +#[pin_project] +pub(crate) struct Unite { + #[pin] + reader: R, + #[pin] + writer: W, +} + +impl Unite { + pub(crate) fn new(reader: R, writer: W) -> Self { + Self { reader, writer } + } +} + +impl AsyncRead for Unite { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + self.project().reader.poll_read(cx, buf) + } +} + +impl AsyncBufRead for Unite { + fn poll_fill_buf( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.project().reader.poll_fill_buf(cx) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + self.project().reader.consume(amt) + } +} + +impl AsyncWrite for Unite { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.project().writer.poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().writer.poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().writer.poll_close(cx) + } +} diff --git a/tests/accept.rs b/tests/accept.rs index 92283a8..a1882f6 100644 --- a/tests/accept.rs +++ b/tests/accept.rs @@ -17,7 +17,7 @@ mod accept { let content_length = 10; let request_str = format!( - "POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: {}\r\n\r\n{}\r\n\r\n", + "POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: {}\r\n\r\n{}", content_length, std::str::from_utf8(&vec![b'|'; content_length]).unwrap() ); @@ -33,6 +33,36 @@ mod accept { Ok(()) } + #[async_std::test] + async fn pipelined() -> Result<()> { + let mut server = TestServer::new(|req| async { + let mut response = Response::new(200); + let len = req.len(); + response.set_body(Body::from_reader(req, len)); + Ok(response) + }); + + let content_length = 10; + + let request_str = format!( + "POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: {}\r\n\r\n{}", + content_length, + std::str::from_utf8(&vec![b'|'; content_length]).unwrap() + ); + + server.write_all(request_str.as_bytes()).await?; + server.write_all(request_str.as_bytes()).await?; + assert_eq!(server.accept_one().await?, ConnectionStatus::KeepAlive); + assert_eq!(server.accept_one().await?, ConnectionStatus::KeepAlive); + + server.close(); + assert_eq!(server.accept_one().await?, ConnectionStatus::Close); + + assert!(server.all_read()); + + Ok(()) + } + #[async_std::test] async fn request_close() -> Result<()> { let mut server = TestServer::new(|_| async { Ok(Response::new(200)) }); @@ -74,7 +104,7 @@ mod accept { let content_length = 10; let request_str = format!( - "POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: {}\r\n\r\n{}\r\n\r\n", + "POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: {}\r\n\r\n{}", content_length, std::str::from_utf8(&vec![b'|'; content_length]).unwrap() ); @@ -130,7 +160,7 @@ mod accept { let content_length = 10000; let request_str = format!( - "POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: {}\r\n\r\n{}\r\n\r\n", + "POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: {}\r\n\r\n{}", content_length, std::str::from_utf8(&vec![b'|'; content_length]).unwrap() ); @@ -169,6 +199,33 @@ mod accept { "GET / HTTP/1.1\r\nHost: example.com\r\nContent-Length: 0\r\n\r\n" )) .await?; + server.close(); + assert_eq!(server.accept_one().await?, ConnectionStatus::KeepAlive); + + assert_eq!(server.accept_one().await?, ConnectionStatus::Close); + + assert!(server.all_read()); + + Ok(()) + } + + #[async_std::test] + async fn echo_server() -> Result<()> { + let mut server = TestServer::new(|mut req| async move { + let mut resp = Response::new(200); + resp.set_body(req.take_body()); + Ok(resp) + }); + + let content_length = 10; + + let request_str = format!( + "POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: {}\r\n\r\n{}", + content_length, + std::str::from_utf8(&vec![b'|'; content_length]).unwrap() + ); + + server.write_all(request_str.as_bytes()).await?; assert_eq!(server.accept_one().await?, ConnectionStatus::KeepAlive); server.close(); diff --git a/tests/continue.rs b/tests/continue.rs index 933fbfe..5d6518a 100644 --- a/tests/continue.rs +++ b/tests/continue.rs @@ -1,9 +1,13 @@ mod test_utils; +use async_h1::server::ConnectionStatus; +use async_h1::Sequenced; +use async_std::future::timeout; +use async_std::io::BufReader; use async_std::{io, prelude::*, task}; -use http_types::Result; +use http_types::{Response, Result}; use std::time::Duration; -use test_utils::TestIO; +use test_utils::{TestIO, TestServer}; const REQUEST_WITH_EXPECT: &[u8] = b"POST / HTTP/1.1\r\n\ Host: example.com\r\n\ @@ -16,7 +20,12 @@ async fn test_with_expect_when_reading_body() -> Result<()> { let (mut client, server) = TestIO::new(); client.write_all(REQUEST_WITH_EXPECT).await?; - let (mut request, _) = async_h1::server::decode(server).await?.unwrap(); + let (mut request, _notify_write) = async_h1::server::decode_rw( + Sequenced::new(BufReader::new(server.clone())), + Sequenced::new(server.clone()), + ) + .await? + .unwrap(); task::sleep(SLEEP_DURATION).await; //prove we're not just testing before we've written @@ -44,11 +53,202 @@ async fn test_without_expect_when_not_reading_body() -> Result<()> { let (mut client, server) = TestIO::new(); client.write_all(REQUEST_WITH_EXPECT).await?; - let (_, _) = async_h1::server::decode(server).await?.unwrap(); + let _ = async_h1::server::decode_rw( + Sequenced::new(BufReader::new(server.clone())), + Sequenced::new(server.clone()), + ) + .await? + .unwrap(); task::sleep(SLEEP_DURATION).await; // just long enough to wait for the channel assert_eq!("", &client.read.to_string()); // we haven't written 100-continue + client.write_all(REQUEST_WITH_EXPECT).await?; + + // Make sure the server doesn't try to read the body before processing the next request + task::sleep(SLEEP_DURATION).await; + let (_, _) = async_h1::server::decode(server).await?.unwrap(); + + Ok(()) +} + +#[async_std::test] +async fn test_accept_unread_body() -> Result<()> { + let mut server = TestServer::new(|_| async { Ok(Response::new(200)) }); + + server.write_all(REQUEST_WITH_EXPECT).await?; + assert_eq!( + timeout(Duration::from_secs(1), server.accept_one()).await??, + ConnectionStatus::KeepAlive + ); + + server.write_all(REQUEST_WITH_EXPECT).await?; + assert_eq!( + timeout(Duration::from_secs(1), server.accept_one()).await??, + ConnectionStatus::KeepAlive + ); + + server.close(); + assert_eq!(server.accept_one().await?, ConnectionStatus::Close); + + assert!(server.all_read()); + + Ok(()) +} + +#[async_std::test] +async fn test_echo_server() -> Result<()> { + let mut server = TestServer::new(|mut req| async move { + let mut resp = Response::new(200); + resp.set_body(req.take_body()); + Ok(resp) + }); + + server.write_all(REQUEST_WITH_EXPECT).await?; + server.write_all(b"0123456789").await?; + assert_eq!(server.accept_one().await?, ConnectionStatus::KeepAlive); + + task::sleep(SLEEP_DURATION).await; // wait for "continue" to be sent + + server.close(); + + assert!(server + .client + .read + .to_string() + .starts_with("HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\n")); + + assert_eq!(server.accept_one().await?, ConnectionStatus::Close); + + assert!(server.all_read()); + + Ok(()) +} + +#[async_std::test] +async fn test_delayed_read() -> Result<()> { + let mut server = TestServer::new(|mut req| async move { + let mut body = req.take_body(); + task::spawn(async move { + let mut buf = Vec::new(); + body.read_to_end(&mut buf).await.unwrap(); + }); + Ok(Response::new(200)) + }); + + server.write_all(REQUEST_WITH_EXPECT).await?; + assert_eq!( + timeout(Duration::from_secs(1), server.accept_one()).await??, + ConnectionStatus::KeepAlive + ); + server.write_all(b"0123456789").await?; + + server.write_all(REQUEST_WITH_EXPECT).await?; + assert_eq!( + timeout(Duration::from_secs(1), server.accept_one()).await??, + ConnectionStatus::KeepAlive + ); + server.write_all(b"0123456789").await?; + + server.close(); + assert_eq!(server.accept_one().await?, ConnectionStatus::Close); + + assert!(server.all_read()); + + Ok(()) +} + +#[async_std::test] +async fn test_accept_fast_unread_sequential_requests() -> Result<()> { + let mut server = TestServer::new(|_| async move { Ok(Response::new(200)) }); + let mut client = server.client.clone(); + + task::spawn(async move { + let mut reader = BufReader::new(client.clone()); + for _ in 0..10 { + let mut buf = String::new(); + client.write_all(REQUEST_WITH_EXPECT).await.unwrap(); + + while !buf.ends_with("\r\n\r\n") { + reader.read_line(&mut buf).await.unwrap(); + } + + assert!(buf.starts_with("HTTP/1.1 200 OK\r\n")); + } + client.close(); + }); + + for _ in 0..10 { + assert_eq!( + timeout(Duration::from_secs(1), server.accept_one()).await??, + ConnectionStatus::KeepAlive + ); + } + + assert_eq!(server.accept_one().await?, ConnectionStatus::Close); + + assert!(server.all_read()); + + Ok(()) +} + +#[async_std::test] +async fn test_accept_partial_read_sequential_requests() -> Result<()> { + const LARGE_REQUEST_WITH_EXPECT: &[u8] = b"POST / HTTP/1.1\r\n\ + Host: example.com\r\n\ + Content-Length: 1000\r\n\ + Expect: 100-continue\r\n\r\n"; + + let mut server = TestServer::new(|mut req| async move { + let mut body = req.take_body(); + let mut buf = [0]; + body.read(&mut buf).await.unwrap(); + Ok(Response::new(200)) + }); + let mut client = server.client.clone(); + + task::spawn(async move { + let mut reader = BufReader::new(client.clone()); + for _ in 0..10 { + let mut buf = String::new(); + client.write_all(LARGE_REQUEST_WITH_EXPECT).await.unwrap(); + + // Wait for body to be requested + while !buf.ends_with("\r\n\r\n") { + reader.read_line(&mut buf).await.unwrap(); + } + assert!(buf.starts_with("HTTP/1.1 100 Continue\r\n")); + + // Write body + for _ in 0..100 { + client.write_all(b"0123456789").await.unwrap(); + } + + // Wait for response + buf.clear(); + while !buf.ends_with("\r\n\r\n") { + reader.read_line(&mut buf).await.unwrap(); + } + + assert!(buf.starts_with("HTTP/1.1 200 OK\r\n")); + } + client.close(); + }); + + for _ in 0..10 { + assert_eq!( + timeout(Duration::from_secs(1), server.accept_one()).await??, + ConnectionStatus::KeepAlive + ); + } + + assert_eq!( + timeout(Duration::from_secs(1), server.accept_one()).await??, + ConnectionStatus::Close + ); + + assert!(server.all_read()); + Ok(()) } diff --git a/tests/server_decode.rs b/tests/server_decode.rs index 10c6701..46f3882 100644 --- a/tests/server_decode.rs +++ b/tests/server_decode.rs @@ -67,6 +67,7 @@ mod server_decode { "llo", "0", "", + "", ]) .await? .unwrap(); @@ -93,6 +94,7 @@ mod server_decode { "0", "x-invalid: å", "", + "", ]) .await? .unwrap(); diff --git a/tests/test_utils.rs b/tests/test_utils.rs index 8194590..7805b85 100644 --- a/tests/test_utils.rs +++ b/tests/test_utils.rs @@ -2,7 +2,7 @@ use async_h1::{ client::Encoder, server::{ConnectionStatus, Server}, }; -use async_std::io::{Read, Write}; +use async_std::io::{BufReader, Read, Write}; use http_types::{Request, Response, Result}; use std::{ fmt::{Debug, Display}, @@ -17,9 +17,9 @@ use async_dup::Arc; #[pin_project::pin_project] pub struct TestServer { - server: Server, + server: Server, TestIO, F, Fut>, #[pin] - client: TestIO, + pub(crate) client: TestIO, } impl TestServer @@ -102,35 +102,47 @@ pub struct TestIO { } #[derive(Default)] -pub struct CloseableCursor { - data: RwLock>, - cursor: RwLock, - waker: RwLock>, - closed: RwLock, +struct CloseableCursorInner { + data: Vec, + cursor: usize, + waker: Option, + closed: bool, } +#[derive(Default)] +pub struct CloseableCursor(RwLock); + impl CloseableCursor { - fn len(&self) -> usize { - self.data.read().unwrap().len() + pub fn len(&self) -> usize { + self.0.read().unwrap().data.len() + } + + pub fn cursor(&self) -> usize { + self.0.read().unwrap().cursor } - fn cursor(&self) -> usize { - *self.cursor.read().unwrap() + pub fn is_empty(&self) -> bool { + self.len() == 0 } - fn current(&self) -> bool { - self.len() == self.cursor() + pub fn current(&self) -> bool { + let inner = self.0.read().unwrap(); + inner.data.len() == inner.cursor } - fn close(&self) { - *self.closed.write().unwrap() = true; + pub fn close(&self) { + let mut inner = self.0.write().unwrap(); + inner.closed = true; + if let Some(waker) = inner.waker.take() { + waker.wake(); + } } } impl Display for CloseableCursor { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let data = &*self.data.read().unwrap(); - let s = std::str::from_utf8(data).unwrap_or("not utf8"); + let inner = self.0.read().unwrap(); + let s = std::str::from_utf8(&inner.data).unwrap_or("not utf8"); write!(f, "{}", s) } } @@ -163,13 +175,14 @@ impl TestIO { impl Debug for CloseableCursor { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let inner = self.0.read().unwrap(); f.debug_struct("CloseableCursor") .field( "data", - &std::str::from_utf8(&self.data.read().unwrap()).unwrap_or("not utf8"), + &std::str::from_utf8(&inner.data).unwrap_or("not utf8"), ) - .field("closed", &*self.closed.read().unwrap()) - .field("cursor", &*self.cursor.read().unwrap()) + .field("closed", &inner.closed) + .field("cursor", &inner.cursor) .finish() } } @@ -180,18 +193,17 @@ impl Read for &CloseableCursor { cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { - let len = self.len(); - let cursor = self.cursor(); - if cursor < len { - let data = &*self.data.read().unwrap(); - let bytes_to_copy = buf.len().min(len - cursor); - buf[..bytes_to_copy].copy_from_slice(&data[cursor..cursor + bytes_to_copy]); - *self.cursor.write().unwrap() += bytes_to_copy; + let mut inner = self.0.write().unwrap(); + if inner.cursor < inner.data.len() { + let bytes_to_copy = buf.len().min(inner.data.len() - inner.cursor); + buf[..bytes_to_copy] + .copy_from_slice(&inner.data[inner.cursor..inner.cursor + bytes_to_copy]); + inner.cursor += bytes_to_copy; Poll::Ready(Ok(bytes_to_copy)) - } else if *self.closed.read().unwrap() { + } else if inner.closed { Poll::Ready(Ok(0)) } else { - *self.waker.write().unwrap() = Some(cx.waker().clone()); + inner.waker = Some(cx.waker().clone()); Poll::Pending } } @@ -203,11 +215,12 @@ impl Write for &CloseableCursor { _cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - if *self.closed.read().unwrap() { + let mut inner = self.0.write().unwrap(); + if inner.closed { Poll::Ready(Ok(0)) } else { - self.data.write().unwrap().extend_from_slice(buf); - if let Some(waker) = self.waker.write().unwrap().take() { + inner.data.extend_from_slice(buf); + if let Some(waker) = inner.waker.take() { waker.wake(); } Poll::Ready(Ok(buf.len())) @@ -219,10 +232,7 @@ impl Write for &CloseableCursor { } fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - if let Some(waker) = self.waker.write().unwrap().take() { - waker.wake(); - } - *self.closed.write().unwrap() = true; + self.close(); Poll::Ready(Ok(())) } }