Skip to content

Commit

Permalink
sctp: allow partial reads
Browse files Browse the repository at this point in the history
removes `ErrShortBuffer`

Refs #273
  • Loading branch information
melekes committed Sep 30, 2022
1 parent c9409ba commit f414248
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 39 deletions.
14 changes: 5 additions & 9 deletions sctp/src/association/association_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -735,16 +735,12 @@ async fn test_assoc_reliable_short_buffer() -> Result<()> {

flush_buffers(&br, &a0, &a1).await;

// Verify partial reads are permitted.
let mut buf = vec![0u8; 3];
let result = s1.read_sctp(&mut buf).await;
assert!(result.is_err(), "expected error to be io.ErrShortBuffer");
if let Err(err) = result {
assert_eq!(
Error::ErrShortBuffer,
err,
"expected error to be io.ErrShortBuffer"
);
}
let (n, ppi) = s1.read_sctp(&mut buf).await?;
assert_eq!(n, 3, "unexpected length of received data");
assert_eq!(&buf[..n], &MSG[..3], "unexpected length of received data");
assert_eq!(ppi, PayloadProtocolIdentifier::Binary, "unexpected ppi");

{
let q = s0.reassembly_queue.lock().await;
Expand Down
15 changes: 7 additions & 8 deletions sctp/src/queue/queue_test.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::error::{Error, Result};
use crate::error::Result;

use bytes::{Bytes, BytesMut};

Expand Down Expand Up @@ -757,7 +757,7 @@ fn test_reassembly_queue_should_fail_to_read_if_the_nex_ssn_is_not_ready() -> Re
}

#[test]
fn test_reassembly_queue_detect_buffer_too_short() -> Result<()> {
fn test_reassembly_queue_permits_partial_reads() -> Result<()> {
let mut rq = ReassemblyQueue::new(0);

let org_ppi = PayloadProtocolIdentifier::Binary;
Expand All @@ -777,12 +777,11 @@ fn test_reassembly_queue_detect_buffer_too_short() -> Result<()> {
assert_eq!(10, rq.get_num_bytes(), "num bytes mismatch");

let mut buf = vec![0u8; 8]; // <- passing buffer too short
let result = rq.read(&mut buf);
assert!(result.is_err(), "read() should not succeed");
if let Err(err) = result {
assert_eq!(Error::ErrShortBuffer, err, "read() should not succeed");
}
assert_eq!(0, rq.get_num_bytes(), "num bytes mismatch");
let (n, ppi) = rq.read(&mut buf)?;
assert_eq!(8, n, "should received 8 bytes");
assert_eq!(2, rq.get_num_bytes(), "num bytes mismatch");
assert_eq!(ppi, org_ppi, "should have valid ppi");
assert_eq!(&buf[..n], b"01234567", "data should match");

Ok(())
}
Expand Down
61 changes: 39 additions & 22 deletions sctp/src/queue/reassembly_queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ pub(crate) struct ReassemblyQueue {
pub(crate) unordered: Vec<ChunkSet>,
pub(crate) unordered_chunks: Vec<ChunkPayloadData>,
pub(crate) n_bytes: usize,

pub(crate) unread_cset: Option<ChunkSet>,
pub(crate) unread_chunk_num_bytes_read: usize,
}

impl ReassemblyQueue {
Expand All @@ -126,10 +129,7 @@ impl ReassemblyQueue {
ReassemblyQueue {
si,
next_ssn: 0, // From RFC 4960 Sec 6.5:
ordered: vec![],
unordered: vec![],
unordered_chunks: vec![],
n_bytes: 0,
..ReassemblyQueue::default()
}
}

Expand Down Expand Up @@ -254,8 +254,11 @@ impl ReassemblyQueue {
}

pub(crate) fn read(&mut self, buf: &mut [u8]) -> Result<(usize, PayloadProtocolIdentifier)> {
// Check unordered first
let cset = if !self.unordered.is_empty() {
let cset = if self.unread_cset.is_some() {
// Read unread chunks from previous iteration, if any.
self.unread_cset.take().unwrap()
} else if !self.unordered.is_empty() {
// Then check unordered
self.unordered.remove(0)
} else if !self.ordered.is_empty() {
// Now, check ordered
Expand All @@ -274,27 +277,41 @@ impl ReassemblyQueue {
return Err(Error::ErrTryAgain);
};

// Concat all fragments into the buffer
let ppi = cset.ppi;

// Concat fragments into the buffer.
let mut n_written = 0;
let mut err = None;
for c in &cset.chunks {
let to_copy = c.user_data.len();
self.subtract_num_bytes(to_copy);
if err.is_none() {
let n = std::cmp::min(to_copy, buf.len() - n_written);
buf[n_written..n_written + n].copy_from_slice(&c.user_data[..n]);
n_written += n;
if n < to_copy {
err = Some(Error::ErrShortBuffer);
for (i, c) in cset.chunks.iter().enumerate() {
// If the last chunk was only partially read during the previous read.
let user_data = if i == 0 && self.unread_chunk_num_bytes_read > 0 {
&c.user_data[self.unread_chunk_num_bytes_read + 1..]
} else {
c.user_data.as_ref()
};
let n = std::cmp::min(user_data.len(), buf.len() - n_written);
buf[n_written..n_written + n].copy_from_slice(&user_data[..n]);
self.subtract_num_bytes(n);
n_written += n;

if n_written == buf.len() {
if n < c.user_data.len() {
// If this chunk was read only partially
self.unread_chunk_num_bytes_read = n;
let mut s = ChunkSet::new(cset.ssn, cset.ppi);
s.chunks = cset.chunks[i..].to_vec();
self.unread_cset = Some(s);
} else if i < cset.chunks.len() - 1 {
// If there are unread chunks
self.unread_chunk_num_bytes_read = 0;
let mut s = ChunkSet::new(cset.ssn, cset.ppi);
s.chunks = cset.chunks[i + 1..].to_vec();
self.unread_cset = Some(s);
}
break;
}
}

if let Some(err) = err {
Err(err)
} else {
Ok((n_written, cset.ppi))
}
Ok((n_written, ppi))
}

/// Use last_ssn to locate a chunkSet then remove it if the set has
Expand Down

0 comments on commit f414248

Please sign in to comment.