diff --git a/association.go b/association.go index 466e620d..47c0d94e 100644 --- a/association.go +++ b/association.go @@ -645,22 +645,9 @@ func (a *Association) marshalPacket(p *packet) ([]byte, error) { func (a *Association) unmarshalPacket(raw []byte) (*packet, error) { p := &packet{} - if !a.useZeroChecksum { - if err := p.unmarshal(true, raw); err != nil { - return nil, err - } - return p, nil - } - - if err := p.unmarshal(false, raw); err != nil { + if err := p.unmarshal(!a.useZeroChecksum, raw); err != nil { return nil, err } - if chunkMandatoryChecksum(p.chunks) { - if err := p.unmarshal(true, raw); err != nil { - return nil, err - } - } - return p, nil } diff --git a/packet.go b/packet.go index 3bff1227..6d3cb3ca 100644 --- a/packet.go +++ b/packet.go @@ -70,11 +70,30 @@ func (p *packet) unmarshal(doChecksum bool, raw []byte) error { return fmt.Errorf("%w: raw only %d bytes, %d is the minimum length", ErrPacketRawTooSmall, len(raw), packetHeaderSize) } + offset := packetHeaderSize + + // Check if doing CRC32c is required. + // Without having SCTP AUTH implemented, this depends only on the type + // og the first chunk. + if offset+chunkHeaderSize <= len(raw) { + switch chunkType(raw[offset]) { + case ctInit, ctCookieEcho: + doChecksum = true + default: + } + } + theirChecksum := binary.LittleEndian.Uint32(raw[8:]) + if theirChecksum != 0 || doChecksum { + ourChecksum := generatePacketChecksum(raw) + if theirChecksum != ourChecksum { + return fmt.Errorf("%w: %d ours: %d", ErrChecksumMismatch, theirChecksum, ourChecksum) + } + } + p.sourcePort = binary.BigEndian.Uint16(raw[0:]) p.destinationPort = binary.BigEndian.Uint16(raw[2:]) p.verificationTag = binary.BigEndian.Uint32(raw[4:]) - offset := packetHeaderSize for { // Exact match, no more chunks if offset == len(raw) { @@ -126,14 +145,6 @@ func (p *packet) unmarshal(doChecksum bool, raw []byte) error { offset += chunkHeaderSize + c.valueLength() + chunkValuePadding } - theirChecksum := binary.LittleEndian.Uint32(raw[8:]) - if theirChecksum != 0 || doChecksum { - ourChecksum := generatePacketChecksum(raw) - if theirChecksum != ourChecksum { - return fmt.Errorf("%w: %d ours: %d", ErrChecksumMismatch, theirChecksum, ourChecksum) - } - } - return nil }