Skip to content

Commit

Permalink
Make async read cancellation safe
Browse files Browse the repository at this point in the history
  • Loading branch information
nyonson committed Oct 2, 2024
1 parent 9f43bce commit 1e42825
Showing 1 changed file with 33 additions and 10 deletions.
43 changes: 33 additions & 10 deletions protocol/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1075,6 +1075,7 @@ where
reader: AsyncProtocolReader {
buffer: reader,
packet_reader,
state: DecryptState::default(),
},
writer: AsyncProtocolWriter {
buffer: writer,
Expand All @@ -1099,6 +1100,17 @@ where
}
}

/// State machine of an asynchronous packet read.
#[cfg(feature = "async")]
#[derive(Default, Debug)]
enum DecryptState {
#[default]
ReadingLength,
ReadingPayload {
packet_bytes: Vec<u8>,
},
}

/// Manages an async buffer to automatically decrypt contents of received packets.
#[cfg(feature = "async")]
pub struct AsyncProtocolReader<R>
Expand All @@ -1107,6 +1119,7 @@ where
{
buffer: R,
packet_reader: PacketReader,
state: DecryptState,
}

#[cfg(feature = "async")]
Expand All @@ -1122,16 +1135,26 @@ where
/// * `Ok(Payload)`: A decrypted payload.
/// * `Err(ProtocolError)`: An error that occurred during the read or decryption.
pub async fn decrypt(&mut self) -> Result<Payload, ProtocolError> {
// TODO: make cancellation safe with state between read_exacts.
let mut length_bytes = [0u8; 3];
self.buffer.read_exact(&mut length_bytes).await?;
let packet_bytes_len = self.packet_reader.decypt_len(length_bytes);
let mut packet_bytes = vec![0u8; packet_bytes_len];
self.buffer.read_exact(&mut packet_bytes).await?;
let payload = self
.packet_reader
.decrypt_payload_with_alloc(&packet_bytes, None)?;
Ok(payload)
// Storing state between async read_exacts to make function more cancellation safe.
loop {
match &mut self.state {
DecryptState::ReadingLength => {
let mut length_bytes = [0u8; 3];
self.buffer.read_exact(&mut length_bytes).await?;
let packet_bytes_len = self.packet_reader.decypt_len(length_bytes);
let packet_bytes = vec![0u8; packet_bytes_len];
self.state = DecryptState::ReadingPayload { packet_bytes };
}
DecryptState::ReadingPayload { packet_bytes } => {
self.buffer.read_exact(packet_bytes).await?;
let payload = self
.packet_reader
.decrypt_payload_with_alloc(packet_bytes, None)?;
self.state = DecryptState::ReadingLength;
return Ok(payload);
}
}
}
}
}

Expand Down

0 comments on commit 1e42825

Please sign in to comment.