From a3f16bbe6b9476d72e86f484e56f12261c089acc Mon Sep 17 00:00:00 2001 From: Nick Johnson Date: Thu, 17 Oct 2024 12:36:56 -0700 Subject: [PATCH] Update proxy with more cancellatoin safety --- proxy/src/lib.rs | 56 ++++++++++++++++++++++++++++++++++-------------- 1 file changed, 40 insertions(+), 16 deletions(-) diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index 941594c..a8fffb2 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -90,16 +90,28 @@ pub async fn peek_addr(client: &TcpStream, network: Network) -> Result, + bytes_read: usize, }, } +impl Default for ReadState { + fn default() -> Self { + ReadState::ReadingLength { + header_bytes: [0u8; V1_HEADER_BYTES], + bytes_read: 0, + } + } +} + /// Read messages on the V1 protocol. pub struct V1ProtocolReader { input: T, @@ -119,9 +131,14 @@ impl V1ProtocolReader { pub async fn read(&mut self) -> Result { loop { match &mut self.state { - ReadState::ReadingLength => { - let mut header_bytes = [0u8; V1_HEADER_BYTES]; - self.input.read_exact(&mut header_bytes).await?; + ReadState::ReadingLength { + header_bytes, + bytes_read, + } => { + while *bytes_read < V1_HEADER_BYTES { + let n = self.input.read(&mut header_bytes[*bytes_read..]).await?; + *bytes_read += n; + } let payload_len = u32::from_le_bytes( header_bytes[16..20] @@ -130,21 +147,28 @@ impl V1ProtocolReader { ) as usize; let mut packet_bytes = vec![0u8; V1_HEADER_BYTES + payload_len]; - packet_bytes[..V1_HEADER_BYTES].copy_from_slice(&header_bytes); + packet_bytes[..V1_HEADER_BYTES].copy_from_slice(header_bytes); - self.state = ReadState::ReadingPayload { packet_bytes }; + self.state = ReadState::ReadingPayload { + packet_bytes, + bytes_read: V1_HEADER_BYTES, + }; } - ReadState::ReadingPayload { packet_bytes } => { - self.input - .read_exact(&mut packet_bytes[V1_HEADER_BYTES..]) - .await?; + ReadState::ReadingPayload { + packet_bytes, + bytes_read, + } => { + while *bytes_read < packet_bytes.len() { + let n = self.input.read(&mut packet_bytes[*bytes_read..]).await?; + *bytes_read += n; + } let message = RawNetworkMessage::consensus_decode(&mut &packet_bytes[..]) .expect("decode v1"); - // Reset state for next read. - self.state = ReadState::ReadingLength; - + self.state = ReadState::default(); + // The RawNetworkMessage type doesn't have a nice way to pull + // out the payload, so using a clone here. return Ok(message.payload().clone()); } }