Skip to content

Commit

Permalink
Refactor plaintext padding for clarity and safety
Browse files Browse the repository at this point in the history
This change enforces that messages have a uniform length at the type
level.

If bitcoin-hpke was modified to retain the underlying in-place interface
then this code could be further simplified so that there is only one
PADDED_MESSAGE_BYTES length buffer shared by all steps, which would also
save a copy step.
  • Loading branch information
nothingmuch committed Oct 24, 2024
1 parent ed0555c commit 9d432cb
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 49 deletions.
97 changes: 49 additions & 48 deletions payjoin/src/hpke.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::io::{Cursor, Read, Write};
use std::ops::Deref;
use std::{error, fmt};

Expand All @@ -12,11 +13,13 @@ use hpke::{Deserializable, OpModeR, OpModeS, Serializable};
use serde::{Deserialize, Serialize};

pub const PADDED_MESSAGE_BYTES: usize = 7168;
pub const PADDED_PLAINTEXT_A_LENGTH: usize = PADDED_MESSAGE_BYTES
- (ELLSWIFT_ENCODING_SIZE + UNCOMPRESSED_PUBLIC_KEY_SIZE + POLY1305_TAG_SIZE + 4);
pub const PADDED_PLAINTEXT_B_LENGTH: usize =
PADDED_MESSAGE_BYTES - (ELLSWIFT_ENCODING_SIZE + POLY1305_TAG_SIZE + 4);
pub const HPKE_OVERHEAD_BYTES: usize = ELLSWIFT_ENCODING_SIZE + POLY1305_TAG_SIZE;
pub const MAX_PLAINTEXT_LENGTH: usize =
PADDED_MESSAGE_BYTES - (HPKE_OVERHEAD_BYTES + MAX_TLV_OVERHEAD);
pub const PADDED_PLAINTEXT_A_LENGTH: usize = MAX_PLAINTEXT_LENGTH - UNCOMPRESSED_PUBLIC_KEY_SIZE;
pub const PADDED_PLAINTEXT_B_LENGTH: usize = MAX_PLAINTEXT_LENGTH;
pub const POLY1305_TAG_SIZE: usize = 16; // FIXME there is a U16 defined for poly1305, should bitcoin hpke re-export it?
pub const MAX_TLV_OVERHEAD: usize = 4;
pub const INFO_A: &[u8; 8] = b"PjV2MsgA";
pub const INFO_B: &[u8; 8] = b"PjV2MsgB";

