diff --git a/payjoin/src/send/mod.rs b/payjoin/src/send/mod.rs index 74d37611..c30f6c98 100644 --- a/payjoin/src/send/mod.rs +++ b/payjoin/src/send/mod.rs @@ -312,7 +312,7 @@ impl RequestContext { Err(e) => { log::warn!("Failed to extract `rs` pubkey, falling back to v1: {}", e); let (req, context_v1) = self.extract_v1()?; - Ok((req, ContextV2 { context_v1, e: None, ohttp_res: None })) + Ok((req, ContextV2 { context_v1, rs: None, e: None, ohttp_res: None })) } } } @@ -355,6 +355,7 @@ impl RequestContext { sequence: self.sequence, min_fee_rate: self.min_fee_rate, }, + rs: Some(self.extract_rs_pubkey()?), e: Some(self.e.clone()), ohttp_res: Some(ohttp_res), }, @@ -402,6 +403,7 @@ pub struct ContextV1 { #[cfg(feature = "v2")] pub struct ContextV2 { context_v1: ContextV1, + rs: Option, e: Option, ohttp_res: Option, } @@ -437,8 +439,8 @@ impl ContextV2 { self, response: &mut impl std::io::Read, ) -> Result, ResponseError> { - match (self.ohttp_res, self.e) { - (Some(ohttp_res), Some(e)) => { + match (self.ohttp_res, self.rs, self.e) { + (Some(ohttp_res), Some(rs), Some(e)) => { let mut res_buf = Vec::new(); response.read_to_end(&mut res_buf).map_err(InternalValidationError::Io)?; let response = crate::v2::ohttp_decapsulate(ohttp_res, &res_buf) @@ -448,7 +450,7 @@ impl ContextV2 { http::StatusCode::ACCEPTED => return Ok(None), _ => return Err(InternalValidationError::UnexpectedStatusCode)?, }; - let psbt = crate::v2::decrypt_message_b_hpke(&body, e) + let psbt = crate::v2::decrypt_message_b_hpke(&body, rs, e) .map_err(InternalValidationError::Hpke)?; let proposal = Psbt::deserialize(&psbt).map_err(InternalValidationError::Psbt)?; diff --git a/payjoin/src/v2.rs b/payjoin/src/v2.rs index c2a002f3..81a245d0 100644 --- a/payjoin/src/v2.rs +++ b/payjoin/src/v2.rs @@ -11,8 +11,9 @@ use hpke::rand_core::OsRng; use hpke::{Deserializable, OpModeR, OpModeS, Serializable}; pub const PADDED_MESSAGE_BYTES: usize = 7168; -pub const PADDED_PLAINTEXT_LENGTH: usize = PADDED_MESSAGE_BYTES - UNCOMPRESSED_PUBLIC_KEY_SIZE * 2; - +pub const PADDED_PLAINTEXT_A_LENGTH: usize = + PADDED_MESSAGE_BYTES - UNCOMPRESSED_PUBLIC_KEY_SIZE * 2; +pub const PADDED_PLAINTEXT_B_LENGTH: usize = PADDED_MESSAGE_BYTES - UNCOMPRESSED_PUBLIC_KEY_SIZE; pub const INFO_A: &[u8] = b"Payjoin v2 Message A"; pub const INFO_B: &[u8] = b"Payjoin v2 Message B"; @@ -131,7 +132,7 @@ pub fn encrypt_message_a_hpke( &mut OsRng, )?; let aad = pk.to_bytes().to_vec(); - let plaintext = pad_plaintext(&mut plaintext)?; + let plaintext = pad_plaintext(&mut plaintext, PADDED_PLAINTEXT_A_LENGTH)?; let ciphertext = encryption_context.seal(plaintext, &aad)?; let mut message_a = encapsulated_key.to_bytes().to_vec(); message_a.extend(&aad); @@ -172,36 +173,36 @@ pub fn encrypt_message_b_hpke( INFO_B, &mut OsRng, )?; - let aad = pk.to_bytes().to_vec(); - let plaintext = pad_plaintext(&mut plaintext)?; - let ciphertext = encryption_context.seal(plaintext, &aad)?; + let plaintext = pad_plaintext(&mut plaintext, PADDED_PLAINTEXT_B_LENGTH)?; + let ciphertext = encryption_context.seal(plaintext, &[])?; let mut message_b = encapsulated_key.to_bytes().to_vec(); - message_b.extend(&aad); message_b.extend(&ciphertext); Ok(message_b.to_vec()) } #[cfg(feature = "send")] -pub fn decrypt_message_b_hpke(message_b: &[u8], s: HpkeSecretKey) -> Result, HpkeError> { +pub fn decrypt_message_b_hpke( + message_b: &[u8], + rs: HpkePublicKey, + s: HpkeSecretKey, +) -> Result, HpkeError> { let enc = message_b.get(..65).ok_or(HpkeError::PayloadTooShort)?; - let enc = EncappedKey::from_bytes(enc).unwrap(); - let aad = message_b.get(65..130).ok_or(HpkeError::PayloadTooShort)?; - let pk_s = PublicKey::from_bytes(aad)?; + let enc = EncappedKey::from_bytes(enc)?; let mut decryption_ctx = hpke::setup_receiver::< ChaCha20Poly1305, HkdfSha256, SecpK256HkdfSha256, - >(&OpModeR::Auth(pk_s), &s.0, &enc, INFO_B)?; + >(&OpModeR::Auth(rs.0), &s.0, &enc, INFO_B)?; let plaintext = - decryption_ctx.open(message_b.get(130..).ok_or(HpkeError::PayloadTooShort)?, aad)?; + decryption_ctx.open(message_b.get(65..).ok_or(HpkeError::PayloadTooShort)?, &[])?; Ok(plaintext) } -fn pad_plaintext(msg: &mut Vec) -> Result<&[u8], HpkeError> { - if msg.len() > PADDED_PLAINTEXT_LENGTH { - return Err(HpkeError::PayloadTooLarge); +fn pad_plaintext(msg: &mut Vec, padded_length: usize) -> Result<&[u8], HpkeError> { + if msg.len() > padded_length { + return Err(HpkeError::PayloadTooLarge { actual: msg.len(), max: padded_length }); } - msg.resize(PADDED_PLAINTEXT_LENGTH, 0); + msg.resize(padded_length, 0); Ok(msg) } @@ -211,7 +212,7 @@ pub enum HpkeError { Secp256k1(bitcoin::secp256k1::Error), Hpke(hpke::HpkeError), InvalidKeyLength, - PayloadTooLarge, + PayloadTooLarge { actual: usize, max: usize }, PayloadTooShort, } @@ -230,8 +231,13 @@ impl fmt::Display for HpkeError { match &self { Hpke(e) => e.fmt(f), InvalidKeyLength => write!(f, "Invalid Length"), - PayloadTooLarge => - write!(f, "Plaintext too large, max size is {} bytes", PADDED_PLAINTEXT_LENGTH), + PayloadTooLarge { actual, max } => { + write!( + f, + "Plaintext too large, max size is {} bytes, actual size is {} bytes", + max, actual + ) + } PayloadTooShort => write!(f, "Payload too small"), Secp256k1(e) => e.fmt(f), } @@ -244,7 +250,8 @@ impl error::Error for HpkeError { match &self { Hpke(e) => Some(e), - InvalidKeyLength | PayloadTooLarge | PayloadTooShort => None, + PayloadTooLarge { .. } => None, + InvalidKeyLength | PayloadTooShort => None, Secp256k1(e) => Some(e), } }