diff --git a/binary_port/src/binary_message.rs b/binary_port/src/binary_message.rs index df76a2e3a9..8ca9cd1199 100644 --- a/binary_port/src/binary_message.rs +++ b/binary_port/src/binary_message.rs @@ -76,23 +76,25 @@ impl codec::Decoder for BinaryMessageCodec { fn decode(&mut self, src: &mut bytes::BytesMut) -> Result, Self::Error> { let (length, have_full_frame) = if let [b1, b2, b3, b4, remainder @ ..] = &src[..] { let length = LengthEncoding::from_le_bytes([*b1, *b2, *b3, *b4]) as usize; - (length, remainder.len() >= length) + let remainder_length = remainder.len(); + (length, remainder_length >= length) } else { // Not enough bytes to read the length. return Ok(None); }; - if !have_full_frame { - // Not enough bytes to read the whole message. - return Ok(None); - }; - if length > self.max_message_size_bytes as usize { return Err(Error::RequestTooLarge { allowed: self.max_message_size_bytes, got: length as u32, }); } + + if !have_full_frame { + // Not enough bytes to read the whole message. + return Ok(None); + }; + if length == 0 { return Err(Error::EmptyRequest); } @@ -177,7 +179,6 @@ mod tests { bytes.extend(&suffix); let _ = codec.decode(&mut bytes); - // Ensure that the bytes are not consumed. assert_eq!(bytes, suffix); } @@ -266,7 +267,28 @@ mod tests { None => break, } } - assert_eq!(messages, decoded_messages); } + + #[test] + fn should_not_decode_when_read_bytes_extend_max() { + const MAX_MESSAGE_BYTES: usize = 1000; + let rng = &mut TestRng::new(); + let mut codec = BinaryMessageCodec::new(MAX_MESSAGE_BYTES as u32); + let mut bytes = bytes::BytesMut::new(); + let some_length = (MAX_MESSAGE_BYTES * 2_usize) as LengthEncoding; //This value doesn't match the + // length of mock_bytes intentionally so we can be sure at what point did the encoder bail - + // we want to ensure that the encoder doesn't read the whole message before it bails + bytes.extend(&some_length.to_le_bytes()); + bytes.extend(std::iter::repeat_with(|| rng.gen::()).take(MAX_MESSAGE_BYTES * 3)); + + let message_res = codec.decode(&mut bytes); + assert!(message_res.is_err()); + let err = message_res.err().unwrap(); + assert!(matches!( + err, + Error::RequestTooLarge { allowed, got} + if allowed == MAX_MESSAGE_BYTES as u32 && got == MAX_MESSAGE_BYTES as u32 * 2, + )) + } }