Skip to content

Commit

Permalink
Drop reference to detached datachannels
Browse files Browse the repository at this point in the history
This allows users of detached datachannels to garbage collect
resources associated with the datachannel and the sctp stream.
There is no functional change here.
  • Loading branch information
sukunrt committed Mar 4, 2024
1 parent 09a4f60 commit 2268938
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 18 deletions.
24 changes: 22 additions & 2 deletions datachannel.go
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,6 @@ func (d *DataChannel) ensureOpen() error {
// resulting DataChannel object.
func (d *DataChannel) Detach() (datachannel.ReadWriteCloser, error) {
d.mu.Lock()
defer d.mu.Unlock()

if !d.api.settingEngine.detach.DataChannels {
return nil, errDetachNotEnabled
Expand All @@ -432,7 +431,28 @@ func (d *DataChannel) Detach() (datachannel.ReadWriteCloser, error) {

d.detachCalled = true

return d.dataChannel, nil
dataChannel := d.dataChannel
d.mu.Unlock()

// Remove the reference from SCTPTransport so that the datachannel
// can be garbage collected on close
d.sctpTransport.lock.Lock()
n := len(d.sctpTransport.dataChannels)
j := 0
for i := 0; i < n; i++ {
if d == d.sctpTransport.dataChannels[i] {
continue
}
d.sctpTransport.dataChannels[j] = d.sctpTransport.dataChannels[i]
j++
}
for i := j; i < n; i++ {
d.sctpTransport.dataChannels[i] = nil
}
d.sctpTransport.dataChannels = d.sctpTransport.dataChannels[:j]
d.sctpTransport.lock.Unlock()

return dataChannel, nil
}

// Close Closes the DataChannel. It may be called regardless of whether
Expand Down
55 changes: 55 additions & 0 deletions datachannel_go_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -692,3 +692,58 @@ func TestDataChannel_Dial(t *testing.T) {
closePair(t, offerPC, answerPC, done)
})
}

func TestDetachRemovesDatachannelReference(t *testing.T) {
// Use Detach data channels mode
s := SettingEngine{}
s.DetachDataChannels()
api := NewAPI(WithSettingEngine(s))

// Set up two peer connections.
config := Configuration{}
pca, err := api.NewPeerConnection(config)
if err != nil {
t.Fatal(err)
}
pcb, err := api.NewPeerConnection(config)
if err != nil {
t.Fatal(err)
}

defer closePairNow(t, pca, pcb)

dcChan := make(chan *DataChannel, 1)
pcb.OnDataChannel(func(d *DataChannel) {
d.OnOpen(func() {
_, err := d.Detach()

Check failure on line 718 in datachannel_go_test.go

View workflow job for this annotation

GitHub Actions / lint / Go

shadow: declaration of "err" shadows declaration at line 704 (govet)
if err != nil {
t.Error(err)
}

dcChan <- d
})
})

if err = signalPair(pca, pcb); err != nil {
t.Fatal(err)
}

attached, err := pca.CreateDataChannel("", nil)
if err != nil {
t.Fatal(err)
}
open := make(chan struct{}, 1)
attached.OnOpen(func() {
open <- struct{}{}
})
<-open

d := <-dcChan
d.sctpTransport.lock.RLock()
defer d.sctpTransport.lock.RUnlock()
for _, dc := range d.sctpTransport.dataChannels[:cap(d.sctpTransport.dataChannels)] {
if dc == d {
t.Errorf("expected sctpTransport to drop reference to datachannel")
}
}
}
3 changes: 3 additions & 0 deletions peerconnection.go
Original file line number Diff line number Diff line change
Expand Up @@ -2029,6 +2029,9 @@ func (pc *PeerConnection) CreateDataChannel(label string, options *DataChannelIn

pc.sctpTransport.lock.Lock()
pc.sctpTransport.dataChannels = append(pc.sctpTransport.dataChannels, d)
if d.ID() != nil {
pc.sctpTransport.dataChannelIDsUsed[*d.ID()] = struct{}{}
}
pc.sctpTransport.dataChannelsRequested++
pc.sctpTransport.lock.Unlock()

Expand Down
30 changes: 15 additions & 15 deletions sctptransport.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ type SCTPTransport struct {

// DataChannels
dataChannels []*DataChannel
dataChannelIDsUsed map[uint16]struct{}
dataChannelsOpened uint32
dataChannelsRequested uint32
dataChannelsAccepted uint32
Expand All @@ -65,10 +66,11 @@ type SCTPTransport struct {
// meant to be used together with the basic WebRTC API.
func (api *API) NewSCTPTransport(dtls *DTLSTransport) *SCTPTransport {
res := &SCTPTransport{
dtlsTransport: dtls,
state: SCTPTransportStateConnecting,
api: api,
log: api.settingEngine.LoggerFactory.NewLogger("ortc"),
dtlsTransport: dtls,
state: SCTPTransportStateConnecting,
api: api,
log: api.settingEngine.LoggerFactory.NewLogger("ortc"),
dataChannelIDsUsed: make(map[uint16]struct{}),
}

res.updateMessageSize()
Expand Down Expand Up @@ -287,6 +289,13 @@ func (r *SCTPTransport) onDataChannel(dc *DataChannel) (done chan struct{}) {
r.lock.Lock()
r.dataChannels = append(r.dataChannels, dc)
r.dataChannelsAccepted++
if dc.ID() != nil {
r.dataChannelIDsUsed[*dc.ID()] = struct{}{}
} else {
// This cannot happen, the constructor for this datachannel in the caller
// takes a pointer to the id.
r.log.Errorf("accepted data channel with no ID")
}

Check warning on line 298 in sctptransport.go

View check run for this annotation

Codecov / codecov/patch

sctptransport.go#L295-L298

Added lines #L295 - L298 were not covered by tests
handler := r.onDataChannelHandler
r.lock.Unlock()

Expand Down Expand Up @@ -393,21 +402,12 @@ func (r *SCTPTransport) generateAndSetDataChannelID(dtlsRole DTLSRole, idOut **u
r.lock.Lock()
defer r.lock.Unlock()

// Create map of ids so we can compare without double-looping each time.
idsMap := make(map[uint16]struct{}, len(r.dataChannels))
for _, dc := range r.dataChannels {
if dc.ID() == nil {
continue
}

idsMap[*dc.ID()] = struct{}{}
}

for ; id < max-1; id += 2 {
if _, ok := idsMap[id]; ok {
if _, ok := r.dataChannelIDsUsed[id]; ok {
continue
}
*idOut = &id
r.dataChannelIDsUsed[id] = struct{}{}
return nil
}

Expand Down
9 changes: 8 additions & 1 deletion sctptransport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,15 @@ import "testing"

func TestGenerateDataChannelID(t *testing.T) {
sctpTransportWithChannels := func(ids []uint16) *SCTPTransport {
ret := &SCTPTransport{dataChannels: []*DataChannel{}}
ret := &SCTPTransport{
dataChannels: []*DataChannel{},
dataChannelIDsUsed: make(map[uint16]struct{}),
}

for i := range ids {
id := ids[i]
ret.dataChannels = append(ret.dataChannels, &DataChannel{id: &id})
ret.dataChannelIDsUsed[id] = struct{}{}
}

return ret
Expand Down Expand Up @@ -46,5 +50,8 @@ func TestGenerateDataChannelID(t *testing.T) {
if *idPtr != testCase.result {
t.Errorf("Wrong id: %d expected %d", *idPtr, testCase.result)
}
if _, ok := testCase.s.dataChannelIDsUsed[*idPtr]; !ok {
t.Errorf("expected new id to be added to the map: %d", *idPtr)
}
}
}

0 comments on commit 2268938

Please sign in to comment.