Skip to content

Commit

Permalink
feat(server): add AutoConnection (#11)
Browse files Browse the repository at this point in the history
A way to auto detect HTTP version from the client.
  • Loading branch information
programatik29 authored Sep 16, 2023
1 parent 229757e commit 334209d
Show file tree
Hide file tree
Showing 7 changed files with 684 additions and 0 deletions.
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ hyper = "=1.0.0-rc.4"
futures-channel = "0.3"
futures-util = { version = "0.3", default-features = false }
http = "0.2"
http-body = "1.0.0-rc.2"
bytes = "1"

once_cell = "1.14"

Expand All @@ -30,9 +32,11 @@ tower-service = "0.3"
tower = { version = "0.4", features = ["make", "util"] }

[dev-dependencies]
hyper = { version = "1.0.0-rc.3", features = ["full"] }
bytes = "1"
http-body-util = "0.1.0-rc.3"
tokio = { version = "1", features = ["macros", "test-util"] }
tokio-test = "0.4"

[target.'cfg(any(target_os = "linux", target_os = "macos"))'.dev-dependencies]
pnet_datalink = "0.27.2"
Expand All @@ -50,6 +54,7 @@ http1 = ["hyper/http1"]
http2 = ["hyper/http2"]

tcp = []
auto = ["hyper/server", "http1", "http2"]
runtime = []

# internal features used in CI
Expand Down
1 change: 1 addition & 0 deletions src/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub mod exec;
#[cfg(feature = "client")]
mod lazy;
pub(crate) mod never;
pub(crate) mod rewind;
#[cfg(feature = "client")]
mod sync;

Expand Down
161 changes: 161 additions & 0 deletions src/common/rewind.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
use std::marker::Unpin;
use std::{cmp, io};

use bytes::{Buf, Bytes};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

use std::{
pin::Pin,
task::{self, Poll},
};

/// Combine a buffer with an IO, rewinding reads to use the buffer.
#[derive(Debug)]
pub(crate) struct Rewind<T> {
pre: Option<Bytes>,
inner: T,
}

impl<T> Rewind<T> {
#[cfg(test)]
pub(crate) fn new(io: T) -> Self {
Rewind {
pre: None,
inner: io,
}
}

#[allow(dead_code)]
pub(crate) fn new_buffered(io: T, buf: Bytes) -> Self {
Rewind {
pre: Some(buf),
inner: io,
}
}

#[cfg(test)]
pub(crate) fn rewind(&mut self, bs: Bytes) {
debug_assert!(self.pre.is_none());
self.pre = Some(bs);
}

// pub(crate) fn into_inner(self) -> (T, Bytes) {
// (self.inner, self.pre.unwrap_or_else(Bytes::new))
// }

// pub(crate) fn get_mut(&mut self) -> &mut T {
// &mut self.inner
// }
}

impl<T> AsyncRead for Rewind<T>
where
T: AsyncRead + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if let Some(mut prefix) = self.pre.take() {
// If there are no remaining bytes, let the bytes get dropped.
if !prefix.is_empty() {
let copy_len = cmp::min(prefix.len(), buf.remaining());
// TODO: There should be a way to do following two lines cleaner...
buf.put_slice(&prefix[..copy_len]);
prefix.advance(copy_len);
// Put back what's left
if !prefix.is_empty() {
self.pre = Some(prefix);
}

return Poll::Ready(Ok(()));
}
}
Pin::new(&mut self.inner).poll_read(cx, buf)
}
}

impl<T> AsyncWrite for Rewind<T>
where
T: AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write(cx, buf)
}

fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx)
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}

fn is_write_vectored(&self) -> bool {
self.inner.is_write_vectored()
}
}

#[cfg(test)]
mod tests {
// FIXME: re-implement tests with `async/await`, this import should
// trigger a warning to remind us
use super::Rewind;
use bytes::Bytes;
use tokio::io::AsyncReadExt;

#[cfg(not(miri))]
#[tokio::test]
async fn partial_rewind() {
let underlying = [104, 101, 108, 108, 111];

let mock = tokio_test::io::Builder::new().read(&underlying).build();

let mut stream = Rewind::new(mock);

// Read off some bytes, ensure we filled o1
let mut buf = [0; 2];
stream.read_exact(&mut buf).await.expect("read1");

// Rewind the stream so that it is as if we never read in the first place.
stream.rewind(Bytes::copy_from_slice(&buf[..]));

let mut buf = [0; 5];
stream.read_exact(&mut buf).await.expect("read1");

// At this point we should have read everything that was in the MockStream
assert_eq!(&buf, &underlying);
}

#[cfg(not(miri))]
#[tokio::test]
async fn full_rewind() {
let underlying = [104, 101, 108, 108, 111];

let mock = tokio_test::io::Builder::new().read(&underlying).build();

let mut stream = Rewind::new(mock);

let mut buf = [0; 5];
stream.read_exact(&mut buf).await.expect("read1");

// Rewind the stream so that it is as if we never read in the first place.
stream.rewind(Bytes::copy_from_slice(&buf[..]));

let mut buf = [0; 5];
stream.read_exact(&mut buf).await.expect("read1");
}
}
4 changes: 4 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
#![deny(missing_docs)]
#![cfg_attr(docsrs, feature(doc_auto_cfg, doc_cfg))]

//! hyper-util
#[cfg(feature = "client")]
pub mod client;
mod common;
pub mod rt;
pub mod server;

mod error;
Loading

0 comments on commit 334209d

Please sign in to comment.