Skip to content

Commit

Permalink
Merge branch 'master' into fix-interceptor-close-race
Browse files Browse the repository at this point in the history
  • Loading branch information
aalekseevx authored Dec 24, 2024
2 parents 114d37f + b82306a commit 5eed59c
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 7 deletions.
11 changes: 4 additions & 7 deletions sctptransport.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func (r *SCTPTransport) Start(_ SCTPCapabilities) error {
r.dataChannelsOpened += openedDCCount
r.lock.Unlock()

go r.acceptDataChannels(sctpAssociation)
go r.acceptDataChannels(sctpAssociation, dataChannels)

return nil
}
Expand All @@ -163,10 +163,9 @@ func (r *SCTPTransport) Stop() error {
return nil
}

func (r *SCTPTransport) acceptDataChannels(a *sctp.Association) {
r.lock.RLock()
dataChannels := make([]*datachannel.DataChannel, 0, len(r.dataChannels))
for _, dc := range r.dataChannels {
func (r *SCTPTransport) acceptDataChannels(a *sctp.Association, existingDataChannels []*DataChannel) {
dataChannels := make([]*datachannel.DataChannel, 0, len(existingDataChannels))
for _, dc := range existingDataChannels {
dc.mu.Lock()
isNil := dc.dataChannel == nil
dc.mu.Unlock()
Expand All @@ -175,8 +174,6 @@ func (r *SCTPTransport) acceptDataChannels(a *sctp.Association) {
}
dataChannels = append(dataChannels, dc.dataChannel)
}
r.lock.RUnlock()

ACCEPT:
for {
dc, err := datachannel.Accept(a, &datachannel.Config{
Expand Down
124 changes: 124 additions & 0 deletions sctptransport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package webrtc

import (
"bytes"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -126,3 +127,126 @@ func TestSCTPTransportOnClose(t *testing.T) {
t.Fatal("timed out")
}
}

func TestSCTPTransportOutOfBandNegotiatedDataChannelDetach(t *testing.T) {
const N = 10
done := make(chan struct{}, N)
for i := 0; i < N; i++ {
go func() {
// Use Detach data channels mode
s := SettingEngine{}
s.DetachDataChannels()
api := NewAPI(WithSettingEngine(s))

// Set up two peer connections.
config := Configuration{}
offerPC, err := api.NewPeerConnection(config)
if err != nil {
t.Error(err)
return
}
answerPC, err := api.NewPeerConnection(config)
if err != nil {
t.Error(err)
return
}

defer closePairNow(t, offerPC, answerPC)
defer func() { done <- struct{}{} }()

negotiated := true
id := uint16(0)
readDetach := make(chan struct{})
dc1, err := offerPC.CreateDataChannel("", &DataChannelInit{
Negotiated: &negotiated,
ID: &id,
})
if err != nil {
t.Error(err)
return
}
dc1.OnOpen(func() {
_, _ = dc1.Detach()
close(readDetach)
})

writeDetach := make(chan struct{})
dc2, err := answerPC.CreateDataChannel("", &DataChannelInit{
Negotiated: &negotiated,
ID: &id,
})
if err != nil {
t.Error(err)
return
}
dc2.OnOpen(func() {
_, _ = dc2.Detach()
close(writeDetach)
})

var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
connestd := make(chan struct{}, 1)
offerPC.OnConnectionStateChange(func(state PeerConnectionState) {
if state == PeerConnectionStateConnected {
connestd <- struct{}{}
}
})
select {
case <-connestd:
case <-time.After(10 * time.Second):
t.Error("conn establishment timed out")
return
}
<-readDetach
err1 := dc1.dataChannel.SetReadDeadline(time.Now().Add(10 * time.Second))
if err1 != nil {
t.Error(err)
return
}
buf := make([]byte, 10)
n, err1 := dc1.dataChannel.Read(buf)
if err1 != nil {
t.Error(err)
return
}
if string(buf[:n]) != "hello" {
t.Error("invalid read")
}
}()
go func() {
defer wg.Done()
connestd := make(chan struct{}, 1)
answerPC.OnConnectionStateChange(func(state PeerConnectionState) {
if state == PeerConnectionStateConnected {
connestd <- struct{}{}
}
})
select {
case <-connestd:
case <-time.After(10 * time.Second):
t.Error("connection establishment timed out")
return
}
<-writeDetach
n, err1 := dc2.dataChannel.Write([]byte("hello"))
if err1 != nil || n != len("hello") {
t.Error(err)
}
}()
err = signalPair(offerPC, answerPC)
require.NoError(t, err)
wg.Wait()
}()
}

for i := 0; i < N; i++ {
select {
case <-done:
case <-time.After(20 * time.Second):
t.Fatal("timed out")
}
}
}

0 comments on commit 5eed59c

Please sign in to comment.