Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Address multi-frame reception issue #6

Merged
merged 18 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Fixed

* A bug that caused connections seeing a high incidence of multi-frame sends to collapse due to a protocol violation by the sender has been fixed.

## [0.2.1] - 2023-01-23

### Changed

* There is now a timeout for how long a peer can take to accept an error message.
Expand Down
5 changes: 4 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "juliet"
version = "0.2.0"
version = "0.2.1"
edition = "2021"
authors = [ "Marc Brinkmann <[email protected]>" ]
exclude = [ "proptest-regressions" ]
Expand Down Expand Up @@ -46,3 +46,6 @@ static_assertions = "1.1.0"
[[example]]
name = "fizzbuzz"
required-features = [ "tracing" ]

[profile.test]
opt-level = 1
57 changes: 36 additions & 21 deletions src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,9 @@ pub struct IoCore<const N: usize, R, W> {
/// The maximum time allowed for a peer to receive an error.
error_timeout: Duration,

/// The frame in the process of being sent, which may be partially transferred already.
current_frame: Option<OutgoingFrame>,
/// The frame in the process of being sent, which may be partially transferred already. Also
/// indicates if the current frame is the final frame of a message.
current_frame: Option<(OutgoingFrame, bool)>,
/// The headers of active current multi-frame transfers.
active_multi_frame: [Option<Header>; N],
/// Frames waiting to be sent.
Expand Down Expand Up @@ -506,6 +507,8 @@ where
}
Outcome::Fatal(err_msg) => {
// The remote messed up, begin shutting down due to an error.
#[cfg(feature = "tracing")]
tracing::warn!(err_msg_header=%err_msg.header(), "injecting error due to fatal outcome");
self.inject_error(err_msg);
}
Outcome::Success(successful_read) => {
Expand Down Expand Up @@ -541,27 +544,42 @@ where
tokio::select! {
biased; // We actually like the bias, avoid the randomness overhead.

write_result = write_all_buf_if_some(&mut self.writer, self.current_frame.as_mut())
write_result = write_all_buf_if_some(&mut self.writer,
self.current_frame.as_mut()
.map(|(ref mut frame, _)| frame))
, if self.current_frame.is_some() => {

write_result.map_err(CoreError::WriteFailed)?;

// Clear `current_frame` via `Option::take` and examine what was sent.
if let Some(frame_sent) = self.current_frame.take() {
if let Some((frame_sent, was_final)) = self.current_frame.take() {
#[cfg(feature = "tracing")]
tracing::trace!(frame=%frame_sent, "sent");

if frame_sent.header().is_error() {
let header_sent = frame_sent.header();

// If we finished the active multi frame send, clear it.
if was_final {
let channel_idx = header_sent.channel().get() as usize;
if let Some(ref active_multi_frame) =
self.active_multi_frame[channel_idx] {
if header_sent == *active_multi_frame {
self.active_multi_frame[channel_idx] = None;
}
}
}

if header_sent.is_error() {
// We finished sending an error frame, time to exit.
return Err(CoreError::RemoteProtocolViolation(frame_sent.header()));
return Err(CoreError::RemoteProtocolViolation(header_sent));
}

// TODO: We should restrict the dirty-queue processing here a little bit
// (only check when completing a multi-frame message).
// A message has completed sending, process the wait queue in case we have
// to start sending a multi-frame message like a response that was delayed
// only because of the one-multi-frame-per-channel restriction.
self.process_wait_queue(frame_sent.header().channel())?;
self.process_wait_queue(header_sent.channel())?;
} else {
#[cfg(feature = "tracing")]
tracing::error!("current frame should not disappear");
Expand Down Expand Up @@ -726,6 +744,9 @@ where
let msg = self.juliet.create_request(channel, payload)?;
let id = msg.header().id();
self.request_map.insert(io_id, (channel, id));
if msg.is_multi_frame(self.juliet.max_frame_size()) {
self.active_multi_frame[channel.get() as usize] = Some(msg.header());
}
self.ready_queue.push_back(msg.frames());

drop(permit);
Expand All @@ -749,6 +770,9 @@ where
payload,
} => {
if let Some(msg) = self.juliet.create_response(channel, id, payload)? {
if msg.is_multi_frame(self.juliet.max_frame_size()) {
self.active_multi_frame[channel.get() as usize] = Some(msg.header());
}
self.ready_queue.push_back(msg.frames())
}
}
Expand Down Expand Up @@ -790,23 +814,14 @@ where
.next_owned(self.juliet.max_frame_size());

// If there are more frames after this one, schedule the remainder.
if let Some(next_frame_iter) = additional_frames {
let is_final = if let Some(next_frame_iter) = additional_frames {
self.ready_queue.push_back(next_frame_iter);
false
} else {
// No additional frames. Check if sending the next frame finishes a multi-frame message.
let about_to_finish = frame.header();
if let Some(ref active_multi) =
self.active_multi_frame[about_to_finish.channel().get() as usize]
{
if about_to_finish == *active_multi {
// Once the scheduled frame is processed, we will finished the multi-frame
// transfer, so we can allow for the next multi-frame transfer to be scheduled.
self.active_multi_frame[about_to_finish.channel().get() as usize] = None;
}
}
}
true
};

self.current_frame = Some(frame);
self.current_frame = Some((frame, is_final));
Ok(())
}

Expand Down
13 changes: 13 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,19 @@ impl<T, E> Outcome<T, E> {
}
}

/// Maps the value of an [`Outcome`].
#[inline]
pub fn map<T2, F>(self, f: F) -> Outcome<T2, E>
where
F: FnOnce(T) -> T2,
{
match self {
Outcome::Incomplete(n) => Outcome::Incomplete(n),
Outcome::Fatal(err) => Outcome::Fatal(err),
Outcome::Success(value) => Outcome::Success(f(value)),
}
}

/// Maps the error of an [`Outcome`].
#[inline]
pub fn map_err<E2, F>(self, f: F) -> Outcome<T, E2>
Expand Down
33 changes: 11 additions & 22 deletions src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use std::{collections::HashSet, fmt::Display};
use bytes::{Buf, Bytes, BytesMut};
use thiserror::Error;

use self::multiframe::MultiframeReceiver;
use self::multiframe::{CompletedFrame, MultiframeReceiver};
pub use self::outgoing_message::{FrameIter, OutgoingFrame, OutgoingMessage};
use crate::{
header::{self, ErrorKind, Header, Kind},
Expand Down Expand Up @@ -751,8 +751,8 @@ impl<const N: usize> JulietProtocol<N> {
) -> Outcome<CompletedRead, OutgoingMessage> {
// First, attempt to complete a frame.
loop {
// We do not have enough data to extract a header, indicate and return.
if buffer.len() < Header::SIZE {
// We do not have enough data to extract a header, indicate and return.
return Outcome::incomplete(Header::SIZE - buffer.len());
}

Expand Down Expand Up @@ -859,11 +859,7 @@ impl<const N: usize> JulietProtocol<N> {
}
}
Kind::RequestPl => {
// Make a note whether or not we are continuing an existing request.
let is_new_request =
channel.current_multiframe_receiver.is_new_transfer(header);

let multiframe_outcome: Option<BytesMut> =
let completed_frame: CompletedFrame =
try_outcome!(channel.current_multiframe_receiver.accept(
header,
buffer,
Expand All @@ -873,8 +869,7 @@ impl<const N: usize> JulietProtocol<N> {
));

// If we made it to this point, we have consumed the frame. Record it.

if is_new_request {
if completed_frame.was_new() {
// Requests must be eagerly (first frame) rejected if exceeding the limit.
if channel.is_at_max_incoming_requests() {
return err_msg(header, ErrorKind::RequestLimitExceeded);
Expand All @@ -887,18 +882,15 @@ impl<const N: usize> JulietProtocol<N> {
channel.increment_cancellation_allowance();
}

if let Some(payload) = multiframe_outcome {
// Message is complete.
// If we completed the message, return it.
if let Some(payload) = completed_frame.into_completed_payload() {
let payload = payload.freeze();

return Success(CompletedRead::NewRequest {
channel: header.channel(),
id: header.id(),
payload: Some(payload),
});
} else {
// We need more frames to complete the payload. Do nothing and attempt
// to read the next frame.
}
}
Kind::ResponsePl => {
Expand All @@ -907,7 +899,7 @@ impl<const N: usize> JulietProtocol<N> {
return err_msg(header, ErrorKind::FictitiousRequest);
}

let multiframe_outcome: Option<BytesMut> =
let multiframe_outcome =
try_outcome!(channel.current_multiframe_receiver.accept(
header,
buffer,
Expand All @@ -916,8 +908,8 @@ impl<const N: usize> JulietProtocol<N> {
ErrorKind::ResponseTooLarge
));

if let Some(payload) = multiframe_outcome {
// Message is complete. Remove it from the outgoing requests.
// If the response is complete, process it.
if let Some(payload) = multiframe_outcome.into_completed_payload() {
channel.outgoing_requests.remove(&header.id());

let payload = payload.freeze();
Expand All @@ -927,9 +919,6 @@ impl<const N: usize> JulietProtocol<N> {
id: header.id(),
payload: Some(payload),
});
} else {
// We need more frames to complete the payload. Do nothing and attempt
// to read the next frame.
}
}
Kind::CancelReq => {
Expand Down Expand Up @@ -2395,9 +2384,9 @@ mod tests {
outcome,
CompletedRead::ReceivedResponse {
channel,
/// The ID of the request received.
// The ID of the request received.
id,
/// The response payload.
// The response payload.
payload: None,
}
);
Expand Down
Loading
Loading