Skip to content

Commit

Permalink
Fix Association and Stream closure
Browse files Browse the repository at this point in the history
* Always close `Association` on `writeLoop` exit

The connection will now always be closed on `writeLoop` exit because it
will ensure that `readLoop` exits, which is needed to propagate the
closing of `Stream`s.

* Guard against creating `Stream`s after `Association` close

It was possible for new `Stream`s to be created after `readLoop` has
exited and called `unregisterStream` on the existing ones. The new
`Stream`s would never close.

This also guards against a potential panic due to send on nil channel
(`acceptCh`).

This may fix pion/webrtc#2098.
  • Loading branch information
mafredri committed Aug 2, 2022
1 parent d0b7cf3 commit e4d7a2f
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 11 deletions.
30 changes: 19 additions & 11 deletions association.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ type Association struct {
storedCookieEcho *chunkCookieEcho

streams map[uint16]*Stream
streamsClosedErr error
acceptCh chan *Stream
readLoopCloseCh chan struct{}
awakeWriteLoopCh chan struct{}
Expand Down Expand Up @@ -513,6 +514,7 @@ func (a *Association) readLoop() {
for _, s := range a.streams {
a.unregisterStream(s, closeErr)
}
a.streamsClosedErr = closeErr
a.lock.Unlock()
close(a.acceptCh)
close(a.readLoopCloseCh)
Expand Down Expand Up @@ -552,9 +554,13 @@ func (a *Association) readLoop() {

func (a *Association) writeLoop() {
a.log.Debugf("[%s] writeLoop entered", a.name)
defer a.log.Debugf("[%s] writeLoop exited", a.name)
defer func() {
if err := a.close(); err != nil {
a.log.Warnf("[%s] failed to close association: %v", a.name, err)
}
a.log.Debugf("[%s] writeLoop exited", a.name)
}()

loop:
for {
rawPackets, ok := a.gatherOutbound()

Expand All @@ -565,28 +571,21 @@ loop:
a.log.Warnf("[%s] failed to write packets on netConn: %v", a.name, err)
}
a.log.Debugf("[%s] writeLoop ended", a.name)
break loop
return
}
atomic.AddUint64(&a.bytesSent, uint64(len(raw)))
}

if !ok {
if err := a.close(); err != nil {
a.log.Warnf("[%s] failed to close association: %v", a.name, err)
}

return
}

select {
case <-a.awakeWriteLoopCh:
case <-a.closeWriteLoopCh:
break loop
return
}
}

a.setState(closed)
a.closeAllTimers()
}

func (a *Association) awakeWriteLoop() {
Expand Down Expand Up @@ -1349,6 +1348,10 @@ func (a *Association) OpenStream(streamIdentifier uint16, defaultPayloadType Pay
a.lock.Lock()
defer a.lock.Unlock()

if a.streamsClosedErr != nil {
return nil, a.streamsClosedErr
}

return a.getOrCreateStream(streamIdentifier, false, defaultPayloadType), nil
}

Expand All @@ -1363,6 +1366,11 @@ func (a *Association) AcceptStream() (*Stream, error) {

// createStream creates a stream. The caller should hold the lock and check no stream exists for this id.
func (a *Association) createStream(streamIdentifier uint16, accept bool) *Stream {
if a.streamsClosedErr != nil {
a.log.Debugf("[%s] dropped a new stream (streamsClosedErr: %s)", a.name, a.streamsClosedErr)
return nil
}

s := &Stream{
association: a,
streamIdentifier: streamIdentifier,
Expand Down
71 changes: 71 additions & 0 deletions association_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2168,6 +2168,7 @@ func TestAssocReset(t *testing.T) {
_, _, err = s0.ReadSCTP(buf)
assert.Equal(t, io.EOF, err, "should be EOF")
doneCh <- err
return
}
}()

Expand Down Expand Up @@ -2278,6 +2279,11 @@ func (c *fakeEchoConn) Write(b []byte) (int, error) {
func (c *fakeEchoConn) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
select {
case <-c.closed:
return c.errClose
default:
}
close(c.echo)
close(c.closed)
return c.errClose
Expand Down Expand Up @@ -2836,4 +2842,69 @@ func TestAssociation_Abort(t *testing.T) {
i, err = s21.Read(buf)
assert.Equal(t, i, 0, "expected no data read")
assert.Error(t, err, "User Initiated Abort: 1234", "expected abort reason")

// Ensure a1 has closed down as well (avoid goroutine leak).
select {
case <-a1.readLoopCloseCh:
case <-time.After(1 * time.Second):
assert.Fail(t, "timed out waiting for a1 read loop to close")
}

time.Sleep(time.Millisecond) // give readLoop a ms to completely exit.
}

func TestAssociation_OpenStreamAfterCloseMustNotHang(t *testing.T) {
runtime.GC()
n0 := runtime.NumGoroutine()

defer func() {
runtime.GC()
assert.Equal(t, n0, runtime.NumGoroutine(), "goroutine is leaked")
}()

a1, a2 := createAssocs(t)

s11, err := a1.OpenStream(1, PayloadTypeWebRTCString)
require.NoError(t, err)

startOpenStream := make(chan struct{})
go func() {
_ = a2.close() // trigger close of read loop.
close(startOpenStream)
}()

done := make(chan struct{})
go func() {
defer close(done)

<-startOpenStream
s21, err := a2.OpenStream(1, PayloadTypeWebRTCString)
if err == nil {
// If stream opened, ensure ReadSCTP doesn't hang.
_, _, err = s21.ReadSCTP(make([]byte, 1))
assert.Error(t, err, "read did not exit with error")
}
}()

timeout := time.After(2 * time.Second)

select {
case <-done:
case <-timeout:
assert.Fail(t, "timed out waiting for a2.OpenStream test goroutine")
}

_ = s11.Close()
select {
case <-a1.readLoopCloseCh:
case <-timeout:
assert.Fail(t, "timed out waiting for a1 read loop to close")
}
select {
case <-a2.readLoopCloseCh:
case <-timeout:
assert.Fail(t, "timed out waiting for a2 read loop to close")
}

time.Sleep(time.Millisecond) // give readLoop a ms to completely exit.
}

0 comments on commit e4d7a2f

Please sign in to comment.