Skip to content

Commit

Permalink
Implement draft-ietf-tsvwg-sctp-zero-checksum-01
Browse files Browse the repository at this point in the history
  • Loading branch information
mengelbart authored and Sean-Der committed Feb 9, 2024
1 parent 2927025 commit 932b71a
Show file tree
Hide file tree
Showing 10 changed files with 325 additions and 71 deletions.
87 changes: 73 additions & 14 deletions association.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ type Association struct {
cumulativeTSNAckPoint uint32
advancedPeerTSNAckPoint uint32
useForwardTSN bool
useZeroChecksum bool
requestZeroChecksum bool

// Congestion control parameters
maxReceiveBufferSize uint32
Expand Down Expand Up @@ -233,6 +235,7 @@ type Config struct {
NetConn net.Conn
MaxReceiveBufferSize uint32
MaxMessageSize uint32
EnableZeroChecksum bool
LoggerFactory logging.LoggerFactory
}

Expand Down Expand Up @@ -320,6 +323,7 @@ func createAssociation(config Config) *Association {
handshakeCompletedCh: make(chan error),
cumulativeTSNAckPoint: tsn - 1,
advancedPeerTSNAckPoint: tsn - 1,
requestZeroChecksum: config.EnableZeroChecksum,
silentError: ErrSilentlyDiscard,
stats: &associationStats{},
log: config.LoggerFactory.NewLogger("sctp"),
Expand Down Expand Up @@ -362,6 +366,11 @@ func (a *Association) init(isClient bool) {
init.initiateTag = a.myVerificationTag
init.advertisedReceiverWindowCredit = a.maxReceiveBufferSize
setSupportedExtensions(&init.chunkInitCommon)

if a.requestZeroChecksum {
init.params = append(init.params, &paramZeroChecksumAcceptable{edmid: dtlsErrorDetectionMethod})
}

a.storedInit = init

err := a.sendInit()
Expand Down Expand Up @@ -618,10 +627,45 @@ func (a *Association) unregisterStream(s *Stream, err error) {
s.readNotifier.Broadcast()
}

func chunkMandatoryChecksum(cc []chunk) bool {
for _, c := range cc {
switch c.(type) {
case *chunkInit, *chunkInitAck, *chunkCookieEcho:
return true
}
}
return false
}

func (a *Association) marshalPacket(p *packet) ([]byte, error) {
return p.marshal(!a.useZeroChecksum || chunkMandatoryChecksum(p.chunks))
}

func (a *Association) unmarshalPacket(raw []byte) (*packet, error) {
p := &packet{}
if !a.useZeroChecksum {
if err := p.unmarshal(true, raw); err != nil {
return nil, err
}
return p, nil
}

if err := p.unmarshal(false, raw); err != nil {
return nil, err
}
if chunkMandatoryChecksum(p.chunks) {
if err := p.unmarshal(true, raw); err != nil {
return nil, err
}
}

return p, nil
}

// handleInbound parses incoming raw packets
func (a *Association) handleInbound(raw []byte) error {
p := &packet{}
if err := p.unmarshal(raw); err != nil {
p, err := a.unmarshalPacket(raw)
if err != nil {
a.log.Warnf("[%s] unable to parse SCTP packet %s", a.name, err)
return nil
}
Expand All @@ -647,7 +691,7 @@ func (a *Association) handleInbound(raw []byte) error {
// The caller should hold the lock
func (a *Association) gatherDataPacketsToRetransmit(rawPackets [][]byte) [][]byte {
for _, p := range a.getDataPacketsToRetransmit() {
raw, err := p.marshal()
raw, err := a.marshalPacket(p)
if err != nil {
a.log.Warnf("[%s] failed to serialize a DATA packet to be retransmitted", a.name)
continue
Expand All @@ -668,7 +712,7 @@ func (a *Association) gatherOutboundDataAndReconfigPackets(rawPackets [][]byte)
a.log.Tracef("[%s] T3-rtx timer start (pt1)", a.name)
a.t3RTX.start(a.rtoMgr.getRTO())
for _, p := range a.bundleDataChunksIntoPackets(chunks) {
raw, err := p.marshal()
raw, err := a.marshalPacket(p)
if err != nil {
a.log.Warnf("[%s] failed to serialize a DATA packet", a.name)
continue
Expand All @@ -683,7 +727,7 @@ func (a *Association) gatherOutboundDataAndReconfigPackets(rawPackets [][]byte)
a.log.Debugf("[%s] retransmit %d RECONFIG chunk(s)", a.name, len(a.reconfigs))
for _, c := range a.reconfigs {
p := a.createPacket([]chunk{c})
raw, err := p.marshal()
raw, err := a.marshalPacket(p)
if err != nil {
a.log.Warnf("[%s] failed to serialize a RECONFIG packet to be retransmitted", a.name)
} else {
Expand All @@ -706,7 +750,7 @@ func (a *Association) gatherOutboundDataAndReconfigPackets(rawPackets [][]byte)
a.log.Debugf("[%s] sending RECONFIG: rsn=%d tsn=%d streams=%v",
a.name, rsn, a.myNextTSN-1, sisToReset)
p := a.createPacket([]chunk{c})
raw, err := p.marshal()
raw, err := a.marshalPacket(p)
if err != nil {
a.log.Warnf("[%s] failed to serialize a RECONFIG packet to be transmitted", a.name)
} else {
Expand Down Expand Up @@ -769,7 +813,7 @@ func (a *Association) gatherOutboundFastRetransmissionPackets(rawPackets [][]byt
}

if len(toFastRetrans) > 0 {
raw, err := a.createPacket(toFastRetrans).marshal()
raw, err := a.marshalPacket(a.createPacket(toFastRetrans))
if err != nil {
a.log.Warnf("[%s] failed to serialize a DATA packet to be fast-retransmitted", a.name)
} else {
Expand All @@ -787,7 +831,7 @@ func (a *Association) gatherOutboundSackPackets(rawPackets [][]byte) [][]byte {
a.ackState = ackStateIdle
sack := a.createSelectiveAckChunk()
a.log.Debugf("[%s] sending SACK: %s", a.name, sack)
raw, err := a.createPacket([]chunk{sack}).marshal()
raw, err := a.marshalPacket(a.createPacket([]chunk{sack}))
if err != nil {
a.log.Warnf("[%s] failed to serialize a SACK packet", a.name)
} else {
Expand All @@ -804,7 +848,7 @@ func (a *Association) gatherOutboundForwardTSNPackets(rawPackets [][]byte) [][]b
a.willSendForwardTSN = false
if sna32GT(a.advancedPeerTSNAckPoint, a.cumulativeTSNAckPoint) {
fwdtsn := a.createForwardTSN()
raw, err := a.createPacket([]chunk{fwdtsn}).marshal()
raw, err := a.marshalPacket(a.createPacket([]chunk{fwdtsn}))
if err != nil {
a.log.Warnf("[%s] failed to serialize a Forward TSN packet", a.name)
} else {
Expand All @@ -827,7 +871,7 @@ func (a *Association) gatherOutboundShutdownPackets(rawPackets [][]byte) ([][]by
cumulativeTSNAck: a.cumulativeTSNAckPoint,
}

raw, err := a.createPacket([]chunk{shutdown}).marshal()
raw, err := a.marshalPacket(a.createPacket([]chunk{shutdown}))
if err != nil {
a.log.Warnf("[%s] failed to serialize a Shutdown packet", a.name)
} else {
Expand All @@ -839,7 +883,7 @@ func (a *Association) gatherOutboundShutdownPackets(rawPackets [][]byte) ([][]by

shutdownAck := &chunkShutdownAck{}

raw, err := a.createPacket([]chunk{shutdownAck}).marshal()
raw, err := a.marshalPacket(a.createPacket([]chunk{shutdownAck}))
if err != nil {
a.log.Warnf("[%s] failed to serialize a ShutdownAck packet", a.name)
} else {
Expand All @@ -851,7 +895,7 @@ func (a *Association) gatherOutboundShutdownPackets(rawPackets [][]byte) ([][]by

shutdownComplete := &chunkShutdownComplete{}

raw, err := a.createPacket([]chunk{shutdownComplete}).marshal()
raw, err := a.marshalPacket(a.createPacket([]chunk{shutdownComplete}))
if err != nil {
a.log.Warnf("[%s] failed to serialize a ShutdownComplete packet", a.name)
} else {
Expand All @@ -875,7 +919,7 @@ func (a *Association) gatherAbortPacket() ([]byte, error) {
abort.errorCauses = []errorCause{cause}
}

raw, err := a.createPacket([]chunk{abort}).marshal()
raw, err := a.marshalPacket(a.createPacket([]chunk{abort}))

return raw, err
}
Expand All @@ -900,7 +944,7 @@ func (a *Association) gatherOutbound() ([][]byte, bool) {

if a.controlQueue.size() > 0 {
for _, p := range a.controlQueue.popAll() {
raw, err := p.marshal()
raw, err := a.marshalPacket(p)
if err != nil {
a.log.Warnf("[%s] failed to serialize a control packet", a.name)
continue
Expand Down Expand Up @@ -1092,6 +1136,7 @@ func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) {
// subtracting one from it.
a.peerLastTSN = i.initialTSN - 1

peerHasZeroChecksum := false
for _, param := range i.params {
switch v := param.(type) { // nolint:gocritic
case *paramSupportedExtensions:
Expand All @@ -1101,8 +1146,11 @@ func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) {
a.useForwardTSN = true
}
}
case *paramZeroChecksumAcceptable:
peerHasZeroChecksum = v.edmid == dtlsErrorDetectionMethod
}
}

if !a.useForwardTSN {
a.log.Warnf("[%s] not using ForwardTSN (on init)", a.name)
}
Expand All @@ -1129,6 +1177,12 @@ func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) {

initAck.params = []param{a.myCookie}

if peerHasZeroChecksum {
initAck.params = append(initAck.params, &paramZeroChecksumAcceptable{edmid: dtlsErrorDetectionMethod})
a.useZeroChecksum = true
}
a.log.Debugf("[%s] useZeroChecksum=%t (on init)", a.name, a.useZeroChecksum)

setSupportedExtensions(&initAck.chunkInitCommon)

outbound.chunks = []chunk{initAck}
Expand Down Expand Up @@ -1186,8 +1240,13 @@ func (a *Association) handleInitAck(p *packet, i *chunkInitAck) error {
a.useForwardTSN = true
}
}
case *paramZeroChecksumAcceptable:
a.useZeroChecksum = v.edmid == dtlsErrorDetectionMethod
}
}

a.log.Debugf("[%s] useZeroChecksum=%t (on initAck)", a.name, a.useZeroChecksum)

if !a.useForwardTSN {
a.log.Warnf("[%s] not using ForwardTSN (on initAck)", a.name)
}
Expand Down
90 changes: 87 additions & 3 deletions association_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1622,7 +1622,7 @@ func TestAssocT1CookieTimer(t *testing.T) {
// Drop all COOKIE-ECHO
br.Filter(0, func(raw []byte) bool {
p := &packet{}
err := p.unmarshal(raw)
err := p.unmarshal(true, raw)
if !assert.Nil(t, err, "failed to parse packet") {
return false // drop
}
Expand Down Expand Up @@ -2285,7 +2285,7 @@ func TestAssocAbort(t *testing.T) {
errorCauseHeader: errorCauseHeader{code: protocolViolation},
}},
}
packet, err := a0.createPacket([]chunk{abort}).marshal()
packet, err := a0.marshalPacket(a0.createPacket([]chunk{abort}))
assert.NoError(t, err)

_, _, err = establishSessionPair(br, a0, a1, si)
Expand Down Expand Up @@ -2964,7 +2964,7 @@ func TestAssociation_HandlePacketInCookieWaitState(t *testing.T) {
}()
}

packet, err := testCase.inputPacket.marshal()
packet, err := a.marshalPacket(testCase.inputPacket)
assert.NoError(t, err)
_, err = charlieConn.Write(packet)
assert.NoError(t, err)
Expand Down Expand Up @@ -3072,3 +3072,87 @@ loop:
assert.Error(t, err1, "context canceled")
assert.Error(t, err2, "context canceled")
}

type customLogger struct {
expectZeroChecksum bool
t *testing.T
}

func (c customLogger) Trace(string) {}
func (c customLogger) Tracef(string, ...interface{}) {}
func (c customLogger) Debug(string) {}
func (c customLogger) Debugf(format string, args ...interface{}) {
if format == "[%s] useZeroChecksum=%t (on initAck)" {
assert.Equal(c.t, args[1], c.expectZeroChecksum)
}
}
func (c customLogger) Info(string) {}
func (c customLogger) Infof(string, ...interface{}) {}
func (c customLogger) Warn(string) {}
func (c customLogger) Warnf(string, ...interface{}) {}
func (c customLogger) Error(string) {}
func (c customLogger) Errorf(string, ...interface{}) {}

func (c customLogger) NewLogger(string) logging.LeveledLogger {
return c
}

func TestAssociation_ZeroChecksum(t *testing.T) {
checkGoroutineLeaks(t)

lim := test.TimeOut(time.Second * 10)
defer lim.Stop()

for _, testCase := range []struct {
clientZeroChecksum, serverZeroChecksum, expectChecksumEnabled bool
}{
{true, true, true},
{false, false, false},
{true, false, true},
{false, true, false},
} {
a1chan, a2chan := make(chan *Association), make(chan *Association)

udp1, udp2 := createUDPConnPair()

go func() {
a1, err := Client(Config{
NetConn: udp1,
LoggerFactory: &customLogger{testCase.expectChecksumEnabled, t},
EnableZeroChecksum: testCase.clientZeroChecksum,
})
assert.NoError(t, err)
a1chan <- a1
}()

go func() {
a2, err := Server(Config{
NetConn: udp2,
LoggerFactory: &customLogger{testCase.expectChecksumEnabled, t},
EnableZeroChecksum: testCase.serverZeroChecksum,
})
assert.NoError(t, err)
a2chan <- a2
}()

a1, a2 := <-a1chan, <-a2chan

writeStream, err := a1.OpenStream(1, PayloadTypeWebRTCString)
require.NoError(t, err)

readStream, err := a2.OpenStream(1, PayloadTypeWebRTCString)
require.NoError(t, err)

testData := []byte("test")
_, err = writeStream.Write(testData)
require.NoError(t, err)

buf := make([]byte, len(testData))
_, err = readStream.Read(buf)
assert.NoError(t, err)
assert.Equal(t, testData, buf)

require.NoError(t, a1.Close())
require.NoError(t, a2.Close())
}
}
Loading

0 comments on commit 932b71a

Please sign in to comment.