Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: resolve data race in Broadcasters system #153

Merged
merged 6 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 15 additions & 20 deletions internal/broadcasters.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
package internal

import (
"slices"
"sync"

"golang.org/x/exp/slices"
)

// This file defines the publish-subscribe model we use for various status/event types in the SDK.
Expand All @@ -19,11 +18,11 @@ const subscriberChannelBufferLength = 10
// Broadcaster is our generalized implementation of broadcasters.
type Broadcaster[V any] struct {
subscribers []channelPair[V]
lock sync.Mutex
lock sync.RWMutex
}

// We need to keep track of both the channel we use for sending (stored as a reflect.Value, because Value
// has methods for sending and closing), and also the
// has methods for sending and closing), and also the channel for receiving.
type channelPair[V any] struct {
sendCh chan<- V
receiveCh <-chan V
Expand All @@ -50,35 +49,31 @@ func (b *Broadcaster[V]) AddListener() <-chan V {
func (b *Broadcaster[V]) RemoveListener(ch <-chan V) {
b.lock.Lock()
defer b.lock.Unlock()
ss := b.subscribers
for i, s := range ss {
b.subscribers = slices.DeleteFunc(b.subscribers, func(pair channelPair[V]) bool {
// The following equality test is the reason why we have to store both the sendCh (chan X) and
// the receiveCh (<-chan X) for each subscriber; "s.sendCh == ch" would not be true because
// they're of two different types.
if s.receiveCh == ch {
copy(ss[i:], ss[i+1:])
ss[len(ss)-1] = channelPair[V]{}
b.subscribers = ss[:len(ss)-1]
close(s.sendCh)
break
if pair.receiveCh == ch {
close(pair.sendCh)
return true
}
}
return false
})
}

// HasListeners returns true if there are any current subscribers.
func (b *Broadcaster[V]) HasListeners() bool {
b.lock.RLock()
defer b.lock.RUnlock()
return len(b.subscribers) > 0
}

// Broadcast broadcasts a value to all current subscribers.
func (b *Broadcaster[V]) Broadcast(value V) {
b.lock.Lock()
ss := slices.Clone(b.subscribers)
b.lock.Unlock()
if len(ss) > 0 {
for _, ch := range ss {
ch.sendCh <- value
}
b.lock.RLock()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, there is a behavioral change here, which may or may not be relevant. But previously during the broadcast you could remove a listener (the lock wasn't retained during the actual broadcast), and now you wouldn't be able do get the exclusive lock during that time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, we probably shouldn't hold the lock during the broadcast as that could take an unbounded amount of time to complete. I'll revert that and add a test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so there's a problem here: if we allow removing a listener while Broadcast is running concurrently, then it could try to send to a channel that is closed (due to removing the listener.) This is a panic.

So that behavior was already bad.

I'm thinking it's better to say "if you aren't draining your receivers, you're not gonna be able to add or remove receivers" which would be the case if we locked around the broadcast.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm reverting to locking around the broadcast. I'll note in the release that this may affect the use-case of adding/removing a listener while broadcasting.

defer b.lock.RUnlock()
for _, ch := range b.subscribers {
ch.sendCh <- value
}
}

Expand Down
64 changes: 64 additions & 0 deletions internal/broadcasters_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package internal

import (
"fmt"
"math/rand"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -81,3 +83,65 @@ func testBroadcasterGenerically[V any](t *testing.T, broadcasterFactory func() *
})
})
}

func TestBroadcasterDataRace(t *testing.T) {
t.Parallel()
b := NewBroadcaster[string]()
t.Cleanup(b.Close)

var waitGroup sync.WaitGroup
for _, fn := range []func(){
// Run every method that uses b.subscribers concurrently to detect data races
func() { b.AddListener() },
func() { b.Broadcast("foo") },
func() { b.Close() },
func() { b.HasListeners() },
func() { b.RemoveListener(nil) },
} {
const concurrentRoutinesWithSelf = 2
// Run a method concurrently with itself to detect data races. These methods will also be
// run concurrently with the previous/next methods in the list.
for i := 0; i < concurrentRoutinesWithSelf; i++ {
waitGroup.Add(1)
fn := fn // make fn a loop-local variable
go func() {
defer waitGroup.Done()
fn()
}()
}
}
waitGroup.Wait()
}

func TestBroadcasterDataRaceRandomFunctionOrder(t *testing.T) {
t.Parallel()
b := NewBroadcaster[string]()
t.Cleanup(b.Close)

funcs := []func(){
func() {
for range b.AddListener() {
}
},
func() { b.Broadcast("foo") },
func() { b.Close() },
func() { b.HasListeners() },
func() { b.RemoveListener(nil) },
}
var waitGroup sync.WaitGroup

const N = 1000

// We're going to keep adding random functions to the set of currently executing functions
// for N iterations. This way, we can detect races that might be order-dependent.

for i := 0; i < N; i++ {
waitGroup.Add(1)
fn := funcs[rand.Intn(len(funcs))]
go func() {
defer waitGroup.Done()
fn()
}()
}
waitGroup.Wait()
}
Loading