From 981479dec0225bbe82f4e7a1f05bc1dad68c4aa0 Mon Sep 17 00:00:00 2001 From: Nick Johnson Date: Fri, 27 Sep 2024 14:00:10 -0700 Subject: [PATCH] Add a Payload type to avoid performance hit --- protocol/src/lib.rs | 98 ++++++++++++++++++++--------------- protocol/tests/round_trips.rs | 12 ++--- proxy/src/lib.rs | 16 +++--- 3 files changed, 71 insertions(+), 55 deletions(-) diff --git a/protocol/src/lib.rs b/protocol/src/lib.rs index 26598fd..5c9d325 100644 --- a/protocol/src/lib.rs +++ b/protocol/src/lib.rs @@ -220,10 +220,9 @@ pub enum PacketType { } impl PacketType { - /// Check if plaintext packet is a decoy. - pub fn from_bytes(plaintext: &[u8]) -> Self { - // Check if header byte has the decoy flag flipped. - if plaintext.first() == Some(&DECOY_BYTE) { + /// Check if header byte has the decoy flag flipped. + pub fn from_byte(header: &u8) -> Self { + if header == &DECOY_BYTE { PacketType::Decoy } else { PacketType::Genuine @@ -239,6 +238,30 @@ impl PacketType { } } +/// Plaintext payload of a packet, which includes the header and contents. +#[cfg(feature = "alloc")] +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Payload { + bytes: Vec, +} + +impl Payload { + pub fn new(bytes: Vec) -> Self { + Self { bytes } + } + + /// Contents of the payload. + pub fn contents(&self) -> &[u8] { + // Exclude the header byte. + &self.bytes[1..] + } + + /// Packet type of the payload. + pub fn packet_type(&self) -> PacketType { + PacketType::from_byte(&self.bytes[0]) + } +} + /// Read packets over established encrypted channel from a peer. #[derive(Clone)] pub struct PacketReader { @@ -271,7 +294,7 @@ impl PacketReader { content_len as usize + NUM_HEADER_BYTES + NUM_TAG_BYTES } - /// Decrypt the packet contents. + /// Decrypt the packet header byte and contents. /// /// # Arguments /// @@ -292,12 +315,12 @@ impl PacketReader { /// * `CiphertextTooSmall` - Ciphertext argument does not contain a whole packet. /// * `BufferTooSmall ` - Contents buffer argument is not large enough for plaintext. /// * Decryption errors for any failures such as a tag mismatch. - pub fn decrypt_contents( + pub fn decrypt_payload( &mut self, ciphertext: &[u8], contents: &mut [u8], aad: Option<&[u8]>, - ) -> Result { + ) -> Result<(), Error> { let auth = aad.unwrap_or_default(); // Check minimum size of ciphertext. if ciphertext.len() < NUM_TAG_BYTES { @@ -317,10 +340,10 @@ impl PacketReader { tag.try_into().expect("16 byte tag"), )?; - Ok(PacketType::from_bytes(contents)) + Ok(()) } - /// Decrypt the packet contents. + /// Decrypt the packet header byte and contents. /// /// # Arguments /// @@ -331,27 +354,21 @@ impl PacketReader { /// # Returns /// /// A `Result` containing: - /// * `Ok(Some(Vec))`: The plaintext contents in a byte vector if it is not a decoy packet. + /// * `Ok(Payload)`: The plaintext header and contents. /// * `Err(Error)`: An error that occurred during decryption. /// /// # Errors /// /// * `CiphertextTooSmall` - Ciphertext argument does not contain a whole packet. #[cfg(feature = "alloc")] - pub fn decrypt_contents_with_alloc( + pub fn decrypt_payload_with_alloc( &mut self, ciphertext: &[u8], aad: Option<&[u8]>, - ) -> Result>, Error> { - let mut contents = vec![0u8; ciphertext.len() - NUM_TAG_BYTES]; - match self.decrypt_contents(ciphertext, &mut contents, aad)? { - PacketType::Decoy => Ok(None), - PacketType::Genuine => { - // Drop the header byte. - contents.remove(0); - Ok(Some(contents)) - } - } + ) -> Result { + let mut payload = vec![0u8; ciphertext.len() - NUM_TAG_BYTES]; + self.decrypt_payload(ciphertext, &mut payload, aad)?; + Ok(Payload::new(payload)) } } @@ -881,7 +898,7 @@ impl<'a> Handshake<'a> { if ciphertext.len() < self.current_buffer_index + NUM_LENGTH_BYTES + packet_length { return Err(Error::CiphertextTooSmall); } - packet_handler.packet_reader.decrypt_contents( + packet_handler.packet_reader.decrypt_payload( &ciphertext[self.current_buffer_index + NUM_LENGTH_BYTES ..self.current_buffer_index + NUM_LENGTH_BYTES + packet_length], packet_buffer, @@ -892,9 +909,8 @@ impl<'a> Handshake<'a> { self.current_buffer_index = self.current_buffer_index + NUM_LENGTH_BYTES + packet_length; self.current_packet_length_bytes = None; - // The version packet is currently just an empty packet. Ok(matches!( - PacketType::from_bytes(packet_buffer), + PacketType::from_byte(packet_buffer.first().expect("header byte")), PacketType::Genuine )) } @@ -1083,9 +1099,9 @@ mod tests { .unwrap(); let dec = bob_packet_handler .packet_reader - .decrypt_contents_with_alloc(&enc_packet[NUM_LENGTH_BYTES..], None) + .decrypt_payload_with_alloc(&enc_packet[NUM_LENGTH_BYTES..], None) .unwrap(); - assert_eq!(None, dec); + assert_eq!(PacketType::Decoy, dec.packet_type()); let message = b"Windows sox!".to_vec(); let enc_packet = bob_packet_handler .packet_writer @@ -1093,9 +1109,9 @@ mod tests { .unwrap(); let dec = alice_packet_handler .packet_reader - .decrypt_contents_with_alloc(&enc_packet[NUM_LENGTH_BYTES..], None) + .decrypt_payload_with_alloc(&enc_packet[NUM_LENGTH_BYTES..], None) .unwrap(); - assert_eq!(message, dec.unwrap()); + assert_eq!(message, dec.contents()); } #[test] @@ -1126,9 +1142,9 @@ mod tests { .unwrap(); let dec_packet = bob_packet_handler .packet_reader - .decrypt_contents_with_alloc(&enc_packet[NUM_LENGTH_BYTES..], None) + .decrypt_payload_with_alloc(&enc_packet[NUM_LENGTH_BYTES..], None) .unwrap(); - assert_eq!(message, dec_packet.unwrap()); + assert_eq!(message, dec_packet.contents()); let message = gen_garbage(420, &mut rng); let enc_packet = bob_packet_handler .packet_writer @@ -1136,9 +1152,9 @@ mod tests { .unwrap(); let dec_packet = alice_packet_handler .packet_reader - .decrypt_contents_with_alloc(&enc_packet[NUM_LENGTH_BYTES..], None) + .decrypt_payload_with_alloc(&enc_packet[NUM_LENGTH_BYTES..], None) .unwrap(); - assert_eq!(message, dec_packet.unwrap()); + assert_eq!(message, dec_packet.contents()); } } @@ -1168,7 +1184,7 @@ mod tests { .unwrap(); let _ = bob_packet_handler .packet_reader - .decrypt_contents_with_alloc(&enc_packet[NUM_LENGTH_BYTES..], Some(&auth_garbage)) + .decrypt_payload_with_alloc(&enc_packet[NUM_LENGTH_BYTES..], Some(&auth_garbage)) .unwrap(); } @@ -1281,9 +1297,9 @@ mod tests { .unwrap(); let dec = alice .packet_reader - .decrypt_contents_with_alloc(&encrypted_message_to_alice[NUM_LENGTH_BYTES..], None) + .decrypt_payload_with_alloc(&encrypted_message_to_alice[NUM_LENGTH_BYTES..], None) .unwrap(); - assert_eq!(message, dec.unwrap()); + assert_eq!(message, dec.contents()); let message = b"g!".to_vec(); let encrypted_message_to_bob = alice .packet_writer @@ -1291,9 +1307,9 @@ mod tests { .unwrap(); let dec = bob .packet_reader - .decrypt_contents_with_alloc(&encrypted_message_to_bob[NUM_LENGTH_BYTES..], None) + .decrypt_payload_with_alloc(&encrypted_message_to_bob[NUM_LENGTH_BYTES..], None) .unwrap(); - assert_eq!(message, dec.unwrap()); + assert_eq!(message, dec.contents()); } #[test] @@ -1347,9 +1363,9 @@ mod tests { .decypt_len(message_to_bob[..3].try_into().unwrap()); let contents = bob .packet_reader - .decrypt_contents_with_alloc(&message_to_bob[3..3 + alice_message_len], None) + .decrypt_payload_with_alloc(&message_to_bob[3..3 + alice_message_len], None) .unwrap(); - assert_eq!(contents.unwrap(), message); + assert_eq!(contents.contents(), message); } // The rest are sourced from [the BIP324 test vectors](https://github.com/bitcoin/bips/blob/master/bip-0324/packet_encoding_test_vectors.csv). @@ -1380,9 +1396,9 @@ mod tests { .unwrap(); let dec_packet = bob_packet_handler .packet_reader - .decrypt_contents_with_alloc(&enc[NUM_LENGTH_BYTES..], None) + .decrypt_payload_with_alloc(&enc[NUM_LENGTH_BYTES..], None) .unwrap(); - assert_eq!(first, dec_packet.unwrap()); + assert_eq!(first, dec_packet.contents()); let contents: Vec = vec![0x8e]; let enc = alice_packet_handler .packet_writer diff --git a/protocol/tests/round_trips.rs b/protocol/tests/round_trips.rs index 7f95faf..3665d81 100644 --- a/protocol/tests/round_trips.rs +++ b/protocol/tests/round_trips.rs @@ -53,9 +53,9 @@ fn hello_world_happy_path() { .unwrap(); let messages = alice .packet_reader - .decrypt_contents_with_alloc(&encrypted_message_to_alice[3..], None) + .decrypt_payload_with_alloc(&encrypted_message_to_alice[3..], None) .unwrap(); - assert_eq!(message, messages.unwrap()); + assert_eq!(message, messages.contents()); let message = b"Goodbye!".to_vec(); let encrypted_message_to_bob = alice .packet_writer @@ -63,9 +63,9 @@ fn hello_world_happy_path() { .unwrap(); let messages = bob .packet_reader - .decrypt_contents_with_alloc(&encrypted_message_to_bob[3..], None) + .decrypt_payload_with_alloc(&encrypted_message_to_bob[3..], None) .unwrap(); - assert_eq!(message, messages.unwrap()); + assert_eq!(message, messages.contents()); } #[test] @@ -159,9 +159,9 @@ fn regtest_handshake() { let mut response_message = vec![0; message_len]; stream.read_exact(&mut response_message).unwrap(); let msg = decrypter - .decrypt_contents_with_alloc(&response_message, None) + .decrypt_payload_with_alloc(&response_message, None) .unwrap(); - let message = deserialize(&msg.unwrap()).unwrap(); + let message = deserialize(msg.contents()).unwrap(); dbg!("{}", message.cmd()); assert_eq!(message.cmd(), "version"); rpc.stop().unwrap(); diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index 5a0c7d6..1f6bc03 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -132,21 +132,21 @@ pub async fn read_v2( input: &mut T, decrypter: &mut PacketReader, ) -> Result { - let mut plaintext: Option> = None; - // Ignore any decoy packets. - while plaintext.is_none() { + let payload = loop { let mut length_bytes = [0u8; 3]; input.read_exact(&mut length_bytes).await?; let packet_bytes_len = decrypter.decypt_len(length_bytes); let mut packet_bytes = vec![0u8; packet_bytes_len]; input.read_exact(&mut packet_bytes).await?; - plaintext = decrypter - .decrypt_contents_with_alloc(&packet_bytes, None) - .expect("decrypt"); - } + let payload = decrypter.decrypt_payload_with_alloc(&packet_bytes, None)?; + + if payload.packet_type() == PacketType::Genuine { + break payload; + } + }; - let message = deserialize(&plaintext.expect("not a decoy")).map_err(|_| Error::Serde)?; + let message = deserialize(payload.contents()).map_err(|_| Error::Serde)?; Ok(message) }