diff --git a/yamux/src/connection/closing.rs b/yamux/src/connection/closing.rs index 4d581cd..41fc815 100644 --- a/yamux/src/connection/closing.rs +++ b/yamux/src/connection/closing.rs @@ -98,3 +98,102 @@ enum State { FlushingPendingFrames, ClosingSocket, } + +#[cfg(test)] +mod tests { + use super::*; + use futures::future::poll_fn; + use futures::FutureExt; + + struct Socket { + written: Vec, + closed: bool, + } + impl AsyncRead for Socket { + fn poll_read( + self: Pin<&mut Self>, + _: &mut Context<'_>, + _: &mut [u8], + ) -> Poll> { + todo!() + } + } + impl AsyncWrite for Socket { + fn poll_write( + mut self: Pin<&mut Self>, + _: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + assert!(!self.closed); + self.written.extend_from_slice(buf); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + todo!() + } + + fn poll_close(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + assert!(!self.closed); + self.closed = true; + Poll::Ready(Ok(())) + } + } + + #[test] + fn pending_frames() { + let frame_pending = Frame::data(StreamId::new(1), vec![2]).unwrap().into(); + let frame_data = Frame::data(StreamId::new(3), vec![4]).unwrap().into(); + let frame_close = Frame::close_stream(StreamId::new(5), false).into(); + let frame_close_ack = Frame::close_stream(StreamId::new(6), true).into(); + let frame_term = Frame::term().into(); + fn encode(buf: &mut Vec, frame: &Frame<()>) { + buf.extend_from_slice(&frame::header::encode(frame.header())); + if frame.header().tag() == frame::header::Tag::Data { + buf.extend_from_slice(frame.clone().into_data().body()); + } + } + let mut expected_written = vec![]; + encode(&mut expected_written, &frame_pending); + encode(&mut expected_written, &frame_data); + encode(&mut expected_written, &frame_close); + encode(&mut expected_written, &frame_close_ack); + encode(&mut expected_written, &frame_term); + + let receiver = |frame: &Frame<_>, command: StreamCommand| { + TaggedStream::new(frame.header().stream_id(), { + let (mut tx, rx) = mpsc::channel(1); + tx.try_send(command).unwrap(); + rx + }) + }; + + let mut stream_receivers: SelectAll<_> = Default::default(); + stream_receivers.push(receiver( + &frame_data, + StreamCommand::SendFrame(frame_data.clone().into_data().left()), + )); + stream_receivers.push(receiver( + &frame_close, + StreamCommand::CloseStream { ack: false }, + )); + stream_receivers.push(receiver( + &frame_close_ack, + StreamCommand::CloseStream { ack: true }, + )); + let pending_frames = vec![frame_pending.clone().into()]; + let mut socket = Socket { + written: vec![], + closed: false, + }; + let mut closing = Closing::new( + stream_receivers, + pending_frames.into(), + frame::Io::new(crate::connection::Id(0), &mut socket).fuse(), + ); + futures::executor::block_on(async { poll_fn(|cx| closing.poll_unpin(cx)).await.unwrap() }); + assert!(closing.pending_frames.is_empty()); + assert!(socket.closed); + assert_eq!(socket.written, expected_written); + } +}