Skip to content

Commit

Permalink
Fix closing pending frames (#194)
Browse files Browse the repository at this point in the history
* send pending frames after collecting them

* test

* Apply suggestions from code review

* Update yamux/src/connection/closing.rs
  • Loading branch information
turuslan authored Oct 23, 2024
1 parent 8bd5d40 commit 7063268
Showing 1 changed file with 110 additions and 11 deletions.
121 changes: 110 additions & 11 deletions yamux/src/connection/closing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ where
socket: Fuse<frame::Io<T>>,
) -> Self {
Self {
state: State::FlushingPendingFrames,
state: State::ClosingStreamReceiver,
stream_receivers,
pending_frames,
socket,
Expand All @@ -49,14 +49,6 @@ where

loop {
match this.state {
State::FlushingPendingFrames => {
ready!(this.socket.poll_ready_unpin(cx))?;

match this.pending_frames.pop_front() {
Some(frame) => this.socket.start_send_unpin(frame)?,
None => this.state = State::ClosingStreamReceiver,
}
}
State::ClosingStreamReceiver => {
for stream in this.stream_receivers.iter_mut() {
stream.inner_mut().close();
Expand All @@ -77,11 +69,19 @@ where
Poll::Pending | Poll::Ready(None) => {
// No more frames from streams, append `Term` frame and flush them all.
this.pending_frames.push_back(Frame::term().into());
this.state = State::ClosingSocket;
this.state = State::FlushingPendingFrames;
continue;
}
}
}
State::FlushingPendingFrames => {
ready!(this.socket.poll_ready_unpin(cx))?;

match this.pending_frames.pop_front() {
Some(frame) => this.socket.start_send_unpin(frame)?,
None => this.state = State::ClosingSocket,
}
}
State::ClosingSocket => {
ready!(this.socket.poll_close_unpin(cx))?;

Expand All @@ -93,8 +93,107 @@ where
}

enum State {
FlushingPendingFrames,
ClosingStreamReceiver,
DrainingStreamReceiver,
FlushingPendingFrames,
ClosingSocket,
}

#[cfg(test)]
mod tests {
use super::*;
use futures::future::poll_fn;
use futures::FutureExt;

struct Socket {
written: Vec<u8>,
closed: bool,
}
impl AsyncRead for Socket {
fn poll_read(
self: Pin<&mut Self>,
_: &mut Context<'_>,
_: &mut [u8],
) -> Poll<std::io::Result<usize>> {
unimplemented!()
}
}
impl AsyncWrite for Socket {
fn poll_write(
mut self: Pin<&mut Self>,
_: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
assert!(!self.closed);
self.written.extend_from_slice(buf);
Poll::Ready(Ok(buf.len()))
}

fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<std::io::Result<()>> {
unimplemented!()
}

fn poll_close(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<std::io::Result<()>> {
assert!(!self.closed);
self.closed = true;
Poll::Ready(Ok(()))
}
}

#[test]
fn pending_frames() {
let frame_pending = Frame::data(StreamId::new(1), vec![2]).unwrap().into();
let frame_data = Frame::data(StreamId::new(3), vec![4]).unwrap().into();
let frame_close = Frame::close_stream(StreamId::new(5), false).into();
let frame_close_ack = Frame::close_stream(StreamId::new(6), true).into();
let frame_term = Frame::term().into();
fn encode(buf: &mut Vec<u8>, frame: &Frame<()>) {
buf.extend_from_slice(&frame::header::encode(frame.header()));
if frame.header().tag() == frame::header::Tag::Data {
buf.extend_from_slice(frame.clone().into_data().body());
}
}
let mut expected_written = vec![];
encode(&mut expected_written, &frame_pending);
encode(&mut expected_written, &frame_data);
encode(&mut expected_written, &frame_close);
encode(&mut expected_written, &frame_close_ack);
encode(&mut expected_written, &frame_term);

let receiver = |frame: &Frame<_>, command: StreamCommand| {
TaggedStream::new(frame.header().stream_id(), {
let (mut tx, rx) = mpsc::channel(1);
tx.try_send(command).unwrap();
rx
})
};

let mut stream_receivers: SelectAll<_> = Default::default();
stream_receivers.push(receiver(
&frame_data,
StreamCommand::SendFrame(frame_data.clone().into_data().left()),
));
stream_receivers.push(receiver(
&frame_close,
StreamCommand::CloseStream { ack: false },
));
stream_receivers.push(receiver(
&frame_close_ack,
StreamCommand::CloseStream { ack: true },
));
let pending_frames = vec![frame_pending.into()];
let mut socket = Socket {
written: vec![],
closed: false,
};
let mut closing = Closing::new(
stream_receivers,
pending_frames.into(),
frame::Io::new(crate::connection::Id(0), &mut socket).fuse(),
);
futures::executor::block_on(async { poll_fn(|cx| closing.poll_unpin(cx)).await.unwrap() });
assert!(closing.pending_frames.is_empty());
assert!(socket.closed);
assert_eq!(socket.written, expected_written);
}
}

0 comments on commit 7063268

Please sign in to comment.