From 68cb1a48724a9b739eca801e5ff9b64a8c6f6b14 Mon Sep 17 00:00:00 2001 From: Casey Waldren Date: Tue, 4 Jun 2024 11:04:34 -0700 Subject: [PATCH] fix: resolve data race in Broadcasters system (#153) This resolves a data race in the Broadcaster implementation, specifically ` HasListeners()` was not protected by a lock. Additionally it swaps out the mutex for a `RWLock` so that concurrent readers don't block each other. A behavioral change is that the `Broadcast` method now locks the subscribers array, so that `Add` or `RemoveListener` cannot be called concurrently. --------- Co-authored-by: John Starich --- internal/broadcasters.go | 35 +++++++++++++++-------------------- internal/broadcasters_test.go | 30 ++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 20 deletions(-) diff --git a/internal/broadcasters.go b/internal/broadcasters.go index 0632ec63..5b902d12 100644 --- a/internal/broadcasters.go +++ b/internal/broadcasters.go @@ -1,9 +1,8 @@ package internal import ( + "slices" "sync" - - "golang.org/x/exp/slices" ) // This file defines the publish-subscribe model we use for various status/event types in the SDK. @@ -19,11 +18,11 @@ const subscriberChannelBufferLength = 10 // Broadcaster is our generalized implementation of broadcasters. type Broadcaster[V any] struct { subscribers []channelPair[V] - lock sync.Mutex + lock sync.RWMutex } // We need to keep track of both the channel we use for sending (stored as a reflect.Value, because Value -// has methods for sending and closing), and also the +// has methods for sending and closing), and also the channel for receiving. type channelPair[V any] struct { sendCh chan<- V receiveCh <-chan V @@ -50,35 +49,31 @@ func (b *Broadcaster[V]) AddListener() <-chan V { func (b *Broadcaster[V]) RemoveListener(ch <-chan V) { b.lock.Lock() defer b.lock.Unlock() - ss := b.subscribers - for i, s := range ss { + b.subscribers = slices.DeleteFunc(b.subscribers, func(pair channelPair[V]) bool { // The following equality test is the reason why we have to store both the sendCh (chan X) and // the receiveCh (<-chan X) for each subscriber; "s.sendCh == ch" would not be true because // they're of two different types. - if s.receiveCh == ch { - copy(ss[i:], ss[i+1:]) - ss[len(ss)-1] = channelPair[V]{} - b.subscribers = ss[:len(ss)-1] - close(s.sendCh) - break + if pair.receiveCh == ch { + close(pair.sendCh) + return true } - } + return false + }) } // HasListeners returns true if there are any current subscribers. func (b *Broadcaster[V]) HasListeners() bool { + b.lock.RLock() + defer b.lock.RUnlock() return len(b.subscribers) > 0 } // Broadcast broadcasts a value to all current subscribers. func (b *Broadcaster[V]) Broadcast(value V) { - b.lock.Lock() - ss := slices.Clone(b.subscribers) - b.lock.Unlock() - if len(ss) > 0 { - for _, ch := range ss { - ch.sendCh <- value - } + b.lock.RLock() + defer b.lock.RUnlock() + for _, ch := range b.subscribers { + ch.sendCh <- value } } diff --git a/internal/broadcasters_test.go b/internal/broadcasters_test.go index 9ba13e4d..9e7c7962 100644 --- a/internal/broadcasters_test.go +++ b/internal/broadcasters_test.go @@ -2,6 +2,7 @@ package internal import ( "fmt" + "sync" "testing" "time" @@ -81,3 +82,32 @@ func testBroadcasterGenerically[V any](t *testing.T, broadcasterFactory func() * }) }) } + +func TestBroadcasterDataRace(t *testing.T) { + t.Parallel() + b := NewBroadcaster[string]() + t.Cleanup(b.Close) + + var waitGroup sync.WaitGroup + for _, fn := range []func(){ + // Run every method that uses b.subscribers concurrently to detect data races + func() { b.AddListener() }, + func() { b.Broadcast("foo") }, + func() { b.Close() }, + func() { b.HasListeners() }, + func() { b.RemoveListener(nil) }, + } { + const concurrentRoutinesWithSelf = 2 + // Run a method concurrently with itself to detect data races. These methods will also be + // run concurrently with the previous/next methods in the list. + for i := 0; i < concurrentRoutinesWithSelf; i++ { + waitGroup.Add(1) + fn := fn // make fn a loop-local variable + go func() { + defer waitGroup.Done() + fn() + }() + } + } + waitGroup.Wait() +}