Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: don't split header and body across TCP packets #168

Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
1c06435
Compilation + test using zerocopy ok
stormshield-pj50 Jul 19, 2023
39fa1e5
Fix after comments
stormshield-pj50 Sep 7, 2023
5d4e2d2
cargo fmt
stormshield-pj50 Sep 8, 2023
62002ae
WIP Minimal diff
thomaseizinger Sep 17, 2023
ee00e1d
Some basic fixes
thomaseizinger Sep 17, 2023
801bf4f
Don't use `ready!` macro with `mem::replace`
thomaseizinger Oct 5, 2023
8e11260
Inline variable
thomaseizinger Oct 5, 2023
aeb4cd2
Remove `Init` state
thomaseizinger Oct 5, 2023
f90c990
Use `?` for decoding header
thomaseizinger Oct 5, 2023
93c3834
Use ctor
thomaseizinger Oct 5, 2023
872815e
Use type-system to only allocate for data frames
thomaseizinger Oct 5, 2023
722b7c8
Don't use `cast` outside of `header` module
thomaseizinger Oct 5, 2023
91e812a
Add TODO
thomaseizinger Oct 5, 2023
c095aac
Replace header::decode() with Frame<T>::try_from_header_buffer()
pjalaber Oct 4, 2023
2cab6b4
Cargo fmt
pjalaber Oct 5, 2023
166f8ff
Reduce diff
thomaseizinger Oct 6, 2023
5c6b172
Bring back header::decode
thomaseizinger Oct 6, 2023
524994f
Reduce diff
thomaseizinger Oct 6, 2023
0108a3d
Reduce diff
thomaseizinger Oct 6, 2023
8d32d16
Resolve todo
thomaseizinger Oct 6, 2023
851341e
Reduce diff
thomaseizinger Oct 6, 2023
9b26409
Reduce diff
thomaseizinger Oct 6, 2023
3df462d
Simplify things a bit further
thomaseizinger Oct 6, 2023
09cda48
Use body_len in debug impl
thomaseizinger Oct 6, 2023
5e3e65b
Remove generic length accessor
thomaseizinger Oct 6, 2023
ab29664
Reduce diff
thomaseizinger Oct 6, 2023
b65f2bd
Ensure we check max body len before allocating
thomaseizinger Oct 9, 2023
ee9c920
Don't allocate unless necessary
thomaseizinger Oct 9, 2023
ae3bc3d
WIP: Use `AsyncWrite::poll_write_vectored`
thomaseizinger Oct 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions yamux/src/chunks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,12 @@ impl Chunks {
}

