Skip to content

Commit

Permalink
Add PeerConnection.GracefulClose
Browse files Browse the repository at this point in the history
  • Loading branch information
edaniels committed Aug 5, 2024
1 parent dbe26d3 commit a77c5e7
Show file tree
Hide file tree
Showing 7 changed files with 299 additions and 18 deletions.
62 changes: 62 additions & 0 deletions datachannel.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ type DataChannel struct {
readyState atomic.Value // DataChannelState
bufferedAmountLowThreshold uint64
detachCalled bool
readLoopActive chan struct{}
isGracefulClosed bool

// The binaryType represents attribute MUST, on getting, return the value to
// which it was last set. On setting, if the new value is either the string
Expand Down Expand Up @@ -225,6 +227,10 @@ func (d *DataChannel) OnOpen(f func()) {
func (d *DataChannel) onOpen() {
d.mu.RLock()
handler := d.onOpenHandler
if d.isGracefulClosed {
d.mu.RUnlock()
return

Check warning on line 232 in datachannel.go

View check run for this annotation

Codecov / codecov/patch

datachannel.go#L231-L232

Added lines #L231 - L232 were not covered by tests
}
d.mu.RUnlock()

if handler != nil {
Expand Down Expand Up @@ -252,6 +258,10 @@ func (d *DataChannel) OnDial(f func()) {
func (d *DataChannel) onDial() {
d.mu.RLock()
handler := d.onDialHandler
if d.isGracefulClosed {
d.mu.RUnlock()
return

Check warning on line 263 in datachannel.go

View check run for this annotation

Codecov / codecov/patch

datachannel.go#L262-L263

Added lines #L262 - L263 were not covered by tests
}
d.mu.RUnlock()

if handler != nil {
Expand All @@ -261,6 +271,10 @@ func (d *DataChannel) onDial() {

// OnClose sets an event handler which is invoked when
// the underlying data transport has been closed.
// Note: Due to backwards compatibility, there is a chance that
// OnClose can be called, even if the GracefulClose is used.
// If this is the case for you, you can deregister OnClose
// prior to GracefulClose.
func (d *DataChannel) OnClose(f func()) {
d.mu.Lock()
defer d.mu.Unlock()
Expand Down Expand Up @@ -292,6 +306,10 @@ func (d *DataChannel) OnMessage(f func(msg DataChannelMessage)) {
func (d *DataChannel) onMessage(msg DataChannelMessage) {
d.mu.RLock()
handler := d.onMessageHandler
if d.isGracefulClosed {
d.mu.RUnlock()
return
}
d.mu.RUnlock()

if handler == nil {
Expand All @@ -302,6 +320,10 @@ func (d *DataChannel) onMessage(msg DataChannelMessage) {

func (d *DataChannel) handleOpen(dc *datachannel.DataChannel, isRemote, isAlreadyNegotiated bool) {
d.mu.Lock()
if d.isGracefulClosed {
d.mu.Unlock()
return

Check warning on line 325 in datachannel.go

View check run for this annotation

Codecov / codecov/patch

datachannel.go#L324-L325

Added lines #L324 - L325 were not covered by tests
}
d.dataChannel = dc
bufferedAmountLowThreshold := d.bufferedAmountLowThreshold
onBufferedAmountLow := d.onBufferedAmountLow
Expand All @@ -326,7 +348,12 @@ func (d *DataChannel) handleOpen(dc *datachannel.DataChannel, isRemote, isAlread
d.mu.Lock()
defer d.mu.Unlock()

if d.isGracefulClosed {
return

Check warning on line 352 in datachannel.go

View check run for this annotation

Codecov / codecov/patch

datachannel.go#L352

Added line #L352 was not covered by tests
}

if !d.api.settingEngine.detach.DataChannels {
d.readLoopActive = make(chan struct{})
go d.readLoop()
}
}
Expand All @@ -342,6 +369,10 @@ func (d *DataChannel) OnError(f func(err error)) {
func (d *DataChannel) onError(err error) {
d.mu.RLock()
handler := d.onErrorHandler
if d.isGracefulClosed {
d.mu.RUnlock()
return
}
d.mu.RUnlock()

if handler != nil {
Expand All @@ -350,6 +381,12 @@ func (d *DataChannel) onError(err error) {
}

func (d *DataChannel) readLoop() {
defer func() {
d.mu.Lock()
readLoopActive := d.readLoopActive
d.mu.Unlock()
defer close(readLoopActive)
}()
buffer := make([]byte, dataChannelBufferSize)
for {
n, isString, err := d.dataChannel.ReadDataChannel(buffer)
Expand Down Expand Up @@ -449,7 +486,32 @@ func (d *DataChannel) Detach() (datachannel.ReadWriteCloser, error) {
// Close Closes the DataChannel. It may be called regardless of whether
// the DataChannel object was created by this peer or the remote peer.
func (d *DataChannel) Close() error {
return d.close(false)
}

// GracefulClose Closes the DataChannel. It may be called regardless of whether
// the DataChannel object was created by this peer or the remote peer. It also waits
// for any goroutines it started to complete. This is only safe to call outside of
// DataChannel callbacks or if in a callback, in its own goroutine.
func (d *DataChannel) GracefulClose() error {
return d.close(true)
}

// Normally, close only stops writes from happening, so graceful=true
// will wait for reads to be finished based on underlying SCTP association
// closure or a SCTP reset stream from the other side. This is safe to call
// with graceful=true after tearing down a PeerConnection but not
// necessarily before. For example, if you used a vnet and dropped all packets
// right before closing the DataChannel, you'd need never see a reset stream.
func (d *DataChannel) close(shouldGracefullyClose bool) error {
d.mu.Lock()
d.isGracefulClosed = true
readLoopActive := d.readLoopActive
if shouldGracefullyClose && readLoopActive != nil {
defer func() {
<-readLoopActive
}()
}
haveSctpTransport := d.dataChannel != nil
d.mu.Unlock()

Expand Down
22 changes: 20 additions & 2 deletions icegatherer.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,31 @@ func (g *ICEGatherer) Gather() error {

// Close prunes all local candidates, and closes the ports.
func (g *ICEGatherer) Close() error {
return g.close(false /* shouldGracefullyClose */)
}

// GracefulClose prunes all local candidates, and closes the ports. It also waits
// for any goroutines it started to complete. This is only safe to call outside of
// ICEGatherer callbacks or if in a callback, in its own goroutine.
func (g *ICEGatherer) GracefulClose() error {
return g.close(true /* shouldGracefullyClose */)

Check warning on line 200 in icegatherer.go

View check run for this annotation

Codecov / codecov/patch

icegatherer.go#L199-L200

Added lines #L199 - L200 were not covered by tests
}

func (g *ICEGatherer) close(shouldGracefullyClose bool) error {
g.lock.Lock()
defer g.lock.Unlock()

if g.agent == nil {
return nil
} else if err := g.agent.Close(); err != nil {
return err
}
if shouldGracefullyClose {
if err := g.agent.GracefulClose(); err != nil {
return err

Check warning on line 212 in icegatherer.go

View check run for this annotation

Codecov / codecov/patch

icegatherer.go#L211-L212

Added lines #L211 - L212 were not covered by tests
}
} else {
if err := g.agent.Close(); err != nil {
return err

Check warning on line 216 in icegatherer.go

View check run for this annotation

Codecov / codecov/patch

icegatherer.go#L216

Added line #L216 was not covered by tests
}
}

g.agent = nil
Expand Down
14 changes: 14 additions & 0 deletions icetransport.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,17 @@ func (t *ICETransport) restart() error {

// Stop irreversibly stops the ICETransport.
func (t *ICETransport) Stop() error {
return t.stop(false /* shouldGracefullyClose */)
}

// GracefulStop irreversibly stops the ICETransport. It also waits
// for any goroutines it started to complete. This is only safe to call outside of
// ICETransport callbacks or if in a callback, in its own goroutine.
func (t *ICETransport) GracefulStop() error {
return t.stop(true /* shouldGracefullyClose */)
}

func (t *ICETransport) stop(shouldGracefullyClose bool) error {
t.lock.Lock()
defer t.lock.Unlock()

Expand All @@ -199,6 +210,9 @@ func (t *ICETransport) Stop() error {
if t.mux != nil {
return t.mux.Close()
} else if t.gatherer != nil {
if shouldGracefullyClose {
return t.gatherer.GracefulClose()

Check warning on line 214 in icetransport.go

View check run for this annotation

Codecov / codecov/patch

icetransport.go#L214

Added line #L214 was not covered by tests
}
return t.gatherer.Close()
}
return nil
Expand Down
71 changes: 58 additions & 13 deletions operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@ type operation func()

// Operations is a task executor.
type operations struct {
mu sync.Mutex
busy bool
ops *list.List
mu sync.Mutex
busyCh chan struct{}
ops *list.List

updateNegotiationNeededFlagOnEmptyChain *atomicBool
onNegotiationNeeded func()
isClosed bool
}

func newOperations(
Expand All @@ -33,21 +34,34 @@ func newOperations(
}

// Enqueue adds a new action to be executed. If there are no actions scheduled,
// the execution will start immediately in a new goroutine.
// the execution will start immediately in a new goroutine. If the queue has been
// closed, the operation will be dropped. The queue is only deliberately closed
// by a user.
func (o *operations) Enqueue(op operation) {
o.mu.Lock()
defer o.mu.Unlock()
_ = o.tryEnqueue(op)
}

// tryEnqueue attempts to enqueue the given operation. It returns false
// if the op is invalid or the queue is closed. mu must be locked by
// tryEnqueue's caller.
func (o *operations) tryEnqueue(op operation) bool {
if op == nil {
return
return false

Check warning on line 51 in operations.go

View check run for this annotation

Codecov / codecov/patch

operations.go#L51

Added line #L51 was not covered by tests
}

o.mu.Lock()
running := o.busy
if o.isClosed {
return false
}
o.ops.PushBack(op)
o.busy = true
o.mu.Unlock()

if !running {
if o.busyCh == nil {
o.busyCh = make(chan struct{})
go o.start()
}

return true
}

// IsEmpty checks if there are tasks in the queue
Expand All @@ -62,12 +76,38 @@ func (o *operations) IsEmpty() bool {
func (o *operations) Done() {
var wg sync.WaitGroup
wg.Add(1)
o.Enqueue(func() {
o.mu.Lock()
enqueued := o.tryEnqueue(func() {
wg.Done()
})
o.mu.Unlock()
if !enqueued {
return
}
wg.Wait()
}

// GracefulClose waits for the operations queue to be cleared and forbids
// new operations from being enqueued.
func (o *operations) GracefulClose() {
o.mu.Lock()
if o.isClosed {
o.mu.Unlock()
return

Check warning on line 96 in operations.go

View check run for this annotation

Codecov / codecov/patch

operations.go#L95-L96

Added lines #L95 - L96 were not covered by tests
}
// do not enqueue anymore ops from here on
// o.isClosed=true will also not allow a new busyCh
// to be created.
o.isClosed = true

busyCh := o.busyCh
o.mu.Unlock()
if busyCh == nil {
return
}
<-busyCh

Check warning on line 108 in operations.go

View check run for this annotation

Codecov / codecov/patch

operations.go#L108

Added line #L108 was not covered by tests
}

func (o *operations) pop() func() {
o.mu.Lock()
defer o.mu.Unlock()
Expand All @@ -87,12 +127,17 @@ func (o *operations) start() {
defer func() {
o.mu.Lock()
defer o.mu.Unlock()
if o.ops.Len() == 0 {
o.busy = false
// this wil lbe the most recent busy chan
close(o.busyCh)

if o.ops.Len() == 0 || o.isClosed {
o.busyCh = nil
return
}

// either a new operation was enqueued while we
// were busy, or an operation panicked
o.busyCh = make(chan struct{})
go o.start()
}()

Expand Down
32 changes: 32 additions & 0 deletions operations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ func TestOperations_Enqueue(t *testing.T) {
onNegotiationNeededCalledCount++
onNegotiationNeededCalledCountMu.Unlock()
})
defer ops.GracefulClose()

for resultSet := 0; resultSet < 100; resultSet++ {
results := make([]int, 16)
resultSetCopy := resultSet
Expand Down Expand Up @@ -46,5 +48,35 @@ func TestOperations_Enqueue(t *testing.T) {
func TestOperations_Done(*testing.T) {
ops := newOperations(&atomicBool{}, func() {
})
defer ops.GracefulClose()
ops.Done()
}

func TestOperations_GracefulClose(t *testing.T) {
ops := newOperations(&atomicBool{}, func() {
})

counter := 0
var counterMu sync.Mutex
incFunc := func() {
counterMu.Lock()
counter++
counterMu.Unlock()
}
const times = 25
for i := 0; i < times; i++ {
ops.Enqueue(incFunc)
}
ops.Done()
counterMu.Lock()
counterCur := counter
counterMu.Unlock()
assert.Equal(t, counterCur, times)

ops.GracefulClose()
for i := 0; i < times; i++ {
ops.Enqueue(incFunc)
}
ops.Done()
assert.Equal(t, counterCur, times)
}
Loading

0 comments on commit a77c5e7

Please sign in to comment.