Skip to content
This repository has been archived by the owner on Mar 16, 2022. It is now read-only.

Commit

Permalink
Implement SRTCP Encryption
Browse files Browse the repository at this point in the history
This finishes adding full SRTCP support to the SRTP package

Resolves pion#117
  • Loading branch information
Sean-Der committed Sep 17, 2018
1 parent ec98022 commit 44ce8d4
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 34 deletions.
23 changes: 7 additions & 16 deletions internal/srtp/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"crypto/sha1" // #nosec
"encoding/binary"

"github.com/pions/webrtc/pkg/rtp"
"github.com/pkg/errors"
)

Expand Down Expand Up @@ -54,6 +53,7 @@ type Context struct {
srtcpSessionKey []byte
srtcpSessionSalt []byte
srtcpSessionAuthTag []byte
srtcpIndex uint32
srtcpBlock cipher.Block
}

Expand Down Expand Up @@ -190,7 +190,7 @@ func (c *Context) generateCounter(sequenceNumber uint16, rolloverCounter uint32,
return counter
}

func (c *Context) addAuthTag(packet *rtp.Packet, s *ssrcState) error {
func (c *Context) generateAuthTag(buf []byte, authTag []byte) ([]byte, error) {
// https://tools.ietf.org/html/rfc3711#section-4.2
// In the case of SRTP, M SHALL consist of the Authenticated
// Portion of the packet (as specified in Figure 1) concatenated with
Expand All @@ -205,20 +205,11 @@ func (c *Context) addAuthTag(packet *rtp.Packet, s *ssrcState) error {
// - Authenticated portion of the packet is everything BEFORE MKI
// - k_a is the session message authentication key
// - n_tag is the bit-length of the output authentication tag

mac := hmac.New(sha1.New, c.srtpSessionAuthTag) // TODO
fullPkt, err := packet.Marshal()
if err != nil {
return err
}

fullPkt = append(fullPkt, make([]byte, 4)...)
binary.BigEndian.PutUint32(fullPkt[len(fullPkt)-4:], s.rolloverCounter)

if _, err := mac.Write(fullPkt); err != nil {
return err
// - ROC is already added by caller (to allow RTP + RTCP support)
mac := hmac.New(sha1.New, authTag)
if _, err := mac.Write(buf); err != nil {
return nil, err
}

packet.Payload = append(packet.Payload, mac.Sum(nil)[0:10]...)
return nil
return mac.Sum(nil)[0:10], nil
}
42 changes: 32 additions & 10 deletions internal/srtp/srtcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,21 @@ package srtp
import (
"crypto/cipher"
"encoding/binary"

"github.com/pkg/errors"
)

// DecryptRTCP decrypts a buffer that contains a RTCP packet
// We can't pass *rtcp.Packet as the encrypt will obscure significant fields
func (c *Context) DecryptRTCP(encrypted []byte) ([]byte, error) {
rtcpLen := int((binary.BigEndian.Uint16(encrypted[2:]) + 1) * 8)
if rtcpLen+srtcpIndexSize+authTagSize > len(encrypted) {
return nil, errors.Errorf("SRCTP packet invalid size: header_len %d, buffer_size %d. ", rtcpLen, len(encrypted))
}
tailOffset := len(encrypted) - (authTagSize + srtcpIndexSize)
out := append([]byte{}, encrypted[0:tailOffset]...)

out := append([]byte{}, encrypted[0:rtcpLen]...)
isEncrypted := encrypted[rtcpLen] >> 7
isEncrypted := encrypted[tailOffset] >> 7
if isEncrypted == 0 {
return out, nil
}

srtcpIndexBuffer := append([]byte{}, encrypted[rtcpLen:rtcpLen+srtcpIndexSize]...)
srtcpIndexBuffer[0] &= 0x7f //unset Encryption bit
srtcpIndexBuffer := append([]byte{}, encrypted[tailOffset:tailOffset+srtcpIndexSize]...)
srtcpIndexBuffer[0] &= 0x7f // unset Encryption bit

index := binary.BigEndian.Uint32(srtcpIndexBuffer)
ssrc := binary.BigEndian.Uint32(encrypted[4:])
Expand All @@ -32,3 +27,30 @@ func (c *Context) DecryptRTCP(encrypted []byte) ([]byte, error) {

return out, nil
}

// EncryptRTCP encrypts a buffer that contains a RTCP packet
func (c *Context) EncryptRTCP(decrypted []byte) ([]byte, error) {
out := append([]byte{}, decrypted[:]...)
ssrc := binary.BigEndian.Uint32(decrypted[4:])

// We roll over early because MSB is used for marking as encrypted
c.srtcpIndex++
if c.srtcpIndex >= 2147483647 {
c.srtcpIndex = 0
}

// Encrypt everything after header
stream := cipher.NewCTR(c.srtcpBlock, c.generateCounter(uint16(c.srtcpIndex&0xffff), c.srtcpIndex>>16, ssrc, c.srtcpSessionSalt))
stream.XORKeyStream(out[8:], out[8:])

// Add SRTCP Index and set Encryption bit
out = append(out, make([]byte, 4)...)
binary.BigEndian.PutUint32(out[len(out)-4:], c.srtcpIndex)
out[len(out)-4] |= 0x80

authTag, err := c.generateAuthTag(out, c.srtcpSessionAuthTag)
if err != nil {
return nil, err
}
return append(out, authTag...), nil
}
13 changes: 12 additions & 1 deletion internal/srtp/srtp.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package srtp

import (
"crypto/cipher"
"encoding/binary"

"github.com/pions/webrtc/pkg/rtp"
)
Expand Down Expand Up @@ -34,10 +35,20 @@ func (c *Context) EncryptRTP(packet *rtp.Packet) bool {
stream := cipher.NewCTR(c.srtpBlock, c.generateCounter(packet.SequenceNumber, s.rolloverCounter, s.ssrc, c.srtpSessionSalt))
stream.XORKeyStream(packet.Payload, packet.Payload)

if err := c.addAuthTag(packet, s); err != nil {
fullPkt, err := packet.Marshal()
if err != nil {
return false
}

fullPkt = append(fullPkt, make([]byte, 4)...)
binary.BigEndian.PutUint32(fullPkt[len(fullPkt)-4:], s.rolloverCounter)

authTag, err := c.generateAuthTag(fullPkt, c.srtpSessionAuthTag)
if err != nil {
return false
}

packet.Payload = append(packet.Payload, authTag...)
return true
}

Expand Down
18 changes: 11 additions & 7 deletions internal/srtp/srtp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package srtp

import (
"bytes"
"fmt"
"testing"

"github.com/pions/webrtc/pkg/rtp"
Expand Down Expand Up @@ -190,16 +189,21 @@ func TestRTCPLifecycle(t *testing.T) {
t.Error(errors.Wrap(err, "CreateContext failed"))
}

// decryptContext, err := CreateContext(masterKey, masterSalt, cipherContextAlgo)
// if err != nil {
// t.Error(errors.Wrap(err, "CreateContext failed"))
// }
decryptContext, err := CreateContext(masterKey, masterSalt, cipherContextAlgo)
if err != nil {
t.Error(errors.Wrap(err, "CreateContext failed"))
}

fmt.Println(len(encrypted))
decryptResult, err := encryptContext.DecryptRTCP(append([]byte{}, encrypted...))
decryptResult, err := decryptContext.DecryptRTCP(append([]byte{}, encrypted...))
if err != nil {
t.Error(err)
}
assert.Equal(decryptResult, decrypted, "RTCP failed to decrypt")

encryptResult, err := encryptContext.EncryptRTCP(append([]byte{}, decrypted...))
if err != nil {
t.Error(err)
}
assert.Equal(encryptResult, encrypted, "RTCP failed to encrypt")

}

0 comments on commit 44ce8d4

Please sign in to comment.