Skip to content

Commit

Permalink
Pass static size ohttp en/decapsulate arguments
Browse files Browse the repository at this point in the history
Take advantage of the edit to use `&[u8]` function signatures where
applicable to reduce tech debt.
  • Loading branch information
DanGould committed Nov 27, 2024
1 parent b404a91 commit d6f927c
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 57 deletions.
11 changes: 3 additions & 8 deletions payjoin-cli/src/app/v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ impl App {
println!("Got a request from the sender. Responding with a Payjoin proposal.");
let res = post_request(req).await?;
payjoin_proposal
.process_res(res.bytes().await?.to_vec(), ohttp_ctx)
.process_res(&res.bytes().await?, ohttp_ctx)
.map_err(|e| anyhow!("Failed to deserialize response {}", e))?;
let payjoin_psbt = payjoin_proposal.psbt().clone();
println!(
Expand Down Expand Up @@ -198,16 +198,11 @@ impl App {
println!("Posting Original PSBT Payload request...");
let response = post_request(req).await?;
println!("Sent fallback transaction");
let v2_ctx = Arc::new(
ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?,
);
let v2_ctx = Arc::new(ctx.process_response(&response.bytes().await?)?);
loop {
let (req, ohttp_ctx) = v2_ctx.extract_req(self.config.ohttp_relay.clone())?;
let response = post_request(req).await?;
match v2_ctx.process_response(
&mut response.bytes().await?.to_vec().as_slice(),
ohttp_ctx,
) {
match v2_ctx.process_response(&response.bytes().await?, ohttp_ctx) {
Ok(Some(psbt)) => return Ok(psbt),
Ok(None) => {
println!("No response yet.");
Expand Down
9 changes: 7 additions & 2 deletions payjoin-directory/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ pub const DEFAULT_DIR_PORT: u16 = 8080;
pub const DEFAULT_DB_HOST: &str = "localhost:6379";
pub const DEFAULT_TIMEOUT_SECS: u64 = 30;

const PADDED_BHTTP_BYTES: usize = 8192;
const ENCAPSULATED_MESSAGE_BYTES: usize = 8192;
const CHACHA20_POLY1305_NONCE_LEN: usize = 32; // chacha20poly1305 n_k
const POLY1305_TAG_SIZE: usize = 16;
pub const BHTTP_REQ_BYTES: usize =
ENCAPSULATED_MESSAGE_BYTES - (CHACHA20_POLY1305_NONCE_LEN + POLY1305_TAG_SIZE);
const V1_MAX_BUFFER_SIZE: usize = 65536;

const V1_REJECT_RES_JSON: &str =
Expand Down Expand Up @@ -209,10 +213,11 @@ async fn handle_ohttp_gateway(
bhttp_res
.write_bhttp(bhttp::Mode::KnownLength, &mut bhttp_bytes)
.map_err(|e| HandlerError::InternalServerError(e.into()))?;
bhttp_bytes.resize(PADDED_BHTTP_BYTES, 0);
bhttp_bytes.resize(BHTTP_REQ_BYTES, 0);
let ohttp_res = res_ctx
.encapsulate(&bhttp_bytes)
.map_err(|e| HandlerError::InternalServerError(e.into()))?;
assert!(ohttp_res.len() == ENCAPSULATED_MESSAGE_BYTES, "Unexpected OHTTP response size");
Ok(Response::new(full(ohttp_res)))
}

Expand Down
27 changes: 19 additions & 8 deletions payjoin/src/ohttp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,21 @@ use std::{error, fmt};

use bitcoin::base64::prelude::BASE64_URL_SAFE_NO_PAD;
use bitcoin::base64::Engine;
use bitcoin::key::constants::UNCOMPRESSED_PUBLIC_KEY_SIZE;

pub const PADDED_MESSAGE_BYTES: usize = 8192;
pub const ENCAPSULATED_MESSAGE_BYTES: usize = 8192;
const N_ENC: usize = UNCOMPRESSED_PUBLIC_KEY_SIZE;
const N_T: usize = crate::hpke::POLY1305_TAG_SIZE;
const OHTTP_REQ_HEADER_BYTES: usize = 7;
pub const PADDED_BHTTP_REQ_BYTES: usize =
ENCAPSULATED_MESSAGE_BYTES - (N_ENC + N_T + OHTTP_REQ_HEADER_BYTES);

pub fn ohttp_encapsulate(
ohttp_keys: &mut ohttp::KeyConfig,
method: &str,
target_resource: &str,
body: Option<&[u8]>,
) -> Result<(Vec<u8>, ohttp::ClientResponse), OhttpEncapsulationError> {
) -> Result<([u8; ENCAPSULATED_MESSAGE_BYTES], ohttp::ClientResponse), OhttpEncapsulationError> {
use std::fmt::Write;

let ctx = ohttp::ClientRequest::from_config(ohttp_keys)?;
Expand All @@ -33,17 +39,22 @@ pub fn ohttp_encapsulate(
if let Some(body) = body {
bhttp_message.write_content(body);
}
let mut bhttp_req = Vec::new();
let _ = bhttp_message.write_bhttp(bhttp::Mode::KnownLength, &mut bhttp_req);
bhttp_req.resize(PADDED_MESSAGE_BYTES, 0);
let encapsulated = ctx.encapsulate(&bhttp_req)?;
Ok(encapsulated)

let mut bhttp_req = [0u8; PADDED_BHTTP_REQ_BYTES];
let mut cursor = std::io::Cursor::new(&mut bhttp_req[..]);
let _ = bhttp_message.write_bhttp(bhttp::Mode::KnownLength, &mut cursor);
let (encapsulated, ohttp_ctx) = ctx.encapsulate(&bhttp_req)?;

let mut buffer = [0u8; ENCAPSULATED_MESSAGE_BYTES];
let len = encapsulated.len().min(ENCAPSULATED_MESSAGE_BYTES);
buffer[..len].copy_from_slice(&encapsulated[..len]);
Ok((buffer, ohttp_ctx))
}

/// decapsulate ohttp, bhttp response and return http response body and status code
pub fn ohttp_decapsulate(
res_ctx: ohttp::ClientResponse,
ohttp_body: &[u8],
ohttp_body: &[u8; ENCAPSULATED_MESSAGE_BYTES],
) -> Result<http::Response<Vec<u8>>, OhttpEncapsulationError> {
let bhttp_body = res_ctx.decapsulate(ohttp_body)?;
let mut r = std::io::Cursor::new(bhttp_body);
Expand Down
9 changes: 9 additions & 0 deletions payjoin/src/receive/v2/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ pub(crate) enum InternalSessionError {
Expired(std::time::SystemTime),
/// OHTTP Encapsulation failed
OhttpEncapsulation(OhttpEncapsulationError),
/// Unexpected response size
UnexpectedResponseSize(usize),
}

impl fmt::Display for SessionError {
Expand All @@ -20,6 +22,12 @@ impl fmt::Display for SessionError {
InternalSessionError::Expired(expiry) => write!(f, "Session expired at {:?}", expiry),
InternalSessionError::OhttpEncapsulation(e) =>
write!(f, "OHTTP Encapsulation Error: {}", e),
InternalSessionError::UnexpectedResponseSize(size) => write!(
f,
"Unexpected response size {}, expected {} bytes",
size,
crate::ohttp::ENCAPSULATED_MESSAGE_BYTES
),
}
}
}
Expand All @@ -29,6 +37,7 @@ impl error::Error for SessionError {
match &self.0 {
InternalSessionError::Expired(_) => None,
InternalSessionError::OhttpEncapsulation(e) => Some(e),
InternalSessionError::UnexpectedResponseSize(_) => None,
}
}
}
Expand Down
27 changes: 20 additions & 7 deletions payjoin/src/receive/v2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,17 @@ impl Receiver {
/// indicating no UncheckedProposal is available yet.
pub fn process_res(
&mut self,
mut body: impl std::io::Read,
body: &[u8],
context: ohttp::ClientResponse,
) -> Result<Option<UncheckedProposal>, Error> {
let mut buf = Vec::new();
let _ = body.read_to_end(&mut buf);
let response_array: &[u8; crate::ohttp::ENCAPSULATED_MESSAGE_BYTES] =
body.try_into().map_err(|_| {
Error::Server(Box::new(SessionError::from(
InternalSessionError::UnexpectedResponseSize(body.len()),
)))
})?;
log::trace!("decapsulating directory response");
let response = ohttp_decapsulate(context, &buf)?;
let response = ohttp_decapsulate(context, response_array)?;
if response.body().is_empty() {
log::debug!("response is empty");
return Ok(None);
Expand All @@ -134,7 +138,10 @@ impl Receiver {

fn fallback_req_body(
&mut self,
) -> Result<(Vec<u8>, ohttp::ClientResponse), OhttpEncapsulationError> {
) -> Result<
([u8; crate::ohttp::ENCAPSULATED_MESSAGE_BYTES], ohttp::ClientResponse),
OhttpEncapsulationError,
> {
let fallback_target = self.pj_url();
ohttp_encapsulate(&mut self.context.ohttp_keys, "GET", fallback_target.as_str(), None)
}
Expand Down Expand Up @@ -509,10 +516,16 @@ impl PayjoinProposal {
/// choose to broadcast the original PSBT.
pub fn process_res(
&self,
res: Vec<u8>,
res: &[u8],
ohttp_context: ohttp::ClientResponse,
) -> Result<(), Error> {
let res = ohttp_decapsulate(ohttp_context, &res)?;
let response_array: &[u8; crate::ohttp::ENCAPSULATED_MESSAGE_BYTES] =
res.try_into().map_err(|_| {
Error::Server(Box::new(SessionError::from(
InternalSessionError::UnexpectedResponseSize(res.len()),
)))
})?;
let res = ohttp_decapsulate(ohttp_context, response_array)?;
if res.status().is_success() {
Ok(())
} else {
Expand Down
4 changes: 2 additions & 2 deletions payjoin/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ impl Request {
}

#[cfg(feature = "v2")]
pub fn new_v2(url: Url, body: Vec<u8>) -> Self {
Self { url, content_type: V2_REQ_CONTENT_TYPE, body }
pub fn new_v2(url: Url, body: [u8; crate::ohttp::ENCAPSULATED_MESSAGE_BYTES]) -> Self {
Self { url, content_type: V2_REQ_CONTENT_TYPE, body: body.to_vec() }
}
}
6 changes: 6 additions & 0 deletions payjoin/src/send/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ pub(crate) enum InternalValidationError {
OhttpEncapsulation(crate::ohttp::OhttpEncapsulationError),
#[cfg(feature = "v2")]
UnexpectedStatusCode,
#[cfg(feature = "v2")]
UnexpectedResponseSize(usize),
}

impl From<InternalValidationError> for ValidationError {
Expand Down Expand Up @@ -119,6 +121,8 @@ impl fmt::Display for ValidationError {
OhttpEncapsulation(e) => write!(f, "Ohttp encapsulation error: {}", e),
#[cfg(feature = "v2")]
UnexpectedStatusCode => write!(f, "unexpected status code"),
#[cfg(feature = "v2")]
UnexpectedResponseSize(size) => write!(f, "unexpected response size {}, expected {} bytes", size, crate::ohttp::ENCAPSULATED_MESSAGE_BYTES),
}
}
}
Expand Down Expand Up @@ -164,6 +168,8 @@ impl std::error::Error for ValidationError {
OhttpEncapsulation(error) => Some(error),
#[cfg(feature = "v2")]
UnexpectedStatusCode => None,
#[cfg(feature = "v2")]
UnexpectedResponseSize(_) => None,
}
}
}
Expand Down
20 changes: 10 additions & 10 deletions payjoin/src/send/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,12 +358,9 @@ pub struct V2PostContext {

#[cfg(feature = "v2")]
impl V2PostContext {
pub fn process_response(
self,
response: &mut impl std::io::Read,
) -> Result<V2GetContext, ResponseError> {
let mut res_buf = Vec::new();
response.read_to_end(&mut res_buf).map_err(InternalValidationError::Io)?;
pub fn process_response(self, response: &[u8]) -> Result<V2GetContext, ResponseError> {
let mut res_buf = [0u8; crate::ohttp::ENCAPSULATED_MESSAGE_BYTES];
res_buf[..response.len()].copy_from_slice(response);
let response = ohttp_decapsulate(self.ohttp_ctx, &res_buf)
.map_err(InternalValidationError::OhttpEncapsulation)?;
match response.status() {
Expand Down Expand Up @@ -417,12 +414,15 @@ impl V2GetContext {

pub fn process_response(
&self,
response: &mut impl std::io::Read,
response: &[u8],
ohttp_ctx: ohttp::ClientResponse,
) -> Result<Option<Psbt>, ResponseError> {
let mut res_buf = Vec::new();
response.read_to_end(&mut res_buf).map_err(InternalValidationError::Io)?;
let response = ohttp_decapsulate(ohttp_ctx, &res_buf)
let response_array: &[u8; crate::ohttp::ENCAPSULATED_MESSAGE_BYTES] =
response
.try_into()
.map_err(|_| InternalValidationError::UnexpectedResponseSize(response.len()))?;

let response = ohttp_decapsulate(ohttp_ctx, response_array)
.map_err(InternalValidationError::OhttpEncapsulation)?;
let body = match response.status() {
http::StatusCode::OK => response.body().to_vec(),
Expand Down
31 changes: 11 additions & 20 deletions payjoin/tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,7 @@ mod integration {
.unwrap();
log::info!("Response: {:#?}", &response);
assert!(response.status().is_success());
let send_ctx =
send_ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?;
let send_ctx = send_ctx.process_response(&response.bytes().await?)?;
// POST Original PSBT

// **********************
Expand All @@ -390,8 +389,7 @@ mod integration {
.body(req.body)
.send()
.await?;
let res = response.bytes().await?.to_vec();
payjoin_proposal.process_res(res, ctx)?;
payjoin_proposal.process_res(&response.bytes().await?, ctx)?;

// **********************
// Inside the Sender:
Expand All @@ -407,9 +405,8 @@ mod integration {
.await
.unwrap();
log::info!("Response: {:#?}", &response);
let checked_payjoin_proposal_psbt = send_ctx
.process_response(&mut response.bytes().await?.to_vec().as_slice(), ohttp_ctx)?
.unwrap();
let checked_payjoin_proposal_psbt =
send_ctx.process_response(&response.bytes().await?, ohttp_ctx)?.unwrap();
let payjoin_tx = extract_pj_tx(&sender, checked_payjoin_proposal_psbt)?;
sender.send_raw_transaction(&payjoin_tx)?;
log::info!("sent");
Expand Down Expand Up @@ -503,8 +500,7 @@ mod integration {
let (req, ctx) = session.extract_req()?;
let response = agent.post(req.url).body(req.body).send().await?;
assert!(response.status().is_success());
let response_body =
session.process_res(response.bytes().await?.to_vec().as_slice(), ctx).unwrap();
let response_body = session.process_res(&response.bytes().await?, ctx).unwrap();
// No proposal yet since sender has not responded
assert!(response_body.is_none());

Expand All @@ -530,8 +526,7 @@ mod integration {
.unwrap();
log::info!("Response: {:#?}", &response);
assert!(response.status().is_success());
let get_ctx =
post_ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?;
let get_ctx = post_ctx.process_response(&response.bytes().await?)?;
let (Request { url, body, content_type, .. }, ohttp_ctx) =
get_ctx.extract_req(directory.to_owned())?;
let response = agent
Expand All @@ -541,9 +536,7 @@ mod integration {
.send()
.await?;
// No response body yet since we are async and pushed fallback_psbt to the buffer
assert!(get_ctx
.process_response(&mut response.bytes().await?.to_vec().as_slice(), ohttp_ctx)?
.is_none());
assert!(get_ctx.process_response(&response.bytes().await?, ohttp_ctx)?.is_none());

// **********************
// Inside the Receiver:
Expand All @@ -560,8 +553,7 @@ mod integration {
assert!(!payjoin_proposal.is_output_substitution_disabled());
let (req, ctx) = payjoin_proposal.extract_v2_req()?;
let response = agent.post(req.url).body(req.body).send().await?;
let res = response.bytes().await?.to_vec();
payjoin_proposal.process_res(res, ctx)?;
payjoin_proposal.process_res(&response.bytes().await?, ctx)?;

// **********************
// Inside the Sender:
Expand All @@ -575,9 +567,8 @@ mod integration {
.body(body.clone())
.send()
.await?;
let checked_payjoin_proposal_psbt = get_ctx
.process_response(&mut response.bytes().await?.to_vec().as_slice(), ohttp_ctx)?
.unwrap();
let checked_payjoin_proposal_psbt =
get_ctx.process_response(&response.bytes().await?, ohttp_ctx)?.unwrap();
let payjoin_tx = extract_pj_tx(&sender, checked_payjoin_proposal_psbt)?;
sender.send_raw_transaction(&payjoin_tx)?;
log::info!("sent");
Expand Down Expand Up @@ -739,7 +730,7 @@ mod integration {
let (req, ctx) = payjoin_proposal.extract_v2_req().unwrap();
let response = agent_clone.post(req.url).body(req.body).send().await?;
payjoin_proposal
.process_res(response.bytes().await?.to_vec(), ctx)
.process_res(&response.bytes().await?, ctx)
.map_err(|e| e.to_string())?;
Ok::<_, Box<dyn std::error::Error + Send + Sync>>(())
});
Expand Down

0 comments on commit d6f927c

Please sign in to comment.