diff --git a/src/dtls.rs b/src/dtls.rs index fdca601..6c1b1c1 100644 --- a/src/dtls.rs +++ b/src/dtls.rs @@ -113,6 +113,7 @@ pub enum DTLSMessageHandshakeBody<'a> { Finished(&'a [u8]), CertificateStatus(TlsCertificateStatusContents<'a>), NextProtocol(TlsNextProtocolContent<'a>), + Fragment(&'a [u8]), } /// DTLS plaintext message @@ -127,6 +128,17 @@ pub enum DTLSMessage<'a> { Heartbeat(TlsMessageHeartbeat<'a>), } +impl<'a> DTLSMessage<'a> { + /// Tell if this DTLSMessage is a (handshake) fragment that needs combining with other + /// fragments to be a complete message. + pub fn is_fragment(&self) -> bool { + match self { + DTLSMessage::Handshake(h) => matches!(h.body, DTLSMessageHandshakeBody::Fragment(_)), + _ => false, + } + } +} + // --------------------------- PARSERS --------------------------- /// DTLS record header @@ -148,6 +160,11 @@ pub fn parse_dtls_record_header(i: &[u8]) -> IResult<&[u8], DTLSRecordHeader> { Ok((i, record)) } +/// Treat the entire input as an opaque fragment. +fn parse_dtls_fragment(i: &[u8]) -> IResult<&[u8], DTLSMessageHandshakeBody> { + Ok((&[], DTLSMessageHandshakeBody::Fragment(i))) +} + /// DTLS Client Hello // Section 4.2 of RFC6347 fn parse_dtls_client_hello(i: &[u8]) -> IResult<&[u8], DTLSMessageHandshakeBody> { @@ -222,8 +239,16 @@ pub fn parse_dtls_message_handshake(i: &[u8]) -> IResult<&[u8], DTLSMessage> { let (i, message_seq) = be_u16(i)?; let (i, fragment_offset) = be_u24(i)?; let (i, fragment_length) = be_u24(i)?; - let (i, raw_msg) = take(length)(i)?; + // This packet contains fragment_length (which is less than length for fragmentation) + let (i, raw_msg) = take(fragment_length)(i)?; + + // Handshake messages can be fragmented over multiple packets. When fragmented, the user + // needs the fragment_offset, fragment_length and length to determine whether they received + // all the fragments. The DTLS spec allows for overlapping and duplicated fragments. + let is_fragment = fragment_offset > 0 || fragment_length < length; + let (_, body) = match msg_type { + _ if is_fragment => parse_dtls_fragment(raw_msg), TlsHandshakeType::ClientHello => parse_dtls_client_hello(raw_msg), TlsHandshakeType::HelloVerifyRequest => parse_dtls_hello_verify_request(raw_msg), TlsHandshakeType::ServerHello => parse_dtls_handshake_msg_server_hello_tlsv12(raw_msg), @@ -281,18 +306,6 @@ pub fn parse_dtls_record_with_header<'i>( } } -/// Parse DTLS record, leaving `fragment` unparsed -// Section 4.1 of RFC6347 -pub fn parse_dtls_raw_record(i: &[u8]) -> IResult<&[u8], DTLSRawRecord> { - let (i, header) = parse_dtls_record_header(i)?; - // As in TLS 1.2, the length should not exceed 2^14. - if header.length > MAX_RECORD_LEN { - return Err(Err::Error(make_error(i, ErrorKind::TooLarge))); - } - let (i, fragment) = take(header.length as usize)(i)?; - Ok((i, DTLSRawRecord { header, fragment })) -} - /// Parse one DTLS plaintext record // Section 4.1 of RFC6347 pub fn parse_dtls_plaintext_record(i: &[u8]) -> IResult<&[u8], DTLSPlaintext> {