From fd5edcb829a9c440edffcfc7961e4668c4ab9c5c Mon Sep 17 00:00:00 2001 From: Lee Bousfield Date: Sat, 23 Dec 2023 13:21:31 -0700 Subject: [PATCH] Fix race conditions in backlog.reset() --- broadcaster/backlog/backlog.go | 26 ++++++++++++++------------ broadcaster/backlog/backlog_test.go | 8 ++++---- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/broadcaster/backlog/backlog.go b/broadcaster/backlog/backlog.go index 461ec3a6f3..549a38ff86 100644 --- a/broadcaster/backlog/backlog.go +++ b/broadcaster/backlog/backlog.go @@ -35,18 +35,18 @@ type Backlog interface { type backlog struct { head atomic.Pointer[backlogSegment] tail atomic.Pointer[backlogSegment] - lookupByIndex *containers.SyncMap[uint64, *backlogSegment] + lookupByIndex atomic.Pointer[containers.SyncMap[uint64, *backlogSegment]] config ConfigFetcher messageCount atomic.Uint64 } // NewBacklog creates a backlog. func NewBacklog(c ConfigFetcher) Backlog { - lookup := &containers.SyncMap[uint64, *backlogSegment]{} - return &backlog{ - lookupByIndex: lookup, - config: c, + b := &backlog{ + config: c, } + b.lookupByIndex.Store(&containers.SyncMap[uint64, *backlogSegment]{}) + return b } // Head return the head backlogSegment within the backlog. @@ -63,6 +63,7 @@ func (b *backlog) Append(bm *m.BroadcastMessage) error { b.delete(uint64(bm.ConfirmedSequenceNumberMessage.SequenceNumber)) } + lookupByIndex := b.lookupByIndex.Load() for _, msg := range bm.Messages { segment := b.tail.Load() if segment == nil { @@ -95,7 +96,7 @@ func (b *backlog) Append(bm *m.BroadcastMessage) error { } else if err != nil { return err } - b.lookupByIndex.Store(uint64(msg.SequenceNumber), segment) + lookupByIndex.Store(uint64(msg.SequenceNumber), segment) b.messageCount.Add(1) } @@ -160,7 +161,7 @@ func (b *backlog) delete(confirmed uint64) { } if confirmed > tail.End() { - log.Error("confirmed sequence number is past the end of stored messages", "confirmed sequence number", confirmed, "last stored sequence number", tail.End()) + log.Warn("confirmed sequence number is past the end of stored messages", "confirmed sequence number", confirmed, "last stored sequence number", tail.End()) b.reset() return } @@ -211,14 +212,15 @@ func (b *backlog) delete(confirmed uint64) { // removeFromLookup removes all entries from the head segment's start index to // the given confirmed index. func (b *backlog) removeFromLookup(start, end uint64) { + lookupByIndex := b.lookupByIndex.Load() for i := start; i <= end; i++ { - b.lookupByIndex.Delete(i) + lookupByIndex.Delete(i) } } // Lookup attempts to find the backlogSegment storing the given message index. func (b *backlog) Lookup(i uint64) (BacklogSegment, error) { - segment, ok := b.lookupByIndex.Load(i) + segment, ok := b.lookupByIndex.Load().Load(i) if !ok { return nil, fmt.Errorf("error finding backlog segment containing message with SequenceNumber %d", i) } @@ -233,9 +235,9 @@ func (s *backlog) Count() uint64 { // reset removes all segments from the backlog. func (b *backlog) reset() { - b.head = atomic.Pointer[backlogSegment]{} - b.tail = atomic.Pointer[backlogSegment]{} - b.lookupByIndex = &containers.SyncMap[uint64, *backlogSegment]{} + b.head.Store(nil) + b.tail.Store(nil) + b.lookupByIndex.Store(&containers.SyncMap[uint64, *backlogSegment]{}) b.messageCount.Store(0) } diff --git a/broadcaster/backlog/backlog_test.go b/broadcaster/backlog/backlog_test.go index ab25a523f7..ee712de9ed 100644 --- a/broadcaster/backlog/backlog_test.go +++ b/broadcaster/backlog/backlog_test.go @@ -57,9 +57,9 @@ func validateBroadcastMessage(t *testing.T, bm *m.BroadcastMessage, expectedCoun func createDummyBacklog(indexes []arbutil.MessageIndex) (*backlog, error) { b := &backlog{ - lookupByIndex: &containers.SyncMap[uint64, *backlogSegment]{}, - config: func() *Config { return &DefaultTestConfig }, + config: func() *Config { return &DefaultTestConfig }, } + b.lookupByIndex.Store(&containers.SyncMap[uint64, *backlogSegment]{}) bm := &m.BroadcastMessage{Messages: m.CreateDummyBroadcastMessages(indexes)} err := b.Append(bm) return b, err @@ -161,9 +161,9 @@ func TestDeleteInvalidBacklog(t *testing.T) { lookup := &containers.SyncMap[uint64, *backlogSegment]{} lookup.Store(40, s) b := &backlog{ - lookupByIndex: lookup, - config: func() *Config { return &DefaultTestConfig }, + config: func() *Config { return &DefaultTestConfig }, } + b.lookupByIndex.Store(lookup) b.messageCount.Store(2) b.head.Store(s) b.tail.Store(s)