diff --git a/protocol/src/lib.rs b/protocol/src/lib.rs index aa29cbb..6743872 100644 --- a/protocol/src/lib.rs +++ b/protocol/src/lib.rs @@ -1098,15 +1098,28 @@ where /// State machine of an asynchronous packet read. #[cfg(feature = "async")] -#[derive(Default, Debug)] +#[derive(Debug)] enum DecryptState { - #[default] - ReadingLength, + ReadingLength { + length_bytes: [u8; 3], + bytes_read: usize, + }, ReadingPayload { packet_bytes: Vec, + bytes_read: usize, }, } +#[cfg(feature = "async")] +impl Default for DecryptState { + fn default() -> Self { + DecryptState::ReadingLength { + length_bytes: [0u8; 3], + bytes_read: 0, + } + } +} + /// Manages an async buffer to automatically decrypt contents of received packets. #[cfg(feature = "async")] pub struct AsyncProtocolReader @@ -1125,26 +1138,42 @@ where { /// Decrypt contents of received packet from buffer. /// + /// This function is cancellation safe. + /// /// # Returns /// /// A `Result` containing: /// * `Ok(Payload)`: A decrypted payload. /// * `Err(ProtocolError)`: An error that occurred during the read or decryption. pub async fn decrypt(&mut self) -> Result { - // Storing state between async read_exacts to make function more cancellation safe. + // Storing state between async reads to make function cancellation safe. loop { match &mut self.state { - DecryptState::ReadingLength => { - let mut length_bytes = [0u8; 3]; - self.buffer.read_exact(&mut length_bytes).await?; - let packet_bytes_len = self.packet_reader.decypt_len(length_bytes); + DecryptState::ReadingLength { + length_bytes, + bytes_read, + } => { + while *bytes_read < 3 { + *bytes_read += self.buffer.read(&mut length_bytes[*bytes_read..]).await?; + } + + let packet_bytes_len = self.packet_reader.decypt_len(*length_bytes); let packet_bytes = vec![0u8; packet_bytes_len]; - self.state = DecryptState::ReadingPayload { packet_bytes }; + self.state = DecryptState::ReadingPayload { + packet_bytes, + bytes_read: 0, + }; } - DecryptState::ReadingPayload { packet_bytes } => { - self.buffer.read_exact(packet_bytes).await?; + DecryptState::ReadingPayload { + packet_bytes, + bytes_read, + } => { + while *bytes_read < packet_bytes.len() { + *bytes_read += self.buffer.read(&mut packet_bytes[*bytes_read..]).await?; + } + let payload = self.packet_reader.decrypt_payload(packet_bytes, None)?; - self.state = DecryptState::ReadingLength; + self.state = DecryptState::default(); return Ok(payload); } }