From 44ce8d41ae5952d9cf40aaed32f0bf9c1d9fec69 Mon Sep 17 00:00:00 2001 From: Sean DuBois Date: Mon, 17 Sep 2018 00:26:06 -0700 Subject: [PATCH] Implement SRTCP Encryption This finishes adding full SRTCP support to the SRTP package Resolves #117 --- internal/srtp/context.go | 23 +++++++-------------- internal/srtp/srtcp.go | 42 +++++++++++++++++++++++++++++--------- internal/srtp/srtp.go | 13 +++++++++++- internal/srtp/srtp_test.go | 18 +++++++++------- 4 files changed, 62 insertions(+), 34 deletions(-) diff --git a/internal/srtp/context.go b/internal/srtp/context.go index 99f13900288..0992e6e11ee 100644 --- a/internal/srtp/context.go +++ b/internal/srtp/context.go @@ -7,7 +7,6 @@ import ( "crypto/sha1" // #nosec "encoding/binary" - "github.com/pions/webrtc/pkg/rtp" "github.com/pkg/errors" ) @@ -54,6 +53,7 @@ type Context struct { srtcpSessionKey []byte srtcpSessionSalt []byte srtcpSessionAuthTag []byte + srtcpIndex uint32 srtcpBlock cipher.Block } @@ -190,7 +190,7 @@ func (c *Context) generateCounter(sequenceNumber uint16, rolloverCounter uint32, return counter } -func (c *Context) addAuthTag(packet *rtp.Packet, s *ssrcState) error { +func (c *Context) generateAuthTag(buf []byte, authTag []byte) ([]byte, error) { // https://tools.ietf.org/html/rfc3711#section-4.2 // In the case of SRTP, M SHALL consist of the Authenticated // Portion of the packet (as specified in Figure 1) concatenated with @@ -205,20 +205,11 @@ func (c *Context) addAuthTag(packet *rtp.Packet, s *ssrcState) error { // - Authenticated portion of the packet is everything BEFORE MKI // - k_a is the session message authentication key // - n_tag is the bit-length of the output authentication tag - - mac := hmac.New(sha1.New, c.srtpSessionAuthTag) // TODO - fullPkt, err := packet.Marshal() - if err != nil { - return err - } - - fullPkt = append(fullPkt, make([]byte, 4)...) - binary.BigEndian.PutUint32(fullPkt[len(fullPkt)-4:], s.rolloverCounter) - - if _, err := mac.Write(fullPkt); err != nil { - return err + // - ROC is already added by caller (to allow RTP + RTCP support) + mac := hmac.New(sha1.New, authTag) + if _, err := mac.Write(buf); err != nil { + return nil, err } - packet.Payload = append(packet.Payload, mac.Sum(nil)[0:10]...) - return nil + return mac.Sum(nil)[0:10], nil } diff --git a/internal/srtp/srtcp.go b/internal/srtp/srtcp.go index a9f3e991e20..197caa4cc61 100644 --- a/internal/srtp/srtcp.go +++ b/internal/srtp/srtcp.go @@ -3,26 +3,21 @@ package srtp import ( "crypto/cipher" "encoding/binary" - - "github.com/pkg/errors" ) // DecryptRTCP decrypts a buffer that contains a RTCP packet // We can't pass *rtcp.Packet as the encrypt will obscure significant fields func (c *Context) DecryptRTCP(encrypted []byte) ([]byte, error) { - rtcpLen := int((binary.BigEndian.Uint16(encrypted[2:]) + 1) * 8) - if rtcpLen+srtcpIndexSize+authTagSize > len(encrypted) { - return nil, errors.Errorf("SRCTP packet invalid size: header_len %d, buffer_size %d. ", rtcpLen, len(encrypted)) - } + tailOffset := len(encrypted) - (authTagSize + srtcpIndexSize) + out := append([]byte{}, encrypted[0:tailOffset]...) - out := append([]byte{}, encrypted[0:rtcpLen]...) - isEncrypted := encrypted[rtcpLen] >> 7 + isEncrypted := encrypted[tailOffset] >> 7 if isEncrypted == 0 { return out, nil } - srtcpIndexBuffer := append([]byte{}, encrypted[rtcpLen:rtcpLen+srtcpIndexSize]...) - srtcpIndexBuffer[0] &= 0x7f //unset Encryption bit + srtcpIndexBuffer := append([]byte{}, encrypted[tailOffset:tailOffset+srtcpIndexSize]...) + srtcpIndexBuffer[0] &= 0x7f // unset Encryption bit index := binary.BigEndian.Uint32(srtcpIndexBuffer) ssrc := binary.BigEndian.Uint32(encrypted[4:]) @@ -32,3 +27,30 @@ func (c *Context) DecryptRTCP(encrypted []byte) ([]byte, error) { return out, nil } + +// EncryptRTCP encrypts a buffer that contains a RTCP packet +func (c *Context) EncryptRTCP(decrypted []byte) ([]byte, error) { + out := append([]byte{}, decrypted[:]...) + ssrc := binary.BigEndian.Uint32(decrypted[4:]) + + // We roll over early because MSB is used for marking as encrypted + c.srtcpIndex++ + if c.srtcpIndex >= 2147483647 { + c.srtcpIndex = 0 + } + + // Encrypt everything after header + stream := cipher.NewCTR(c.srtcpBlock, c.generateCounter(uint16(c.srtcpIndex&0xffff), c.srtcpIndex>>16, ssrc, c.srtcpSessionSalt)) + stream.XORKeyStream(out[8:], out[8:]) + + // Add SRTCP Index and set Encryption bit + out = append(out, make([]byte, 4)...) + binary.BigEndian.PutUint32(out[len(out)-4:], c.srtcpIndex) + out[len(out)-4] |= 0x80 + + authTag, err := c.generateAuthTag(out, c.srtcpSessionAuthTag) + if err != nil { + return nil, err + } + return append(out, authTag...), nil +} diff --git a/internal/srtp/srtp.go b/internal/srtp/srtp.go index 88dff7d5f32..9dd3feccee2 100644 --- a/internal/srtp/srtp.go +++ b/internal/srtp/srtp.go @@ -2,6 +2,7 @@ package srtp import ( "crypto/cipher" + "encoding/binary" "github.com/pions/webrtc/pkg/rtp" ) @@ -34,10 +35,20 @@ func (c *Context) EncryptRTP(packet *rtp.Packet) bool { stream := cipher.NewCTR(c.srtpBlock, c.generateCounter(packet.SequenceNumber, s.rolloverCounter, s.ssrc, c.srtpSessionSalt)) stream.XORKeyStream(packet.Payload, packet.Payload) - if err := c.addAuthTag(packet, s); err != nil { + fullPkt, err := packet.Marshal() + if err != nil { return false } + fullPkt = append(fullPkt, make([]byte, 4)...) + binary.BigEndian.PutUint32(fullPkt[len(fullPkt)-4:], s.rolloverCounter) + + authTag, err := c.generateAuthTag(fullPkt, c.srtpSessionAuthTag) + if err != nil { + return false + } + + packet.Payload = append(packet.Payload, authTag...) return true } diff --git a/internal/srtp/srtp_test.go b/internal/srtp/srtp_test.go index 76061622cec..fdcbce1fe87 100644 --- a/internal/srtp/srtp_test.go +++ b/internal/srtp/srtp_test.go @@ -2,7 +2,6 @@ package srtp import ( "bytes" - "fmt" "testing" "github.com/pions/webrtc/pkg/rtp" @@ -190,16 +189,21 @@ func TestRTCPLifecycle(t *testing.T) { t.Error(errors.Wrap(err, "CreateContext failed")) } - // decryptContext, err := CreateContext(masterKey, masterSalt, cipherContextAlgo) - // if err != nil { - // t.Error(errors.Wrap(err, "CreateContext failed")) - // } + decryptContext, err := CreateContext(masterKey, masterSalt, cipherContextAlgo) + if err != nil { + t.Error(errors.Wrap(err, "CreateContext failed")) + } - fmt.Println(len(encrypted)) - decryptResult, err := encryptContext.DecryptRTCP(append([]byte{}, encrypted...)) + decryptResult, err := decryptContext.DecryptRTCP(append([]byte{}, encrypted...)) if err != nil { t.Error(err) } assert.Equal(decryptResult, decrypted, "RTCP failed to decrypt") + encryptResult, err := encryptContext.EncryptRTCP(append([]byte{}, decrypted...)) + if err != nil { + t.Error(err) + } + assert.Equal(encryptResult, encrypted, "RTCP failed to encrypt") + }