Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide SCTP Association OnClose callback #2861

Merged
merged 2 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions sctptransport.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
// OnStateChange func()

onErrorHandler func(error)
onCloseHandler func(error)

sctpAssociation *sctp.Association
onDataChannelHandler func(*DataChannel)
Expand Down Expand Up @@ -174,6 +175,7 @@
dataChannels = append(dataChannels, dc.dataChannel)
}
r.lock.RUnlock()

ACCEPT:
for {
dc, err := datachannel.Accept(a, &datachannel.Config{
Expand All @@ -183,6 +185,9 @@
if !errors.Is(err, io.EOF) {
r.log.Errorf("Failed to accept data channel: %v", err)
r.onError(err)
r.onClose(err)

Check warning on line 188 in sctptransport.go

View check run for this annotation

Codecov / codecov/patch

sctptransport.go#L188

Added line #L188 was not covered by tests
} else {
r.onClose(nil)
}
return
}
Expand Down Expand Up @@ -230,9 +235,14 @@
MaxRetransmits: maxRetransmits,
}, r, r.api.settingEngine.LoggerFactory.NewLogger("ortc"))
if err != nil {
// This data channel is invalid. Close it and log an error.
if err1 := dc.Close(); err1 != nil {
r.log.Errorf("Failed to close invalid data channel: %v", err1)

Check warning on line 240 in sctptransport.go

View check run for this annotation

Codecov / codecov/patch

sctptransport.go#L239-L240

Added lines #L239 - L240 were not covered by tests
}
r.log.Errorf("Failed to accept data channel: %v", err)
r.onError(err)
return
// We've received a datachannel with invalid configuration. We can still receive other datachannels.
continue ACCEPT

Check warning on line 245 in sctptransport.go

View check run for this annotation

Codecov / codecov/patch

sctptransport.go#L245

Added line #L245 was not covered by tests
}

<-r.onDataChannel(rtcDC)
Expand All @@ -249,8 +259,7 @@
}
}

// OnError sets an event handler which is invoked when
// the SCTP connection error occurs.
// OnError sets an event handler which is invoked when the SCTP Association errors.
func (r *SCTPTransport) OnError(f func(err error)) {
r.lock.Lock()
defer r.lock.Unlock()
Expand All @@ -267,6 +276,23 @@
}
}

// OnClose sets an event handler which is invoked when the SCTP Association closes.
func (r *SCTPTransport) OnClose(f func(err error)) {
r.lock.Lock()
defer r.lock.Unlock()
r.onCloseHandler = f
}

func (r *SCTPTransport) onClose(err error) {
r.lock.RLock()
handler := r.onCloseHandler
r.lock.RUnlock()

if handler != nil {
go handler(err)
}
}

// OnDataChannel sets an event handler which is invoked when a data
// channel message arrives from a remote peer.
func (r *SCTPTransport) OnDataChannel(f func(*DataChannel)) {
Expand Down
73 changes: 72 additions & 1 deletion sctptransport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@

package webrtc

import "testing"
import (
"bytes"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func TestGenerateDataChannelID(t *testing.T) {
sctpTransportWithChannels := func(ids []uint16) *SCTPTransport {
Expand Down Expand Up @@ -48,3 +54,68 @@ func TestGenerateDataChannelID(t *testing.T) {
}
}
}

func TestSCTPTransportOnClose(t *testing.T) {
offerPC, answerPC, err := newPair()
require.NoError(t, err)

defer closePairNow(t, offerPC, answerPC)

answerPC.OnDataChannel(func(dc *DataChannel) {
dc.OnMessage(func(_ DataChannelMessage) {
if err1 := dc.Send([]byte("hello")); err1 != nil {
t.Error("failed to send message")
}
})
})

recvMsg := make(chan struct{}, 1)
offerPC.OnConnectionStateChange(func(state PeerConnectionState) {
if state == PeerConnectionStateConnected {
defer func() {
offerPC.OnConnectionStateChange(nil)
}()

dc, createErr := offerPC.CreateDataChannel(expectedLabel, nil)
if createErr != nil {
t.Errorf("Failed to create a PC pair for testing")
return
}
dc.OnMessage(func(msg DataChannelMessage) {
if !bytes.Equal(msg.Data, []byte("hello")) {
t.Error("invalid msg received")
}
recvMsg <- struct{}{}
})
dc.OnOpen(func() {
if err1 := dc.Send([]byte("hello")); err1 != nil {
t.Error("failed to send initial msg", err1)
}
})
}
})

err = signalPair(offerPC, answerPC)
require.NoError(t, err)

select {
case <-recvMsg:
case <-time.After(5 * time.Second):
t.Fatal("timed out")
}

// setup SCTP OnClose callback
ch := make(chan error, 1)
answerPC.SCTP().OnClose(func(err error) {
ch <- err
})

err = offerPC.Close() // This will trigger sctp onclose callback on remote
require.NoError(t, err)

select {
case <-ch:
case <-time.After(5 * time.Second):
t.Fatal("timed out")
}
}
Loading