diff --git a/pkg/flexfec/encoder_interceptor.go b/pkg/flexfec/encoder_interceptor.go new file mode 100644 index 00000000..22c6850b --- /dev/null +++ b/pkg/flexfec/encoder_interceptor.go @@ -0,0 +1,79 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package flexfec + +import ( + "github.com/pion/interceptor" + "github.com/pion/rtp" +) + +// FecInterceptor implements FlexFec. +type FecInterceptor struct { + interceptor.NoOp + flexFecEncoder FlexEncoder + packetBuffer []rtp.Packet + minNumMediaPackets uint32 +} + +// FecOption can be used to set initial options on Fec encoder interceptors. +type FecOption func(d *FecInterceptor) error + +// FecInterceptorFactory creates new FecInterceptors. +type FecInterceptorFactory struct { + opts []FecOption +} + +// NewFecInterceptor returns a new Fec interceptor factory. +func NewFecInterceptor(opts ...FecOption) (*FecInterceptorFactory, error) { + return &FecInterceptorFactory{opts: opts}, nil +} + +// NewInterceptor constructs a new FecInterceptor. +func (r *FecInterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor, error) { + // Hardcoded for now: + // Min num media packets to encode FEC -> 5 + // Min num fec packets -> 1 + + interceptor := &FecInterceptor{ + packetBuffer: make([]rtp.Packet, 0), + minNumMediaPackets: 5, + } + return interceptor, nil +} + +// BindLocalStream lets you modify any outgoing RTP packets. It is called once for per LocalStream. The returned method +// will be called once per rtp packet. +func (r *FecInterceptor) BindLocalStream(info *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter { + // Chromium supports version flexfec-03 of existing draft, this is the one we will configure by default + // although we should support configuring the latest (flexfec-20) as well. + r.flexFecEncoder = NewFlexEncoder03(info.PayloadType, info.SSRC) + + return interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { + r.packetBuffer = append(r.packetBuffer, rtp.Packet{ + Header: *header, + Payload: payload, + }) + + // Send the media RTP packet + result, err := writer.Write(header, payload, attributes) + + // Send the FEC packets + var fecPackets []rtp.Packet + if len(r.packetBuffer) == int(r.minNumMediaPackets) { + fecPackets = r.flexFecEncoder.EncodeFec(r.packetBuffer, 2) + + for _, fecPacket := range fecPackets { + fecResult, fecErr := writer.Write(&fecPacket.Header, fecPacket.Payload, attributes) + + if fecErr != nil && fecResult == 0 { + break + } + } + // Reset the packet buffer now that we've sent the corresponding FEC packets. + r.packetBuffer = nil + } + + return result, err + }) +} diff --git a/pkg/flexfec/flexfec_coverage.go b/pkg/flexfec/flexfec_coverage.go index 580bdf7e..c1e463dc 100644 --- a/pkg/flexfec/flexfec_coverage.go +++ b/pkg/flexfec/flexfec_coverage.go @@ -42,9 +42,45 @@ func NewCoverage(mediaPackets []rtp.Packet, numFecPackets uint32) *ProtectionCov // We allocate the biggest array of bitmasks that respects the max constraints. var packetMasks [MaxFecPackets]util.BitArray for i := 0; i < int(MaxFecPackets); i++ { - packetMasks[i] = util.NewBitArray(MaxMediaPackets) + packetMasks[i] = util.BitArray{} } + coverage := &ProtectionCoverage{ + packetMasks: packetMasks, + numFecPackets: 0, + numMediaPackets: 0, + mediaPackets: nil, + } + + coverage.UpdateCoverage(mediaPackets, numFecPackets) + return coverage +} + +// UpdateCoverage updates the ProtectionCoverage object with new bitmasks accounting for the numFecPackets +// we want to use to protect the batch media packets. +func (p *ProtectionCoverage) UpdateCoverage(mediaPackets []rtp.Packet, numFecPackets uint32) { + numMediaPackets := uint32(len(mediaPackets)) + + // Basic sanity checks + if numMediaPackets <= 0 || numMediaPackets > MaxMediaPackets { + return + } + + p.mediaPackets = mediaPackets + + if numFecPackets == p.numFecPackets && numMediaPackets == p.numMediaPackets { + // We have the same number of FEC packets covering the same number of media packets, we can simply + // reuse the previous coverage map with the updated media packets. + return + } + + p.numFecPackets = numFecPackets + p.numMediaPackets = numMediaPackets + + // The number of FEC packets and/or the number of packets has changed, we need to update the coverage map + // to reflect these new values. + p.resetCoverage() + // Generate FEC bit mask where numFecPackets FEC packets are covering numMediaPackets Media packets. // In the packetMasks array, each FEC packet is represented by a single BitArray, each bit in a given BitArray // corresponds to a specific Media packet. @@ -52,26 +88,18 @@ func NewCoverage(mediaPackets []rtp.Packet, numFecPackets uint32) *ProtectionCov for fecPacketIndex := uint32(0); fecPacketIndex < numFecPackets; fecPacketIndex++ { // We use an interleaved method to determine coverage. Given N FEC packets, Media packet X will be // covered by FEC packet X % N. - for mediaPacketIndex := uint32(0); mediaPacketIndex < numMediaPackets; mediaPacketIndex++ { - coveringFecPktIndex := mediaPacketIndex % numFecPackets - packetMasks[coveringFecPktIndex].SetBit(mediaPacketIndex, 1) + coveredMediaPacketIndex := fecPacketIndex + for coveredMediaPacketIndex < numMediaPackets { + p.packetMasks[fecPacketIndex].SetBit(coveredMediaPacketIndex) + coveredMediaPacketIndex += numFecPackets } } - - return &ProtectionCoverage{ - packetMasks: packetMasks, - numFecPackets: numFecPackets, - numMediaPackets: numMediaPackets, - mediaPackets: mediaPackets, - } } // ResetCoverage clears the underlying map so that we can reuse it for new batches of RTP packets. -func (p *ProtectionCoverage) ResetCoverage() { +func (p *ProtectionCoverage) resetCoverage() { for i := uint32(0); i < MaxFecPackets; i++ { - for j := uint32(0); j < MaxMediaPackets; j++ { - p.packetMasks[i].SetBit(j, 0) - } + p.packetMasks[i].Reset() } } @@ -86,26 +114,46 @@ func (p *ProtectionCoverage) GetCoveredBy(fecPacketIndex uint32) *util.MediaPack return util.NewMediaPacketIterator(p.mediaPackets, coverage) } -// MarshalBitmasks returns the underlying bitmask that defines which media packets are protected by the -// specified fecPacketIndex. -func (p *ProtectionCoverage) MarshalBitmasks(fecPacketIndex uint32) []byte { - return p.packetMasks[fecPacketIndex].Marshal() -} - // ExtractMask1 returns the first section of the bitmask as defined by the FEC header. // https://datatracker.ietf.org/doc/html/rfc8627#section-4.2.2.1 func (p *ProtectionCoverage) ExtractMask1(fecPacketIndex uint32) uint16 { - return uint16(p.packetMasks[fecPacketIndex].GetBitValue(0, 14)) + mask := p.packetMasks[fecPacketIndex] + // We get the first 16 bits (64 - 16 -> shift by 48) and we shift once more for K field + mask1 := mask.Lo >> 49 + return uint16(mask1) } // ExtractMask2 returns the second section of the bitmask as defined by the FEC header. // https://datatracker.ietf.org/doc/html/rfc8627#section-4.2.2.1 func (p *ProtectionCoverage) ExtractMask2(fecPacketIndex uint32) uint32 { - return uint32(p.packetMasks[fecPacketIndex].GetBitValue(15, 45)) + mask := p.packetMasks[fecPacketIndex] + // We remove the first 15 bits + mask2 := mask.Lo << 15 + // We get the first 31 bits (64 - 31 -> shift by 33) and we shift once more for K field + mask2 >>= 34 + return uint32(mask2) } // ExtractMask3 returns the third section of the bitmask as defined by the FEC header. // https://datatracker.ietf.org/doc/html/rfc8627#section-4.2.2.1 func (p *ProtectionCoverage) ExtractMask3(fecPacketIndex uint32) uint64 { - return p.packetMasks[fecPacketIndex].GetBitValue(46, 109) + mask := p.packetMasks[fecPacketIndex] + // We remove the first 46 bits + maskLo := mask.Lo << 46 + maskHi := mask.Hi >> 18 + mask3 := maskLo | maskHi + return mask3 +} + +// ExtractMask3_03 returns the third section of the bitmask as defined by the FEC header. +// https://datatracker.ietf.org/doc/html/draft-ietf-payload-flexible-fec-scheme-03#section-4.2 +func (p *ProtectionCoverage) ExtractMask3_03(fecPacketIndex uint32) uint64 { + mask := p.packetMasks[fecPacketIndex] + // We remove the first 46 bits + maskLo := mask.Lo << 46 + maskHi := mask.Hi >> 18 + mask3 := maskLo | maskHi + // We shift once for the K bit. + mask3 >>= 1 + return mask3 } diff --git a/pkg/flexfec/flexfec_encoder.go b/pkg/flexfec/flexfec_encoder.go index e1531a47..479832e5 100644 --- a/pkg/flexfec/flexfec_encoder.go +++ b/pkg/flexfec/flexfec_encoder.go @@ -20,33 +20,37 @@ const ( BaseFecHeaderSize = 12 ) -// FlexEncoder implements the Fec encoding mechanism for the "Flex" variant of FlexFec. -type FlexEncoder struct { - baseSN uint16 +// FlexEncoder is the interface that FecInterceptor uses to encode Fec packets. +type FlexEncoder interface { + EncodeFec(mediaPackets []rtp.Packet, numFecPackets uint32) []rtp.Packet +} + +// FlexEncoder20 implements the Fec encoding mechanism for the "Flex" variant of FlexFec. +type FlexEncoder20 struct { + fecBaseSn uint16 payloadType uint8 ssrc uint32 coverage *ProtectionCoverage } // NewFlexEncoder returns a new FlexFecEncer. -func NewFlexEncoder(baseSN uint16, payloadType uint8, ssrc uint32) *FlexEncoder { - return &FlexEncoder{ - baseSN: baseSN, +func NewFlexEncoder(payloadType uint8, ssrc uint32) *FlexEncoder20 { + return &FlexEncoder20{ payloadType: payloadType, ssrc: ssrc, - coverage: nil, + fecBaseSn: uint16(1000), } } // EncodeFec returns a list of generated RTP packets with FEC payloads that protect the specified mediaPackets. // This method does not account for missing RTP packets in the mediaPackets array nor does it account for // them being passed out of order. -func (flex *FlexEncoder) EncodeFec(mediaPackets []rtp.Packet, numFecPackets uint32) []rtp.Packet { +func (flex *FlexEncoder20) EncodeFec(mediaPackets []rtp.Packet, numFecPackets uint32) []rtp.Packet { // Start by defining which FEC packets cover which media packets if flex.coverage == nil { flex.coverage = NewCoverage(mediaPackets, numFecPackets) } else { - flex.coverage.ResetCoverage() + flex.coverage.UpdateCoverage(mediaPackets, numFecPackets) } if flex.coverage == nil { @@ -56,39 +60,42 @@ func (flex *FlexEncoder) EncodeFec(mediaPackets []rtp.Packet, numFecPackets uint // Generate FEC payloads fecPackets := make([]rtp.Packet, numFecPackets) for fecPacketIndex := uint32(0); fecPacketIndex < numFecPackets; fecPacketIndex++ { - fecPackets[fecPacketIndex] = flex.encodeFlexFecPacket(fecPacketIndex) + fecPackets[fecPacketIndex] = flex.encodeFlexFecPacket(fecPacketIndex, mediaPackets[0].SequenceNumber) } return fecPackets } -func (flex *FlexEncoder) encodeFlexFecPacket(fecPacketIndex uint32) rtp.Packet { +func (flex *FlexEncoder20) encodeFlexFecPacket(fecPacketIndex uint32, mediaBaseSn uint16) rtp.Packet { mediaPacketsIt := flex.coverage.GetCoveredBy(fecPacketIndex) flexFecHeader := flex.encodeFlexFecHeader( mediaPacketsIt, flex.coverage.ExtractMask1(fecPacketIndex), flex.coverage.ExtractMask2(fecPacketIndex), flex.coverage.ExtractMask3(fecPacketIndex), + mediaBaseSn, ) flexFecRepairPayload := flex.encodeFlexFecRepairPayload(mediaPacketsIt.Reset()) - return rtp.Packet{ + packet := rtp.Packet{ Header: rtp.Header{ Version: 2, Padding: false, Extension: false, Marker: false, PayloadType: flex.payloadType, - SequenceNumber: flex.baseSN, + SequenceNumber: flex.fecBaseSn, Timestamp: 54243243, SSRC: flex.ssrc, CSRC: []uint32{}, }, Payload: append(flexFecHeader, flexFecRepairPayload...), } + flex.fecBaseSn++ + return packet } -func (flex *FlexEncoder) encodeFlexFecHeader(mediaPackets *util.MediaPacketIterator, mask1 uint16, optionalMask2 uint32, optionalMask3 uint64) []byte { +func (flex *FlexEncoder20) encodeFlexFecHeader(mediaPackets *util.MediaPacketIterator, mask1 uint16, optionalMask2 uint32, optionalMask3 uint64, mediaBaseSn uint16) []byte { /* 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 @@ -119,7 +126,7 @@ func (flex *FlexEncoder) encodeFlexFecHeader(mediaPackets *util.MediaPacketItera headerSize += 8 } - // Allocate a the FlexFec header + // Allocate the FlexFec header flexFecHeader := make([]byte, headerSize) // XOR the relevant fields for the header @@ -149,6 +156,9 @@ func (flex *FlexEncoder) encodeFlexFecHeader(mediaPackets *util.MediaPacketItera flexFecHeader[7] ^= flexFecHeader[7] } + // Write the base SN for the batch of media packets + binary.BigEndian.PutUint16(flexFecHeader[8:10], mediaBaseSn) + // Write the bitmasks to the header binary.BigEndian.PutUint16(flexFecHeader[10:12], mask1) @@ -163,7 +173,7 @@ func (flex *FlexEncoder) encodeFlexFecHeader(mediaPackets *util.MediaPacketItera return flexFecHeader } -func (flex *FlexEncoder) encodeFlexFecRepairPayload(mediaPackets *util.MediaPacketIterator) []byte { +func (flex *FlexEncoder20) encodeFlexFecRepairPayload(mediaPackets *util.MediaPacketIterator) []byte { flexFecPayload := make([]byte, len(mediaPackets.First().Payload)) for mediaPackets.HasNext() { diff --git a/pkg/flexfec/flexfec_encoder_03.go b/pkg/flexfec/flexfec_encoder_03.go new file mode 100644 index 00000000..15958a4b --- /dev/null +++ b/pkg/flexfec/flexfec_encoder_03.go @@ -0,0 +1,213 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package flexfec implements FlexFEC to recover missing RTP packets due to packet loss. +// https://datatracker.ietf.org/doc/html/rfc8627 +package flexfec + +import ( + "encoding/binary" + + "github.com/pion/interceptor/pkg/flexfec/util" + "github.com/pion/rtp" +) + +const ( + // BaseFec03HeaderSize represents the minium FEC payload's header size including the + // required first mask. + BaseFec03HeaderSize = 20 +) + +// FlexEncoder03 implements the Fec encoding mechanism for the "Flex" variant of FlexFec. +type FlexEncoder03 struct { + fecBaseSn uint16 + payloadType uint8 + ssrc uint32 + coverage *ProtectionCoverage +} + +// NewFlexEncoder03 returns a new FlexFecEncoder. +func NewFlexEncoder03(payloadType uint8, ssrc uint32) *FlexEncoder03 { + return &FlexEncoder03{ + payloadType: payloadType, + ssrc: ssrc, + fecBaseSn: uint16(1000), + } +} + +// EncodeFec returns a list of generated RTP packets with FEC payloads that protect the specified mediaPackets. +// This method does not account for missing RTP packets in the mediaPackets array nor does it account for +// them being passed out of order. +func (flex *FlexEncoder03) EncodeFec(mediaPackets []rtp.Packet, numFecPackets uint32) []rtp.Packet { + // Start by defining which FEC packets cover which media packets + if flex.coverage == nil { + flex.coverage = NewCoverage(mediaPackets, numFecPackets) + } else { + flex.coverage.UpdateCoverage(mediaPackets, numFecPackets) + } + + if flex.coverage == nil { + return nil + } + + // Generate FEC payloads + fecPackets := make([]rtp.Packet, numFecPackets) + for fecPacketIndex := uint32(0); fecPacketIndex < numFecPackets; fecPacketIndex++ { + fecPackets[fecPacketIndex] = flex.encodeFlexFecPacket(fecPacketIndex, mediaPackets[0].SequenceNumber) + } + + return fecPackets +} + +func (flex *FlexEncoder03) encodeFlexFecPacket(fecPacketIndex uint32, mediaBaseSn uint16) rtp.Packet { + mediaPacketsIt := flex.coverage.GetCoveredBy(fecPacketIndex) + flexFecHeader := flex.encodeFlexFecHeader( + mediaPacketsIt, + flex.coverage.ExtractMask1(fecPacketIndex), + flex.coverage.ExtractMask2(fecPacketIndex), + flex.coverage.ExtractMask3_03(fecPacketIndex), + mediaBaseSn, + ) + flexFecRepairPayload := flex.encodeFlexFecRepairPayload(mediaPacketsIt.Reset()) + + packet := rtp.Packet{ + Header: rtp.Header{ + Version: 2, + Padding: false, + Extension: false, + Marker: false, + PayloadType: flex.payloadType, + SequenceNumber: flex.fecBaseSn, + Timestamp: 54243243, + SSRC: flex.ssrc, + CSRC: []uint32{}, + }, + Payload: append(flexFecHeader, flexFecRepairPayload...), + } + flex.fecBaseSn++ + return packet +} + +func (flex *FlexEncoder03) encodeFlexFecHeader(mediaPackets *util.MediaPacketIterator, mask1 uint16, optionalMask2 uint32, optionalMask3 uint64, mediaBaseSn uint16) []byte { + /* + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + |0|0| P|X| CC |M| PT recovery | length recovery | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | TS recovery | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | SSRCCount | reserved | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | SSRC_i | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | SN base_i |k| Mask [0-14] | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + |k| Mask [15-45] (optional) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + |k| | + +-+ Mask [46-108] (optional) | + | | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | ... next in SSRC_i ... | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ + + // Get header size - This depends on the size of the bitmask. + headerSize := BaseFec03HeaderSize + if optionalMask2 > 0 { + headerSize += 4 + } + if optionalMask3 > 0 { + headerSize += 8 + } + + // Allocate the FlexFec header + flexFecHeader := make([]byte, headerSize) + + // We allocate a single temporary buffer to store the mediaPacket bytes. This reduces + // overall allocations. + tmpMediaPacketBuf := make([]byte, 0) + for mediaPackets.HasNext() { + mediaPacket := mediaPackets.Next() + + if mediaPacket.MarshalSize() > len(tmpMediaPacketBuf) { + // The temporary buffer is too small, we need to resize. + tmpMediaPacketBuf = make([]byte, mediaPacket.MarshalSize()) + } + n, err := mediaPacket.MarshalTo(tmpMediaPacketBuf) + + if n == 0 || err != nil { + return nil + } + + // XOR the first 2 bytes of the header: V, P, X, CC, M, PT fields + flexFecHeader[0] ^= tmpMediaPacketBuf[0] + flexFecHeader[1] ^= tmpMediaPacketBuf[1] + + // Clear the first 2 bits + flexFecHeader[0] &= 0b00111111 + + // XOR the length recovery field + lengthRecoveryVal := uint16(mediaPacket.MarshalSize() - BaseRTPHeaderSize) + flexFecHeader[2] ^= uint8(lengthRecoveryVal >> 8) + flexFecHeader[3] ^= uint8(lengthRecoveryVal) + + // XOR the 5th to 8th bytes of the header: the timestamp field + flexFecHeader[4] ^= tmpMediaPacketBuf[4] + flexFecHeader[5] ^= tmpMediaPacketBuf[5] + flexFecHeader[6] ^= tmpMediaPacketBuf[6] + flexFecHeader[7] ^= tmpMediaPacketBuf[7] + } + + // Write the SSRC count + flexFecHeader[8] = 1 + + // Write 0s in reserved + flexFecHeader[9] = 0 + flexFecHeader[10] = 0 + flexFecHeader[11] = 0 + + // Write the SSRC of media packets protected by this FEC packet + binary.BigEndian.PutUint32(flexFecHeader[12:16], mediaPackets.First().SSRC) + + // Write the base SN for the batch of media packets + binary.BigEndian.PutUint16(flexFecHeader[16:18], mediaBaseSn) + + // Write the bitmasks to the header + binary.BigEndian.PutUint16(flexFecHeader[18:20], mask1) + + if optionalMask2 == 0 { + flexFecHeader[18] |= 0b10000000 + return flexFecHeader + } + binary.BigEndian.PutUint32(flexFecHeader[20:24], optionalMask2) + + if optionalMask3 == 0 { + flexFecHeader[20] |= 0b10000000 + } else { + binary.BigEndian.PutUint64(flexFecHeader[24:32], optionalMask3) + } + + return flexFecHeader +} + +func (flex *FlexEncoder03) encodeFlexFecRepairPayload(mediaPackets *util.MediaPacketIterator) []byte { + flexFecPayload := make([]byte, len(mediaPackets.First().Payload)) + + for mediaPackets.HasNext() { + mediaPacketPayload := mediaPackets.Next().Payload + + if len(flexFecPayload) < len(mediaPacketPayload) { + // Expected FEC packet payload is bigger that what we can currently store, + // we need to resize. + flexFecPayloadTmp := make([]byte, len(mediaPacketPayload)) + copy(flexFecPayloadTmp, flexFecPayload) + flexFecPayload = flexFecPayloadTmp + } + for byteIndex := 0; byteIndex < len(mediaPacketPayload); byteIndex++ { + flexFecPayload[byteIndex] ^= mediaPacketPayload[byteIndex] + } + } + return flexFecPayload +} diff --git a/pkg/flexfec/util/bitarray.go b/pkg/flexfec/util/bitarray.go index d24efc20..c081f0a4 100644 --- a/pkg/flexfec/util/bitarray.go +++ b/pkg/flexfec/util/bitarray.go @@ -6,73 +6,40 @@ package util // BitArray provides support for bitmask manipulations. type BitArray struct { - bytes []byte + Lo uint64 // leftmost 64 bits + Hi uint64 // rightmost 64 bits } -// NewBitArray returns a new BitArray. It takes sizeBits as parameter which represents -// the size of the underlying bitmask. -func NewBitArray(sizeBits uint32) BitArray { - var sizeBytes uint32 - if sizeBits%8 == 0 { - sizeBytes = sizeBits / 8 +// SetBit sets a bit to the specified bit value on the bitmask. +func (b *BitArray) SetBit(bitIndex uint32) { + if bitIndex < 64 { + b.Lo |= uint64(0b1) << (63 - bitIndex) } else { - sizeBytes = sizeBits/8 + 1 - } - - return BitArray{ - bytes: make([]byte, sizeBytes), + hiBitIndex := bitIndex - 64 + b.Hi |= uint64(0b1) << (63 - hiBitIndex) } } -// SetBit sets a bit to the specified bit value on the bitmask. -func (b *BitArray) SetBit(bitIndex uint32, bitValue uint32) { - byteIndex := bitIndex / 8 - bitOffset := uint(bitIndex % 8) - - // Set the specific bit to 1 using bitwise OR - if bitValue == 1 { - b.bytes[byteIndex] |= 1 << bitOffset - } else { - b.bytes[byteIndex] |= 0 << bitOffset - } +// Reset clears the bitmask. +func (b *BitArray) Reset() { + b.Lo = 0 + b.Hi = 0 } // GetBit returns the bit value at a specified index of the bitmask. func (b *BitArray) GetBit(bitIndex uint32) uint8 { - return b.bytes[bitIndex/8] -} - -// Marshal returns the underlying bitmask. -func (b *BitArray) Marshal() []byte { - return b.bytes -} - -// GetBitValue returns a subsection of the bitmask. -func (b *BitArray) GetBitValue(i int, j int) uint64 { - if i < 0 || i >= len(b.bytes)*8 || j < 0 || j >= len(b.bytes)*8 || i > j { + if bitIndex < 64 { + result := (b.Lo & (uint64(0b1) << (63 - bitIndex))) + if result > 0 { + return 1 + } return 0 } - startByte := i / 8 - startBit := i % 8 - endByte := j / 8 - - // Create a slice containing the bytes to extract - subArray := b.bytes[startByte : endByte+1] - - // Initialize the result value - var result uint64 - - // Loop through the bytes and concatenate the bits - for idx, b := range subArray { - if idx == 0 { - b <<= uint(startBit) - } - result |= uint64(b) + hiBitIndex := bitIndex - 64 + result := (b.Hi & (uint64(0b1) << (63 - hiBitIndex))) + if result > 0 { + return 1 } - - // Mask the bits that are not part of the desired range - result &= (1<