/// Add another chunk of bytes to the end.
pub(crate) fn push(&mut self, x: Vec<u8>, offset: usize) {
let x_len = x.len();
let cursor = io::Cursor::new(x);
let mut chunk = Chunk { cursor };
chunk.advance(offset);
if !chunk.is_empty() {
assert_eq!(chunk.len(), x_len - offset);
self.len += chunk.len() + offset;
self.seq.push_back(chunk);
pub(crate) fn push(&mut self, x: Vec<u8>) {
self.len += x.len();
if !x.is_empty() {
self.seq.push_back(Chunk {
cursor: io::Cursor::new(x),
})
}
}

Expand Down
130 changes: 54 additions & 76 deletions yamux/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ mod cleanup;
mod closing;
mod stream;

use crate::frame::header::HEADER_SIZE;
use crate::tagged_stream::TaggedStream;
use crate::{
error::ConnectionError,
Expand Down Expand Up @@ -356,7 +355,7 @@ struct Active<T> {
socket: Fuse<frame::Io<T>>,
next_id: u32,

streams: IntMap<u32, Arc<Mutex<stream::Shared>>>,
streams: IntMap<StreamId, Arc<Mutex<stream::Shared>>>,
stream_receivers: SelectAll<TaggedStream<StreamId, mpsc::Receiver<StreamCommand>>>,
no_streams_waker: Option<Waker>,

Expand Down Expand Up @@ -520,14 +519,8 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {

if extra_credit > 0 {
let mut frame = Frame::window_update(id, extra_credit);
let mut parsed_frame = frame.parse_mut().expect("valid frame");
parsed_frame.header_mut().syn();
log::trace!(
"{}/{}: sending initial {}",
self.id,
id,
parsed_frame.header()
);
frame.header_mut().syn();
log::trace!("{}/{}: sending initial {}", self.id, id, frame.header());
self.pending_frames.push_back(frame.into());
}

Expand All @@ -538,18 +531,17 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
}

log::debug!("{}: new outbound {} of {}", self.id, stream, self);
self.streams.insert(id.val(), stream.clone_shared());
self.streams.insert(id, stream.clone_shared());

Poll::Ready(Ok(stream))
}

fn on_send_frame(&mut self, frame: Frame<Either<Data, WindowUpdate>>) {
let parsed_frame = frame.parse().expect("valid frame");
log::trace!(
"{}/{}: sending: {}",
self.id,
parsed_frame.header().stream_id(),
parsed_frame.header()
frame.header().stream_id(),
frame.header()
);
self.pending_frames.push_back(frame.into());
}
Expand All @@ -561,10 +553,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
}

fn on_drop_stream(&mut self, stream_id: StreamId) {
let s = self
.streams
.remove(&stream_id.val())
.expect("stream not found");
let s = self.streams.remove(&stream_id).expect("stream not found");

log::trace!("{}: removing dropped stream {}", self.id, stream_id);
let frame = {
Expand All @@ -575,15 +564,15 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
State::Open { .. } => {
let mut header = Header::data(stream_id, 0);
header.rst();
Some(Frame::from_header(header))
Some(Frame::new(header))
}
// The stream was dropped without calling `poll_close`.
// We have already received a FIN from remote and send one
// back which closes the stream for good.
State::RecvClosed => {
let mut header = Header::data(stream_id, 0);
header.fin();
Some(Frame::from_header(header))
Some(Frame::new(header))
}
// The stream was properly closed. We already sent our FIN frame.
// The remote may be out of credit though and blocked on
Expand All @@ -596,7 +585,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
// which we will never send, so reset the stream now.
let mut header = Header::data(stream_id, 0);
header.rst();
Some(Frame::from_header(header))
Some(Frame::new(header))
} else {
// The remote has either still credit or will be given more
// (due to an enqueued window update or because the update
Expand All @@ -620,8 +609,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
frame
};
if let Some(f) = frame {
let pf = f.parse().expect("valid frame");
log::trace!("{}/{}: sending: {}", self.id, stream_id, pf.header());
log::trace!("{}/{}: sending: {}", self.id, stream_id, f.header());
self.pending_frames.push_back(f.into());
}
}
Expand All @@ -633,12 +621,11 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
/// Otherwise we process the frame and potentially return a new `Stream`
/// if one was opened by the remote.
fn on_frame(&mut self, frame: Frame<()>) -> Result<Option<Stream>> {
let parsed_frame = frame.parse().expect("valid frame");
log::trace!("{}: received: {}", self.id, parsed_frame.header());
log::trace!("{}: received: {}", self.id, frame.header());

if parsed_frame.header().flags().contains(header::ACK) {
let id = parsed_frame.header().stream_id();
if let Some(stream) = self.streams.get(&id.val()) {
if frame.header().flags().contains(header::ACK) {
let id = frame.header().stream_id();
if let Some(stream) = self.streams.get(&id) {
stream
.lock()
.update_state(self.id, id, State::Open { acknowledged: true });
Expand All @@ -648,7 +635,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
}
}

let action = match parsed_frame.header().tag().expect("valid header's tag") {
let action = match frame.header().tag() {
Tag::Data => self.on_data(frame.into_data()),
Tag::WindowUpdate => self.on_window_update(&frame.into_window_update()),
Tag::Ping => self.on_ping(&frame.into_ping()),
Expand All @@ -659,25 +646,21 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
Action::New(stream, update) => {
log::trace!("{}: new inbound {} of {}", self.id, stream, self);
if let Some(f) = update {
let pf = f.parse().expect("valid frame");
log::trace!("{}/{}: sending update", self.id, pf.header().stream_id());
log::trace!("{}/{}: sending update", self.id, f.header().stream_id());
self.pending_frames.push_back(f.into());
}
return Ok(Some(stream));
}
Action::Update(f) => {
let pf = f.parse().expect("valid frame");
log::trace!("{}: sending update: {:?}", self.id, pf.header());
log::trace!("{}: sending update: {:?}", self.id, f.header());
self.pending_frames.push_back(f.into());
}
Action::Ping(f) => {
let pf = f.parse().expect("valid frame");
log::trace!("{}/{}: pong", self.id, pf.header().stream_id());
log::trace!("{}/{}: pong", self.id, f.header().stream_id());
self.pending_frames.push_back(f.into());
}
Action::Reset(f) => {
let pf = f.parse().expect("valid frame");
log::trace!("{}/{}: sending reset", self.id, pf.header().stream_id());
log::trace!("{}/{}: sending reset", self.id, f.header().stream_id());
self.pending_frames.push_back(f.into());
}
Action::Terminate(f) => {
Expand All @@ -690,12 +673,11 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
}

fn on_data(&mut self, frame: Frame<Data>) -> Action {
let parsed_frame = frame.parse().expect("valid frame");
let stream_id = parsed_frame.header().stream_id();
let stream_id = frame.header().stream_id();

if parsed_frame.header().flags().contains(header::RST) {
if frame.header().flags().contains(header::RST) {
// stream reset
if let Some(s) = self.streams.get_mut(&stream_id.val()) {
if let Some(s) = self.streams.get_mut(&stream_id) {
let mut shared = s.lock();
shared.update_state(self.id, stream_id, State::Closed);
if let Some(w) = shared.reader.take() {
Expand All @@ -708,23 +690,23 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
return Action::None;
}

let is_finish = parsed_frame.header().flags().contains(header::FIN); // half-close
let is_finish = frame.header().flags().contains(header::FIN); // half-close

if parsed_frame.header().flags().contains(header::SYN) {
if frame.header().flags().contains(header::SYN) {
// new stream
if !self.is_valid_remote_id(stream_id, Tag::Data) {
log::error!("{}: invalid stream id {}", self.id, stream_id);
return Action::Terminate(Frame::protocol_error());
}
if parsed_frame.body().len() > DEFAULT_CREDIT as usize {
if frame.body().len() > DEFAULT_CREDIT as usize {
log::error!(
"{}/{}: 1st body of stream exceeds default credit",
self.id,
stream_id
);
return Action::Terminate(Frame::protocol_error());
}
if self.streams.contains_key(&stream_id.val()) {
if self.streams.contains_key(&stream_id) {
log::error!("{}/{}: stream already exists", self.id, stream_id);
return Action::Terminate(Frame::protocol_error());
}
Expand All @@ -739,30 +721,28 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
if is_finish {
shared.update_state(self.id, stream_id, State::RecvClosed);
}
shared.window = shared.window.saturating_sub(parsed_frame.body_len());
shared.buffer.push(frame.into_buffer(), HEADER_SIZE);
thomaseizinger marked this conversation as resolved.
Show resolved Hide resolved
shared.window = shared.window.saturating_sub(frame.body_len());
shared.buffer.push(frame.into_body());

if matches!(self.config.window_update_mode, WindowUpdateMode::OnReceive) {
if let Some(credit) = shared.next_window_update() {
shared.window += credit;

let mut frame = Frame::window_update(stream_id, credit);
let mut parsed_frame = frame.parse_mut().expect("valid frame");
parsed_frame.header_mut().ack();
frame.header_mut().ack();
window_update = Some(frame)
}
}
}
if window_update.is_none() {
stream.set_flag(stream::Flag::Ack)
}
self.streams.insert(stream_id.val(), stream.clone_shared());
self.streams.insert(stream_id, stream.clone_shared());
return Action::New(stream, window_update);
}

if let Some(s) = self.streams.get_mut(&stream_id.val()) {
if let Some(s) = self.streams.get_mut(&stream_id) {
let mut shared = s.lock();
if parsed_frame.body().len() > shared.window as usize {
if frame.body().len() > shared.window as usize {
log::error!(
"{}/{}: frame body larger than window of stream",
self.id,
Expand All @@ -782,10 +762,10 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
);
let mut header = Header::data(stream_id, 0);
header.rst();
return Action::Reset(Frame::from_header(header));
return Action::Reset(Frame::new(header));
}
shared.window = shared.window.saturating_sub(parsed_frame.body_len());
shared.buffer.push(frame.into_buffer(), HEADER_SIZE);
shared.window = shared.window.saturating_sub(frame.body_len());
shared.buffer.push(frame.into_body());
if let Some(w) = shared.reader.take() {
w.wake()
}
Expand Down Expand Up @@ -816,12 +796,11 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
}

fn on_window_update(&mut self, frame: &Frame<WindowUpdate>) -> Action {
let parsed_frame = frame.parse().expect("valid frame");
let stream_id = parsed_frame.header().stream_id();
let stream_id = frame.header().stream_id();

if parsed_frame.header().flags().contains(header::RST) {
if frame.header().flags().contains(header::RST) {
// stream reset
if let Some(s) = self.streams.get_mut(&stream_id.val()) {
if let Some(s) = self.streams.get_mut(&stream_id) {
let mut shared = s.lock();
shared.update_state(self.id, stream_id, State::Closed);
if let Some(w) = shared.reader.take() {
Expand All @@ -834,15 +813,15 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
return Action::None;
}

let is_finish = parsed_frame.header().flags().contains(header::FIN); // half-close
let is_finish = frame.header().flags().contains(header::FIN); // half-close

if parsed_frame.header().flags().contains(header::SYN) {
if frame.header().flags().contains(header::SYN) {
// new stream
if !self.is_valid_remote_id(stream_id, Tag::WindowUpdate) {
log::error!("{}: invalid stream id {}", self.id, stream_id);
return Action::Terminate(Frame::protocol_error());
}
if self.streams.contains_key(&stream_id.val()) {
if self.streams.contains_key(&stream_id) {
log::error!("{}/{}: stream already exists", self.id, stream_id);
return Action::Terminate(Frame::protocol_error());
}
Expand All @@ -851,7 +830,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
return Action::Terminate(Frame::protocol_error());
}

let credit = parsed_frame.header().credit() + DEFAULT_CREDIT;
let credit = frame.header().credit() + DEFAULT_CREDIT;
let mut stream = self.make_new_inbound_stream(stream_id, credit);
stream.set_flag(stream::Flag::Ack);

Expand All @@ -860,13 +839,13 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
.shared()
.update_state(self.id, stream_id, State::RecvClosed);
}
self.streams.insert(stream_id.val(), stream.clone_shared());
self.streams.insert(stream_id, stream.clone_shared());
return Action::New(stream, None);
}

if let Some(s) = self.streams.get_mut(&stream_id.val()) {
if let Some(s) = self.streams.get_mut(&stream_id) {
let mut shared = s.lock();
shared.credit += parsed_frame.header().credit();
shared.credit += frame.header().credit();
if is_finish {
shared.update_state(self.id, stream_id, State::RecvClosed);
}
Expand All @@ -893,16 +872,15 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
}

fn on_ping(&mut self, frame: &Frame<Ping>) -> Action {
let parsed_frame = frame.parse().expect("valid frame");
let stream_id = parsed_frame.header().stream_id();
if parsed_frame.header().flags().contains(header::ACK) {
let stream_id = frame.header().stream_id();
if frame.header().flags().contains(header::ACK) {
// pong
return Action::None;
}
if stream_id == CONNECTION_ID || self.streams.contains_key(&stream_id.val()) {
let mut hdr = Header::ping(parsed_frame.header().nonce());
if stream_id == CONNECTION_ID || self.streams.contains_key(&stream_id) {
let mut hdr = Header::ping(frame.header().nonce());
hdr.ack();
return Action::Ping(Frame::from_header(hdr));
return Action::Ping(Frame::new(hdr));
}
log::trace!(
"{}/{}: ping for unknown stream, possibly dropped earlier: {:?}",
Expand Down Expand Up @@ -969,8 +947,8 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
// - Its ID is odd and we are the client.
// - Its ID is even and we are the server.
.filter(|(id, _)| match self.mode {
Mode::Client => StreamId::new(**id).is_client(),
Mode::Server => StreamId::new(**id).is_server(),
Mode::Client => id.is_client(),
Mode::Server => id.is_server(),
})
.filter(|(_, s)| s.lock().is_pending_ack())
.count()
Expand All @@ -993,7 +971,7 @@ impl<T> Active<T> {
fn drop_all_streams(&mut self) {
for (id, s) in self.streams.drain() {
let mut shared = s.lock();
shared.update_state(self.id, StreamId::new(id), State::Closed);
shared.update_state(self.id, id, State::Closed);
if let Some(w) = shared.reader.take() {
w.wake()
}
Expand Down
Loading
Loading