Skip to content

Commit

Permalink
Add a Payload type to avoid performance hit
Browse files Browse the repository at this point in the history
  • Loading branch information
nyonson committed Sep 27, 2024
1 parent 1b40672 commit 981479d
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 55 deletions.
98 changes: 57 additions & 41 deletions protocol/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,9 @@ pub enum PacketType {
}

impl PacketType {
/// Check if plaintext packet is a decoy.
pub fn from_bytes(plaintext: &[u8]) -> Self {
// Check if header byte has the decoy flag flipped.
if plaintext.first() == Some(&DECOY_BYTE) {
/// Check if header byte has the decoy flag flipped.
pub fn from_byte(header: &u8) -> Self {
if header == &DECOY_BYTE {
PacketType::Decoy
} else {
PacketType::Genuine
Expand All @@ -239,6 +238,30 @@ impl PacketType {
}
}

/// Plaintext payload of a packet, which includes the header and contents.
#[cfg(feature = "alloc")]
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Payload {
bytes: Vec<u8>,
}

impl Payload {
pub fn new(bytes: Vec<u8>) -> Self {
Self { bytes }
}

/// Contents of the payload.
pub fn contents(&self) -> &[u8] {
// Exclude the header byte.
&self.bytes[1..]
}

/// Packet type of the payload.
pub fn packet_type(&self) -> PacketType {
PacketType::from_byte(&self.bytes[0])
}
}

/// Read packets over established encrypted channel from a peer.
#[derive(Clone)]
pub struct PacketReader {
Expand Down Expand Up @@ -271,7 +294,7 @@ impl PacketReader {
content_len as usize + NUM_HEADER_BYTES + NUM_TAG_BYTES
}

/// Decrypt the packet contents.
/// Decrypt the packet header byte and contents.
///
/// # Arguments
///
Expand All @@ -292,12 +315,12 @@ impl PacketReader {
/// * `CiphertextTooSmall` - Ciphertext argument does not contain a whole packet.
/// * `BufferTooSmall ` - Contents buffer argument is not large enough for plaintext.
/// * Decryption errors for any failures such as a tag mismatch.
pub fn decrypt_contents(
pub fn decrypt_payload(
&mut self,
ciphertext: &[u8],
contents: &mut [u8],
aad: Option<&[u8]>,
) -> Result<PacketType, Error> {
) -> Result<(), Error> {
let auth = aad.unwrap_or_default();
// Check minimum size of ciphertext.
if ciphertext.len() < NUM_TAG_BYTES {
Expand All @@ -317,10 +340,10 @@ impl PacketReader {
tag.try_into().expect("16 byte tag"),
)?;

Ok(PacketType::from_bytes(contents))
Ok(())
}

/// Decrypt the packet contents.
/// Decrypt the packet header byte and contents.
///
/// # Arguments
///
Expand All @@ -331,27 +354,21 @@ impl PacketReader {
/// # Returns
///
/// A `Result` containing:
/// * `Ok(Some(Vec<u8>))`: The plaintext contents in a byte vector if it is not a decoy packet.
/// * `Ok(Payload)`: The plaintext header and contents.
/// * `Err(Error)`: An error that occurred during decryption.
///
/// # Errors
///
/// * `CiphertextTooSmall` - Ciphertext argument does not contain a whole packet.
#[cfg(feature = "alloc")]
pub fn decrypt_contents_with_alloc(
pub fn decrypt_payload_with_alloc(
&mut self,
ciphertext: &[u8],
aad: Option<&[u8]>,
) -> Result<Option<Vec<u8>>, Error> {
let mut contents = vec![0u8; ciphertext.len() - NUM_TAG_BYTES];
match self.decrypt_contents(ciphertext, &mut contents, aad)? {
PacketType::Decoy => Ok(None),
PacketType::Genuine => {
// Drop the header byte.
contents.remove(0);
Ok(Some(contents))
}
}
) -> Result<Payload, Error> {
let mut payload = vec![0u8; ciphertext.len() - NUM_TAG_BYTES];
self.decrypt_payload(ciphertext, &mut payload, aad)?;
Ok(Payload::new(payload))
}
}

Expand Down Expand Up @@ -881,7 +898,7 @@ impl<'a> Handshake<'a> {
if ciphertext.len() < self.current_buffer_index + NUM_LENGTH_BYTES + packet_length {
return Err(Error::CiphertextTooSmall);
}
packet_handler.packet_reader.decrypt_contents(
packet_handler.packet_reader.decrypt_payload(
&ciphertext[self.current_buffer_index + NUM_LENGTH_BYTES
..self.current_buffer_index + NUM_LENGTH_BYTES + packet_length],
packet_buffer,
Expand All @@ -892,9 +909,8 @@ impl<'a> Handshake<'a> {
self.current_buffer_index = self.current_buffer_index + NUM_LENGTH_BYTES + packet_length;
self.current_packet_length_bytes = None;

// The version packet is currently just an empty packet.
Ok(matches!(
PacketType::from_bytes(packet_buffer),
PacketType::from_byte(packet_buffer.first().expect("header byte")),
PacketType::Genuine
))
}
Expand Down Expand Up @@ -1083,19 +1099,19 @@ mod tests {
.unwrap();
let dec = bob_packet_handler
.packet_reader
.decrypt_contents_with_alloc(&enc_packet[NUM_LENGTH_BYTES..], None)
.decrypt_payload_with_alloc(&enc_packet[NUM_LENGTH_BYTES..], None)
.unwrap();
assert_eq!(None, dec);
assert_eq!(PacketType::Decoy, dec.packet_type());
let message = b"Windows sox!".to_vec();
let enc_packet = bob_packet_handler
.packet_writer
.encrypt_packet_with_alloc(&message, None, PacketType::Genuine)
.unwrap();
let dec = alice_packet_handler
.packet_reader
.decrypt_contents_with_alloc(&enc_packet[NUM_LENGTH_BYTES..], None)
.decrypt_payload_with_alloc(&enc_packet[NUM_LENGTH_BYTES..], None)
.unwrap();
assert_eq!(message, dec.unwrap());
assert_eq!(message, dec.contents());
}

#[test]
Expand Down Expand Up @@ -1126,19 +1142,19 @@ mod tests {
.unwrap();
let dec_packet = bob_packet_handler
.packet_reader
.decrypt_contents_with_alloc(&enc_packet[NUM_LENGTH_BYTES..], None)
.decrypt_payload_with_alloc(&enc_packet[NUM_LENGTH_BYTES..], None)
.unwrap();
assert_eq!(message, dec_packet.unwrap());
assert_eq!(message, dec_packet.contents());
let message = gen_garbage(420, &mut rng);
let enc_packet = bob_packet_handler
.packet_writer
.encrypt_packet_with_alloc(&message, None, PacketType::Genuine)
.unwrap();
let dec_packet = alice_packet_handler
.packet_reader
.decrypt_contents_with_alloc(&enc_packet[NUM_LENGTH_BYTES..], None)
.decrypt_payload_with_alloc(&enc_packet[NUM_LENGTH_BYTES..], None)
.unwrap();
assert_eq!(message, dec_packet.unwrap());
assert_eq!(message, dec_packet.contents());
}
}

Expand Down Expand Up @@ -1168,7 +1184,7 @@ mod tests {
.unwrap();
let _ = bob_packet_handler
.packet_reader
.decrypt_contents_with_alloc(&enc_packet[NUM_LENGTH_BYTES..], Some(&auth_garbage))
.decrypt_payload_with_alloc(&enc_packet[NUM_LENGTH_BYTES..], Some(&auth_garbage))
.unwrap();
}

Expand Down Expand Up @@ -1281,19 +1297,19 @@ mod tests {
.unwrap();
let dec = alice
.packet_reader
.decrypt_contents_with_alloc(&encrypted_message_to_alice[NUM_LENGTH_BYTES..], None)
.decrypt_payload_with_alloc(&encrypted_message_to_alice[NUM_LENGTH_BYTES..], None)
.unwrap();
assert_eq!(message, dec.unwrap());
assert_eq!(message, dec.contents());
let message = b"g!".to_vec();
let encrypted_message_to_bob = alice
.packet_writer
.encrypt_packet_with_alloc(&message, None, PacketType::Genuine)
.unwrap();
let dec = bob
.packet_reader
.decrypt_contents_with_alloc(&encrypted_message_to_bob[NUM_LENGTH_BYTES..], None)
.decrypt_payload_with_alloc(&encrypted_message_to_bob[NUM_LENGTH_BYTES..], None)
.unwrap();
assert_eq!(message, dec.unwrap());
assert_eq!(message, dec.contents());
}

#[test]
Expand Down Expand Up @@ -1347,9 +1363,9 @@ mod tests {
.decypt_len(message_to_bob[..3].try_into().unwrap());
let contents = bob
.packet_reader
.decrypt_contents_with_alloc(&message_to_bob[3..3 + alice_message_len], None)
.decrypt_payload_with_alloc(&message_to_bob[3..3 + alice_message_len], None)
.unwrap();
assert_eq!(contents.unwrap(), message);
assert_eq!(contents.contents(), message);
}

// The rest are sourced from [the BIP324 test vectors](https://github.com/bitcoin/bips/blob/master/bip-0324/packet_encoding_test_vectors.csv).
Expand Down Expand Up @@ -1380,9 +1396,9 @@ mod tests {
.unwrap();
let dec_packet = bob_packet_handler
.packet_reader
.decrypt_contents_with_alloc(&enc[NUM_LENGTH_BYTES..], None)
.decrypt_payload_with_alloc(&enc[NUM_LENGTH_BYTES..], None)
.unwrap();
assert_eq!(first, dec_packet.unwrap());
assert_eq!(first, dec_packet.contents());
let contents: Vec<u8> = vec![0x8e];
let enc = alice_packet_handler
.packet_writer
Expand Down
12 changes: 6 additions & 6 deletions protocol/tests/round_trips.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,19 @@ fn hello_world_happy_path() {
.unwrap();
let messages = alice
.packet_reader
.decrypt_contents_with_alloc(&encrypted_message_to_alice[3..], None)
.decrypt_payload_with_alloc(&encrypted_message_to_alice[3..], None)
.unwrap();
assert_eq!(message, messages.unwrap());
assert_eq!(message, messages.contents());
let message = b"Goodbye!".to_vec();
let encrypted_message_to_bob = alice
.packet_writer
.encrypt_packet_with_alloc(&message, None, PacketType::Genuine)
.unwrap();
let messages = bob
.packet_reader
.decrypt_contents_with_alloc(&encrypted_message_to_bob[3..], None)
.decrypt_payload_with_alloc(&encrypted_message_to_bob[3..], None)
.unwrap();
assert_eq!(message, messages.unwrap());
assert_eq!(message, messages.contents());
}

#[test]
Expand Down Expand Up @@ -159,9 +159,9 @@ fn regtest_handshake() {
let mut response_message = vec![0; message_len];
stream.read_exact(&mut response_message).unwrap();
let msg = decrypter
.decrypt_contents_with_alloc(&response_message, None)
.decrypt_payload_with_alloc(&response_message, None)
.unwrap();
let message = deserialize(&msg.unwrap()).unwrap();
let message = deserialize(msg.contents()).unwrap();
dbg!("{}", message.cmd());
assert_eq!(message.cmd(), "version");
rpc.stop().unwrap();
Expand Down
16 changes: 8 additions & 8 deletions proxy/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,21 +132,21 @@ pub async fn read_v2<T: AsyncRead + Unpin>(
input: &mut T,
decrypter: &mut PacketReader,
) -> Result<NetworkMessage, Error> {
let mut plaintext: Option<Vec<u8>> = None;

// Ignore any decoy packets.
while plaintext.is_none() {
let payload = loop {
let mut length_bytes = [0u8; 3];
input.read_exact(&mut length_bytes).await?;
let packet_bytes_len = decrypter.decypt_len(length_bytes);
let mut packet_bytes = vec![0u8; packet_bytes_len];
input.read_exact(&mut packet_bytes).await?;
plaintext = decrypter
.decrypt_contents_with_alloc(&packet_bytes, None)
.expect("decrypt");
}
let payload = decrypter.decrypt_payload_with_alloc(&packet_bytes, None)?;

if payload.packet_type() == PacketType::Genuine {
break payload;
}
};

let message = deserialize(&plaintext.expect("not a decoy")).map_err(|_| Error::Serde)?;
let message = deserialize(payload.contents()).map_err(|_| Error::Serde)?;
Ok(message)
}

Expand Down

0 comments on commit 981479d

Please sign in to comment.