diff --git a/sctptransport.go b/sctptransport.go index 9e227df31e5..6f03d1c0f0a 100644 --- a/sctptransport.go +++ b/sctptransport.go @@ -142,7 +142,7 @@ func (r *SCTPTransport) Start(_ SCTPCapabilities) error { r.dataChannelsOpened += openedDCCount r.lock.Unlock() - go r.acceptDataChannels(sctpAssociation) + go r.acceptDataChannels(sctpAssociation, dataChannels) return nil } @@ -163,10 +163,9 @@ func (r *SCTPTransport) Stop() error { return nil } -func (r *SCTPTransport) acceptDataChannels(a *sctp.Association) { - r.lock.RLock() - dataChannels := make([]*datachannel.DataChannel, 0, len(r.dataChannels)) - for _, dc := range r.dataChannels { +func (r *SCTPTransport) acceptDataChannels(a *sctp.Association, existingDataChannels []*DataChannel) { + dataChannels := make([]*datachannel.DataChannel, 0, len(existingDataChannels)) + for _, dc := range existingDataChannels { dc.mu.Lock() isNil := dc.dataChannel == nil dc.mu.Unlock() @@ -175,8 +174,6 @@ func (r *SCTPTransport) acceptDataChannels(a *sctp.Association) { } dataChannels = append(dataChannels, dc.dataChannel) } - r.lock.RUnlock() - ACCEPT: for { dc, err := datachannel.Accept(a, &datachannel.Config{ diff --git a/sctptransport_test.go b/sctptransport_test.go index 889518dd05b..bd28120f55e 100644 --- a/sctptransport_test.go +++ b/sctptransport_test.go @@ -8,6 +8,7 @@ package webrtc import ( "bytes" + "sync" "testing" "time" @@ -126,3 +127,127 @@ func TestSCTPTransportOnClose(t *testing.T) { t.Fatal("timed out") } } + +func TestSCTPTransportOutOfBandNegotiatedDataChannelDetach(t *testing.T) { + // create lots of peer connections + done := make(chan struct{}, 50) + const N = 50 + for i := 0; i < N; i++ { + go func() { + // Use Detach data channels mode + s := SettingEngine{} + s.DetachDataChannels() + api := NewAPI(WithSettingEngine(s)) + + // Set up two peer connections. + config := Configuration{} + offerPC, err := api.NewPeerConnection(config) + if err != nil { + t.Error(err) + return + } + answerPC, err := api.NewPeerConnection(config) + if err != nil { + t.Error(err) + return + } + + defer closePairNow(t, offerPC, answerPC) + defer func() { done <- struct{}{} }() + + negotiated := true + id := uint16(0) + readDetach := make(chan struct{}) + dc1, err := offerPC.CreateDataChannel("", &DataChannelInit{ + Negotiated: &negotiated, + ID: &id, + }) + if err != nil { + t.Error(err) + return + } + dc1.OnOpen(func() { + _, _ = dc1.Detach() + close(readDetach) + }) + + writeDetach := make(chan struct{}) + dc2, err := answerPC.CreateDataChannel("", &DataChannelInit{ + Negotiated: &negotiated, + ID: &id, + }) + if err != nil { + t.Error(err) + return + } + dc2.OnOpen(func() { + _, _ = dc2.Detach() + close(writeDetach) + }) + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + connestd := make(chan struct{}, 1) + offerPC.OnConnectionStateChange(func(state PeerConnectionState) { + if state == PeerConnectionStateConnected { + connestd <- struct{}{} + } + }) + select { + case <-connestd: + case <-time.After(10 * time.Second): + t.Error("conn establishment timed out") + return + } + <-readDetach + err := dc1.dataChannel.SetReadDeadline(time.Now().Add(10 * time.Second)) + if err != nil { + t.Error(err) + return + } + buf := make([]byte, 10) + n, err := dc1.dataChannel.Read(buf) + if err != nil { + t.Error(err) + return + } + if string(buf[:n]) != "hello" { + t.Error("invalid read") + } + }() + go func() { + defer wg.Done() + connestd := make(chan struct{}, 1) + answerPC.OnConnectionStateChange(func(state PeerConnectionState) { + if state == PeerConnectionStateConnected { + connestd <- struct{}{} + } + }) + select { + case <-connestd: + case <-time.After(10 * time.Second): + t.Error("connection establishment timed out") + return + } + <-writeDetach + n, err := dc2.dataChannel.Write([]byte("hello")) + if err != nil || n != len("hello") { + t.Error(err) + } + }() + err = signalPair(offerPC, answerPC) + require.NoError(t, err) + wg.Wait() + }() + } + + for i := 0; i < N; i++ { + select { + case <-done: + case <-time.After(20 * time.Second): + t.Fatal("timed out") + } + } +}