diff --git a/packet.go b/packet.go index 7aebb14..af88af3 100644 --- a/packet.go +++ b/packet.go @@ -149,64 +149,55 @@ func (h *Header) Unmarshal(buf []byte) (n int, err error) { //nolint:gocognit n += 2 extensionLength := int(binary.BigEndian.Uint16(buf[n:])) * 4 n += 2 + extensionEnd := n + extensionLength - if expected := n + extensionLength; len(buf) < expected { - return n, fmt.Errorf("size %d < %d: %w", - len(buf), expected, - errHeaderSizeInsufficientForExtension, - ) + if len(buf) < extensionEnd { + return n, fmt.Errorf("size %d < %d: %w", len(buf), extensionEnd, errHeaderSizeInsufficientForExtension) } - switch h.ExtensionProfile { - // RFC 8285 RTP One Byte Header Extension - case extensionProfileOneByte: - end := n + extensionLength - for n < end { + if h.ExtensionProfile == extensionProfileOneByte || h.ExtensionProfile == extensionProfileTwoByte { + var ( + extid uint8 + payloadLen int + ) + + for n < extensionEnd { if buf[n] == 0x00 { // padding n++ continue } - extid := buf[n] >> 4 - payloadLen := int(buf[n]&^0xF0 + 1) - n++ + if h.ExtensionProfile == extensionProfileOneByte { + extid = buf[n] >> 4 + payloadLen = int(buf[n]&^0xF0 + 1) + n++ - if extid == extensionIDReserved { - break - } + if extid == extensionIDReserved { + break + } + } else { + extid = buf[n] + n++ - extension := Extension{id: extid, payload: buf[n : n+payloadLen]} - h.Extensions = append(h.Extensions, extension) - n += payloadLen - } + if len(buf) <= n { + return n, fmt.Errorf("size %d < %d: %w", len(buf), n, errHeaderSizeInsufficientForExtension) + } - // RFC 8285 RTP Two Byte Header Extension - case extensionProfileTwoByte: - end := n + extensionLength - for n < end { - if buf[n] == 0x00 { // padding + payloadLen = int(buf[n]) n++ - continue } - extid := buf[n] - n++ - - payloadLen := int(buf[n]) - n++ + if extensionPayloadEnd := n + payloadLen; len(buf) <= extensionPayloadEnd { + return n, fmt.Errorf("size %d < %d: %w", len(buf), extensionPayloadEnd, errHeaderSizeInsufficientForExtension) + } extension := Extension{id: extid, payload: buf[n : n+payloadLen]} h.Extensions = append(h.Extensions, extension) n += payloadLen } - - default: // RFC3550 Extension - if len(buf) < n+extensionLength { - return n, fmt.Errorf("%w: %d < %d", - errHeaderSizeInsufficientForExtension, len(buf), n+extensionLength) - } - - extension := Extension{id: 0, payload: buf[n : n+extensionLength]} + } else { + // RFC3550 Extension + extension := Extension{id: 0, payload: buf[n:extensionEnd]} h.Extensions = append(h.Extensions, extension) n += len(h.Extensions[0].payload) } diff --git a/packet_test.go b/packet_test.go index 3044c0d..98dfb40 100644 --- a/packet_test.go +++ b/packet_test.go @@ -1192,6 +1192,34 @@ func TestRFC8285TwoByteSetExtensionShouldErrorWhenPayloadTooLarge(t *testing.T) } } +func TestRFC8285Padding(t *testing.T) { + header := &Header{} + + for _, payload := range [][]byte{ + { + 0b00010000, // header.Extension = true + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // SequenceNumber, Timestamp, SSRC + 0xBE, 0xDE, // header.ExtensionProfile = extensionProfileOneByte + 0, 1, // extensionLength + 0, 0, 0, // padding + 1, // extid + }, + { + 0b00010000, // header.Extension = true + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // SequenceNumber, Timestamp, SSRC + 0x10, 0x00, // header.ExtensionProfile = extensionProfileOneByte + 0, 1, // extensionLength + 0, 0, 0, // padding + 1, // extid + }, + } { + _, err := header.Unmarshal(payload) + if !errors.Is(err, errHeaderSizeInsufficientForExtension) { + t.Fatal("Expected errHeaderSizeInsufficientForExtension") + } + } +} + func TestRFC3550SetExtensionShouldErrorWhenNonZero(t *testing.T) { payload := []byte{ // Payload