diff --git a/.gitignore b/.gitignore index 200abe7..39ac3d3 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ *~ *.swp /.vscode/ +/mutants.out*/ diff --git a/Cargo.toml b/Cargo.toml index 0621518..eb9b298 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,4 +7,5 @@ members = [ "ohttp-client", "ohttp-client-cli", "ohttp-server", + "sync-async", ] diff --git a/bhttp-convert/Cargo.toml b/bhttp-convert/Cargo.toml index 8a82101..47e11f1 100644 --- a/bhttp-convert/Cargo.toml +++ b/bhttp-convert/Cargo.toml @@ -9,4 +9,4 @@ structopt = "0.3" [dependencies.bhttp] path= "../bhttp" -features = ["bhttp", "http"] +features = ["http"] diff --git a/bhttp-convert/src/main.rs b/bhttp-convert/src/main.rs index 763fadf..054c64f 100644 --- a/bhttp-convert/src/main.rs +++ b/bhttp-convert/src/main.rs @@ -1,11 +1,12 @@ #![deny(warnings, clippy::pedantic)] -use bhttp::{Message, Mode}; use std::{ fs::File, io::{self, Read}, path::PathBuf, }; + +use bhttp::{Message, Mode}; use structopt::StructOpt; #[derive(Debug, StructOpt)] diff --git a/bhttp/Cargo.toml b/bhttp/Cargo.toml index 8c89536..fe3d325 100644 --- a/bhttp/Cargo.toml +++ b/bhttp/Cargo.toml @@ -9,17 +9,18 @@ description = "Binary HTTP messages (RFC 9292)" repository = "https://github.com/martinthomson/ohttp" [features] -default = ["bhttp"] -bhttp = ["read-bhttp", "write-bhttp"] -http = ["read-http", "write-http"] -read-bhttp = [] -write-bhttp = [] -read-http = ["url"] -write-http = [] +default = [] +http = ["dep:url"] +stream = ["dep:futures", "dep:pin-project"] [dependencies] +futures = {version = "0.3", optional = true} +pin-project = {version = "1.1", optional = true} thiserror = "1" url = {version = "2", optional = true} [dev-dependencies] hex = "0.4" + +[dev-dependencies.sync-async] +path= "../sync-async" diff --git a/bhttp/src/err.rs b/bhttp/src/err.rs index 19d5455..45d6fab 100644 --- a/bhttp/src/err.rs +++ b/bhttp/src/err.rs @@ -1,14 +1,9 @@ -use thiserror::Error; - -#[derive(Error, Debug)] +#[derive(thiserror::Error, Debug)] pub enum Error { #[error("a request used the CONNECT method")] ConnectUnsupported, #[error("a field contained invalid Unicode: {0}")] CharacterEncoding(#[from] std::string::FromUtf8Error), - #[error("a chunk of data of {0} bytes is too large")] - #[cfg(feature = "stream")] - ChunkTooLarge(u64), #[error("read a response when expecting a request")] ExpectedRequest, #[error("read a request when expecting a response")] @@ -19,8 +14,14 @@ pub enum Error { InvalidMode, #[error("the status code of a response needs to be in 100..=599")] InvalidStatus, + #[cfg(feature = "stream")] + #[error("a method was called when the message was in the wrong state")] + InvalidState, #[error("IO error {0}")] Io(#[from] std::io::Error), + #[cfg(feature = "stream")] + #[error("the size of a vector exceeded the limit that was set")] + LimitExceeded, #[error("a field or line was missing a necessary character 0x{0:x}")] Missing(u8), #[error("a URL was missing a key component")] @@ -34,14 +35,8 @@ pub enum Error { #[error("a message included the Upgrade field")] UpgradeUnsupported, #[error("a URL could not be parsed into components: {0}")] - #[cfg(feature = "read-http")] + #[cfg(feature = "http")] UrlParse(#[from] url::ParseError), } -#[cfg(any( - feature = "read-http", - feature = "write-http", - feature = "read-bhttp", - feature = "write-bhttp" -))] pub type Res = Result; diff --git a/bhttp/src/lib.rs b/bhttp/src/lib.rs index 3c8fbde..2082b0f 100644 --- a/bhttp/src/lib.rs +++ b/bhttp/src/lib.rs @@ -1,51 +1,35 @@ #![deny(warnings, clippy::pedantic)] #![allow(clippy::missing_errors_doc)] // Too lazy to document these. -#[cfg(feature = "read-bhttp")] -use std::convert::TryFrom; -#[cfg(any( - feature = "read-http", - feature = "write-http", - feature = "read-bhttp", - feature = "write-bhttp" -))] -use std::io; - -#[cfg(feature = "read-http")] +use std::{ + borrow::BorrowMut, + io, + ops::{Deref, DerefMut}, +}; + +#[cfg(feature = "http")] use url::Url; mod err; mod parse; -#[cfg(any(feature = "read-bhttp", feature = "write-bhttp"))] mod rw; - -#[cfg(any(feature = "read-http", feature = "read-bhttp",))] -use std::borrow::BorrowMut; +#[cfg(feature = "stream")] +pub mod stream; pub use err::Error; -#[cfg(any( - feature = "read-http", - feature = "write-http", - feature = "read-bhttp", - feature = "write-bhttp" -))] use err::Res; -#[cfg(feature = "read-http")] +#[cfg(feature = "http")] use parse::{downcase, is_ows, read_line, split_at, COLON, SEMICOLON, SLASH, SP}; use parse::{index_of, trim_ows, COMMA}; -#[cfg(feature = "read-bhttp")] -use rw::{read_varint, read_vec}; -#[cfg(feature = "write-bhttp")] -use rw::{write_len, write_varint, write_vec}; +use rw::{read_varint, read_vec, write_len, write_varint, write_vec}; -#[cfg(feature = "read-http")] +#[cfg(feature = "http")] const CONTENT_LENGTH: &[u8] = b"content-length"; -#[cfg(feature = "read-bhttp")] const COOKIE: &[u8] = b"cookie"; const TRANSFER_ENCODING: &[u8] = b"transfer-encoding"; const CHUNKED: &[u8] = b"chunked"; -#[derive(Clone, Copy, PartialEq, Eq)] +#[derive(Clone, Copy, Debug)] pub struct StatusCode(u16); impl StatusCode { @@ -88,17 +72,47 @@ impl From for u16 { } } +#[cfg(test)] +impl PartialEq for StatusCode +where + Self: TryFrom, + T: Copy, +{ + fn eq(&self, other: &T) -> bool { + StatusCode::try_from(*other).map_or(false, |o| o.0 == self.0) + } +} + +#[cfg(not(test))] +impl PartialEq for StatusCode { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Eq for StatusCode {} + pub trait ReadSeek: io::BufRead + io::Seek {} impl ReadSeek for io::Cursor where T: AsRef<[u8]> {} impl ReadSeek for io::BufReader where T: io::Read + io::Seek {} #[derive(Clone, Copy, Debug, PartialEq, Eq)] -#[cfg(any(feature = "read-bhttp", feature = "write-bhttp"))] pub enum Mode { KnownLength, IndeterminateLength, } +impl TryFrom for Mode { + type Error = Error; + fn try_from(t: u64) -> Result { + match t { + 0 | 1 => Ok(Self::KnownLength), + 2 | 3 => Ok(Self::IndeterminateLength), + _ => Err(Error::InvalidMode), + } + } +} + pub struct Field { name: Vec, value: Vec, @@ -120,7 +134,7 @@ impl Field { &self.value } - #[cfg(feature = "write-http")] + #[cfg(feature = "http")] pub fn write_http(&self, w: &mut impl io::Write) -> Res<()> { w.write_all(&self.name)?; w.write_all(b": ")?; @@ -129,20 +143,31 @@ impl Field { Ok(()) } - #[cfg(feature = "write-bhttp")] pub fn write_bhttp(&self, w: &mut impl io::Write) -> Res<()> { write_vec(&self.name, w)?; write_vec(&self.value, w)?; Ok(()) } - #[cfg(feature = "read-http")] + #[cfg(feature = "http")] pub fn obs_fold(&mut self, extra: &[u8]) { self.value.push(SP); self.value.extend(trim_ows(extra)); } } +#[cfg(test)] +impl std::fmt::Debug for Field { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + write!( + f, + "{n}: {v}", + n = String::from_utf8_lossy(&self.name), + v = String::from_utf8_lossy(&self.value), + ) + } +} + #[derive(Default)] pub struct FieldSection(Vec); impl FieldSection { @@ -151,15 +176,26 @@ impl FieldSection { self.0.is_empty() } + #[must_use] + pub fn len(&self) -> usize { + self.0.len() + } + /// Gets the value from the first instance of the field. #[must_use] pub fn get(&self, n: &[u8]) -> Option<&[u8]> { - for f in &self.0 { + self.get_all(n).next() + } + + /// Gets all of the values of the named field. + pub fn get_all<'a, 'b>(&'a self, n: &'b [u8]) -> impl Iterator + use<'a, 'b> { + self.0.iter().filter_map(move |f| { if &f.name[..] == n { - return Some(&f.value); + Some(&f.value[..]) + } else { + None } - } - None + }) } pub fn put(&mut self, name: impl Into>, value: impl Into>) { @@ -192,7 +228,7 @@ impl FieldSection { /// As required by the HTTP specification, remove the Connection header /// field, everything it refers to, and a few extra fields. - #[cfg(feature = "read-http")] + #[cfg(feature = "http")] fn strip_connection_headers(&mut self) { const CONNECTION: &[u8] = b"connection"; const PROXY_CONNECTION: &[u8] = b"proxy-connection"; @@ -232,7 +268,7 @@ impl FieldSection { }); } - #[cfg(feature = "read-http")] + #[cfg(feature = "http")] fn parse_line(fields: &mut Vec, line: Vec) -> Res<()> { // obs-fold is helpful in specs, so support it here too let f = if is_ows(line[0]) { @@ -251,7 +287,7 @@ impl FieldSection { Ok(()) } - #[cfg(feature = "read-http")] + #[cfg(feature = "http")] pub fn read_http(r: &mut T) -> Res where T: BorrowMut + ?Sized, @@ -267,7 +303,6 @@ impl FieldSection { } } - #[cfg(feature = "read-bhttp")] fn read_bhttp_fields(terminator: bool, r: &mut T) -> Res> where T: BorrowMut + ?Sized, @@ -302,7 +337,6 @@ impl FieldSection { } } - #[cfg(feature = "read-bhttp")] pub fn read_bhttp(mode: Mode, r: &mut T) -> Res where T: BorrowMut + ?Sized, @@ -320,7 +354,6 @@ impl FieldSection { Ok(Self(fields)) } - #[cfg(feature = "write-bhttp")] fn write_bhttp_headers(&self, w: &mut impl io::Write) -> Res<()> { for f in &self.0 { f.write_bhttp(w)?; @@ -328,7 +361,6 @@ impl FieldSection { Ok(()) } - #[cfg(feature = "write-bhttp")] pub fn write_bhttp(&self, mode: Mode, w: &mut impl io::Write) -> Res<()> { if mode == Mode::KnownLength { let mut buf = Vec::new(); @@ -341,7 +373,7 @@ impl FieldSection { Ok(()) } - #[cfg(feature = "write-http")] + #[cfg(feature = "http")] pub fn write_http(&self, w: &mut impl io::Write) -> Res<()> { for f in &self.0 { f.write_http(w)?; @@ -351,6 +383,16 @@ impl FieldSection { } } +#[cfg(test)] +impl std::fmt::Debug for FieldSection { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + for fv in self.fields() { + fv.fmt(f)?; + } + Ok(()) + } +} + pub enum ControlData { Request { method: Vec, @@ -420,7 +462,7 @@ impl ControlData { } } - #[cfg(feature = "read-http")] + #[cfg(feature = "http")] pub fn read_http(line: Vec) -> Res { // request-line = method SP request-target SP HTTP-version // status-line = HTTP-version SP status-code SP [reason-phrase] @@ -467,7 +509,6 @@ impl ControlData { } } - #[cfg(feature = "read-bhttp")] pub fn read_bhttp(request: bool, r: &mut T) -> Res where T: BorrowMut + ?Sized, @@ -493,7 +534,6 @@ impl ControlData { } /// If this is an informational response. - #[cfg(any(feature = "read-bhttp", feature = "read-http"))] #[must_use] fn informational(&self) -> Option { match self { @@ -502,7 +542,6 @@ impl ControlData { } } - #[cfg(feature = "write-bhttp")] #[must_use] fn code(&self, mode: Mode) -> u64 { match (self, mode) { @@ -513,7 +552,6 @@ impl ControlData { } } - #[cfg(feature = "write-bhttp")] pub fn write_bhttp(&self, w: &mut impl io::Write) -> Res<()> { match self { Self::Request { @@ -532,7 +570,7 @@ impl ControlData { Ok(()) } - #[cfg(feature = "write-http")] + #[cfg(feature = "http")] pub fn write_http(&self, w: &mut impl io::Write) -> Res<()> { match self { Self::Request { @@ -560,6 +598,68 @@ impl ControlData { } } +#[cfg(test)] +impl PartialEq<(M, S, A, P)> for ControlData +where + M: AsRef<[u8]>, + S: AsRef<[u8]>, + A: AsRef<[u8]>, + P: AsRef<[u8]>, +{ + fn eq(&self, other: &(M, S, A, P)) -> bool { + match self { + Self::Request { + method, + scheme, + authority, + path, + } => { + method == other.0.as_ref() + && scheme == other.1.as_ref() + && authority == other.2.as_ref() + && path == other.3.as_ref() + } + Self::Response(_) => false, + } + } +} + +#[cfg(test)] +impl PartialEq for ControlData +where + StatusCode: TryFrom, + T: Copy, +{ + fn eq(&self, other: &T) -> bool { + match self { + Self::Request { .. } => false, + Self::Response(code) => code == other, + } + } +} + +#[cfg(test)] +impl std::fmt::Debug for ControlData { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + match self { + Self::Request { + method, + scheme, + authority, + path, + } => write!( + f, + "{m} {s}://{a}{p}", + m = String::from_utf8_lossy(method), + s = String::from_utf8_lossy(scheme), + a = String::from_utf8_lossy(authority), + p = String::from_utf8_lossy(path), + ), + Self::Response(code) => write!(f, "{code:?}"), + } + } +} + pub struct InformationalResponse { status: StatusCode, fields: FieldSection, @@ -581,7 +681,6 @@ impl InformationalResponse { &self.fields } - #[cfg(feature = "write-bhttp")] fn write_bhttp(&self, mode: Mode, w: &mut impl io::Write) -> Res<()> { write_varint(self.status.code(), w)?; self.fields.write_bhttp(mode, w)?; @@ -589,10 +688,71 @@ impl InformationalResponse { } } +impl Deref for InformationalResponse { + type Target = FieldSection; + + fn deref(&self) -> &Self::Target { + &self.fields + } +} + +impl DerefMut for InformationalResponse { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.fields + } +} + +pub struct Header { + control: ControlData, + fields: FieldSection, +} + +impl Header { + #[must_use] + pub fn control(&self) -> &ControlData { + &self.control + } +} + +impl From for Header { + fn from(control: ControlData) -> Self { + Self { + control, + fields: FieldSection::default(), + } + } +} + +impl From<(ControlData, FieldSection)> for Header { + fn from((control, fields): (ControlData, FieldSection)) -> Self { + Self { control, fields } + } +} + +impl std::ops::Deref for Header { + type Target = FieldSection; + fn deref(&self) -> &Self::Target { + &self.fields + } +} + +impl std::ops::DerefMut for Header { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.fields + } +} + +#[cfg(test)] +impl std::fmt::Debug for Header { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + self.control.fmt(f)?; + self.fields.fmt(f) + } +} + pub struct Message { informational: Vec, - control: ControlData, - header: FieldSection, + header: Header, content: Vec, trailer: FieldSection, } @@ -602,13 +762,12 @@ impl Message { pub fn request(method: Vec, scheme: Vec, authority: Vec, path: Vec) -> Self { Self { informational: Vec::new(), - control: ControlData::Request { + header: Header::from(ControlData::Request { method, scheme, authority, path, - }, - header: FieldSection::default(), + }), content: Vec::new(), trailer: FieldSection::default(), } @@ -618,8 +777,7 @@ impl Message { pub fn response(status: StatusCode) -> Self { Self { informational: Vec::new(), - control: ControlData::Response(status), - header: FieldSection::default(), + header: Header::from(ControlData::Response(status)), content: Vec::new(), trailer: FieldSection::default(), } @@ -644,11 +802,11 @@ impl Message { #[must_use] pub fn control(&self) -> &ControlData { - &self.control + self.header.control() } #[must_use] - pub fn header(&self) -> &FieldSection { + pub fn header(&self) -> &Header { &self.header } @@ -662,7 +820,7 @@ impl Message { &self.trailer } - #[cfg(feature = "read-http")] + #[cfg(feature = "http")] fn read_chunked(r: &mut T) -> Res> where T: BorrowMut + ?Sized, @@ -686,7 +844,7 @@ impl Message { } } - #[cfg(feature = "read-http")] + #[cfg(feature = "http")] #[allow(clippy::read_zero_byte_vec)] // https://github.com/rust-lang/rust-clippy/issues/9274 pub fn read_http(r: &mut T) -> Res where @@ -703,20 +861,20 @@ impl Message { control = ControlData::read_http(line)?; } - let mut header = FieldSection::read_http(r)?; + let mut hfields = FieldSection::read_http(r)?; let (content, trailer) = if matches!(control.status().map(StatusCode::code), Some(204 | 304)) { // 204 and 304 have no body, no matter what Content-Length says. // Unfortunately, we can't do the same for responses to HEAD. (Vec::new(), FieldSection::default()) - } else if header.is_chunked() { + } else if hfields.is_chunked() { let content = Self::read_chunked(r)?; let trailer = FieldSection::read_http(r)?; (content, trailer) } else { let mut content = Vec::new(); - if let Some(cl) = header.get(CONTENT_LENGTH) { + if let Some(cl) = hfields.get(CONTENT_LENGTH) { let cl_str = String::from_utf8(Vec::from(cl))?; let cl_int = cl_str.parse::()?; if cl_int > 0 { @@ -731,23 +889,22 @@ impl Message { (content, FieldSection::default()) }; - header.strip_connection_headers(); + hfields.strip_connection_headers(); Ok(Self { informational, - control, - header, + header: Header::from((control, hfields)), content, trailer, }) } - #[cfg(feature = "write-http")] + #[cfg(feature = "http")] pub fn write_http(&self, w: &mut impl io::Write) -> Res<()> { for info in &self.informational { ControlData::Response(info.status()).write_http(w)?; info.fields().write_http(w)?; } - self.control.write_http(w)?; + self.header.control.write_http(w)?; if !self.content.is_empty() { if self.trailer.is_empty() { write!(w, "Content-Length: {}\r\n", self.content.len())?; @@ -770,7 +927,6 @@ impl Message { } /// Read a BHTTP message. - #[cfg(feature = "read-bhttp")] pub fn read_bhttp(r: &mut T) -> Res where T: BorrowMut + ?Sized, @@ -778,11 +934,7 @@ impl Message { { let t = read_varint(r)?.ok_or(Error::Truncated)?; let request = t == 0 || t == 2; - let mode = match t { - 0 | 1 => Mode::KnownLength, - 2 | 3 => Mode::IndeterminateLength, - _ => return Err(Error::InvalidMode), - }; + let mode = Mode::try_from(t)?; let mut control = ControlData::read_bhttp(request, r)?; let mut informational = Vec::new(); @@ -791,7 +943,7 @@ impl Message { informational.push(InformationalResponse::new(status, fields)); control = ControlData::read_bhttp(request, r)?; } - let header = FieldSection::read_bhttp(mode, r)?; + let hfields = FieldSection::read_bhttp(mode, r)?; let mut content = read_vec(r)?.unwrap_or_default(); if mode == Mode::IndeterminateLength && !content.is_empty() { @@ -808,20 +960,18 @@ impl Message { Ok(Self { informational, - control, - header, + header: Header::from((control, hfields)), content, trailer, }) } - #[cfg(feature = "write-bhttp")] pub fn write_bhttp(&self, mode: Mode, w: &mut impl io::Write) -> Res<()> { - write_varint(self.control.code(mode), w)?; + write_varint(self.header.control.code(mode), w)?; for info in &self.informational { info.write_bhttp(mode, w)?; } - self.control.write_bhttp(w)?; + self.header.control.write_bhttp(w)?; self.header.write_bhttp(mode, w)?; write_vec(&self.content, w)?; @@ -833,7 +983,7 @@ impl Message { } } -#[cfg(feature = "write-http")] +#[cfg(feature = "http")] impl std::fmt::Debug for Message { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { let mut buf = Vec::new(); diff --git a/bhttp/src/parse.rs b/bhttp/src/parse.rs index ee52493..06165fc 100644 --- a/bhttp/src/parse.rs +++ b/bhttp/src/parse.rs @@ -1,20 +1,21 @@ -#[cfg(feature = "read-http")] -use crate::{Error, ReadSeek, Res}; -#[cfg(feature = "read-http")] +#[cfg(feature = "http")] use std::borrow::BorrowMut; +#[cfg(feature = "http")] +use crate::{Error, ReadSeek, Res}; + pub const HTAB: u8 = 0x09; -#[cfg(feature = "read-http")] +#[cfg(feature = "http")] pub const NL: u8 = 0x0a; -#[cfg(feature = "read-http")] +#[cfg(feature = "http")] pub const CR: u8 = 0x0d; pub const SP: u8 = 0x20; pub const COMMA: u8 = 0x2c; -#[cfg(feature = "read-http")] +#[cfg(feature = "http")] pub const SLASH: u8 = 0x2f; -#[cfg(feature = "read-http")] +#[cfg(feature = "http")] pub const COLON: u8 = 0x3a; -#[cfg(feature = "read-http")] +#[cfg(feature = "http")] pub const SEMICOLON: u8 = 0x3b; pub fn is_ows(x: u8) -> bool { @@ -34,7 +35,7 @@ pub fn trim_ows(v: &[u8]) -> &[u8] { &v[..0] } -#[cfg(feature = "read-http")] +#[cfg(feature = "http")] pub fn downcase(n: &mut [u8]) { for i in n { if *i >= 0x41 && *i <= 0x5a { @@ -52,7 +53,7 @@ pub fn index_of(v: u8, line: &[u8]) -> Option { None } -#[cfg(feature = "read-http")] +#[cfg(feature = "http")] pub fn split_at(v: u8, mut line: Vec) -> Option<(Vec, Vec)> { index_of(v, &line).map(|i| { let tail = line.split_off(i + 1); @@ -61,7 +62,7 @@ pub fn split_at(v: u8, mut line: Vec) -> Option<(Vec, Vec)> { }) } -#[cfg(feature = "read-http")] +#[cfg(feature = "http")] pub fn read_line(r: &mut T) -> Res> where T: BorrowMut + ?Sized, diff --git a/bhttp/src/rw.rs b/bhttp/src/rw.rs index 92009ed..4659dd4 100644 --- a/bhttp/src/rw.rs +++ b/bhttp/src/rw.rs @@ -1,79 +1,66 @@ -#[cfg(feature = "read-bhttp")] -use std::borrow::BorrowMut; -use std::{convert::TryFrom, io}; +use std::{borrow::BorrowMut, convert::TryFrom, io}; -use crate::err::Res; -#[cfg(feature = "read-bhttp")] -use crate::{err::Error, ReadSeek}; +use crate::{ + err::{Error, Res}, + ReadSeek, +}; -#[cfg(feature = "write-bhttp")] #[allow(clippy::cast_possible_truncation)] -fn write_uint(n: u8, v: impl Into, w: &mut impl io::Write) -> Res<()> { - let v = v.into(); - assert!(n > 0 && usize::from(n) < std::mem::size_of::()); - for i in 0..n { - w.write_all(&[((v >> (8 * (n - i - 1))) & 0xff) as u8])?; - } +pub(crate) fn write_uint(v: impl Into, w: &mut impl io::Write) -> Res<()> { + let v = v.into().to_be_bytes(); + assert!((1..=std::mem::size_of::()).contains(&N)); + w.write_all(&v[std::mem::size_of::() - N..])?; Ok(()) } -#[cfg(feature = "write-bhttp")] pub fn write_varint(v: impl Into, w: &mut impl io::Write) -> Res<()> { let v = v.into(); match () { - () if v < (1 << 6) => write_uint(1, v, w), - () if v < (1 << 14) => write_uint(2, v | (1 << 14), w), - () if v < (1 << 30) => write_uint(4, v | (2 << 30), w), - () if v < (1 << 62) => write_uint(8, v | (3 << 62), w), - () => panic!("Varint value too large"), + () if v < (1 << 6) => write_uint::<1>(v, w), + () if v < (1 << 14) => write_uint::<2>(v | (1 << 14), w), + () if v < (1 << 30) => write_uint::<4>(v | (2 << 30), w), + () if v < (1 << 62) => write_uint::<8>(v | (3 << 62), w), + () => panic!("varint value too large"), } } -#[cfg(feature = "write-bhttp")] pub fn write_len(len: usize, w: &mut impl io::Write) -> Res<()> { write_varint(u64::try_from(len).unwrap(), w) } -#[cfg(feature = "write-bhttp")] pub fn write_vec(v: &[u8], w: &mut impl io::Write) -> Res<()> { write_len(v.len(), w)?; w.write_all(v)?; Ok(()) } -#[cfg(feature = "read-bhttp")] -fn read_uint(n: usize, r: &mut T) -> Res> +fn read_uint(r: &mut T) -> Res> where T: BorrowMut + ?Sized, R: ReadSeek + ?Sized, { - let mut buf = [0; 7]; - let count = r.borrow_mut().read(&mut buf[..n])?; + let mut buf = [0; 8]; + let count = r.borrow_mut().read(&mut buf[(8 - N)..])?; if count == 0 { Ok(None) - } else if count < n { + } else if count < N { Err(Error::Truncated) } else { - let mut v = 0; - for i in &buf[..n] { - v = (v << 8) | u64::from(*i); - } - Ok(Some(v)) + Ok(Some(u64::from_be_bytes(buf))) } } -#[cfg(feature = "read-bhttp")] pub fn read_varint(r: &mut T) -> Res> where T: BorrowMut + ?Sized, R: ReadSeek + ?Sized, { - if let Some(b1) = read_uint(1, r)? { + if let Some(b1) = read_uint::<_, _, 1>(r)? { Ok(Some(match b1 >> 6 { 0 => b1 & 0x3f, - 1 => ((b1 & 0x3f) << 8) | read_uint(1, r)?.ok_or(Error::Truncated)?, - 2 => ((b1 & 0x3f) << 24) | read_uint(3, r)?.ok_or(Error::Truncated)?, - 3 => ((b1 & 0x3f) << 56) | read_uint(7, r)?.ok_or(Error::Truncated)?, + 1 => ((b1 & 0x3f) << 8) | read_uint::<_, _, 1>(r)?.ok_or(Error::Truncated)?, + 2 => ((b1 & 0x3f) << 24) | read_uint::<_, _, 3>(r)?.ok_or(Error::Truncated)?, + 3 => ((b1 & 0x3f) << 56) | read_uint::<_, _, 7>(r)?.ok_or(Error::Truncated)?, _ => unreachable!(), })) } else { @@ -81,7 +68,6 @@ where } } -#[cfg(feature = "read-bhttp")] pub fn read_vec(r: &mut T) -> Res>> where T: BorrowMut + ?Sized, @@ -106,3 +92,67 @@ where Ok(None) } } + +#[cfg(test)] +mod test { + use std::io::Cursor; + + use super::{read_varint, write_varint}; + use crate::{rw::read_vec, Error}; + + #[test] + fn basics() { + for i in [ + 0_u64, + 1, + 17, + 63, + 64, + 100, + 0x3fff, + 0x4000, + 0x1_0002, + 0x3fff_ffff, + 0x4000_0000, + 0x3456_dead_beef, + 0x3fff_ffff_ffff_ffff, + ] { + let mut buf = Vec::new(); + write_varint(i, &mut buf).unwrap(); + let sz_bytes = (64 - i.leading_zeros() + 2 + 7) / 8; // +2 size bits, +7 to round up + assert_eq!( + buf.len(), + usize::try_from(sz_bytes.next_power_of_two()).unwrap() + ); + + let o = read_varint(&mut Cursor::new(buf.clone())).unwrap(); + assert_eq!(Some(i), o); + + for cut in 1..buf.len() { + let e = read_varint(&mut Cursor::new(buf[..cut].to_vec())).unwrap_err(); + assert!(matches!(e, Error::Truncated)); + } + } + } + + #[test] + fn read_nothing() { + let o = read_varint(&mut Cursor::new(Vec::new())).unwrap(); + assert!(o.is_none()); + } + + #[test] + #[should_panic(expected = "varint value too large")] + fn too_big() { + _ = write_varint(0x4000_0000_0000_0000_u64, &mut Vec::new()); + } + + #[test] + fn too_big_vec() { + let mut buf = Vec::new(); + write_varint(10_u64, &mut buf).unwrap(); + buf.resize(10, 0); // Not enough extra for the promised length. + let e = read_vec(&mut Cursor::new(buf.clone())).unwrap_err(); + assert!(matches!(e, Error::Truncated)); + } +} diff --git a/bhttp/src/stream/int.rs b/bhttp/src/stream/int.rs new file mode 100644 index 0000000..f70231b --- /dev/null +++ b/bhttp/src/stream/int.rs @@ -0,0 +1,255 @@ +use std::{ + future::Future, + pin::{pin, Pin}, + task::{Context, Poll}, +}; + +use futures::io::AsyncRead; + +use crate::{Error, Res}; + +#[pin_project::pin_project] +pub struct ReadUint { + /// The source of data. + src: S, + /// A buffer that holds the bytes that have been read so far. + v: [u8; 8], + /// A counter of the number of bytes that are already in place. + /// This starts out at `8-N`. + read: usize, +} + +impl ReadUint { + pub fn stream(self) -> S { + self.src + } +} + +impl Future for ReadUint { + type Output = Res; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + match pin!(this.src).poll_read(cx, &mut this.v[*this.read..]) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(count)) => { + if count == 0 { + return Poll::Ready(Err(Error::Truncated)); + } + *this.read += count; + if *this.read == 8 { + Poll::Ready(Ok(u64::from_be_bytes(*this.v))) + } else { + Poll::Pending + } + } + Poll::Ready(Err(e)) => Poll::Ready(Err(Error::from(e))), + } + } +} + +#[cfg(test)] +fn read_uint(src: S) -> ReadUint { + ReadUint { + src, + v: [0; 8], + read: 8 - N, + } +} + +#[pin_project::pin_project(project = ReadVarintProj)] +pub enum ReadVarint { + // Invariant: this Option always contains Some. + First(Option), + Extra1(#[pin] ReadUint), + Extra3(#[pin] ReadUint), + Extra7(#[pin] ReadUint), +} + +impl ReadVarint { + pub fn stream(self) -> S { + match self { + Self::Extra1(s) | Self::Extra3(s) | Self::Extra7(s) => s.stream(), + Self::First(mut s) => s.take().unwrap(), + } + } +} + +impl Future for ReadVarint { + type Output = Res>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.as_mut(); + if let Self::First(ref mut src) = this.get_mut() { + let mut buf = [0; 1]; + let src_ref = src.as_mut().unwrap(); + if let Poll::Ready(res) = pin!(src_ref).poll_read(cx, &mut buf[..]) { + match res { + Ok(0) => return Poll::Ready(Ok(None)), + Ok(_) => (), + Err(e) => return Poll::Ready(Err(Error::from(e))), + } + + let b1 = buf[0]; + let mut v = [0; 8]; + let next = match b1 >> 6 { + 0 => return Poll::Ready(Ok(Some(u64::from(b1)))), + 1 => { + let src = src.take().unwrap(); + v[6] = b1 & 0x3f; + Self::Extra1(ReadUint { src, v, read: 7 }) + } + 2 => { + let src = src.take().unwrap(); + v[4] = b1 & 0x3f; + Self::Extra3(ReadUint { src, v, read: 5 }) + } + 3 => { + let src = src.take().unwrap(); + v[0] = b1 & 0x3f; + Self::Extra7(ReadUint { src, v, read: 1 }) + } + _ => unreachable!(), + }; + + self.set(next); + } + } + let extra = match self.project() { + ReadVarintProj::Extra1(s) | ReadVarintProj::Extra3(s) | ReadVarintProj::Extra7(s) => { + s.poll(cx) + } + ReadVarintProj::First(_) => return Poll::Pending, + }; + if let Poll::Ready(v) = extra { + Poll::Ready(v.map(Some)) + } else { + Poll::Pending + } + } +} + +pub fn read_varint(src: S) -> ReadVarint { + ReadVarint::First(Some(src)) +} + +#[cfg(test)] +mod test { + use sync_async::SyncResolve; + + use crate::{ + err::Error, + rw::{write_uint as sync_write_uint, write_varint as sync_write_varint}, + stream::int::{read_uint, read_varint}, + }; + + const VARINTS: &[u64] = &[ + 0, + 1, + 63, + 64, + (1 << 14) - 1, + 1 << 14, + (1 << 30) - 1, + 1 << 30, + (1 << 62) - 1, + ]; + + #[test] + fn read_uint_values() { + macro_rules! validate_uint_range { + (@ $n:expr) => { + let m = u64::MAX >> (64 - 8 * $n); + for v in [0, 1, m] { + println!("{n} byte encoding of 0x{v:x}", n = $n); + let mut buf = Vec::with_capacity($n); + sync_write_uint::<$n>(v, &mut buf).unwrap(); + let mut buf_ref = &buf[..]; + let mut fut = read_uint::<_, $n>(&mut buf_ref); + assert_eq!(v, fut.sync_resolve().unwrap()); + let s = fut.stream(); + assert!(s.is_empty()); + } + }; + ($($n:expr),+ $(,)?) => { + $( + validate_uint_range!(@ $n); + )+ + } + } + validate_uint_range!(1, 2, 3, 4, 5, 6, 7, 8); + } + + #[test] + fn read_uint_truncated() { + macro_rules! validate_uint_truncated { + (@ $n:expr) => { + let m = u64::MAX >> (64 - 8 * $n); + for v in [0, 1, m] { + println!("{n} byte encoding of 0x{v:x}", n = $n); + let mut buf = Vec::with_capacity($n); + sync_write_uint::<$n>(v, &mut buf).unwrap(); + for i in 1..buf.len() { + let err = read_uint::<_, $n>(&mut &buf[..i]).sync_resolve().unwrap_err(); + assert!(matches!(err, Error::Truncated)); + } + } + }; + ($($n:expr),+ $(,)?) => { + $( + validate_uint_truncated!(@ $n); + )+ + } + } + validate_uint_truncated!(1, 2, 3, 4, 5, 6, 7, 8); + } + + #[test] + fn read_varint_values() { + for &v in VARINTS { + let mut buf = Vec::new(); + sync_write_varint(v, &mut buf).unwrap(); + let mut buf_ref = &buf[..]; + let mut fut = read_varint(&mut buf_ref); + assert_eq!(Some(v), fut.sync_resolve().unwrap()); + let s = fut.stream(); + assert!(s.is_empty()); + } + } + + #[test] + fn read_varint_none() { + assert!(read_varint(&mut &[][..]).sync_resolve().unwrap().is_none()); + } + + #[test] + fn read_varint_truncated() { + for &v in VARINTS { + let mut buf = Vec::new(); + sync_write_varint(v, &mut buf).unwrap(); + for i in 1..buf.len() { + let err = { + let mut buf: &[u8] = &buf[..i]; + read_varint(&mut buf).sync_resolve() + } + .unwrap_err(); + assert!(matches!(err, Error::Truncated)); + } + } + } + + #[test] + fn read_varint_extra() { + const EXTRA: &[u8] = &[161, 2, 49]; + for &v in VARINTS { + let mut buf = Vec::new(); + sync_write_varint(v, &mut buf).unwrap(); + buf.extend_from_slice(EXTRA); + let mut buf_ref = &buf[..]; + let mut fut = read_varint(&mut buf_ref); + assert_eq!(Some(v), fut.sync_resolve().unwrap()); + let s = fut.stream(); + assert_eq!(&s[..], EXTRA); + } + } +} diff --git a/bhttp/src/stream/mod.rs b/bhttp/src/stream/mod.rs new file mode 100644 index 0000000..009b1eb --- /dev/null +++ b/bhttp/src/stream/mod.rs @@ -0,0 +1,588 @@ +#![allow(dead_code)] +#![allow(clippy::incompatible_msrv)] // This module uses features from rust 1.82 + +use std::{ + cmp::min, + io::{Cursor, Error as IoError, Result as IoResult}, + mem, + pin::{pin, Pin}, + task::{Context, Poll}, +}; + +use futures::{stream::unfold, AsyncRead, Stream, TryStreamExt}; + +use crate::{ + err::Res, + stream::{int::read_varint, vec::read_vec}, + ControlData, Error, Field, FieldSection, Header, InformationalResponse, Message, Mode, COOKIE, +}; +mod int; +mod vec; + +trait AsyncReadControlData: Sized { + async fn async_read(request: bool, src: S) -> Res; +} + +impl AsyncReadControlData for ControlData { + async fn async_read(request: bool, mut src: S) -> Res { + let v = if request { + let method = read_vec(&mut src).await?.ok_or(Error::Truncated)?; + let scheme = read_vec(&mut src).await?.ok_or(Error::Truncated)?; + let authority = read_vec(&mut src).await?.ok_or(Error::Truncated)?; + let path = read_vec(&mut src).await?.ok_or(Error::Truncated)?; + Self::Request { + method, + scheme, + authority, + path, + } + } else { + let code = read_varint(&mut src).await?.ok_or(Error::Truncated)?; + Self::Response(crate::StatusCode::try_from(code)?) + }; + Ok(v) + } +} + +trait AsyncReadFieldSection: Sized { + async fn async_read(mode: Mode, src: S) -> Res; +} + +impl AsyncReadFieldSection for FieldSection { + async fn async_read(mode: Mode, mut src: S) -> Res { + let fields = if mode == Mode::KnownLength { + // Known-length fields can just be read into a buffer. + if let Some(buf) = read_vec(&mut src).await? { + Self::read_bhttp_fields(false, &mut Cursor::new(&buf[..]))? + } else { + Vec::new() + } + } else { + // The async version needs to be implemented directly. + let mut fields: Vec = Vec::new(); + let mut cookie_index: Option = None; + loop { + if let Some(n) = read_vec(&mut src).await? { + if n.is_empty() { + break fields; + } + let mut v = read_vec(&mut src).await?.ok_or(Error::Truncated)?; + if n == COOKIE { + if let Some(i) = &cookie_index { + fields[*i].value.extend_from_slice(b"; "); + fields[*i].value.append(&mut v); + continue; + } + cookie_index = Some(fields.len()); + } + fields.push(Field::new(n, v)); + } else if fields.is_empty() { + break fields; + } else { + return Err(Error::Truncated); + } + } + }; + Ok(Self(fields)) + } +} + +#[derive(Default)] +enum BodyState { + // The starting state. + #[default] + Init, + // When reading the length, use this. + ReadLength { + buf: [u8; 8], + read: usize, + }, + // When reading the data, track how much is left. + ReadData { + remaining: usize, + }, +} + +impl BodyState { + fn read_len() -> Self { + Self::ReadLength { + buf: [0; 8], + read: 0, + } + } +} + +pub struct Body<'b, S> { + msg: &'b mut AsyncMessage, +} + +impl<'b, S> Body<'b, S> {} + +impl<'b, S: AsyncRead + Unpin> AsyncRead for Body<'b, S> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + self.msg.read_body(cx, buf).map_err(IoError::other) + } +} + +/// A helper function for the more complex body-reading code. +fn poll_error(e: Error) -> Poll> { + Poll::Ready(Err(IoError::other(e))) +} + +enum AsyncMessageState { + Init, + // Processing Informational responses (or before that). + Informational(bool), + // Having obtained the control data for the header, this is it. + Header(ControlData), + // Processing the Body. + Body(BodyState), + // Processing the trailer. + Trailer, + // All done. + Done, +} + +pub struct AsyncMessage { + // Whether this is a request and which mode. + mode: Option, + state: AsyncMessageState, + src: S, +} + +unsafe impl Send for AsyncMessage {} + +impl AsyncMessage { + async fn next_info(&mut self) -> Res> { + let request = if matches!(self.state, AsyncMessageState::Init) { + // Read control data ... + let t = read_varint(&mut self.src).await?.ok_or(Error::Truncated)?; + let request = t == 0 || t == 2; + self.mode = Some(Mode::try_from(t)?); + self.state = AsyncMessageState::Informational(request); + request + } else { + // ... or recover it. + let AsyncMessageState::Informational(request) = self.state else { + return Err(Error::InvalidState); + }; + request + }; + + let control = ControlData::async_read(request, &mut self.src).await?; + if let Some(status) = control.informational() { + let mode = self.mode.unwrap(); + let fields = FieldSection::async_read(mode, &mut self.src).await?; + Ok(Some(InformationalResponse::new(status, fields))) + } else { + self.state = AsyncMessageState::Header(control); + Ok(None) + } + } + + /// Produces a stream of informational responses from a fresh message. + /// Returns an empty stream if passed a request (or if there are no informational responses). + /// Error values on the stream indicate failures. + /// + /// There is no need to call this method to read a request, though + /// doing so is harmless. + /// + /// You can discard the stream that this function returns + /// without affecting the message. You can then either call this + /// method again to get any additional informational responses or + /// call `header()` to get the message header. + pub fn informational(&mut self) -> impl Stream> + use<'_, S> { + unfold(self, |this| async move { + this.next_info().await.transpose().map(|info| (info, this)) + }) + } + + /// This reads the header. If you have not called `informational` + /// and drained the resulting stream, this will do that for you. + /// # Panics + /// Never. + pub async fn header(&mut self) -> Res
{ + if matches!( + self.state, + AsyncMessageState::Init | AsyncMessageState::Informational(_) + ) { + // Need to scrub for errors, + // so that this can abort properly if there is one. + // The `try_any` usage is there to ensure that the stream is fully drained. + _ = self.informational().try_any(|_| async { false }).await?; + } + + if matches!(self.state, AsyncMessageState::Header(_)) { + let mode = self.mode.unwrap(); + let hfields = FieldSection::async_read(mode, &mut self.src).await?; + + let AsyncMessageState::Header(control) = mem::replace( + &mut self.state, + AsyncMessageState::Body(BodyState::default()), + ) else { + unreachable!(); + }; + Ok(Header::from((control, hfields))) + } else { + Err(Error::InvalidState) + } + } + + fn body_state(&mut self, s: BodyState) { + self.state = AsyncMessageState::Body(s); + } + + fn body_done(&mut self) { + self.state = AsyncMessageState::Trailer; + } + + /// Read the length of a body chunk. + /// This updates the values of `read` and `buf` to track the portion of the length + /// that was successfully read. + /// Returns `Some` with the error code that should be used if the reading + /// resulted in a conclusive outcome. + fn read_body_len( + cx: &mut Context<'_>, + mut src: &mut S, + first: bool, + read: &mut usize, + buf: &mut [u8; 8], + ) -> Option>> { + let mut src = pin!(src); + if *read == 0 { + let mut b = [0; 1]; + match src.as_mut().poll_read(cx, &mut b[..]) { + Poll::Pending => return Some(Poll::Pending), + Poll::Ready(Ok(0)) => { + return if first { + // It's OK for the first length to be absent. + // Just skip to the end. + *read = 8; + None + } else { + // ...it's not OK to drop length when continuing. + Some(poll_error(Error::Truncated)) + }; + } + Poll::Ready(Ok(1)) => match b[0] >> 6 { + 0 => { + buf[7] = b[0] & 0x3f; + *read = 8; + } + 1 => { + buf[6] = b[0] & 0x3f; + *read = 7; + } + 2 => { + buf[4] = b[0] & 0x3f; + *read = 5; + } + 3 => { + buf[0] = b[0] & 0x3f; + *read = 1; + } + _ => unreachable!(), + }, + Poll::Ready(Ok(_)) => unreachable!(), + Poll::Ready(Err(e)) => return Some(Poll::Ready(Err(e))), + } + } + if *read < 8 { + match src.as_mut().poll_read(cx, &mut buf[*read..]) { + Poll::Pending => return Some(Poll::Pending), + Poll::Ready(Ok(0)) => return Some(poll_error(Error::Truncated)), + Poll::Ready(Ok(len)) => { + *read += len; + } + Poll::Ready(Err(e)) => return Some(Poll::Ready(Err(e))), + } + } + None + } + + fn read_body(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + // The length that precedes the first chunk can be absent. + // Only allow that for the first chunk (if indeterminate length). + let first = if let AsyncMessageState::Body(BodyState::Init) = &self.state { + self.body_state(BodyState::read_len()); + true + } else { + false + }; + + // Read the length. This uses `read_body_len` to track the state of this reading. + // This doesn't use `ReadVarint` or any convenience functions because we + // need to track the state and we don't want the borrow checker to flip out. + if let AsyncMessageState::Body(BodyState::ReadLength { buf, read }) = &mut self.state { + if let Some(res) = Self::read_body_len(cx, &mut self.src, first, read, buf) { + return res; + } + if *read == 8 { + match usize::try_from(u64::from_be_bytes(*buf)) { + Ok(0) => { + self.body_done(); + return Poll::Ready(Ok(0)); + } + Ok(remaining) => { + self.body_state(BodyState::ReadData { remaining }); + } + Err(e) => return poll_error(Error::IntRange(e)), + } + } + } + + match &mut self.state { + AsyncMessageState::Body(BodyState::ReadData { remaining }) => { + let amount = min(*remaining, buf.len()); + let res = pin!(&mut self.src).poll_read(cx, &mut buf[..amount]); + match res { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(0)) => poll_error(Error::Truncated), + Poll::Ready(Ok(len)) => { + *remaining -= len; + if *remaining == 0 { + let mode = self.mode.unwrap(); + if mode == Mode::IndeterminateLength { + self.body_state(BodyState::read_len()); + } else { + self.body_done(); + } + } + Poll::Ready(Ok(len)) + } + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + } + } + AsyncMessageState::Trailer => Poll::Ready(Ok(0)), + _ => Poll::Pending, + } + } + + /// Read the body. + /// This produces an implementation of `AsyncRead` that filters out + /// the framing from the message body. + /// # Errors + /// This errors when the header has not been read. + /// Any IO errors are generated by the returned `Body` instance. + pub fn body(&mut self) -> Res> { + match self.state { + AsyncMessageState::Body(_) => Ok(Body { msg: self }), + _ => Err(Error::InvalidState), + } + } + + /// Read any trailer. + /// This might be empty. + /// # Errors + /// This errors when the body has not been read. + /// # Panics + /// Never. + pub async fn trailer(&mut self) -> Res { + if matches!(self.state, AsyncMessageState::Trailer) { + let trailer = FieldSection::async_read(self.mode.unwrap(), &mut self.src).await?; + self.state = AsyncMessageState::Done; + Ok(trailer) + } else { + Err(Error::InvalidState) + } + } +} + +pub trait AsyncReadMessage: Sized { + fn async_read(src: S) -> AsyncMessage; +} + +impl AsyncReadMessage for Message { + fn async_read(src: S) -> AsyncMessage { + AsyncMessage { + mode: None, + state: AsyncMessageState::Init, + src, + } + } +} + +#[cfg(test)] +mod test { + use std::pin::pin; + + use futures::TryStreamExt; + use sync_async::{Dribble, SyncRead, SyncResolve, SyncTryCollect}; + + use crate::{stream::AsyncReadMessage, Error, Message}; + + // Example from Section 5.1 of RFC 9292. + const REQUEST1: &[u8] = &[ + 0x00, 0x03, 0x47, 0x45, 0x54, 0x05, 0x68, 0x74, 0x74, 0x70, 0x73, 0x00, 0x0a, 0x2f, 0x68, + 0x65, 0x6c, 0x6c, 0x6f, 0x2e, 0x74, 0x78, 0x74, 0x40, 0x6c, 0x0a, 0x75, 0x73, 0x65, 0x72, + 0x2d, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x34, 0x63, 0x75, 0x72, 0x6c, 0x2f, 0x37, 0x2e, 0x31, + 0x36, 0x2e, 0x33, 0x20, 0x6c, 0x69, 0x62, 0x63, 0x75, 0x72, 0x6c, 0x2f, 0x37, 0x2e, 0x31, + 0x36, 0x2e, 0x33, 0x20, 0x4f, 0x70, 0x65, 0x6e, 0x53, 0x53, 0x4c, 0x2f, 0x30, 0x2e, 0x39, + 0x2e, 0x37, 0x6c, 0x20, 0x7a, 0x6c, 0x69, 0x62, 0x2f, 0x31, 0x2e, 0x32, 0x2e, 0x33, 0x04, + 0x68, 0x6f, 0x73, 0x74, 0x0f, 0x77, 0x77, 0x77, 0x2e, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, + 0x65, 0x2e, 0x63, 0x6f, 0x6d, 0x0f, 0x61, 0x63, 0x63, 0x65, 0x70, 0x74, 0x2d, 0x6c, 0x61, + 0x6e, 0x67, 0x75, 0x61, 0x67, 0x65, 0x06, 0x65, 0x6e, 0x2c, 0x20, 0x6d, 0x69, 0x00, 0x00, + ]; + const REQUEST2: &[u8] = &[ + 0x02, 0x03, 0x47, 0x45, 0x54, 0x05, 0x68, 0x74, 0x74, 0x70, 0x73, 0x00, 0x0a, 0x2f, 0x68, + 0x65, 0x6c, 0x6c, 0x6f, 0x2e, 0x74, 0x78, 0x74, 0x0a, 0x75, 0x73, 0x65, 0x72, 0x2d, 0x61, + 0x67, 0x65, 0x6e, 0x74, 0x34, 0x63, 0x75, 0x72, 0x6c, 0x2f, 0x37, 0x2e, 0x31, 0x36, 0x2e, + 0x33, 0x20, 0x6c, 0x69, 0x62, 0x63, 0x75, 0x72, 0x6c, 0x2f, 0x37, 0x2e, 0x31, 0x36, 0x2e, + 0x33, 0x20, 0x4f, 0x70, 0x65, 0x6e, 0x53, 0x53, 0x4c, 0x2f, 0x30, 0x2e, 0x39, 0x2e, 0x37, + 0x6c, 0x20, 0x7a, 0x6c, 0x69, 0x62, 0x2f, 0x31, 0x2e, 0x32, 0x2e, 0x33, 0x04, 0x68, 0x6f, + 0x73, 0x74, 0x0f, 0x77, 0x77, 0x77, 0x2e, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, + 0x63, 0x6f, 0x6d, 0x0f, 0x61, 0x63, 0x63, 0x65, 0x70, 0x74, 0x2d, 0x6c, 0x61, 0x6e, 0x67, + 0x75, 0x61, 0x67, 0x65, 0x06, 0x65, 0x6e, 0x2c, 0x20, 0x6d, 0x69, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + ]; + + #[test] + fn informational() { + const INFO: &[u8] = &[1, 64, 100, 0, 64, 200, 0]; + let mut buf_alias = INFO; + let mut msg = Message::async_read(&mut buf_alias); + let info = msg.informational().sync_collect().unwrap(); + assert_eq!(info.len(), 1); + let err = msg.informational().sync_collect(); + assert!(matches!(err, Err(Error::InvalidState))); + let hdr = pin!(msg.header()).sync_resolve().unwrap(); + assert_eq!(hdr.control().status().unwrap().code(), 200); + assert!(hdr.is_empty()); + } + + #[test] + fn sample_requests() { + fn validate_sample_request(mut buf: &[u8]) { + let mut msg = Message::async_read(&mut buf); + let info = msg.informational().sync_collect().unwrap(); + assert!(info.is_empty()); + + let hdr = pin!(msg.header()).sync_resolve().unwrap(); + assert_eq!(hdr.control(), &(b"GET", b"https", b"", b"/hello.txt")); + assert_eq!( + hdr.get(b"user-agent"), + Some(&b"curl/7.16.3 libcurl/7.16.3 OpenSSL/0.9.7l zlib/1.2.3"[..]), + ); + assert_eq!(hdr.get(b"host"), Some(&b"www.example.com"[..])); + assert_eq!(hdr.get(b"accept-language"), Some(&b"en, mi"[..])); + assert_eq!(hdr.len(), 3); + + let body = pin!(msg.body().unwrap()).sync_read_to_end(); + assert!(body.is_empty()); + + let trailer = pin!(msg.trailer()).sync_resolve().unwrap(); + assert!(trailer.is_empty()); + } + + validate_sample_request(REQUEST1); + validate_sample_request(REQUEST2); + validate_sample_request(&REQUEST2[..REQUEST2.len() - 12]); + } + + #[test] + fn truncated_header() { + // The indefinite-length request example includes 10 bytes of padding. + // The three additional zero values at the end represent: + // 1. The terminating zero for the header field section. + // 2. The terminating zero for the (empty) body. + // 3. The terminating zero for the (absent) trailer field section. + // The latter two (body and trailer) can be cut and the message will still work. + // The first is not optional; dropping it means that the message is truncated. + let mut buf = &mut &REQUEST2[..REQUEST2.len() - 13]; + let mut msg = Message::async_read(&mut buf); + // Use this test to test skipping a few things. + let err = pin!(msg.header()).sync_resolve().unwrap_err(); + assert!(matches!(err, Error::Truncated)); + } + + /// This test is crazy. It reads a byte at a time and checks the state constantly. + #[test] + fn sample_response() { + const RESPONSE: &[u8] = &[ + 0x03, 0x40, 0x66, 0x07, 0x72, 0x75, 0x6e, 0x6e, 0x69, 0x6e, 0x67, 0x0a, 0x22, 0x73, + 0x6c, 0x65, 0x65, 0x70, 0x20, 0x31, 0x35, 0x22, 0x00, 0x40, 0x67, 0x04, 0x6c, 0x69, + 0x6e, 0x6b, 0x23, 0x3c, 0x2f, 0x73, 0x74, 0x79, 0x6c, 0x65, 0x2e, 0x63, 0x73, 0x73, + 0x3e, 0x3b, 0x20, 0x72, 0x65, 0x6c, 0x3d, 0x70, 0x72, 0x65, 0x6c, 0x6f, 0x61, 0x64, + 0x3b, 0x20, 0x61, 0x73, 0x3d, 0x73, 0x74, 0x79, 0x6c, 0x65, 0x04, 0x6c, 0x69, 0x6e, + 0x6b, 0x24, 0x3c, 0x2f, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x2e, 0x6a, 0x73, 0x3e, + 0x3b, 0x20, 0x72, 0x65, 0x6c, 0x3d, 0x70, 0x72, 0x65, 0x6c, 0x6f, 0x61, 0x64, 0x3b, + 0x20, 0x61, 0x73, 0x3d, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x00, 0x40, 0xc8, 0x04, + 0x64, 0x61, 0x74, 0x65, 0x1d, 0x4d, 0x6f, 0x6e, 0x2c, 0x20, 0x32, 0x37, 0x20, 0x4a, + 0x75, 0x6c, 0x20, 0x32, 0x30, 0x30, 0x39, 0x20, 0x31, 0x32, 0x3a, 0x32, 0x38, 0x3a, + 0x35, 0x33, 0x20, 0x47, 0x4d, 0x54, 0x06, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x06, + 0x41, 0x70, 0x61, 0x63, 0x68, 0x65, 0x0d, 0x6c, 0x61, 0x73, 0x74, 0x2d, 0x6d, 0x6f, + 0x64, 0x69, 0x66, 0x69, 0x65, 0x64, 0x1d, 0x57, 0x65, 0x64, 0x2c, 0x20, 0x32, 0x32, + 0x20, 0x4a, 0x75, 0x6c, 0x20, 0x32, 0x30, 0x30, 0x39, 0x20, 0x31, 0x39, 0x3a, 0x31, + 0x35, 0x3a, 0x35, 0x36, 0x20, 0x47, 0x4d, 0x54, 0x04, 0x65, 0x74, 0x61, 0x67, 0x14, + 0x22, 0x33, 0x34, 0x61, 0x61, 0x33, 0x38, 0x37, 0x2d, 0x64, 0x2d, 0x31, 0x35, 0x36, + 0x38, 0x65, 0x62, 0x30, 0x30, 0x22, 0x0d, 0x61, 0x63, 0x63, 0x65, 0x70, 0x74, 0x2d, + 0x72, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x05, 0x62, 0x79, 0x74, 0x65, 0x73, 0x0e, 0x63, + 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x2d, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x02, + 0x35, 0x31, 0x04, 0x76, 0x61, 0x72, 0x79, 0x0f, 0x41, 0x63, 0x63, 0x65, 0x70, 0x74, + 0x2d, 0x45, 0x6e, 0x63, 0x6f, 0x64, 0x69, 0x6e, 0x67, 0x0c, 0x63, 0x6f, 0x6e, 0x74, + 0x65, 0x6e, 0x74, 0x2d, 0x74, 0x79, 0x70, 0x65, 0x0a, 0x74, 0x65, 0x78, 0x74, 0x2f, + 0x70, 0x6c, 0x61, 0x69, 0x6e, 0x00, 0x33, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x57, + 0x6f, 0x72, 0x6c, 0x64, 0x21, 0x20, 0x4d, 0x79, 0x20, 0x63, 0x6f, 0x6e, 0x74, 0x65, + 0x6e, 0x74, 0x20, 0x69, 0x6e, 0x63, 0x6c, 0x75, 0x64, 0x65, 0x73, 0x20, 0x61, 0x20, + 0x74, 0x72, 0x61, 0x69, 0x6c, 0x69, 0x6e, 0x67, 0x20, 0x43, 0x52, 0x4c, 0x46, 0x2e, + 0x0d, 0x0a, 0x00, 0x00, + ]; + + let mut buf = RESPONSE; + let mut msg = Message::async_read(Dribble::new(&mut buf)); + + { + // Need to scope access to `info` or it will hold the reference to `msg`. + let mut info = pin!(msg.informational()); + + let info1 = info.try_next().sync_resolve().unwrap().unwrap(); + assert_eq!(info1.status(), 102_u16); + assert_eq!(info1.len(), 1); + assert_eq!(info1.get(b"running"), Some(&b"\"sleep 15\""[..])); + + let info2 = info.try_next().sync_resolve().unwrap().unwrap(); + assert_eq!(info2.status(), 103_u16); + assert_eq!(info2.len(), 2); + let links = info2.get_all(b"link").collect::>(); + assert_eq!( + &links, + &[ + &b"; rel=preload; as=style"[..], + &b"; rel=preload; as=script"[..], + ] + ); + + assert!(info.try_next().sync_resolve().unwrap().is_none()); + } + + let hdr = pin!(msg.header()).sync_resolve().unwrap(); + assert_eq!(hdr.control(), &200_u16); + assert_eq!(hdr.len(), 8); + assert_eq!(hdr.get(b"vary"), Some(&b"Accept-Encoding"[..])); + assert_eq!(hdr.get(b"etag"), Some(&b"\"34aa387-d-1568eb00\""[..])); + + { + let mut body = pin!(msg.body().unwrap()); + assert_eq!(body.sync_read_exact(12), b"Hello World!"); + } + // Attempting to read the trailer before finishing the body should fail. + assert!(matches!( + pin!(msg.trailer()).sync_resolve(), + Err(Error::InvalidState) + )); + { + // Picking up the body again should work fine. + let mut body = pin!(msg.body().unwrap()); + assert_eq!( + body.sync_read_to_end(), + b" My content includes a trailing CRLF.\r\n" + ); + } + let trailer = pin!(msg.trailer()).sync_resolve().unwrap(); + assert!(trailer.is_empty()); + } +} diff --git a/bhttp/src/stream/vec.rs b/bhttp/src/stream/vec.rs new file mode 100644 index 0000000..a0989a9 --- /dev/null +++ b/bhttp/src/stream/vec.rs @@ -0,0 +1,222 @@ +use std::{ + future::Future, + mem, + pin::{pin, Pin}, + task::{Context, Poll}, +}; + +use futures::{io::AsyncRead, FutureExt}; + +use super::int::{read_varint, ReadVarint}; +use crate::{Error, Res}; + +#[pin_project::pin_project(project = ReadVecProj)] +#[allow(clippy::module_name_repetitions)] +pub enum ReadVec { + // Invariant: This Option is always Some. + ReadLen { + src: Option>, + cap: u64, + }, + ReadBody { + src: S, + buf: Vec, + remaining: usize, + }, +} + +impl ReadVec { + #![allow(dead_code)] // TODO these really need to be used. + + /// # Panics + /// If `limit` is more than `usize::MAX` or + /// if this is called after the length is read. + pub fn limit(&mut self, limit: u64) { + usize::try_from(limit).expect("cannot set a limit larger than usize::MAX"); + if let Self::ReadLen { ref mut cap, .. } = self { + *cap = limit; + } else { + panic!("cannot set a limit once the size has been read"); + } + } + + pub fn stream(self) -> S { + match self { + Self::ReadLen { mut src, .. } => src.take().unwrap().stream(), + Self::ReadBody { src, .. } => src, + } + } +} + +impl Future for ReadVec { + type Output = Res>>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.as_mut(); + if let Self::ReadLen { src, cap } = this.get_mut() { + match src.as_mut().unwrap().poll_unpin(cx) { + Poll::Ready(Ok(None)) => return Poll::Ready(Ok(None)), + Poll::Ready(Ok(Some(0))) => return Poll::Ready(Ok(Some(Vec::new()))), + Poll::Ready(Ok(Some(sz))) => { + if sz > *cap { + return Poll::Ready(Err(Error::LimitExceeded)); + } + // `cap` cannot exceed min(usize::MAX, u64::MAX). + let sz = usize::try_from(sz).unwrap(); + let body = Self::ReadBody { + src: src.take().unwrap().stream(), + buf: vec![0; sz], + remaining: sz, + }; + self.set(body); + } + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => return Poll::Pending, + } + } + + let ReadVecProj::ReadBody { + src, + buf, + remaining, + } = self.project() + else { + return Poll::Pending; + }; + + let offset = buf.len() - *remaining; + match pin!(src).poll_read(cx, &mut buf[offset..]) { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => Poll::Ready(Err(Error::from(e))), + Poll::Ready(Ok(0)) => Poll::Ready(Err(Error::Truncated)), + Poll::Ready(Ok(c)) => { + *remaining -= c; + if *remaining > 0 { + Poll::Pending + } else { + Poll::Ready(Ok(Some(mem::take(buf)))) + } + } + } + } +} + +#[allow(clippy::module_name_repetitions)] +pub fn read_vec(src: S) -> ReadVec { + ReadVec::ReadLen { + src: Some(read_varint(src)), + cap: u64::try_from(usize::MAX).unwrap_or(u64::MAX), + } +} + +#[cfg(test)] +mod test { + + use std::{ + cmp, + fmt::Debug, + io::Result, + pin::Pin, + task::{Context, Poll}, + }; + + use futures::AsyncRead; + use sync_async::SyncResolve; + + use crate::{rw::write_varint as sync_write_varint, stream::vec::read_vec, Error}; + + const FILL_VALUE: u8 = 90; + + fn fill(len: T) -> Vec + where + u64: TryFrom, + >::Error: Debug, + usize: TryFrom, + >::Error: Debug, + T: Debug + Copy, + { + let mut buf = Vec::new(); + sync_write_varint(u64::try_from(len).unwrap(), &mut buf).unwrap(); + buf.resize(buf.len() + usize::try_from(len).unwrap(), FILL_VALUE); + buf + } + + #[test] + fn read_vecs() { + for len in [0, 1, 2, 3, 64] { + let buf = fill(len); + let mut buf_ref = &buf[..]; + let mut fut = read_vec(&mut buf_ref); + if let Ok(Some(out)) = fut.sync_resolve() { + assert_eq!(len, out.len()); + assert!(out.iter().all(|&v| v == FILL_VALUE)); + + assert!(fut.stream().is_empty()); + } + } + } + + #[test] + fn exceed_cap() { + const LEN: u64 = 20; + let buf = fill(LEN); + let mut buf_ref = &buf[..]; + let mut fut = read_vec(&mut buf_ref); + fut.limit(LEN - 1); + assert!(matches!(fut.sync_resolve(), Err(Error::LimitExceeded))); + } + + /// This class implements `AsyncRead`, but + /// always blocks after returning a fixed value. + #[derive(Default)] + struct IncompleteRead<'a> { + data: &'a [u8], + consumed: usize, + } + + impl<'a> IncompleteRead<'a> { + fn new(data: &'a [u8]) -> Self { + Self { data, consumed: 0 } + } + } + + impl AsyncRead for IncompleteRead<'_> { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let remaining = &self.data[self.consumed..]; + if remaining.is_empty() { + Poll::Pending + } else { + let copied = cmp::min(buf.len(), remaining.len()); + buf[..copied].copy_from_slice(&remaining[..copied]); + self.as_mut().consumed += copied; + Poll::Ready(std::io::Result::Ok(copied)) + } + } + } + + #[test] + #[should_panic(expected = "cannot set a limit once the size has been read")] + fn late_cap() { + let mut buf = IncompleteRead::new(&[2, 1]); + _ = read_vec(&mut buf).sync_resolve_with(|mut f| { + println!("pending"); + f.as_mut().limit(100); + }); + } + + #[test] + #[cfg(any(target_pointer_width = "32", target_pointer_width = "16"))] + #[should_panic(expected = "cannot set a limit larger than usize::MAX")] + fn too_large_cap() { + const LEN: u64 = 20; + let buf = fill(LEN); + + let mut buf_ref = &buf[..]; + let mut fut = read_vec(&mut buf_ref); + fut.limit(u64::try_from(usize::MAX).unwrap() + 1); + } +} diff --git a/bhttp/tests/test.rs b/bhttp/tests/test.rs index c6729c6..9ed9731 100644 --- a/bhttp/tests/test.rs +++ b/bhttp/tests/test.rs @@ -1,5 +1,5 @@ // Rather than grapple with #[cfg(...)] for every variable and import. -#![cfg(all(feature = "http", feature = "bhttp"))] +#![cfg(feature = "http")] use std::{io::Cursor, mem::drop}; diff --git a/ohttp-client-cli/Cargo.toml b/ohttp-client-cli/Cargo.toml index 4f530f6..f40b198 100644 --- a/ohttp-client-cli/Cargo.toml +++ b/ohttp-client-cli/Cargo.toml @@ -15,7 +15,7 @@ hex = "0.4" [dependencies.bhttp] path= "../bhttp" -features = ["bhttp", "http"] +features = ["http"] [dependencies.ohttp] path= "../ohttp" diff --git a/ohttp-client-cli/src/main.rs b/ohttp-client-cli/src/main.rs index 7acc0f1..37a948d 100644 --- a/ohttp-client-cli/src/main.rs +++ b/ohttp-client-cli/src/main.rs @@ -1,8 +1,9 @@ #![deny(warnings, clippy::pedantic)] +use std::io::{self, BufRead, Write}; + use bhttp::{Message, Mode}; use ohttp::{init, ClientRequest}; -use std::io::{self, BufRead, Write}; fn main() { init(); diff --git a/ohttp-client/Cargo.toml b/ohttp-client/Cargo.toml index 7677608..1fa27fc 100644 --- a/ohttp-client/Cargo.toml +++ b/ohttp-client/Cargo.toml @@ -19,7 +19,7 @@ tokio = { version = "1", features = ["full"] } [dependencies.bhttp] path= "../bhttp" -features = ["bhttp", "http"] +features = ["http"] [dependencies.ohttp] path= "../ohttp" diff --git a/ohttp-client/src/main.rs b/ohttp-client/src/main.rs index 1d71dc2..4e8a431 100644 --- a/ohttp-client/src/main.rs +++ b/ohttp-client/src/main.rs @@ -1,7 +1,8 @@ #![deny(warnings, clippy::pedantic)] -use bhttp::{Message, Mode}; use std::{fs::File, io, io::Read, ops::Deref, path::PathBuf, str::FromStr}; + +use bhttp::{Message, Mode}; use structopt::StructOpt; type Res = Result>; diff --git a/ohttp-server/Cargo.toml b/ohttp-server/Cargo.toml index 026df48..87b595e 100644 --- a/ohttp-server/Cargo.toml +++ b/ohttp-server/Cargo.toml @@ -18,7 +18,7 @@ warp = { version = "0.3", features = ["tls"] } [dependencies.bhttp] path= "../bhttp" -features = ["bhttp", "write-http"] +features = ["http"] [dependencies.ohttp] path= "../ohttp" diff --git a/ohttp/Cargo.toml b/ohttp/Cargo.toml index ffaabd5..01a2027 100644 --- a/ohttp/Cargo.toml +++ b/ohttp/Cargo.toml @@ -10,27 +10,29 @@ description = "Oblivious HTTP" repository = "https://github.com/martinthomson/ohttp" [features] -default = ["client", "server", "rust-hpke"] +default = ["client", "server", "rust-hpke", "stream"] app-svc = ["nss"] client = [] external-sqlite = [] -gecko = ["nss", "mozbuild"] -nss = ["bindgen", "regex-mess"] -pq = ["hpke-pq"] -regex-mess = ["regex", "regex-automata", "regex-syntax"] -rust-hpke = ["rand", "aead", "aes-gcm", "chacha20poly1305", "hkdf", "sha2", "hpke"] +gecko = ["nss", "dep:mozbuild"] +nss = ["dep:bindgen", "regex-mess"] +pq = ["dep:hpke-pq"] +regex-mess = ["dep:regex", "dep:regex-automata", "dep:regex-syntax"] +rust-hpke = ["dep:rand", "dep:aead", "dep:aes-gcm", "dep:chacha20poly1305", "dep:hkdf", "dep:sha2", "dep:hpke"] server = [] +stream = ["dep:futures", "dep:pin-project"] [dependencies] aead = {version = "0.4", optional = true, features = ["std"]} aes-gcm = {version = "0.9", optional = true} byteorder = "1.4" chacha20poly1305 = {version = "0.8", optional = true} +futures = {version = "0.3", optional = true} hex = "0.4" hkdf = {version = "0.11", optional = true} hpke = {version = "0.11.0", optional = true, default-features = false, features = ["std", "x25519"]} -lazy_static = "1.4" log = {version = "0.4", default-features = false} +pin-project = {version = "1.1", optional = true} rand = {version = "0.8", optional = true} # bindgen uses regex and friends, which have been updated past our MSRV # however, the cargo resolver happily resolves versions that it can't compile @@ -64,3 +66,4 @@ features = ["runtime"] [dev-dependencies] env_logger = {version = "0.10", default-features = false} +sync-async = {path = "../sync-async"} diff --git a/ohttp/build.rs b/ohttp/build.rs index 1c01e3f..312cce2 100644 --- a/ohttp/build.rs +++ b/ohttp/build.rs @@ -8,8 +8,6 @@ #[cfg(feature = "nss")] mod nss { - use bindgen::Builder; - use serde_derive::Deserialize; use std::{ collections::HashMap, env, fs, @@ -17,6 +15,9 @@ mod nss { process::Command, }; + use bindgen::Builder; + use serde_derive::Deserialize; + const BINDINGS_DIR: &str = "bindings"; const BINDINGS_CONFIG: &str = "bindings.toml"; @@ -114,7 +115,6 @@ mod nss { let mut build_nss = vec![ String::from("./build.sh"), String::from("-Ddisable_tests=1"), - String::from("-Denable_draft_hpke=1"), ]; if is_debug() { build_nss.push(String::from("--static")); @@ -191,16 +191,8 @@ mod nss { } fn static_link(nsslibdir: &Path, use_static_softoken: bool, use_static_nspr: bool) { - let mut static_libs = vec![ - "certdb", - "certhi", - "cryptohi", - "nss_static", - "nssb", - "nssdev", - "nsspki", - "nssutil", - ]; + // The ordering of these libraries is critical for the linker. + let mut static_libs = vec!["cryptohi", "nss_static"]; let mut dynamic_libs = vec![]; if use_static_softoken { @@ -211,6 +203,8 @@ mod nss { static_libs.push("pk11wrap"); } + static_libs.extend_from_slice(&["nsspki", "nssdev", "nssb", "certhi", "certdb", "nssutil"]); + if use_static_nspr { static_libs.append(&mut nspr_libs()); } else { diff --git a/ohttp/src/config.rs b/ohttp/src/config.rs index 9b55755..f2316ce 100644 --- a/ohttp/src/config.rs +++ b/ohttp/src/config.rs @@ -1,24 +1,24 @@ -use crate::{ - err::{Error, Res}, - hpke::{Aead as AeadId, Kdf, Kem}, - KeyId, -}; -use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt}; use std::{ convert::TryFrom, io::{BufRead, BufReader, Cursor, Read}, }; +use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt}; + #[cfg(feature = "nss")] use crate::nss::{ hpke::{generate_key_pair, Config as HpkeConfig, HpkeR}, PrivateKey, PublicKey, }; - #[cfg(feature = "rust-hpke")] use crate::rh::hpke::{ derive_key_pair, generate_key_pair, Config as HpkeConfig, HpkeR, PrivateKey, PublicKey, }; +use crate::{ + err::{Error, Res}, + hpke::{Aead as AeadId, Kdf, Kem}, + KeyId, +}; /// A tuple of KDF and AEAD identifiers. #[derive(Debug, Copy, Clone, PartialEq, Eq)] @@ -95,18 +95,15 @@ impl KeyConfig { Self::strip_unsupported(&mut symmetric, kem); assert!(!symmetric.is_empty()); let (sk, pk) = derive_key_pair(kem, ikm)?; - Ok(Self { + return Ok(Self { key_id, kem, symmetric, sk: Some(sk), pk, - }) - } - #[cfg(not(feature = "rust-hpke"))] - { - Err(Error::Unsupported) + }); } + Err(Error::Unsupported) } /// Encode a list of key configurations. @@ -260,6 +257,22 @@ impl KeyConfig { Err(Error::Unsupported) } } + + #[allow(clippy::similar_names)] // for kem_id and key_id + pub(crate) fn decode_hpke_config(&self, r: &mut Cursor<&[u8]>) -> Res { + let key_id = r.read_u8()?; + if key_id != self.key_id { + return Err(Error::KeyId); + } + let kem_id = Kem::try_from(r.read_u16::()?)?; + if kem_id != self.kem { + return Err(Error::InvalidKem); + } + let kdf_id = Kdf::try_from(r.read_u16::()?)?; + let aead_id = AeadId::try_from(r.read_u16::()?)?; + let hpke_config = HpkeConfig::new(self.kem, kdf_id, aead_id); + Ok(hpke_config) + } } impl AsRef for KeyConfig { @@ -270,11 +283,12 @@ impl AsRef for KeyConfig { #[cfg(test)] mod test { + use std::iter::zip; + use crate::{ hpke::{Aead, Kdf, Kem}, init, Error, KeyConfig, KeyId, SymmetricSuite, }; - use std::iter::zip; const KEY_ID: KeyId = 1; const KEM: Kem = Kem::X25519Sha256; diff --git a/ohttp/src/crypto.rs b/ohttp/src/crypto.rs new file mode 100644 index 0000000..a22b74e --- /dev/null +++ b/ohttp/src/crypto.rs @@ -0,0 +1,13 @@ +use crate::{err::Res, AeadId}; + +pub trait Decrypt { + fn open(&mut self, aad: &[u8], ct: &[u8]) -> Res>; + #[allow(dead_code)] // Used by stream feature. + fn alg(&self) -> AeadId; +} + +pub trait Encrypt { + fn seal(&mut self, aad: &[u8], ct: &[u8]) -> Res>; + #[allow(dead_code)] // Used by stream feature. + fn alg(&self) -> AeadId; +} diff --git a/ohttp/src/err.rs b/ohttp/src/err.rs index 3c6ebd2..caaedc8 100644 --- a/ohttp/src/err.rs +++ b/ohttp/src/err.rs @@ -5,9 +5,15 @@ pub enum Error { #[cfg(feature = "rust-hpke")] #[error("a problem occurred with the AEAD")] Aead(#[from] aead::Error), + #[cfg(feature = "stream")] + #[error("a stream chunk was larger than the maximum allowed size")] + ChunkTooLarge, #[cfg(feature = "nss")] #[error("a problem occurred during cryptographic processing: {0}")] Crypto(#[from] crate::nss::Error), + #[cfg(feature = "stream")] + #[error("a stream contained data after the last chunk")] + ExtraData, #[error("an error was found in the format")] Format, #[cfg(all(feature = "rust-hpke", not(feature = "pq")))] @@ -26,6 +32,9 @@ pub enum Error { Io(#[from] std::io::Error), #[error("the key ID was invalid")] KeyId, + #[cfg(feature = "stream")] + #[error("the object was not ready")] + NotReady, #[error("a field was truncated")] Truncated, #[error("the configuration was not supported")] diff --git a/ohttp/src/lib.rs b/ohttp/src/lib.rs index 38e3666..133d9b8 100644 --- a/ohttp/src/lib.rs +++ b/ohttp/src/lib.rs @@ -4,8 +4,11 @@ not(all(feature = "client", feature = "server")), allow(dead_code, unused_imports) )] +#[cfg(all(feature = "nss", feature = "rust-hpke"))] +compile_error!("features \"nss\" and \"rust-hpke\" are mutually incompatible"); mod config; +mod crypto; mod err; pub mod hpke; #[cfg(feature = "nss")] @@ -14,48 +17,48 @@ mod nss; mod rand; #[cfg(feature = "rust-hpke")] mod rh; +#[cfg(feature = "stream")] +mod stream; -pub use crate::{ - config::{KeyConfig, SymmetricSuite}, - err::Error, -}; - -use crate::{ - err::Res, - hpke::{Aead as AeadId, Kdf, Kem}, -}; -use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt}; -use log::trace; use std::{ cmp::max, convert::TryFrom, - io::{BufReader, Read}, + io::{Cursor, Read}, mem::size_of, }; -#[cfg(feature = "nss")] -use crate::nss::random; +use byteorder::{NetworkEndian, WriteBytesExt}; +use crypto::{Decrypt, Encrypt}; +use log::trace; + #[cfg(feature = "nss")] use crate::nss::{ aead::{Aead, Mode, NONCE_LEN}, hkdf::{Hkdf, KeyMechanism}, hpke::{Config as HpkeConfig, Exporter, HpkeR, HpkeS}, + random, PublicKey, SymKey, }; - -#[cfg(feature = "rust-hpke")] -use crate::rand::random; +#[cfg(feature = "stream")] +use crate::stream::{ClientRequest as StreamClient, ServerRequest as ServerRequestStream}; +pub use crate::{ + config::{KeyConfig, SymmetricSuite}, + err::Error, +}; +use crate::{err::Res, hpke::Aead as AeadId}; #[cfg(feature = "rust-hpke")] -use crate::rh::{ - aead::{Aead, Mode, NONCE_LEN}, - hkdf::{Hkdf, KeyMechanism}, - hpke::{Config as HpkeConfig, Exporter, HpkeR, HpkeS}, +use crate::{ + rand::random, + rh::{ + aead::{Aead, Mode, NONCE_LEN}, + hkdf::{Hkdf, KeyMechanism}, + hpke::{Config as HpkeConfig, Exporter, HpkeR, HpkeS, PublicKey}, + SymKey, + }, }; /// The request header is a `KeyId` and 2 each for KEM, KDF, and AEAD identifiers const REQUEST_HEADER_LEN: usize = size_of::() + 6; const INFO_REQUEST: &[u8] = b"message/bhttp request"; -/// The info used for HPKE export is `INFO_REQUEST`, a zero byte, and the header. -const INFO_LEN: usize = INFO_REQUEST.len() + 1 + REQUEST_HEADER_LEN; const LABEL_RESPONSE: &[u8] = b"message/bhttp response"; const INFO_KEY: &[u8] = b"key"; const INFO_NONCE: &[u8] = b"nonce"; @@ -69,9 +72,9 @@ pub fn init() { } /// Construct the info parameter we use to initialize an `HpkeS` instance. -fn build_info(key_id: KeyId, config: HpkeConfig) -> Res> { - let mut info = Vec::with_capacity(INFO_LEN); - info.extend_from_slice(INFO_REQUEST); +fn build_info(label: &[u8], key_id: KeyId, config: HpkeConfig) -> Res> { + let mut info = Vec::with_capacity(label.len() + 1 + REQUEST_HEADER_LEN); + info.extend_from_slice(label); info.push(0); info.write_u8(key_id)?; info.write_u16::(u16::from(config.kem()))?; @@ -85,8 +88,9 @@ fn build_info(key_id: KeyId, config: HpkeConfig) -> Res> { /// This might not be necessary if we agree on a format. #[cfg(feature = "client")] pub struct ClientRequest { - hpke: HpkeS, - header: Vec, + key_id: KeyId, + config: HpkeConfig, + pk: PublicKey, } #[cfg(feature = "client")] @@ -95,14 +99,11 @@ impl ClientRequest { pub fn from_config(config: &mut KeyConfig) -> Res { // TODO(mt) choose the best config, not just the first. let selected = config.select(config.symmetric[0])?; - - // Build the info, which contains the message header. - let info = build_info(config.key_id, selected)?; - let hpke = HpkeS::new(selected, &mut config.pk, &info)?; - - let header = Vec::from(&info[INFO_REQUEST.len() + 1..]); - debug_assert_eq!(header.len(), REQUEST_HEADER_LEN); - Ok(Self { hpke, header }) + Ok(Self { + key_id: config.key_id, + config: selected, + pk: config.pk.clone(), + }) } /// Reads an encoded configuration and constructs a single use client sender. @@ -127,21 +128,32 @@ impl ClientRequest { /// Encapsulate a request. This consumes this object. /// This produces a response handler and the bytes of an encapsulated request. pub fn encapsulate(mut self, request: &[u8]) -> Res<(Vec, ClientResponse)> { - let extra = - self.hpke.config().kem().n_enc() + self.hpke.config().aead().n_t() + request.len(); - let expected_len = self.header.len() + extra; + // Build the info, which contains the message header. + let info = build_info(INFO_REQUEST, self.key_id, self.config)?; + let mut hpke = HpkeS::new(self.config, &mut self.pk, &info)?; - let mut enc_request = self.header; + let header = Vec::from(&info[INFO_REQUEST.len() + 1..]); + debug_assert_eq!(header.len(), REQUEST_HEADER_LEN); + + let extra = hpke.config().kem().n_enc() + hpke.config().aead().n_t() + request.len(); + let expected_len = header.len() + extra; + + let mut enc_request = header; enc_request.reserve_exact(extra); - let enc = self.hpke.enc()?; + let enc = hpke.enc()?; enc_request.extend_from_slice(&enc); - let mut ct = self.hpke.seal(&[], request)?; + let mut ct = hpke.seal(&[], request)?; enc_request.append(&mut ct); debug_assert_eq!(expected_len, enc_request.len()); - Ok((enc_request, ClientResponse::new(self.hpke, enc))) + Ok((enc_request, ClientResponse::new(hpke, enc))) + } + + #[cfg(feature = "stream")] + pub fn encapsulate_stream(self, dst: S) -> Res> { + StreamClient::start(dst, self.config, self.key_id, self.pk) } } @@ -170,48 +182,45 @@ impl Server { &self.config } - /// Remove encapsulation on a message. + fn decode_request_header(&self, r: &mut Cursor<&[u8]>, label: &[u8]) -> Res<(HpkeR, Vec)> { + let hpke_config = self.config.decode_hpke_config(r)?; + let sym = SymmetricSuite::new(hpke_config.kdf(), hpke_config.aead()); + let config = self.config.select(sym)?; + let info = build_info(label, self.config.key_id, hpke_config)?; + + let mut enc = vec![0; config.kem().n_enc()]; + r.read_exact(&mut enc)?; + + Ok(( + HpkeR::new( + config, + &self.config.pk, + self.config.sk.as_ref().unwrap(), + &enc, + &info, + )?, + enc, + )) + } + + /// Remove encapsulation on a request. /// # Panics /// Not as a consequence of this code, but Rust won't know that for sure. - #[allow(clippy::similar_names)] // for kem_id and key_id pub fn decapsulate(&self, enc_request: &[u8]) -> Res<(Vec, ServerResponse)> { - if enc_request.len() < REQUEST_HEADER_LEN { + if enc_request.len() <= REQUEST_HEADER_LEN { return Err(Error::Truncated); } - let mut r = BufReader::new(enc_request); - let key_id = r.read_u8()?; - if key_id != self.config.key_id { - return Err(Error::KeyId); - } - let kem_id = Kem::try_from(r.read_u16::()?)?; - if kem_id != self.config.kem { - return Err(Error::InvalidKem); - } - let kdf_id = Kdf::try_from(r.read_u16::()?)?; - let aead_id = AeadId::try_from(r.read_u16::()?)?; - let sym = SymmetricSuite::new(kdf_id, aead_id); - - let info = build_info( - key_id, - HpkeConfig::new(self.config.kem, sym.kdf(), sym.aead()), - )?; - - let cfg = self.config.select(sym)?; - let mut enc = vec![0; cfg.kem().n_enc()]; - r.read_exact(&mut enc)?; - let mut hpke = HpkeR::new( - cfg, - &self.config.pk, - self.config.sk.as_ref().unwrap(), - &enc, - &info, - )?; + let mut r = Cursor::new(enc_request); + let (mut hpke, enc) = self.decode_request_header(&mut r, INFO_REQUEST)?; - let mut ct = Vec::new(); - r.read_to_end(&mut ct)?; + let request = hpke.open(&[], &enc_request[usize::try_from(r.position())?..])?; + Ok((request, ServerResponse::new(&hpke, &enc)?)) + } - let request = hpke.open(&[], &ct)?; - Ok((request, ServerResponse::new(&hpke, enc)?)) + /// Remove encapsulation on a streamed request. + #[cfg(feature = "stream")] + pub fn decapsulate_stream(self, src: S) -> ServerRequestStream { + ServerRequestStream::new(self.config, src) } } @@ -219,19 +228,16 @@ fn entropy(config: HpkeConfig) -> usize { max(config.aead().n_n(), config.aead().n_k()) } -fn make_aead( - mode: Mode, - cfg: HpkeConfig, - exp: &impl Exporter, - enc: Vec, - response_nonce: &[u8], -) -> Res { - let secret = exp.export(LABEL_RESPONSE, entropy(cfg))?; - let mut salt = enc; - salt.extend_from_slice(response_nonce); +fn export_secret(exp: &E, label: &[u8], cfg: HpkeConfig) -> Res { + exp.export(label, entropy(cfg)) +} + +fn make_aead(mode: Mode, cfg: HpkeConfig, secret: &SymKey, enc: &[u8], nonce: &[u8]) -> Res { + let mut salt = enc.to_vec(); + salt.extend_from_slice(nonce); let hkdf = Hkdf::new(cfg.kdf()); - let prk = hkdf.extract(&salt, &secret)?; + let prk = hkdf.extract(&salt, secret)?; let key = hkdf.expand_key(&prk, INFO_KEY, KeyMechanism::Aead(cfg.aead()))?; let iv = hkdf.expand_data(&prk, INFO_NONCE, cfg.aead().n_n())?; @@ -250,9 +256,15 @@ pub struct ServerResponse { #[cfg(feature = "server")] impl ServerResponse { - fn new(hpke: &HpkeR, enc: Vec) -> Res { + fn new(hpke: &HpkeR, enc: &[u8]) -> Res { let response_nonce = random(entropy(hpke.config())); - let aead = make_aead(Mode::Encrypt, hpke.config(), hpke, enc, &response_nonce)?; + let aead = make_aead( + Mode::Encrypt, + hpke.config(), + &export_secret(hpke, LABEL_RESPONSE, hpke.config())?, + enc, + &response_nonce, + )?; Ok(Self { response_nonce, aead, @@ -302,48 +314,54 @@ impl ClientResponse { let mut aead = make_aead( Mode::Decrypt, self.hpke.config(), - &self.hpke, - self.enc, + &export_secret(&self.hpke, LABEL_RESPONSE, self.hpke.config())?, + &self.enc, response_nonce, )?; - aead.open(&[], 0, ct) // 0 is the sequence number + aead.open(&[], ct) // 0 is the sequence number } } #[cfg(all(test, feature = "client", feature = "server"))] mod test { + use std::{fmt::Debug, io::ErrorKind}; + + use log::trace; + use crate::{ config::SymmetricSuite, err::Res, hpke::{Aead, Kdf, Kem}, ClientRequest, Error, KeyConfig, KeyId, Server, }; - use log::trace; - use std::{fmt::Debug, io::ErrorKind}; - const KEY_ID: KeyId = 1; - const KEM: Kem = Kem::X25519Sha256; - const SYMMETRIC: &[SymmetricSuite] = &[ + pub const KEY_ID: KeyId = 1; + pub const KEM: Kem = Kem::X25519Sha256; + pub const SYMMETRIC: &[SymmetricSuite] = &[ SymmetricSuite::new(Kdf::HkdfSha256, Aead::Aes128Gcm), SymmetricSuite::new(Kdf::HkdfSha256, Aead::ChaCha20Poly1305), ]; - const REQUEST: &[u8] = &[ + pub const REQUEST: &[u8] = &[ 0x00, 0x03, 0x47, 0x45, 0x54, 0x05, 0x68, 0x74, 0x74, 0x70, 0x73, 0x0b, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x63, 0x6f, 0x6d, 0x01, 0x2f, ]; - const RESPONSE: &[u8] = &[0x01, 0x40, 0xc8]; + pub const RESPONSE: &[u8] = &[0x01, 0x40, 0xc8]; - fn init() { + pub fn init() { crate::init(); _ = env_logger::try_init(); // ignore errors here } + pub fn make_config() -> KeyConfig { + KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)).unwrap() + } + #[test] fn request_response() { init(); - let server_config = KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)).unwrap(); + let server_config = make_config(); let server = Server::new(server_config).unwrap(); let encoded_config = server.config().encode().unwrap(); trace!("Config: {}", hex::encode(&encoded_config)); @@ -368,7 +386,7 @@ mod test { fn two_requests() { init(); - let server_config = KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)).unwrap(); + let server_config = make_config(); let server = Server::new(server_config).unwrap(); let encoded_config = server.config().encode().unwrap(); @@ -408,7 +426,7 @@ mod test { fn request_truncated(cut: usize) { init(); - let server_config = KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)).unwrap(); + let server_config = make_config(); let server = Server::new(server_config).unwrap(); let encoded_config = server.config().encode().unwrap(); @@ -439,7 +457,7 @@ mod test { fn response_truncated(cut: usize) { init(); - let server_config = KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)).unwrap(); + let server_config = make_config(); let server = Server::new(server_config).unwrap(); let encoded_config = server.config().encode().unwrap(); @@ -498,7 +516,7 @@ mod test { fn request_from_config_list() { init(); - let server_config = KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)).unwrap(); + let server_config = make_config(); let server = Server::new(server_config).unwrap(); let encoded_config = server.config().encode().unwrap(); diff --git a/ohttp/src/nss/aead.rs b/ohttp/src/nss/aead.rs index 18f0b66..bbf1d00 100644 --- a/ohttp/src/nss/aead.rs +++ b/ohttp/src/nss/aead.rs @@ -1,3 +1,11 @@ +use std::{ + convert::{TryFrom, TryInto}, + mem, + os::raw::c_int, +}; + +use log::trace; + use super::{ err::secstatus_to_res, p11::{ @@ -10,15 +18,10 @@ use super::{ }, }; use crate::{ + crypto::{Decrypt, Encrypt}, err::{Error, Res}, hpke::Aead as AeadId, }; -use log::trace; -use std::{ - convert::{TryFrom, TryInto}, - mem, - os::raw::c_int, -}; /// All the nonces are the same length. Exploit that. pub const NONCE_LEN: usize = 12; @@ -67,6 +70,7 @@ impl Mode { /// This is an AEAD instance that uses the pub struct Aead { mode: Mode, + algorithm: AeadId, ctx: Context, nonce_base: [u8; NONCE_LEN], } @@ -118,58 +122,59 @@ impl Aead { }; Ok(Self { mode, + algorithm, ctx: Context::from_ptr(ptr)?, nonce_base, }) } - pub fn seal(&mut self, aad: &[u8], pt: &[u8]) -> Res> { - assert_eq!(self.mode, Mode::Encrypt); - // A copy for the nonce generator to write into. But we don't use the value. + #[allow(dead_code)] + pub fn open_seq(&mut self, aad: &[u8], seq: SequenceNumber, ct: &[u8]) -> Res> { + assert_eq!(self.mode, Mode::Decrypt); let mut nonce = self.nonce_base; - // Ciphertext with enough space for the tag. - // Even though we give the operation a separate buffer for the tag, - // reserve the capacity on allocation. - let mut ct = vec![0; pt.len() + TAG_LEN]; - let mut ct_len: c_int = 0; - let mut tag = vec![0; TAG_LEN]; + for (i, n) in nonce.iter_mut().rev().take(COUNTER_LEN).enumerate() { + *n ^= u8::try_from((seq >> (8 * i)) & 0xff).unwrap(); + } + + let mut pt = vec![0; ct.len()]; // NSS needs more space than it uses for plaintext. + let mut pt_len: c_int = 0; + let pt_expected = ct.len().checked_sub(TAG_LEN).ok_or(Error::Truncated)?; secstatus_to_res(unsafe { PK11_AEADOp( *self.ctx, - CK_GENERATOR_FUNCTION::from(CKG_GENERATE_COUNTER_XOR), + CK_GENERATOR_FUNCTION::from(CKG_NO_GENERATE), c_int_len(NONCE_LEN - COUNTER_LEN), // Fixed portion of the nonce. nonce.as_mut_ptr(), c_int_len(nonce.len()), aad.as_ptr(), c_int_len(aad.len()), - ct.as_mut_ptr(), - &mut ct_len, - c_int_len(ct.len()), // signed :( - tag.as_mut_ptr(), - c_int_len(tag.len()), - pt.as_ptr(), - c_int_len(pt.len()), + pt.as_mut_ptr(), + &mut pt_len, + c_int_len(pt.len()), // signed :( + ct.as_ptr().add(pt_expected).cast_mut(), // const cast :( + c_int_len(TAG_LEN), + ct.as_ptr(), + c_int_len(pt_expected), ) })?; - ct.truncate(usize::try_from(ct_len).unwrap()); - debug_assert_eq!(ct.len(), pt.len()); - ct.append(&mut tag); - Ok(ct) + let len = usize::try_from(pt_len).unwrap(); + debug_assert_eq!(len, pt_expected); + pt.truncate(len); + Ok(pt) } +} - pub fn open(&mut self, aad: &[u8], seq: SequenceNumber, ct: &[u8]) -> Res> { +impl Decrypt for Aead { + fn open(&mut self, aad: &[u8], ct: &[u8]) -> Res> { assert_eq!(self.mode, Mode::Decrypt); let mut nonce = self.nonce_base; - for (i, n) in nonce.iter_mut().rev().take(COUNTER_LEN).enumerate() { - *n ^= u8::try_from((seq >> (8 * i)) & 0xff).unwrap(); - } let mut pt = vec![0; ct.len()]; // NSS needs more space than it uses for plaintext. let mut pt_len: c_int = 0; let pt_expected = ct.len().checked_sub(TAG_LEN).ok_or(Error::Truncated)?; secstatus_to_res(unsafe { PK11_AEADOp( *self.ctx, - CK_GENERATOR_FUNCTION::from(CKG_NO_GENERATE), + CK_GENERATOR_FUNCTION::from(CKG_GENERATE_COUNTER_XOR), c_int_len(NONCE_LEN - COUNTER_LEN), // Fixed portion of the nonce. nonce.as_mut_ptr(), c_int_len(nonce.len()), @@ -177,8 +182,8 @@ impl Aead { c_int_len(aad.len()), pt.as_mut_ptr(), &mut pt_len, - c_int_len(pt.len()), // signed :( - ct.as_ptr().add(pt_expected) as *mut _, // const cast :( + c_int_len(pt.len()), // signed :( + ct.as_ptr().add(pt_expected).cast_mut(), // const cast :( c_int_len(TAG_LEN), ct.as_ptr(), c_int_len(pt_expected), @@ -189,6 +194,50 @@ impl Aead { pt.truncate(len); Ok(pt) } + + fn alg(&self) -> AeadId { + self.algorithm + } +} + +impl Encrypt for Aead { + fn seal(&mut self, aad: &[u8], pt: &[u8]) -> Res> { + assert_eq!(self.mode, Mode::Encrypt); + // A copy for the nonce generator to write into. But we don't use the value. + let mut nonce = self.nonce_base; + // Ciphertext with enough space for the tag. + // Even though we give the operation a separate buffer for the tag, + // reserve the capacity on allocation. + let mut ct = vec![0; pt.len() + TAG_LEN]; + let mut ct_len: c_int = 0; + let mut tag = vec![0; TAG_LEN]; + secstatus_to_res(unsafe { + PK11_AEADOp( + *self.ctx, + CK_GENERATOR_FUNCTION::from(CKG_GENERATE_COUNTER_XOR), + c_int_len(NONCE_LEN - COUNTER_LEN), // Fixed portion of the nonce. + nonce.as_mut_ptr(), + c_int_len(nonce.len()), + aad.as_ptr(), + c_int_len(aad.len()), + ct.as_mut_ptr(), + &mut ct_len, + c_int_len(ct.len()), // signed :( + tag.as_mut_ptr(), + c_int_len(tag.len()), + pt.as_ptr(), + c_int_len(pt.len()), + ) + })?; + ct.truncate(usize::try_from(ct_len).unwrap()); + debug_assert_eq!(ct.len(), pt.len()); + ct.append(&mut tag); + Ok(ct) + } + + fn alg(&self) -> AeadId { + self.algorithm + } } #[cfg(test)] @@ -197,6 +246,7 @@ mod test { super::{super::hpke::Aead as AeadId, init}, Aead, Mode, SequenceNumber, NONCE_LEN, }; + use crate::crypto::{Decrypt, Encrypt}; /// Check that the first invocation of encryption matches expected values. /// Also check decryption of the same. @@ -216,7 +266,7 @@ mod test { assert_eq!(&ciphertext[..], ct); let mut dec = Aead::new(Mode::Decrypt, algorithm, &k, *nonce).unwrap(); - let plaintext = dec.open(aad, 0, ct).unwrap(); + let plaintext = dec.open(aad, ct).unwrap(); assert_eq!(&plaintext[..], pt); } @@ -231,7 +281,7 @@ mod test { ) { let k = Aead::import_key(algorithm, key).unwrap(); let mut dec = Aead::new(Mode::Decrypt, algorithm, &k, *nonce).unwrap(); - let plaintext = dec.open(aad, seq, ct).unwrap(); + let plaintext = dec.open_seq(aad, seq, ct).unwrap(); assert_eq!(&plaintext[..], pt); } diff --git a/ohttp/src/nss/err.rs b/ohttp/src/nss/err.rs index af85066..bc7c86d 100644 --- a/ohttp/src/nss/err.rs +++ b/ohttp/src/nss/err.rs @@ -10,9 +10,10 @@ clippy::module_name_repetitions )] +use std::os::raw::c_char; + use super::{SECStatus, SECSuccess}; use crate::err::Res; -use std::os::raw::c_char; include!(concat!(env!("OUT_DIR"), "/nspr_error.rs")); mod codes { diff --git a/ohttp/src/nss/hkdf.rs b/ohttp/src/nss/hkdf.rs index 470b1dd..af38ba9 100644 --- a/ohttp/src/nss/hkdf.rs +++ b/ohttp/src/nss/hkdf.rs @@ -1,3 +1,7 @@ +use std::{convert::TryFrom, os::raw::c_int, ptr::null_mut}; + +use log::trace; + use super::{ super::hpke::{Aead, Kdf}, p11::{ @@ -10,8 +14,6 @@ use super::{ }, }; use crate::err::Res; -use log::trace; -use std::{convert::TryFrom, os::raw::c_int, ptr::null_mut}; #[derive(Clone, Copy)] pub enum KeyMechanism { diff --git a/ohttp/src/nss/hpke.rs b/ohttp/src/nss/hpke.rs index b7ef845..53fa3a5 100644 --- a/ohttp/src/nss/hpke.rs +++ b/ohttp/src/nss/hpke.rs @@ -1,10 +1,3 @@ -use super::{ - super::hpke::{Aead, Kdf, Kem}, - err::{sec::SEC_ERROR_INVALID_ARGS, secstatus_to_res, Error}, - p11::{sys, Item, PrivateKey, PublicKey, Slot, SymKey}, -}; -use crate::err::Res; -use log::{log_enabled, trace}; use std::{ convert::TryFrom, ops::Deref, @@ -12,8 +5,19 @@ use std::{ ptr::{addr_of_mut, null, null_mut}, }; +use log::{log_enabled, trace}; pub use sys::{HpkeAeadId as AeadId, HpkeKdfId as KdfId, HpkeKemId as KemId}; +use super::{ + super::hpke::{Aead, Kdf, Kem}, + err::{sec::SEC_ERROR_INVALID_ARGS, secstatus_to_res, Error}, + p11::{sys, Item, PrivateKey, PublicKey, Slot, SymKey}, +}; +use crate::{ + crypto::{Decrypt, Encrypt}, + err::Res, +}; + /// Configuration for `Hpke`. #[derive(Clone, Copy)] pub struct Config { @@ -134,8 +138,10 @@ impl HpkeS { let slc = unsafe { std::slice::from_raw_parts(r.data, len) }; Ok(Vec::from(slc)) } +} - pub fn seal(&mut self, aad: &[u8], pt: &[u8]) -> Res> { +impl Encrypt for HpkeS { + fn seal(&mut self, aad: &[u8], pt: &[u8]) -> Res> { let mut out: *mut sys::SECItem = null_mut(); secstatus_to_res(unsafe { sys::PK11_HPKE_Seal(*self.context, &Item::wrap(aad), &Item::wrap(pt), &mut out) @@ -143,6 +149,10 @@ impl HpkeS { let v = Item::from_ptr(out)?; Ok(unsafe { v.into_vec() }) } + + fn alg(&self) -> Aead { + self.config.aead() + } } impl Exporter for HpkeS { @@ -208,8 +218,10 @@ impl HpkeR { })?; PublicKey::from_ptr(ptr) } +} - pub fn open(&mut self, aad: &[u8], ct: &[u8]) -> Res> { +impl Decrypt for HpkeR { + fn open(&mut self, aad: &[u8], ct: &[u8]) -> Res> { let mut out: *mut sys::SECItem = null_mut(); secstatus_to_res(unsafe { sys::PK11_HPKE_Open(*self.context, &Item::wrap(aad), &Item::wrap(ct), &mut out) @@ -217,6 +229,10 @@ impl HpkeR { let v = Item::from_ptr(out)?; Ok(unsafe { v.into_vec() }) } + + fn alg(&self) -> Aead { + self.config.aead() + } } impl Exporter for HpkeR { @@ -293,7 +309,11 @@ pub fn generate_key_pair(kem: Kem) -> Res<(PrivateKey, PublicKey)> { #[cfg(test)] mod test { use super::{generate_key_pair, Config, HpkeContext, HpkeR, HpkeS}; - use crate::{hpke::Aead, init}; + use crate::{ + crypto::{Decrypt, Encrypt}, + hpke::Aead, + init, + }; const INFO: &[u8] = b"info"; const AAD: &[u8] = b"aad"; diff --git a/ohttp/src/nss/mod.rs b/ohttp/src/nss/mod.rs index 1b60c9e..91dc44e 100644 --- a/ohttp/src/nss/mod.rs +++ b/ohttp/src/nss/mod.rs @@ -1,8 +1,4 @@ -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. +#![allow(clippy::incompatible_msrv)] // This feature needs 1.70 mod err; #[macro_use] @@ -11,11 +7,12 @@ pub mod aead; pub mod hkdf; pub mod hpke; -pub use self::p11::{random, PrivateKey, PublicKey}; +use std::{ptr::null, sync::OnceLock}; + use err::secstatus_to_res; pub use err::Error; -use lazy_static::lazy_static; -use std::ptr::null; + +pub use self::p11::{random, PrivateKey, PublicKey, SymKey}; #[allow(clippy::pedantic, non_upper_case_globals, clippy::upper_case_acronyms)] mod nss_init { @@ -45,17 +42,7 @@ impl Drop for NssLoaded { } } -lazy_static! { - static ref INITIALIZED: NssLoaded = { - if already_initialized() { - return NssLoaded::External; - } - - secstatus_to_res(unsafe { nss_init::NSS_NoDB_Init(null()) }).expect("NSS_NoDB_Init failed"); - - NssLoaded::NoDb - }; -} +static INITIALIZED: OnceLock = OnceLock::new(); fn already_initialized() -> bool { unsafe { nss_init::NSS_IsInitialized() != 0 } @@ -63,5 +50,13 @@ fn already_initialized() -> bool { /// Initialize NSS. This only executes the initialization routines once. pub fn init() { - lazy_static::initialize(&INITIALIZED); + INITIALIZED.get_or_init(|| { + if already_initialized() { + NssLoaded::External + } else { + secstatus_to_res(unsafe { nss_init::NSS_NoDB_Init(null()) }) + .expect("NSS_NoDB_Init failed"); + NssLoaded::NoDb + } + }); } diff --git a/ohttp/src/nss/p11.rs b/ohttp/src/nss/p11.rs index 9bddef6..19d9420 100644 --- a/ohttp/src/nss/p11.rs +++ b/ohttp/src/nss/p11.rs @@ -4,8 +4,6 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use super::err::{secstatus_to_res, Error}; -use crate::err::Res; use std::{ convert::TryFrom, marker::PhantomData, @@ -14,6 +12,9 @@ use std::{ ptr::null_mut, }; +use super::err::{secstatus_to_res, Error}; +use crate::err::Res; + #[allow( clippy::pedantic, clippy::upper_case_acronyms, diff --git a/ohttp/src/rh/aead.rs b/ohttp/src/rh/aead.rs index f209086..cc05e89 100644 --- a/ohttp/src/rh/aead.rs +++ b/ohttp/src/rh/aead.rs @@ -1,16 +1,19 @@ -#![allow(dead_code)] // TODO: remove +use std::convert::TryFrom; -use super::SymKey; -use crate::{err::Res, hpke::Aead as AeadId}; use aead::{AeadMut, Key, NewAead, Nonce, Payload}; use aes_gcm::{Aes128Gcm, Aes256Gcm}; use chacha20poly1305::ChaCha20Poly1305; -use std::convert::TryFrom; + +use super::SymKey; +use crate::{ + crypto::{Decrypt, Encrypt}, + err::Res, + hpke::Aead as AeadId, +}; /// All the nonces are the same length. Exploit that. pub const NONCE_LEN: usize = 12; const COUNTER_LEN: usize = 8; -const TAG_LEN: usize = 16; type SequenceNumber = u64; @@ -54,6 +57,7 @@ impl AeadEngine { /// A switch-hitting AEAD that uses a selected primitive. pub struct Aead { mode: Mode, + algorithm: AeadId, engine: AeadEngine, nonce_base: [u8; NONCE_LEN], seq: SequenceNumber, @@ -80,6 +84,7 @@ impl Aead { }; Ok(Self { mode, + algorithm, engine: aead, nonce_base, seq: 0, @@ -100,7 +105,28 @@ impl Aead { nonce } - pub fn seal(&mut self, aad: &[u8], pt: &[u8]) -> Res> { + pub fn open_seq(&mut self, aad: &[u8], seq: SequenceNumber, ct: &[u8]) -> Res> { + assert_eq!(self.mode, Mode::Decrypt); + let nonce = self.nonce(seq); + let pt = self.engine.decrypt(&nonce, Payload { msg: ct, aad })?; + Ok(pt) + } +} + +impl Decrypt for Aead { + fn open(&mut self, aad: &[u8], ct: &[u8]) -> Res> { + let res = self.open_seq(aad, self.seq, ct); + self.seq += 1; + res + } + + fn alg(&self) -> AeadId { + self.algorithm + } +} + +impl Encrypt for Aead { + fn seal(&mut self, aad: &[u8], pt: &[u8]) -> Res> { assert_eq!(self.mode, Mode::Encrypt); // A copy for the nonce generator to write into. But we don't use the value. let nonce = self.nonce(self.seq); @@ -109,11 +135,8 @@ impl Aead { Ok(ct) } - pub fn open(&mut self, aad: &[u8], seq: SequenceNumber, ct: &[u8]) -> Res> { - assert_eq!(self.mode, Mode::Decrypt); - let nonce = self.nonce(seq); - let pt = self.engine.decrypt(&nonce, Payload { msg: ct, aad })?; - Ok(pt) + fn alg(&self) -> AeadId { + self.algorithm } } @@ -123,6 +146,7 @@ mod test { super::super::{hpke::Aead as AeadId, init}, Aead, Mode, SequenceNumber, NONCE_LEN, }; + use crate::crypto::{Decrypt, Encrypt}; /// Check that the first invocation of encryption matches expected values. /// Also check decryption of the same. @@ -142,7 +166,7 @@ mod test { assert_eq!(&ciphertext[..], ct); let mut dec = Aead::new(Mode::Decrypt, algorithm, &k, *nonce).unwrap(); - let plaintext = dec.open(aad, 0, ct).unwrap(); + let plaintext = dec.open(aad, ct).unwrap(); assert_eq!(&plaintext[..], pt); } @@ -157,7 +181,7 @@ mod test { ) { let k = Aead::import_key(algorithm, key).unwrap(); let mut dec = Aead::new(Mode::Decrypt, algorithm, &k, *nonce).unwrap(); - let plaintext = dec.open(aad, seq, ct).unwrap(); + let plaintext = dec.open_seq(aad, seq, ct).unwrap(); assert_eq!(&plaintext[..], pt); } diff --git a/ohttp/src/rh/hkdf.rs b/ohttp/src/rh/hkdf.rs index aeb3a8d..a88f60e 100644 --- a/ohttp/src/rh/hkdf.rs +++ b/ohttp/src/rh/hkdf.rs @@ -1,13 +1,14 @@ #![allow(dead_code)] // TODO: remove +use hkdf::Hkdf as HkdfImpl; +use log::trace; +use sha2::{Sha256, Sha384, Sha512}; + use super::SymKey; use crate::{ err::{Error, Res}, hpke::{Aead, Kdf}, }; -use hkdf::Hkdf as HkdfImpl; -use log::trace; -use sha2::{Sha256, Sha384, Sha512}; #[derive(Clone, Copy)] pub enum KeyMechanism { diff --git a/ohttp/src/rh/hpke.rs b/ohttp/src/rh/hpke.rs index 4b81152..58e29c7 100644 --- a/ohttp/src/rh/hpke.rs +++ b/ohttp/src/rh/hpke.rs @@ -1,15 +1,13 @@ -use super::SymKey; -use crate::{ - hpke::{Aead, Kdf, Kem}, - Error, Res, -}; +use std::ops::Deref; #[cfg(not(feature = "pq"))] use ::hpke as rust_hpke; - #[cfg(feature = "pq")] use ::hpke_pq as rust_hpke; - +use ::rand::thread_rng; +use log::trace; +#[cfg(feature = "pq")] +use rust_hpke::kem::X25519Kyber768Draft00; use rust_hpke::{ aead::{AeadCtxR, AeadCtxS, AeadTag, AesGcm128, ChaCha20Poly1305}, kdf::HkdfSha256, @@ -17,12 +15,12 @@ use rust_hpke::{ setup_receiver, setup_sender, Deserializable, OpModeR, OpModeS, Serializable, }; -#[cfg(feature = "pq")] -use rust_hpke::kem::X25519Kyber768Draft00; - -use ::rand::thread_rng; -use log::trace; -use std::ops::Deref; +use super::SymKey; +use crate::{ + crypto::{Decrypt, Encrypt}, + hpke::{Aead, Kdf, Kem}, + Error, Res, +}; /// Configuration for `Hpke`. #[derive(Clone, Copy)] @@ -306,13 +304,19 @@ impl HpkeS { pub fn enc(&self) -> Res> { Ok(self.enc.clone()) } +} - pub fn seal(&mut self, aad: &[u8], pt: &[u8]) -> Res> { +impl Encrypt for HpkeS { + fn seal(&mut self, aad: &[u8], pt: &[u8]) -> Res> { let mut buf = pt.to_owned(); let mut tag = self.context.seal(&mut buf, aad)?; buf.append(&mut tag); Ok(buf) } + + fn alg(&self) -> Aead { + self.config.aead() + } } impl Exporter for HpkeS { @@ -525,13 +529,19 @@ impl HpkeR { ), }) } +} - pub fn open(&mut self, aad: &[u8], ct: &[u8]) -> Res> { +impl Decrypt for HpkeR { + fn open(&mut self, aad: &[u8], ct: &[u8]) -> Res> { let mut buf = ct.to_owned(); let pt_len = self.context.open(&mut buf, aad)?.len(); buf.truncate(pt_len); Ok(buf) } + + fn alg(&self) -> Aead { + self.config.aead() + } } impl Exporter for HpkeR { @@ -597,6 +607,7 @@ pub fn derive_key_pair(kem: Kem, ikm: &[u8]) -> Res<(PrivateKey, PublicKey)> { mod test { use super::{generate_key_pair, Config, HpkeR, HpkeS}; use crate::{ + crypto::{Decrypt, Encrypt}, hpke::{Aead, Kem}, init, }; diff --git a/ohttp/src/stream.rs b/ohttp/src/stream.rs new file mode 100644 index 0000000..f66f707 --- /dev/null +++ b/ohttp/src/stream.rs @@ -0,0 +1,782 @@ +#![allow(clippy::incompatible_msrv)] // Until I can make MSRV conditional on feature choice. + +use std::{ + cmp::min, + io::{Cursor, Error as IoError, Result as IoResult}, + mem, + pin::Pin, + task::{Context, Poll}, +}; + +use futures::{AsyncRead, AsyncWrite}; +use pin_project::pin_project; + +use crate::{ + build_info, + crypto::{Decrypt, Encrypt}, + entropy, + err::Res, + export_secret, make_aead, random, Aead, Error, HpkeConfig, HpkeR, HpkeS, KeyConfig, KeyId, + Mode, PublicKey, SymKey, REQUEST_HEADER_LEN, +}; + +/// The info string for a chunked request. +pub(crate) const INFO_REQUEST: &[u8] = b"message/bhttp chunked request"; +/// The exporter label for a chunked response. +pub(crate) const LABEL_RESPONSE: &[u8] = b"message/bhttp chunked response"; +/// The length of the plaintext of the largest chunk that is permitted. +const MAX_CHUNK_PLAINTEXT: usize = 1 << 14; +const CHUNK_AAD: &[u8] = b""; +const FINAL_CHUNK_AAD: &[u8] = b"final"; + +#[allow(clippy::unnecessary_wraps)] +fn some_error(e: E) -> Option>> +where + Error: From, +{ + Some(Poll::Ready(Err(IoError::other(Error::from(e))))) +} + +#[pin_project(project = ChunkWriterProjection)] +struct ChunkWriter { + #[pin] + dst: D, + cipher: E, + buf: Vec, +} + +impl ChunkWriter { + fn write_len(w: &mut [u8], len: usize) -> &[u8] { + let v: u64 = len.try_into().unwrap(); + let (v, len) = match () { + () if v < (1 << 6) => (v, 1), + () if v < (1 << 14) => (v | 1 << 14, 2), + () if v < (1 << 30) => (v | (2 << 30), 4), + () if v < (1 << 62) => (v | (3 << 62), 8), + () => panic!("varint value too large"), + }; + w[..len].copy_from_slice(&v.to_be_bytes()[(8 - len)..]); + &w[..len] + } +} + +impl ChunkWriter { + /// Flush our buffer. + /// Returns `Some` if the flush blocks or is unsuccessful. + /// If that contains `Ready`, it does so only when there is an error. + fn flush( + this: &mut ChunkWriterProjection<'_, D, C>, + cx: &mut Context<'_>, + ) -> Option> { + while !this.buf.is_empty() { + match this.dst.as_mut().poll_write(cx, &this.buf[..]) { + Poll::Pending => return Some(Poll::Pending), + Poll::Ready(Ok(len)) => { + if len < this.buf.len() { + // We've written something to the underlying writer, + // which is probably blocked. + // We could return `Poll::Pending`, + // but that would mean taking responsibility + // for calling `cx.waker().wake()` + // when more space comes available. + // + // So, rather than do that, loop. + // If the underlying writer is truly blocked, + // it assumes responsibility for waking the task. + *this.buf = this.buf.split_off(len); + } else { + this.buf.clear(); + } + } + Poll::Ready(Err(e)) => return Some(Poll::Ready(e)), + } + } + None + } + + fn write_chunk( + this: &mut ChunkWriterProjection<'_, D, C>, + cx: &mut Context<'_>, + input: &[u8], + last: bool, + ) -> Poll> { + let aad = if last { FINAL_CHUNK_AAD } else { CHUNK_AAD }; + let mut ct = this.cipher.seal(aad, input).map_err(IoError::other)?; + let (len, written) = if last { + (0, 0) + } else { + (ct.len(), input.len()) + }; + + let mut len_buf = [0; 8]; + let len = Self::write_len(&mut len_buf[..], len); + let w = match this.dst.as_mut().poll_write(cx, len) { + Poll::Pending => 0, + Poll::Ready(Ok(w)) => w, + e @ Poll::Ready(Err(_)) => return e, + }; + + if w < len.len() { + this.buf.extend_from_slice(&len[w..]); + this.buf.append(&mut ct); + } else { + match this.dst.as_mut().poll_write(cx, &ct[..]) { + Poll::Pending => { + *this.buf = ct; + } + Poll::Ready(Ok(w)) => { + *this.buf = ct.split_off(w); + } + e @ Poll::Ready(Err(_)) => return e, + } + } + Poll::Ready(Ok(written)) + } +} + +impl AsyncWrite for ChunkWriter { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + input: &[u8], + ) -> Poll> { + let mut this = self.project(); + // We have buffered data, so dump it into the output directly. + if let Some(value) = Self::flush(&mut this, cx) { + return value.map(Err); + } + + // Now encipher a chunk. + let len = min(input.len(), MAX_CHUNK_PLAINTEXT); + Self::write_chunk(&mut this, cx, &input[..len], false) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + if let Some(p) = Self::flush(&mut this, cx) { + // Flushing our buffers either blocked or failed. + p.map(Err) + } else { + this.dst.as_mut().poll_flush(cx) + } + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Self::write_chunk(&mut self.project(), cx, &[], true).map(|p| p.map(|_| ())) + } +} + +#[pin_project(project = ClientProjection)] +pub struct ClientRequest { + #[pin] + writer: ChunkWriter, +} + +impl ClientRequest { + /// Start the processing of a stream. + pub fn start(dst: D, config: HpkeConfig, key_id: KeyId, mut pk: PublicKey) -> Res { + let info = build_info(INFO_REQUEST, key_id, config)?; + let hpke = HpkeS::new(config, &mut pk, &info)?; + + let mut header = Vec::from(&info[INFO_REQUEST.len() + 1..]); + debug_assert_eq!(header.len(), REQUEST_HEADER_LEN); + + let mut e = hpke.enc()?; + header.append(&mut e); + + Ok(Self { + writer: ChunkWriter { + dst, + cipher: hpke, + buf: header, + }, + }) + } + + /// Get an object that can be used to process the response. + /// + /// While this can be used while sending the request, + /// doing so creates a risk of revealing unwanted information to the gateway. + /// That includes the round trip time between client and gateway, + /// which might reveal information about the location of the client. + pub fn response(&self, src: R) -> Res> { + let enc = self.writer.cipher.enc()?; + let secret = export_secret( + &self.writer.cipher, + LABEL_RESPONSE, + self.writer.cipher.config(), + )?; + Ok(ClientResponse { + src, + config: self.writer.cipher.config(), + state: ClientResponseState::Header { + enc, + secret, + nonce: [0; 16], + read: 0, + }, + }) + } +} + +impl AsyncWrite for ClientRequest { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + input: &[u8], + ) -> Poll> { + self.project().writer.as_mut().poll_write(cx, input) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().writer.as_mut().poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().writer.as_mut().poll_close(cx) + } +} + +enum ChunkReader { + Length { + len: [u8; 8], + offset: usize, + }, + Data { + buf: Vec, + offset: usize, + length: usize, + }, + Done, +} + +impl ChunkReader { + fn length() -> Self { + Self::Length { + len: [0; 8], + offset: 0, + } + } + + fn data(length: usize) -> Self { + // Avoid use `with_capacity` here. Only allocate when necessary. + // We might be able to into the buffer we're given instead, to save an allocation. + Self::Data { + buf: Vec::new(), + // Note that because we're allocating the full chunk, + // we need to track what has been used. + offset: 0, + length, + } + } + + fn read_fixed( + mut src: Pin<&mut S>, + cx: &mut Context<'_>, + buf: &mut [u8], + offset: &mut usize, + ) -> Option>> { + while *offset < buf.len() { + // Read any remaining bytes of the length. + match src.as_mut().poll_read(cx, &mut buf[*offset..]) { + Poll::Pending => return Some(Poll::Pending), + Poll::Ready(Ok(0)) => { + return some_error(Error::Truncated); + } + Poll::Ready(Ok(r)) => { + *offset += r; + } + e @ Poll::Ready(Err(_)) => return Some(e), + } + } + None + } + + fn read_length0( + &mut self, + mut src: Pin<&mut S>, + cx: &mut Context<'_>, + ) -> Option>> { + let Self::Length { len, offset } = self else { + return None; + }; + + let res = Self::read_fixed(src.as_mut(), cx, &mut len[..1], offset); + if res.is_some() { + return res; + } + + let form = len[0] >> 6; + if form == 0 { + *self = Self::data(usize::from(len[0])); + } else { + let v = mem::replace(&mut len[0], 0) & 0x3f; + let i = match form { + 1 => 6, + 2 => 4, + 3 => 0, + _ => unreachable!(), + }; + len[i] = v; + *offset = i + 1; + } + None + } + + fn read_length( + &mut self, + mut src: Pin<&mut S>, + cx: &mut Context<'_>, + aead: &mut C, + ) -> Option>> { + // Read the first byte. + let res = self.read_length0(src.as_mut(), cx); + if res.is_some() { + return res; + } + + let Self::Length { len, offset } = self else { + return None; + }; + + let res = Self::read_fixed(src.as_mut(), cx, &mut len[..], offset); + if res.is_some() { + return res; + } + + let remaining = match usize::try_from(u64::from_be_bytes(*len)) { + Ok(remaining) => remaining, + Err(e) => return some_error(e), + }; + if remaining > MAX_CHUNK_PLAINTEXT + aead.alg().n_t() { + return some_error(Error::ChunkTooLarge); + } + + *self = Self::data(remaining); + None + } + + /// Optional optimization that reads a single chunk into the output buffer. + fn read_into_output( + &mut self, + mut src: Pin<&mut S>, + cx: &mut Context<'_>, + aead: &mut C, + output: &mut [u8], + ) -> Option>> { + let Self::Data { + buf, + offset, + length, + } = self + else { + return None; + }; + if *length == 0 || *offset > 0 || output.len() < *length { + // We need to pull in a complete chunk in one go for this to be worthwhile. + return None; + } + + match src.as_mut().poll_read(cx, &mut output[..*length]) { + Poll::Pending => Some(Poll::Pending), + Poll::Ready(Ok(0)) => some_error(Error::Truncated), + Poll::Ready(Ok(r)) => { + if r == *length { + let pt = match aead.open(CHUNK_AAD, &output[..r]) { + Ok(pt) => pt, + Err(e) => return some_error(e), + }; + output[..pt.len()].copy_from_slice(&pt); + *self = Self::length(); + Some(Poll::Ready(Ok(pt.len()))) + } else { + buf.reserve_exact(*length); + buf.extend_from_slice(&output[..r]); + buf.resize(*length, 0); + *offset += r; + None + } + } + e @ Poll::Ready(Err(_)) => Some(e), + } + } + + fn read( + &mut self, + mut src: Pin<&mut S>, + cx: &mut Context<'_>, + cipher: &mut C, + output: &mut [u8], + ) -> Poll> { + while !matches!(self, Self::Done) { + if let Some(res) = self.read_length(src.as_mut(), cx, cipher) { + return res; + } + + // Read data. + if let Some(res) = self.read_into_output(src.as_mut(), cx, cipher, output) { + return res; + } + + let Self::Data { + buf, + offset, + length, + } = self + else { + unreachable!(); + }; + + // Allocate now as needed. + let last = *length == 0; + if buf.is_empty() { + let sz = if last { + MAX_CHUNK_PLAINTEXT + cipher.alg().n_t() + } else { + *length + }; + buf.resize(sz, 0); + } + + match src.as_mut().poll_read(cx, &mut buf[*offset..]) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Ok(0)) => { + if last { + buf.truncate(*offset); + } else { + return Poll::Ready(Err(IoError::other(Error::Truncated))); + } + } + Poll::Ready(Ok(r)) => { + *offset += r; + if last || *offset < *length { + continue; // Keep reading + } + } + e @ Poll::Ready(Err(_)) => return e, + } + + let aad = if last { FINAL_CHUNK_AAD } else { CHUNK_AAD }; + let pt = cipher.open(aad, buf).map_err(IoError::other)?; + output[..pt.len()].copy_from_slice(&pt); + + if last { + *self = Self::Done; + } else { + *self = Self::length(); + if pt.is_empty() { + continue; // Read the next chunk + } + }; + + return Poll::Ready(Ok(pt.len())); + } + + Poll::Ready(Ok(0)) + } +} + +enum ServerRequestState { + HpkeConfig { + buf: [u8; 7], + read: usize, + }, + Enc { + config: HpkeConfig, + info: Vec, + read: usize, + }, + Body { + hpke: HpkeR, + state: ChunkReader, + }, +} + +#[pin_project(project = ServerRequestProjection)] +pub struct ServerRequest { + #[pin] + src: S, + key_config: KeyConfig, + enc: Vec, + state: ServerRequestState, +} + +impl ServerRequest { + pub fn new(key_config: KeyConfig, src: S) -> Self { + Self { + src, + key_config, + enc: Vec::new(), + state: ServerRequestState::HpkeConfig { + buf: [0; 7], + read: 0, + }, + } + } + + /// Get a response that wraps the given async write instance. + /// This fails with an error if the request header hasn't been processed. + /// This condition is not exposed through a future anywhere, + /// but you can wait for the first byte of data. + pub fn response(&self, dst: D) -> Res> { + let ServerRequestState::Body { hpke, state: _ } = &self.state else { + return Err(Error::NotReady); + }; + + let response_nonce = random(entropy(hpke.config())); + let aead = make_aead( + Mode::Encrypt, + hpke.config(), + &export_secret(hpke, LABEL_RESPONSE, hpke.config())?, + &self.enc, + &response_nonce, + )?; + Ok(ServerResponse { + writer: ChunkWriter { + dst, + cipher: aead, + buf: response_nonce, + }, + }) + } +} + +impl ServerRequest { + fn read_config( + this: &mut ServerRequestProjection<'_, S>, + cx: &mut Context<'_>, + ) -> Option>> { + let ServerRequestState::HpkeConfig { buf, read } = this.state else { + return None; + }; + + let res = ChunkReader::read_fixed(this.src.as_mut(), cx, &mut buf[..], read); + if res.is_some() { + return res; + } + + let config = match this + .key_config + .decode_hpke_config(&mut Cursor::new(&buf[..])) + { + Ok(cfg) => cfg, + Err(e) => return some_error(e), + }; + let info = match build_info(INFO_REQUEST, this.key_config.key_id, config) { + Ok(info) => info, + Err(e) => return some_error(e), + }; + this.enc.resize(config.kem().n_enc(), 0); + + *this.state = ServerRequestState::Enc { + config, + info, + read: 0, + }; + None + } + + fn read_enc( + this: &mut ServerRequestProjection<'_, S>, + cx: &mut Context<'_>, + ) -> Option>> { + let ServerRequestState::Enc { config, info, read } = this.state else { + return None; + }; + + let res = ChunkReader::read_fixed(this.src.as_mut(), cx, &mut this.enc[..], read); + if res.is_some() { + return res; + } + + let hpke = match HpkeR::new( + *config, + &this.key_config.pk, + this.key_config.sk.as_ref().unwrap(), + this.enc, + info, + ) { + Ok(hpke) => hpke, + Err(e) => return some_error(e), + }; + + *this.state = ServerRequestState::Body { + hpke, + state: ChunkReader::length(), + }; + None + } +} + +impl AsyncRead for ServerRequest { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + output: &mut [u8], + ) -> Poll> { + let mut this = self.project(); + if let Some(res) = Self::read_config(&mut this, cx) { + return res; + } + + if let Some(res) = Self::read_enc(&mut this, cx) { + return res; + } + + if let ServerRequestState::Body { hpke, state } = this.state { + state.read(this.src, cx, hpke, output) + } else { + Poll::Ready(Ok(0)) + } + } +} + +#[pin_project(project = ServerResponseProjection)] +pub struct ServerResponse { + #[pin] + writer: ChunkWriter, +} + +impl AsyncWrite for ServerResponse { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + input: &[u8], + ) -> Poll> { + self.project().writer.as_mut().poll_write(cx, input) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().writer.as_mut().poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().writer.as_mut().poll_close(cx) + } +} + +enum ClientResponseState { + Header { + enc: Vec, + secret: SymKey, + nonce: [u8; 16], + read: usize, + }, + Body { + aead: Aead, + state: ChunkReader, + }, +} + +#[pin_project(project = ClientResponseProjection)] +pub struct ClientResponse { + #[pin] + src: S, + config: HpkeConfig, + state: ClientResponseState, +} + +impl ClientResponse { + fn read_nonce( + this: &mut ClientResponseProjection<'_, S>, + cx: &mut Context<'_>, + ) -> Option>> { + let ClientResponseState::Header { + enc, + secret, + nonce, + read, + } = this.state + else { + return None; + }; + + let nonce = &mut nonce[..entropy(*this.config)]; + let res = ChunkReader::read_fixed(this.src.as_mut(), cx, nonce, read); + if res.is_some() { + return res; + } + + let aead = match make_aead(Mode::Decrypt, *this.config, secret, enc, nonce) { + Ok(aead) => aead, + Err(e) => return some_error(e), + }; + + *this.state = ClientResponseState::Body { + aead, + state: ChunkReader::length(), + }; + None + } +} + +impl AsyncRead for ClientResponse { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + output: &mut [u8], + ) -> Poll> { + let mut this = self.project(); + if let Some(res) = Self::read_nonce(&mut this, cx) { + return res; + } + + if let ClientResponseState::Body { aead, state } = this.state { + state.read(this.src, cx, aead, output) + } else { + Poll::Ready(Ok(0)) + } + } +} + +#[cfg(test)] +mod test { + use futures::AsyncWriteExt; + use log::trace; + use sync_async::{Pipe, SyncRead, SyncResolve}; + + use crate::{ + test::{init, make_config, REQUEST, RESPONSE}, + ClientRequest, Server, + }; + + #[test] + fn request_response() { + init(); + + let server_config = make_config(); + let server = Server::new(server_config).unwrap(); + let encoded_config = server.config().encode().unwrap(); + trace!("Config: {}", hex::encode(&encoded_config)); + + // The client sends a request. + let client = ClientRequest::from_encoded_config(&encoded_config).unwrap(); + let (mut request_read, request_write) = Pipe::new(); + let mut client_request = client.encapsulate_stream(request_write).unwrap(); + client_request.write_all(REQUEST).sync_resolve().unwrap(); + client_request.close().sync_resolve().unwrap(); + + trace!("Request: {}", hex::encode(REQUEST)); + let enc_request = request_read.sync_read_to_end(); + trace!("Encapsulated Request: {}", hex::encode(&enc_request)); + + // The server receives a request. + let mut server_request = server.decapsulate_stream(&enc_request[..]); + assert_eq!(server_request.sync_read_to_end(), REQUEST); + + // The server sends a response. + let (mut response_read, response_write) = Pipe::new(); + let mut server_response = server_request.response(response_write).unwrap(); + server_response.write_all(RESPONSE).sync_resolve().unwrap(); + server_response.close().sync_resolve().unwrap(); + + let enc_response = response_read.sync_read_to_end(); + trace!("Encapsulated Response: {}", hex::encode(&enc_response)); + + // The client receives a response. + let mut client_response = client_request.response(&enc_response[..]).unwrap(); + let response_buf = client_response.sync_read_to_end(); + assert_eq!(response_buf, RESPONSE); + trace!("Response: {}", hex::encode(response_buf)); + } +} diff --git a/pre-commit b/pre-commit index 758b923..4693b48 100755 --- a/pre-commit +++ b/pre-commit @@ -6,10 +6,12 @@ # $ ln -s ../../hooks/pre-commit .git/hooks/pre-commit root="$(git rev-parse --show-toplevel 2>/dev/null)" +RUST_FMT_CFG="imports_granularity=Crate,group_imports=StdExternalCrate" # Some sanity checking. -hash cargo || exit 1 -[[ -n "$root" ]] || exit 1 +set -e +hash cargo +[[ -n "$root" ]] # Installation. if [[ "$1" == "install" ]]; then @@ -23,31 +25,54 @@ if [[ "$1" == "install" ]]; then exit fi -# Check formatting. +# Stash unstaged changes if [[ "$1" != "all" ]]; then - msg="pre-commit stash @$(git rev-parse --short @) $RANDOM" - trap 'git stash list -1 --format="format:%s" | grep -q "'"$msg"'" && git stash pop -q' EXIT - git stash push -k -u -q -m "$msg" + stashdir="$(mktemp -d "$root"/.pre-commit.stashXXXXXX)" + msg="pre-commit stash @$(git rev-parse --short @) ${stashdir##*.stash}" + gitdir="$(git rev-parse --git-dir 2>/dev/null)" + + stash() { + # Move MERGE_[HEAD|MODE|MSG] files to the root directory, and let `git stash push` save them. + find "$gitdir" -maxdepth 1 -name 'MERGE_*' -exec mv \{\} "$stashdir" \; + git stash push -k -u -q -m "$msg" + } + + unstash() { + git stash list -1 --format="format:%s" | grep -q "$msg" && git stash pop -q + # Moves MERGE files restored by `git stash pop` back into .git/ directory. + if [[ -d "$stashdir" ]]; then + find "$stashdir" -exec mv -n \{\} "$gitdir" \; + rmdir "$stashdir" + fi + } + + trap unstash EXIT + stash fi -if ! errors=($(cargo fmt -- --check --config imports_granularity=crate -l)); then - echo "Formatting errors found." - echo "Run \`cargo fmt\` to fix the following files:" + +# Check formatting +if ! errors=($(cargo fmt -- --check --config "$RUST_FMT_CFG" -l)); then + echo "Formatting errors found in:" for err in "${errors[@]}"; do echo " $err" done + echo "To fix, run \`cargo fmt -- --config $RUST_FMT_CFG\`" exit 1 fi -if ! cargo clippy --tests; then - exit 1 -fi -if ! cargo test; then - exit 1 -fi -if [[ -n "$NSS_DIR" ]]; then - if ! cargo clippy --tests --no-default-features --features nss; then - exit 1 - fi - if ! cargo test --no-default-features --features nss; then + +check() { + msg="$1" + shift + if ! "$@"; then + echo "${msg}: Failed command:" + echo " ${@@Q}" exit 1 fi +} + +check "clippy" cargo clippy --tests +check "test" cargo test +if [[ -n "$NSS_DIR" ]]; then + check "clippy(NSS)" cargo clippy --tests --no-default-features --features nss + check "test(NSS)" cargo test --no-default-features --features nss fi diff --git a/sync-async/Cargo.toml b/sync-async/Cargo.toml new file mode 100644 index 0000000..c419df1 --- /dev/null +++ b/sync-async/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "sync-async" +version = "0.5.3" +authors = ["Martin Thomson "] +edition = "2021" +license = "MIT OR Apache-2.0" +description = "Synchronous Helpers for Async Code" +repository = "https://github.com/martinthomson/ohttp" + +[dependencies] +futures = "0.3" +pin-project = "1.1" \ No newline at end of file diff --git a/sync-async/src/lib.rs b/sync-async/src/lib.rs new file mode 100644 index 0000000..07cb8ed --- /dev/null +++ b/sync-async/src/lib.rs @@ -0,0 +1,270 @@ +use std::{ + cmp::min, + future::Future, + io::Result as IoResult, + pin::{pin, Pin}, + task::{Context, Poll}, +}; + +use futures::{ + io::{ReadHalf, WriteHalf}, + AsyncRead, AsyncReadExt, AsyncWrite, TryStream, TryStreamExt, +}; +use pin_project::pin_project; + +fn noop_context() -> Context<'static> { + use std::{ + ptr::null, + task::{RawWaker, RawWakerVTable, Waker}, + }; + + const fn noop_raw_waker() -> RawWaker { + unsafe fn noop_clone(_data: *const ()) -> RawWaker { + noop_raw_waker() + } + + unsafe fn noop(_data: *const ()) {} + + const NOOP_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(noop_clone, noop, noop, noop); + RawWaker::new(null(), &NOOP_WAKER_VTABLE) + } + + pub fn noop_waker_ref() -> &'static Waker { + #[repr(transparent)] + struct SyncRawWaker(RawWaker); + unsafe impl Sync for SyncRawWaker {} + + static NOOP_WAKER_INSTANCE: SyncRawWaker = SyncRawWaker(noop_raw_waker()); + + // SAFETY: `Waker` is #[repr(transparent)] over its `RawWaker`. + unsafe { &*(std::ptr::addr_of!(NOOP_WAKER_INSTANCE.0).cast()) } + } + + Context::from_waker(noop_waker_ref()) +} + +/// Drives the given future (`f`) until it resolves. +/// Executes the indicated function (`p`) each time the +/// poll returned `Poll::Pending`. +pub trait SyncResolve { + type Output; + + fn sync_resolve(&mut self) -> Self::Output { + self.sync_resolve_with(|_| {}) + } + + fn sync_resolve_with)>(&mut self, p: P) -> Self::Output; +} + +impl SyncResolve for F { + type Output = F::Output; + + fn sync_resolve_with)>(&mut self, p: P) -> Self::Output { + let mut cx = noop_context(); + let mut fut = Pin::new(self); + let mut v = fut.as_mut().poll(&mut cx); + while v.is_pending() { + p(fut.as_mut()); + v = fut.as_mut().poll(&mut cx); + } + if let Poll::Ready(v) = v { + v + } else { + unreachable!(); + } + } +} + +pub trait SyncTryCollect { + type Item; + type Error; + + /// Synchronously gather all items from a stream. + /// # Errors + /// When the underlying source produces an error. + fn sync_collect(self) -> Result, Self::Error>; +} + +impl SyncTryCollect for S { + type Item = S::Ok; + type Error = S::Error; + + fn sync_collect(self) -> Result, Self::Error> { + pin!(self.try_collect::>()).sync_resolve() + } +} + +pub trait SyncRead { + fn sync_read_exact(&mut self, amount: usize) -> Vec; + fn sync_read_to_end(&mut self) -> Vec; +} + +impl SyncRead for S { + fn sync_read_exact(&mut self, amount: usize) -> Vec { + let mut buf = vec![0; amount]; + let res = self.read_exact(&mut buf[..]); + pin!(res).sync_resolve().unwrap(); + buf + } + + fn sync_read_to_end(&mut self) -> Vec { + let mut buf = Vec::new(); + let res = self.read_to_end(&mut buf); + pin!(res).sync_resolve().unwrap(); + buf + } +} + +#[pin_project(project = DribbleProjection)] +pub struct Dribble { + #[pin] + s: S, +} + +impl Dribble { + pub fn new(s: S) -> Self { + Self { s } + } +} + +impl AsyncRead for Dribble { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + pin!(&mut self.s).poll_read(cx, &mut buf[..1]) + } +} + +impl AsyncWrite for Dribble { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + let mut this = self.project(); + this.s.as_mut().poll_write(cx, &buf[..1]) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + this.s.as_mut().poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + this.s.as_mut().poll_close(cx) + } +} + +#[pin_project(project = StutterProjection)] +pub struct Stutter { + stall: bool, + #[pin] + s: S, +} + +impl Stutter { + pub fn new(s: S) -> Self { + Self { stall: false, s } + } + + fn stutter(self: Pin<&mut Self>, cx: &mut Context<'_>, f: F) -> Poll + where + F: FnOnce(Pin<&mut S>, &mut Context<'_>) -> Poll, + { + let mut this = self.project(); + *this.stall = !*this.stall; + if *this.stall { + // When returning `Poll::Pending`, you have to wake the task. + // We aren't running code anywhere except here, + // so call it here and ensure that the task is picked up. + cx.waker().wake_by_ref(); + Poll::Pending + } else { + f(this.s.as_mut(), cx) + } + } +} + +impl AsyncRead for Stutter { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Self::stutter(self, cx, |s, cx| s.poll_read(cx, buf)) + } +} + +impl AsyncWrite for Stutter { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + Self::stutter(self, cx, |s, cx| s.poll_write(cx, buf)) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Self::stutter(self, cx, AsyncWrite::poll_flush) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Self::stutter(self, cx, AsyncWrite::poll_close) + } +} + +/// A Cursor implementation that has separate read and write cursors. +/// +/// This allows tests to create paired read and write objects, +/// where writes to one can be read by the other. +/// +/// This relies on the implementation of `AyncReadExt::split` to provide +/// any locking and concurrency, rather than implementing it. +#[derive(Default)] +#[pin_project] +pub struct Pipe { + buf: Vec, + r: usize, + w: usize, +} + +impl Pipe { + #[must_use] + pub fn new() -> (ReadHalf, WriteHalf) { + AsyncReadExt::split(Self::default()) + } +} + +impl AsyncRead for Pipe { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let amnt = min(buf.len(), self.buf.len() - self.r); + buf[..amnt].copy_from_slice(&self.buf[self.r..self.r + amnt]); + self.r += amnt; + Poll::Ready(Ok(amnt)) + } +} + +impl AsyncWrite for Pipe { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + mut buf: &[u8], + ) -> Poll> { + if self.w < self.buf.len() { + let overlap = min(buf.len() - self.w, self.buf.len()); + let range = self.w..(self.w + overlap); + self.buf[range].copy_from_slice(&buf[..overlap]); + buf = &buf[overlap..]; + } + self.buf.extend_from_slice(buf); + self.w += buf.len(); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +}