diff --git a/payjoin/src/hpke.rs b/payjoin/src/hpke.rs index c5cfb429..f443a661 100644 --- a/payjoin/src/hpke.rs +++ b/payjoin/src/hpke.rs @@ -466,4 +466,81 @@ mod test { }) ); } + + /// Test that the encrypted payloads are uniform. + /// + /// This randomized test will generate a false negative with negligible probability + /// if all encrypted messages share an identical byte at a given position by chance. + /// It should fail deterministically before the ellswift changes due to the first + /// byte being 0x01 in all messages in all runs (a only, b only, a+b). + /// It should fail deterministically for a only messages because of the second + /// 0x01 byte at position 65 (the previously aad'd reply key). + #[test] + fn test_encrypted_payload_bit_uniformity() { + fn generate_messages(count: usize) -> (Vec>, Vec>) { + let mut messages_a = Vec::with_capacity(count); + let mut messages_b = Vec::with_capacity(count); + + for _ in 0..count { + let sender_keypair = HpkeKeyPair::gen_keypair(); + let receiver_keypair = HpkeKeyPair::gen_keypair(); + let reply_keypair = HpkeKeyPair::gen_keypair(); + + let plaintext_a = vec![0u8; PADDED_PLAINTEXT_A_LENGTH]; + let message_a = encrypt_message_a( + plaintext_a, + reply_keypair.public_key(), + receiver_keypair.public_key(), + ) + .expect("encryption should work"); + + let plaintext_b = vec![0u8; PADDED_PLAINTEXT_B_LENGTH]; + let message_b = + encrypt_message_b(plaintext_b, &receiver_keypair, sender_keypair.public_key()) + .expect("encryption should work"); + + messages_a.push(message_a); + messages_b.push(message_b); + } + + (messages_a, messages_b) + } + + /// For each of the 256 pairwise combinations, ensure their lengths are equal, + /// XOR the two messages together, and OR this into an accumulator that starts + /// as all 0x00s. + fn check_uniformity(messages: Vec>) { + let mut accumulator = vec![0u8; PADDED_MESSAGE_BYTES]; + + for (i, msg1) in messages.iter().enumerate() { + for msg2 in messages.iter().skip(i + 1) { + assert_eq!(msg1.len(), msg2.len(), "Message lengths should be equal"); + for (acc, (&b1, &b2)) in + accumulator.iter_mut().zip(msg1.iter().zip(msg2.iter())) + { + *acc |= b1 ^ b2; + } + } + } + + assert!( + accumulator.iter().all(|&b| b != 0), + "All bytes in the accumulator should be non-zero" + ); + } + + // Generate 8 each of messages a and b + let (messages_a, messages_b) = generate_messages(8); + let mut combined_messages = messages_a; + combined_messages.extend(messages_b); + check_uniformity(combined_messages); + + // Generate 16 messages of only a + let (messages_a, _) = generate_messages(16); + check_uniformity(messages_a); + + // Generate 16 messages of only b + let (_, messages_b) = generate_messages(16); + check_uniformity(messages_b); + } }