From a3a3a63c3c70f2a25e5d0447748155c18deb75eb Mon Sep 17 00:00:00 2001 From: Mathis Engelbart Date: Tue, 4 Jul 2023 16:27:10 +0200 Subject: [PATCH] Implement draft-ietf-tsvwg-sctp-zero-checksum-01 --- association.go | 87 +++++++++++++++++++++++++++++------ association_test.go | 90 +++++++++++++++++++++++++++++++++++-- chunk_test.go | 24 +++++----- packet.go | 25 +++++++---- packet_test.go | 10 ++--- param.go | 2 + param_zero_checksum.go | 53 ++++++++++++++++++++++ param_zero_checksum_test.go | 40 +++++++++++++++++ paramtype.go | 57 ++++++++++++----------- vnet_test.go | 2 +- 10 files changed, 319 insertions(+), 71 deletions(-) create mode 100644 param_zero_checksum.go create mode 100644 param_zero_checksum_test.go diff --git a/association.go b/association.go index 507e3aad..cf493ef7 100644 --- a/association.go +++ b/association.go @@ -177,6 +177,8 @@ type Association struct { cumulativeTSNAckPoint uint32 advancedPeerTSNAckPoint uint32 useForwardTSN bool + useZeroChecksum bool + requestZeroChecksum bool // Congestion control parameters maxReceiveBufferSize uint32 @@ -233,6 +235,7 @@ type Config struct { NetConn net.Conn MaxReceiveBufferSize uint32 MaxMessageSize uint32 + EnableZeroChecksum bool LoggerFactory logging.LoggerFactory } @@ -320,6 +323,7 @@ func createAssociation(config Config) *Association { handshakeCompletedCh: make(chan error), cumulativeTSNAckPoint: tsn - 1, advancedPeerTSNAckPoint: tsn - 1, + requestZeroChecksum: config.EnableZeroChecksum, silentError: ErrSilentlyDiscard, stats: &associationStats{}, log: config.LoggerFactory.NewLogger("sctp"), @@ -362,6 +366,11 @@ func (a *Association) init(isClient bool) { init.initiateTag = a.myVerificationTag init.advertisedReceiverWindowCredit = a.maxReceiveBufferSize setSupportedExtensions(&init.chunkInitCommon) + + if a.requestZeroChecksum { + init.params = append(init.params, ¶mZeroChecksumAcceptable{edmid: dtlsErrorDetectionMethod}) + } + a.storedInit = init err := a.sendInit() @@ -618,10 +627,45 @@ func (a *Association) unregisterStream(s *Stream, err error) { s.readNotifier.Broadcast() } +func chunkMandatoryChecksum(cc []chunk) bool { + for _, c := range cc { + switch c.(type) { + case *chunkInit, *chunkInitAck, *chunkCookieEcho: + return true + } + } + return false +} + +func (a *Association) marshalPacket(p *packet) ([]byte, error) { + return p.marshal(!a.useZeroChecksum || chunkMandatoryChecksum(p.chunks)) +} + +func (a *Association) unmarshalPacket(raw []byte) (*packet, error) { + p := &packet{} + if !a.useZeroChecksum { + if err := p.unmarshal(true, raw); err != nil { + return nil, err + } + return p, nil + } + + if err := p.unmarshal(false, raw); err != nil { + return nil, err + } + if chunkMandatoryChecksum(p.chunks) { + if err := p.unmarshal(true, raw); err != nil { + return nil, err + } + } + + return p, nil +} + // handleInbound parses incoming raw packets func (a *Association) handleInbound(raw []byte) error { - p := &packet{} - if err := p.unmarshal(raw); err != nil { + p, err := a.unmarshalPacket(raw) + if err != nil { a.log.Warnf("[%s] unable to parse SCTP packet %s", a.name, err) return nil } @@ -647,7 +691,7 @@ func (a *Association) handleInbound(raw []byte) error { // The caller should hold the lock func (a *Association) gatherDataPacketsToRetransmit(rawPackets [][]byte) [][]byte { for _, p := range a.getDataPacketsToRetransmit() { - raw, err := p.marshal() + raw, err := a.marshalPacket(p) if err != nil { a.log.Warnf("[%s] failed to serialize a DATA packet to be retransmitted", a.name) continue @@ -668,7 +712,7 @@ func (a *Association) gatherOutboundDataAndReconfigPackets(rawPackets [][]byte) a.log.Tracef("[%s] T3-rtx timer start (pt1)", a.name) a.t3RTX.start(a.rtoMgr.getRTO()) for _, p := range a.bundleDataChunksIntoPackets(chunks) { - raw, err := p.marshal() + raw, err := a.marshalPacket(p) if err != nil { a.log.Warnf("[%s] failed to serialize a DATA packet", a.name) continue @@ -683,7 +727,7 @@ func (a *Association) gatherOutboundDataAndReconfigPackets(rawPackets [][]byte) a.log.Debugf("[%s] retransmit %d RECONFIG chunk(s)", a.name, len(a.reconfigs)) for _, c := range a.reconfigs { p := a.createPacket([]chunk{c}) - raw, err := p.marshal() + raw, err := a.marshalPacket(p) if err != nil { a.log.Warnf("[%s] failed to serialize a RECONFIG packet to be retransmitted", a.name) } else { @@ -706,7 +750,7 @@ func (a *Association) gatherOutboundDataAndReconfigPackets(rawPackets [][]byte) a.log.Debugf("[%s] sending RECONFIG: rsn=%d tsn=%d streams=%v", a.name, rsn, a.myNextTSN-1, sisToReset) p := a.createPacket([]chunk{c}) - raw, err := p.marshal() + raw, err := a.marshalPacket(p) if err != nil { a.log.Warnf("[%s] failed to serialize a RECONFIG packet to be transmitted", a.name) } else { @@ -769,7 +813,7 @@ func (a *Association) gatherOutboundFastRetransmissionPackets(rawPackets [][]byt } if len(toFastRetrans) > 0 { - raw, err := a.createPacket(toFastRetrans).marshal() + raw, err := a.marshalPacket(a.createPacket(toFastRetrans)) if err != nil { a.log.Warnf("[%s] failed to serialize a DATA packet to be fast-retransmitted", a.name) } else { @@ -787,7 +831,7 @@ func (a *Association) gatherOutboundSackPackets(rawPackets [][]byte) [][]byte { a.ackState = ackStateIdle sack := a.createSelectiveAckChunk() a.log.Debugf("[%s] sending SACK: %s", a.name, sack) - raw, err := a.createPacket([]chunk{sack}).marshal() + raw, err := a.marshalPacket(a.createPacket([]chunk{sack})) if err != nil { a.log.Warnf("[%s] failed to serialize a SACK packet", a.name) } else { @@ -804,7 +848,7 @@ func (a *Association) gatherOutboundForwardTSNPackets(rawPackets [][]byte) [][]b a.willSendForwardTSN = false if sna32GT(a.advancedPeerTSNAckPoint, a.cumulativeTSNAckPoint) { fwdtsn := a.createForwardTSN() - raw, err := a.createPacket([]chunk{fwdtsn}).marshal() + raw, err := a.marshalPacket(a.createPacket([]chunk{fwdtsn})) if err != nil { a.log.Warnf("[%s] failed to serialize a Forward TSN packet", a.name) } else { @@ -827,7 +871,7 @@ func (a *Association) gatherOutboundShutdownPackets(rawPackets [][]byte) ([][]by cumulativeTSNAck: a.cumulativeTSNAckPoint, } - raw, err := a.createPacket([]chunk{shutdown}).marshal() + raw, err := a.marshalPacket(a.createPacket([]chunk{shutdown})) if err != nil { a.log.Warnf("[%s] failed to serialize a Shutdown packet", a.name) } else { @@ -839,7 +883,7 @@ func (a *Association) gatherOutboundShutdownPackets(rawPackets [][]byte) ([][]by shutdownAck := &chunkShutdownAck{} - raw, err := a.createPacket([]chunk{shutdownAck}).marshal() + raw, err := a.marshalPacket(a.createPacket([]chunk{shutdownAck})) if err != nil { a.log.Warnf("[%s] failed to serialize a ShutdownAck packet", a.name) } else { @@ -851,7 +895,7 @@ func (a *Association) gatherOutboundShutdownPackets(rawPackets [][]byte) ([][]by shutdownComplete := &chunkShutdownComplete{} - raw, err := a.createPacket([]chunk{shutdownComplete}).marshal() + raw, err := a.marshalPacket(a.createPacket([]chunk{shutdownComplete})) if err != nil { a.log.Warnf("[%s] failed to serialize a ShutdownComplete packet", a.name) } else { @@ -875,7 +919,7 @@ func (a *Association) gatherAbortPacket() ([]byte, error) { abort.errorCauses = []errorCause{cause} } - raw, err := a.createPacket([]chunk{abort}).marshal() + raw, err := a.marshalPacket(a.createPacket([]chunk{abort})) return raw, err } @@ -900,7 +944,7 @@ func (a *Association) gatherOutbound() ([][]byte, bool) { if a.controlQueue.size() > 0 { for _, p := range a.controlQueue.popAll() { - raw, err := p.marshal() + raw, err := a.marshalPacket(p) if err != nil { a.log.Warnf("[%s] failed to serialize a control packet", a.name) continue @@ -1092,6 +1136,7 @@ func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) { // subtracting one from it. a.peerLastTSN = i.initialTSN - 1 + peerHasZeroChecksum := false for _, param := range i.params { switch v := param.(type) { // nolint:gocritic case *paramSupportedExtensions: @@ -1101,8 +1146,11 @@ func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) { a.useForwardTSN = true } } + case *paramZeroChecksumAcceptable: + peerHasZeroChecksum = v.edmid == dtlsErrorDetectionMethod } } + if !a.useForwardTSN { a.log.Warnf("[%s] not using ForwardTSN (on init)", a.name) } @@ -1129,6 +1177,12 @@ func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) { initAck.params = []param{a.myCookie} + if peerHasZeroChecksum { + initAck.params = append(initAck.params, ¶mZeroChecksumAcceptable{edmid: dtlsErrorDetectionMethod}) + a.useZeroChecksum = true + } + a.log.Debugf("[%s] useZeroChecksum=%t (on init)", a.name, a.useZeroChecksum) + setSupportedExtensions(&initAck.chunkInitCommon) outbound.chunks = []chunk{initAck} @@ -1186,8 +1240,13 @@ func (a *Association) handleInitAck(p *packet, i *chunkInitAck) error { a.useForwardTSN = true } } + case *paramZeroChecksumAcceptable: + a.useZeroChecksum = v.edmid == dtlsErrorDetectionMethod } } + + a.log.Debugf("[%s] useZeroChecksum=%t (on initAck)", a.name, a.useZeroChecksum) + if !a.useForwardTSN { a.log.Warnf("[%s] not using ForwardTSN (on initAck)", a.name) } diff --git a/association_test.go b/association_test.go index 7e884772..2b6fb733 100644 --- a/association_test.go +++ b/association_test.go @@ -1622,7 +1622,7 @@ func TestAssocT1CookieTimer(t *testing.T) { // Drop all COOKIE-ECHO br.Filter(0, func(raw []byte) bool { p := &packet{} - err := p.unmarshal(raw) + err := p.unmarshal(true, raw) if !assert.Nil(t, err, "failed to parse packet") { return false // drop } @@ -2285,7 +2285,7 @@ func TestAssocAbort(t *testing.T) { errorCauseHeader: errorCauseHeader{code: protocolViolation}, }}, } - packet, err := a0.createPacket([]chunk{abort}).marshal() + packet, err := a0.marshalPacket(a0.createPacket([]chunk{abort})) assert.NoError(t, err) _, _, err = establishSessionPair(br, a0, a1, si) @@ -2964,7 +2964,7 @@ func TestAssociation_HandlePacketInCookieWaitState(t *testing.T) { }() } - packet, err := testCase.inputPacket.marshal() + packet, err := a.marshalPacket(testCase.inputPacket) assert.NoError(t, err) _, err = charlieConn.Write(packet) assert.NoError(t, err) @@ -3072,3 +3072,87 @@ loop: assert.Error(t, err1, "context canceled") assert.Error(t, err2, "context canceled") } + +type customLogger struct { + expectZeroChecksum bool + t *testing.T +} + +func (c customLogger) Trace(string) {} +func (c customLogger) Tracef(string, ...interface{}) {} +func (c customLogger) Debug(string) {} +func (c customLogger) Debugf(format string, args ...interface{}) { + if format == "[%s] useZeroChecksum=%t (on initAck)" { + assert.Equal(c.t, args[1], c.expectZeroChecksum) + } +} +func (c customLogger) Info(string) {} +func (c customLogger) Infof(string, ...interface{}) {} +func (c customLogger) Warn(string) {} +func (c customLogger) Warnf(string, ...interface{}) {} +func (c customLogger) Error(string) {} +func (c customLogger) Errorf(string, ...interface{}) {} + +func (c customLogger) NewLogger(string) logging.LeveledLogger { + return c +} + +func TestAssociation_ZeroChecksum(t *testing.T) { + checkGoroutineLeaks(t) + + lim := test.TimeOut(time.Second * 10) + defer lim.Stop() + + for _, testCase := range []struct { + clientZeroChecksum, serverZeroChecksum, expectChecksumEnabled bool + }{ + {true, true, true}, + {false, false, false}, + {true, false, true}, + {false, true, false}, + } { + a1chan, a2chan := make(chan *Association), make(chan *Association) + + udp1, udp2 := createUDPConnPair() + + go func() { + a1, err := Client(Config{ + NetConn: udp1, + LoggerFactory: &customLogger{testCase.expectChecksumEnabled, t}, + EnableZeroChecksum: testCase.clientZeroChecksum, + }) + assert.NoError(t, err) + a1chan <- a1 + }() + + go func() { + a2, err := Server(Config{ + NetConn: udp2, + LoggerFactory: &customLogger{testCase.expectChecksumEnabled, t}, + EnableZeroChecksum: testCase.serverZeroChecksum, + }) + assert.NoError(t, err) + a2chan <- a2 + }() + + a1, a2 := <-a1chan, <-a2chan + + writeStream, err := a1.OpenStream(1, PayloadTypeWebRTCString) + require.NoError(t, err) + + readStream, err := a2.OpenStream(1, PayloadTypeWebRTCString) + require.NoError(t, err) + + testData := []byte("test") + _, err = writeStream.Write(testData) + require.NoError(t, err) + + buf := make([]byte, len(testData)) + _, err = readStream.Read(buf) + assert.NoError(t, err) + assert.Equal(t, testData, buf) + + require.NoError(t, a1.Close()) + require.NoError(t, a2.Close()) + } +} diff --git a/chunk_test.go b/chunk_test.go index a060eb5e..c831e447 100644 --- a/chunk_test.go +++ b/chunk_test.go @@ -18,7 +18,7 @@ func TestInitChunk(t *testing.T) { 0xc9, 0xbf, 0x75, 0x9c, 0xb1, 0x2c, 0x57, 0x4f, 0xa4, 0x5a, 0x51, 0xba, 0x60, 0x17, 0x78, 0x27, 0x94, 0x5c, 0x31, 0xe6, 0x5d, 0x5b, 0x09, 0x47, 0xe2, 0x22, 0x06, 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00, } - err := pkt.unmarshal(rawPkt) + err := pkt.unmarshal(true, rawPkt) if err != nil { t.Errorf("Unmarshal failed, has chunk") } @@ -47,7 +47,7 @@ func TestInitChunk(t *testing.T) { func TestInitAck(t *testing.T) { pkt := &packet{} rawPkt := []byte{0x13, 0x88, 0x13, 0x88, 0xce, 0x15, 0x79, 0xa2, 0x96, 0x19, 0xe8, 0xb2, 0x02, 0x00, 0x00, 0x1c, 0xeb, 0x81, 0x4e, 0x01, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x50, 0xdf, 0x90, 0xd9, 0x00, 0x07, 0x00, 0x08, 0x94, 0x06, 0x2f, 0x93} - err := pkt.unmarshal(rawPkt) + err := pkt.unmarshal(true, rawPkt) if err != nil { t.Errorf("Unmarshal failed, has chunk: %v", err) } @@ -63,12 +63,12 @@ func TestInitAck(t *testing.T) { func TestChromeChunk1Init(t *testing.T) { pkt := &packet{} rawPkt := []byte{0x13, 0x88, 0x13, 0x88, 0x00, 0x00, 0x00, 0x00, 0xbc, 0xb3, 0x45, 0xa2, 0x01, 0x00, 0x00, 0x56, 0xce, 0x15, 0x79, 0xa2, 0x00, 0x02, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x94, 0x57, 0x95, 0xc0, 0xc0, 0x00, 0x00, 0x04, 0x80, 0x08, 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, 0x82, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x24, 0xff, 0x5c, 0x49, 0x19, 0x4a, 0x94, 0xe8, 0x2a, 0xec, 0x58, 0x55, 0x62, 0x29, 0x1f, 0x8e, 0x23, 0xcd, 0x7c, 0xe8, 0x46, 0xba, 0x58, 0x1b, 0x3d, 0xab, 0xd7, 0x7e, 0x50, 0xf2, 0x41, 0xb1, 0x2e, 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00} - err := pkt.unmarshal(rawPkt) + err := pkt.unmarshal(true, rawPkt) if err != nil { t.Errorf("Unmarshal failed, has chunk: %v", err) } - rawPkt2, err := pkt.marshal() + rawPkt2, err := pkt.marshal(true) if err != nil { t.Errorf("Remarshal failed: %v", err) } @@ -79,12 +79,12 @@ func TestChromeChunk1Init(t *testing.T) { func TestChromeChunk2InitAck(t *testing.T) { pkt := &packet{} rawPkt := []byte{0x13, 0x88, 0x13, 0x88, 0xce, 0x15, 0x79, 0xa2, 0xb5, 0xdb, 0x2d, 0x93, 0x02, 0x00, 0x01, 0x90, 0x9b, 0xd5, 0xb3, 0x6f, 0x00, 0x02, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0xef, 0xb4, 0x72, 0x87, 0xc0, 0x00, 0x00, 0x04, 0x80, 0x08, 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, 0x82, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x24, 0x2e, 0xf9, 0x9c, 0x10, 0x63, 0x72, 0xed, 0x0d, 0x33, 0xc2, 0xdc, 0x7f, 0x9f, 0xd7, 0xef, 0x1b, 0xc9, 0xc4, 0xa7, 0x41, 0x9a, 0x07, 0x68, 0x6b, 0x66, 0xfb, 0x6a, 0x4e, 0x32, 0x5d, 0xe4, 0x25, 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00, 0x00, 0x07, 0x01, 0x38, 0x4b, 0x41, 0x4d, 0x45, 0x2d, 0x42, 0x53, 0x44, 0x20, 0x31, 0x2e, 0x31, 0x00, 0x00, 0x00, 0x00, 0x9c, 0x1e, 0x49, 0x5b, 0x00, 0x00, 0x00, 0x00, 0xd2, 0x42, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x60, 0xea, 0x00, 0x00, 0xc4, 0x13, 0x3d, 0xe9, 0x86, 0xb1, 0x85, 0x75, 0xa2, 0x79, 0x15, 0xce, 0x9b, 0xd5, 0xb3, 0x6f, 0x20, 0xe0, 0x9f, 0x89, 0xe0, 0x27, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x20, 0xe0, 0x9f, 0x89, 0xe0, 0x27, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x13, 0x88, 0x13, 0x88, 0x00, 0x00, 0x01, 0x00, 0x01, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x56, 0xce, 0x15, 0x79, 0xa2, 0x00, 0x02, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x94, 0x57, 0x95, 0xc0, 0xc0, 0x00, 0x00, 0x04, 0x80, 0x08, 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, 0x82, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x24, 0xff, 0x5c, 0x49, 0x19, 0x4a, 0x94, 0xe8, 0x2a, 0xec, 0x58, 0x55, 0x62, 0x29, 0x1f, 0x8e, 0x23, 0xcd, 0x7c, 0xe8, 0x46, 0xba, 0x58, 0x1b, 0x3d, 0xab, 0xd7, 0x7e, 0x50, 0xf2, 0x41, 0xb1, 0x2e, 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00, 0x02, 0x00, 0x01, 0x90, 0x9b, 0xd5, 0xb3, 0x6f, 0x00, 0x02, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0xef, 0xb4, 0x72, 0x87, 0xc0, 0x00, 0x00, 0x04, 0x80, 0x08, 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, 0x82, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x24, 0x2e, 0xf9, 0x9c, 0x10, 0x63, 0x72, 0xed, 0x0d, 0x33, 0xc2, 0xdc, 0x7f, 0x9f, 0xd7, 0xef, 0x1b, 0xc9, 0xc4, 0xa7, 0x41, 0x9a, 0x07, 0x68, 0x6b, 0x66, 0xfb, 0x6a, 0x4e, 0x32, 0x5d, 0xe4, 0x25, 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00, 0xca, 0x0c, 0x21, 0x11, 0xce, 0xf4, 0xfc, 0xb3, 0x66, 0x99, 0x4f, 0xdb, 0x4f, 0x95, 0x6b, 0x6f, 0x3b, 0xb1, 0xdb, 0x5a} - err := pkt.unmarshal(rawPkt) + err := pkt.unmarshal(true, rawPkt) if err != nil { t.Errorf("Unmarshal failed, has chunk: %v", err) } - rawPkt2, err := pkt.marshal() + rawPkt2, err := pkt.marshal(true) if err != nil { t.Errorf("Remarshal failed: %v", err) } @@ -112,13 +112,13 @@ func TestInitMarshalUnmarshal(t *testing.T) { initAck.params = []param{cookie} p.chunks = []chunk{initAck} - rawPkt, err := p.marshal() + rawPkt, err := p.marshal(true) if err != nil { t.Errorf("Failed to marshal packet: %v", err) } pkt := &packet{} - err = pkt.unmarshal(rawPkt) + err = pkt.unmarshal(true, rawPkt) if err != nil { t.Errorf("Unmarshal failed, has chunk: %v", err) } @@ -147,7 +147,7 @@ func TestInitMarshalUnmarshal(t *testing.T) { func TestPayloadDataMarshalUnmarshal(t *testing.T) { pkt := &packet{} rawPkt := []byte{0x13, 0x88, 0x13, 0x88, 0xfc, 0xd6, 0x3f, 0xc6, 0xbe, 0xfa, 0xdc, 0x52, 0x0a, 0x00, 0x00, 0x24, 0x9b, 0x28, 0x7e, 0x48, 0xa3, 0x7b, 0xc1, 0x83, 0xc4, 0x4b, 0x41, 0x04, 0xa4, 0xf7, 0xed, 0x4c, 0x93, 0x62, 0xc3, 0x49, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x1f, 0xa8, 0x79, 0xa1, 0xc7, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x32, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x66, 0x6f, 0x6f, 0x00} - err := pkt.unmarshal(rawPkt) + err := pkt.unmarshal(true, rawPkt) if err != nil { t.Errorf("Unmarshal failed, has chunk: %v", err) } @@ -161,7 +161,7 @@ func TestPayloadDataMarshalUnmarshal(t *testing.T) { func TestSelectAckChunk(t *testing.T) { pkt := &packet{} rawPkt := []byte{0x13, 0x88, 0x13, 0x88, 0xc2, 0x98, 0x98, 0x0f, 0x42, 0x31, 0xea, 0x78, 0x03, 0x00, 0x00, 0x14, 0x87, 0x73, 0xbd, 0xa4, 0x00, 0x01, 0xfe, 0x74, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x02} - err := pkt.unmarshal(rawPkt) + err := pkt.unmarshal(true, rawPkt) if err != nil { t.Errorf("Unmarshal failed, has chunk: %v", err) } @@ -175,7 +175,7 @@ func TestSelectAckChunk(t *testing.T) { func TestReconfigChunk(t *testing.T) { pkt := &packet{} rawPkt := []byte{0x13, 0x88, 0x13, 0x88, 0xb6, 0xa5, 0x12, 0xe5, 0x75, 0x3b, 0x12, 0xd3, 0x82, 0x0, 0x0, 0x16, 0x0, 0xd, 0x0, 0x12, 0x4e, 0x1c, 0xb9, 0xe6, 0x3a, 0x74, 0x8d, 0xff, 0x4e, 0x1c, 0xb9, 0xe6, 0x0, 0x1, 0x0, 0x0} - err := pkt.unmarshal(rawPkt) + err := pkt.unmarshal(true, rawPkt) if err != nil { t.Errorf("Unmarshal failed, has chunk: %v", err) } @@ -193,7 +193,7 @@ func TestReconfigChunk(t *testing.T) { func TestForwardTSNChunk(t *testing.T) { pkt := &packet{} rawPkt := append([]byte{0x13, 0x88, 0x13, 0x88, 0xb6, 0xa5, 0x12, 0xe5, 0x1f, 0x9d, 0xa0, 0xfb}, testChunkForwardTSN()...) - err := pkt.unmarshal(rawPkt) + err := pkt.unmarshal(true, rawPkt) if err != nil { t.Errorf("Unmarshal failed, has chunk: %v", err) } diff --git a/packet.go b/packet.go index 2d40d295..316ca7d0 100644 --- a/packet.go +++ b/packet.go @@ -65,7 +65,7 @@ var ( ErrChecksumMismatch = errors.New("checksum mismatch theirs") ) -func (p *packet) unmarshal(raw []byte) error { +func (p *packet) unmarshal(doChecksum bool, raw []byte) error { if len(raw) < packetHeaderSize { return fmt.Errorf("%w: raw only %d bytes, %d is the minimum length", ErrPacketRawTooSmall, len(raw), packetHeaderSize) } @@ -125,15 +125,19 @@ func (p *packet) unmarshal(raw []byte) error { chunkValuePadding := getPadding(c.valueLength()) offset += chunkHeaderSize + c.valueLength() + chunkValuePadding } - theirChecksum := binary.LittleEndian.Uint32(raw[8:]) - ourChecksum := generatePacketChecksum(raw) - if theirChecksum != ourChecksum { - return fmt.Errorf("%w: %d ours: %d", ErrChecksumMismatch, theirChecksum, ourChecksum) + + if doChecksum { + theirChecksum := binary.LittleEndian.Uint32(raw[8:]) + ourChecksum := generatePacketChecksum(raw) + if theirChecksum != ourChecksum { + return fmt.Errorf("%w: %d ours: %d", ErrChecksumMismatch, theirChecksum, ourChecksum) + } } + return nil } -func (p *packet) marshal() ([]byte, error) { +func (p *packet) marshal(doChecksum bool) ([]byte, error) { raw := make([]byte, packetHeaderSize) // Populate static headers @@ -156,9 +160,12 @@ func (p *packet) marshal() ([]byte, error) { } } - // Checksum is already in BigEndian - // Using LittleEndian.PutUint32 stops it from being flipped - binary.LittleEndian.PutUint32(raw[8:], generatePacketChecksum(raw)) + if doChecksum { + // Checksum is already in BigEndian + // Using LittleEndian.PutUint32 stops it from being flipped + binary.LittleEndian.PutUint32(raw[8:], generatePacketChecksum(raw)) + } + return raw, nil } diff --git a/packet_test.go b/packet_test.go index e36d0647..1a270e53 100644 --- a/packet_test.go +++ b/packet_test.go @@ -11,12 +11,12 @@ import ( func TestPacketUnmarshal(t *testing.T) { pkt := &packet{} - if err := pkt.unmarshal([]byte{}); err == nil { + if err := pkt.unmarshal(true, []byte{}); err == nil { t.Errorf("Unmarshal should fail when a packet is too small to be SCTP") } headerOnly := []byte{0x13, 0x88, 0x13, 0x88, 0x00, 0x00, 0x00, 0x00, 0x06, 0xa9, 0x00, 0xe1} - err := pkt.unmarshal(headerOnly) + err := pkt.unmarshal(true, headerOnly) switch { case err != nil: t.Errorf("Unmarshal failed for SCTP packet with no chunks: %v", err) @@ -36,7 +36,7 @@ func TestPacketUnmarshal(t *testing.T) { 0x5d, 0x5b, 0x09, 0x47, 0xe2, 0x22, 0x06, 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00, } - if err := pkt.unmarshal(rawChunk); err != nil { + if err := pkt.unmarshal(true, rawChunk); err != nil { t.Errorf("Unmarshal failed, has chunk: %v", err) } } @@ -45,11 +45,11 @@ func TestPacketMarshal(t *testing.T) { pkt := &packet{} headerOnly := []byte{0x13, 0x88, 0x13, 0x88, 0x00, 0x00, 0x00, 0x00, 0x06, 0xa9, 0x00, 0xe1} - if err := pkt.unmarshal(headerOnly); err != nil { + if err := pkt.unmarshal(true, headerOnly); err != nil { t.Errorf("Unmarshal failed for SCTP packet with no chunks: %v", err) } - headerOnlyMarshaled, err := pkt.marshal() + headerOnlyMarshaled, err := pkt.marshal(true) if err != nil { t.Errorf("Marshal failed for SCTP packet with no chunks: %v", err) } else if !bytes.Equal(headerOnly, headerOnlyMarshaled) { diff --git a/param.go b/param.go index 8035add3..c28d7a5b 100644 --- a/param.go +++ b/param.go @@ -38,6 +38,8 @@ func buildParam(t paramType, rawParam []byte) (param, error) { return (¶mOutgoingResetRequest{}).unmarshal(rawParam) case reconfigResp: return (¶mReconfigResponse{}).unmarshal(rawParam) + case zeroChecksumAcceptable: + return (¶mZeroChecksumAcceptable{}).unmarshal(rawParam) default: return nil, fmt.Errorf("%w: %v", ErrParamTypeUnhandled, t) } diff --git a/param_zero_checksum.go b/param_zero_checksum.go new file mode 100644 index 00000000..258565c1 --- /dev/null +++ b/param_zero_checksum.go @@ -0,0 +1,53 @@ +package sctp + +import ( + "encoding/binary" + "errors" +) + +// This parameter is used to inform the receiver that a sender is willing to +// accept zero as checksum if some other error detection method is used +// instead. +// +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 0x8001 (suggested) | Length = 8 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Error Detection Method Identifier (EDMID) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + +type paramZeroChecksumAcceptable struct { + paramHeader + // The Error Detection Method Identifier (EDMID) specifies an alternate + // error detection method the sender of this parameter is willing to use for + // received packets. + edmid uint32 +} + +// Zero Checksum parameter error +var ( + ErrZeroChecksumParamTooShort = errors.New("zero checksum parameter too short") +) + +const ( + dtlsErrorDetectionMethod uint32 = 1 +) + +func (r *paramZeroChecksumAcceptable) marshal() ([]byte, error) { + r.typ = zeroChecksumAcceptable + r.raw = make([]byte, 4) + binary.BigEndian.PutUint32(r.raw, r.edmid) + return r.paramHeader.marshal() +} + +func (r *paramZeroChecksumAcceptable) unmarshal(raw []byte) (param, error) { + err := r.paramHeader.unmarshal(raw) + if err != nil { + return nil, err + } + if len(r.raw) < 4 { + return nil, ErrZeroChecksumParamTooShort + } + r.edmid = binary.BigEndian.Uint32(r.raw) + return r, nil +} diff --git a/param_zero_checksum_test.go b/param_zero_checksum_test.go new file mode 100644 index 00000000..7fc9396d --- /dev/null +++ b/param_zero_checksum_test.go @@ -0,0 +1,40 @@ +package sctp + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParamZeroChecksum(t *testing.T) { + tt := []struct { + binary []byte + parsed *paramZeroChecksumAcceptable + }{ + { + binary: []byte{0x80, 0x01, 0x00, 0x08, 0x00, 0x00, 0x00, 0x01}, + parsed: ¶mZeroChecksumAcceptable{ + paramHeader: paramHeader{ + typ: zeroChecksumAcceptable, + unrecognizedAction: paramHeaderUnrecognizedActionSkip, + len: 8, + raw: []byte{0x00, 0x00, 0x00, 0x01}, + }, + edmid: 1, + }, + }, + } + + for i, tc := range tt { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + actual := ¶mZeroChecksumAcceptable{} + _, err := actual.unmarshal(tc.binary) + assert.NoError(t, err) + assert.Equal(t, tc.parsed, actual) + b, err := actual.marshal() + assert.NoError(t, err) + assert.Equal(t, tc.binary, b) + }) + } +} diff --git a/paramtype.go b/paramtype.go index f0a3da38..de1b3dd9 100644 --- a/paramtype.go +++ b/paramtype.go @@ -13,33 +13,34 @@ import ( type paramType uint16 const ( - heartbeatInfo paramType = 1 // Heartbeat Info [RFC4960] - ipV4Addr paramType = 5 // IPv4 IP [RFC4960] - ipV6Addr paramType = 6 // IPv6 IP [RFC4960] - stateCookie paramType = 7 // State Cookie [RFC4960] - unrecognizedParam paramType = 8 // Unrecognized Parameters [RFC4960] - cookiePreservative paramType = 9 // Cookie Preservative [RFC4960] - hostNameAddr paramType = 11 // Host Name IP [RFC4960] - supportedAddrTypes paramType = 12 // Supported IP Types [RFC4960] - outSSNResetReq paramType = 13 // Outgoing SSN Reset Request Parameter [RFC6525] - incSSNResetReq paramType = 14 // Incoming SSN Reset Request Parameter [RFC6525] - ssnTSNResetReq paramType = 15 // SSN/TSN Reset Request Parameter [RFC6525] - reconfigResp paramType = 16 // Re-configuration Response Parameter [RFC6525] - addOutStreamsReq paramType = 17 // Add Outgoing Streams Request Parameter [RFC6525] - addIncStreamsReq paramType = 18 // Add Incoming Streams Request Parameter [RFC6525] - ecnCapable paramType = 32768 // ECN Capable (0x8000) [RFC2960] - random paramType = 32770 // Random (0x8002) [RFC4805] - chunkList paramType = 32771 // Chunk List (0x8003) [RFC4895] - reqHMACAlgo paramType = 32772 // Requested HMAC Algorithm Parameter (0x8004) [RFC4895] - padding paramType = 32773 // Padding (0x8005) - supportedExt paramType = 32776 // Supported Extensions (0x8008) [RFC5061] - forwardTSNSupp paramType = 49152 // Forward TSN supported (0xC000) [RFC3758] - addIPAddr paramType = 49153 // Add IP IP (0xC001) [RFC5061] - delIPAddr paramType = 49154 // Delete IP IP (0xC002) [RFC5061] - errClauseInd paramType = 49155 // Error Cause Indication (0xC003) [RFC5061] - setPriAddr paramType = 49156 // Set Primary IP (0xC004) [RFC5061] - successInd paramType = 49157 // Success Indication (0xC005) [RFC5061] - adaptLayerInd paramType = 49158 // Adaptation Layer Indication (0xC006) [RFC5061] + heartbeatInfo paramType = 1 // Heartbeat Info [RFC4960] + ipV4Addr paramType = 5 // IPv4 IP [RFC4960] + ipV6Addr paramType = 6 // IPv6 IP [RFC4960] + stateCookie paramType = 7 // State Cookie [RFC4960] + unrecognizedParam paramType = 8 // Unrecognized Parameters [RFC4960] + cookiePreservative paramType = 9 // Cookie Preservative [RFC4960] + hostNameAddr paramType = 11 // Host Name IP [RFC4960] + supportedAddrTypes paramType = 12 // Supported IP Types [RFC4960] + outSSNResetReq paramType = 13 // Outgoing SSN Reset Request Parameter [RFC6525] + incSSNResetReq paramType = 14 // Incoming SSN Reset Request Parameter [RFC6525] + ssnTSNResetReq paramType = 15 // SSN/TSN Reset Request Parameter [RFC6525] + reconfigResp paramType = 16 // Re-configuration Response Parameter [RFC6525] + addOutStreamsReq paramType = 17 // Add Outgoing Streams Request Parameter [RFC6525] + addIncStreamsReq paramType = 18 // Add Incoming Streams Request Parameter [RFC6525] + ecnCapable paramType = 32768 // ECN Capable (0x8000) [RFC2960] + zeroChecksumAcceptable paramType = 32769 // Zero Checksum Acceptable [draft-ietf-tsvwg-sctp-zero-checksum-00] + random paramType = 32770 // Random (0x8002) [RFC4805] + chunkList paramType = 32771 // Chunk List (0x8003) [RFC4895] + reqHMACAlgo paramType = 32772 // Requested HMAC Algorithm Parameter (0x8004) [RFC4895] + padding paramType = 32773 // Padding (0x8005) + supportedExt paramType = 32776 // Supported Extensions (0x8008) [RFC5061] + forwardTSNSupp paramType = 49152 // Forward TSN supported (0xC000) [RFC3758] + addIPAddr paramType = 49153 // Add IP IP (0xC001) [RFC5061] + delIPAddr paramType = 49154 // Delete IP IP (0xC002) [RFC5061] + errClauseInd paramType = 49155 // Error Cause Indication (0xC003) [RFC5061] + setPriAddr paramType = 49156 // Set Primary IP (0xC004) [RFC5061] + successInd paramType = 49157 // Success Indication (0xC005) [RFC5061] + adaptLayerInd paramType = 49158 // Adaptation Layer Indication (0xC006) [RFC5061] ) // Parameter packet errors @@ -86,6 +87,8 @@ func (p paramType) String() string { return "Add Incoming Streams Request Parameter" case ecnCapable: return "ECN Capable" + case zeroChecksumAcceptable: + return "Zero Checksum Acceptable" case random: return "Random" case chunkList: diff --git a/vnet_test.go b/vnet_test.go index edbf6549..6f225bbd 100644 --- a/vnet_test.go +++ b/vnet_test.go @@ -76,7 +76,7 @@ func buildVNetEnv(cfg *vNetEnvConfig) (*vNetEnv, error) { return func(c vnet.Chunk) bool { var toDrop bool p := &packet{} - if err2 := p.unmarshal(c.UserData()); err2 != nil { + if err2 := p.unmarshal(true, c.UserData()); err2 != nil { panic(errSCTPPacketParse) }