Expand Down Expand Up @@ -152,7 +155,7 @@ pub fn encrypt_message_a(
body: Vec<u8>,
reply_pk: &HpkePublicKey,
receiver_pk: &HpkePublicKey,
) -> Result<Vec<u8>, HpkeError> {
) -> Result<[u8; PADDED_MESSAGE_BYTES], HpkeError> {
let (encapsulated_key, mut encryption_context) =
hpke::setup_sender::<ChaCha20Poly1305, HkdfSha256, SecpK256HkdfSha256, _>(
&OpModeS::Base,
Expand All @@ -161,29 +164,25 @@ pub fn encrypt_message_a(
&mut OsRng,
)?;

let length = UNCOMPRESSED_PUBLIC_KEY_SIZE + body.len();

let mut body = body;
let extra_pad = if length < 0xfd { 2 } else { 0 }; // add 2 extra bytes of padding if BigSize is 1 byte instead of 3
pad_plaintext(&mut body, PADDED_PLAINTEXT_A_LENGTH + extra_pad)?;
let mut plaintext = [0x00u8; PADDED_MESSAGE_BYTES - HPKE_OVERHEAD_BYTES];
let mut c = prepare_tlv(&mut plaintext, body.len(), UNCOMPRESSED_PUBLIC_KEY_SIZE)?;
c.write(&reply_pk.to_bytes()).expect("length checked by prepare_tlv");
c.write(&body).expect("length checked by prepare_tlv");

let mut plaintext = encode_tlv(length.try_into().expect("checked by pad_plaintext"));
plaintext.extend(reply_pk.to_bytes());
plaintext.extend(body);
let mut message_a = [0u8; PADDED_MESSAGE_BYTES];
let mut c = Cursor::new(&mut message_a[..]);
c.write(&ellswift_bytes_from_encapped_key(&encapsulated_key)?)
.expect("length checked by prepare_tlv");
c.write(&encryption_context.seal(&plaintext, &[])?).expect("length checked by prepare_tlv");

let ciphertext = encryption_context.seal(&plaintext, &[])?;
let mut message_a = ellswift_bytes_from_encapped_key(&encapsulated_key)?.to_vec();
message_a.extend(&ciphertext);
Ok(message_a.to_vec())
Ok(message_a)
}

#[cfg(feature = "receive")]
pub fn decrypt_message_a(
message_a: &[u8],
receiver_sk: HpkeSecretKey,
) -> Result<(Vec<u8>, HpkePublicKey), HpkeError> {
use std::io::{Cursor, Read};

let mut cursor = Cursor::new(message_a);

let mut enc_bytes = [0u8; ELLSWIFT_ENCODING_SIZE];
Expand Down Expand Up @@ -213,10 +212,10 @@ pub fn decrypt_message_a(
/// Message B is sent from the receiver to the sender containing a Payjoin PSBT payload or an error
#[cfg(feature = "receive")]
pub fn encrypt_message_b(
mut body: Vec<u8>,
body: &[u8],
receiver_keypair: &HpkeKeyPair,
sender_pk: &HpkePublicKey,
) -> Result<Vec<u8>, HpkeError> {
) -> Result<[u8; PADDED_MESSAGE_BYTES], HpkeError> {
let (encapsulated_key, mut encryption_context) =
hpke::setup_sender::<ChaCha20Poly1305, HkdfSha256, SecpK256HkdfSha256, _>(
&OpModeS::Auth((
Expand All @@ -228,18 +227,17 @@ pub fn encrypt_message_b(
&mut OsRng,
)?;

let length = body.len();
let extra_pad = if length < 0xfd { 2 } else { 0 }; // add 2 extra bytes of padding if BigSize is 1 byte instead of 3
pad_plaintext(&mut body, PADDED_PLAINTEXT_B_LENGTH + extra_pad)?;
let mut plaintext = [0x00u8; PADDED_MESSAGE_BYTES - HPKE_OVERHEAD_BYTES];
let mut c = prepare_tlv(&mut plaintext, body.len(), 0)?;
c.write(body).expect("length checked by prepare_tlv");

let mut plaintext =
encode_tlv(length.try_into().expect("length already checked in pad_plaintext"));
plaintext.extend(body);
let mut message_b = [0u8; PADDED_MESSAGE_BYTES];
c = Cursor::new(&mut message_b);
c.write(&ellswift_bytes_from_encapped_key(&encapsulated_key)?)
.expect("length checked by prepare_tlv");
c.write(&encryption_context.seal(&plaintext, &[])?).expect("length checked by prepare_tlv");

let ciphertext = encryption_context.seal(&plaintext, &[])?;
let mut message_b = ellswift_bytes_from_encapped_key(&encapsulated_key)?.to_vec();
message_b.extend(&ciphertext);
Ok(message_b.to_vec())
Ok(message_b)
}

#[cfg(feature = "send")]
Expand All @@ -262,24 +260,27 @@ pub fn decrypt_message_b(
Ok(extract_tlv_value(&plaintext)?.to_vec())
}

fn pad_plaintext(msg: &mut Vec<u8>, padded_length: usize) -> Result<&[u8], HpkeError> {
if msg.len() > padded_length {
return Err(HpkeError::PayloadTooLarge { actual: msg.len(), max: padded_length });
}
msg.resize(padded_length, 0);
Ok(msg)
}

fn encode_tlv(length: u16) -> Vec<u8> {
fn prepare_tlv<'a>(
buf: &'a mut [u8; PADDED_MESSAGE_BYTES - HPKE_OVERHEAD_BYTES],
body_length: usize,
overhead: usize,
) -> Result<Cursor<&'a mut [u8]>, HpkeError> {
let length = body_length + overhead;
if length < 0xfd {
vec![0x00, length.try_into().expect("length checked in conditional")]
} else {
let mut buf = vec![0x00, 0xfd, 0x00, 0x00];
buf[1] = length.try_into().expect("length checked in conditional");
Ok(Cursor::new(&mut buf[2..MAX_PLAINTEXT_LENGTH - 2]))
} else if length <= MAX_PLAINTEXT_LENGTH {
buf[1] = 0xfd;
NetworkEndian::write_u16(
&mut buf[2..4],
length.try_into().expect("length already checked in pad_plaintext"),
length.try_into().expect("length checked in conditional"),
);
buf
Ok(Cursor::new(&mut buf[4..]))
} else {
Err(HpkeError::PayloadTooLarge {
actual: body_length,
max: MAX_PLAINTEXT_LENGTH - overhead,
})
}
}

Expand Down Expand Up @@ -434,7 +435,7 @@ mod test {
let receiver_keypair = HpkeKeyPair::gen_keypair();

let message_b =
encrypt_message_b(plaintext.clone(), &receiver_keypair, reply_keypair.public_key())
encrypt_message_b(&plaintext, &receiver_keypair, reply_keypair.public_key())
.expect("encryption should work");

assert_eq!(message_b.len(), PADDED_MESSAGE_BYTES);
Expand All @@ -451,7 +452,7 @@ mod test {
plaintext.resize(PADDED_PLAINTEXT_B_LENGTH, 0);
plaintext[PADDED_PLAINTEXT_B_LENGTH - 1] = 42;
let message_b =
encrypt_message_b(plaintext.clone(), &receiver_keypair, reply_keypair.public_key())
encrypt_message_b(&plaintext, &receiver_keypair, reply_keypair.public_key())
.expect("encryption should work");

assert_eq!(message_b.len(), PADDED_MESSAGE_BYTES);
Expand Down Expand Up @@ -506,7 +507,7 @@ mod test {

plaintext.resize(PADDED_PLAINTEXT_B_LENGTH + 1, 0);
assert_eq!(
encrypt_message_b(plaintext.clone(), &receiver_keypair, reply_keypair.public_key()),
encrypt_message_b(&plaintext, &receiver_keypair, reply_keypair.public_key()),
Err(HpkeError::PayloadTooLarge {
actual: PADDED_PLAINTEXT_B_LENGTH + 1,
max: PADDED_PLAINTEXT_B_LENGTH
Expand Down
2 changes: 1 addition & 1 deletion payjoin/src/receive/v2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ impl PayjoinProposal {
let sender_subdir = subdir_path_from_pubkey(e);
target_resource =
self.context.directory.join(&sender_subdir).map_err(|e| Error::Server(e.into()))?;
body = encrypt_message_b(payjoin_bytes, &self.context.s, e)?;
body = encrypt_message_b(&payjoin_bytes, &self.context.s, e)?.to_vec();
method = "POST";
} else {
// Prepare v2 wrapped and backwards-compatible v1 payload
Expand Down

0 comments on commit 9d432cb

Please sign in to comment.