Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Drop reference to detached datachannels #2696

Merged
merged 1 commit into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions datachannel.go
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,6 @@ func (d *DataChannel) ensureOpen() error {
// resulting DataChannel object.
func (d *DataChannel) Detach() (datachannel.ReadWriteCloser, error) {
d.mu.Lock()
defer d.mu.Unlock()

if !d.api.settingEngine.detach.DataChannels {
return nil, errDetachNotEnabled
Expand All @@ -432,7 +431,28 @@ func (d *DataChannel) Detach() (datachannel.ReadWriteCloser, error) {

d.detachCalled = true

return d.dataChannel, nil
dataChannel := d.dataChannel
d.mu.Unlock()

// Remove the reference from SCTPTransport so that the datachannel
// can be garbage collected on close
d.sctpTransport.lock.Lock()
n := len(d.sctpTransport.dataChannels)
j := 0
for i := 0; i < n; i++ {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does d.sctpTransport.dataChannels need to remain in order? Can you swap remove instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I can. I just didn't want to change that property here.

if d == d.sctpTransport.dataChannels[i] {
continue
}
d.sctpTransport.dataChannels[j] = d.sctpTransport.dataChannels[i]
j++
}
for i := j; i < n; i++ {
d.sctpTransport.dataChannels[i] = nil
}
d.sctpTransport.dataChannels = d.sctpTransport.dataChannels[:j]
d.sctpTransport.lock.Unlock()

return dataChannel, nil
}

// Close Closes the DataChannel. It may be called regardless of whether
Expand Down
54 changes: 54 additions & 0 deletions datachannel_go_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -692,3 +692,57 @@ func TestDataChannel_Dial(t *testing.T) {
closePair(t, offerPC, answerPC, done)
})
}

func TestDetachRemovesDatachannelReference(t *testing.T) {
// Use Detach data channels mode
s := SettingEngine{}
s.DetachDataChannels()
api := NewAPI(WithSettingEngine(s))

// Set up two peer connections.
config := Configuration{}
pca, err := api.NewPeerConnection(config)
if err != nil {
t.Fatal(err)
}
pcb, err := api.NewPeerConnection(config)
if err != nil {
t.Fatal(err)
}

defer closePairNow(t, pca, pcb)

dcChan := make(chan *DataChannel, 1)
pcb.OnDataChannel(func(d *DataChannel) {
d.OnOpen(func() {
if _, detachErr := d.Detach(); detachErr != nil {
t.Error(detachErr)
}

dcChan <- d
})
})

if err = signalPair(pca, pcb); err != nil {
t.Fatal(err)
}

attached, err := pca.CreateDataChannel("", nil)
if err != nil {
t.Fatal(err)
}
open := make(chan struct{}, 1)
attached.OnOpen(func() {
open <- struct{}{}
})
<-open

d := <-dcChan
d.sctpTransport.lock.RLock()
defer d.sctpTransport.lock.RUnlock()
for _, dc := range d.sctpTransport.dataChannels[:cap(d.sctpTransport.dataChannels)] {
if dc == d {
t.Errorf("expected sctpTransport to drop reference to datachannel")
}
}
}
3 changes: 3 additions & 0 deletions peerconnection.go
Original file line number Diff line number Diff line change
Expand Up @@ -2018,6 +2018,9 @@ func (pc *PeerConnection) CreateDataChannel(label string, options *DataChannelIn

pc.sctpTransport.lock.Lock()
pc.sctpTransport.dataChannels = append(pc.sctpTransport.dataChannels, d)
if d.ID() != nil {
pc.sctpTransport.dataChannelIDsUsed[*d.ID()] = struct{}{}
}
pc.sctpTransport.dataChannelsRequested++
pc.sctpTransport.lock.Unlock()

Expand Down
30 changes: 15 additions & 15 deletions sctptransport.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@

// DataChannels
dataChannels []*DataChannel
dataChannelIDsUsed map[uint16]struct{}
dataChannelsOpened uint32
dataChannelsRequested uint32
dataChannelsAccepted uint32
Expand All @@ -65,10 +66,11 @@
// meant to be used together with the basic WebRTC API.
func (api *API) NewSCTPTransport(dtls *DTLSTransport) *SCTPTransport {
res := &SCTPTransport{
dtlsTransport: dtls,
state: SCTPTransportStateConnecting,
api: api,
log: api.settingEngine.LoggerFactory.NewLogger("ortc"),
dtlsTransport: dtls,
state: SCTPTransportStateConnecting,
api: api,
log: api.settingEngine.LoggerFactory.NewLogger("ortc"),
dataChannelIDsUsed: make(map[uint16]struct{}),
}

res.updateMessageSize()
Expand Down Expand Up @@ -287,6 +289,13 @@
r.lock.Lock()
r.dataChannels = append(r.dataChannels, dc)
r.dataChannelsAccepted++
if dc.ID() != nil {
r.dataChannelIDsUsed[*dc.ID()] = struct{}{}
} else {
// This cannot happen, the constructor for this datachannel in the caller
// takes a pointer to the id.
r.log.Errorf("accepted data channel with no ID")
}

Check warning on line 298 in sctptransport.go

View check run for this annotation

Codecov / codecov/patch

sctptransport.go#L295-L298

Added lines #L295 - L298 were not covered by tests
handler := r.onDataChannelHandler
r.lock.Unlock()

Expand Down Expand Up @@ -393,21 +402,12 @@
r.lock.Lock()
defer r.lock.Unlock()

// Create map of ids so we can compare without double-looping each time.
idsMap := make(map[uint16]struct{}, len(r.dataChannels))
for _, dc := range r.dataChannels {
if dc.ID() == nil {
continue
}

idsMap[*dc.ID()] = struct{}{}
}

for ; id < max-1; id += 2 {
if _, ok := idsMap[id]; ok {
if _, ok := r.dataChannelIDsUsed[id]; ok {
continue
}
*idOut = &id
r.dataChannelIDsUsed[id] = struct{}{}
return nil
}

Expand Down
9 changes: 8 additions & 1 deletion sctptransport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,15 @@ import "testing"

func TestGenerateDataChannelID(t *testing.T) {
sctpTransportWithChannels := func(ids []uint16) *SCTPTransport {
ret := &SCTPTransport{dataChannels: []*DataChannel{}}
ret := &SCTPTransport{
dataChannels: []*DataChannel{},
dataChannelIDsUsed: make(map[uint16]struct{}),
}

for i := range ids {
id := ids[i]
ret.dataChannels = append(ret.dataChannels, &DataChannel{id: &id})
ret.dataChannelIDsUsed[id] = struct{}{}
}

return ret
Expand Down Expand Up @@ -46,5 +50,8 @@ func TestGenerateDataChannelID(t *testing.T) {
if *idPtr != testCase.result {
t.Errorf("Wrong id: %d expected %d", *idPtr, testCase.result)
}
if _, ok := testCase.s.dataChannelIDsUsed[*idPtr]; !ok {
t.Errorf("expected new id to be added to the map: %d", *idPtr)
}
}
}
Loading