diff --git a/sctp/src/association/association_test.rs b/sctp/src/association/association_test.rs index 79f40b75a..c78d38fd0 100644 --- a/sctp/src/association/association_test.rs +++ b/sctp/src/association/association_test.rs @@ -735,16 +735,12 @@ async fn test_assoc_reliable_short_buffer() -> Result<()> { flush_buffers(&br, &a0, &a1).await; + // Verify partial reads are permitted. let mut buf = vec![0u8; 3]; - let result = s1.read_sctp(&mut buf).await; - assert!(result.is_err(), "expected error to be io.ErrShortBuffer"); - if let Err(err) = result { - assert_eq!( - Error::ErrShortBuffer, - err, - "expected error to be io.ErrShortBuffer" - ); - } + let (n, ppi) = s1.read_sctp(&mut buf).await?; + assert_eq!(n, 3, "unexpected length of received data"); + assert_eq!(&buf[..n], &MSG[..3], "unexpected length of received data"); + assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi"); { let q = s0.reassembly_queue.lock().await; diff --git a/sctp/src/queue/queue_test.rs b/sctp/src/queue/queue_test.rs index 26a680114..997a811a3 100644 --- a/sctp/src/queue/queue_test.rs +++ b/sctp/src/queue/queue_test.rs @@ -1,4 +1,4 @@ -use crate::error::{Error, Result}; +use crate::error::Result; use bytes::{Bytes, BytesMut}; @@ -757,7 +757,7 @@ fn test_reassembly_queue_should_fail_to_read_if_the_nex_ssn_is_not_ready() -> Re } #[test] -fn test_reassembly_queue_detect_buffer_too_short() -> Result<()> { +fn test_reassembly_queue_permits_partial_reads() -> Result<()> { let mut rq = ReassemblyQueue::new(0); let org_ppi = PayloadProtocolIdentifier::Binary; @@ -777,12 +777,11 @@ fn test_reassembly_queue_detect_buffer_too_short() -> Result<()> { assert_eq!(10, rq.get_num_bytes(), "num bytes mismatch"); let mut buf = vec![0u8; 8]; // <- passing buffer too short - let result = rq.read(&mut buf); - assert!(result.is_err(), "read() should not succeed"); - if let Err(err) = result { - assert_eq!(Error::ErrShortBuffer, err, "read() should not succeed"); - } - assert_eq!(0, rq.get_num_bytes(), "num bytes mismatch"); + let (n, ppi) = rq.read(&mut buf)?; + assert_eq!(8, n, "should received 8 bytes"); + assert_eq!(2, rq.get_num_bytes(), "num bytes mismatch"); + assert_eq!(ppi, org_ppi, "should have valid ppi"); + assert_eq!(&buf[..n], b"01234567", "data should match"); Ok(()) } diff --git a/sctp/src/queue/reassembly_queue.rs b/sctp/src/queue/reassembly_queue.rs index 339908eaf..3a6e35926 100644 --- a/sctp/src/queue/reassembly_queue.rs +++ b/sctp/src/queue/reassembly_queue.rs @@ -114,6 +114,9 @@ pub(crate) struct ReassemblyQueue { pub(crate) unordered: Vec, pub(crate) unordered_chunks: Vec, pub(crate) n_bytes: usize, + + pub(crate) unread_cset: Option, + pub(crate) unread_chunk_num_bytes_read: usize, } impl ReassemblyQueue { @@ -126,10 +129,7 @@ impl ReassemblyQueue { ReassemblyQueue { si, next_ssn: 0, // From RFC 4960 Sec 6.5: - ordered: vec![], - unordered: vec![], - unordered_chunks: vec![], - n_bytes: 0, + ..ReassemblyQueue::default() } } @@ -254,8 +254,11 @@ impl ReassemblyQueue { } pub(crate) fn read(&mut self, buf: &mut [u8]) -> Result<(usize, PayloadProtocolIdentifier)> { - // Check unordered first - let cset = if !self.unordered.is_empty() { + let cset = if self.unread_cset.is_some() { + // Read unread chunks from previous iteration, if any. + self.unread_cset.take().unwrap() + } else if !self.unordered.is_empty() { + // Then check unordered self.unordered.remove(0) } else if !self.ordered.is_empty() { // Now, check ordered @@ -274,27 +277,41 @@ impl ReassemblyQueue { return Err(Error::ErrTryAgain); }; - // Concat all fragments into the buffer + let ppi = cset.ppi; + + // Concat fragments into the buffer. let mut n_written = 0; - let mut err = None; - for c in &cset.chunks { - let to_copy = c.user_data.len(); - self.subtract_num_bytes(to_copy); - if err.is_none() { - let n = std::cmp::min(to_copy, buf.len() - n_written); - buf[n_written..n_written + n].copy_from_slice(&c.user_data[..n]); - n_written += n; - if n < to_copy { - err = Some(Error::ErrShortBuffer); + for (i, c) in cset.chunks.iter().enumerate() { + // If the last chunk was only partially read during the previous read. + let user_data = if i == 0 && self.unread_chunk_num_bytes_read > 0 { + &c.user_data[self.unread_chunk_num_bytes_read + 1..] + } else { + c.user_data.as_ref() + }; + let n = std::cmp::min(user_data.len(), buf.len() - n_written); + buf[n_written..n_written + n].copy_from_slice(&user_data[..n]); + self.subtract_num_bytes(n); + n_written += n; + + if n_written == buf.len() { + if n < c.user_data.len() { + // If this chunk was read only partially + self.unread_chunk_num_bytes_read = n; + let mut s = ChunkSet::new(cset.ssn, cset.ppi); + s.chunks = cset.chunks[i..].to_vec(); + self.unread_cset = Some(s); + } else if i < cset.chunks.len() - 1 { + // If there are unread chunks + self.unread_chunk_num_bytes_read = 0; + let mut s = ChunkSet::new(cset.ssn, cset.ppi); + s.chunks = cset.chunks[i + 1..].to_vec(); + self.unread_cset = Some(s); } + break; } } - if let Some(err) = err { - Err(err) - } else { - Ok((n_written, cset.ppi)) - } + Ok((n_written, ppi)) } /// Use last_ssn to locate a chunkSet then remove it if the set has