diff --git a/datachannel.go b/datachannel.go index 96ae9121f48..a846eb58e3d 100644 --- a/datachannel.go +++ b/datachannel.go @@ -420,7 +420,6 @@ func (d *DataChannel) ensureOpen() error { // resulting DataChannel object. func (d *DataChannel) Detach() (datachannel.ReadWriteCloser, error) { d.mu.Lock() - defer d.mu.Unlock() if !d.api.settingEngine.detach.DataChannels { return nil, errDetachNotEnabled @@ -432,7 +431,28 @@ func (d *DataChannel) Detach() (datachannel.ReadWriteCloser, error) { d.detachCalled = true - return d.dataChannel, nil + dataChannel := d.dataChannel + d.mu.Unlock() + + // Remove the reference from SCTPTransport so that the datachannel + // can be garbage collected on close + d.sctpTransport.lock.Lock() + n := len(d.sctpTransport.dataChannels) + j := 0 + for i := 0; i < n; i++ { + if d == d.sctpTransport.dataChannels[i] { + continue + } + d.sctpTransport.dataChannels[j] = d.sctpTransport.dataChannels[i] + j++ + } + for i := j; i < n; i++ { + d.sctpTransport.dataChannels[i] = nil + } + d.sctpTransport.dataChannels = d.sctpTransport.dataChannels[:j] + d.sctpTransport.lock.Unlock() + + return dataChannel, nil } // Close Closes the DataChannel. It may be called regardless of whether diff --git a/datachannel_go_test.go b/datachannel_go_test.go index 8c02a820b83..8abe4ac634c 100644 --- a/datachannel_go_test.go +++ b/datachannel_go_test.go @@ -692,3 +692,58 @@ func TestDataChannel_Dial(t *testing.T) { closePair(t, offerPC, answerPC, done) }) } + +func TestDetachRemovesDatachannelReference(t *testing.T) { + // Use Detach data channels mode + s := SettingEngine{} + s.DetachDataChannels() + api := NewAPI(WithSettingEngine(s)) + + // Set up two peer connections. + config := Configuration{} + pca, err := api.NewPeerConnection(config) + if err != nil { + t.Fatal(err) + } + pcb, err := api.NewPeerConnection(config) + if err != nil { + t.Fatal(err) + } + + defer closePairNow(t, pca, pcb) + + dcChan := make(chan *DataChannel, 1) + pcb.OnDataChannel(func(d *DataChannel) { + d.OnOpen(func() { + _, err := d.Detach() + if err != nil { + t.Error(err) + } + + dcChan <- d + }) + }) + + if err = signalPair(pca, pcb); err != nil { + t.Fatal(err) + } + + attached, err := pca.CreateDataChannel("", nil) + if err != nil { + t.Fatal(err) + } + open := make(chan struct{}, 1) + attached.OnOpen(func() { + open <- struct{}{} + }) + <-open + + d := <-dcChan + d.sctpTransport.lock.RLock() + defer d.sctpTransport.lock.RUnlock() + for _, dc := range d.sctpTransport.dataChannels[:cap(d.sctpTransport.dataChannels)] { + if dc == d { + t.Errorf("expected sctpTransport to drop reference to datachannel") + } + } +} diff --git a/peerconnection.go b/peerconnection.go index aef7b7b32dd..b69952ba1f9 100644 --- a/peerconnection.go +++ b/peerconnection.go @@ -2029,6 +2029,9 @@ func (pc *PeerConnection) CreateDataChannel(label string, options *DataChannelIn pc.sctpTransport.lock.Lock() pc.sctpTransport.dataChannels = append(pc.sctpTransport.dataChannels, d) + if d.ID() != nil { + pc.sctpTransport.dataChannelIDsUsed[*d.ID()] = struct{}{} + } pc.sctpTransport.dataChannelsRequested++ pc.sctpTransport.lock.Unlock() diff --git a/sctptransport.go b/sctptransport.go index 8b0a8da7cf9..4632102d9ec 100644 --- a/sctptransport.go +++ b/sctptransport.go @@ -52,6 +52,7 @@ type SCTPTransport struct { // DataChannels dataChannels []*DataChannel + dataChannelIDsUsed map[uint16]struct{} dataChannelsOpened uint32 dataChannelsRequested uint32 dataChannelsAccepted uint32 @@ -65,10 +66,11 @@ type SCTPTransport struct { // meant to be used together with the basic WebRTC API. func (api *API) NewSCTPTransport(dtls *DTLSTransport) *SCTPTransport { res := &SCTPTransport{ - dtlsTransport: dtls, - state: SCTPTransportStateConnecting, - api: api, - log: api.settingEngine.LoggerFactory.NewLogger("ortc"), + dtlsTransport: dtls, + state: SCTPTransportStateConnecting, + api: api, + log: api.settingEngine.LoggerFactory.NewLogger("ortc"), + dataChannelIDsUsed: make(map[uint16]struct{}), } res.updateMessageSize() @@ -287,6 +289,13 @@ func (r *SCTPTransport) onDataChannel(dc *DataChannel) (done chan struct{}) { r.lock.Lock() r.dataChannels = append(r.dataChannels, dc) r.dataChannelsAccepted++ + if dc.ID() != nil { + r.dataChannelIDsUsed[*dc.ID()] = struct{}{} + } else { + // This cannot happen, the constructor for this datachannel in the caller + // takes a pointer to the id. + r.log.Errorf("accepted data channel with no ID") + } handler := r.onDataChannelHandler r.lock.Unlock() @@ -393,21 +402,12 @@ func (r *SCTPTransport) generateAndSetDataChannelID(dtlsRole DTLSRole, idOut **u r.lock.Lock() defer r.lock.Unlock() - // Create map of ids so we can compare without double-looping each time. - idsMap := make(map[uint16]struct{}, len(r.dataChannels)) - for _, dc := range r.dataChannels { - if dc.ID() == nil { - continue - } - - idsMap[*dc.ID()] = struct{}{} - } - for ; id < max-1; id += 2 { - if _, ok := idsMap[id]; ok { + if _, ok := r.dataChannelIDsUsed[id]; ok { continue } *idOut = &id + r.dataChannelIDsUsed[id] = struct{}{} return nil } diff --git a/sctptransport_test.go b/sctptransport_test.go index 74313debca0..9943e8f0629 100644 --- a/sctptransport_test.go +++ b/sctptransport_test.go @@ -10,11 +10,15 @@ import "testing" func TestGenerateDataChannelID(t *testing.T) { sctpTransportWithChannels := func(ids []uint16) *SCTPTransport { - ret := &SCTPTransport{dataChannels: []*DataChannel{}} + ret := &SCTPTransport{ + dataChannels: []*DataChannel{}, + dataChannelIDsUsed: make(map[uint16]struct{}), + } for i := range ids { id := ids[i] ret.dataChannels = append(ret.dataChannels, &DataChannel{id: &id}) + ret.dataChannelIDsUsed[id] = struct{}{} } return ret @@ -46,5 +50,8 @@ func TestGenerateDataChannelID(t *testing.T) { if *idPtr != testCase.result { t.Errorf("Wrong id: %d expected %d", *idPtr, testCase.result) } + if _, ok := testCase.s.dataChannelIDsUsed[*idPtr]; !ok { + t.Errorf("expected new id to be added to the map: %d", *idPtr) + } } }