From aca53acddc200ff53350760d99966fb0b87afaa5 Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Wed, 23 Oct 2024 11:49:33 +1100 Subject: [PATCH 01/20] First stage of streaming implementation --- bhttp-convert/Cargo.toml | 2 +- bhttp/Cargo.toml | 14 +-- bhttp/src/err.rs | 12 +- bhttp/src/lib.rs | 73 ++++------- bhttp/src/parse.rs | 23 ++-- bhttp/src/rw.rs | 58 ++++----- bhttp/src/stream/mod.rs | 233 ++++++++++++++++++++++++++++++++++++ bhttp/tests/test.rs | 2 +- ohttp-client-cli/Cargo.toml | 2 +- ohttp-client/Cargo.toml | 2 +- ohttp-server/Cargo.toml | 2 +- ohttp/build.rs | 20 ++-- 12 files changed, 307 insertions(+), 136 deletions(-) create mode 100644 bhttp/src/stream/mod.rs diff --git a/bhttp-convert/Cargo.toml b/bhttp-convert/Cargo.toml index 8a82101..47e11f1 100644 --- a/bhttp-convert/Cargo.toml +++ b/bhttp-convert/Cargo.toml @@ -9,4 +9,4 @@ structopt = "0.3" [dependencies.bhttp] path= "../bhttp" -features = ["bhttp", "http"] +features = ["http"] diff --git a/bhttp/Cargo.toml b/bhttp/Cargo.toml index 8c89536..14aff2b 100644 --- a/bhttp/Cargo.toml +++ b/bhttp/Cargo.toml @@ -9,17 +9,15 @@ description = "Binary HTTP messages (RFC 9292)" repository = "https://github.com/martinthomson/ohttp" [features] -default = ["bhttp"] -bhttp = ["read-bhttp", "write-bhttp"] -http = ["read-http", "write-http"] -read-bhttp = [] -write-bhttp = [] -read-http = ["url"] -write-http = [] +default = ["stream"] +http = ["url"] +stream = ["futures", "pin-project"] [dependencies] +futures = {version = "0.3", optional = true} +pin-project = {version = "1.1", optional = true} thiserror = "1" url = {version = "2", optional = true} [dev-dependencies] -hex = "0.4" +hex = "0.4" \ No newline at end of file diff --git a/bhttp/src/err.rs b/bhttp/src/err.rs index 19d5455..3a457da 100644 --- a/bhttp/src/err.rs +++ b/bhttp/src/err.rs @@ -1,6 +1,4 @@ -use thiserror::Error; - -#[derive(Error, Debug)] +#[derive(thiserror::Error, Debug)] pub enum Error { #[error("a request used the CONNECT method")] ConnectUnsupported, @@ -34,14 +32,8 @@ pub enum Error { #[error("a message included the Upgrade field")] UpgradeUnsupported, #[error("a URL could not be parsed into components: {0}")] - #[cfg(feature = "read-http")] + #[cfg(feature = "http")] UrlParse(#[from] url::ParseError), } -#[cfg(any( - feature = "read-http", - feature = "write-http", - feature = "read-bhttp", - feature = "write-bhttp" -))] pub type Res = Result; diff --git a/bhttp/src/lib.rs b/bhttp/src/lib.rs index 3c8fbde..f92b2c4 100644 --- a/bhttp/src/lib.rs +++ b/bhttp/src/lib.rs @@ -1,46 +1,26 @@ #![deny(warnings, clippy::pedantic)] #![allow(clippy::missing_errors_doc)] // Too lazy to document these. -#[cfg(feature = "read-bhttp")] -use std::convert::TryFrom; -#[cfg(any( - feature = "read-http", - feature = "write-http", - feature = "read-bhttp", - feature = "write-bhttp" -))] -use std::io; - -#[cfg(feature = "read-http")] +use std::{borrow::BorrowMut, io}; + +#[cfg(feature = "http")] use url::Url; mod err; mod parse; -#[cfg(any(feature = "read-bhttp", feature = "write-bhttp"))] mod rw; - -#[cfg(any(feature = "read-http", feature = "read-bhttp",))] -use std::borrow::BorrowMut; +#[cfg(feature = "stream")] +pub mod stream; pub use err::Error; -#[cfg(any( - feature = "read-http", - feature = "write-http", - feature = "read-bhttp", - feature = "write-bhttp" -))] use err::Res; -#[cfg(feature = "read-http")] +#[cfg(feature = "http")] use parse::{downcase, is_ows, read_line, split_at, COLON, SEMICOLON, SLASH, SP}; use parse::{index_of, trim_ows, COMMA}; -#[cfg(feature = "read-bhttp")] -use rw::{read_varint, read_vec}; -#[cfg(feature = "write-bhttp")] -use rw::{write_len, write_varint, write_vec}; +use rw::{read_varint, read_vec, write_len, write_varint, write_vec}; -#[cfg(feature = "read-http")] +#[cfg(feature = "http")] const CONTENT_LENGTH: &[u8] = b"content-length"; -#[cfg(feature = "read-bhttp")] const COOKIE: &[u8] = b"cookie"; const TRANSFER_ENCODING: &[u8] = b"transfer-encoding"; const CHUNKED: &[u8] = b"chunked"; @@ -93,7 +73,6 @@ impl ReadSeek for io::Cursor where T: AsRef<[u8]> {} impl ReadSeek for io::BufReader where T: io::Read + io::Seek {} #[derive(Clone, Copy, Debug, PartialEq, Eq)] -#[cfg(any(feature = "read-bhttp", feature = "write-bhttp"))] pub enum Mode { KnownLength, IndeterminateLength, @@ -120,7 +99,7 @@ impl Field { &self.value } - #[cfg(feature = "write-http")] + #[cfg(feature = "http")] pub fn write_http(&self, w: &mut impl io::Write) -> Res<()> { w.write_all(&self.name)?; w.write_all(b": ")?; @@ -129,14 +108,13 @@ impl Field { Ok(()) } - #[cfg(feature = "write-bhttp")] pub fn write_bhttp(&self, w: &mut impl io::Write) -> Res<()> { write_vec(&self.name, w)?; write_vec(&self.value, w)?; Ok(()) } - #[cfg(feature = "read-http")] + #[cfg(feature = "http")] pub fn obs_fold(&mut self, extra: &[u8]) { self.value.push(SP); self.value.extend(trim_ows(extra)); @@ -192,7 +170,7 @@ impl FieldSection { /// As required by the HTTP specification, remove the Connection header /// field, everything it refers to, and a few extra fields. - #[cfg(feature = "read-http")] + #[cfg(feature = "http")] fn strip_connection_headers(&mut self) { const CONNECTION: &[u8] = b"connection"; const PROXY_CONNECTION: &[u8] = b"proxy-connection"; @@ -232,7 +210,7 @@ impl FieldSection { }); } - #[cfg(feature = "read-http")] + #[cfg(feature = "http")] fn parse_line(fields: &mut Vec, line: Vec) -> Res<()> { // obs-fold is helpful in specs, so support it here too let f = if is_ows(line[0]) { @@ -251,7 +229,7 @@ impl FieldSection { Ok(()) } - #[cfg(feature = "read-http")] + #[cfg(feature = "http")] pub fn read_http(r: &mut T) -> Res where T: BorrowMut + ?Sized, @@ -267,7 +245,6 @@ impl FieldSection { } } - #[cfg(feature = "read-bhttp")] fn read_bhttp_fields(terminator: bool, r: &mut T) -> Res> where T: BorrowMut + ?Sized, @@ -302,7 +279,6 @@ impl FieldSection { } } - #[cfg(feature = "read-bhttp")] pub fn read_bhttp(mode: Mode, r: &mut T) -> Res where T: BorrowMut + ?Sized, @@ -320,7 +296,6 @@ impl FieldSection { Ok(Self(fields)) } - #[cfg(feature = "write-bhttp")] fn write_bhttp_headers(&self, w: &mut impl io::Write) -> Res<()> { for f in &self.0 { f.write_bhttp(w)?; @@ -328,7 +303,6 @@ impl FieldSection { Ok(()) } - #[cfg(feature = "write-bhttp")] pub fn write_bhttp(&self, mode: Mode, w: &mut impl io::Write) -> Res<()> { if mode == Mode::KnownLength { let mut buf = Vec::new(); @@ -341,7 +315,7 @@ impl FieldSection { Ok(()) } - #[cfg(feature = "write-http")] + #[cfg(feature = "http")] pub fn write_http(&self, w: &mut impl io::Write) -> Res<()> { for f in &self.0 { f.write_http(w)?; @@ -420,7 +394,7 @@ impl ControlData { } } - #[cfg(feature = "read-http")] + #[cfg(feature = "http")] pub fn read_http(line: Vec) -> Res { // request-line = method SP request-target SP HTTP-version // status-line = HTTP-version SP status-code SP [reason-phrase] @@ -467,7 +441,6 @@ impl ControlData { } } - #[cfg(feature = "read-bhttp")] pub fn read_bhttp(request: bool, r: &mut T) -> Res where T: BorrowMut + ?Sized, @@ -493,7 +466,6 @@ impl ControlData { } /// If this is an informational response. - #[cfg(any(feature = "read-bhttp", feature = "read-http"))] #[must_use] fn informational(&self) -> Option { match self { @@ -502,7 +474,6 @@ impl ControlData { } } - #[cfg(feature = "write-bhttp")] #[must_use] fn code(&self, mode: Mode) -> u64 { match (self, mode) { @@ -513,7 +484,6 @@ impl ControlData { } } - #[cfg(feature = "write-bhttp")] pub fn write_bhttp(&self, w: &mut impl io::Write) -> Res<()> { match self { Self::Request { @@ -532,7 +502,7 @@ impl ControlData { Ok(()) } - #[cfg(feature = "write-http")] + #[cfg(feature = "http")] pub fn write_http(&self, w: &mut impl io::Write) -> Res<()> { match self { Self::Request { @@ -581,7 +551,6 @@ impl InformationalResponse { &self.fields } - #[cfg(feature = "write-bhttp")] fn write_bhttp(&self, mode: Mode, w: &mut impl io::Write) -> Res<()> { write_varint(self.status.code(), w)?; self.fields.write_bhttp(mode, w)?; @@ -662,7 +631,7 @@ impl Message { &self.trailer } - #[cfg(feature = "read-http")] + #[cfg(feature = "http")] fn read_chunked(r: &mut T) -> Res> where T: BorrowMut + ?Sized, @@ -686,7 +655,7 @@ impl Message { } } - #[cfg(feature = "read-http")] + #[cfg(feature = "http")] #[allow(clippy::read_zero_byte_vec)] // https://github.com/rust-lang/rust-clippy/issues/9274 pub fn read_http(r: &mut T) -> Res where @@ -741,7 +710,7 @@ impl Message { }) } - #[cfg(feature = "write-http")] + #[cfg(feature = "http")] pub fn write_http(&self, w: &mut impl io::Write) -> Res<()> { for info in &self.informational { ControlData::Response(info.status()).write_http(w)?; @@ -770,7 +739,6 @@ impl Message { } /// Read a BHTTP message. - #[cfg(feature = "read-bhttp")] pub fn read_bhttp(r: &mut T) -> Res where T: BorrowMut + ?Sized, @@ -815,7 +783,6 @@ impl Message { }) } - #[cfg(feature = "write-bhttp")] pub fn write_bhttp(&self, mode: Mode, w: &mut impl io::Write) -> Res<()> { write_varint(self.control.code(mode), w)?; for info in &self.informational { @@ -833,7 +800,7 @@ impl Message { } } -#[cfg(feature = "write-http")] +#[cfg(feature = "http")] impl std::fmt::Debug for Message { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { let mut buf = Vec::new(); diff --git a/bhttp/src/parse.rs b/bhttp/src/parse.rs index ee52493..06165fc 100644 --- a/bhttp/src/parse.rs +++ b/bhttp/src/parse.rs @@ -1,20 +1,21 @@ -#[cfg(feature = "read-http")] -use crate::{Error, ReadSeek, Res}; -#[cfg(feature = "read-http")] +#[cfg(feature = "http")] use std::borrow::BorrowMut; +#[cfg(feature = "http")] +use crate::{Error, ReadSeek, Res}; + pub const HTAB: u8 = 0x09; -#[cfg(feature = "read-http")] +#[cfg(feature = "http")] pub const NL: u8 = 0x0a; -#[cfg(feature = "read-http")] +#[cfg(feature = "http")] pub const CR: u8 = 0x0d; pub const SP: u8 = 0x20; pub const COMMA: u8 = 0x2c; -#[cfg(feature = "read-http")] +#[cfg(feature = "http")] pub const SLASH: u8 = 0x2f; -#[cfg(feature = "read-http")] +#[cfg(feature = "http")] pub const COLON: u8 = 0x3a; -#[cfg(feature = "read-http")] +#[cfg(feature = "http")] pub const SEMICOLON: u8 = 0x3b; pub fn is_ows(x: u8) -> bool { @@ -34,7 +35,7 @@ pub fn trim_ows(v: &[u8]) -> &[u8] { &v[..0] } -#[cfg(feature = "read-http")] +#[cfg(feature = "http")] pub fn downcase(n: &mut [u8]) { for i in n { if *i >= 0x41 && *i <= 0x5a { @@ -52,7 +53,7 @@ pub fn index_of(v: u8, line: &[u8]) -> Option { None } -#[cfg(feature = "read-http")] +#[cfg(feature = "http")] pub fn split_at(v: u8, mut line: Vec) -> Option<(Vec, Vec)> { index_of(v, &line).map(|i| { let tail = line.split_off(i + 1); @@ -61,7 +62,7 @@ pub fn split_at(v: u8, mut line: Vec) -> Option<(Vec, Vec)> { }) } -#[cfg(feature = "read-http")] +#[cfg(feature = "http")] pub fn read_line(r: &mut T) -> Res> where T: BorrowMut + ?Sized, diff --git a/bhttp/src/rw.rs b/bhttp/src/rw.rs index 92009ed..fa7c717 100644 --- a/bhttp/src/rw.rs +++ b/bhttp/src/rw.rs @@ -1,79 +1,66 @@ -#[cfg(feature = "read-bhttp")] -use std::borrow::BorrowMut; -use std::{convert::TryFrom, io}; +use std::{borrow::BorrowMut, convert::TryFrom, io}; -use crate::err::Res; -#[cfg(feature = "read-bhttp")] -use crate::{err::Error, ReadSeek}; +use crate::{ + err::{Error, Res}, + ReadSeek, +}; -#[cfg(feature = "write-bhttp")] #[allow(clippy::cast_possible_truncation)] -fn write_uint(n: u8, v: impl Into, w: &mut impl io::Write) -> Res<()> { - let v = v.into(); - assert!(n > 0 && usize::from(n) < std::mem::size_of::()); - for i in 0..n { - w.write_all(&[((v >> (8 * (n - i - 1))) & 0xff) as u8])?; - } +pub(crate) fn write_uint(v: impl Into, w: &mut impl io::Write) -> Res<()> { + let v = v.into().to_be_bytes(); + assert!((1..=std::mem::size_of::()).contains(&N)); + w.write_all(&v[8 - N..])?; Ok(()) } -#[cfg(feature = "write-bhttp")] pub fn write_varint(v: impl Into, w: &mut impl io::Write) -> Res<()> { let v = v.into(); match () { - () if v < (1 << 6) => write_uint(1, v, w), - () if v < (1 << 14) => write_uint(2, v | (1 << 14), w), - () if v < (1 << 30) => write_uint(4, v | (2 << 30), w), - () if v < (1 << 62) => write_uint(8, v | (3 << 62), w), + () if v < (1 << 6) => write_uint::<1>(v, w), + () if v < (1 << 14) => write_uint::<2>(v | (1 << 14), w), + () if v < (1 << 30) => write_uint::<4>(v | (2 << 30), w), + () if v < (1 << 62) => write_uint::<8>(v | (3 << 62), w), () => panic!("Varint value too large"), } } -#[cfg(feature = "write-bhttp")] pub fn write_len(len: usize, w: &mut impl io::Write) -> Res<()> { write_varint(u64::try_from(len).unwrap(), w) } -#[cfg(feature = "write-bhttp")] pub fn write_vec(v: &[u8], w: &mut impl io::Write) -> Res<()> { write_len(v.len(), w)?; w.write_all(v)?; Ok(()) } -#[cfg(feature = "read-bhttp")] -fn read_uint(n: usize, r: &mut T) -> Res> +fn read_uint(r: &mut T) -> Res> where T: BorrowMut + ?Sized, R: ReadSeek + ?Sized, { - let mut buf = [0; 7]; - let count = r.borrow_mut().read(&mut buf[..n])?; + let mut buf = [0; 8]; + let count = r.borrow_mut().read(&mut buf[(8 - N)..])?; if count == 0 { Ok(None) - } else if count < n { + } else if count < N { Err(Error::Truncated) } else { - let mut v = 0; - for i in &buf[..n] { - v = (v << 8) | u64::from(*i); - } - Ok(Some(v)) + Ok(Some(u64::from_be_bytes(buf))) } } -#[cfg(feature = "read-bhttp")] pub fn read_varint(r: &mut T) -> Res> where T: BorrowMut + ?Sized, R: ReadSeek + ?Sized, { - if let Some(b1) = read_uint(1, r)? { + if let Some(b1) = read_uint::<_, _, 1>(r)? { Ok(Some(match b1 >> 6 { 0 => b1 & 0x3f, - 1 => ((b1 & 0x3f) << 8) | read_uint(1, r)?.ok_or(Error::Truncated)?, - 2 => ((b1 & 0x3f) << 24) | read_uint(3, r)?.ok_or(Error::Truncated)?, - 3 => ((b1 & 0x3f) << 56) | read_uint(7, r)?.ok_or(Error::Truncated)?, + 1 => ((b1 & 0x3f) << 8) | read_uint::<_, _, 1>(r)?.ok_or(Error::Truncated)?, + 2 => ((b1 & 0x3f) << 24) | read_uint::<_, _, 3>(r)?.ok_or(Error::Truncated)?, + 3 => ((b1 & 0x3f) << 56) | read_uint::<_, _, 7>(r)?.ok_or(Error::Truncated)?, _ => unreachable!(), })) } else { @@ -81,7 +68,6 @@ where } } -#[cfg(feature = "read-bhttp")] pub fn read_vec(r: &mut T) -> Res>> where T: BorrowMut + ?Sized, diff --git a/bhttp/src/stream/mod.rs b/bhttp/src/stream/mod.rs new file mode 100644 index 0000000..982eda7 --- /dev/null +++ b/bhttp/src/stream/mod.rs @@ -0,0 +1,233 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use futures::io::AsyncRead; + +use crate::{Error, Res}; + +#[pin_project::pin_project] +pub struct ReadUint<'a, S, const N: usize> { + /// The source of data. + src: Pin<&'a mut S>, + /// A buffer that holds the bytes that have been read so far. + v: [u8; 8], + /// A counter of the number of bytes that are already in place. + /// This starts out at `8-N`. + read: usize, +} + +impl<'a, S, const N: usize> Future for ReadUint<'a, S, N> +where + S: AsyncRead, +{ + type Output = Res; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + match this.src.as_mut().poll_read(cx, &mut this.v[*this.read..]) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(count)) => { + if count == 0 { + return Poll::Ready(Err(Error::Truncated)); + } + *this.read += count; + if *this.read == 8 { + Poll::Ready(Ok(u64::from_be_bytes(*this.v))) + } else { + Poll::Pending + } + } + Poll::Ready(Err(e)) => Poll::Ready(Err(Error::from(e))), + } + } +} + +pub fn read_uint(src: &mut S) -> ReadUint<'_, S, N> { + ReadUint { + src: Pin::new(src), + v: [0; 8], + read: 8 - N, + } +} + +#[pin_project::pin_project(project = ReadVariantProj)] +pub enum ReadVarint<'a, S> { + First(Option>), + Extra1(#[pin] ReadUint<'a, S, 8>), + Extra3(#[pin] ReadUint<'a, S, 8>), + Extra7(#[pin] ReadUint<'a, S, 8>), +} + +impl<'a, S> Future for ReadVarint<'a, S> +where + S: AsyncRead, +{ + type Output = Res>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.as_mut(); + if let Self::First(src) = this.get_mut() { + let mut src = src.take().unwrap(); + let mut buf = [0; 1]; + if let Poll::Ready(Ok(c)) = src.as_mut().poll_read(cx, &mut buf[..]) { + if c == 0 { + return Poll::Ready(Ok(None)); + } + let b1 = buf[0]; + let mut v = [0; 8]; + let next = match b1 >> 6 { + 0 => return Poll::Ready(Ok(Some(u64::from(b1)))), + 1 => { + v[6] = b1 & 0x3f; + Self::Extra1(ReadUint { src, v, read: 7 }) + } + 2 => { + v[4] = b1 & 0x3f; + Self::Extra3(ReadUint { src, v, read: 5 }) + } + 3 => { + v[0] = b1 & 0x3f; + Self::Extra7(ReadUint { src, v, read: 1 }) + } + _ => unreachable!(), + }; + + self.set(next); + } + } + let extra = match self.project() { + ReadVariantProj::Extra1(s) + | ReadVariantProj::Extra3(s) + | ReadVariantProj::Extra7(s) => s.poll(cx), + ReadVariantProj::First(_) => return Poll::Pending, + }; + if let Poll::Ready(v) = extra { + Poll::Ready(v.map(Some)) + } else { + Poll::Pending + } + } +} + +pub fn read_varint(src: &mut S) -> ReadVarint<'_, S> { + ReadVarint::First(Some(Pin::new(src))) +} + +#[cfg(test)] +mod test { + use std::task::{Context, Poll}; + + use futures::{Future, FutureExt}; + + use crate::{ + rw::{write_uint as sync_write_uint, write_varint as sync_write_varint}, + stream::{read_uint as stream_read_uint, read_varint as stream_read_varint}, + }; + + pub fn noop_context() -> Context<'static> { + use std::{ + ptr::null, + task::{RawWaker, RawWakerVTable, Waker}, + }; + + const fn noop_raw_waker() -> RawWaker { + unsafe fn noop_clone(_data: *const ()) -> RawWaker { + noop_raw_waker() + } + + unsafe fn noop(_data: *const ()) {} + + const NOOP_WAKER_VTABLE: RawWakerVTable = + RawWakerVTable::new(noop_clone, noop, noop, noop); + RawWaker::new(null(), &NOOP_WAKER_VTABLE) + } + + pub fn noop_waker_ref() -> &'static Waker { + struct SyncRawWaker(RawWaker); + unsafe impl Sync for SyncRawWaker {} + + static NOOP_WAKER_INSTANCE: SyncRawWaker = SyncRawWaker(noop_raw_waker()); + + // SAFETY: `Waker` is #[repr(transparent)] over its `RawWaker`. + unsafe { &*(std::ptr::addr_of!(NOOP_WAKER_INSTANCE.0).cast()) } + } + + Context::from_waker(noop_waker_ref()) + } + + fn assert_unpin(v: T) -> T { + v + } + + fn read_uint(mut buf: &[u8]) -> u64 { + println!("{buf:?}"); + let mut cx = noop_context(); + let mut fut = assert_unpin(stream_read_uint::<_, N>(&mut buf)); + let mut v = fut.poll_unpin(&mut cx); + while v.is_pending() { + v = fut.poll_unpin(&mut cx); + } + if let Poll::Ready(Ok(v)) = v { + v + } else { + panic!("v is not OK: {v:?}"); + } + } + + #[test] + fn read_uint_values() { + macro_rules! validate_uint_range { + (@ $n:expr) => { + let m = u64::MAX >> (64 - 8 * $n); + for v in [0, 1, m] { + println!("{n} byte encoding of 0x{v:x}", n = $n); + let mut buf = Vec::with_capacity($n); + sync_write_uint::<$n>(v, &mut buf).unwrap(); + assert_eq!(v, read_uint::<$n>(&buf[..])); + } + }; + ($($n:expr),+ $(,)?) => { + $( + validate_uint_range!(@ $n); + )+ + } + } + validate_uint_range!(1, 2, 3, 4, 5, 6, 7, 8); + } + + fn read_varint(mut buf: &[u8]) -> u64 { + let mut cx = noop_context(); + let mut fut = assert_unpin(stream_read_varint(&mut buf)); + let mut v = fut.poll_unpin(&mut cx); + while v.is_pending() { + v = fut.poll_unpin(&mut cx); + } + if let Poll::Ready(Ok(Some(v))) = v { + v + } else { + panic!("v is not OK: {v:?}"); + } + } + + #[test] + fn read_varint_values() { + for i in [ + 0, + 1, + 63, + 64, + (1 << 14) - 1, + 1 << 14, + (1 << 30) - 1, + 1 << 30, + (1 << 62) - 1, + ] { + let mut buf = Vec::new(); + sync_write_varint(i, &mut buf).unwrap(); + assert_eq!(i, read_varint(&buf[..])); + } + } +} diff --git a/bhttp/tests/test.rs b/bhttp/tests/test.rs index c6729c6..9ed9731 100644 --- a/bhttp/tests/test.rs +++ b/bhttp/tests/test.rs @@ -1,5 +1,5 @@ // Rather than grapple with #[cfg(...)] for every variable and import. -#![cfg(all(feature = "http", feature = "bhttp"))] +#![cfg(feature = "http")] use std::{io::Cursor, mem::drop}; diff --git a/ohttp-client-cli/Cargo.toml b/ohttp-client-cli/Cargo.toml index 4f530f6..f40b198 100644 --- a/ohttp-client-cli/Cargo.toml +++ b/ohttp-client-cli/Cargo.toml @@ -15,7 +15,7 @@ hex = "0.4" [dependencies.bhttp] path= "../bhttp" -features = ["bhttp", "http"] +features = ["http"] [dependencies.ohttp] path= "../ohttp" diff --git a/ohttp-client/Cargo.toml b/ohttp-client/Cargo.toml index 7677608..1fa27fc 100644 --- a/ohttp-client/Cargo.toml +++ b/ohttp-client/Cargo.toml @@ -19,7 +19,7 @@ tokio = { version = "1", features = ["full"] } [dependencies.bhttp] path= "../bhttp" -features = ["bhttp", "http"] +features = ["http"] [dependencies.ohttp] path= "../ohttp" diff --git a/ohttp-server/Cargo.toml b/ohttp-server/Cargo.toml index 026df48..87b595e 100644 --- a/ohttp-server/Cargo.toml +++ b/ohttp-server/Cargo.toml @@ -18,7 +18,7 @@ warp = { version = "0.3", features = ["tls"] } [dependencies.bhttp] path= "../bhttp" -features = ["bhttp", "write-http"] +features = ["http"] [dependencies.ohttp] path= "../ohttp" diff --git a/ohttp/build.rs b/ohttp/build.rs index 1c01e3f..312cce2 100644 --- a/ohttp/build.rs +++ b/ohttp/build.rs @@ -8,8 +8,6 @@ #[cfg(feature = "nss")] mod nss { - use bindgen::Builder; - use serde_derive::Deserialize; use std::{ collections::HashMap, env, fs, @@ -17,6 +15,9 @@ mod nss { process::Command, }; + use bindgen::Builder; + use serde_derive::Deserialize; + const BINDINGS_DIR: &str = "bindings"; const BINDINGS_CONFIG: &str = "bindings.toml"; @@ -114,7 +115,6 @@ mod nss { let mut build_nss = vec![ String::from("./build.sh"), String::from("-Ddisable_tests=1"), - String::from("-Denable_draft_hpke=1"), ]; if is_debug() { build_nss.push(String::from("--static")); @@ -191,16 +191,8 @@ mod nss { } fn static_link(nsslibdir: &Path, use_static_softoken: bool, use_static_nspr: bool) { - let mut static_libs = vec![ - "certdb", - "certhi", - "cryptohi", - "nss_static", - "nssb", - "nssdev", - "nsspki", - "nssutil", - ]; + // The ordering of these libraries is critical for the linker. + let mut static_libs = vec!["cryptohi", "nss_static"]; let mut dynamic_libs = vec![]; if use_static_softoken { @@ -211,6 +203,8 @@ mod nss { static_libs.push("pk11wrap"); } + static_libs.extend_from_slice(&["nsspki", "nssdev", "nssb", "certhi", "certdb", "nssutil"]); + if use_static_nspr { static_libs.append(&mut nspr_libs()); } else { From a9d76f7f1b4603b956f54bc76b36030c0a6afdb5 Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Wed, 23 Oct 2024 16:15:06 +1100 Subject: [PATCH 02/20] Adding vector reading capabilities --- bhttp/src/err.rs | 3 + bhttp/src/stream/context.rs | 62 +++++++++ bhttp/src/stream/int.rs | 261 ++++++++++++++++++++++++++++++++++++ bhttp/src/stream/mod.rs | 236 +------------------------------- bhttp/src/stream/vec.rs | 229 +++++++++++++++++++++++++++++++ 5 files changed, 559 insertions(+), 232 deletions(-) create mode 100644 bhttp/src/stream/context.rs create mode 100644 bhttp/src/stream/int.rs create mode 100644 bhttp/src/stream/vec.rs diff --git a/bhttp/src/err.rs b/bhttp/src/err.rs index 3a457da..e53e0a1 100644 --- a/bhttp/src/err.rs +++ b/bhttp/src/err.rs @@ -19,6 +19,9 @@ pub enum Error { InvalidStatus, #[error("IO error {0}")] Io(#[from] std::io::Error), + #[cfg(feature = "stream")] + #[error("the size of a vector exceeded the limit that was set")] + LimitExceeded, #[error("a field or line was missing a necessary character 0x{0:x}")] Missing(u8), #[error("a URL was missing a key component")] diff --git a/bhttp/src/stream/context.rs b/bhttp/src/stream/context.rs new file mode 100644 index 0000000..c6987b0 --- /dev/null +++ b/bhttp/src/stream/context.rs @@ -0,0 +1,62 @@ +use std::{ + future::Future, + task::{Context, Poll}, +}; + +use futures::FutureExt; + +fn noop_context() -> Context<'static> { + use std::{ + ptr::null, + task::{RawWaker, RawWakerVTable, Waker}, + }; + + const fn noop_raw_waker() -> RawWaker { + unsafe fn noop_clone(_data: *const ()) -> RawWaker { + noop_raw_waker() + } + + unsafe fn noop(_data: *const ()) {} + + const NOOP_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(noop_clone, noop, noop, noop); + RawWaker::new(null(), &NOOP_WAKER_VTABLE) + } + + pub fn noop_waker_ref() -> &'static Waker { + struct SyncRawWaker(RawWaker); + unsafe impl Sync for SyncRawWaker {} + + static NOOP_WAKER_INSTANCE: SyncRawWaker = SyncRawWaker(noop_raw_waker()); + + // SAFETY: `Waker` is #[repr(transparent)] over its `RawWaker`. + unsafe { &*(std::ptr::addr_of!(NOOP_WAKER_INSTANCE.0).cast()) } + } + + Context::from_waker(noop_waker_ref()) +} + +fn assert_unpin(v: F) -> F { + v +} + +/// Drives the given future (`f`) until it resolves. +/// Executes the indicated function (`p`) each time the +/// poll returned `Poll::Pending`. +pub fn sync_resolve_with(f: F, p: P) -> F::Output { + let mut cx = noop_context(); + let mut fut = assert_unpin(f); + let mut v = fut.poll_unpin(&mut cx); + while v.is_pending() { + p(&mut fut); + v = fut.poll_unpin(&mut cx); + } + if let Poll::Ready(v) = v { + v + } else { + unreachable!(); + } +} + +pub fn sync_resolve(f: F) -> F::Output { + sync_resolve_with(f, |_| {}) +} diff --git a/bhttp/src/stream/int.rs b/bhttp/src/stream/int.rs new file mode 100644 index 0000000..02b05df --- /dev/null +++ b/bhttp/src/stream/int.rs @@ -0,0 +1,261 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use futures::io::AsyncRead; + +use crate::{Error, Res}; + +#[pin_project::pin_project] +pub struct ReadUint<'a, S, const N: usize> { + /// The source of data. + src: Pin<&'a mut S>, + /// A buffer that holds the bytes that have been read so far. + v: [u8; 8], + /// A counter of the number of bytes that are already in place. + /// This starts out at `8-N`. + read: usize, +} + +impl<'a, S, const N: usize> ReadUint<'a, S, N> { + pub fn stream(self) -> Pin<&'a mut S> { + self.src + } +} + +impl<'a, S, const N: usize> Future for ReadUint<'a, S, N> +where + S: AsyncRead, +{ + type Output = Res; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + match this.src.as_mut().poll_read(cx, &mut this.v[*this.read..]) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(count)) => { + if count == 0 { + return Poll::Ready(Err(Error::Truncated)); + } + *this.read += count; + if *this.read == 8 { + Poll::Ready(Ok(u64::from_be_bytes(*this.v))) + } else { + Poll::Pending + } + } + Poll::Ready(Err(e)) => Poll::Ready(Err(Error::from(e))), + } + } +} + +pub fn read_uint(src: &mut S) -> ReadUint<'_, S, N> { + ReadUint { + src: Pin::new(src), + v: [0; 8], + read: 8 - N, + } +} + +#[pin_project::pin_project(project = ReadVarintProj)] +pub enum ReadVarint<'a, S> { + // Invariant: this Option always contains Some. + First(Option>), + Extra1(#[pin] ReadUint<'a, S, 8>), + Extra3(#[pin] ReadUint<'a, S, 8>), + Extra7(#[pin] ReadUint<'a, S, 8>), +} + +impl<'a, S> ReadVarint<'a, S> { + pub fn stream(self) -> Pin<&'a mut S> { + match self { + Self::Extra1(s) | Self::Extra3(s) | Self::Extra7(s) => s.stream(), + Self::First(mut s) => s.take().unwrap(), + } + } +} + +impl<'a, S> Future for ReadVarint<'a, S> +where + S: AsyncRead, +{ + type Output = Res>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.as_mut(); + if let Self::First(ref mut src) = this.get_mut() { + let mut buf = [0; 1]; + let src_ref = src.as_mut().unwrap().as_mut(); + if let Poll::Ready(res) = src_ref.poll_read(cx, &mut buf[..]) { + match res { + Ok(0) => return Poll::Ready(Ok(None)), + Ok(_) => (), + Err(e) => return Poll::Ready(Err(Error::from(e))), + } + + let b1 = buf[0]; + let mut v = [0; 8]; + let next = match b1 >> 6 { + 0 => return Poll::Ready(Ok(Some(u64::from(b1)))), + 1 => { + let src = src.take().unwrap(); + v[6] = b1 & 0x3f; + Self::Extra1(ReadUint { src, v, read: 7 }) + } + 2 => { + let src = src.take().unwrap(); + v[4] = b1 & 0x3f; + Self::Extra3(ReadUint { src, v, read: 5 }) + } + 3 => { + let src = src.take().unwrap(); + v[0] = b1 & 0x3f; + Self::Extra7(ReadUint { src, v, read: 1 }) + } + _ => unreachable!(), + }; + + self.set(next); + } + } + let extra = match self.project() { + ReadVarintProj::Extra1(s) | ReadVarintProj::Extra3(s) | ReadVarintProj::Extra7(s) => { + s.poll(cx) + } + ReadVarintProj::First(_) => return Poll::Pending, + }; + if let Poll::Ready(v) = extra { + Poll::Ready(v.map(Some)) + } else { + Poll::Pending + } + } +} + +pub fn read_varint(src: &mut S) -> ReadVarint<'_, S> { + ReadVarint::First(Some(Pin::new(src))) +} + +#[cfg(test)] +mod test { + use crate::{ + err::Error, + rw::{write_uint as sync_write_uint, write_varint as sync_write_varint}, + stream::{ + context::sync_resolve, + int::{read_uint, read_varint}, + }, + }; + + const VARINTS: &[u64] = &[ + 0, + 1, + 63, + 64, + (1 << 14) - 1, + 1 << 14, + (1 << 30) - 1, + 1 << 30, + (1 << 62) - 1, + ]; + + #[test] + fn read_uint_values() { + macro_rules! validate_uint_range { + (@ $n:expr) => { + let m = u64::MAX >> (64 - 8 * $n); + for v in [0, 1, m] { + println!("{n} byte encoding of 0x{v:x}", n = $n); + let mut buf = Vec::with_capacity($n); + sync_write_uint::<$n>(v, &mut buf).unwrap(); + let mut buf_ref = &buf[..]; + let mut fut = read_uint::<_, $n>(&mut buf_ref); + assert_eq!(v, sync_resolve(&mut fut).unwrap()); + let s = fut.stream(); + assert!(s.is_empty()); + } + }; + ($($n:expr),+ $(,)?) => { + $( + validate_uint_range!(@ $n); + )+ + } + } + validate_uint_range!(1, 2, 3, 4, 5, 6, 7, 8); + } + + #[test] + fn read_uint_truncated() { + macro_rules! validate_uint_truncated { + (@ $n:expr) => { + let m = u64::MAX >> (64 - 8 * $n); + for v in [0, 1, m] { + println!("{n} byte encoding of 0x{v:x}", n = $n); + let mut buf = Vec::with_capacity($n); + sync_write_uint::<$n>(v, &mut buf).unwrap(); + for i in 1..buf.len() { + let err = sync_resolve(read_uint::<_, $n>(&mut &buf[..i])).unwrap_err(); + assert!(matches!(err, Error::Truncated)); + } + } + }; + ($($n:expr),+ $(,)?) => { + $( + validate_uint_truncated!(@ $n); + )+ + } + } + validate_uint_truncated!(1, 2, 3, 4, 5, 6, 7, 8); + } + + #[test] + fn read_varint_values() { + for &v in VARINTS { + let mut buf = Vec::new(); + sync_write_varint(v, &mut buf).unwrap(); + let mut buf_ref = &buf[..]; + let mut fut = read_varint(&mut buf_ref); + assert_eq!(Some(v), sync_resolve(&mut fut).unwrap()); + let s = fut.stream(); + assert!(s.is_empty()); + } + } + + #[test] + fn read_varint_none() { + assert!(sync_resolve(read_varint(&mut &[][..])).unwrap().is_none()); + } + + #[test] + fn read_varint_truncated() { + for &v in VARINTS { + let mut buf = Vec::new(); + sync_write_varint(v, &mut buf).unwrap(); + for i in 1..buf.len() { + let err = { + let mut buf: &[u8] = &buf[..i]; + sync_resolve(read_varint(&mut buf)) + } + .unwrap_err(); + assert!(matches!(err, Error::Truncated)); + } + } + } + + #[test] + fn read_varint_extra() { + const EXTRA: &[u8] = &[161, 2, 49]; + for &v in VARINTS { + let mut buf = Vec::new(); + sync_write_varint(v, &mut buf).unwrap(); + buf.extend_from_slice(EXTRA); + let mut buf_ref = &buf[..]; + let mut fut = read_varint(&mut buf_ref); + assert_eq!(Some(v), sync_resolve(&mut fut).unwrap()); + let s = fut.stream(); + assert_eq!(&s[..], EXTRA); + } + } +} diff --git a/bhttp/src/stream/mod.rs b/bhttp/src/stream/mod.rs index 982eda7..f00b008 100644 --- a/bhttp/src/stream/mod.rs +++ b/bhttp/src/stream/mod.rs @@ -1,233 +1,5 @@ -use std::{ - future::Future, - pin::Pin, - task::{Context, Poll}, -}; - -use futures::io::AsyncRead; - -use crate::{Error, Res}; - -#[pin_project::pin_project] -pub struct ReadUint<'a, S, const N: usize> { - /// The source of data. - src: Pin<&'a mut S>, - /// A buffer that holds the bytes that have been read so far. - v: [u8; 8], - /// A counter of the number of bytes that are already in place. - /// This starts out at `8-N`. - read: usize, -} - -impl<'a, S, const N: usize> Future for ReadUint<'a, S, N> -where - S: AsyncRead, -{ - type Output = Res; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - match this.src.as_mut().poll_read(cx, &mut this.v[*this.read..]) { - Poll::Pending => Poll::Pending, - Poll::Ready(Ok(count)) => { - if count == 0 { - return Poll::Ready(Err(Error::Truncated)); - } - *this.read += count; - if *this.read == 8 { - Poll::Ready(Ok(u64::from_be_bytes(*this.v))) - } else { - Poll::Pending - } - } - Poll::Ready(Err(e)) => Poll::Ready(Err(Error::from(e))), - } - } -} - -pub fn read_uint(src: &mut S) -> ReadUint<'_, S, N> { - ReadUint { - src: Pin::new(src), - v: [0; 8], - read: 8 - N, - } -} - -#[pin_project::pin_project(project = ReadVariantProj)] -pub enum ReadVarint<'a, S> { - First(Option>), - Extra1(#[pin] ReadUint<'a, S, 8>), - Extra3(#[pin] ReadUint<'a, S, 8>), - Extra7(#[pin] ReadUint<'a, S, 8>), -} - -impl<'a, S> Future for ReadVarint<'a, S> -where - S: AsyncRead, -{ - type Output = Res>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.as_mut(); - if let Self::First(src) = this.get_mut() { - let mut src = src.take().unwrap(); - let mut buf = [0; 1]; - if let Poll::Ready(Ok(c)) = src.as_mut().poll_read(cx, &mut buf[..]) { - if c == 0 { - return Poll::Ready(Ok(None)); - } - let b1 = buf[0]; - let mut v = [0; 8]; - let next = match b1 >> 6 { - 0 => return Poll::Ready(Ok(Some(u64::from(b1)))), - 1 => { - v[6] = b1 & 0x3f; - Self::Extra1(ReadUint { src, v, read: 7 }) - } - 2 => { - v[4] = b1 & 0x3f; - Self::Extra3(ReadUint { src, v, read: 5 }) - } - 3 => { - v[0] = b1 & 0x3f; - Self::Extra7(ReadUint { src, v, read: 1 }) - } - _ => unreachable!(), - }; - - self.set(next); - } - } - let extra = match self.project() { - ReadVariantProj::Extra1(s) - | ReadVariantProj::Extra3(s) - | ReadVariantProj::Extra7(s) => s.poll(cx), - ReadVariantProj::First(_) => return Poll::Pending, - }; - if let Poll::Ready(v) = extra { - Poll::Ready(v.map(Some)) - } else { - Poll::Pending - } - } -} - -pub fn read_varint(src: &mut S) -> ReadVarint<'_, S> { - ReadVarint::First(Some(Pin::new(src))) -} - +#![allow(dead_code)] // TODO #[cfg(test)] -mod test { - use std::task::{Context, Poll}; - - use futures::{Future, FutureExt}; - - use crate::{ - rw::{write_uint as sync_write_uint, write_varint as sync_write_varint}, - stream::{read_uint as stream_read_uint, read_varint as stream_read_varint}, - }; - - pub fn noop_context() -> Context<'static> { - use std::{ - ptr::null, - task::{RawWaker, RawWakerVTable, Waker}, - }; - - const fn noop_raw_waker() -> RawWaker { - unsafe fn noop_clone(_data: *const ()) -> RawWaker { - noop_raw_waker() - } - - unsafe fn noop(_data: *const ()) {} - - const NOOP_WAKER_VTABLE: RawWakerVTable = - RawWakerVTable::new(noop_clone, noop, noop, noop); - RawWaker::new(null(), &NOOP_WAKER_VTABLE) - } - - pub fn noop_waker_ref() -> &'static Waker { - struct SyncRawWaker(RawWaker); - unsafe impl Sync for SyncRawWaker {} - - static NOOP_WAKER_INSTANCE: SyncRawWaker = SyncRawWaker(noop_raw_waker()); - - // SAFETY: `Waker` is #[repr(transparent)] over its `RawWaker`. - unsafe { &*(std::ptr::addr_of!(NOOP_WAKER_INSTANCE.0).cast()) } - } - - Context::from_waker(noop_waker_ref()) - } - - fn assert_unpin(v: T) -> T { - v - } - - fn read_uint(mut buf: &[u8]) -> u64 { - println!("{buf:?}"); - let mut cx = noop_context(); - let mut fut = assert_unpin(stream_read_uint::<_, N>(&mut buf)); - let mut v = fut.poll_unpin(&mut cx); - while v.is_pending() { - v = fut.poll_unpin(&mut cx); - } - if let Poll::Ready(Ok(v)) = v { - v - } else { - panic!("v is not OK: {v:?}"); - } - } - - #[test] - fn read_uint_values() { - macro_rules! validate_uint_range { - (@ $n:expr) => { - let m = u64::MAX >> (64 - 8 * $n); - for v in [0, 1, m] { - println!("{n} byte encoding of 0x{v:x}", n = $n); - let mut buf = Vec::with_capacity($n); - sync_write_uint::<$n>(v, &mut buf).unwrap(); - assert_eq!(v, read_uint::<$n>(&buf[..])); - } - }; - ($($n:expr),+ $(,)?) => { - $( - validate_uint_range!(@ $n); - )+ - } - } - validate_uint_range!(1, 2, 3, 4, 5, 6, 7, 8); - } - - fn read_varint(mut buf: &[u8]) -> u64 { - let mut cx = noop_context(); - let mut fut = assert_unpin(stream_read_varint(&mut buf)); - let mut v = fut.poll_unpin(&mut cx); - while v.is_pending() { - v = fut.poll_unpin(&mut cx); - } - if let Poll::Ready(Ok(Some(v))) = v { - v - } else { - panic!("v is not OK: {v:?}"); - } - } - - #[test] - fn read_varint_values() { - for i in [ - 0, - 1, - 63, - 64, - (1 << 14) - 1, - 1 << 14, - (1 << 30) - 1, - 1 << 30, - (1 << 62) - 1, - ] { - let mut buf = Vec::new(); - sync_write_varint(i, &mut buf).unwrap(); - assert_eq!(i, read_varint(&buf[..])); - } - } -} +mod context; +mod int; +mod vec; diff --git a/bhttp/src/stream/vec.rs b/bhttp/src/stream/vec.rs new file mode 100644 index 0000000..8b02474 --- /dev/null +++ b/bhttp/src/stream/vec.rs @@ -0,0 +1,229 @@ +use std::{ + future::Future, + mem, + pin::Pin, + task::{Context, Poll}, +}; + +use futures::{io::AsyncRead, FutureExt}; + +use super::int::{read_varint, ReadVarint}; +use crate::{Error, Res}; + +#[pin_project::pin_project(project = ReadVecProj)] +#[allow(clippy::module_name_repetitions)] +pub enum ReadVec<'a, S> { + // Invariant: This Option is always Some. + ReadLen { + src: Option>, + cap: u64, + }, + ReadBody { + src: Pin<&'a mut S>, + buf: Vec, + remaining: usize, + }, +} + +impl<'a, S> ReadVec<'a, S> { + /// # Panics + /// If `limit` is more than `usize::MAX` or + /// if this is called after the length is read. + fn limit(&mut self, limit: u64) { + usize::try_from(limit).expect("cannot set a limit larger than usize::MAX"); + if let Self::ReadLen { ref mut cap, .. } = self { + *cap = limit; + } else { + panic!("cannot set a limit once the size has been read"); + } + } + + fn stream(self) -> Pin<&'a mut S> { + match self { + Self::ReadLen { mut src, .. } => src.take().unwrap().stream(), + Self::ReadBody { src, .. } => src, + } + } +} + +impl<'a, S> Future for ReadVec<'a, S> +where + S: AsyncRead, +{ + type Output = Res>>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.as_mut(); + if let Self::ReadLen { src, cap } = this.get_mut() { + match src.as_mut().unwrap().poll_unpin(cx) { + Poll::Ready(Ok(None)) => return Poll::Ready(Ok(None)), + Poll::Ready(Ok(Some(0))) => return Poll::Ready(Ok(Some(Vec::new()))), + Poll::Ready(Ok(Some(sz))) => { + if sz > *cap { + return Poll::Ready(Err(Error::LimitExceeded)); + } + // `cap` cannot exceed min(usize::MAX, u64::MAX). + let sz = usize::try_from(sz).unwrap(); + let body = Self::ReadBody { + src: src.take().unwrap().stream(), + buf: vec![0; sz], + remaining: sz, + }; + self.set(body); + } + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => return Poll::Pending, + } + } + + let ReadVecProj::ReadBody { + src, + buf, + remaining, + } = self.project() + else { + return Poll::Pending; + }; + + let offset = buf.len() - *remaining; + match src.as_mut().poll_read(cx, &mut buf[offset..]) { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => Poll::Ready(Err(Error::from(e))), + Poll::Ready(Ok(0)) => Poll::Ready(Err(Error::Truncated)), + Poll::Ready(Ok(c)) => { + *remaining -= c; + if *remaining > 0 { + Poll::Pending + } else { + Poll::Ready(Ok(Some(mem::take(buf)))) + } + } + } + } +} + +#[allow(clippy::module_name_repetitions)] +pub fn read_vec(src: &mut S) -> ReadVec<'_, S> { + ReadVec::ReadLen { + src: Some(read_varint(src)), + cap: u64::try_from(usize::MAX).unwrap_or(u64::MAX), + } +} + +#[cfg(test)] +mod test { + + use std::{ + cmp, + fmt::Debug, + io::Result, + pin::Pin, + task::{Context, Poll}, + }; + + use futures::AsyncRead; + + use crate::{ + rw::write_varint as sync_write_varint, + stream::{ + context::{sync_resolve, sync_resolve_with}, + vec::read_vec, + }, + Error, + }; + + const FILL_VALUE: u8 = 90; + + fn fill(len: T) -> Vec + where + u64: TryFrom, + >::Error: Debug, + usize: TryFrom, + >::Error: Debug, + T: Debug + Copy, + { + let mut buf = Vec::new(); + sync_write_varint(u64::try_from(len).unwrap(), &mut buf).unwrap(); + buf.resize(buf.len() + usize::try_from(len).unwrap(), FILL_VALUE); + buf + } + + #[test] + fn read_vecs() { + for len in [0, 1, 2, 3, 64] { + let buf = fill(len); + let mut buf_ref = &buf[..]; + let mut fut = read_vec(&mut buf_ref); + if let Ok(Some(out)) = sync_resolve(&mut fut) { + assert_eq!(len, out.len()); + assert!(out.iter().all(|&v| v == FILL_VALUE)); + + assert!(fut.stream().is_empty()); + } + } + } + + #[test] + fn exceed_cap() { + const LEN: u64 = 20; + let buf = fill(LEN); + let mut buf_ref = &buf[..]; + let mut fut = read_vec(&mut buf_ref); + fut.limit(LEN - 1); + assert!(matches!(sync_resolve(&mut fut), Err(Error::LimitExceeded))); + } + + /// This class implements `AsyncRead`, but + /// always blocks after returning a fixed value. + #[derive(Default)] + struct IncompleteRead<'a> { + data: &'a [u8], + consumed: usize, + } + + impl<'a> IncompleteRead<'a> { + fn new(data: &'a [u8]) -> Self { + Self { data, consumed: 0 } + } + } + + impl AsyncRead for IncompleteRead<'_> { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let remaining = &self.data[self.consumed..]; + if remaining.is_empty() { + Poll::Pending + } else { + let copied = cmp::min(buf.len(), remaining.len()); + buf[..copied].copy_from_slice(&remaining[..copied]); + self.as_mut().consumed += copied; + Poll::Ready(std::io::Result::Ok(copied)) + } + } + } + + #[test] + #[should_panic(expected = "cannot set a limit once the size has been read")] + fn late_cap() { + let mut buf = IncompleteRead::new(&[2, 1]); + _ = sync_resolve_with(read_vec(&mut buf), |f| { + println!("pending"); + f.limit(100); + }); + } + + #[test] + #[cfg(target_pointer_width = "32")] + #[should_panic(expected = "cannot set a limit larger than usize::MAX")] + fn too_large_cap() { + const LEN: u64 = 20; + let buf = fill(LEN); + + let mut buf_ref = &buf[..]; + let mut fut = read_vec(&mut buf_ref); + fut.limit(u64::try_from(usize::MAX).unwrap() + 1); + } +} From 30b7ef53cb04338acfd18958b31bc83737e88147 Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Wed, 23 Oct 2024 16:40:58 +1100 Subject: [PATCH 03/20] Add 16-bit arch as well, I guess --- bhttp/src/stream/vec.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bhttp/src/stream/vec.rs b/bhttp/src/stream/vec.rs index 8b02474..d71040a 100644 --- a/bhttp/src/stream/vec.rs +++ b/bhttp/src/stream/vec.rs @@ -216,7 +216,7 @@ mod test { } #[test] - #[cfg(target_pointer_width = "32")] + #[cfg(any(target_pointer_width = "32", target_pointer_width = "16"))] #[should_panic(expected = "cannot set a limit larger than usize::MAX")] fn too_large_cap() { const LEN: u64 = 20; From c432a05a2b11b13a21f7b85c4b73dfcfd5766a08 Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Fri, 25 Oct 2024 10:41:22 +1100 Subject: [PATCH 04/20] Checkpoint --- bhttp/src/err.rs | 3 + bhttp/src/lib.rs | 96 +++++--- bhttp/src/stream/{context.rs => future.rs} | 62 +++-- bhttp/src/stream/int.rs | 14 +- bhttp/src/stream/mod.rs | 260 ++++++++++++++++++++- bhttp/src/stream/vec.rs | 13 +- 6 files changed, 386 insertions(+), 62 deletions(-) rename bhttp/src/stream/{context.rs => future.rs} (50%) diff --git a/bhttp/src/err.rs b/bhttp/src/err.rs index e53e0a1..eb14acc 100644 --- a/bhttp/src/err.rs +++ b/bhttp/src/err.rs @@ -17,6 +17,9 @@ pub enum Error { InvalidMode, #[error("the status code of a response needs to be in 100..=599")] InvalidStatus, + #[cfg(feature = "stream")] + #[error("a method was called when the message was in the wrong state")] + InvalidState, #[error("IO error {0}")] Io(#[from] std::io::Error), #[cfg(feature = "stream")] diff --git a/bhttp/src/lib.rs b/bhttp/src/lib.rs index f92b2c4..be299a2 100644 --- a/bhttp/src/lib.rs +++ b/bhttp/src/lib.rs @@ -25,7 +25,7 @@ const COOKIE: &[u8] = b"cookie"; const TRANSFER_ENCODING: &[u8] = b"transfer-encoding"; const CHUNKED: &[u8] = b"chunked"; -#[derive(Clone, Copy, PartialEq, Eq)] +#[derive(Clone, Copy, PartialEq, Eq, Debug)] pub struct StatusCode(u16); impl StatusCode { @@ -78,6 +78,17 @@ pub enum Mode { IndeterminateLength, } +impl TryFrom for Mode { + type Error = Error; + fn try_from(t: u64) -> Result { + match t { + 0 | 1 => Ok(Self::KnownLength), + 2 | 3 => Ok(Self::IndeterminateLength), + _ => Err(Error::InvalidMode), + } + } +} + pub struct Field { name: Vec, value: Vec, @@ -558,10 +569,49 @@ impl InformationalResponse { } } +pub struct Header { + control: ControlData, + fields: FieldSection, +} + +impl Header { + #[must_use] + pub fn control(&self) -> &ControlData { + &self.control + } +} + +impl From for Header { + fn from(control: ControlData) -> Self { + Self { + control, + fields: FieldSection::default(), + } + } +} + +impl From<(ControlData, FieldSection)> for Header { + fn from((control, fields): (ControlData, FieldSection)) -> Self { + Self { control, fields } + } +} + +impl std::ops::Deref for Header { + type Target = FieldSection; + fn deref(&self) -> &Self::Target { + &self.fields + } +} + +impl std::ops::DerefMut for Header { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.fields + } +} + pub struct Message { informational: Vec, - control: ControlData, - header: FieldSection, + header: Header, content: Vec, trailer: FieldSection, } @@ -571,13 +621,12 @@ impl Message { pub fn request(method: Vec, scheme: Vec, authority: Vec, path: Vec) -> Self { Self { informational: Vec::new(), - control: ControlData::Request { + header: Header::from(ControlData::Request { method, scheme, authority, path, - }, - header: FieldSection::default(), + }), content: Vec::new(), trailer: FieldSection::default(), } @@ -587,8 +636,7 @@ impl Message { pub fn response(status: StatusCode) -> Self { Self { informational: Vec::new(), - control: ControlData::Response(status), - header: FieldSection::default(), + header: Header::from(ControlData::Response(status)), content: Vec::new(), trailer: FieldSection::default(), } @@ -613,11 +661,11 @@ impl Message { #[must_use] pub fn control(&self) -> &ControlData { - &self.control + self.header.control() } #[must_use] - pub fn header(&self) -> &FieldSection { + pub fn header(&self) -> &Header { &self.header } @@ -672,20 +720,20 @@ impl Message { control = ControlData::read_http(line)?; } - let mut header = FieldSection::read_http(r)?; + let mut hfields = FieldSection::read_http(r)?; let (content, trailer) = if matches!(control.status().map(StatusCode::code), Some(204 | 304)) { // 204 and 304 have no body, no matter what Content-Length says. // Unfortunately, we can't do the same for responses to HEAD. (Vec::new(), FieldSection::default()) - } else if header.is_chunked() { + } else if hfields.is_chunked() { let content = Self::read_chunked(r)?; let trailer = FieldSection::read_http(r)?; (content, trailer) } else { let mut content = Vec::new(); - if let Some(cl) = header.get(CONTENT_LENGTH) { + if let Some(cl) = hfields.get(CONTENT_LENGTH) { let cl_str = String::from_utf8(Vec::from(cl))?; let cl_int = cl_str.parse::()?; if cl_int > 0 { @@ -700,11 +748,10 @@ impl Message { (content, FieldSection::default()) }; - header.strip_connection_headers(); + hfields.strip_connection_headers(); Ok(Self { informational, - control, - header, + header: Header::from((control, hfields)), content, trailer, }) @@ -716,7 +763,7 @@ impl Message { ControlData::Response(info.status()).write_http(w)?; info.fields().write_http(w)?; } - self.control.write_http(w)?; + self.header.control.write_http(w)?; if !self.content.is_empty() { if self.trailer.is_empty() { write!(w, "Content-Length: {}\r\n", self.content.len())?; @@ -746,11 +793,7 @@ impl Message { { let t = read_varint(r)?.ok_or(Error::Truncated)?; let request = t == 0 || t == 2; - let mode = match t { - 0 | 1 => Mode::KnownLength, - 2 | 3 => Mode::IndeterminateLength, - _ => return Err(Error::InvalidMode), - }; + let mode = Mode::try_from(t)?; let mut control = ControlData::read_bhttp(request, r)?; let mut informational = Vec::new(); @@ -759,7 +802,7 @@ impl Message { informational.push(InformationalResponse::new(status, fields)); control = ControlData::read_bhttp(request, r)?; } - let header = FieldSection::read_bhttp(mode, r)?; + let hfields = FieldSection::read_bhttp(mode, r)?; let mut content = read_vec(r)?.unwrap_or_default(); if mode == Mode::IndeterminateLength && !content.is_empty() { @@ -776,19 +819,18 @@ impl Message { Ok(Self { informational, - control, - header, + header: Header::from((control, hfields)), content, trailer, }) } pub fn write_bhttp(&self, mode: Mode, w: &mut impl io::Write) -> Res<()> { - write_varint(self.control.code(mode), w)?; + write_varint(self.header.control.code(mode), w)?; for info in &self.informational { info.write_bhttp(mode, w)?; } - self.control.write_bhttp(w)?; + self.header.control.write_bhttp(w)?; self.header.write_bhttp(mode, w)?; write_vec(&self.content, w)?; diff --git a/bhttp/src/stream/context.rs b/bhttp/src/stream/future.rs similarity index 50% rename from bhttp/src/stream/context.rs rename to bhttp/src/stream/future.rs index c6987b0..ee26fc9 100644 --- a/bhttp/src/stream/context.rs +++ b/bhttp/src/stream/future.rs @@ -1,9 +1,12 @@ use std::{ future::Future, + pin::{pin, Pin}, task::{Context, Poll}, }; -use futures::FutureExt; +use futures::{TryStream, TryStreamExt}; + +use crate::Error; fn noop_context() -> Context<'static> { use std::{ @@ -35,28 +38,51 @@ fn noop_context() -> Context<'static> { Context::from_waker(noop_waker_ref()) } -fn assert_unpin(v: F) -> F { - v -} - /// Drives the given future (`f`) until it resolves. /// Executes the indicated function (`p`) each time the /// poll returned `Poll::Pending`. -pub fn sync_resolve_with(f: F, p: P) -> F::Output { - let mut cx = noop_context(); - let mut fut = assert_unpin(f); - let mut v = fut.poll_unpin(&mut cx); - while v.is_pending() { - p(&mut fut); - v = fut.poll_unpin(&mut cx); +pub trait SyncResolve { + type Output; + + fn sync_resolve(&mut self) -> Self::Output { + self.sync_resolve_with(|_| {}) } - if let Poll::Ready(v) = v { - v - } else { - unreachable!(); + + fn sync_resolve_with)>(&mut self, p: P) -> Self::Output; +} + +impl SyncResolve for F { + type Output = F::Output; + + fn sync_resolve_with)>(&mut self, p: P) -> Self::Output { + let mut cx = noop_context(); + let mut fut = Pin::new(self); + let mut v = fut.as_mut().poll(&mut cx); + while v.is_pending() { + p(fut.as_mut()); + v = fut.as_mut().poll(&mut cx); + } + if let Poll::Ready(v) = v { + v + } else { + unreachable!(); + } } } -pub fn sync_resolve(f: F) -> F::Output { - sync_resolve_with(f, |_| {}) +pub trait SyncCollect { + type Item; + + fn sync_collect(self) -> Result, Error>; +} + +impl SyncCollect for S +where + S: TryStream, +{ + type Item = S::Ok; + + fn sync_collect(self) -> Result, Error> { + pin!(self.try_collect::>()).sync_resolve() + } } diff --git a/bhttp/src/stream/int.rs b/bhttp/src/stream/int.rs index 02b05df..460e3f1 100644 --- a/bhttp/src/stream/int.rs +++ b/bhttp/src/stream/int.rs @@ -144,7 +144,7 @@ mod test { err::Error, rw::{write_uint as sync_write_uint, write_varint as sync_write_varint}, stream::{ - context::sync_resolve, + future::SyncResolve, int::{read_uint, read_varint}, }, }; @@ -172,7 +172,7 @@ mod test { sync_write_uint::<$n>(v, &mut buf).unwrap(); let mut buf_ref = &buf[..]; let mut fut = read_uint::<_, $n>(&mut buf_ref); - assert_eq!(v, sync_resolve(&mut fut).unwrap()); + assert_eq!(v, fut.sync_resolve().unwrap()); let s = fut.stream(); assert!(s.is_empty()); } @@ -196,7 +196,7 @@ mod test { let mut buf = Vec::with_capacity($n); sync_write_uint::<$n>(v, &mut buf).unwrap(); for i in 1..buf.len() { - let err = sync_resolve(read_uint::<_, $n>(&mut &buf[..i])).unwrap_err(); + let err = read_uint::<_, $n>(&mut &buf[..i]).sync_resolve().unwrap_err(); assert!(matches!(err, Error::Truncated)); } } @@ -217,7 +217,7 @@ mod test { sync_write_varint(v, &mut buf).unwrap(); let mut buf_ref = &buf[..]; let mut fut = read_varint(&mut buf_ref); - assert_eq!(Some(v), sync_resolve(&mut fut).unwrap()); + assert_eq!(Some(v), fut.sync_resolve().unwrap()); let s = fut.stream(); assert!(s.is_empty()); } @@ -225,7 +225,7 @@ mod test { #[test] fn read_varint_none() { - assert!(sync_resolve(read_varint(&mut &[][..])).unwrap().is_none()); + assert!(read_varint(&mut &[][..]).sync_resolve().unwrap().is_none()); } #[test] @@ -236,7 +236,7 @@ mod test { for i in 1..buf.len() { let err = { let mut buf: &[u8] = &buf[..i]; - sync_resolve(read_varint(&mut buf)) + read_varint(&mut buf).sync_resolve() } .unwrap_err(); assert!(matches!(err, Error::Truncated)); @@ -253,7 +253,7 @@ mod test { buf.extend_from_slice(EXTRA); let mut buf_ref = &buf[..]; let mut fut = read_varint(&mut buf_ref); - assert_eq!(Some(v), sync_resolve(&mut fut).unwrap()); + assert_eq!(Some(v), fut.sync_resolve().unwrap()); let s = fut.stream(); assert_eq!(&s[..], EXTRA); } diff --git a/bhttp/src/stream/mod.rs b/bhttp/src/stream/mod.rs index f00b008..011e0d1 100644 --- a/bhttp/src/stream/mod.rs +++ b/bhttp/src/stream/mod.rs @@ -1,5 +1,261 @@ -#![allow(dead_code)] // TODO +#![allow(dead_code)] + +use std::{ + io::{Cursor, Result as IoResult}, + mem, + pin::Pin, + task::{Context, Poll}, +}; + +use futures::{stream::unfold, AsyncRead, Stream, TryStreamExt}; +use int::ReadVarint; + +use crate::{ + err::Res, + stream::{int::read_varint, vec::read_vec}, + ControlData, Error, Field, FieldSection, Header, InformationalResponse, Message, Mode, COOKIE, +}; #[cfg(test)] -mod context; +mod future; mod int; mod vec; + +trait AsyncReadControlData: Sized { + async fn async_read(request: bool, src: &mut S) -> Res; +} + +impl AsyncReadControlData for ControlData { + async fn async_read(request: bool, src: &mut S) -> Res { + let v = if request { + let method = read_vec(src).await?.ok_or(Error::Truncated)?; + let scheme = read_vec(src).await?.ok_or(Error::Truncated)?; + let authority = read_vec(src).await?.ok_or(Error::Truncated)?; + let path = read_vec(src).await?.ok_or(Error::Truncated)?; + Self::Request { + method, + scheme, + authority, + path, + } + } else { + Self::Response(crate::StatusCode::try_from( + read_varint(src).await?.ok_or(Error::Truncated)?, + )?) + }; + Ok(v) + } +} + +trait AsyncReadFieldSection: Sized { + async fn async_read(mode: Mode, src: &mut S) -> Res; +} + +impl AsyncReadFieldSection for FieldSection { + async fn async_read(mode: Mode, src: &mut S) -> Res { + let fields = if mode == Mode::KnownLength { + // Known-length fields can just be read into a buffer. + if let Some(buf) = read_vec(src).await? { + Self::read_bhttp_fields(false, &mut Cursor::new(&buf[..]))? + } else { + Vec::new() + } + } else { + // The async version needs to be implemented directly. + let mut fields: Vec = Vec::new(); + let mut cookie_index: Option = None; + loop { + if let Some(n) = read_vec(src).await? { + if n.is_empty() { + break fields; + } + let mut v = read_vec(src).await?.ok_or(Error::Truncated)?; + if n == COOKIE { + if let Some(i) = &cookie_index { + fields[*i].value.extend_from_slice(b"; "); + fields[*i].value.append(&mut v); + continue; + } + cookie_index = Some(fields.len()); + } + fields.push(Field::new(n, v)); + } else { + return Err(Error::Truncated); + } + } + }; + Ok(Self(fields)) + } +} + +enum BodyState<'a, S> { + // When reading the length, use this. + ReadLength(ReadVarint<'a, S>), + // When reading the data, track how much is left. + ReadData { + remaining: usize, + src: Pin<&'a mut S>, + }, +} + +#[pin_project::pin_project] +struct Body<'a, S> { + mode: Mode, + state: BodyState<'a, S>, +} + +impl<'a, S: AsyncRead> AsyncRead for Body<'a, S> { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + self.project().src.as_mut().poll_read(cx, buf) + } +} + +enum AsyncMessageState { + // Processing Informational responses (or before that). + Informational, + // Having obtained the control data for the header, this is it. + Header(ControlData), + // Processing the Body. + Body, + // Processing the trailer. + Trailer, +} + +struct AsyncMessage<'a, S> { + // Whether this is a request and which mode. + framing: Option<(bool, Mode)>, + state: AsyncMessageState, + src: Pin<&'a mut S>, +} + +impl<'a, S: AsyncRead> AsyncMessage<'a, S> { + /// Get the mode. This panics if the header hasn't been read yet. + fn mode(&self) -> Mode { + self.framing.unwrap().1 + } + + async fn next_info(&mut self) -> Res> { + if !matches!(self.state, AsyncMessageState::Informational) { + return Ok(None); + } + + let (request, mode) = if let Some((request, mode)) = self.framing { + (request, mode) + } else { + let t = read_varint(&mut self.src).await?.ok_or(Error::Truncated)?; + let request = t == 0 || t == 2; + let mode = Mode::try_from(t)?; + self.framing = Some((request, mode)); + (request, mode) + }; + + let control = ControlData::async_read(request, &mut self.src).await?; + if let Some(status) = control.informational() { + let fields = FieldSection::async_read(mode, &mut self.src).await?; + Ok(Some(InformationalResponse::new(status, fields))) + } else { + self.state = AsyncMessageState::Header(control); + Ok(None) + } + } + + /// Produces a stream of informational responses from a fresh message. + /// Returns an empty stream if called at other times. + /// Error values on the stream indicate failures. + /// + /// There is no need to call this method to read a request, though + /// doing so is harmless. + /// + /// You can discard the stream that this function returns + /// without affecting the message. You can then either call this + /// method again to get any additional informational responses or + /// call `header()` to get the message header. + pub fn informational( + &mut self, + ) -> impl Stream> + use<'_, 'a, S> { + unfold(self, |this| async move { + this.next_info().await.transpose().map(|info| (info, this)) + }) + } + + /// This reads the header. If you have not called `informational` + /// and drained the resulting stream, this will do that for you. + pub async fn header(&mut self) -> Res
{ + if matches!(self.state, AsyncMessageState::Informational) { + // Need to scrub for errors, + // so that this can abort properly if there is one. + // The `try_any` usage is there to ensure that the stream is fully drained. + _ = self.informational().try_any(|_| async { false }).await?; + } + if matches!(self.state, AsyncMessageState::Header(_)) { + let AsyncMessageState::Header(control) = + mem::replace(&mut self.state, AsyncMessageState::Body) + else { + unreachable!(); + }; + let mode = self.mode(); + let hfields = FieldSection::async_read(mode, &mut self.src).await?; + Ok(Header::from((control, hfields))) + } else { + Err(Error::InvalidState) + } + } + + pub fn body<'s>(&'s mut self) -> Res> + where + 'a: 's, + { + if matches!(self.state, AsyncMessageState::Body) { + Ok(Body { + mode: self.mode(), + state: BodyState::ReadLength(read_varint(self.src.as_mut())), + }) + } else { + Err(Error::InvalidState) + } + } +} + +trait AsyncReadMessage: Sized { + fn async_read(src: &mut S) -> AsyncMessage<'_, S>; +} + +impl AsyncReadMessage for Message { + fn async_read(src: &mut S) -> AsyncMessage<'_, S> { + AsyncMessage { + framing: None, + state: AsyncMessageState::Informational, + src: Pin::new(src), + } + } +} + +#[cfg(test)] +mod test { + use std::pin::pin; + + use crate::{ + stream::{ + future::{SyncCollect, SyncResolve}, + AsyncReadMessage, + }, + Message, + }; + + #[test] + fn informational() { + const INFO: &[u8] = &[1, 64, 100, 0, 64, 200, 0]; + let mut buf_alias = INFO; + let mut msg = Message::async_read(&mut buf_alias); + let info = msg.informational().sync_collect().unwrap(); + assert_eq!(info.len(), 1); + let info = msg.informational().sync_collect().unwrap(); + assert!(info.is_empty()); + let hdr = pin!(msg.header()).sync_resolve().unwrap(); + assert_eq!(hdr.control().status().unwrap().code(), 200); + assert!(hdr.is_empty()); + } +} diff --git a/bhttp/src/stream/vec.rs b/bhttp/src/stream/vec.rs index d71040a..16dd433 100644 --- a/bhttp/src/stream/vec.rs +++ b/bhttp/src/stream/vec.rs @@ -125,10 +125,7 @@ mod test { use crate::{ rw::write_varint as sync_write_varint, - stream::{ - context::{sync_resolve, sync_resolve_with}, - vec::read_vec, - }, + stream::{future::SyncResolve, vec::read_vec}, Error, }; @@ -154,7 +151,7 @@ mod test { let buf = fill(len); let mut buf_ref = &buf[..]; let mut fut = read_vec(&mut buf_ref); - if let Ok(Some(out)) = sync_resolve(&mut fut) { + if let Ok(Some(out)) = fut.sync_resolve() { assert_eq!(len, out.len()); assert!(out.iter().all(|&v| v == FILL_VALUE)); @@ -170,7 +167,7 @@ mod test { let mut buf_ref = &buf[..]; let mut fut = read_vec(&mut buf_ref); fut.limit(LEN - 1); - assert!(matches!(sync_resolve(&mut fut), Err(Error::LimitExceeded))); + assert!(matches!(fut.sync_resolve(), Err(Error::LimitExceeded))); } /// This class implements `AsyncRead`, but @@ -209,9 +206,9 @@ mod test { #[should_panic(expected = "cannot set a limit once the size has been read")] fn late_cap() { let mut buf = IncompleteRead::new(&[2, 1]); - _ = sync_resolve_with(read_vec(&mut buf), |f| { + _ = read_vec(&mut buf).sync_resolve_with(|mut f| { println!("pending"); - f.limit(100); + f.as_mut().limit(100); }); } From 7d78bc7b48fd7f15f27072cd18db509b95273363 Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Mon, 28 Oct 2024 11:51:49 +1100 Subject: [PATCH 05/20] Checkpoint, fuck lifetimes --- bhttp/src/lib.rs | 153 ++++++++++++++++- bhttp/src/stream/future.rs | 28 ++- bhttp/src/stream/int.rs | 49 +++--- bhttp/src/stream/mod.rs | 342 +++++++++++++++++++++++++++++++------ bhttp/src/stream/vec.rs | 25 ++- 5 files changed, 495 insertions(+), 102 deletions(-) diff --git a/bhttp/src/lib.rs b/bhttp/src/lib.rs index be299a2..2082b0f 100644 --- a/bhttp/src/lib.rs +++ b/bhttp/src/lib.rs @@ -1,7 +1,11 @@ #![deny(warnings, clippy::pedantic)] #![allow(clippy::missing_errors_doc)] // Too lazy to document these. -use std::{borrow::BorrowMut, io}; +use std::{ + borrow::BorrowMut, + io, + ops::{Deref, DerefMut}, +}; #[cfg(feature = "http")] use url::Url; @@ -25,7 +29,7 @@ const COOKIE: &[u8] = b"cookie"; const TRANSFER_ENCODING: &[u8] = b"transfer-encoding"; const CHUNKED: &[u8] = b"chunked"; -#[derive(Clone, Copy, PartialEq, Eq, Debug)] +#[derive(Clone, Copy, Debug)] pub struct StatusCode(u16); impl StatusCode { @@ -68,6 +72,26 @@ impl From for u16 { } } +#[cfg(test)] +impl PartialEq for StatusCode +where + Self: TryFrom, + T: Copy, +{ + fn eq(&self, other: &T) -> bool { + StatusCode::try_from(*other).map_or(false, |o| o.0 == self.0) + } +} + +#[cfg(not(test))] +impl PartialEq for StatusCode { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Eq for StatusCode {} + pub trait ReadSeek: io::BufRead + io::Seek {} impl ReadSeek for io::Cursor where T: AsRef<[u8]> {} impl ReadSeek for io::BufReader where T: io::Read + io::Seek {} @@ -132,6 +156,18 @@ impl Field { } } +#[cfg(test)] +impl std::fmt::Debug for Field { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + write!( + f, + "{n}: {v}", + n = String::from_utf8_lossy(&self.name), + v = String::from_utf8_lossy(&self.value), + ) + } +} + #[derive(Default)] pub struct FieldSection(Vec); impl FieldSection { @@ -140,15 +176,26 @@ impl FieldSection { self.0.is_empty() } + #[must_use] + pub fn len(&self) -> usize { + self.0.len() + } + /// Gets the value from the first instance of the field. #[must_use] pub fn get(&self, n: &[u8]) -> Option<&[u8]> { - for f in &self.0 { + self.get_all(n).next() + } + + /// Gets all of the values of the named field. + pub fn get_all<'a, 'b>(&'a self, n: &'b [u8]) -> impl Iterator + use<'a, 'b> { + self.0.iter().filter_map(move |f| { if &f.name[..] == n { - return Some(&f.value); + Some(&f.value[..]) + } else { + None } - } - None + }) } pub fn put(&mut self, name: impl Into>, value: impl Into>) { @@ -336,6 +383,16 @@ impl FieldSection { } } +#[cfg(test)] +impl std::fmt::Debug for FieldSection { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + for fv in self.fields() { + fv.fmt(f)?; + } + Ok(()) + } +} + pub enum ControlData { Request { method: Vec, @@ -541,6 +598,68 @@ impl ControlData { } } +#[cfg(test)] +impl PartialEq<(M, S, A, P)> for ControlData +where + M: AsRef<[u8]>, + S: AsRef<[u8]>, + A: AsRef<[u8]>, + P: AsRef<[u8]>, +{ + fn eq(&self, other: &(M, S, A, P)) -> bool { + match self { + Self::Request { + method, + scheme, + authority, + path, + } => { + method == other.0.as_ref() + && scheme == other.1.as_ref() + && authority == other.2.as_ref() + && path == other.3.as_ref() + } + Self::Response(_) => false, + } + } +} + +#[cfg(test)] +impl PartialEq for ControlData +where + StatusCode: TryFrom, + T: Copy, +{ + fn eq(&self, other: &T) -> bool { + match self { + Self::Request { .. } => false, + Self::Response(code) => code == other, + } + } +} + +#[cfg(test)] +impl std::fmt::Debug for ControlData { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + match self { + Self::Request { + method, + scheme, + authority, + path, + } => write!( + f, + "{m} {s}://{a}{p}", + m = String::from_utf8_lossy(method), + s = String::from_utf8_lossy(scheme), + a = String::from_utf8_lossy(authority), + p = String::from_utf8_lossy(path), + ), + Self::Response(code) => write!(f, "{code:?}"), + } + } +} + pub struct InformationalResponse { status: StatusCode, fields: FieldSection, @@ -569,6 +688,20 @@ impl InformationalResponse { } } +impl Deref for InformationalResponse { + type Target = FieldSection; + + fn deref(&self) -> &Self::Target { + &self.fields + } +} + +impl DerefMut for InformationalResponse { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.fields + } +} + pub struct Header { control: ControlData, fields: FieldSection, @@ -609,6 +742,14 @@ impl std::ops::DerefMut for Header { } } +#[cfg(test)] +impl std::fmt::Debug for Header { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + self.control.fmt(f)?; + self.fields.fmt(f) + } +} + pub struct Message { informational: Vec, header: Header, diff --git a/bhttp/src/stream/future.rs b/bhttp/src/stream/future.rs index ee26fc9..3eae2bb 100644 --- a/bhttp/src/stream/future.rs +++ b/bhttp/src/stream/future.rs @@ -4,7 +4,7 @@ use std::{ task::{Context, Poll}, }; -use futures::{TryStream, TryStreamExt}; +use futures::{AsyncRead, AsyncReadExt, TryStream, TryStreamExt}; use crate::Error; @@ -76,13 +76,31 @@ pub trait SyncCollect { fn sync_collect(self) -> Result, Error>; } -impl SyncCollect for S -where - S: TryStream, -{ +impl> SyncCollect for S { type Item = S::Ok; fn sync_collect(self) -> Result, Error> { pin!(self.try_collect::>()).sync_resolve() } } + +pub trait SyncRead { + fn sync_read_exact(&mut self, amount: usize) -> Vec; + fn sync_read_to_end(&mut self) -> Vec; +} + +impl SyncRead for S { + fn sync_read_exact(&mut self, amount: usize) -> Vec { + let mut buf = vec![0; amount]; + let res = self.read_exact(&mut buf[..]); + pin!(res).sync_resolve().unwrap(); + buf + } + + fn sync_read_to_end(&mut self) -> Vec { + let mut buf = Vec::new(); + let res = self.read_to_end(&mut buf); + pin!(res).sync_resolve().unwrap(); + buf + } +} diff --git a/bhttp/src/stream/int.rs b/bhttp/src/stream/int.rs index 460e3f1..5b248df 100644 --- a/bhttp/src/stream/int.rs +++ b/bhttp/src/stream/int.rs @@ -1,6 +1,6 @@ use std::{ future::Future, - pin::Pin, + pin::{pin, Pin}, task::{Context, Poll}, }; @@ -9,9 +9,9 @@ use futures::io::AsyncRead; use crate::{Error, Res}; #[pin_project::pin_project] -pub struct ReadUint<'a, S, const N: usize> { +pub struct ReadUint { /// The source of data. - src: Pin<&'a mut S>, + src: S, /// A buffer that holds the bytes that have been read so far. v: [u8; 8], /// A counter of the number of bytes that are already in place. @@ -19,21 +19,18 @@ pub struct ReadUint<'a, S, const N: usize> { read: usize, } -impl<'a, S, const N: usize> ReadUint<'a, S, N> { - pub fn stream(self) -> Pin<&'a mut S> { +impl ReadUint { + pub fn stream(self) -> S { self.src } } -impl<'a, S, const N: usize> Future for ReadUint<'a, S, N> -where - S: AsyncRead, -{ +impl Future for ReadUint { type Output = Res; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); - match this.src.as_mut().poll_read(cx, &mut this.v[*this.read..]) { + match pin!(this.src).poll_read(cx, &mut this.v[*this.read..]) { Poll::Pending => Poll::Pending, Poll::Ready(Ok(count)) => { if count == 0 { @@ -51,25 +48,26 @@ where } } -pub fn read_uint(src: &mut S) -> ReadUint<'_, S, N> { +#[cfg(test)] +fn read_uint(src: S) -> ReadUint { ReadUint { - src: Pin::new(src), + src, v: [0; 8], read: 8 - N, } } #[pin_project::pin_project(project = ReadVarintProj)] -pub enum ReadVarint<'a, S> { +pub enum ReadVarint { // Invariant: this Option always contains Some. - First(Option>), - Extra1(#[pin] ReadUint<'a, S, 8>), - Extra3(#[pin] ReadUint<'a, S, 8>), - Extra7(#[pin] ReadUint<'a, S, 8>), + First(Option), + Extra1(#[pin] ReadUint), + Extra3(#[pin] ReadUint), + Extra7(#[pin] ReadUint), } -impl<'a, S> ReadVarint<'a, S> { - pub fn stream(self) -> Pin<&'a mut S> { +impl ReadVarint { + pub fn stream(self) -> S { match self { Self::Extra1(s) | Self::Extra3(s) | Self::Extra7(s) => s.stream(), Self::First(mut s) => s.take().unwrap(), @@ -77,18 +75,15 @@ impl<'a, S> ReadVarint<'a, S> { } } -impl<'a, S> Future for ReadVarint<'a, S> -where - S: AsyncRead, -{ +impl Future for ReadVarint { type Output = Res>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.as_mut(); if let Self::First(ref mut src) = this.get_mut() { let mut buf = [0; 1]; - let src_ref = src.as_mut().unwrap().as_mut(); - if let Poll::Ready(res) = src_ref.poll_read(cx, &mut buf[..]) { + let src_ref = src.as_mut().unwrap(); + if let Poll::Ready(res) = pin!(src_ref).poll_read(cx, &mut buf[..]) { match res { Ok(0) => return Poll::Ready(Ok(None)), Ok(_) => (), @@ -134,8 +129,8 @@ where } } -pub fn read_varint(src: &mut S) -> ReadVarint<'_, S> { - ReadVarint::First(Some(Pin::new(src))) +pub fn read_varint(src: S) -> ReadVarint { + ReadVarint::First(Some(src)) } #[cfg(test)] diff --git a/bhttp/src/stream/mod.rs b/bhttp/src/stream/mod.rs index 011e0d1..9839059 100644 --- a/bhttp/src/stream/mod.rs +++ b/bhttp/src/stream/mod.rs @@ -1,13 +1,14 @@ -#![allow(dead_code)] +#![allow(clippy::incompatible_msrv)] // This module uses features from rust 1.82 use std::{ - io::{Cursor, Result as IoResult}, + cmp::min, + io::{Cursor, Error as IoError, Result as IoResult}, mem, - pin::Pin, + pin::{pin, Pin}, task::{Context, Poll}, }; -use futures::{stream::unfold, AsyncRead, Stream, TryStreamExt}; +use futures::{stream::unfold, AsyncRead, FutureExt, Stream, TryStreamExt}; use int::ReadVarint; use crate::{ @@ -21,16 +22,16 @@ mod int; mod vec; trait AsyncReadControlData: Sized { - async fn async_read(request: bool, src: &mut S) -> Res; + async fn async_read(request: bool, src: S) -> Res; } impl AsyncReadControlData for ControlData { - async fn async_read(request: bool, src: &mut S) -> Res { + async fn async_read(request: bool, mut src: S) -> Res { let v = if request { - let method = read_vec(src).await?.ok_or(Error::Truncated)?; - let scheme = read_vec(src).await?.ok_or(Error::Truncated)?; - let authority = read_vec(src).await?.ok_or(Error::Truncated)?; - let path = read_vec(src).await?.ok_or(Error::Truncated)?; + let method = read_vec(&mut src).await?.ok_or(Error::Truncated)?; + let scheme = read_vec(&mut src).await?.ok_or(Error::Truncated)?; + let authority = read_vec(&mut src).await?.ok_or(Error::Truncated)?; + let path = read_vec(&mut src).await?.ok_or(Error::Truncated)?; Self::Request { method, scheme, @@ -38,23 +39,22 @@ impl AsyncReadControlData for ControlData { path, } } else { - Self::Response(crate::StatusCode::try_from( - read_varint(src).await?.ok_or(Error::Truncated)?, - )?) + let code = read_varint(&mut src).await?.ok_or(Error::Truncated)?; + Self::Response(crate::StatusCode::try_from(code)?) }; Ok(v) } } trait AsyncReadFieldSection: Sized { - async fn async_read(mode: Mode, src: &mut S) -> Res; + async fn async_read(mode: Mode, src: S) -> Res; } impl AsyncReadFieldSection for FieldSection { - async fn async_read(mode: Mode, src: &mut S) -> Res { + async fn async_read(mode: Mode, mut src: S) -> Res { let fields = if mode == Mode::KnownLength { // Known-length fields can just be read into a buffer. - if let Some(buf) = read_vec(src).await? { + if let Some(buf) = read_vec(&mut src).await? { Self::read_bhttp_fields(false, &mut Cursor::new(&buf[..]))? } else { Vec::new() @@ -64,11 +64,11 @@ impl AsyncReadFieldSection for FieldSection { let mut fields: Vec = Vec::new(); let mut cookie_index: Option = None; loop { - if let Some(n) = read_vec(src).await? { + if let Some(n) = read_vec(&mut src).await? { if n.is_empty() { break fields; } - let mut v = read_vec(src).await?.ok_or(Error::Truncated)?; + let mut v = read_vec(&mut src).await?.ok_or(Error::Truncated)?; if n == COOKIE { if let Some(i) = &cookie_index { fields[*i].value.extend_from_slice(b"; "); @@ -78,6 +78,8 @@ impl AsyncReadFieldSection for FieldSection { cookie_index = Some(fields.len()); } fields.push(Field::new(n, v)); + } else if fields.is_empty() { + break fields; } else { return Err(Error::Truncated); } @@ -87,51 +89,113 @@ impl AsyncReadFieldSection for FieldSection { } } -enum BodyState<'a, S> { +#[allow(clippy::mut_mut)] // TODO look into this more. +enum BodyState<'a, 'b, S> { // When reading the length, use this. - ReadLength(ReadVarint<'a, S>), + // Invariant: This is always `Some`. + ReadLength(Option>), // When reading the data, track how much is left. + // Invariant: `src` is always `Some`. ReadData { remaining: usize, - src: Pin<&'a mut S>, + src: Option<&'b mut &'a mut S>, }, } -#[pin_project::pin_project] -struct Body<'a, S> { +pub struct Body<'a, 'b, S> { mode: Mode, - state: BodyState<'a, S>, + state: &'b mut AsyncMessageState<'a, 'b, S>, } -impl<'a, S: AsyncRead> AsyncRead for Body<'a, S> { +impl<'a, 'b, S> Body<'a, 'b, S> { + fn set_state(&mut self, s: BodyState<'a, 'b, S>) { + *self.state = AsyncMessageState::Body(s); + } + + fn done(&mut self) { + *self.state = AsyncMessageState::Trailer; + } +} + +impl<'a, 'b, S: AsyncRead + Unpin> AsyncRead for Body<'a, 'b, S> { fn poll_read( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { - self.project().src.as_mut().poll_read(cx, buf) + fn poll_error(e: Error) -> Poll> { + Poll::Ready(Err(IoError::other(e))) + } + + let mode = self.mode; + if let AsyncMessageState::Body(BodyState::ReadLength(r)) = &mut self.state { + match r.as_mut().unwrap().poll_unpin(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Ok(Some(0) | None)) => { + self.done(); + return Poll::Ready(Ok(0)); + } + Poll::Ready(Ok(Some(len))) => { + match usize::try_from(len) { + Ok(remaining) => { + let src = r.take().map(ReadVarint::stream); + self.set_state(BodyState::ReadData { remaining, src }); + // fall through to maybe read the body + } + Err(e) => return poll_error(Error::IntRange(e)), + } + } + Poll::Ready(Err(e)) => return poll_error(e), + } + } + + if let AsyncMessageState::Body(BodyState::ReadData { remaining, src }) = &mut self.state { + let amount = min(*remaining, buf.len()); + let res = pin!(src.as_mut().unwrap()).poll_read(cx, &mut buf[..amount]); + match res { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(0)) => poll_error(Error::Truncated), + Poll::Ready(Ok(len)) => { + *remaining -= len; + if *remaining == 0 { + if mode == Mode::IndeterminateLength { + let src = src.take().map(read_varint); + self.set_state(BodyState::ReadLength(src)); + } else { + self.done(); + } + } + Poll::Ready(Ok(len)) + } + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + } + } else { + Poll::Pending + } } } -enum AsyncMessageState { +enum AsyncMessageState<'a, 'b, S> { // Processing Informational responses (or before that). Informational, // Having obtained the control data for the header, this is it. Header(ControlData), // Processing the Body. - Body, + Body(BodyState<'a, 'b, S>), // Processing the trailer. Trailer, } -struct AsyncMessage<'a, S> { +pub struct AsyncMessage<'a, 'b, S> { // Whether this is a request and which mode. framing: Option<(bool, Mode)>, - state: AsyncMessageState, - src: Pin<&'a mut S>, + state: AsyncMessageState<'a, 'b, S>, + src: &'a mut S, } -impl<'a, S: AsyncRead> AsyncMessage<'a, S> { +unsafe impl Send for AsyncMessage<'_, '_, S> {} + +impl<'a, 'b, S: AsyncRead + Unpin> AsyncMessage<'a, 'b, S> { /// Get the mode. This panics if the header hasn't been read yet. fn mode(&self) -> Mode { self.framing.unwrap().1 @@ -175,7 +239,7 @@ impl<'a, S: AsyncRead> AsyncMessage<'a, S> { /// call `header()` to get the message header. pub fn informational( &mut self, - ) -> impl Stream> + use<'_, 'a, S> { + ) -> impl Stream> + use<'_, 'a, 'b, S> { unfold(self, |this| async move { this.next_info().await.transpose().map(|info| (info, this)) }) @@ -183,7 +247,7 @@ impl<'a, S: AsyncRead> AsyncMessage<'a, S> { /// This reads the header. If you have not called `informational` /// and drained the resulting stream, this will do that for you. - pub async fn header(&mut self) -> Res
{ + pub async fn header(&'b mut self) -> Res
{ if matches!(self.state, AsyncMessageState::Informational) { // Need to scrub for errors, // so that this can abort properly if there is one. @@ -191,44 +255,62 @@ impl<'a, S: AsyncRead> AsyncMessage<'a, S> { _ = self.informational().try_any(|_| async { false }).await?; } if matches!(self.state, AsyncMessageState::Header(_)) { + let mode = self.mode(); + let hfields = FieldSection::async_read(mode, &mut self.src).await?; + + let bs: BodyState<'a, 'b, S> = BodyState::ReadLength(Some(read_varint(&mut self.src))); let AsyncMessageState::Header(control) = - mem::replace(&mut self.state, AsyncMessageState::Body) + mem::replace(&mut self.state, AsyncMessageState::Body(bs)) else { unreachable!(); }; - let mode = self.mode(); - let hfields = FieldSection::async_read(mode, &mut self.src).await?; Ok(Header::from((control, hfields))) } else { Err(Error::InvalidState) } } - pub fn body<'s>(&'s mut self) -> Res> - where - 'a: 's, - { - if matches!(self.state, AsyncMessageState::Body) { + /// Read the body. + /// This produces an implementation of `AsyncRead` that filters out + /// the framing from the message body. + /// # Errors + /// This errors when the header has not been read. + /// Any IO errors are generated by the returned `Body` instance. + pub fn body(&'b mut self) -> Res> { + if matches!(self.state, AsyncMessageState::Body(_)) { + let mode = self.mode(); Ok(Body { - mode: self.mode(), - state: BodyState::ReadLength(read_varint(self.src.as_mut())), + mode, + state: &mut self.state, }) } else { Err(Error::InvalidState) } } + + /// Read any trailer. + /// This might be empty. + /// # Errors + /// This errors when the body has not been read. + pub async fn trailer(&mut self) -> Res { + if matches!(self.state, AsyncMessageState::Trailer) { + Ok(FieldSection::async_read(self.mode(), &mut self.src).await?) + } else { + Err(Error::InvalidState) + } + } } -trait AsyncReadMessage: Sized { - fn async_read(src: &mut S) -> AsyncMessage<'_, S>; +pub trait AsyncReadMessage: Sized { + fn async_read<'b, S: AsyncRead + Unpin>(src: &mut S) -> AsyncMessage<'_, 'b, S>; } impl AsyncReadMessage for Message { - fn async_read(src: &mut S) -> AsyncMessage<'_, S> { + fn async_read<'b, S: AsyncRead + Unpin>(src: &mut S) -> AsyncMessage<'_, 'b, S> { AsyncMessage { framing: None, state: AsyncMessageState::Informational, - src: Pin::new(src), + src, } } } @@ -237,14 +319,41 @@ impl AsyncReadMessage for Message { mod test { use std::pin::pin; + use futures::TryStreamExt; + use crate::{ stream::{ - future::{SyncCollect, SyncResolve}, + future::{SyncCollect, SyncRead, SyncResolve}, AsyncReadMessage, }, - Message, + Error, Message, }; + // Example from Section 5.1 of RFC 9292. + const REQUEST1: &[u8] = &[ + 0x00, 0x03, 0x47, 0x45, 0x54, 0x05, 0x68, 0x74, 0x74, 0x70, 0x73, 0x00, 0x0a, 0x2f, 0x68, + 0x65, 0x6c, 0x6c, 0x6f, 0x2e, 0x74, 0x78, 0x74, 0x40, 0x6c, 0x0a, 0x75, 0x73, 0x65, 0x72, + 0x2d, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x34, 0x63, 0x75, 0x72, 0x6c, 0x2f, 0x37, 0x2e, 0x31, + 0x36, 0x2e, 0x33, 0x20, 0x6c, 0x69, 0x62, 0x63, 0x75, 0x72, 0x6c, 0x2f, 0x37, 0x2e, 0x31, + 0x36, 0x2e, 0x33, 0x20, 0x4f, 0x70, 0x65, 0x6e, 0x53, 0x53, 0x4c, 0x2f, 0x30, 0x2e, 0x39, + 0x2e, 0x37, 0x6c, 0x20, 0x7a, 0x6c, 0x69, 0x62, 0x2f, 0x31, 0x2e, 0x32, 0x2e, 0x33, 0x04, + 0x68, 0x6f, 0x73, 0x74, 0x0f, 0x77, 0x77, 0x77, 0x2e, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, + 0x65, 0x2e, 0x63, 0x6f, 0x6d, 0x0f, 0x61, 0x63, 0x63, 0x65, 0x70, 0x74, 0x2d, 0x6c, 0x61, + 0x6e, 0x67, 0x75, 0x61, 0x67, 0x65, 0x06, 0x65, 0x6e, 0x2c, 0x20, 0x6d, 0x69, 0x00, 0x00, + ]; + const REQUEST2: &[u8] = &[ + 0x02, 0x03, 0x47, 0x45, 0x54, 0x05, 0x68, 0x74, 0x74, 0x70, 0x73, 0x00, 0x0a, 0x2f, 0x68, + 0x65, 0x6c, 0x6c, 0x6f, 0x2e, 0x74, 0x78, 0x74, 0x0a, 0x75, 0x73, 0x65, 0x72, 0x2d, 0x61, + 0x67, 0x65, 0x6e, 0x74, 0x34, 0x63, 0x75, 0x72, 0x6c, 0x2f, 0x37, 0x2e, 0x31, 0x36, 0x2e, + 0x33, 0x20, 0x6c, 0x69, 0x62, 0x63, 0x75, 0x72, 0x6c, 0x2f, 0x37, 0x2e, 0x31, 0x36, 0x2e, + 0x33, 0x20, 0x4f, 0x70, 0x65, 0x6e, 0x53, 0x53, 0x4c, 0x2f, 0x30, 0x2e, 0x39, 0x2e, 0x37, + 0x6c, 0x20, 0x7a, 0x6c, 0x69, 0x62, 0x2f, 0x31, 0x2e, 0x32, 0x2e, 0x33, 0x04, 0x68, 0x6f, + 0x73, 0x74, 0x0f, 0x77, 0x77, 0x77, 0x2e, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, + 0x63, 0x6f, 0x6d, 0x0f, 0x61, 0x63, 0x63, 0x65, 0x70, 0x74, 0x2d, 0x6c, 0x61, 0x6e, 0x67, + 0x75, 0x61, 0x67, 0x65, 0x06, 0x65, 0x6e, 0x2c, 0x20, 0x6d, 0x69, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + ]; + #[test] fn informational() { const INFO: &[u8] = &[1, 64, 100, 0, 64, 200, 0]; @@ -258,4 +367,135 @@ mod test { assert_eq!(hdr.control().status().unwrap().code(), 200); assert!(hdr.is_empty()); } + + #[test] + fn sample_requests() { + fn validate_sample_request(mut buf: &[u8]) { + let mut msg = Message::async_read(&mut buf); + let info = msg.informational().sync_collect().unwrap(); + assert!(info.is_empty()); + + let hdr = pin!(msg.header()).sync_resolve().unwrap(); + assert_eq!(hdr.control(), &(b"GET", b"https", b"", b"/hello.txt")); + assert_eq!( + hdr.get(b"user-agent"), + Some(&b"curl/7.16.3 libcurl/7.16.3 OpenSSL/0.9.7l zlib/1.2.3"[..]), + ); + assert_eq!(hdr.get(b"host"), Some(&b"www.example.com"[..])); + assert_eq!(hdr.get(b"accept-language"), Some(&b"en, mi"[..])); + assert_eq!(hdr.len(), 3); + + let body = pin!(msg.body().unwrap()).sync_read_to_end(); + assert!(body.is_empty()); + + let trailer = pin!(msg.trailer()).sync_resolve().unwrap(); + assert!(trailer.is_empty()); + } + + validate_sample_request(REQUEST1); + validate_sample_request(REQUEST2); + validate_sample_request(&REQUEST2[..REQUEST2.len() - 12]); + } + + #[test] + fn truncated_header() { + // The indefinite-length request example includes 10 bytes of padding. + // The three additional zero values at the end represent: + // 1. The terminating zero for the header field section. + // 2. The terminating zero for the (empty) body. + // 3. The terminating zero for the (absent) trailer field section. + // The latter two (body and trailer) can be cut and the message will still work. + // The first is not optional; dropping it means that the message is truncated. + let mut buf = &mut &REQUEST2[..REQUEST2.len() - 13]; + let mut msg = Message::async_read(&mut buf); + // Use this test to test skipping a few things. + let err = pin!(msg.header()).sync_resolve().unwrap_err(); + assert!(matches!(err, Error::Truncated)); + } + + #[test] + fn sample_responses() { + const RESPONSE: &[u8] = &[ + 0x03, 0x40, 0x66, 0x07, 0x72, 0x75, 0x6e, 0x6e, 0x69, 0x6e, 0x67, 0x0a, 0x22, 0x73, + 0x6c, 0x65, 0x65, 0x70, 0x20, 0x31, 0x35, 0x22, 0x00, 0x40, 0x67, 0x04, 0x6c, 0x69, + 0x6e, 0x6b, 0x23, 0x3c, 0x2f, 0x73, 0x74, 0x79, 0x6c, 0x65, 0x2e, 0x63, 0x73, 0x73, + 0x3e, 0x3b, 0x20, 0x72, 0x65, 0x6c, 0x3d, 0x70, 0x72, 0x65, 0x6c, 0x6f, 0x61, 0x64, + 0x3b, 0x20, 0x61, 0x73, 0x3d, 0x73, 0x74, 0x79, 0x6c, 0x65, 0x04, 0x6c, 0x69, 0x6e, + 0x6b, 0x24, 0x3c, 0x2f, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x2e, 0x6a, 0x73, 0x3e, + 0x3b, 0x20, 0x72, 0x65, 0x6c, 0x3d, 0x70, 0x72, 0x65, 0x6c, 0x6f, 0x61, 0x64, 0x3b, + 0x20, 0x61, 0x73, 0x3d, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x00, 0x40, 0xc8, 0x04, + 0x64, 0x61, 0x74, 0x65, 0x1d, 0x4d, 0x6f, 0x6e, 0x2c, 0x20, 0x32, 0x37, 0x20, 0x4a, + 0x75, 0x6c, 0x20, 0x32, 0x30, 0x30, 0x39, 0x20, 0x31, 0x32, 0x3a, 0x32, 0x38, 0x3a, + 0x35, 0x33, 0x20, 0x47, 0x4d, 0x54, 0x06, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x06, + 0x41, 0x70, 0x61, 0x63, 0x68, 0x65, 0x0d, 0x6c, 0x61, 0x73, 0x74, 0x2d, 0x6d, 0x6f, + 0x64, 0x69, 0x66, 0x69, 0x65, 0x64, 0x1d, 0x57, 0x65, 0x64, 0x2c, 0x20, 0x32, 0x32, + 0x20, 0x4a, 0x75, 0x6c, 0x20, 0x32, 0x30, 0x30, 0x39, 0x20, 0x31, 0x39, 0x3a, 0x31, + 0x35, 0x3a, 0x35, 0x36, 0x20, 0x47, 0x4d, 0x54, 0x04, 0x65, 0x74, 0x61, 0x67, 0x14, + 0x22, 0x33, 0x34, 0x61, 0x61, 0x33, 0x38, 0x37, 0x2d, 0x64, 0x2d, 0x31, 0x35, 0x36, + 0x38, 0x65, 0x62, 0x30, 0x30, 0x22, 0x0d, 0x61, 0x63, 0x63, 0x65, 0x70, 0x74, 0x2d, + 0x72, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x05, 0x62, 0x79, 0x74, 0x65, 0x73, 0x0e, 0x63, + 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x2d, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x02, + 0x35, 0x31, 0x04, 0x76, 0x61, 0x72, 0x79, 0x0f, 0x41, 0x63, 0x63, 0x65, 0x70, 0x74, + 0x2d, 0x45, 0x6e, 0x63, 0x6f, 0x64, 0x69, 0x6e, 0x67, 0x0c, 0x63, 0x6f, 0x6e, 0x74, + 0x65, 0x6e, 0x74, 0x2d, 0x74, 0x79, 0x70, 0x65, 0x0a, 0x74, 0x65, 0x78, 0x74, 0x2f, + 0x70, 0x6c, 0x61, 0x69, 0x6e, 0x00, 0x33, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x57, + 0x6f, 0x72, 0x6c, 0x64, 0x21, 0x20, 0x4d, 0x79, 0x20, 0x63, 0x6f, 0x6e, 0x74, 0x65, + 0x6e, 0x74, 0x20, 0x69, 0x6e, 0x63, 0x6c, 0x75, 0x64, 0x65, 0x73, 0x20, 0x61, 0x20, + 0x74, 0x72, 0x61, 0x69, 0x6c, 0x69, 0x6e, 0x67, 0x20, 0x43, 0x52, 0x4c, 0x46, 0x2e, + 0x0d, 0x0a, 0x00, 0x00, + ]; + + let mut buf = RESPONSE; + let mut msg = Message::async_read(&mut buf); + + { + // Need to scope access to `info` or it will hold the reference to `msg`. + let mut info = pin!(msg.informational()); + + let info1 = info.try_next().sync_resolve().unwrap().unwrap(); + assert_eq!(info1.status(), 102_u16); + assert_eq!(info1.len(), 1); + assert_eq!(info1.get(b"running"), Some(&b"\"sleep 15\""[..])); + + let info2 = info.try_next().sync_resolve().unwrap().unwrap(); + assert_eq!(info2.status(), 103_u16); + assert_eq!(info2.len(), 2); + let links = info2.get_all(b"link").collect::>(); + assert_eq!( + &links, + &[ + &b"; rel=preload; as=style"[..], + &b"; rel=preload; as=script"[..], + ] + ); + + assert!(info.try_next().sync_resolve().unwrap().is_none()); + } + + let hdr = pin!(msg.header()).sync_resolve().unwrap(); + assert_eq!(hdr.control(), &200_u16); + assert_eq!(hdr.len(), 8); + assert_eq!(hdr.get(b"vary"), Some(&b"Accept-Encoding"[..])); + assert_eq!(hdr.get(b"etag"), Some(&b"\"34aa387-d-1568eb00\""[..])); + + { + let mut body = pin!(msg.body().unwrap()); + assert_eq!(body.sync_read_exact(12), b"Hello World!"); + } + // Attempting to read the trailer before finishing the body should fail. + assert!(matches!( + pin!(msg.trailer()).sync_resolve(), + Err(Error::InvalidState) + )); + { + // Picking up the body again should work fine. + let mut body = pin!(msg.body().unwrap()); + assert_eq!( + body.sync_read_to_end(), + b" My content includes a trailing CRLF.\r\n" + ); + } + let trailer = pin!(msg.trailer()).sync_resolve().unwrap(); + assert!(trailer.is_empty()); + } } diff --git a/bhttp/src/stream/vec.rs b/bhttp/src/stream/vec.rs index 16dd433..05a4e24 100644 --- a/bhttp/src/stream/vec.rs +++ b/bhttp/src/stream/vec.rs @@ -1,7 +1,7 @@ use std::{ future::Future, mem, - pin::Pin, + pin::{pin, Pin}, task::{Context, Poll}, }; @@ -12,24 +12,26 @@ use crate::{Error, Res}; #[pin_project::pin_project(project = ReadVecProj)] #[allow(clippy::module_name_repetitions)] -pub enum ReadVec<'a, S> { +pub enum ReadVec { // Invariant: This Option is always Some. ReadLen { - src: Option>, + src: Option>, cap: u64, }, ReadBody { - src: Pin<&'a mut S>, + src: S, buf: Vec, remaining: usize, }, } -impl<'a, S> ReadVec<'a, S> { +impl ReadVec { + #![allow(dead_code)] // TODO these really need to be used. + /// # Panics /// If `limit` is more than `usize::MAX` or /// if this is called after the length is read. - fn limit(&mut self, limit: u64) { + pub fn limit(&mut self, limit: u64) { usize::try_from(limit).expect("cannot set a limit larger than usize::MAX"); if let Self::ReadLen { ref mut cap, .. } = self { *cap = limit; @@ -38,7 +40,7 @@ impl<'a, S> ReadVec<'a, S> { } } - fn stream(self) -> Pin<&'a mut S> { + pub fn stream(self) -> S { match self { Self::ReadLen { mut src, .. } => src.take().unwrap().stream(), Self::ReadBody { src, .. } => src, @@ -46,10 +48,7 @@ impl<'a, S> ReadVec<'a, S> { } } -impl<'a, S> Future for ReadVec<'a, S> -where - S: AsyncRead, -{ +impl Future for ReadVec { type Output = Res>>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -86,7 +85,7 @@ where }; let offset = buf.len() - *remaining; - match src.as_mut().poll_read(cx, &mut buf[offset..]) { + match pin!(src).poll_read(cx, &mut buf[offset..]) { Poll::Pending => Poll::Pending, Poll::Ready(Err(e)) => Poll::Ready(Err(Error::from(e))), Poll::Ready(Ok(0)) => Poll::Ready(Err(Error::Truncated)), @@ -103,7 +102,7 @@ where } #[allow(clippy::module_name_repetitions)] -pub fn read_vec(src: &mut S) -> ReadVec<'_, S> { +pub fn read_vec(src: S) -> ReadVec { ReadVec::ReadLen { src: Some(read_varint(src)), cap: u64::try_from(usize::MAX).unwrap_or(u64::MAX), From 6f4ec9a616fee6d139f3d2b3a5a2247b02e22782 Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Mon, 28 Oct 2024 18:20:11 +1100 Subject: [PATCH 06/20] Working enough for now --- bhttp/src/stream/future.rs | 20 +++ bhttp/src/stream/mod.rs | 340 +++++++++++++++++++++++-------------- 2 files changed, 237 insertions(+), 123 deletions(-) diff --git a/bhttp/src/stream/future.rs b/bhttp/src/stream/future.rs index 3eae2bb..9d0362c 100644 --- a/bhttp/src/stream/future.rs +++ b/bhttp/src/stream/future.rs @@ -104,3 +104,23 @@ impl SyncRead for S { buf } } + +pub struct Dribble { + src: S, +} + +impl Dribble { + pub fn new(src: S) -> Self { + Self { src } + } +} + +impl AsyncRead for Dribble { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + pin!(&mut self.src).poll_read(cx, &mut buf[..1]) + } +} diff --git a/bhttp/src/stream/mod.rs b/bhttp/src/stream/mod.rs index 9839059..7496d8f 100644 --- a/bhttp/src/stream/mod.rs +++ b/bhttp/src/stream/mod.rs @@ -1,3 +1,4 @@ +#![allow(dead_code)] #![allow(clippy::incompatible_msrv)] // This module uses features from rust 1.82 use std::{ @@ -8,8 +9,7 @@ use std::{ task::{Context, Poll}, }; -use futures::{stream::unfold, AsyncRead, FutureExt, Stream, TryStreamExt}; -use int::ReadVarint; +use futures::{stream::unfold, AsyncRead, Stream, TryStreamExt}; use crate::{ err::Res, @@ -89,135 +89,95 @@ impl AsyncReadFieldSection for FieldSection { } } -#[allow(clippy::mut_mut)] // TODO look into this more. -enum BodyState<'a, 'b, S> { +#[derive(Default)] +enum BodyState { + // The starting state. + #[default] + Init, // When reading the length, use this. - // Invariant: This is always `Some`. - ReadLength(Option>), + ReadLength { + buf: [u8; 8], + read: usize, + }, // When reading the data, track how much is left. - // Invariant: `src` is always `Some`. ReadData { remaining: usize, - src: Option<&'b mut &'a mut S>, }, } -pub struct Body<'a, 'b, S> { - mode: Mode, - state: &'b mut AsyncMessageState<'a, 'b, S>, -} - -impl<'a, 'b, S> Body<'a, 'b, S> { - fn set_state(&mut self, s: BodyState<'a, 'b, S>) { - *self.state = AsyncMessageState::Body(s); +impl BodyState { + fn read_len() -> Self { + Self::ReadLength { + buf: [0; 8], + read: 0, + } } +} - fn done(&mut self) { - *self.state = AsyncMessageState::Trailer; - } +pub struct Body<'b, S> { + msg: &'b mut AsyncMessage, } -impl<'a, 'b, S: AsyncRead + Unpin> AsyncRead for Body<'a, 'b, S> { +impl<'b, S> Body<'b, S> {} + +impl<'b, S: AsyncRead + Unpin> AsyncRead for Body<'b, S> { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { - fn poll_error(e: Error) -> Poll> { - Poll::Ready(Err(IoError::other(e))) - } - - let mode = self.mode; - if let AsyncMessageState::Body(BodyState::ReadLength(r)) = &mut self.state { - match r.as_mut().unwrap().poll_unpin(cx) { - Poll::Pending => return Poll::Pending, - Poll::Ready(Ok(Some(0) | None)) => { - self.done(); - return Poll::Ready(Ok(0)); - } - Poll::Ready(Ok(Some(len))) => { - match usize::try_from(len) { - Ok(remaining) => { - let src = r.take().map(ReadVarint::stream); - self.set_state(BodyState::ReadData { remaining, src }); - // fall through to maybe read the body - } - Err(e) => return poll_error(Error::IntRange(e)), - } - } - Poll::Ready(Err(e)) => return poll_error(e), - } - } - - if let AsyncMessageState::Body(BodyState::ReadData { remaining, src }) = &mut self.state { - let amount = min(*remaining, buf.len()); - let res = pin!(src.as_mut().unwrap()).poll_read(cx, &mut buf[..amount]); - match res { - Poll::Pending => Poll::Pending, - Poll::Ready(Ok(0)) => poll_error(Error::Truncated), - Poll::Ready(Ok(len)) => { - *remaining -= len; - if *remaining == 0 { - if mode == Mode::IndeterminateLength { - let src = src.take().map(read_varint); - self.set_state(BodyState::ReadLength(src)); - } else { - self.done(); - } - } - Poll::Ready(Ok(len)) - } - Poll::Ready(Err(e)) => Poll::Ready(Err(e)), - } - } else { - Poll::Pending - } + self.msg.read_body(cx, buf).map_err(IoError::other) } } -enum AsyncMessageState<'a, 'b, S> { +/// A helper function for the more complex body-reading code. +fn poll_error(e: Error) -> Poll> { + Poll::Ready(Err(IoError::other(e))) +} + +enum AsyncMessageState { + Init, // Processing Informational responses (or before that). - Informational, + Informational(bool), // Having obtained the control data for the header, this is it. Header(ControlData), // Processing the Body. - Body(BodyState<'a, 'b, S>), + Body(BodyState), // Processing the trailer. Trailer, + // All done. + Done, } -pub struct AsyncMessage<'a, 'b, S> { +pub struct AsyncMessage { // Whether this is a request and which mode. - framing: Option<(bool, Mode)>, - state: AsyncMessageState<'a, 'b, S>, - src: &'a mut S, + mode: Option, + state: AsyncMessageState, + src: S, } -unsafe impl Send for AsyncMessage<'_, '_, S> {} - -impl<'a, 'b, S: AsyncRead + Unpin> AsyncMessage<'a, 'b, S> { - /// Get the mode. This panics if the header hasn't been read yet. - fn mode(&self) -> Mode { - self.framing.unwrap().1 - } +unsafe impl Send for AsyncMessage {} +impl AsyncMessage { async fn next_info(&mut self) -> Res> { - if !matches!(self.state, AsyncMessageState::Informational) { - return Ok(None); - } - - let (request, mode) = if let Some((request, mode)) = self.framing { - (request, mode) - } else { + let request = if matches!(self.state, AsyncMessageState::Init) { + // Read control data ... let t = read_varint(&mut self.src).await?.ok_or(Error::Truncated)?; let request = t == 0 || t == 2; - let mode = Mode::try_from(t)?; - self.framing = Some((request, mode)); - (request, mode) + self.mode = Some(Mode::try_from(t)?); + self.state = AsyncMessageState::Informational(request); + request + } else { + // ... or recover it. + let AsyncMessageState::Informational(request) = self.state else { + return Err(Error::InvalidState); + }; + request }; let control = ControlData::async_read(request, &mut self.src).await?; if let Some(status) = control.informational() { + let mode = self.mode.unwrap(); let fields = FieldSection::async_read(mode, &mut self.src).await?; Ok(Some(InformationalResponse::new(status, fields))) } else { @@ -227,7 +187,7 @@ impl<'a, 'b, S: AsyncRead + Unpin> AsyncMessage<'a, 'b, S> { } /// Produces a stream of informational responses from a fresh message. - /// Returns an empty stream if called at other times. + /// Returns an empty stream if passed a request (or if there are no informational responses). /// Error values on the stream indicate failures. /// /// There is no need to call this method to read a request, though @@ -237,9 +197,7 @@ impl<'a, 'b, S: AsyncRead + Unpin> AsyncMessage<'a, 'b, S> { /// without affecting the message. You can then either call this /// method again to get any additional informational responses or /// call `header()` to get the message header. - pub fn informational( - &mut self, - ) -> impl Stream> + use<'_, 'a, 'b, S> { + pub fn informational(&mut self) -> impl Stream> + use<'_, S> { unfold(self, |this| async move { this.next_info().await.transpose().map(|info| (info, this)) }) @@ -247,21 +205,27 @@ impl<'a, 'b, S: AsyncRead + Unpin> AsyncMessage<'a, 'b, S> { /// This reads the header. If you have not called `informational` /// and drained the resulting stream, this will do that for you. - pub async fn header(&'b mut self) -> Res
{ - if matches!(self.state, AsyncMessageState::Informational) { + /// # Panics + /// Never. + pub async fn header(&mut self) -> Res
{ + if matches!( + self.state, + AsyncMessageState::Init | AsyncMessageState::Informational(_) + ) { // Need to scrub for errors, // so that this can abort properly if there is one. // The `try_any` usage is there to ensure that the stream is fully drained. _ = self.informational().try_any(|_| async { false }).await?; } + if matches!(self.state, AsyncMessageState::Header(_)) { - let mode = self.mode(); + let mode = self.mode.unwrap(); let hfields = FieldSection::async_read(mode, &mut self.src).await?; - let bs: BodyState<'a, 'b, S> = BodyState::ReadLength(Some(read_varint(&mut self.src))); - let AsyncMessageState::Header(control) = - mem::replace(&mut self.state, AsyncMessageState::Body(bs)) - else { + let AsyncMessageState::Header(control) = mem::replace( + &mut self.state, + AsyncMessageState::Body(BodyState::default()), + ) else { unreachable!(); }; Ok(Header::from((control, hfields))) @@ -270,21 +234,146 @@ impl<'a, 'b, S: AsyncRead + Unpin> AsyncMessage<'a, 'b, S> { } } + fn body_state(&mut self, s: BodyState) { + self.state = AsyncMessageState::Body(s); + } + + fn body_done(&mut self) { + self.state = AsyncMessageState::Trailer; + } + + /// Read the length of a body chunk. + /// This updates the values of `read` and `buf` to track the portion of the length + /// that was successfully read. + /// Returns `Some` with the error code that should be used if the reading + /// resulted in a conclusive outcome. + fn read_body_len( + cx: &mut Context<'_>, + mut src: &mut S, + first: bool, + read: &mut usize, + buf: &mut [u8; 8], + ) -> Option>> { + let mut src = pin!(src); + if *read == 0 { + let mut b = [0; 1]; + match src.as_mut().poll_read(cx, &mut b[..]) { + Poll::Pending => return Some(Poll::Pending), + Poll::Ready(Ok(0)) => { + return if first { + // It's OK for the first length to be absent. + // Just skip to the end. + *read = 8; + None + } else { + // ...it's not OK to drop length when continuing. + Some(poll_error(Error::Truncated)) + }; + } + Poll::Ready(Ok(1)) => match b[0] >> 6 { + 0 => { + buf[7] = b[0] & 0x3f; + *read = 8; + } + 1 => { + buf[6] = b[0] & 0x3f; + *read = 7; + } + 2 => { + buf[4] = b[0] & 0x3f; + *read = 5; + } + 3 => { + buf[0] = b[0] & 0x3f; + *read = 1; + } + _ => unreachable!(), + }, + Poll::Ready(Ok(_)) => unreachable!(), + Poll::Ready(Err(e)) => return Some(Poll::Ready(Err(e))), + } + } + if *read < 8 { + match src.as_mut().poll_read(cx, &mut buf[*read..]) { + Poll::Pending => return Some(Poll::Pending), + Poll::Ready(Ok(0)) => return Some(poll_error(Error::Truncated)), + Poll::Ready(Ok(len)) => { + *read += len; + } + Poll::Ready(Err(e)) => return Some(Poll::Ready(Err(e))), + } + } + None + } + + fn read_body(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + // The length that precedes the first chunk can be absent. + // Only allow that for the first chunk (if indeterminate length). + let first = if let AsyncMessageState::Body(BodyState::Init) = &self.state { + self.body_state(BodyState::read_len()); + true + } else { + false + }; + + // Read the length. This uses `read_body_len` to track the state of this reading. + // This doesn't use `ReadVarint` or any convenience functions because we + // need to track the state and we don't want the borrow checker to flip out. + if let AsyncMessageState::Body(BodyState::ReadLength { buf, read }) = &mut self.state { + if let Some(res) = Self::read_body_len(cx, &mut self.src, first, read, buf) { + return res; + } + if *read == 8 { + match usize::try_from(u64::from_be_bytes(*buf)) { + Ok(0) => { + self.body_done(); + return Poll::Ready(Ok(0)); + } + Ok(remaining) => { + self.body_state(BodyState::ReadData { remaining }); + } + Err(e) => return poll_error(Error::IntRange(e)), + } + } + } + + match &mut self.state { + AsyncMessageState::Body(BodyState::ReadData { remaining }) => { + let amount = min(*remaining, buf.len()); + let res = pin!(&mut self.src).poll_read(cx, &mut buf[..amount]); + match res { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(0)) => poll_error(Error::Truncated), + Poll::Ready(Ok(len)) => { + *remaining -= len; + if *remaining == 0 { + let mode = self.mode.unwrap(); + if mode == Mode::IndeterminateLength { + self.body_state(BodyState::read_len()); + } else { + self.body_done(); + } + } + Poll::Ready(Ok(len)) + } + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + } + } + AsyncMessageState::Trailer => Poll::Ready(Ok(0)), + _ => Poll::Pending, + } + } + /// Read the body. /// This produces an implementation of `AsyncRead` that filters out /// the framing from the message body. /// # Errors /// This errors when the header has not been read. /// Any IO errors are generated by the returned `Body` instance. - pub fn body(&'b mut self) -> Res> { - if matches!(self.state, AsyncMessageState::Body(_)) { - let mode = self.mode(); - Ok(Body { - mode, - state: &mut self.state, - }) - } else { - Err(Error::InvalidState) + pub fn body(&mut self) -> Res> { + match self.state { + AsyncMessageState::Body(_) => Ok(Body { msg: self }), + _ => Err(Error::InvalidState), } } @@ -292,9 +381,13 @@ impl<'a, 'b, S: AsyncRead + Unpin> AsyncMessage<'a, 'b, S> { /// This might be empty. /// # Errors /// This errors when the body has not been read. + /// # Panics + /// Never. pub async fn trailer(&mut self) -> Res { if matches!(self.state, AsyncMessageState::Trailer) { - Ok(FieldSection::async_read(self.mode(), &mut self.src).await?) + let trailer = FieldSection::async_read(self.mode.unwrap(), &mut self.src).await?; + self.state = AsyncMessageState::Done; + Ok(trailer) } else { Err(Error::InvalidState) } @@ -302,14 +395,14 @@ impl<'a, 'b, S: AsyncRead + Unpin> AsyncMessage<'a, 'b, S> { } pub trait AsyncReadMessage: Sized { - fn async_read<'b, S: AsyncRead + Unpin>(src: &mut S) -> AsyncMessage<'_, 'b, S>; + fn async_read(src: S) -> AsyncMessage; } impl AsyncReadMessage for Message { - fn async_read<'b, S: AsyncRead + Unpin>(src: &mut S) -> AsyncMessage<'_, 'b, S> { + fn async_read(src: S) -> AsyncMessage { AsyncMessage { - framing: None, - state: AsyncMessageState::Informational, + mode: None, + state: AsyncMessageState::Init, src, } } @@ -323,7 +416,7 @@ mod test { use crate::{ stream::{ - future::{SyncCollect, SyncRead, SyncResolve}, + future::{Dribble, SyncCollect, SyncRead, SyncResolve}, AsyncReadMessage, }, Error, Message, @@ -361,8 +454,8 @@ mod test { let mut msg = Message::async_read(&mut buf_alias); let info = msg.informational().sync_collect().unwrap(); assert_eq!(info.len(), 1); - let info = msg.informational().sync_collect().unwrap(); - assert!(info.is_empty()); + let err = msg.informational().sync_collect(); + assert!(matches!(err, Err(Error::InvalidState))); let hdr = pin!(msg.header()).sync_resolve().unwrap(); assert_eq!(hdr.control().status().unwrap().code(), 200); assert!(hdr.is_empty()); @@ -413,8 +506,9 @@ mod test { assert!(matches!(err, Error::Truncated)); } + /// This test is crazy. It reads a byte at a time and checks the state constantly. #[test] - fn sample_responses() { + fn sample_response() { const RESPONSE: &[u8] = &[ 0x03, 0x40, 0x66, 0x07, 0x72, 0x75, 0x6e, 0x6e, 0x69, 0x6e, 0x67, 0x0a, 0x22, 0x73, 0x6c, 0x65, 0x65, 0x70, 0x20, 0x31, 0x35, 0x22, 0x00, 0x40, 0x67, 0x04, 0x6c, 0x69, @@ -446,7 +540,7 @@ mod test { ]; let mut buf = RESPONSE; - let mut msg = Message::async_read(&mut buf); + let mut msg = Message::async_read(Dribble::new(&mut buf)); { // Need to scope access to `info` or it will hold the reference to `msg`. From aac02a98a940153ab8fe1a8109612f79931f51f2 Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Tue, 29 Oct 2024 11:32:55 +1100 Subject: [PATCH 07/20] Update to use dep: syntax for dependencies --- bhttp/Cargo.toml | 6 +++--- ohttp/Cargo.toml | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/bhttp/Cargo.toml b/bhttp/Cargo.toml index 14aff2b..21776f9 100644 --- a/bhttp/Cargo.toml +++ b/bhttp/Cargo.toml @@ -9,9 +9,9 @@ description = "Binary HTTP messages (RFC 9292)" repository = "https://github.com/martinthomson/ohttp" [features] -default = ["stream"] -http = ["url"] -stream = ["futures", "pin-project"] +default = [] +http = ["dep:url"] +stream = ["dep:futures", "dep:pin-project"] [dependencies] futures = {version = "0.3", optional = true} diff --git a/ohttp/Cargo.toml b/ohttp/Cargo.toml index ffaabd5..237e3b1 100644 --- a/ohttp/Cargo.toml +++ b/ohttp/Cargo.toml @@ -14,11 +14,11 @@ default = ["client", "server", "rust-hpke"] app-svc = ["nss"] client = [] external-sqlite = [] -gecko = ["nss", "mozbuild"] -nss = ["bindgen", "regex-mess"] -pq = ["hpke-pq"] -regex-mess = ["regex", "regex-automata", "regex-syntax"] -rust-hpke = ["rand", "aead", "aes-gcm", "chacha20poly1305", "hkdf", "sha2", "hpke"] +gecko = ["nss", "dep:mozbuild"] +nss = ["dep:bindgen", "regex-mess"] +pq = ["dep:hpke-pq"] +regex-mess = ["dep:regex", "dep:regex-automata", "dep:regex-syntax"] +rust-hpke = ["dep:rand", "dep:aead", "dep:aes-gcm", "dep:chacha20poly1305", "dep:hkdf", "dep:sha2", "dep:hpke"] server = [] [dependencies] From ab494ddca7491ad2047779d8dd8ebfbe0432c830 Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Tue, 29 Oct 2024 13:24:20 +1100 Subject: [PATCH 08/20] Fix NSS dependency ordering; remove dead conditional code --- bhttp/src/err.rs | 3 --- ohttp/build.rs | 20 +++++++------------- 2 files changed, 7 insertions(+), 16 deletions(-) diff --git a/bhttp/src/err.rs b/bhttp/src/err.rs index 19d5455..d9d9b6f 100644 --- a/bhttp/src/err.rs +++ b/bhttp/src/err.rs @@ -6,9 +6,6 @@ pub enum Error { ConnectUnsupported, #[error("a field contained invalid Unicode: {0}")] CharacterEncoding(#[from] std::string::FromUtf8Error), - #[error("a chunk of data of {0} bytes is too large")] - #[cfg(feature = "stream")] - ChunkTooLarge(u64), #[error("read a response when expecting a request")] ExpectedRequest, #[error("read a request when expecting a response")] diff --git a/ohttp/build.rs b/ohttp/build.rs index 1c01e3f..312cce2 100644 --- a/ohttp/build.rs +++ b/ohttp/build.rs @@ -8,8 +8,6 @@ #[cfg(feature = "nss")] mod nss { - use bindgen::Builder; - use serde_derive::Deserialize; use std::{ collections::HashMap, env, fs, @@ -17,6 +15,9 @@ mod nss { process::Command, }; + use bindgen::Builder; + use serde_derive::Deserialize; + const BINDINGS_DIR: &str = "bindings"; const BINDINGS_CONFIG: &str = "bindings.toml"; @@ -114,7 +115,6 @@ mod nss { let mut build_nss = vec![ String::from("./build.sh"), String::from("-Ddisable_tests=1"), - String::from("-Denable_draft_hpke=1"), ]; if is_debug() { build_nss.push(String::from("--static")); @@ -191,16 +191,8 @@ mod nss { } fn static_link(nsslibdir: &Path, use_static_softoken: bool, use_static_nspr: bool) { - let mut static_libs = vec![ - "certdb", - "certhi", - "cryptohi", - "nss_static", - "nssb", - "nssdev", - "nsspki", - "nssutil", - ]; + // The ordering of these libraries is critical for the linker. + let mut static_libs = vec!["cryptohi", "nss_static"]; let mut dynamic_libs = vec![]; if use_static_softoken { @@ -211,6 +203,8 @@ mod nss { static_libs.push("pk11wrap"); } + static_libs.extend_from_slice(&["nsspki", "nssdev", "nssb", "certhi", "certdb", "nssutil"]); + if use_static_nspr { static_libs.append(&mut nspr_libs()); } else { From 69a22b95843b8f50f831eeaad620e670664940f6 Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Tue, 29 Oct 2024 13:24:42 +1100 Subject: [PATCH 09/20] Improve varint codec test coverage --- bhttp/src/rw.rs | 84 ++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 73 insertions(+), 11 deletions(-) diff --git a/bhttp/src/rw.rs b/bhttp/src/rw.rs index 92009ed..b081dea 100644 --- a/bhttp/src/rw.rs +++ b/bhttp/src/rw.rs @@ -8,12 +8,10 @@ use crate::{err::Error, ReadSeek}; #[cfg(feature = "write-bhttp")] #[allow(clippy::cast_possible_truncation)] -fn write_uint(n: u8, v: impl Into, w: &mut impl io::Write) -> Res<()> { - let v = v.into(); - assert!(n > 0 && usize::from(n) < std::mem::size_of::()); - for i in 0..n { - w.write_all(&[((v >> (8 * (n - i - 1))) & 0xff) as u8])?; - } +pub(crate) fn write_uint(v: impl Into, w: &mut impl io::Write) -> Res<()> { + let v = v.into().to_be_bytes(); + assert!((1..=std::mem::size_of::()).contains(&N)); + w.write_all(&v[std::mem::size_of::() - N..])?; Ok(()) } @@ -21,11 +19,11 @@ fn write_uint(n: u8, v: impl Into, w: &mut impl io::Write) -> Res<()> { pub fn write_varint(v: impl Into, w: &mut impl io::Write) -> Res<()> { let v = v.into(); match () { - () if v < (1 << 6) => write_uint(1, v, w), - () if v < (1 << 14) => write_uint(2, v | (1 << 14), w), - () if v < (1 << 30) => write_uint(4, v | (2 << 30), w), - () if v < (1 << 62) => write_uint(8, v | (3 << 62), w), - () => panic!("Varint value too large"), + () if v < (1 << 6) => write_uint::<1>(v, w), + () if v < (1 << 14) => write_uint::<2>(v | (1 << 14), w), + () if v < (1 << 30) => write_uint::<4>(v | (2 << 30), w), + () if v < (1 << 62) => write_uint::<8>(v | (3 << 62), w), + () => panic!("varint value too large"), } } @@ -106,3 +104,67 @@ where Ok(None) } } + +#[cfg(test)] +mod test { + use std::io::Cursor; + + use super::{read_varint, write_varint}; + use crate::{rw::read_vec, Error}; + + #[test] + fn basics() { + for i in [ + 0_u64, + 1, + 17, + 63, + 64, + 100, + 0x3fff, + 0x4000, + 0x1_0002, + 0x3fff_ffff, + 0x4000_0000, + 0x3456_dead_beef, + 0x3fff_ffff_ffff_ffff, + ] { + let mut buf = Vec::new(); + write_varint(i, &mut buf).unwrap(); + let sz_bytes = (64 - i.leading_zeros() + 2 + 7) / 8; // +2 size bits, +7 to round up + assert_eq!( + buf.len(), + usize::try_from(sz_bytes.next_power_of_two()).unwrap() + ); + + let o = read_varint(&mut Cursor::new(buf.clone())).unwrap(); + assert_eq!(Some(i), o); + + for cut in 1..buf.len() { + let e = read_varint(&mut Cursor::new(buf[..cut].to_vec())).unwrap_err(); + assert!(matches!(e, Error::Truncated)); + } + } + } + + #[test] + fn read_nothing() { + let o = read_varint(&mut Cursor::new(Vec::new())).unwrap(); + assert!(o.is_none()); + } + + #[test] + #[should_panic(expected = "varint value too large")] + fn too_big() { + _ = write_varint(0x4000_0000_0000_0000_u64, &mut Vec::new()); + } + + #[test] + fn too_big_vec() { + let mut buf = Vec::new(); + write_varint(10_u64, &mut buf).unwrap(); + buf.resize(10, 0); // Not enough extra for the promised length. + let e = read_vec(&mut Cursor::new(buf.clone())).unwrap_err(); + assert!(matches!(e, Error::Truncated)); + } +} From fc800e300e17fe18b7ae7eacae9c3c2f98e789a4 Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Tue, 29 Oct 2024 13:25:45 +1100 Subject: [PATCH 10/20] Add mutants output to gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 200abe7..39ac3d3 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ *~ *.swp /.vscode/ +/mutants.out*/ From 00cfe4f018b52c52f8ae5b00273bda4af12ef8d8 Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Tue, 29 Oct 2024 15:00:05 +1100 Subject: [PATCH 11/20] Merge main --- .gitignore | 1 + bhttp/src/err.rs | 3 --- bhttp/src/rw.rs | 68 ++++++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 67 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 200abe7..39ac3d3 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ *~ *.swp /.vscode/ +/mutants.out*/ diff --git a/bhttp/src/err.rs b/bhttp/src/err.rs index eb14acc..45d6fab 100644 --- a/bhttp/src/err.rs +++ b/bhttp/src/err.rs @@ -4,9 +4,6 @@ pub enum Error { ConnectUnsupported, #[error("a field contained invalid Unicode: {0}")] CharacterEncoding(#[from] std::string::FromUtf8Error), - #[error("a chunk of data of {0} bytes is too large")] - #[cfg(feature = "stream")] - ChunkTooLarge(u64), #[error("read a response when expecting a request")] ExpectedRequest, #[error("read a request when expecting a response")] diff --git a/bhttp/src/rw.rs b/bhttp/src/rw.rs index fa7c717..4659dd4 100644 --- a/bhttp/src/rw.rs +++ b/bhttp/src/rw.rs @@ -9,7 +9,7 @@ use crate::{ pub(crate) fn write_uint(v: impl Into, w: &mut impl io::Write) -> Res<()> { let v = v.into().to_be_bytes(); assert!((1..=std::mem::size_of::()).contains(&N)); - w.write_all(&v[8 - N..])?; + w.write_all(&v[std::mem::size_of::() - N..])?; Ok(()) } @@ -20,7 +20,7 @@ pub fn write_varint(v: impl Into, w: &mut impl io::Write) -> Res<()> { () if v < (1 << 14) => write_uint::<2>(v | (1 << 14), w), () if v < (1 << 30) => write_uint::<4>(v | (2 << 30), w), () if v < (1 << 62) => write_uint::<8>(v | (3 << 62), w), - () => panic!("Varint value too large"), + () => panic!("varint value too large"), } } @@ -92,3 +92,67 @@ where Ok(None) } } + +#[cfg(test)] +mod test { + use std::io::Cursor; + + use super::{read_varint, write_varint}; + use crate::{rw::read_vec, Error}; + + #[test] + fn basics() { + for i in [ + 0_u64, + 1, + 17, + 63, + 64, + 100, + 0x3fff, + 0x4000, + 0x1_0002, + 0x3fff_ffff, + 0x4000_0000, + 0x3456_dead_beef, + 0x3fff_ffff_ffff_ffff, + ] { + let mut buf = Vec::new(); + write_varint(i, &mut buf).unwrap(); + let sz_bytes = (64 - i.leading_zeros() + 2 + 7) / 8; // +2 size bits, +7 to round up + assert_eq!( + buf.len(), + usize::try_from(sz_bytes.next_power_of_two()).unwrap() + ); + + let o = read_varint(&mut Cursor::new(buf.clone())).unwrap(); + assert_eq!(Some(i), o); + + for cut in 1..buf.len() { + let e = read_varint(&mut Cursor::new(buf[..cut].to_vec())).unwrap_err(); + assert!(matches!(e, Error::Truncated)); + } + } + } + + #[test] + fn read_nothing() { + let o = read_varint(&mut Cursor::new(Vec::new())).unwrap(); + assert!(o.is_none()); + } + + #[test] + #[should_panic(expected = "varint value too large")] + fn too_big() { + _ = write_varint(0x4000_0000_0000_0000_u64, &mut Vec::new()); + } + + #[test] + fn too_big_vec() { + let mut buf = Vec::new(); + write_varint(10_u64, &mut buf).unwrap(); + buf.resize(10, 0); // Not enough extra for the promised length. + let e = read_vec(&mut Cursor::new(buf.clone())).unwrap_err(); + assert!(matches!(e, Error::Truncated)); + } +} From e19de296dff3e35cae5b828f512af7bc2863fa1b Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Tue, 29 Oct 2024 15:40:19 +1100 Subject: [PATCH 12/20] Better formatting --- bhttp-convert/src/main.rs | 3 +- ohttp-client-cli/src/main.rs | 3 +- ohttp-client/src/main.rs | 3 +- ohttp/src/config.rs | 17 +++++----- ohttp/src/lib.rs | 29 ++++++++-------- ohttp/src/nss/aead.rs | 14 ++++---- ohttp/src/nss/err.rs | 3 +- ohttp/src/nss/hkdf.rs | 6 ++-- ohttp/src/nss/hpke.rs | 15 +++++---- ohttp/src/nss/mod.rs | 6 ++-- ohttp/src/nss/p11.rs | 5 +-- ohttp/src/rh/aead.rs | 8 +++-- ohttp/src/rh/hkdf.rs | 7 ++-- ohttp/src/rh/hpke.rs | 23 ++++++------- pre-commit | 65 +++++++++++++++++++++++++----------- 15 files changed, 123 insertions(+), 84 deletions(-) diff --git a/bhttp-convert/src/main.rs b/bhttp-convert/src/main.rs index 763fadf..054c64f 100644 --- a/bhttp-convert/src/main.rs +++ b/bhttp-convert/src/main.rs @@ -1,11 +1,12 @@ #![deny(warnings, clippy::pedantic)] -use bhttp::{Message, Mode}; use std::{ fs::File, io::{self, Read}, path::PathBuf, }; + +use bhttp::{Message, Mode}; use structopt::StructOpt; #[derive(Debug, StructOpt)] diff --git a/ohttp-client-cli/src/main.rs b/ohttp-client-cli/src/main.rs index 7acc0f1..37a948d 100644 --- a/ohttp-client-cli/src/main.rs +++ b/ohttp-client-cli/src/main.rs @@ -1,8 +1,9 @@ #![deny(warnings, clippy::pedantic)] +use std::io::{self, BufRead, Write}; + use bhttp::{Message, Mode}; use ohttp::{init, ClientRequest}; -use std::io::{self, BufRead, Write}; fn main() { init(); diff --git a/ohttp-client/src/main.rs b/ohttp-client/src/main.rs index 1d71dc2..4e8a431 100644 --- a/ohttp-client/src/main.rs +++ b/ohttp-client/src/main.rs @@ -1,7 +1,8 @@ #![deny(warnings, clippy::pedantic)] -use bhttp::{Message, Mode}; use std::{fs::File, io, io::Read, ops::Deref, path::PathBuf, str::FromStr}; + +use bhttp::{Message, Mode}; use structopt::StructOpt; type Res = Result>; diff --git a/ohttp/src/config.rs b/ohttp/src/config.rs index 9b55755..ebadb3c 100644 --- a/ohttp/src/config.rs +++ b/ohttp/src/config.rs @@ -1,24 +1,24 @@ -use crate::{ - err::{Error, Res}, - hpke::{Aead as AeadId, Kdf, Kem}, - KeyId, -}; -use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt}; use std::{ convert::TryFrom, io::{BufRead, BufReader, Cursor, Read}, }; +use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt}; + #[cfg(feature = "nss")] use crate::nss::{ hpke::{generate_key_pair, Config as HpkeConfig, HpkeR}, PrivateKey, PublicKey, }; - #[cfg(feature = "rust-hpke")] use crate::rh::hpke::{ derive_key_pair, generate_key_pair, Config as HpkeConfig, HpkeR, PrivateKey, PublicKey, }; +use crate::{ + err::{Error, Res}, + hpke::{Aead as AeadId, Kdf, Kem}, + KeyId, +}; /// A tuple of KDF and AEAD identifiers. #[derive(Debug, Copy, Clone, PartialEq, Eq)] @@ -270,11 +270,12 @@ impl AsRef for KeyConfig { #[cfg(test)] mod test { + use std::iter::zip; + use crate::{ hpke::{Aead, Kdf, Kem}, init, Error, KeyConfig, KeyId, SymmetricSuite, }; - use std::iter::zip; const KEY_ID: KeyId = 1; const KEM: Kem = Kem::X25519Sha256; diff --git a/ohttp/src/lib.rs b/ohttp/src/lib.rs index 38e3666..8c07fa3 100644 --- a/ohttp/src/lib.rs +++ b/ohttp/src/lib.rs @@ -15,17 +15,6 @@ mod rand; #[cfg(feature = "rust-hpke")] mod rh; -pub use crate::{ - config::{KeyConfig, SymmetricSuite}, - err::Error, -}; - -use crate::{ - err::Res, - hpke::{Aead as AeadId, Kdf, Kem}, -}; -use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt}; -use log::trace; use std::{ cmp::max, convert::TryFrom, @@ -33,6 +22,9 @@ use std::{ mem::size_of, }; +use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt}; +use log::trace; + #[cfg(feature = "nss")] use crate::nss::random; #[cfg(feature = "nss")] @@ -41,7 +33,6 @@ use crate::nss::{ hkdf::{Hkdf, KeyMechanism}, hpke::{Config as HpkeConfig, Exporter, HpkeR, HpkeS}, }; - #[cfg(feature = "rust-hpke")] use crate::rand::random; #[cfg(feature = "rust-hpke")] @@ -50,6 +41,14 @@ use crate::rh::{ hkdf::{Hkdf, KeyMechanism}, hpke::{Config as HpkeConfig, Exporter, HpkeR, HpkeS}, }; +pub use crate::{ + config::{KeyConfig, SymmetricSuite}, + err::Error, +}; +use crate::{ + err::Res, + hpke::{Aead as AeadId, Kdf, Kem}, +}; /// The request header is a `KeyId` and 2 each for KEM, KDF, and AEAD identifiers const REQUEST_HEADER_LEN: usize = size_of::() + 6; @@ -312,14 +311,16 @@ impl ClientResponse { #[cfg(all(test, feature = "client", feature = "server"))] mod test { + use std::{fmt::Debug, io::ErrorKind}; + + use log::trace; + use crate::{ config::SymmetricSuite, err::Res, hpke::{Aead, Kdf, Kem}, ClientRequest, Error, KeyConfig, KeyId, Server, }; - use log::trace; - use std::{fmt::Debug, io::ErrorKind}; const KEY_ID: KeyId = 1; const KEM: Kem = Kem::X25519Sha256; diff --git a/ohttp/src/nss/aead.rs b/ohttp/src/nss/aead.rs index 18f0b66..aecfc90 100644 --- a/ohttp/src/nss/aead.rs +++ b/ohttp/src/nss/aead.rs @@ -1,3 +1,11 @@ +use std::{ + convert::{TryFrom, TryInto}, + mem, + os::raw::c_int, +}; + +use log::trace; + use super::{ err::secstatus_to_res, p11::{ @@ -13,12 +21,6 @@ use crate::{ err::{Error, Res}, hpke::Aead as AeadId, }; -use log::trace; -use std::{ - convert::{TryFrom, TryInto}, - mem, - os::raw::c_int, -}; /// All the nonces are the same length. Exploit that. pub const NONCE_LEN: usize = 12; diff --git a/ohttp/src/nss/err.rs b/ohttp/src/nss/err.rs index af85066..bc7c86d 100644 --- a/ohttp/src/nss/err.rs +++ b/ohttp/src/nss/err.rs @@ -10,9 +10,10 @@ clippy::module_name_repetitions )] +use std::os::raw::c_char; + use super::{SECStatus, SECSuccess}; use crate::err::Res; -use std::os::raw::c_char; include!(concat!(env!("OUT_DIR"), "/nspr_error.rs")); mod codes { diff --git a/ohttp/src/nss/hkdf.rs b/ohttp/src/nss/hkdf.rs index 470b1dd..af38ba9 100644 --- a/ohttp/src/nss/hkdf.rs +++ b/ohttp/src/nss/hkdf.rs @@ -1,3 +1,7 @@ +use std::{convert::TryFrom, os::raw::c_int, ptr::null_mut}; + +use log::trace; + use super::{ super::hpke::{Aead, Kdf}, p11::{ @@ -10,8 +14,6 @@ use super::{ }, }; use crate::err::Res; -use log::trace; -use std::{convert::TryFrom, os::raw::c_int, ptr::null_mut}; #[derive(Clone, Copy)] pub enum KeyMechanism { diff --git a/ohttp/src/nss/hpke.rs b/ohttp/src/nss/hpke.rs index b7ef845..16d018d 100644 --- a/ohttp/src/nss/hpke.rs +++ b/ohttp/src/nss/hpke.rs @@ -1,10 +1,3 @@ -use super::{ - super::hpke::{Aead, Kdf, Kem}, - err::{sec::SEC_ERROR_INVALID_ARGS, secstatus_to_res, Error}, - p11::{sys, Item, PrivateKey, PublicKey, Slot, SymKey}, -}; -use crate::err::Res; -use log::{log_enabled, trace}; use std::{ convert::TryFrom, ops::Deref, @@ -12,8 +5,16 @@ use std::{ ptr::{addr_of_mut, null, null_mut}, }; +use log::{log_enabled, trace}; pub use sys::{HpkeAeadId as AeadId, HpkeKdfId as KdfId, HpkeKemId as KemId}; +use super::{ + super::hpke::{Aead, Kdf, Kem}, + err::{sec::SEC_ERROR_INVALID_ARGS, secstatus_to_res, Error}, + p11::{sys, Item, PrivateKey, PublicKey, Slot, SymKey}, +}; +use crate::err::Res; + /// Configuration for `Hpke`. #[derive(Clone, Copy)] pub struct Config { diff --git a/ohttp/src/nss/mod.rs b/ohttp/src/nss/mod.rs index 1b60c9e..c9d4c7e 100644 --- a/ohttp/src/nss/mod.rs +++ b/ohttp/src/nss/mod.rs @@ -11,11 +11,13 @@ pub mod aead; pub mod hkdf; pub mod hpke; -pub use self::p11::{random, PrivateKey, PublicKey}; +use std::ptr::null; + use err::secstatus_to_res; pub use err::Error; use lazy_static::lazy_static; -use std::ptr::null; + +pub use self::p11::{random, PrivateKey, PublicKey}; #[allow(clippy::pedantic, non_upper_case_globals, clippy::upper_case_acronyms)] mod nss_init { diff --git a/ohttp/src/nss/p11.rs b/ohttp/src/nss/p11.rs index 9bddef6..19d9420 100644 --- a/ohttp/src/nss/p11.rs +++ b/ohttp/src/nss/p11.rs @@ -4,8 +4,6 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use super::err::{secstatus_to_res, Error}; -use crate::err::Res; use std::{ convert::TryFrom, marker::PhantomData, @@ -14,6 +12,9 @@ use std::{ ptr::null_mut, }; +use super::err::{secstatus_to_res, Error}; +use crate::err::Res; + #[allow( clippy::pedantic, clippy::upper_case_acronyms, diff --git a/ohttp/src/rh/aead.rs b/ohttp/src/rh/aead.rs index f209086..9bfe13a 100644 --- a/ohttp/src/rh/aead.rs +++ b/ohttp/src/rh/aead.rs @@ -1,11 +1,13 @@ #![allow(dead_code)] // TODO: remove -use super::SymKey; -use crate::{err::Res, hpke::Aead as AeadId}; +use std::convert::TryFrom; + use aead::{AeadMut, Key, NewAead, Nonce, Payload}; use aes_gcm::{Aes128Gcm, Aes256Gcm}; use chacha20poly1305::ChaCha20Poly1305; -use std::convert::TryFrom; + +use super::SymKey; +use crate::{err::Res, hpke::Aead as AeadId}; /// All the nonces are the same length. Exploit that. pub const NONCE_LEN: usize = 12; diff --git a/ohttp/src/rh/hkdf.rs b/ohttp/src/rh/hkdf.rs index aeb3a8d..a88f60e 100644 --- a/ohttp/src/rh/hkdf.rs +++ b/ohttp/src/rh/hkdf.rs @@ -1,13 +1,14 @@ #![allow(dead_code)] // TODO: remove +use hkdf::Hkdf as HkdfImpl; +use log::trace; +use sha2::{Sha256, Sha384, Sha512}; + use super::SymKey; use crate::{ err::{Error, Res}, hpke::{Aead, Kdf}, }; -use hkdf::Hkdf as HkdfImpl; -use log::trace; -use sha2::{Sha256, Sha384, Sha512}; #[derive(Clone, Copy)] pub enum KeyMechanism { diff --git a/ohttp/src/rh/hpke.rs b/ohttp/src/rh/hpke.rs index 4b81152..ce8ee78 100644 --- a/ohttp/src/rh/hpke.rs +++ b/ohttp/src/rh/hpke.rs @@ -1,15 +1,13 @@ -use super::SymKey; -use crate::{ - hpke::{Aead, Kdf, Kem}, - Error, Res, -}; +use std::ops::Deref; #[cfg(not(feature = "pq"))] use ::hpke as rust_hpke; - #[cfg(feature = "pq")] use ::hpke_pq as rust_hpke; - +use ::rand::thread_rng; +use log::trace; +#[cfg(feature = "pq")] +use rust_hpke::kem::X25519Kyber768Draft00; use rust_hpke::{ aead::{AeadCtxR, AeadCtxS, AeadTag, AesGcm128, ChaCha20Poly1305}, kdf::HkdfSha256, @@ -17,12 +15,11 @@ use rust_hpke::{ setup_receiver, setup_sender, Deserializable, OpModeR, OpModeS, Serializable, }; -#[cfg(feature = "pq")] -use rust_hpke::kem::X25519Kyber768Draft00; - -use ::rand::thread_rng; -use log::trace; -use std::ops::Deref; +use super::SymKey; +use crate::{ + hpke::{Aead, Kdf, Kem}, + Error, Res, +}; /// Configuration for `Hpke`. #[derive(Clone, Copy)] diff --git a/pre-commit b/pre-commit index 758b923..18e917d 100755 --- a/pre-commit +++ b/pre-commit @@ -6,10 +6,12 @@ # $ ln -s ../../hooks/pre-commit .git/hooks/pre-commit root="$(git rev-parse --show-toplevel 2>/dev/null)" +RUST_FMT_CFG="imports_granularity=Crate,group_imports=StdExternalCrate" # Some sanity checking. -hash cargo || exit 1 -[[ -n "$root" ]] || exit 1 +set -e +hash cargo +[[ -n "$root" ]] # Installation. if [[ "$1" == "install" ]]; then @@ -23,31 +25,54 @@ if [[ "$1" == "install" ]]; then exit fi -# Check formatting. +# Stash unstaged changes if [[ "$1" != "all" ]]; then - msg="pre-commit stash @$(git rev-parse --short @) $RANDOM" - trap 'git stash list -1 --format="format:%s" | grep -q "'"$msg"'" && git stash pop -q' EXIT - git stash push -k -u -q -m "$msg" + stashdir="$(mktemp -d "$root"/.pre-commit.stashXXXXXX)" + msg="pre-commit stash @$(git rev-parse --short @) ${stashdir##*.stash}" + gitdir="$(git rev-parse --git-dir 2>/dev/null)" + + stash() { + # Move MERGE_[HEAD|MODE|MSG] files to the root directory, and let `git stash push` save them. + find "$gitdir" -maxdepth 1 -name 'MERGE_*' -exec mv \{\} "$stashdir" \; + git stash push -k -u -q -m "$msg" + } + + unstash() { + git stash list -1 --format="format:%s" | grep -q "$msg" && git stash pop -q + # Moves MERGE files restored by `git stash pop` back into .git/ directory. + if [[ -d "$stashdir" ]]; then + find "$stashdir" -exec mv -n \{\} "$gitdir" \; + rmdir "$stashdir" + fi + } + + trap unstash EXIT + stash fi -if ! errors=($(cargo fmt -- --check --config imports_granularity=crate -l)); then - echo "Formatting errors found." - echo "Run \`cargo fmt\` to fix the following files:" + +# Check formatting +if ! errors=($(cargo fmt -- --check --config "$RUST_FMT_CFG" -l)); then + echo "Formatting errors found in:" for err in "${errors[@]}"; do echo " $err" done + echo "To fix, run \`cargo fmt -- --config $RUST_FMT_CFG\`" exit 1 fi -if ! cargo clippy --tests; then - exit 1 -fi -if ! cargo test; then - exit 1 -fi -if [[ -n "$NSS_DIR" ]]; then - if ! cargo clippy --tests --no-default-features --features nss; then - exit 1 - fi - if ! cargo test --no-default-features --features nss; then + +check() { + msg="$1" + shift + if ! echo "$@"; then + echo "${msg}: Failed command:" + echo " ${@@Q}" exit 1 fi +} + +check "clippy" cargo clippy --tests +check "test" cargo test +if [[ -n "$NSS_DIR" ]]; then + check "clippy(NSS)" cargo clippy --tests --no-default-features --features nss + check "test(NSS)" cargo test --no-default-features --features nss fi From c91933c7c21e066ed313b40d87e83af0a63612eb Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Sat, 2 Nov 2024 10:34:24 +0000 Subject: [PATCH 13/20] Refactor stream helper functions --- Cargo.toml | 2 +- bhttp/Cargo.toml | 5 ++++- bhttp/src/stream/int.rs | 7 +++---- sync-async/Cargo.toml | 12 ++++++++++++ bhttp/src/stream/future.rs => sync-async/src/lib.rs | 11 ++++++----- 5 files changed, 26 insertions(+), 11 deletions(-) create mode 100644 sync-async/Cargo.toml rename bhttp/src/stream/future.rs => sync-async/src/lib.rs (92%) diff --git a/Cargo.toml b/Cargo.toml index 0621518..fb7825e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,5 +6,5 @@ members = [ "ohttp", "ohttp-client", "ohttp-client-cli", - "ohttp-server", + "ohttp-server", "sync-async", ] diff --git a/bhttp/Cargo.toml b/bhttp/Cargo.toml index 21776f9..fe3d325 100644 --- a/bhttp/Cargo.toml +++ b/bhttp/Cargo.toml @@ -20,4 +20,7 @@ thiserror = "1" url = {version = "2", optional = true} [dev-dependencies] -hex = "0.4" \ No newline at end of file +hex = "0.4" + +[dev-dependencies.sync-async] +path= "../sync-async" diff --git a/bhttp/src/stream/int.rs b/bhttp/src/stream/int.rs index 5b248df..f70231b 100644 --- a/bhttp/src/stream/int.rs +++ b/bhttp/src/stream/int.rs @@ -135,13 +135,12 @@ pub fn read_varint(src: S) -> ReadVarint { #[cfg(test)] mod test { + use sync_async::SyncResolve; + use crate::{ err::Error, rw::{write_uint as sync_write_uint, write_varint as sync_write_varint}, - stream::{ - future::SyncResolve, - int::{read_uint, read_varint}, - }, + stream::int::{read_uint, read_varint}, }; const VARINTS: &[u64] = &[ diff --git a/sync-async/Cargo.toml b/sync-async/Cargo.toml new file mode 100644 index 0000000..c419df1 --- /dev/null +++ b/sync-async/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "sync-async" +version = "0.5.3" +authors = ["Martin Thomson "] +edition = "2021" +license = "MIT OR Apache-2.0" +description = "Synchronous Helpers for Async Code" +repository = "https://github.com/martinthomson/ohttp" + +[dependencies] +futures = "0.3" +pin-project = "1.1" \ No newline at end of file diff --git a/bhttp/src/stream/future.rs b/sync-async/src/lib.rs similarity index 92% rename from bhttp/src/stream/future.rs rename to sync-async/src/lib.rs index 9d0362c..e21ca07 100644 --- a/bhttp/src/stream/future.rs +++ b/sync-async/src/lib.rs @@ -6,8 +6,6 @@ use std::{ use futures::{AsyncRead, AsyncReadExt, TryStream, TryStreamExt}; -use crate::Error; - fn noop_context() -> Context<'static> { use std::{ ptr::null, @@ -26,6 +24,7 @@ fn noop_context() -> Context<'static> { } pub fn noop_waker_ref() -> &'static Waker { + #[repr(transparent)] struct SyncRawWaker(RawWaker); unsafe impl Sync for SyncRawWaker {} @@ -72,14 +71,16 @@ impl SyncResolve for F { pub trait SyncCollect { type Item; + type Error; - fn sync_collect(self) -> Result, Error>; + fn sync_collect(self) -> Result, Self::Error>; } -impl> SyncCollect for S { +impl SyncCollect for S { type Item = S::Ok; + type Error = S::Error; - fn sync_collect(self) -> Result, Error> { + fn sync_collect(self) -> Result, Self::Error> { pin!(self.try_collect::>()).sync_resolve() } } From 880d958e94341eefa01e7cb38cffd98ed8630c28 Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Sat, 2 Nov 2024 10:51:42 +0000 Subject: [PATCH 14/20] Move to use OnceLock --- ohttp/Cargo.toml | 3 ++- ohttp/src/nss/mod.rs | 23 ++++++++++------------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/ohttp/Cargo.toml b/ohttp/Cargo.toml index 237e3b1..5e8c484 100644 --- a/ohttp/Cargo.toml +++ b/ohttp/Cargo.toml @@ -20,6 +20,7 @@ pq = ["dep:hpke-pq"] regex-mess = ["dep:regex", "dep:regex-automata", "dep:regex-syntax"] rust-hpke = ["dep:rand", "dep:aead", "dep:aes-gcm", "dep:chacha20poly1305", "dep:hkdf", "dep:sha2", "dep:hpke"] server = [] +stream = [] [dependencies] aead = {version = "0.4", optional = true, features = ["std"]} @@ -29,7 +30,6 @@ chacha20poly1305 = {version = "0.8", optional = true} hex = "0.4" hkdf = {version = "0.11", optional = true} hpke = {version = "0.11.0", optional = true, default-features = false, features = ["std", "x25519"]} -lazy_static = "1.4" log = {version = "0.4", default-features = false} rand = {version = "0.8", optional = true} # bindgen uses regex and friends, which have been updated past our MSRV @@ -64,3 +64,4 @@ features = ["runtime"] [dev-dependencies] env_logger = {version = "0.10", default-features = false} +sync-async = {path = "../sync-async"} diff --git a/ohttp/src/nss/mod.rs b/ohttp/src/nss/mod.rs index c9d4c7e..c282a13 100644 --- a/ohttp/src/nss/mod.rs +++ b/ohttp/src/nss/mod.rs @@ -15,7 +15,6 @@ use std::ptr::null; use err::secstatus_to_res; pub use err::Error; -use lazy_static::lazy_static; pub use self::p11::{random, PrivateKey, PublicKey}; @@ -47,17 +46,7 @@ impl Drop for NssLoaded { } } -lazy_static! { - static ref INITIALIZED: NssLoaded = { - if already_initialized() { - return NssLoaded::External; - } - - secstatus_to_res(unsafe { nss_init::NSS_NoDB_Init(null()) }).expect("NSS_NoDB_Init failed"); - - NssLoaded::NoDb - }; -} +static INITIALIZED: OnceLock = OnceLock::new(); fn already_initialized() -> bool { unsafe { nss_init::NSS_IsInitialized() != 0 } @@ -65,5 +54,13 @@ fn already_initialized() -> bool { /// Initialize NSS. This only executes the initialization routines once. pub fn init() { - lazy_static::initialize(&INITIALIZED); + INITIALIZED.get_or_init(|| { + if already_initialized() { + NssLoaded::External + } else { + secstatus_to_res(unsafe { nss_init::NSS_NoDB_Init(null()) }) + .expect("NSS_NoDB_Init failed"); + NssLoaded::NoDb + } + }); } From 69402f463e8d041aecbb6a08d6196e35620febd6 Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Mon, 11 Nov 2024 12:01:39 +0000 Subject: [PATCH 15/20] Checkpoint for request streaming --- ohttp/Cargo.toml | 6 +++-- ohttp/src/lib.rs | 68 ++++++++++++++++++++++++++++++++---------------- 2 files changed, 50 insertions(+), 24 deletions(-) diff --git a/ohttp/Cargo.toml b/ohttp/Cargo.toml index 5e8c484..01a2027 100644 --- a/ohttp/Cargo.toml +++ b/ohttp/Cargo.toml @@ -10,7 +10,7 @@ description = "Oblivious HTTP" repository = "https://github.com/martinthomson/ohttp" [features] -default = ["client", "server", "rust-hpke"] +default = ["client", "server", "rust-hpke", "stream"] app-svc = ["nss"] client = [] external-sqlite = [] @@ -20,17 +20,19 @@ pq = ["dep:hpke-pq"] regex-mess = ["dep:regex", "dep:regex-automata", "dep:regex-syntax"] rust-hpke = ["dep:rand", "dep:aead", "dep:aes-gcm", "dep:chacha20poly1305", "dep:hkdf", "dep:sha2", "dep:hpke"] server = [] -stream = [] +stream = ["dep:futures", "dep:pin-project"] [dependencies] aead = {version = "0.4", optional = true, features = ["std"]} aes-gcm = {version = "0.9", optional = true} byteorder = "1.4" chacha20poly1305 = {version = "0.8", optional = true} +futures = {version = "0.3", optional = true} hex = "0.4" hkdf = {version = "0.11", optional = true} hpke = {version = "0.11.0", optional = true, default-features = false, features = ["std", "x25519"]} log = {version = "0.4", default-features = false} +pin-project = {version = "1.1", optional = true} rand = {version = "0.8", optional = true} # bindgen uses regex and friends, which have been updated past our MSRV # however, the cargo resolver happily resolves versions that it can't compile diff --git a/ohttp/src/lib.rs b/ohttp/src/lib.rs index 8c07fa3..0ce3e79 100644 --- a/ohttp/src/lib.rs +++ b/ohttp/src/lib.rs @@ -14,6 +14,8 @@ mod nss; mod rand; #[cfg(feature = "rust-hpke")] mod rh; +#[cfg(feature = "stream")] +mod stream; use std::{ cmp::max, @@ -23,7 +25,10 @@ use std::{ }; use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt}; +#[cfg(feature = "stream")] +use futures::AsyncRead; use log::trace; +use rh::hpke::PublicKey; #[cfg(feature = "nss")] use crate::nss::random; @@ -41,6 +46,8 @@ use crate::rh::{ hkdf::{Hkdf, KeyMechanism}, hpke::{Config as HpkeConfig, Exporter, HpkeR, HpkeS}, }; +#[cfg(feature = "stream")] +use crate::stream::ClientRequestStream; pub use crate::{ config::{KeyConfig, SymmetricSuite}, err::Error, @@ -53,8 +60,6 @@ use crate::{ /// The request header is a `KeyId` and 2 each for KEM, KDF, and AEAD identifiers const REQUEST_HEADER_LEN: usize = size_of::() + 6; const INFO_REQUEST: &[u8] = b"message/bhttp request"; -/// The info used for HPKE export is `INFO_REQUEST`, a zero byte, and the header. -const INFO_LEN: usize = INFO_REQUEST.len() + 1 + REQUEST_HEADER_LEN; const LABEL_RESPONSE: &[u8] = b"message/bhttp response"; const INFO_KEY: &[u8] = b"key"; const INFO_NONCE: &[u8] = b"nonce"; @@ -68,9 +73,9 @@ pub fn init() { } /// Construct the info parameter we use to initialize an `HpkeS` instance. -fn build_info(key_id: KeyId, config: HpkeConfig) -> Res> { - let mut info = Vec::with_capacity(INFO_LEN); - info.extend_from_slice(INFO_REQUEST); +fn build_info(label: &[u8], key_id: KeyId, config: HpkeConfig) -> Res> { + let mut info = Vec::with_capacity(label.len() + 1 + REQUEST_HEADER_LEN); + info.extend_from_slice(label); info.push(0); info.write_u8(key_id)?; info.write_u16::(u16::from(config.kem()))?; @@ -84,8 +89,9 @@ fn build_info(key_id: KeyId, config: HpkeConfig) -> Res> { /// This might not be necessary if we agree on a format. #[cfg(feature = "client")] pub struct ClientRequest { - hpke: HpkeS, - header: Vec, + key_id: KeyId, + config: HpkeConfig, + pk: PublicKey, } #[cfg(feature = "client")] @@ -94,14 +100,11 @@ impl ClientRequest { pub fn from_config(config: &mut KeyConfig) -> Res { // TODO(mt) choose the best config, not just the first. let selected = config.select(config.symmetric[0])?; - - // Build the info, which contains the message header. - let info = build_info(config.key_id, selected)?; - let hpke = HpkeS::new(selected, &mut config.pk, &info)?; - - let header = Vec::from(&info[INFO_REQUEST.len() + 1..]); - debug_assert_eq!(header.len(), REQUEST_HEADER_LEN); - Ok(Self { hpke, header }) + Ok(Self { + key_id: config.key_id, + config: selected, + pk: config.pk.clone(), + }) } /// Reads an encoded configuration and constructs a single use client sender. @@ -126,21 +129,41 @@ impl ClientRequest { /// Encapsulate a request. This consumes this object. /// This produces a response handler and the bytes of an encapsulated request. pub fn encapsulate(mut self, request: &[u8]) -> Res<(Vec, ClientResponse)> { - let extra = - self.hpke.config().kem().n_enc() + self.hpke.config().aead().n_t() + request.len(); - let expected_len = self.header.len() + extra; + // Build the info, which contains the message header. + let info = build_info(INFO_REQUEST, self.key_id, self.config)?; + let mut hpke = HpkeS::new(self.config, &mut self.pk, &info)?; - let mut enc_request = self.header; + let header = Vec::from(&info[INFO_REQUEST.len() + 1..]); + debug_assert_eq!(header.len(), REQUEST_HEADER_LEN); + + let extra = hpke.config().kem().n_enc() + hpke.config().aead().n_t() + request.len(); + let expected_len = header.len() + extra; + + let mut enc_request = header; enc_request.reserve_exact(extra); - let enc = self.hpke.enc()?; + let enc = hpke.enc()?; enc_request.extend_from_slice(&enc); - let mut ct = self.hpke.seal(&[], request)?; + let mut ct = hpke.seal(&[], request)?; enc_request.append(&mut ct); debug_assert_eq!(expected_len, enc_request.len()); - Ok((enc_request, ClientResponse::new(self.hpke, enc))) + Ok((enc_request, ClientResponse::new(hpke, enc))) + } + + #[cfg(feature = "stream")] + pub fn encapsulate_stream(mut self, src: S) -> Res> { + let info = build_info(crate::stream::INFO_REQUEST, self.key_id, self.config)?; + let hpke = HpkeS::new(self.config, &mut self.pk, &info)?; + + let mut header = Vec::from(&info[crate::stream::INFO_REQUEST.len() + 1..]); + debug_assert_eq!(header.len(), REQUEST_HEADER_LEN); + + let mut e = hpke.enc()?; + header.append(&mut e); + + Ok(ClientRequestStream::new(src, hpke, header)) } } @@ -191,6 +214,7 @@ impl Server { let sym = SymmetricSuite::new(kdf_id, aead_id); let info = build_info( + INFO_REQUEST, key_id, HpkeConfig::new(self.config.kem, sym.kdf(), sym.aead()), )?; From 7222a26f7ddbedd53690252893d933963501da2d Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Fri, 13 Dec 2024 05:41:28 +0000 Subject: [PATCH 16/20] Add missing file --- ohttp/src/stream.rs | 110 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 ohttp/src/stream.rs diff --git a/ohttp/src/stream.rs b/ohttp/src/stream.rs new file mode 100644 index 0000000..1f9bd90 --- /dev/null +++ b/ohttp/src/stream.rs @@ -0,0 +1,110 @@ +use std::{ + cmp::min, + io::{Error as IoError, Result as IoResult}, + pin::Pin, + task::{Context, Poll}, +}; + +use futures::AsyncRead; + +use crate::HpkeS; + +pub(crate) const INFO_REQUEST: &[u8] = b"message/bhttp chunked request"; + +fn write_len(w: &mut [u8], len: usize) -> usize { + let v: u64 = len.try_into().unwrap(); + let (v, len) = match () { + () if v < (1 << 6) => (v, 1), + () if v < (1 << 14) => (v | 1 << 14, 2), + () if v < (1 << 30) => (v | (2 << 30), 4), + () if v < (1 << 62) => (v | (3 << 62), 8), + () => panic!("varint value too large"), + }; + w[..len].copy_from_slice(&v.to_be_bytes()[(8 - len)..]); + len +} + +#[pin_project::pin_project] +pub struct ClientRequestStream { + #[pin] + src: S, + hpke: HpkeS, + buf: Vec, +} + +impl ClientRequestStream { + pub fn new(src: S, hpke: HpkeS, header: Vec) -> Self { + Self { + src, + hpke, + buf: header, + } + } +} + +impl AsyncRead for ClientRequestStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: &mut [u8], + ) -> Poll> { + let this = self.project(); + // We have buffered data, so dump it into the output directly. + let mut written = if this.buf.is_empty() { + 0 + } else { + let amnt = min(this.buf.len(), buf.len()); + buf[..amnt].copy_from_slice(&this.buf[..amnt]); + buf = &mut buf[amnt..]; + *this.buf = this.buf.split_off(amnt); + if buf.is_empty() { + return Poll::Ready(Ok(amnt)); + } + amnt + }; + + // Now read into the buffer. + // Because we are expanding the data, when the buffer we are provided is too small, + // we have to use a temporary buffer so that we can save some bytes. + let mut tmp = [0; 64]; + let read_buf = if buf.len() < tmp.len() { + // Use the provided buffer, but leave room for AEAD tag and a varint. + let read_len = min(buf.len(), 1 << 62) - this.hpke.aead().n_t(); + &mut buf[8..read_len] + } else { + &mut tmp[..] + }; + let (aad, len): (&[u8], _) = match this.src.poll_read(cx, read_buf) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Ok(0)) => (&b"final"[..], 0), + Poll::Ready(Ok(len)) => (&[], len), + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + }; + + let ct = this + .hpke + .seal(aad, &mut read_buf[..len]) + .map_err(IoError::other)?; + + // Now we need to write the length of the chunk. + let len_len = write_len(&mut tmp, ct.len()); + if len_len <= buf.len() { + // If the length fits in the buffer, that's easy. + buf[..len_len].copy_from_slice(&tmp[..len_len]); + written += len_len; + buf = &mut buf[len_len..]; + } else { + // Otherwise, we need to save any remainder in our own buffer. + buf.copy_from_slice(&tmp[..buf.len()]); + this.buf.extend_from_slice(&tmp[buf.len()..len_len]); + let amnt = buf.len(); + written += amnt; + buf = &mut buf[amnt..]; + } + + let amnt = min(ct.len(), buf.len()); + buf[..amnt].copy_from_slice(&ct[..amnt]); + this.buf.extend_from_slice(&ct[amnt..]); + Poll::Ready(Ok(amnt + written)) + } +} From 9bea3113ddaf53c5ddd9bb2e64da7c2955cb874b Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Thu, 19 Dec 2024 12:06:12 +1100 Subject: [PATCH 17/20] Checkpoint --- Cargo.toml | 3 +- bhttp/src/stream/mod.rs | 11 +- bhttp/src/stream/vec.rs | 7 +- ohttp/Cargo.toml | 2 +- ohttp/src/config.rs | 9 +- ohttp/src/err.rs | 6 + ohttp/src/lib.rs | 158 ++++++----- ohttp/src/nss/aead.rs | 42 ++- ohttp/src/nss/mod.rs | 4 +- ohttp/src/rh/aead.rs | 15 +- ohttp/src/stream.rs | 570 ++++++++++++++++++++++++++++++++++++++-- sync-async/src/lib.rs | 94 ++++++- 12 files changed, 787 insertions(+), 134 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index fb7825e..eb9b298 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,5 +6,6 @@ members = [ "ohttp", "ohttp-client", "ohttp-client-cli", - "ohttp-server", "sync-async", + "ohttp-server", + "sync-async", ] diff --git a/bhttp/src/stream/mod.rs b/bhttp/src/stream/mod.rs index 7496d8f..009b1eb 100644 --- a/bhttp/src/stream/mod.rs +++ b/bhttp/src/stream/mod.rs @@ -16,8 +16,6 @@ use crate::{ stream::{int::read_varint, vec::read_vec}, ControlData, Error, Field, FieldSection, Header, InformationalResponse, Message, Mode, COOKIE, }; -#[cfg(test)] -mod future; mod int; mod vec; @@ -413,14 +411,9 @@ mod test { use std::pin::pin; use futures::TryStreamExt; + use sync_async::{Dribble, SyncRead, SyncResolve, SyncTryCollect}; - use crate::{ - stream::{ - future::{Dribble, SyncCollect, SyncRead, SyncResolve}, - AsyncReadMessage, - }, - Error, Message, - }; + use crate::{stream::AsyncReadMessage, Error, Message}; // Example from Section 5.1 of RFC 9292. const REQUEST1: &[u8] = &[ diff --git a/bhttp/src/stream/vec.rs b/bhttp/src/stream/vec.rs index 05a4e24..a0989a9 100644 --- a/bhttp/src/stream/vec.rs +++ b/bhttp/src/stream/vec.rs @@ -121,12 +121,9 @@ mod test { }; use futures::AsyncRead; + use sync_async::SyncResolve; - use crate::{ - rw::write_varint as sync_write_varint, - stream::{future::SyncResolve, vec::read_vec}, - Error, - }; + use crate::{rw::write_varint as sync_write_varint, stream::vec::read_vec, Error}; const FILL_VALUE: u8 = 90; diff --git a/ohttp/Cargo.toml b/ohttp/Cargo.toml index 01a2027..17ab467 100644 --- a/ohttp/Cargo.toml +++ b/ohttp/Cargo.toml @@ -10,7 +10,7 @@ description = "Oblivious HTTP" repository = "https://github.com/martinthomson/ohttp" [features] -default = ["client", "server", "rust-hpke", "stream"] +default = ["client", "server", "nss", "stream"] app-svc = ["nss"] client = [] external-sqlite = [] diff --git a/ohttp/src/config.rs b/ohttp/src/config.rs index ebadb3c..0f6fab3 100644 --- a/ohttp/src/config.rs +++ b/ohttp/src/config.rs @@ -95,18 +95,15 @@ impl KeyConfig { Self::strip_unsupported(&mut symmetric, kem); assert!(!symmetric.is_empty()); let (sk, pk) = derive_key_pair(kem, ikm)?; - Ok(Self { + return Ok(Self { key_id, kem, symmetric, sk: Some(sk), pk, - }) - } - #[cfg(not(feature = "rust-hpke"))] - { - Err(Error::Unsupported) + }); } + Err(Error::Unsupported) } /// Encode a list of key configurations. diff --git a/ohttp/src/err.rs b/ohttp/src/err.rs index 3c6ebd2..c52f56b 100644 --- a/ohttp/src/err.rs +++ b/ohttp/src/err.rs @@ -5,9 +5,15 @@ pub enum Error { #[cfg(feature = "rust-hpke")] #[error("a problem occurred with the AEAD")] Aead(#[from] aead::Error), + #[cfg(feature = "stream")] + #[error("a stream chunk was larger than the maximum allowed size")] + ChunkTooLarge, #[cfg(feature = "nss")] #[error("a problem occurred during cryptographic processing: {0}")] Crypto(#[from] crate::nss::Error), + #[cfg(feature = "stream")] + #[error("a stream contained data after the last chunk")] + ExtraData, #[error("an error was found in the format")] Format, #[cfg(all(feature = "rust-hpke", not(feature = "pq")))] diff --git a/ohttp/src/lib.rs b/ohttp/src/lib.rs index 0ce3e79..6e9857d 100644 --- a/ohttp/src/lib.rs +++ b/ohttp/src/lib.rs @@ -4,6 +4,8 @@ not(all(feature = "client", feature = "server")), allow(dead_code, unused_imports) )] +#[cfg(all(feature = "nss", feature = "rust-hpke"))] +compile_error!("features \"nss\" and \"rust-hpke\" are mutually incompatible"); mod config; mod err; @@ -20,34 +22,22 @@ mod stream; use std::{ cmp::max, convert::TryFrom, - io::{BufReader, Read}, + io::{Cursor, Read}, mem::size_of, }; use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt}; -#[cfg(feature = "stream")] -use futures::AsyncRead; use log::trace; -use rh::hpke::PublicKey; -#[cfg(feature = "nss")] -use crate::nss::random; #[cfg(feature = "nss")] use crate::nss::{ aead::{Aead, Mode, NONCE_LEN}, hkdf::{Hkdf, KeyMechanism}, hpke::{Config as HpkeConfig, Exporter, HpkeR, HpkeS}, -}; -#[cfg(feature = "rust-hpke")] -use crate::rand::random; -#[cfg(feature = "rust-hpke")] -use crate::rh::{ - aead::{Aead, Mode, NONCE_LEN}, - hkdf::{Hkdf, KeyMechanism}, - hpke::{Config as HpkeConfig, Exporter, HpkeR, HpkeS}, + random, PublicKey, SymKey, }; #[cfg(feature = "stream")] -use crate::stream::ClientRequestStream; +use crate::stream::{ClientRequest as StreamClient, ServerRequest as StreamServer}; pub use crate::{ config::{KeyConfig, SymmetricSuite}, err::Error, @@ -56,6 +46,16 @@ use crate::{ err::Res, hpke::{Aead as AeadId, Kdf, Kem}, }; +#[cfg(feature = "rust-hpke")] +use crate::{ + rand::random, + rh::{ + aead::{Aead, Mode, NONCE_LEN}, + hkdf::{Hkdf, KeyMechanism}, + hpke::{Config as HpkeConfig, Exporter, HpkeR, HpkeS, PublicKey}, + SymKey, + }, +}; /// The request header is a `KeyId` and 2 each for KEM, KDF, and AEAD identifiers const REQUEST_HEADER_LEN: usize = size_of::() + 6; @@ -153,17 +153,8 @@ impl ClientRequest { } #[cfg(feature = "stream")] - pub fn encapsulate_stream(mut self, src: S) -> Res> { - let info = build_info(crate::stream::INFO_REQUEST, self.key_id, self.config)?; - let hpke = HpkeS::new(self.config, &mut self.pk, &info)?; - - let mut header = Vec::from(&info[crate::stream::INFO_REQUEST.len() + 1..]); - debug_assert_eq!(header.len(), REQUEST_HEADER_LEN); - - let mut e = hpke.enc()?; - header.append(&mut e); - - Ok(ClientRequestStream::new(src, hpke, header)) + pub fn encapsulate_stream(self, dst: S) -> Res> { + StreamClient::start(dst, self.config, self.key_id, self.pk) } } @@ -192,15 +183,8 @@ impl Server { &self.config } - /// Remove encapsulation on a message. - /// # Panics - /// Not as a consequence of this code, but Rust won't know that for sure. #[allow(clippy::similar_names)] // for kem_id and key_id - pub fn decapsulate(&self, enc_request: &[u8]) -> Res<(Vec, ServerResponse)> { - if enc_request.len() < REQUEST_HEADER_LEN { - return Err(Error::Truncated); - } - let mut r = BufReader::new(enc_request); + fn decode_hpke_config(&self, r: &mut Cursor<&[u8]>) -> Res { let key_id = r.read_u8()?; if key_id != self.config.key_id { return Err(Error::KeyId); @@ -211,50 +195,72 @@ impl Server { } let kdf_id = Kdf::try_from(r.read_u16::()?)?; let aead_id = AeadId::try_from(r.read_u16::()?)?; - let sym = SymmetricSuite::new(kdf_id, aead_id); + let hpke_config = HpkeConfig::new(self.config.kem, kdf_id, aead_id); + Ok(hpke_config) + } - let info = build_info( - INFO_REQUEST, - key_id, - HpkeConfig::new(self.config.kem, sym.kdf(), sym.aead()), - )?; + fn decode_request_header(&self, r: &mut Cursor<&[u8]>, label: &[u8]) -> Res<(HpkeR, Vec)> { + let hpke_config = self.decode_hpke_config(r)?; + let sym = SymmetricSuite::new(hpke_config.kdf(), hpke_config.aead()); + let config = self.config.select(sym)?; + let info = build_info(label, self.config.key_id, hpke_config)?; - let cfg = self.config.select(sym)?; - let mut enc = vec![0; cfg.kem().n_enc()]; + let mut enc = vec![0; config.kem().n_enc()]; r.read_exact(&mut enc)?; - let mut hpke = HpkeR::new( - cfg, - &self.config.pk, - self.config.sk.as_ref().unwrap(), - &enc, - &info, - )?; - let mut ct = Vec::new(); - r.read_to_end(&mut ct)?; + Ok(( + HpkeR::new( + config, + &self.config.pk, + self.config.sk.as_ref().unwrap(), + &enc, + &info, + )?, + enc, + )) + } - let request = hpke.open(&[], &ct)?; + /// Remove encapsulation on a request. + /// # Panics + /// Not as a consequence of this code, but Rust won't know that for sure. + pub fn decapsulate(&self, enc_request: &[u8]) -> Res<(Vec, ServerResponse)> { + if enc_request.len() <= REQUEST_HEADER_LEN { + return Err(Error::Truncated); + } + let mut r = Cursor::new(enc_request); + let (mut hpke, enc) = self.decode_request_header(&mut r, INFO_REQUEST)?; + + let request = hpke.open(&[], &enc_request[usize::try_from(r.position())?..])?; Ok((request, ServerResponse::new(&hpke, enc)?)) } + + /// Remove encapsulation on a streamed request. + #[cfg(feature = "stream")] + pub fn decapsulate_stream(self, src: S) -> StreamServer { + StreamServer::new(self.config, src) + } } fn entropy(config: HpkeConfig) -> usize { max(config.aead().n_n(), config.aead().n_k()) } +fn export_secret(exp: &E, label: &[u8], cfg: HpkeConfig) -> Res { + exp.export(label, entropy(cfg)) +} + fn make_aead( mode: Mode, cfg: HpkeConfig, - exp: &impl Exporter, + secret: &SymKey, enc: Vec, - response_nonce: &[u8], + nonce: &[u8], ) -> Res { - let secret = exp.export(LABEL_RESPONSE, entropy(cfg))?; let mut salt = enc; - salt.extend_from_slice(response_nonce); + salt.extend_from_slice(nonce); let hkdf = Hkdf::new(cfg.kdf()); - let prk = hkdf.extract(&salt, &secret)?; + let prk = hkdf.extract(&salt, secret)?; let key = hkdf.expand_key(&prk, INFO_KEY, KeyMechanism::Aead(cfg.aead()))?; let iv = hkdf.expand_data(&prk, INFO_NONCE, cfg.aead().n_n())?; @@ -275,7 +281,13 @@ pub struct ServerResponse { impl ServerResponse { fn new(hpke: &HpkeR, enc: Vec) -> Res { let response_nonce = random(entropy(hpke.config())); - let aead = make_aead(Mode::Encrypt, hpke.config(), hpke, enc, &response_nonce)?; + let aead = make_aead( + Mode::Encrypt, + hpke.config(), + &export_secret(hpke, LABEL_RESPONSE, hpke.config())?, + enc, + &response_nonce, + )?; Ok(Self { response_nonce, aead, @@ -325,11 +337,11 @@ impl ClientResponse { let mut aead = make_aead( Mode::Decrypt, self.hpke.config(), - &self.hpke, + &export_secret(&self.hpke, LABEL_RESPONSE, self.hpke.config())?, self.enc, response_nonce, )?; - aead.open(&[], 0, ct) // 0 is the sequence number + aead.open(&[], ct) // 0 is the sequence number } } @@ -346,29 +358,33 @@ mod test { ClientRequest, Error, KeyConfig, KeyId, Server, }; - const KEY_ID: KeyId = 1; - const KEM: Kem = Kem::X25519Sha256; - const SYMMETRIC: &[SymmetricSuite] = &[ + pub const KEY_ID: KeyId = 1; + pub const KEM: Kem = Kem::X25519Sha256; + pub const SYMMETRIC: &[SymmetricSuite] = &[ SymmetricSuite::new(Kdf::HkdfSha256, Aead::Aes128Gcm), SymmetricSuite::new(Kdf::HkdfSha256, Aead::ChaCha20Poly1305), ]; - const REQUEST: &[u8] = &[ + pub const REQUEST: &[u8] = &[ 0x00, 0x03, 0x47, 0x45, 0x54, 0x05, 0x68, 0x74, 0x74, 0x70, 0x73, 0x0b, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x63, 0x6f, 0x6d, 0x01, 0x2f, ]; - const RESPONSE: &[u8] = &[0x01, 0x40, 0xc8]; + pub const RESPONSE: &[u8] = &[0x01, 0x40, 0xc8]; - fn init() { + pub fn init() { crate::init(); _ = env_logger::try_init(); // ignore errors here } + pub fn make_config() -> KeyConfig { + KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)).unwrap() + } + #[test] fn request_response() { init(); - let server_config = KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)).unwrap(); + let server_config = make_config(); let server = Server::new(server_config).unwrap(); let encoded_config = server.config().encode().unwrap(); trace!("Config: {}", hex::encode(&encoded_config)); @@ -393,7 +409,7 @@ mod test { fn two_requests() { init(); - let server_config = KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)).unwrap(); + let server_config = make_config(); let server = Server::new(server_config).unwrap(); let encoded_config = server.config().encode().unwrap(); @@ -433,7 +449,7 @@ mod test { fn request_truncated(cut: usize) { init(); - let server_config = KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)).unwrap(); + let server_config = make_config(); let server = Server::new(server_config).unwrap(); let encoded_config = server.config().encode().unwrap(); @@ -464,7 +480,7 @@ mod test { fn response_truncated(cut: usize) { init(); - let server_config = KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)).unwrap(); + let server_config = make_config(); let server = Server::new(server_config).unwrap(); let encoded_config = server.config().encode().unwrap(); @@ -523,7 +539,7 @@ mod test { fn request_from_config_list() { init(); - let server_config = KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)).unwrap(); + let server_config = make_config(); let server = Server::new(server_config).unwrap(); let encoded_config = server.config().encode().unwrap(); diff --git a/ohttp/src/nss/aead.rs b/ohttp/src/nss/aead.rs index aecfc90..d389b99 100644 --- a/ohttp/src/nss/aead.rs +++ b/ohttp/src/nss/aead.rs @@ -159,12 +159,44 @@ impl Aead { Ok(ct) } - pub fn open(&mut self, aad: &[u8], seq: SequenceNumber, ct: &[u8]) -> Res> { + pub fn open(&mut self, aad: &[u8], ct: &[u8]) -> Res> { + assert_eq!(self.mode, Mode::Decrypt); + let mut nonce = self.nonce_base; + let mut pt = vec![0; ct.len()]; // NSS needs more space than it uses for plaintext. + let mut pt_len: c_int = 0; + let pt_expected = ct.len().checked_sub(TAG_LEN).ok_or(Error::Truncated)?; + secstatus_to_res(unsafe { + PK11_AEADOp( + *self.ctx, + CK_GENERATOR_FUNCTION::from(CKG_GENERATE_COUNTER_XOR), + c_int_len(NONCE_LEN - COUNTER_LEN), // Fixed portion of the nonce. + nonce.as_mut_ptr(), + c_int_len(nonce.len()), + aad.as_ptr(), + c_int_len(aad.len()), + pt.as_mut_ptr(), + &mut pt_len, + c_int_len(pt.len()), // signed :( + ct.as_ptr().add(pt_expected).cast_mut(), // const cast :( + c_int_len(TAG_LEN), + ct.as_ptr(), + c_int_len(pt_expected), + ) + })?; + let len = usize::try_from(pt_len).unwrap(); + debug_assert_eq!(len, pt_expected); + pt.truncate(len); + Ok(pt) + } + + #[allow(dead_code)] + pub fn open_seq(&mut self, aad: &[u8], seq: SequenceNumber, ct: &[u8]) -> Res> { assert_eq!(self.mode, Mode::Decrypt); let mut nonce = self.nonce_base; for (i, n) in nonce.iter_mut().rev().take(COUNTER_LEN).enumerate() { *n ^= u8::try_from((seq >> (8 * i)) & 0xff).unwrap(); } + let mut pt = vec![0; ct.len()]; // NSS needs more space than it uses for plaintext. let mut pt_len: c_int = 0; let pt_expected = ct.len().checked_sub(TAG_LEN).ok_or(Error::Truncated)?; @@ -179,8 +211,8 @@ impl Aead { c_int_len(aad.len()), pt.as_mut_ptr(), &mut pt_len, - c_int_len(pt.len()), // signed :( - ct.as_ptr().add(pt_expected) as *mut _, // const cast :( + c_int_len(pt.len()), // signed :( + ct.as_ptr().add(pt_expected).cast_mut(), // const cast :( c_int_len(TAG_LEN), ct.as_ptr(), c_int_len(pt_expected), @@ -218,7 +250,7 @@ mod test { assert_eq!(&ciphertext[..], ct); let mut dec = Aead::new(Mode::Decrypt, algorithm, &k, *nonce).unwrap(); - let plaintext = dec.open(aad, 0, ct).unwrap(); + let plaintext = dec.open(aad, ct).unwrap(); assert_eq!(&plaintext[..], pt); } @@ -233,7 +265,7 @@ mod test { ) { let k = Aead::import_key(algorithm, key).unwrap(); let mut dec = Aead::new(Mode::Decrypt, algorithm, &k, *nonce).unwrap(); - let plaintext = dec.open(aad, seq, ct).unwrap(); + let plaintext = dec.open_seq(aad, seq, ct).unwrap(); assert_eq!(&plaintext[..], pt); } diff --git a/ohttp/src/nss/mod.rs b/ohttp/src/nss/mod.rs index c282a13..4f30b22 100644 --- a/ohttp/src/nss/mod.rs +++ b/ohttp/src/nss/mod.rs @@ -11,12 +11,12 @@ pub mod aead; pub mod hkdf; pub mod hpke; -use std::ptr::null; +use std::{ptr::null, sync::OnceLock}; use err::secstatus_to_res; pub use err::Error; -pub use self::p11::{random, PrivateKey, PublicKey}; +pub use self::p11::{random, PrivateKey, PublicKey, SymKey}; #[allow(clippy::pedantic, non_upper_case_globals, clippy::upper_case_acronyms)] mod nss_init { diff --git a/ohttp/src/rh/aead.rs b/ohttp/src/rh/aead.rs index 9bfe13a..24499f9 100644 --- a/ohttp/src/rh/aead.rs +++ b/ohttp/src/rh/aead.rs @@ -1,5 +1,3 @@ -#![allow(dead_code)] // TODO: remove - use std::convert::TryFrom; use aead::{AeadMut, Key, NewAead, Nonce, Payload}; @@ -12,7 +10,6 @@ use crate::{err::Res, hpke::Aead as AeadId}; /// All the nonces are the same length. Exploit that. pub const NONCE_LEN: usize = 12; const COUNTER_LEN: usize = 8; -const TAG_LEN: usize = 16; type SequenceNumber = u64; @@ -111,7 +108,13 @@ impl Aead { Ok(ct) } - pub fn open(&mut self, aad: &[u8], seq: SequenceNumber, ct: &[u8]) -> Res> { + pub fn open(&mut self, aad: &[u8], ct: &[u8]) -> Res> { + let res = self.open_seq(aad, self.seq, ct); + self.seq += 1; + res + } + + pub fn open_seq(&mut self, aad: &[u8], seq: SequenceNumber, ct: &[u8]) -> Res> { assert_eq!(self.mode, Mode::Decrypt); let nonce = self.nonce(seq); let pt = self.engine.decrypt(&nonce, Payload { msg: ct, aad })?; @@ -144,7 +147,7 @@ mod test { assert_eq!(&ciphertext[..], ct); let mut dec = Aead::new(Mode::Decrypt, algorithm, &k, *nonce).unwrap(); - let plaintext = dec.open(aad, 0, ct).unwrap(); + let plaintext = dec.open(aad, ct).unwrap(); assert_eq!(&plaintext[..], pt); } @@ -159,7 +162,7 @@ mod test { ) { let k = Aead::import_key(algorithm, key).unwrap(); let mut dec = Aead::new(Mode::Decrypt, algorithm, &k, *nonce).unwrap(); - let plaintext = dec.open(aad, seq, ct).unwrap(); + let plaintext = dec.open_seq(aad, seq, ct).unwrap(); assert_eq!(&plaintext[..], pt); } diff --git a/ohttp/src/stream.rs b/ohttp/src/stream.rs index 1f9bd90..1865305 100644 --- a/ohttp/src/stream.rs +++ b/ohttp/src/stream.rs @@ -1,17 +1,30 @@ +#![allow(clippy::incompatible_msrv)] // Until I can make MSRV conditional on feature choice. + use std::{ cmp::min, io::{Error as IoError, Result as IoResult}, + mem, pin::Pin, task::{Context, Poll}, }; -use futures::AsyncRead; +use futures::{AsyncRead, AsyncWrite}; -use crate::HpkeS; +use crate::{ + build_info, entropy, err::Res, export_secret, make_aead, Aead, Error, HpkeConfig, HpkeR, HpkeS, + KeyConfig, KeyId, Mode, PublicKey, SymKey, REQUEST_HEADER_LEN, +}; +/// The info string for a chunked request. pub(crate) const INFO_REQUEST: &[u8] = b"message/bhttp chunked request"; +/// The exporter label for a chunked response. +pub(crate) const LABEL_RESPONSE: &[u8] = b"message/bhttp chunked response"; +/// The length of the plaintext of the largest chunk that is permitted. +const MAX_CHUNK_PLAINTEXT: usize = 1 << 14; +const CHUNK_AAD: &[u8] = b""; +const FINAL_CHUNK_AAD: &[u8] = b"final"; -fn write_len(w: &mut [u8], len: usize) -> usize { +fn write_len(w: &mut [u8], len: usize) -> &[u8] { let v: u64 = len.try_into().unwrap(); let (v, len) = match () { () if v < (1 << 6) => (v, 1), @@ -21,28 +34,59 @@ fn write_len(w: &mut [u8], len: usize) -> usize { () => panic!("varint value too large"), }; w[..len].copy_from_slice(&v.to_be_bytes()[(8 - len)..]); - len + &w[..len] } -#[pin_project::pin_project] -pub struct ClientRequestStream { +#[pin_project::pin_project(project = ClientProjection)] +pub struct ClientRequest { #[pin] - src: S, + dst: S, hpke: HpkeS, buf: Vec, } -impl ClientRequestStream { - pub fn new(src: S, hpke: HpkeS, header: Vec) -> Self { - Self { - src, +impl ClientRequest { + /// Start the processing of a stream. + pub fn start(dst: S, config: HpkeConfig, key_id: KeyId, mut pk: PublicKey) -> Res { + let info = build_info(INFO_REQUEST, key_id, config)?; + let hpke = HpkeS::new(config, &mut pk, &info)?; + + let mut header = Vec::from(&info[INFO_REQUEST.len() + 1..]); + debug_assert_eq!(header.len(), REQUEST_HEADER_LEN); + + let mut e = hpke.enc()?; + header.append(&mut e); + + Ok(Self { + dst, hpke, buf: header, - } + }) + } + + /// Get an object that can be used to process the response. + /// + /// While this can be used while sending the request, + /// doing so creates a risk of revealing unwanted information to the gateway. + /// That includes the round trip time between client and gateway, + /// which might reveal information about the location of the client. + pub fn response(&self, src: R) -> Res> { + let enc = self.hpke.enc()?; + let secret = export_secret(&self.hpke, LABEL_RESPONSE, self.hpke.config())?; + Ok(ClientResponse { + src, + config: self.hpke.config(), + state: ClientResponseState::Header { + enc, + secret, + nonce: [0; 16], + read: 0, + }, + }) } } -impl AsyncRead for ClientRequestStream { +impl AsyncRead for ClientRequest { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -68,26 +112,26 @@ impl AsyncRead for ClientRequestStream { // we have to use a temporary buffer so that we can save some bytes. let mut tmp = [0; 64]; let read_buf = if buf.len() < tmp.len() { - // Use the provided buffer, but leave room for AEAD tag and a varint. - let read_len = min(buf.len(), 1 << 62) - this.hpke.aead().n_t(); + // Use the provided buffer, but cap the amount we read to MAX_CHUNK_PLAINTEXT. + let read_len = min(buf.len(), MAX_CHUNK_PLAINTEXT); &mut buf[8..read_len] } else { &mut tmp[..] }; - let (aad, len): (&[u8], _) = match this.src.poll_read(cx, read_buf) { + let (aad, len): (&[u8], _) = match this.dst.poll_read(cx, read_buf) { Poll::Pending => return Poll::Pending, - Poll::Ready(Ok(0)) => (&b"final"[..], 0), + Poll::Ready(Ok(0)) => (FINAL_CHUNK_AAD, 0), Poll::Ready(Ok(len)) => (&[], len), - Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + e @ Poll::Ready(Err(_)) => return e, }; let ct = this .hpke - .seal(aad, &mut read_buf[..len]) + .seal(aad, &read_buf[..len]) .map_err(IoError::other)?; // Now we need to write the length of the chunk. - let len_len = write_len(&mut tmp, ct.len()); + let len_len = write_len(&mut tmp, ct.len()).len(); if len_len <= buf.len() { // If the length fits in the buffer, that's easy. buf[..len_len].copy_from_slice(&tmp[..len_len]); @@ -108,3 +152,489 @@ impl AsyncRead for ClientRequestStream { Poll::Ready(Ok(amnt + written)) } } + +impl ClientRequest { + /// Flush our buffer. + /// Returns `Some` if the flush blocks or is unsuccessful. + /// If that contains `Ready`, it does so only when there is an error. + fn flush(this: &mut ClientProjection<'_, S>, cx: &mut Context<'_>) -> Option> { + while !this.buf.is_empty() { + match this.dst.as_mut().poll_write(cx, &this.buf[..]) { + Poll::Pending => return Some(Poll::Pending), + Poll::Ready(Ok(len)) => { + if len < this.buf.len() { + // We've written something to the underlying writer, + // which is probably blocked. + // We could return `Poll::Pending`, + // but that would mean taking responsibility + // for calling `cx.waker().wake()` + // when more space comes available. + // + // So, rather than do that, loop. + // If the underlying writer is truly blocked, + // it assumes responsibility for waking the task. + *this.buf = this.buf.split_off(len); + } else { + this.buf.clear(); + } + } + Poll::Ready(Err(e)) => return Some(Poll::Ready(e)), + } + } + None + } + + fn write_chunk( + this: &mut ClientProjection<'_, S>, + cx: &mut Context<'_>, + input: &[u8], + last: bool, + ) -> Poll> { + let aad = if last { FINAL_CHUNK_AAD } else { CHUNK_AAD }; + let mut ct = this.hpke.seal(aad, input).map_err(IoError::other)?; + let (len, written) = if last { + (0, 0) + } else { + (ct.len(), input.len()) + }; + + let mut len_buf = [0; 8]; + let len = write_len(&mut len_buf[..], len); + let w = match this.dst.as_mut().poll_write(cx, len) { + Poll::Pending => 0, + Poll::Ready(Ok(w)) => w, + e @ Poll::Ready(Err(_)) => return e, + }; + + if w < len.len() { + this.buf.extend_from_slice(&len[w..]); + this.buf.append(&mut ct); + } else { + match this.dst.as_mut().poll_write(cx, &ct[..]) { + Poll::Pending => { + *this.buf = ct; + } + Poll::Ready(Ok(w)) => { + *this.buf = ct.split_off(w); + } + e @ Poll::Ready(Err(_)) => return e, + } + } + Poll::Ready(Ok(written)) + } +} + +impl AsyncWrite for ClientRequest { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + input: &[u8], + ) -> Poll> { + let mut this = self.project(); + // We have buffered data, so dump it into the output directly. + if let Some(value) = Self::flush(&mut this, cx) { + return value.map(Err); + } + + // Now encipher a chunk. + let len = min(input.len(), MAX_CHUNK_PLAINTEXT); + Self::write_chunk(&mut this, cx, &input[..len], false) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + if let Some(p) = Self::flush(&mut this, cx) { + // Flushing our buffers either blocked or failed. + p.map(Err) + } else { + this.dst.as_mut().poll_flush(cx) + } + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Self::write_chunk(&mut self.project(), cx, &[], true).map(|p| p.map(|_| ())) + } +} + +enum ChunkState { + Length { + len: [u8; 8], + offset: usize, + }, + Data { + buf: Vec, + offset: usize, + length: usize, + }, + Done, +} + +impl ChunkState { + fn length() -> Self { + Self::Length { + len: [0; 8], + offset: 0, + } + } + + fn data(length: usize) -> Self { + // Avoid use `with_capacity` here. Only allocate when necessary. + // We might be able to into the buffer we're given instead, to save an allocation. + Self::Data { + buf: Vec::new(), + // Note that because we're allocating the full chunk, + // we need to track what has been used. + offset: 0, + length, + } + } +} + +#[allow(dead_code)] // TODO +enum ServerRequestState { + HpkeConfig { + config: KeyConfig, + buf: [u8; 7], + read: usize, + }, + Enc { + config: HpkeConfig, + info: Vec, + buf: Vec, + }, + Body { + hpke: HpkeR, + state: ChunkState, + }, +} + +#[pin_project::pin_project(project = ServerRequestProjection)] +pub struct ServerRequest { + #[pin] + src: S, + state: ServerRequestState, +} + +impl ServerRequest { + pub fn new(config: KeyConfig, src: S) -> Self { + Self { + src, + state: ServerRequestState::HpkeConfig { + config, + buf: [0; 7], + read: 0, + }, + } + } +} + +enum ClientResponseState { + Header { + enc: Vec, + secret: SymKey, + nonce: [u8; 16], + read: usize, + }, + Body { + aead: Aead, + state: ChunkState, + }, +} + +impl ClientResponseState { + fn done(&self) -> bool { + matches!( + self, + Self::Body { + state: ChunkState::Done, + .. + } + ) + } +} + +#[pin_project::pin_project(project = ClientResponseProjection)] +pub struct ClientResponse { + #[pin] + src: S, + config: HpkeConfig, + state: ClientResponseState, +} + +impl ClientResponse { + fn read_nonce( + this: &mut ClientResponseProjection<'_, S>, + cx: &mut Context<'_>, + ) -> Option>> { + if let ClientResponseState::Header { + enc, + secret, + nonce, + read, + } = &mut this.state + { + let aead = match this.src.as_mut().poll_read(cx, &mut nonce[*read..]) { + Poll::Pending => return Some(Poll::Pending), + Poll::Ready(Ok(0)) => { + return Some(Poll::Ready(Err(IoError::other(Error::Truncated)))) + } + Poll::Ready(Ok(len)) => { + *read += len; + if *read < entropy(*this.config) { + return Some(Poll::Pending); + } + match make_aead( + Mode::Decrypt, + *this.config, + secret, + mem::take(enc), + &nonce[..entropy(*this.config)], + ) { + Ok(aead) => aead, + Err(e) => return Some(Poll::Ready(Err(IoError::other(e)))), + } + } + e @ Poll::Ready(Err(_)) => return Some(e), + }; + + *this.state = ClientResponseState::Body { + aead, + state: ChunkState::length(), + }; + }; + None + } + + fn read_length( + this: &mut ClientResponseProjection<'_, S>, + cx: &mut Context<'_>, + ) -> Option>> { + if let ClientResponseState::Body { aead: _, state } = this.state { + // Read the first byte. + if let ChunkState::Length { len, offset } = state { + if *offset == 0 { + match this.src.as_mut().poll_read(cx, &mut len[..1]) { + Poll::Pending => return Some(Poll::Pending), + Poll::Ready(Ok(0)) => { + return Some(Poll::Ready(Err(IoError::other(Error::Truncated)))); + } + Poll::Ready(Ok(1)) => { + let form = len[0] >> 6; + if form == 0 { + *state = ChunkState::data(usize::from(len[0])); + } else { + let v = mem::replace(&mut len[0], 0) & 0x3f; + let i = match form { + 1 => 6, + 2 => 4, + 3 => 0, + _ => unreachable!(), + }; + len[i] = v; + *offset = i + 1; + } + } + Poll::Ready(Ok(_)) => unreachable!(), + e @ Poll::Ready(Err(_)) => return Some(e), + } + } + } + + // Read any remaining bytes of the length. + if let ChunkState::Length { len, offset } = state { + if *offset != 0 { + *state = match this.src.as_mut().poll_read(cx, &mut len[*offset..]) { + Poll::Pending => return Some(Poll::Pending), + Poll::Ready(Ok(0)) => { + return Some(Poll::Ready(Err(IoError::other(Error::Truncated)))); + } + Poll::Ready(Ok(r)) => { + *offset += r; + if *offset < 8 { + return Some(Poll::Pending); + } + let remaining = match usize::try_from(u64::from_be_bytes(*len)) { + Ok(remaining) => remaining, + Err(e) => return Some(Poll::Ready(Err(IoError::other(e)))), + }; + if remaining > MAX_CHUNK_PLAINTEXT + this.config.aead().n_t() { + return Some(Poll::Ready(Err(IoError::other( + Error::ChunkTooLarge, + )))); + } + ChunkState::data(remaining) + } + e @ Poll::Ready(Err(_)) => return Some(e), + }; + } + } + } + + None + } + + /// Optional optimization that reads a single chunk into the output buffer. + fn read_into_output( + this: &mut ClientResponseProjection<'_, S>, + cx: &mut Context<'_>, + output: &mut [u8], + ) -> Option>> { + if let ClientResponseState::Body { aead, state } = this.state { + if let ChunkState::Data { + buf, + offset, + length, + } = state + { + if *length > 0 && *offset == 0 && output.len() + this.config.aead().n_t() >= *length + { + match this.src.as_mut().poll_read(cx, output) { + Poll::Pending => return Some(Poll::Pending), + Poll::Ready(Ok(0)) => { + return Some(Poll::Ready(Err(IoError::other(Error::Truncated)))); + } + Poll::Ready(Ok(r)) => { + if r < *length { + buf.extend_from_slice(&output[..r]); + *offset += r; + return Some(Poll::Pending); + } + + let pt = match aead.open(CHUNK_AAD, &output[..r]) { + Ok(pt) => pt, + Err(e) => return Some(Poll::Ready(Err(IoError::other(e)))), + }; + output[..pt.len()].copy_from_slice(&pt); + *state = ChunkState::length(); + return Some(Poll::Ready(Ok(pt.len()))); + } + e @ Poll::Ready(Err(_)) => return Some(e), + } + } + } + } + + None + } +} + +impl AsyncRead for ClientResponse { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + output: &mut [u8], + ) -> Poll> { + let mut this = self.project(); + if let Some(res) = Self::read_nonce(&mut this, cx) { + return res; + } + + while !this.state.done() { + if let Some(res) = Self::read_length(&mut this, cx) { + return res; + } + + // Read data. + if let Some(res) = Self::read_into_output(&mut this, cx, output) { + return res; + } + + if let ClientResponseState::Body { aead, state } = this.state { + if let ChunkState::Data { + buf, + offset, + length, + } = state + { + // Allocate now as needed. + let last = *length == 0; + if buf.is_empty() { + let sz = if last { + MAX_CHUNK_PLAINTEXT + this.config.aead().n_t() + } else { + *length + }; + buf.resize(sz, 0); + } + + let aad = match this.src.as_mut().poll_read(cx, &mut buf[*offset..]) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Ok(0)) => { + if !last { + return Poll::Ready(Err(IoError::other(Error::Truncated))); + } + + FINAL_CHUNK_AAD + } + Poll::Ready(Ok(r)) => { + if *offset + r < *length { + buf.extend_from_slice(&output[..r]); + *offset += r; + return Poll::Pending; + } + + CHUNK_AAD + } + e @ Poll::Ready(Err(_)) => return e, + }; + + let pt = aead.open(aad, buf).map_err(IoError::other)?; + output[..pt.len()].copy_from_slice(&pt); + *state = if last { + ChunkState::Done + } else { + ChunkState::length() + }; + if !pt.is_empty() { + return Poll::Ready(Ok(pt.len())); + } + } + } + } + Poll::Ready(Ok(0)) + } +} + +#[cfg(test)] +mod test { + use futures::{io::Cursor, AsyncReadExt, AsyncWriteExt}; + use log::trace; + use sync_async::{SyncRead, SyncResolve}; + + use crate::{ + test::{init, make_config, REQUEST, RESPONSE}, + ClientRequest, Server, + }; + + #[test] + fn request_response() { + init(); + + let server_config = make_config(); + let server = Server::new(server_config).unwrap(); + let encoded_config = server.config().encode().unwrap(); + trace!("Config: {}", hex::encode(&encoded_config)); + + let client = ClientRequest::from_encoded_config(&encoded_config).unwrap(); + let (mut request_read, request_write) = AsyncReadExt::split(Cursor::new(Vec::new())); + let mut client_request = client.encapsulate_stream(request_write).unwrap(); + client_request.write_all(REQUEST).sync_resolve().unwrap(); + client_request.close().sync_resolve().unwrap(); + + trace!("Request: {}", hex::encode(REQUEST)); + let request_buf = request_read.sync_read_to_end(); + trace!("Encapsulated Request: {}", hex::encode(&request_buf)); + + let (request, server_response) = server.decapsulate(&request_buf[..]).unwrap(); + assert_eq!(&request[..], REQUEST); + + let enc_response = server_response.encapsulate(RESPONSE).unwrap(); + trace!("Encapsulated Response: {}", hex::encode(&enc_response)); + + let mut client_response = client_request.response(&enc_response[..]).unwrap(); + + let response_buf = client_response.sync_read_to_end(); + assert_eq!(response_buf, RESPONSE); + trace!("Response: {}", hex::encode(response_buf)); + } +} diff --git a/sync-async/src/lib.rs b/sync-async/src/lib.rs index e21ca07..f421021 100644 --- a/sync-async/src/lib.rs +++ b/sync-async/src/lib.rs @@ -1,10 +1,12 @@ use std::{ future::Future, + io::Result as IoResult, pin::{pin, Pin}, task::{Context, Poll}, }; -use futures::{AsyncRead, AsyncReadExt, TryStream, TryStreamExt}; +use futures::{AsyncRead, AsyncReadExt, AsyncWrite, TryStream, TryStreamExt}; +use pin_project::pin_project; fn noop_context() -> Context<'static> { use std::{ @@ -69,14 +71,17 @@ impl SyncResolve for F { } } -pub trait SyncCollect { +pub trait SyncTryCollect { type Item; type Error; + /// Synchronously gather all items from a stream. + /// # Errors + /// When the underlying source produces an error. fn sync_collect(self) -> Result, Self::Error>; } -impl SyncCollect for S { +impl SyncTryCollect for S { type Item = S::Ok; type Error = S::Error; @@ -106,13 +111,15 @@ impl SyncRead for S { } } +#[pin_project(project = DribbleProjection)] pub struct Dribble { - src: S, + #[pin] + s: S, } impl Dribble { - pub fn new(src: S) -> Self { - Self { src } + pub fn new(s: S) -> Self { + Self { s } } } @@ -121,7 +128,78 @@ impl AsyncRead for Dribble { mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], - ) -> Poll> { - pin!(&mut self.src).poll_read(cx, &mut buf[..1]) + ) -> Poll> { + pin!(&mut self.s).poll_read(cx, &mut buf[..1]) + } +} + +impl AsyncWrite for Dribble { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + let mut this = self.project(); + this.s.as_mut().poll_write(cx, &buf[..1]) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + this.s.as_mut().poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + this.s.as_mut().poll_close(cx) + } +} + +#[pin_project(project = StutterProjection)] +pub struct Stutter { + stall: bool, + #[pin] + s: S, +} + +impl Stutter { + pub fn new(s: S) -> Self { + Self { stall: false, s } + } + + fn stutter(self: Pin<&mut Self>, cx: &mut Context<'_>, f: F) -> Poll + where + F: FnOnce(Pin<&mut S>, &mut Context<'_>) -> Poll, + { + let mut this = self.project(); + *this.stall = !*this.stall; + if *this.stall { + // When returning `Poll::Pending`, you have to wake the task. + // We aren't running code anywhere except here, + // so call it here and ensure that the task is picked up. + cx.waker().wake_by_ref(); + Poll::Pending + } else { + f(this.s.as_mut(), cx) + } + } +} + +impl AsyncRead for Stutter { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Self::stutter(self, cx, |s, cx| s.poll_read(cx, buf)) + } +} + +impl AsyncWrite for Stutter { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + Self::stutter(self, cx, |s, cx| s.poll_write(cx, buf)) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Self::stutter(self, cx, AsyncWrite::poll_flush) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Self::stutter(self, cx, AsyncWrite::poll_close) } } From f1c5dae900282b2fd8fc178a12831b5f23d2a3fd Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Thu, 19 Dec 2024 12:18:52 +1100 Subject: [PATCH 18/20] Checkpoint --- ohttp/Cargo.toml | 2 +- ohttp/src/crypto.rs | 9 +++++++++ ohttp/src/lib.rs | 2 ++ ohttp/src/nss/hpke.rs | 19 +++++++++++++++---- ohttp/src/rh/hpke.rs | 9 +++++++-- ohttp/src/stream.rs | 4 ++-- 6 files changed, 36 insertions(+), 9 deletions(-) create mode 100644 ohttp/src/crypto.rs diff --git a/ohttp/Cargo.toml b/ohttp/Cargo.toml index 17ab467..01a2027 100644 --- a/ohttp/Cargo.toml +++ b/ohttp/Cargo.toml @@ -10,7 +10,7 @@ description = "Oblivious HTTP" repository = "https://github.com/martinthomson/ohttp" [features] -default = ["client", "server", "nss", "stream"] +default = ["client", "server", "rust-hpke", "stream"] app-svc = ["nss"] client = [] external-sqlite = [] diff --git a/ohttp/src/crypto.rs b/ohttp/src/crypto.rs new file mode 100644 index 0000000..5ab6aa9 --- /dev/null +++ b/ohttp/src/crypto.rs @@ -0,0 +1,9 @@ +use crate::err::Res; + +pub trait Decrypt { + fn open(&mut self, aad: &[u8], ct: &[u8]) -> Res>; +} + +pub trait Encrypt { + fn seal(&mut self, aad: &[u8], ct: &[u8]) -> Res>; +} diff --git a/ohttp/src/lib.rs b/ohttp/src/lib.rs index 6e9857d..b3f65f4 100644 --- a/ohttp/src/lib.rs +++ b/ohttp/src/lib.rs @@ -8,6 +8,7 @@ compile_error!("features \"nss\" and \"rust-hpke\" are mutually incompatible"); mod config; +mod crypto; mod err; pub mod hpke; #[cfg(feature = "nss")] @@ -27,6 +28,7 @@ use std::{ }; use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt}; +use crypto::{Decrypt, Encrypt}; use log::trace; #[cfg(feature = "nss")] diff --git a/ohttp/src/nss/hpke.rs b/ohttp/src/nss/hpke.rs index 16d018d..c67f928 100644 --- a/ohttp/src/nss/hpke.rs +++ b/ohttp/src/nss/hpke.rs @@ -13,7 +13,10 @@ use super::{ err::{sec::SEC_ERROR_INVALID_ARGS, secstatus_to_res, Error}, p11::{sys, Item, PrivateKey, PublicKey, Slot, SymKey}, }; -use crate::err::Res; +use crate::{ + crypto::{Decrypt, Encrypt}, + err::Res, +}; /// Configuration for `Hpke`. #[derive(Clone, Copy)] @@ -135,8 +138,10 @@ impl HpkeS { let slc = unsafe { std::slice::from_raw_parts(r.data, len) }; Ok(Vec::from(slc)) } +} - pub fn seal(&mut self, aad: &[u8], pt: &[u8]) -> Res> { +impl Encrypt for HpkeS { + fn seal(&mut self, aad: &[u8], pt: &[u8]) -> Res> { let mut out: *mut sys::SECItem = null_mut(); secstatus_to_res(unsafe { sys::PK11_HPKE_Seal(*self.context, &Item::wrap(aad), &Item::wrap(pt), &mut out) @@ -209,8 +214,10 @@ impl HpkeR { })?; PublicKey::from_ptr(ptr) } +} - pub fn open(&mut self, aad: &[u8], ct: &[u8]) -> Res> { +impl Decrypt for HpkeR { + fn open(&mut self, aad: &[u8], ct: &[u8]) -> Res> { let mut out: *mut sys::SECItem = null_mut(); secstatus_to_res(unsafe { sys::PK11_HPKE_Open(*self.context, &Item::wrap(aad), &Item::wrap(ct), &mut out) @@ -294,7 +301,11 @@ pub fn generate_key_pair(kem: Kem) -> Res<(PrivateKey, PublicKey)> { #[cfg(test)] mod test { use super::{generate_key_pair, Config, HpkeContext, HpkeR, HpkeS}; - use crate::{hpke::Aead, init}; + use crate::{ + crypto::{Decrypt, Encrypt}, + hpke::Aead, + init, + }; const INFO: &[u8] = b"info"; const AAD: &[u8] = b"aad"; diff --git a/ohttp/src/rh/hpke.rs b/ohttp/src/rh/hpke.rs index ce8ee78..46fd4d9 100644 --- a/ohttp/src/rh/hpke.rs +++ b/ohttp/src/rh/hpke.rs @@ -17,6 +17,7 @@ use rust_hpke::{ use super::SymKey; use crate::{ + crypto::{Decrypt, Encrypt}, hpke::{Aead, Kdf, Kem}, Error, Res, }; @@ -303,8 +304,10 @@ impl HpkeS { pub fn enc(&self) -> Res> { Ok(self.enc.clone()) } +} - pub fn seal(&mut self, aad: &[u8], pt: &[u8]) -> Res> { +impl Encrypt for HpkeS { + fn seal(&mut self, aad: &[u8], pt: &[u8]) -> Res> { let mut buf = pt.to_owned(); let mut tag = self.context.seal(&mut buf, aad)?; buf.append(&mut tag); @@ -522,8 +525,10 @@ impl HpkeR { ), }) } +} - pub fn open(&mut self, aad: &[u8], ct: &[u8]) -> Res> { +impl Decrypt for HpkeR { + fn open(&mut self, aad: &[u8], ct: &[u8]) -> Res> { let mut buf = ct.to_owned(); let pt_len = self.context.open(&mut buf, aad)?.len(); buf.truncate(pt_len); diff --git a/ohttp/src/stream.rs b/ohttp/src/stream.rs index 1865305..6a94d3c 100644 --- a/ohttp/src/stream.rs +++ b/ohttp/src/stream.rs @@ -11,8 +11,8 @@ use std::{ use futures::{AsyncRead, AsyncWrite}; use crate::{ - build_info, entropy, err::Res, export_secret, make_aead, Aead, Error, HpkeConfig, HpkeR, HpkeS, - KeyConfig, KeyId, Mode, PublicKey, SymKey, REQUEST_HEADER_LEN, + build_info, crypto::Encrypt, entropy, err::Res, export_secret, make_aead, Aead, Error, + HpkeConfig, HpkeR, HpkeS, KeyConfig, KeyId, Mode, PublicKey, SymKey, REQUEST_HEADER_LEN, }; /// The info string for a chunked request. From 9cf1972cb450d67fa0a8e562f0af1dab360a88bb Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Thu, 19 Dec 2024 17:55:00 +1100 Subject: [PATCH 19/20] Working --- ohttp/src/config.rs | 16 + ohttp/src/crypto.rs | 5 +- ohttp/src/err.rs | 3 + ohttp/src/lib.rs | 47 +-- ohttp/src/nss/aead.rs | 102 ++--- ohttp/src/nss/hpke.rs | 8 + ohttp/src/rh/aead.rs | 47 ++- ohttp/src/rh/hpke.rs | 12 + ohttp/src/stream.rs | 851 ++++++++++++++++++++++++------------------ sync-async/src/lib.rs | 67 +++- 10 files changed, 710 insertions(+), 448 deletions(-) diff --git a/ohttp/src/config.rs b/ohttp/src/config.rs index 0f6fab3..f2316ce 100644 --- a/ohttp/src/config.rs +++ b/ohttp/src/config.rs @@ -257,6 +257,22 @@ impl KeyConfig { Err(Error::Unsupported) } } + + #[allow(clippy::similar_names)] // for kem_id and key_id + pub(crate) fn decode_hpke_config(&self, r: &mut Cursor<&[u8]>) -> Res { + let key_id = r.read_u8()?; + if key_id != self.key_id { + return Err(Error::KeyId); + } + let kem_id = Kem::try_from(r.read_u16::()?)?; + if kem_id != self.kem { + return Err(Error::InvalidKem); + } + let kdf_id = Kdf::try_from(r.read_u16::()?)?; + let aead_id = AeadId::try_from(r.read_u16::()?)?; + let hpke_config = HpkeConfig::new(self.kem, kdf_id, aead_id); + Ok(hpke_config) + } } impl AsRef for KeyConfig { diff --git a/ohttp/src/crypto.rs b/ohttp/src/crypto.rs index 5ab6aa9..b20a019 100644 --- a/ohttp/src/crypto.rs +++ b/ohttp/src/crypto.rs @@ -1,9 +1,12 @@ -use crate::err::Res; +use crate::{err::Res, AeadId}; pub trait Decrypt { fn open(&mut self, aad: &[u8], ct: &[u8]) -> Res>; + fn alg(&self) -> AeadId; } pub trait Encrypt { + #[allow(dead_code)] // TODO + fn alg(&self) -> AeadId; fn seal(&mut self, aad: &[u8], ct: &[u8]) -> Res>; } diff --git a/ohttp/src/err.rs b/ohttp/src/err.rs index c52f56b..caaedc8 100644 --- a/ohttp/src/err.rs +++ b/ohttp/src/err.rs @@ -32,6 +32,9 @@ pub enum Error { Io(#[from] std::io::Error), #[error("the key ID was invalid")] KeyId, + #[cfg(feature = "stream")] + #[error("the object was not ready")] + NotReady, #[error("a field was truncated")] Truncated, #[error("the configuration was not supported")] diff --git a/ohttp/src/lib.rs b/ohttp/src/lib.rs index b3f65f4..133d9b8 100644 --- a/ohttp/src/lib.rs +++ b/ohttp/src/lib.rs @@ -27,7 +27,7 @@ use std::{ mem::size_of, }; -use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt}; +use byteorder::{NetworkEndian, WriteBytesExt}; use crypto::{Decrypt, Encrypt}; use log::trace; @@ -39,15 +39,12 @@ use crate::nss::{ random, PublicKey, SymKey, }; #[cfg(feature = "stream")] -use crate::stream::{ClientRequest as StreamClient, ServerRequest as StreamServer}; +use crate::stream::{ClientRequest as StreamClient, ServerRequest as ServerRequestStream}; pub use crate::{ config::{KeyConfig, SymmetricSuite}, err::Error, }; -use crate::{ - err::Res, - hpke::{Aead as AeadId, Kdf, Kem}, -}; +use crate::{err::Res, hpke::Aead as AeadId}; #[cfg(feature = "rust-hpke")] use crate::{ rand::random, @@ -185,24 +182,8 @@ impl Server { &self.config } - #[allow(clippy::similar_names)] // for kem_id and key_id - fn decode_hpke_config(&self, r: &mut Cursor<&[u8]>) -> Res { - let key_id = r.read_u8()?; - if key_id != self.config.key_id { - return Err(Error::KeyId); - } - let kem_id = Kem::try_from(r.read_u16::()?)?; - if kem_id != self.config.kem { - return Err(Error::InvalidKem); - } - let kdf_id = Kdf::try_from(r.read_u16::()?)?; - let aead_id = AeadId::try_from(r.read_u16::()?)?; - let hpke_config = HpkeConfig::new(self.config.kem, kdf_id, aead_id); - Ok(hpke_config) - } - fn decode_request_header(&self, r: &mut Cursor<&[u8]>, label: &[u8]) -> Res<(HpkeR, Vec)> { - let hpke_config = self.decode_hpke_config(r)?; + let hpke_config = self.config.decode_hpke_config(r)?; let sym = SymmetricSuite::new(hpke_config.kdf(), hpke_config.aead()); let config = self.config.select(sym)?; let info = build_info(label, self.config.key_id, hpke_config)?; @@ -233,13 +214,13 @@ impl Server { let (mut hpke, enc) = self.decode_request_header(&mut r, INFO_REQUEST)?; let request = hpke.open(&[], &enc_request[usize::try_from(r.position())?..])?; - Ok((request, ServerResponse::new(&hpke, enc)?)) + Ok((request, ServerResponse::new(&hpke, &enc)?)) } /// Remove encapsulation on a streamed request. #[cfg(feature = "stream")] - pub fn decapsulate_stream(self, src: S) -> StreamServer { - StreamServer::new(self.config, src) + pub fn decapsulate_stream(self, src: S) -> ServerRequestStream { + ServerRequestStream::new(self.config, src) } } @@ -251,14 +232,8 @@ fn export_secret(exp: &E, label: &[u8], cfg: HpkeConfig) -> Res, - nonce: &[u8], -) -> Res { - let mut salt = enc; +fn make_aead(mode: Mode, cfg: HpkeConfig, secret: &SymKey, enc: &[u8], nonce: &[u8]) -> Res { + let mut salt = enc.to_vec(); salt.extend_from_slice(nonce); let hkdf = Hkdf::new(cfg.kdf()); @@ -281,7 +256,7 @@ pub struct ServerResponse { #[cfg(feature = "server")] impl ServerResponse { - fn new(hpke: &HpkeR, enc: Vec) -> Res { + fn new(hpke: &HpkeR, enc: &[u8]) -> Res { let response_nonce = random(entropy(hpke.config())); let aead = make_aead( Mode::Encrypt, @@ -340,7 +315,7 @@ impl ClientResponse { Mode::Decrypt, self.hpke.config(), &export_secret(&self.hpke, LABEL_RESPONSE, self.hpke.config())?, - self.enc, + &self.enc, response_nonce, )?; aead.open(&[], ct) // 0 is the sequence number diff --git a/ohttp/src/nss/aead.rs b/ohttp/src/nss/aead.rs index d389b99..bbf1d00 100644 --- a/ohttp/src/nss/aead.rs +++ b/ohttp/src/nss/aead.rs @@ -18,6 +18,7 @@ use super::{ }, }; use crate::{ + crypto::{Decrypt, Encrypt}, err::{Error, Res}, hpke::Aead as AeadId, }; @@ -69,6 +70,7 @@ impl Mode { /// This is an AEAD instance that uses the pub struct Aead { mode: Mode, + algorithm: AeadId, ctx: Context, nonce_base: [u8; NONCE_LEN], } @@ -120,55 +122,27 @@ impl Aead { }; Ok(Self { mode, + algorithm, ctx: Context::from_ptr(ptr)?, nonce_base, }) } - pub fn seal(&mut self, aad: &[u8], pt: &[u8]) -> Res> { - assert_eq!(self.mode, Mode::Encrypt); - // A copy for the nonce generator to write into. But we don't use the value. - let mut nonce = self.nonce_base; - // Ciphertext with enough space for the tag. - // Even though we give the operation a separate buffer for the tag, - // reserve the capacity on allocation. - let mut ct = vec![0; pt.len() + TAG_LEN]; - let mut ct_len: c_int = 0; - let mut tag = vec![0; TAG_LEN]; - secstatus_to_res(unsafe { - PK11_AEADOp( - *self.ctx, - CK_GENERATOR_FUNCTION::from(CKG_GENERATE_COUNTER_XOR), - c_int_len(NONCE_LEN - COUNTER_LEN), // Fixed portion of the nonce. - nonce.as_mut_ptr(), - c_int_len(nonce.len()), - aad.as_ptr(), - c_int_len(aad.len()), - ct.as_mut_ptr(), - &mut ct_len, - c_int_len(ct.len()), // signed :( - tag.as_mut_ptr(), - c_int_len(tag.len()), - pt.as_ptr(), - c_int_len(pt.len()), - ) - })?; - ct.truncate(usize::try_from(ct_len).unwrap()); - debug_assert_eq!(ct.len(), pt.len()); - ct.append(&mut tag); - Ok(ct) - } - - pub fn open(&mut self, aad: &[u8], ct: &[u8]) -> Res> { + #[allow(dead_code)] + pub fn open_seq(&mut self, aad: &[u8], seq: SequenceNumber, ct: &[u8]) -> Res> { assert_eq!(self.mode, Mode::Decrypt); let mut nonce = self.nonce_base; + for (i, n) in nonce.iter_mut().rev().take(COUNTER_LEN).enumerate() { + *n ^= u8::try_from((seq >> (8 * i)) & 0xff).unwrap(); + } + let mut pt = vec![0; ct.len()]; // NSS needs more space than it uses for plaintext. let mut pt_len: c_int = 0; let pt_expected = ct.len().checked_sub(TAG_LEN).ok_or(Error::Truncated)?; secstatus_to_res(unsafe { PK11_AEADOp( *self.ctx, - CK_GENERATOR_FUNCTION::from(CKG_GENERATE_COUNTER_XOR), + CK_GENERATOR_FUNCTION::from(CKG_NO_GENERATE), c_int_len(NONCE_LEN - COUNTER_LEN), // Fixed portion of the nonce. nonce.as_mut_ptr(), c_int_len(nonce.len()), @@ -188,22 +162,19 @@ impl Aead { pt.truncate(len); Ok(pt) } +} - #[allow(dead_code)] - pub fn open_seq(&mut self, aad: &[u8], seq: SequenceNumber, ct: &[u8]) -> Res> { +impl Decrypt for Aead { + fn open(&mut self, aad: &[u8], ct: &[u8]) -> Res> { assert_eq!(self.mode, Mode::Decrypt); let mut nonce = self.nonce_base; - for (i, n) in nonce.iter_mut().rev().take(COUNTER_LEN).enumerate() { - *n ^= u8::try_from((seq >> (8 * i)) & 0xff).unwrap(); - } - let mut pt = vec![0; ct.len()]; // NSS needs more space than it uses for plaintext. let mut pt_len: c_int = 0; let pt_expected = ct.len().checked_sub(TAG_LEN).ok_or(Error::Truncated)?; secstatus_to_res(unsafe { PK11_AEADOp( *self.ctx, - CK_GENERATOR_FUNCTION::from(CKG_NO_GENERATE), + CK_GENERATOR_FUNCTION::from(CKG_GENERATE_COUNTER_XOR), c_int_len(NONCE_LEN - COUNTER_LEN), // Fixed portion of the nonce. nonce.as_mut_ptr(), c_int_len(nonce.len()), @@ -223,6 +194,50 @@ impl Aead { pt.truncate(len); Ok(pt) } + + fn alg(&self) -> AeadId { + self.algorithm + } +} + +impl Encrypt for Aead { + fn seal(&mut self, aad: &[u8], pt: &[u8]) -> Res> { + assert_eq!(self.mode, Mode::Encrypt); + // A copy for the nonce generator to write into. But we don't use the value. + let mut nonce = self.nonce_base; + // Ciphertext with enough space for the tag. + // Even though we give the operation a separate buffer for the tag, + // reserve the capacity on allocation. + let mut ct = vec![0; pt.len() + TAG_LEN]; + let mut ct_len: c_int = 0; + let mut tag = vec![0; TAG_LEN]; + secstatus_to_res(unsafe { + PK11_AEADOp( + *self.ctx, + CK_GENERATOR_FUNCTION::from(CKG_GENERATE_COUNTER_XOR), + c_int_len(NONCE_LEN - COUNTER_LEN), // Fixed portion of the nonce. + nonce.as_mut_ptr(), + c_int_len(nonce.len()), + aad.as_ptr(), + c_int_len(aad.len()), + ct.as_mut_ptr(), + &mut ct_len, + c_int_len(ct.len()), // signed :( + tag.as_mut_ptr(), + c_int_len(tag.len()), + pt.as_ptr(), + c_int_len(pt.len()), + ) + })?; + ct.truncate(usize::try_from(ct_len).unwrap()); + debug_assert_eq!(ct.len(), pt.len()); + ct.append(&mut tag); + Ok(ct) + } + + fn alg(&self) -> AeadId { + self.algorithm + } } #[cfg(test)] @@ -231,6 +246,7 @@ mod test { super::{super::hpke::Aead as AeadId, init}, Aead, Mode, SequenceNumber, NONCE_LEN, }; + use crate::crypto::{Decrypt, Encrypt}; /// Check that the first invocation of encryption matches expected values. /// Also check decryption of the same. diff --git a/ohttp/src/nss/hpke.rs b/ohttp/src/nss/hpke.rs index c67f928..53fa3a5 100644 --- a/ohttp/src/nss/hpke.rs +++ b/ohttp/src/nss/hpke.rs @@ -149,6 +149,10 @@ impl Encrypt for HpkeS { let v = Item::from_ptr(out)?; Ok(unsafe { v.into_vec() }) } + + fn alg(&self) -> Aead { + self.config.aead() + } } impl Exporter for HpkeS { @@ -225,6 +229,10 @@ impl Decrypt for HpkeR { let v = Item::from_ptr(out)?; Ok(unsafe { v.into_vec() }) } + + fn alg(&self) -> Aead { + self.config.aead() + } } impl Exporter for HpkeR { diff --git a/ohttp/src/rh/aead.rs b/ohttp/src/rh/aead.rs index 24499f9..9521b30 100644 --- a/ohttp/src/rh/aead.rs +++ b/ohttp/src/rh/aead.rs @@ -5,7 +5,11 @@ use aes_gcm::{Aes128Gcm, Aes256Gcm}; use chacha20poly1305::ChaCha20Poly1305; use super::SymKey; -use crate::{err::Res, hpke::Aead as AeadId}; +use crate::{ + crypto::{Decrypt, Encrypt}, + err::Res, + hpke::Aead as AeadId, +}; /// All the nonces are the same length. Exploit that. pub const NONCE_LEN: usize = 12; @@ -53,6 +57,7 @@ impl AeadEngine { /// A switch-hitting AEAD that uses a selected primitive. pub struct Aead { mode: Mode, + algorithm: AeadId, engine: AeadEngine, nonce_base: [u8; NONCE_LEN], seq: SequenceNumber, @@ -79,6 +84,7 @@ impl Aead { }; Ok(Self { mode, + algorithm, engine: aead, nonce_base, seq: 0, @@ -99,26 +105,40 @@ impl Aead { nonce } - pub fn seal(&mut self, aad: &[u8], pt: &[u8]) -> Res> { + pub fn open_seq(&mut self, aad: &[u8], seq: SequenceNumber, ct: &[u8]) -> Res> { + assert_eq!(self.mode, Mode::Decrypt); + let nonce = self.nonce(seq); + let pt = self.engine.decrypt(&nonce, Payload { msg: ct, aad })?; + Ok(pt) + } +} + +impl Decrypt for Aead { + fn open(&mut self, aad: &[u8], ct: &[u8]) -> Res> { + println!("aead open: {}", hex::encode(ct)); + let res = self.open_seq(aad, self.seq, ct); + self.seq += 1; + res + } + + fn alg(&self) -> AeadId { + self.algorithm + } +} + +impl Encrypt for Aead { + fn seal(&mut self, aad: &[u8], pt: &[u8]) -> Res> { assert_eq!(self.mode, Mode::Encrypt); // A copy for the nonce generator to write into. But we don't use the value. let nonce = self.nonce(self.seq); self.seq += 1; let ct = self.engine.encrypt(&nonce, Payload { msg: pt, aad })?; + println!("aead seal: {}", hex::encode(&ct)); Ok(ct) } - pub fn open(&mut self, aad: &[u8], ct: &[u8]) -> Res> { - let res = self.open_seq(aad, self.seq, ct); - self.seq += 1; - res - } - - pub fn open_seq(&mut self, aad: &[u8], seq: SequenceNumber, ct: &[u8]) -> Res> { - assert_eq!(self.mode, Mode::Decrypt); - let nonce = self.nonce(seq); - let pt = self.engine.decrypt(&nonce, Payload { msg: ct, aad })?; - Ok(pt) + fn alg(&self) -> AeadId { + self.algorithm } } @@ -128,6 +148,7 @@ mod test { super::super::{hpke::Aead as AeadId, init}, Aead, Mode, SequenceNumber, NONCE_LEN, }; + use crate::crypto::{Decrypt, Encrypt}; /// Check that the first invocation of encryption matches expected values. /// Also check decryption of the same. diff --git a/ohttp/src/rh/hpke.rs b/ohttp/src/rh/hpke.rs index 46fd4d9..48c83aa 100644 --- a/ohttp/src/rh/hpke.rs +++ b/ohttp/src/rh/hpke.rs @@ -311,8 +311,13 @@ impl Encrypt for HpkeS { let mut buf = pt.to_owned(); let mut tag = self.context.seal(&mut buf, aad)?; buf.append(&mut tag); + println!("hpke seal: {}", hex::encode(&buf)); Ok(buf) } + + fn alg(&self) -> Aead { + self.config.aead() + } } impl Exporter for HpkeS { @@ -529,11 +534,17 @@ impl HpkeR { impl Decrypt for HpkeR { fn open(&mut self, aad: &[u8], ct: &[u8]) -> Res> { + println!("hpke open: {}", hex::encode(&ct)); + let mut buf = ct.to_owned(); let pt_len = self.context.open(&mut buf, aad)?.len(); buf.truncate(pt_len); Ok(buf) } + + fn alg(&self) -> Aead { + self.config.aead() + } } impl Exporter for HpkeR { @@ -599,6 +610,7 @@ pub fn derive_key_pair(kem: Kem, ikm: &[u8]) -> Res<(PrivateKey, PublicKey)> { mod test { use super::{generate_key_pair, Config, HpkeR, HpkeS}; use crate::{ + crypto::{Decrypt, Encrypt}, hpke::{Aead, Kem}, init, }; diff --git a/ohttp/src/stream.rs b/ohttp/src/stream.rs index 6a94d3c..4b6751f 100644 --- a/ohttp/src/stream.rs +++ b/ohttp/src/stream.rs @@ -2,17 +2,22 @@ use std::{ cmp::min, - io::{Error as IoError, Result as IoResult}, + io::{Cursor, Error as IoError, Result as IoResult}, mem, pin::Pin, task::{Context, Poll}, }; use futures::{AsyncRead, AsyncWrite}; +use pin_project::pin_project; use crate::{ - build_info, crypto::Encrypt, entropy, err::Res, export_secret, make_aead, Aead, Error, - HpkeConfig, HpkeR, HpkeS, KeyConfig, KeyId, Mode, PublicKey, SymKey, REQUEST_HEADER_LEN, + build_info, + crypto::{Decrypt, Encrypt}, + entropy, + err::Res, + export_secret, make_aead, random, Aead, Error, HpkeConfig, HpkeR, HpkeS, KeyConfig, KeyId, + Mode, PublicKey, SymKey, REQUEST_HEADER_LEN, }; /// The info string for a chunked request. @@ -24,140 +29,37 @@ const MAX_CHUNK_PLAINTEXT: usize = 1 << 14; const CHUNK_AAD: &[u8] = b""; const FINAL_CHUNK_AAD: &[u8] = b"final"; -fn write_len(w: &mut [u8], len: usize) -> &[u8] { - let v: u64 = len.try_into().unwrap(); - let (v, len) = match () { - () if v < (1 << 6) => (v, 1), - () if v < (1 << 14) => (v | 1 << 14, 2), - () if v < (1 << 30) => (v | (2 << 30), 4), - () if v < (1 << 62) => (v | (3 << 62), 8), - () => panic!("varint value too large"), - }; - w[..len].copy_from_slice(&v.to_be_bytes()[(8 - len)..]); - &w[..len] -} - -#[pin_project::pin_project(project = ClientProjection)] -pub struct ClientRequest { +#[pin_project(project = ChunkWriterProjection)] +struct ChunkWriter { #[pin] - dst: S, - hpke: HpkeS, + dst: D, + cipher: E, buf: Vec, } -impl ClientRequest { - /// Start the processing of a stream. - pub fn start(dst: S, config: HpkeConfig, key_id: KeyId, mut pk: PublicKey) -> Res { - let info = build_info(INFO_REQUEST, key_id, config)?; - let hpke = HpkeS::new(config, &mut pk, &info)?; - - let mut header = Vec::from(&info[INFO_REQUEST.len() + 1..]); - debug_assert_eq!(header.len(), REQUEST_HEADER_LEN); - - let mut e = hpke.enc()?; - header.append(&mut e); - - Ok(Self { - dst, - hpke, - buf: header, - }) - } - - /// Get an object that can be used to process the response. - /// - /// While this can be used while sending the request, - /// doing so creates a risk of revealing unwanted information to the gateway. - /// That includes the round trip time between client and gateway, - /// which might reveal information about the location of the client. - pub fn response(&self, src: R) -> Res> { - let enc = self.hpke.enc()?; - let secret = export_secret(&self.hpke, LABEL_RESPONSE, self.hpke.config())?; - Ok(ClientResponse { - src, - config: self.hpke.config(), - state: ClientResponseState::Header { - enc, - secret, - nonce: [0; 16], - read: 0, - }, - }) - } -} - -impl AsyncRead for ClientRequest { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - mut buf: &mut [u8], - ) -> Poll> { - let this = self.project(); - // We have buffered data, so dump it into the output directly. - let mut written = if this.buf.is_empty() { - 0 - } else { - let amnt = min(this.buf.len(), buf.len()); - buf[..amnt].copy_from_slice(&this.buf[..amnt]); - buf = &mut buf[amnt..]; - *this.buf = this.buf.split_off(amnt); - if buf.is_empty() { - return Poll::Ready(Ok(amnt)); - } - amnt - }; - - // Now read into the buffer. - // Because we are expanding the data, when the buffer we are provided is too small, - // we have to use a temporary buffer so that we can save some bytes. - let mut tmp = [0; 64]; - let read_buf = if buf.len() < tmp.len() { - // Use the provided buffer, but cap the amount we read to MAX_CHUNK_PLAINTEXT. - let read_len = min(buf.len(), MAX_CHUNK_PLAINTEXT); - &mut buf[8..read_len] - } else { - &mut tmp[..] +impl ChunkWriter { + fn write_len(w: &mut [u8], len: usize) -> &[u8] { + let v: u64 = len.try_into().unwrap(); + let (v, len) = match () { + () if v < (1 << 6) => (v, 1), + () if v < (1 << 14) => (v | 1 << 14, 2), + () if v < (1 << 30) => (v | (2 << 30), 4), + () if v < (1 << 62) => (v | (3 << 62), 8), + () => panic!("varint value too large"), }; - let (aad, len): (&[u8], _) = match this.dst.poll_read(cx, read_buf) { - Poll::Pending => return Poll::Pending, - Poll::Ready(Ok(0)) => (FINAL_CHUNK_AAD, 0), - Poll::Ready(Ok(len)) => (&[], len), - e @ Poll::Ready(Err(_)) => return e, - }; - - let ct = this - .hpke - .seal(aad, &read_buf[..len]) - .map_err(IoError::other)?; - - // Now we need to write the length of the chunk. - let len_len = write_len(&mut tmp, ct.len()).len(); - if len_len <= buf.len() { - // If the length fits in the buffer, that's easy. - buf[..len_len].copy_from_slice(&tmp[..len_len]); - written += len_len; - buf = &mut buf[len_len..]; - } else { - // Otherwise, we need to save any remainder in our own buffer. - buf.copy_from_slice(&tmp[..buf.len()]); - this.buf.extend_from_slice(&tmp[buf.len()..len_len]); - let amnt = buf.len(); - written += amnt; - buf = &mut buf[amnt..]; - } - - let amnt = min(ct.len(), buf.len()); - buf[..amnt].copy_from_slice(&ct[..amnt]); - this.buf.extend_from_slice(&ct[amnt..]); - Poll::Ready(Ok(amnt + written)) + w[..len].copy_from_slice(&v.to_be_bytes()[(8 - len)..]); + &w[..len] } } -impl ClientRequest { +impl ChunkWriter { /// Flush our buffer. /// Returns `Some` if the flush blocks or is unsuccessful. /// If that contains `Ready`, it does so only when there is an error. - fn flush(this: &mut ClientProjection<'_, S>, cx: &mut Context<'_>) -> Option> { + fn flush( + this: &mut ChunkWriterProjection<'_, D, E>, + cx: &mut Context<'_>, + ) -> Option> { while !this.buf.is_empty() { match this.dst.as_mut().poll_write(cx, &this.buf[..]) { Poll::Pending => return Some(Poll::Pending), @@ -185,13 +87,13 @@ impl ClientRequest { } fn write_chunk( - this: &mut ClientProjection<'_, S>, + this: &mut ChunkWriterProjection<'_, D, E>, cx: &mut Context<'_>, input: &[u8], last: bool, ) -> Poll> { let aad = if last { FINAL_CHUNK_AAD } else { CHUNK_AAD }; - let mut ct = this.hpke.seal(aad, input).map_err(IoError::other)?; + let mut ct = this.cipher.seal(aad, input).map_err(IoError::other)?; let (len, written) = if last { (0, 0) } else { @@ -199,7 +101,8 @@ impl ClientRequest { }; let mut len_buf = [0; 8]; - let len = write_len(&mut len_buf[..], len); + let len = Self::write_len(&mut len_buf[..], len); + println!("chunk: {}", hex::encode(len)); let w = match this.dst.as_mut().poll_write(cx, len) { Poll::Pending => 0, Poll::Ready(Ok(w)) => w, @@ -224,7 +127,7 @@ impl ClientRequest { } } -impl AsyncWrite for ClientRequest { +impl AsyncWrite for ChunkWriter { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -256,7 +159,78 @@ impl AsyncWrite for ClientRequest { } } -enum ChunkState { +#[pin_project(project = ClientProjection)] +pub struct ClientRequest { + #[pin] + writer: ChunkWriter, +} + +impl ClientRequest { + /// Start the processing of a stream. + pub fn start(dst: D, config: HpkeConfig, key_id: KeyId, mut pk: PublicKey) -> Res { + let info = build_info(INFO_REQUEST, key_id, config)?; + let hpke = HpkeS::new(config, &mut pk, &info)?; + + let mut header = Vec::from(&info[INFO_REQUEST.len() + 1..]); + debug_assert_eq!(header.len(), REQUEST_HEADER_LEN); + + let mut e = hpke.enc()?; + header.append(&mut e); + + Ok(Self { + writer: ChunkWriter { + dst, + cipher: hpke, + buf: header, + }, + }) + } + + /// Get an object that can be used to process the response. + /// + /// While this can be used while sending the request, + /// doing so creates a risk of revealing unwanted information to the gateway. + /// That includes the round trip time between client and gateway, + /// which might reveal information about the location of the client. + pub fn response(&self, src: R) -> Res> { + let enc = self.writer.cipher.enc()?; + let secret = export_secret( + &self.writer.cipher, + LABEL_RESPONSE, + self.writer.cipher.config(), + )?; + Ok(ClientResponse { + src, + config: self.writer.cipher.config(), + state: ClientResponseState::Header { + enc, + secret, + nonce: [0; 16], + read: 0, + }, + }) + } +} + +impl AsyncWrite for ClientRequest { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + input: &[u8], + ) -> Poll> { + self.project().writer.as_mut().poll_write(cx, input) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().writer.as_mut().poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().writer.as_mut().poll_close(cx) + } +} + +enum ChunkReader { Length { len: [u8; 8], offset: usize, @@ -269,7 +243,7 @@ enum ChunkState { Done, } -impl ChunkState { +impl ChunkReader { fn length() -> Self { Self::Length { len: [0; 8], @@ -288,44 +262,386 @@ impl ChunkState { length, } } + + fn read_length( + &mut self, + mut src: Pin<&mut S>, + cx: &mut Context<'_>, + aead: &mut A, + ) -> Option>> { + // Read the first byte. + let Self::Length { len, offset } = self else { + return None; + }; + + if *offset == 0 { + match src.as_mut().poll_read(cx, &mut len[..1]) { + Poll::Pending => return Some(Poll::Pending), + Poll::Ready(Ok(0)) => { + return Some(Poll::Ready(Err(IoError::other(Error::Truncated)))); + } + Poll::Ready(Ok(1)) => {} + Poll::Ready(Ok(_)) => unreachable!(), + e @ Poll::Ready(Err(_)) => return Some(e), + } + } + + let form = len[0] >> 6; + if form == 0 { + *self = Self::data(usize::from(len[0])); + return None; + } + let v = mem::replace(&mut len[0], 0) & 0x3f; + let i = match form { + 1 => 6, + 2 => 4, + 3 => 0, + _ => unreachable!(), + }; + len[i] = v; + *offset = i + 1; + + while *offset < len.len() { + // Read any remaining bytes of the length. + match src.as_mut().poll_read(cx, &mut len[*offset..]) { + Poll::Pending => return Some(Poll::Pending), + Poll::Ready(Ok(0)) => { + return Some(Poll::Ready(Err(IoError::other(Error::Truncated)))); + } + Poll::Ready(Ok(r)) => { + *offset += r; + if *offset < 8 { + continue; + } + let remaining = match usize::try_from(u64::from_be_bytes(*len)) { + Ok(remaining) => remaining, + Err(e) => return Some(Poll::Ready(Err(IoError::other(e)))), + }; + if remaining > MAX_CHUNK_PLAINTEXT + aead.alg().n_t() { + return Some(Poll::Ready(Err(IoError::other(Error::ChunkTooLarge)))); + } + *self = Self::data(remaining); + return None; + } + e @ Poll::Ready(Err(_)) => return Some(e), + } + } + None + } + + /// Optional optimization that reads a single chunk into the output buffer. + fn read_into_output( + &mut self, + mut src: Pin<&mut S>, + cx: &mut Context<'_>, + aead: &mut A, + output: &mut [u8], + ) -> Option>> { + let Self::Data { + buf, + offset, + length, + } = self + else { + return None; + }; + if *length == 0 || *offset > 0 || output.len() < *length { + // We need to pull in a complete chunk in one go for this to be worthwhile. + return None; + } + + match src.as_mut().poll_read(cx, &mut output[..*length]) { + Poll::Pending => Some(Poll::Pending), + Poll::Ready(Ok(0)) => Some(Poll::Ready(Err(IoError::other(Error::Truncated)))), + Poll::Ready(Ok(r)) => { + if r == *length { + let pt = match aead.open(CHUNK_AAD, &output[..r]) { + Ok(pt) => pt, + Err(e) => return Some(Poll::Ready(Err(IoError::other(e)))), + }; + output[..pt.len()].copy_from_slice(&pt); + *self = Self::length(); + Some(Poll::Ready(Ok(pt.len()))) + } else { + buf.reserve_exact(*length); + buf.extend_from_slice(&output[..r]); + buf.resize(*length, 0); + *offset += r; + None + } + } + e @ Poll::Ready(Err(_)) => Some(e), + } + } + + fn read( + &mut self, + mut src: Pin<&mut S>, + cx: &mut Context<'_>, + cipher: &mut A, + output: &mut [u8], + ) -> Poll> { + while !matches!(self, Self::Done) { + if let Some(res) = self.read_length(src.as_mut(), cx, cipher) { + return res; + } + + // Read data. + if let Some(res) = self.read_into_output(src.as_mut(), cx, cipher, output) { + return res; + } + + let Self::Data { + buf, + offset, + length, + } = self + else { + unreachable!(); + }; + + // Allocate now as needed. + let last = *length == 0; + if buf.is_empty() { + let sz = if last { + MAX_CHUNK_PLAINTEXT + cipher.alg().n_t() + } else { + *length + }; + buf.resize(sz, 0); + } + + match src.as_mut().poll_read(cx, &mut buf[*offset..]) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Ok(0)) => { + if last { + buf.truncate(*offset); + } else { + return Poll::Ready(Err(IoError::other(Error::Truncated))); + } + } + Poll::Ready(Ok(r)) => { + *offset += r; + if last || *offset < *length { + continue; // Keep reading + } + } + e @ Poll::Ready(Err(_)) => return e, + } + + let aad = if last { FINAL_CHUNK_AAD } else { CHUNK_AAD }; + let pt = cipher.open(aad, buf).map_err(IoError::other)?; + output[..pt.len()].copy_from_slice(&pt); + + if last { + *self = Self::Done; + } else { + *self = Self::length(); + if pt.is_empty() { + continue; // Read the next chunk + } + }; + + return Poll::Ready(Ok(pt.len())); + } + + Poll::Ready(Ok(0)) + } } -#[allow(dead_code)] // TODO enum ServerRequestState { HpkeConfig { - config: KeyConfig, buf: [u8; 7], read: usize, }, Enc { config: HpkeConfig, info: Vec, - buf: Vec, + read: usize, }, Body { hpke: HpkeR, - state: ChunkState, + state: ChunkReader, }, } -#[pin_project::pin_project(project = ServerRequestProjection)] +#[pin_project(project = ServerRequestProjection)] pub struct ServerRequest { #[pin] src: S, + key_config: KeyConfig, + enc: Vec, state: ServerRequestState, } impl ServerRequest { - pub fn new(config: KeyConfig, src: S) -> Self { + pub fn new(key_config: KeyConfig, src: S) -> Self { Self { src, + key_config, + enc: Vec::new(), state: ServerRequestState::HpkeConfig { - config, buf: [0; 7], read: 0, }, } } + + /// Get a response that wraps the given async write instance. + /// This fails with an error if the request header hasn't been processed. + /// This condition is not exposed through a future anywhere, + /// but you can wait for the first byte of data. + pub fn response(&self, dst: D) -> Res> { + let ServerRequestState::Body { hpke, state: _ } = &self.state else { + return Err(Error::NotReady); + }; + + let response_nonce = random(entropy(hpke.config())); + let aead = make_aead( + Mode::Encrypt, + hpke.config(), + &export_secret(hpke, LABEL_RESPONSE, hpke.config())?, + &self.enc, + &response_nonce, + )?; + Ok(ServerResponse { + writer: ChunkWriter { + dst, + cipher: aead, + buf: response_nonce, + }, + }) + } +} + +impl ServerRequest { + fn read_config( + this: &mut ServerRequestProjection<'_, S>, + cx: &mut Context<'_>, + ) -> Option>> { + let ServerRequestState::HpkeConfig { buf, read } = this.state else { + return None; + }; + + while *read < buf.len() { + match this.src.as_mut().poll_read(cx, &mut buf[*read..]) { + Poll::Pending => return Some(Poll::Pending), + Poll::Ready(Ok(0)) => { + return Some(Poll::Ready(Err(IoError::other(Error::Truncated)))) + } + Poll::Ready(Ok(len)) => { + *read += len; + } + e @ Poll::Ready(Err(_)) => return Some(e), + } + } + + let config = match this + .key_config + .decode_hpke_config(&mut Cursor::new(&buf[..])) + { + Ok(cfg) => cfg, + Err(e) => return Some(Poll::Ready(Err(IoError::other(e)))), + }; + let info = match build_info(INFO_REQUEST, this.key_config.key_id, config) { + Ok(info) => info, + Err(e) => return Some(Poll::Ready(Err(IoError::other(e)))), + }; + this.enc.resize(config.kem().n_enc(), 0); + + *this.state = ServerRequestState::Enc { + config, + info, + read: 0, + }; + None + } + + fn read_enc( + this: &mut ServerRequestProjection<'_, S>, + cx: &mut Context<'_>, + ) -> Option>> { + let ServerRequestState::Enc { config, info, read } = this.state else { + return None; + }; + + while *read < this.enc.len() { + match this.src.as_mut().poll_read(cx, &mut this.enc[*read..]) { + Poll::Pending => return Some(Poll::Pending), + Poll::Ready(Ok(0)) => { + return Some(Poll::Ready(Err(IoError::other(Error::Truncated)))) + } + Poll::Ready(Ok(len)) => { + *read += len; + } + e @ Poll::Ready(Err(_)) => return Some(e), + } + } + + let hpke = match HpkeR::new( + *config, + &this.key_config.pk, + this.key_config.sk.as_ref().unwrap(), + this.enc, + info, + ) { + Ok(hpke) => hpke, + Err(e) => return Some(Poll::Ready(Err(IoError::other(e)))), + }; + + *this.state = ServerRequestState::Body { + hpke, + state: ChunkReader::length(), + }; + None + } +} + +impl AsyncRead for ServerRequest { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + output: &mut [u8], + ) -> Poll> { + let mut this = self.project(); + if let Some(res) = Self::read_config(&mut this, cx) { + return res; + } + + if let Some(res) = Self::read_enc(&mut this, cx) { + return res; + } + + if let ServerRequestState::Body { hpke, state } = this.state { + state.read(this.src, cx, hpke, output) + } else { + Poll::Ready(Ok(0)) + } + } +} + +#[pin_project(project = ServerResponseProjection)] +pub struct ServerResponse { + #[pin] + writer: ChunkWriter, +} + +impl AsyncWrite for ServerResponse { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + input: &[u8], + ) -> Poll> { + self.project().writer.as_mut().poll_write(cx, input) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().writer.as_mut().poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().writer.as_mut().poll_close(cx) + } } enum ClientResponseState { @@ -337,23 +653,11 @@ enum ClientResponseState { }, Body { aead: Aead, - state: ChunkState, + state: ChunkReader, }, } -impl ClientResponseState { - fn done(&self) -> bool { - matches!( - self, - Self::Body { - state: ChunkState::Done, - .. - } - ) - } -} - -#[pin_project::pin_project(project = ClientResponseProjection)] +#[pin_project(project = ClientResponseProjection)] pub struct ClientResponse { #[pin] src: S, @@ -366,154 +670,46 @@ impl ClientResponse { this: &mut ClientResponseProjection<'_, S>, cx: &mut Context<'_>, ) -> Option>> { - if let ClientResponseState::Header { + let ClientResponseState::Header { enc, secret, nonce, read, - } = &mut this.state - { - let aead = match this.src.as_mut().poll_read(cx, &mut nonce[*read..]) { + } = this.state + else { + return None; + }; + loop { + match this.src.as_mut().poll_read(cx, &mut nonce[*read..]) { Poll::Pending => return Some(Poll::Pending), Poll::Ready(Ok(0)) => { return Some(Poll::Ready(Err(IoError::other(Error::Truncated)))) } Poll::Ready(Ok(len)) => { *read += len; - if *read < entropy(*this.config) { - return Some(Poll::Pending); - } - match make_aead( - Mode::Decrypt, - *this.config, - secret, - mem::take(enc), - &nonce[..entropy(*this.config)], - ) { - Ok(aead) => aead, - Err(e) => return Some(Poll::Ready(Err(IoError::other(e)))), + if *read == entropy(*this.config) { + break; } } e @ Poll::Ready(Err(_)) => return Some(e), - }; - - *this.state = ClientResponseState::Body { - aead, - state: ChunkState::length(), - }; - }; - None - } - - fn read_length( - this: &mut ClientResponseProjection<'_, S>, - cx: &mut Context<'_>, - ) -> Option>> { - if let ClientResponseState::Body { aead: _, state } = this.state { - // Read the first byte. - if let ChunkState::Length { len, offset } = state { - if *offset == 0 { - match this.src.as_mut().poll_read(cx, &mut len[..1]) { - Poll::Pending => return Some(Poll::Pending), - Poll::Ready(Ok(0)) => { - return Some(Poll::Ready(Err(IoError::other(Error::Truncated)))); - } - Poll::Ready(Ok(1)) => { - let form = len[0] >> 6; - if form == 0 { - *state = ChunkState::data(usize::from(len[0])); - } else { - let v = mem::replace(&mut len[0], 0) & 0x3f; - let i = match form { - 1 => 6, - 2 => 4, - 3 => 0, - _ => unreachable!(), - }; - len[i] = v; - *offset = i + 1; - } - } - Poll::Ready(Ok(_)) => unreachable!(), - e @ Poll::Ready(Err(_)) => return Some(e), - } - } - } - - // Read any remaining bytes of the length. - if let ChunkState::Length { len, offset } = state { - if *offset != 0 { - *state = match this.src.as_mut().poll_read(cx, &mut len[*offset..]) { - Poll::Pending => return Some(Poll::Pending), - Poll::Ready(Ok(0)) => { - return Some(Poll::Ready(Err(IoError::other(Error::Truncated)))); - } - Poll::Ready(Ok(r)) => { - *offset += r; - if *offset < 8 { - return Some(Poll::Pending); - } - let remaining = match usize::try_from(u64::from_be_bytes(*len)) { - Ok(remaining) => remaining, - Err(e) => return Some(Poll::Ready(Err(IoError::other(e)))), - }; - if remaining > MAX_CHUNK_PLAINTEXT + this.config.aead().n_t() { - return Some(Poll::Ready(Err(IoError::other( - Error::ChunkTooLarge, - )))); - } - ChunkState::data(remaining) - } - e @ Poll::Ready(Err(_)) => return Some(e), - }; - } } } - None - } - - /// Optional optimization that reads a single chunk into the output buffer. - fn read_into_output( - this: &mut ClientResponseProjection<'_, S>, - cx: &mut Context<'_>, - output: &mut [u8], - ) -> Option>> { - if let ClientResponseState::Body { aead, state } = this.state { - if let ChunkState::Data { - buf, - offset, - length, - } = state - { - if *length > 0 && *offset == 0 && output.len() + this.config.aead().n_t() >= *length - { - match this.src.as_mut().poll_read(cx, output) { - Poll::Pending => return Some(Poll::Pending), - Poll::Ready(Ok(0)) => { - return Some(Poll::Ready(Err(IoError::other(Error::Truncated)))); - } - Poll::Ready(Ok(r)) => { - if r < *length { - buf.extend_from_slice(&output[..r]); - *offset += r; - return Some(Poll::Pending); - } - - let pt = match aead.open(CHUNK_AAD, &output[..r]) { - Ok(pt) => pt, - Err(e) => return Some(Poll::Ready(Err(IoError::other(e)))), - }; - output[..pt.len()].copy_from_slice(&pt); - *state = ChunkState::length(); - return Some(Poll::Ready(Ok(pt.len()))); - } - e @ Poll::Ready(Err(_)) => return Some(e), - } - } - } - } + let aead = match make_aead( + Mode::Decrypt, + *this.config, + secret, + enc, + &nonce[..entropy(*this.config)], + ) { + Ok(aead) => aead, + Err(e) => return Some(Poll::Ready(Err(IoError::other(e)))), + }; + *this.state = ClientResponseState::Body { + aead, + state: ChunkReader::length(), + }; None } } @@ -529,77 +725,19 @@ impl AsyncRead for ClientResponse { return res; } - while !this.state.done() { - if let Some(res) = Self::read_length(&mut this, cx) { - return res; - } - - // Read data. - if let Some(res) = Self::read_into_output(&mut this, cx, output) { - return res; - } - - if let ClientResponseState::Body { aead, state } = this.state { - if let ChunkState::Data { - buf, - offset, - length, - } = state - { - // Allocate now as needed. - let last = *length == 0; - if buf.is_empty() { - let sz = if last { - MAX_CHUNK_PLAINTEXT + this.config.aead().n_t() - } else { - *length - }; - buf.resize(sz, 0); - } - - let aad = match this.src.as_mut().poll_read(cx, &mut buf[*offset..]) { - Poll::Pending => return Poll::Pending, - Poll::Ready(Ok(0)) => { - if !last { - return Poll::Ready(Err(IoError::other(Error::Truncated))); - } - - FINAL_CHUNK_AAD - } - Poll::Ready(Ok(r)) => { - if *offset + r < *length { - buf.extend_from_slice(&output[..r]); - *offset += r; - return Poll::Pending; - } - - CHUNK_AAD - } - e @ Poll::Ready(Err(_)) => return e, - }; - - let pt = aead.open(aad, buf).map_err(IoError::other)?; - output[..pt.len()].copy_from_slice(&pt); - *state = if last { - ChunkState::Done - } else { - ChunkState::length() - }; - if !pt.is_empty() { - return Poll::Ready(Ok(pt.len())); - } - } - } + if let ClientResponseState::Body { aead, state } = this.state { + state.read(this.src, cx, aead, output) + } else { + Poll::Ready(Ok(0)) } - Poll::Ready(Ok(0)) } } #[cfg(test)] mod test { - use futures::{io::Cursor, AsyncReadExt, AsyncWriteExt}; + use futures::AsyncWriteExt; use log::trace; - use sync_async::{SyncRead, SyncResolve}; + use sync_async::{Pipe, SyncRead, SyncResolve}; use crate::{ test::{init, make_config, REQUEST, RESPONSE}, @@ -616,19 +754,24 @@ mod test { trace!("Config: {}", hex::encode(&encoded_config)); let client = ClientRequest::from_encoded_config(&encoded_config).unwrap(); - let (mut request_read, request_write) = AsyncReadExt::split(Cursor::new(Vec::new())); + let (mut request_read, request_write) = Pipe::new(); let mut client_request = client.encapsulate_stream(request_write).unwrap(); client_request.write_all(REQUEST).sync_resolve().unwrap(); client_request.close().sync_resolve().unwrap(); trace!("Request: {}", hex::encode(REQUEST)); - let request_buf = request_read.sync_read_to_end(); - trace!("Encapsulated Request: {}", hex::encode(&request_buf)); + let enc_request = request_read.sync_read_to_end(); + trace!("Encapsulated Request: {}", hex::encode(&enc_request)); + + let mut server_request = server.decapsulate_stream(&enc_request[..]); + assert_eq!(server_request.sync_read_to_end(), REQUEST); - let (request, server_response) = server.decapsulate(&request_buf[..]).unwrap(); - assert_eq!(&request[..], REQUEST); + let (mut response_read, response_write) = Pipe::new(); + let mut server_response = server_request.response(response_write).unwrap(); + server_response.write_all(RESPONSE).sync_resolve().unwrap(); + server_response.close().sync_resolve().unwrap(); - let enc_response = server_response.encapsulate(RESPONSE).unwrap(); + let enc_response = response_read.sync_read_to_end(); trace!("Encapsulated Response: {}", hex::encode(&enc_response)); let mut client_response = client_request.response(&enc_response[..]).unwrap(); diff --git a/sync-async/src/lib.rs b/sync-async/src/lib.rs index f421021..07cb8ed 100644 --- a/sync-async/src/lib.rs +++ b/sync-async/src/lib.rs @@ -1,11 +1,15 @@ use std::{ + cmp::min, future::Future, io::Result as IoResult, pin::{pin, Pin}, task::{Context, Poll}, }; -use futures::{AsyncRead, AsyncReadExt, AsyncWrite, TryStream, TryStreamExt}; +use futures::{ + io::{ReadHalf, WriteHalf}, + AsyncRead, AsyncReadExt, AsyncWrite, TryStream, TryStreamExt, +}; use pin_project::pin_project; fn noop_context() -> Context<'static> { @@ -203,3 +207,64 @@ impl AsyncWrite for Stutter { Self::stutter(self, cx, AsyncWrite::poll_close) } } + +/// A Cursor implementation that has separate read and write cursors. +/// +/// This allows tests to create paired read and write objects, +/// where writes to one can be read by the other. +/// +/// This relies on the implementation of `AyncReadExt::split` to provide +/// any locking and concurrency, rather than implementing it. +#[derive(Default)] +#[pin_project] +pub struct Pipe { + buf: Vec, + r: usize, + w: usize, +} + +impl Pipe { + #[must_use] + pub fn new() -> (ReadHalf, WriteHalf) { + AsyncReadExt::split(Self::default()) + } +} + +impl AsyncRead for Pipe { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let amnt = min(buf.len(), self.buf.len() - self.r); + buf[..amnt].copy_from_slice(&self.buf[self.r..self.r + amnt]); + self.r += amnt; + Poll::Ready(Ok(amnt)) + } +} + +impl AsyncWrite for Pipe { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + mut buf: &[u8], + ) -> Poll> { + if self.w < self.buf.len() { + let overlap = min(buf.len() - self.w, self.buf.len()); + let range = self.w..(self.w + overlap); + self.buf[range].copy_from_slice(&buf[..overlap]); + buf = &buf[overlap..]; + } + self.buf.extend_from_slice(buf); + self.w += buf.len(); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} From 20d238f1092f4fbd44089932eaf547fef0075612 Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Thu, 19 Dec 2024 18:47:58 +1100 Subject: [PATCH 20/20] Working better --- ohttp/src/crypto.rs | 5 +- ohttp/src/nss/mod.rs | 6 +- ohttp/src/rh/aead.rs | 2 - ohttp/src/rh/hpke.rs | 3 - ohttp/src/stream.rs | 209 +++++++++++++++++++++---------------------- pre-commit | 2 +- 6 files changed, 109 insertions(+), 118 deletions(-) diff --git a/ohttp/src/crypto.rs b/ohttp/src/crypto.rs index b20a019..a22b74e 100644 --- a/ohttp/src/crypto.rs +++ b/ohttp/src/crypto.rs @@ -2,11 +2,12 @@ use crate::{err::Res, AeadId}; pub trait Decrypt { fn open(&mut self, aad: &[u8], ct: &[u8]) -> Res>; + #[allow(dead_code)] // Used by stream feature. fn alg(&self) -> AeadId; } pub trait Encrypt { - #[allow(dead_code)] // TODO - fn alg(&self) -> AeadId; fn seal(&mut self, aad: &[u8], ct: &[u8]) -> Res>; + #[allow(dead_code)] // Used by stream feature. + fn alg(&self) -> AeadId; } diff --git a/ohttp/src/nss/mod.rs b/ohttp/src/nss/mod.rs index 4f30b22..91dc44e 100644 --- a/ohttp/src/nss/mod.rs +++ b/ohttp/src/nss/mod.rs @@ -1,8 +1,4 @@ -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. +#![allow(clippy::incompatible_msrv)] // This feature needs 1.70 mod err; #[macro_use] diff --git a/ohttp/src/rh/aead.rs b/ohttp/src/rh/aead.rs index 9521b30..cc05e89 100644 --- a/ohttp/src/rh/aead.rs +++ b/ohttp/src/rh/aead.rs @@ -115,7 +115,6 @@ impl Aead { impl Decrypt for Aead { fn open(&mut self, aad: &[u8], ct: &[u8]) -> Res> { - println!("aead open: {}", hex::encode(ct)); let res = self.open_seq(aad, self.seq, ct); self.seq += 1; res @@ -133,7 +132,6 @@ impl Encrypt for Aead { let nonce = self.nonce(self.seq); self.seq += 1; let ct = self.engine.encrypt(&nonce, Payload { msg: pt, aad })?; - println!("aead seal: {}", hex::encode(&ct)); Ok(ct) } diff --git a/ohttp/src/rh/hpke.rs b/ohttp/src/rh/hpke.rs index 48c83aa..58e29c7 100644 --- a/ohttp/src/rh/hpke.rs +++ b/ohttp/src/rh/hpke.rs @@ -311,7 +311,6 @@ impl Encrypt for HpkeS { let mut buf = pt.to_owned(); let mut tag = self.context.seal(&mut buf, aad)?; buf.append(&mut tag); - println!("hpke seal: {}", hex::encode(&buf)); Ok(buf) } @@ -534,8 +533,6 @@ impl HpkeR { impl Decrypt for HpkeR { fn open(&mut self, aad: &[u8], ct: &[u8]) -> Res> { - println!("hpke open: {}", hex::encode(&ct)); - let mut buf = ct.to_owned(); let pt_len = self.context.open(&mut buf, aad)?.len(); buf.truncate(pt_len); diff --git a/ohttp/src/stream.rs b/ohttp/src/stream.rs index 4b6751f..f66f707 100644 --- a/ohttp/src/stream.rs +++ b/ohttp/src/stream.rs @@ -29,6 +29,14 @@ const MAX_CHUNK_PLAINTEXT: usize = 1 << 14; const CHUNK_AAD: &[u8] = b""; const FINAL_CHUNK_AAD: &[u8] = b"final"; +#[allow(clippy::unnecessary_wraps)] +fn some_error(e: E) -> Option>> +where + Error: From, +{ + Some(Poll::Ready(Err(IoError::other(Error::from(e))))) +} + #[pin_project(project = ChunkWriterProjection)] struct ChunkWriter { #[pin] @@ -52,12 +60,12 @@ impl ChunkWriter { } } -impl ChunkWriter { +impl ChunkWriter { /// Flush our buffer. /// Returns `Some` if the flush blocks or is unsuccessful. /// If that contains `Ready`, it does so only when there is an error. fn flush( - this: &mut ChunkWriterProjection<'_, D, E>, + this: &mut ChunkWriterProjection<'_, D, C>, cx: &mut Context<'_>, ) -> Option> { while !this.buf.is_empty() { @@ -87,7 +95,7 @@ impl ChunkWriter { } fn write_chunk( - this: &mut ChunkWriterProjection<'_, D, E>, + this: &mut ChunkWriterProjection<'_, D, C>, cx: &mut Context<'_>, input: &[u8], last: bool, @@ -102,7 +110,6 @@ impl ChunkWriter { let mut len_buf = [0; 8]; let len = Self::write_len(&mut len_buf[..], len); - println!("chunk: {}", hex::encode(len)); let w = match this.dst.as_mut().poll_write(cx, len) { Poll::Pending => 0, Poll::Ready(Ok(w)) => w, @@ -127,7 +134,7 @@ impl ChunkWriter { } } -impl AsyncWrite for ChunkWriter { +impl AsyncWrite for ChunkWriter { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -263,78 +270,98 @@ impl ChunkReader { } } - fn read_length( - &mut self, + fn read_fixed( mut src: Pin<&mut S>, cx: &mut Context<'_>, - aead: &mut A, + buf: &mut [u8], + offset: &mut usize, ) -> Option>> { - // Read the first byte. - let Self::Length { len, offset } = self else { - return None; - }; - - if *offset == 0 { - match src.as_mut().poll_read(cx, &mut len[..1]) { + while *offset < buf.len() { + // Read any remaining bytes of the length. + match src.as_mut().poll_read(cx, &mut buf[*offset..]) { Poll::Pending => return Some(Poll::Pending), Poll::Ready(Ok(0)) => { - return Some(Poll::Ready(Err(IoError::other(Error::Truncated)))); + return some_error(Error::Truncated); + } + Poll::Ready(Ok(r)) => { + *offset += r; } - Poll::Ready(Ok(1)) => {} - Poll::Ready(Ok(_)) => unreachable!(), e @ Poll::Ready(Err(_)) => return Some(e), } } + None + } + + fn read_length0( + &mut self, + mut src: Pin<&mut S>, + cx: &mut Context<'_>, + ) -> Option>> { + let Self::Length { len, offset } = self else { + return None; + }; + + let res = Self::read_fixed(src.as_mut(), cx, &mut len[..1], offset); + if res.is_some() { + return res; + } let form = len[0] >> 6; if form == 0 { *self = Self::data(usize::from(len[0])); - return None; + } else { + let v = mem::replace(&mut len[0], 0) & 0x3f; + let i = match form { + 1 => 6, + 2 => 4, + 3 => 0, + _ => unreachable!(), + }; + len[i] = v; + *offset = i + 1; + } + None + } + + fn read_length( + &mut self, + mut src: Pin<&mut S>, + cx: &mut Context<'_>, + aead: &mut C, + ) -> Option>> { + // Read the first byte. + let res = self.read_length0(src.as_mut(), cx); + if res.is_some() { + return res; } - let v = mem::replace(&mut len[0], 0) & 0x3f; - let i = match form { - 1 => 6, - 2 => 4, - 3 => 0, - _ => unreachable!(), + + let Self::Length { len, offset } = self else { + return None; }; - len[i] = v; - *offset = i + 1; - while *offset < len.len() { - // Read any remaining bytes of the length. - match src.as_mut().poll_read(cx, &mut len[*offset..]) { - Poll::Pending => return Some(Poll::Pending), - Poll::Ready(Ok(0)) => { - return Some(Poll::Ready(Err(IoError::other(Error::Truncated)))); - } - Poll::Ready(Ok(r)) => { - *offset += r; - if *offset < 8 { - continue; - } - let remaining = match usize::try_from(u64::from_be_bytes(*len)) { - Ok(remaining) => remaining, - Err(e) => return Some(Poll::Ready(Err(IoError::other(e)))), - }; - if remaining > MAX_CHUNK_PLAINTEXT + aead.alg().n_t() { - return Some(Poll::Ready(Err(IoError::other(Error::ChunkTooLarge)))); - } - *self = Self::data(remaining); - return None; - } - e @ Poll::Ready(Err(_)) => return Some(e), - } + let res = Self::read_fixed(src.as_mut(), cx, &mut len[..], offset); + if res.is_some() { + return res; + } + + let remaining = match usize::try_from(u64::from_be_bytes(*len)) { + Ok(remaining) => remaining, + Err(e) => return some_error(e), + }; + if remaining > MAX_CHUNK_PLAINTEXT + aead.alg().n_t() { + return some_error(Error::ChunkTooLarge); } + + *self = Self::data(remaining); None } /// Optional optimization that reads a single chunk into the output buffer. - fn read_into_output( + fn read_into_output( &mut self, mut src: Pin<&mut S>, cx: &mut Context<'_>, - aead: &mut A, + aead: &mut C, output: &mut [u8], ) -> Option>> { let Self::Data { @@ -352,12 +379,12 @@ impl ChunkReader { match src.as_mut().poll_read(cx, &mut output[..*length]) { Poll::Pending => Some(Poll::Pending), - Poll::Ready(Ok(0)) => Some(Poll::Ready(Err(IoError::other(Error::Truncated)))), + Poll::Ready(Ok(0)) => some_error(Error::Truncated), Poll::Ready(Ok(r)) => { if r == *length { let pt = match aead.open(CHUNK_AAD, &output[..r]) { Ok(pt) => pt, - Err(e) => return Some(Poll::Ready(Err(IoError::other(e)))), + Err(e) => return some_error(e), }; output[..pt.len()].copy_from_slice(&pt); *self = Self::length(); @@ -374,11 +401,11 @@ impl ChunkReader { } } - fn read( + fn read( &mut self, mut src: Pin<&mut S>, cx: &mut Context<'_>, - cipher: &mut A, + cipher: &mut C, output: &mut [u8], ) -> Poll> { while !matches!(self, Self::Done) { @@ -523,17 +550,9 @@ impl ServerRequest { return None; }; - while *read < buf.len() { - match this.src.as_mut().poll_read(cx, &mut buf[*read..]) { - Poll::Pending => return Some(Poll::Pending), - Poll::Ready(Ok(0)) => { - return Some(Poll::Ready(Err(IoError::other(Error::Truncated)))) - } - Poll::Ready(Ok(len)) => { - *read += len; - } - e @ Poll::Ready(Err(_)) => return Some(e), - } + let res = ChunkReader::read_fixed(this.src.as_mut(), cx, &mut buf[..], read); + if res.is_some() { + return res; } let config = match this @@ -541,11 +560,11 @@ impl ServerRequest { .decode_hpke_config(&mut Cursor::new(&buf[..])) { Ok(cfg) => cfg, - Err(e) => return Some(Poll::Ready(Err(IoError::other(e)))), + Err(e) => return some_error(e), }; let info = match build_info(INFO_REQUEST, this.key_config.key_id, config) { Ok(info) => info, - Err(e) => return Some(Poll::Ready(Err(IoError::other(e)))), + Err(e) => return some_error(e), }; this.enc.resize(config.kem().n_enc(), 0); @@ -565,17 +584,9 @@ impl ServerRequest { return None; }; - while *read < this.enc.len() { - match this.src.as_mut().poll_read(cx, &mut this.enc[*read..]) { - Poll::Pending => return Some(Poll::Pending), - Poll::Ready(Ok(0)) => { - return Some(Poll::Ready(Err(IoError::other(Error::Truncated)))) - } - Poll::Ready(Ok(len)) => { - *read += len; - } - e @ Poll::Ready(Err(_)) => return Some(e), - } + let res = ChunkReader::read_fixed(this.src.as_mut(), cx, &mut this.enc[..], read); + if res.is_some() { + return res; } let hpke = match HpkeR::new( @@ -586,7 +597,7 @@ impl ServerRequest { info, ) { Ok(hpke) => hpke, - Err(e) => return Some(Poll::Ready(Err(IoError::other(e)))), + Err(e) => return some_error(e), }; *this.state = ServerRequestState::Body { @@ -679,31 +690,16 @@ impl ClientResponse { else { return None; }; - loop { - match this.src.as_mut().poll_read(cx, &mut nonce[*read..]) { - Poll::Pending => return Some(Poll::Pending), - Poll::Ready(Ok(0)) => { - return Some(Poll::Ready(Err(IoError::other(Error::Truncated)))) - } - Poll::Ready(Ok(len)) => { - *read += len; - if *read == entropy(*this.config) { - break; - } - } - e @ Poll::Ready(Err(_)) => return Some(e), - } + + let nonce = &mut nonce[..entropy(*this.config)]; + let res = ChunkReader::read_fixed(this.src.as_mut(), cx, nonce, read); + if res.is_some() { + return res; } - let aead = match make_aead( - Mode::Decrypt, - *this.config, - secret, - enc, - &nonce[..entropy(*this.config)], - ) { + let aead = match make_aead(Mode::Decrypt, *this.config, secret, enc, nonce) { Ok(aead) => aead, - Err(e) => return Some(Poll::Ready(Err(IoError::other(e)))), + Err(e) => return some_error(e), }; *this.state = ClientResponseState::Body { @@ -753,6 +749,7 @@ mod test { let encoded_config = server.config().encode().unwrap(); trace!("Config: {}", hex::encode(&encoded_config)); + // The client sends a request. let client = ClientRequest::from_encoded_config(&encoded_config).unwrap(); let (mut request_read, request_write) = Pipe::new(); let mut client_request = client.encapsulate_stream(request_write).unwrap(); @@ -763,9 +760,11 @@ mod test { let enc_request = request_read.sync_read_to_end(); trace!("Encapsulated Request: {}", hex::encode(&enc_request)); + // The server receives a request. let mut server_request = server.decapsulate_stream(&enc_request[..]); assert_eq!(server_request.sync_read_to_end(), REQUEST); + // The server sends a response. let (mut response_read, response_write) = Pipe::new(); let mut server_response = server_request.response(response_write).unwrap(); server_response.write_all(RESPONSE).sync_resolve().unwrap(); @@ -774,8 +773,8 @@ mod test { let enc_response = response_read.sync_read_to_end(); trace!("Encapsulated Response: {}", hex::encode(&enc_response)); + // The client receives a response. let mut client_response = client_request.response(&enc_response[..]).unwrap(); - let response_buf = client_response.sync_read_to_end(); assert_eq!(response_buf, RESPONSE); trace!("Response: {}", hex::encode(response_buf)); diff --git a/pre-commit b/pre-commit index 18e917d..4693b48 100755 --- a/pre-commit +++ b/pre-commit @@ -63,7 +63,7 @@ fi check() { msg="$1" shift - if ! echo "$@"; then + if ! "$@"; then echo "${msg}: Failed command:" echo " ${@@Q}" exit 1