Skip to content

Commit

Permalink
Add PeerConneciton.CloseCleanly
Browse files Browse the repository at this point in the history
  • Loading branch information
edaniels committed Jul 1, 2024
1 parent 0a97ff6 commit 6f8a580
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 1 deletion.
9 changes: 9 additions & 0 deletions datachannel.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ type DataChannel struct {
readyState atomic.Value // DataChannelState
bufferedAmountLowThreshold uint64
detachCalled bool
readLoopActive chan struct{}

// The binaryType represents attribute MUST, on getting, return the value to
// which it was last set. On setting, if the new value is either the string
Expand Down Expand Up @@ -327,6 +328,7 @@ func (d *DataChannel) handleOpen(dc *datachannel.DataChannel, isRemote, isAlread
defer d.mu.Unlock()

if !d.api.settingEngine.detach.DataChannels {
d.readLoopActive = make(chan struct{})
go d.readLoop()
}
}
Expand All @@ -350,6 +352,7 @@ func (d *DataChannel) onError(err error) {
}

func (d *DataChannel) readLoop() {
defer close(d.readLoopActive)
buffer := make([]byte, dataChannelBufferSize)
for {
n, isString, err := d.dataChannel.ReadDataChannel(buffer)
Expand Down Expand Up @@ -457,6 +460,12 @@ func (d *DataChannel) Close() error {
return nil
}

if d.readLoopActive != nil {
defer func() {
<-d.readLoopActive
}()
}

d.setReadyState(DataChannelStateClosing)
if !haveSctpTransport {
return nil
Expand Down
30 changes: 29 additions & 1 deletion peerconnection.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ type PeerConnection struct {
idpLoginURL *string

isClosed *atomicBool
isClosedDone chan struct{}
isNegotiationNeeded *atomicBool
updateNegotiationNeededFlagOnEmptyChain *atomicBool

Expand Down Expand Up @@ -116,6 +117,7 @@ func (api *API) NewPeerConnection(configuration Configuration) (*PeerConnection,
ICECandidatePoolSize: 0,
},
isClosed: &atomicBool{},
isClosedDone: make(chan struct{}),
isNegotiationNeeded: &atomicBool{},
updateNegotiationNeededFlagOnEmptyChain: &atomicBool{},
lastOffer: "",
Expand Down Expand Up @@ -2049,8 +2051,11 @@ func (pc *PeerConnection) Close() error {
// https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #1)
// https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #2)
if pc.isClosed.swap(true) {
// someone else got here first but may still be closing (e.g. via DTLS close_notify)
<-pc.isClosedDone
return nil
}
defer close(pc.isClosedDone)

// https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #3)
pc.signalingState.Set(SignalingStateClosed)
Expand Down Expand Up @@ -2100,6 +2105,26 @@ func (pc *PeerConnection) Close() error {
return util.FlattenErrs(closeErrs)
}

// CloseCleanly attempts to close a PeerConnection as cleanly as possible such that
// all background workers spawned are terminated. It acomplishes this by closing down any
// resources that have been created for the PeerConnection and then by calling Close.
// This is a combination of accommodating for golang idioms in addition to following the
// W3C specification's close procedure.
// If you purely want the W3C behavior, just call Close.
func (pc *PeerConnection) CloseCleanly() error {
var closeErrs []error

pc.mu.Lock()
for _, d := range pc.sctpTransport.dataChannels {
closeErrs = append(closeErrs, d.Close())
}
pc.mu.Unlock()

closeErrs = append(closeErrs, pc.Close())

return util.FlattenErrs(closeErrs)
}

// addRTPTransceiver appends t into rtpTransceivers
// and fires onNegotiationNeeded;
// caller of this method should hold `pc.mu` lock
Expand Down Expand Up @@ -2268,8 +2293,11 @@ func (pc *PeerConnection) startTransports(iceRole ICERole, dtlsRole DTLSRole, re
}

pc.dtlsTransport.internalOnCloseHandler = func() {
pc.log.Info("Closing PeerConnection from DTLS CloseNotify")
if pc.isClosed.get() {
return
}

pc.log.Info("Closing PeerConnection from DTLS CloseNotify")
go func() {
if pcClosErr := pc.Close(); pcClosErr != nil {
pc.log.Warnf("Failed to close PeerConnection from DTLS CloseNotify: %s", pcClosErr)
Expand Down
102 changes: 102 additions & 0 deletions peerconnection_close_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
package webrtc

import (
"runtime"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -179,3 +181,103 @@ func TestPeerConnection_Close_DuringICE(t *testing.T) {
t.Error("pcOffer.Close() Timeout")
}
}

func TestPeerConnection_CloseCleanly(t *testing.T) {
// Limit runtime in case of deadlocks
lim := test.TimeOut(time.Second * 20)
defer lim.Stop()

report := CheckRoutinesIntolerant(t)
defer report()

pcOffer, pcAnswer, err := newPair()
if err != nil {
t.Fatal(err)
}

var dcAnswer *DataChannel
answerDataChannelOpened := make(chan struct{})
pcAnswer.OnDataChannel(func(d *DataChannel) {
// Make sure this is the data channel we were looking for. (Not the one
// created in signalPair).
if d.Label() != "data" {
return
}
dcAnswer = d
close(answerDataChannelOpened)
})

dcOffer, err := pcOffer.CreateDataChannel("data", nil)
if err != nil {
t.Fatal(err)
}

offerDataChannelOpened := make(chan struct{})
dcOffer.OnOpen(func() {
close(offerDataChannelOpened)
})

err = signalPair(pcOffer, pcAnswer)
if err != nil {
t.Fatal(err)
}

<-offerDataChannelOpened
<-answerDataChannelOpened

msgNum := 0
dcOffer.OnMessage(func(_ DataChannelMessage) {
t.Log("msg", msgNum)
msgNum++
})

// send 50 messages, then close pcOffer, and then send another 50
for i := 0; i < 100; i++ {
if i == 50 {
err = pcOffer.CloseCleanly()
if err != nil {
t.Fatal(err)
}
}
_ = dcAnswer.Send([]byte("hello!"))
}

err = pcAnswer.CloseCleanly()
if err != nil {
t.Fatal(err)
}
}

// CheckRoutinesIntolerant is used to check for leaked go-routines.
// It differs from test.CheckRoutines in that it won't wait at all
// for lingering goroutines. This is helpful for tests that need
// to ensure clean closure of resources.
func CheckRoutinesIntolerant(t *testing.T) func() {
return func() {
routines := getRoutines()
if len(routines) == 0 {
return
}
t.Fatalf("%s: \n%s", "Unexpected routines on test end", strings.Join(routines, "\n\n")) // nolint
}
}

func getRoutines() []string {
buf := make([]byte, 2<<20)
buf = buf[:runtime.Stack(buf, true)]
return filterRoutines(strings.Split(string(buf), "\n\n"))
}

func filterRoutines(routines []string) []string {
result := []string{}
for _, stack := range routines {
if stack == "" || // Empty
strings.Contains(stack, "testing.Main(") || // Tests
strings.Contains(stack, "testing.(*T).Run(") || // Test run
strings.Contains(stack, "getRoutines(") { // This routine
continue
}
result = append(result, stack)
}
return result
}

0 comments on commit 6f8a580

Please sign in to comment.