From 150d26da97d577a290b775843a6799ac8247899a Mon Sep 17 00:00:00 2001 From: Sean DuBois Date: Fri, 22 Apr 2022 23:47:56 -0400 Subject: [PATCH] Use ICETransport private fields for PeerConnection PeerConnection used the public OnConnectionStateChange to track the status of the ICETransport. This was incorrect because a user can override this value at anytime. Add a new internalOnConnectionStateChangeHandler that is set directly by the PeerConnection and not accessible to the user. --- icetransport.go | 11 +++++++---- icetransport_test.go | 30 ++++++++++++++++++++---------- peerconnection.go | 2 +- 3 files changed, 28 insertions(+), 15 deletions(-) diff --git a/icetransport.go b/icetransport.go index 31cc4c9a5c6..5b1a1c3bd6c 100644 --- a/icetransport.go +++ b/icetransport.go @@ -22,8 +22,9 @@ type ICETransport struct { role ICERole - onConnectionStateChangeHandler atomic.Value // func(ICETransportState) - onSelectedCandidatePairChangeHandler atomic.Value // func(*ICECandidatePair) + onConnectionStateChangeHandler atomic.Value // func(ICETransportState) + internalOnConnectionStateChangeHandler atomic.Value // func(ICETransportState) + onSelectedCandidatePairChangeHandler atomic.Value // func(*ICECandidatePair) state atomic.Value // ICETransportState @@ -220,8 +221,10 @@ func (t *ICETransport) OnConnectionStateChange(f func(ICETransportState)) { } func (t *ICETransport) onConnectionStateChange(state ICETransportState) { - handler := t.onConnectionStateChangeHandler.Load() - if handler != nil { + if handler := t.onConnectionStateChangeHandler.Load(); handler != nil { + handler.(func(ICETransportState))(state) + } + if handler := t.internalOnConnectionStateChangeHandler.Load(); handler != nil { handler.(func(ICETransportState))(state) } } diff --git a/icetransport_test.go b/icetransport_test.go index c7edb4017bc..221441d0b8f 100644 --- a/icetransport_test.go +++ b/icetransport_test.go @@ -4,6 +4,7 @@ package webrtc import ( + "sync" "sync/atomic" "testing" "time" @@ -22,23 +23,32 @@ func TestICETransport_OnConnectionStateChange(t *testing.T) { pcOffer, pcAnswer, err := newPair() assert.NoError(t, err) - iceOfferComplete := make(chan struct{}) - iceAnswerComplete := make(chan struct{}) + var ( + iceComplete sync.WaitGroup + peerConnectionConnected sync.WaitGroup + ) + iceComplete.Add(2) + peerConnectionConnected.Add(2) - pcOffer.SCTP().Transport().ICETransport().OnConnectionStateChange(func(s ICETransportState) { + onIceComplete := func(s ICETransportState) { if s == ICETransportStateConnected { - close(iceOfferComplete) + iceComplete.Done() } - }) + } + pcOffer.SCTP().Transport().ICETransport().OnConnectionStateChange(onIceComplete) + pcAnswer.SCTP().Transport().ICETransport().OnConnectionStateChange(onIceComplete) - pcAnswer.SCTP().Transport().ICETransport().OnConnectionStateChange(func(s ICETransportState) { - if s == ICETransportStateConnected { - close(iceAnswerComplete) + onConnected := func(s PeerConnectionState) { + if s == PeerConnectionStateConnected { + peerConnectionConnected.Done() } - }) + } + pcOffer.OnConnectionStateChange(onConnected) + pcAnswer.OnConnectionStateChange(onConnected) assert.NoError(t, signalPair(pcOffer, pcAnswer)) - <-iceOfferComplete + iceComplete.Wait() + peerConnectionConnected.Wait() closePairNow(t, pcOffer, pcAnswer) } diff --git a/peerconnection.go b/peerconnection.go index 1794e8d41ef..64b75a2883e 100644 --- a/peerconnection.go +++ b/peerconnection.go @@ -755,7 +755,7 @@ func (pc *PeerConnection) updateConnectionState(iceConnectionState ICEConnection func (pc *PeerConnection) createICETransport() *ICETransport { t := pc.api.NewICETransport(pc.iceGatherer) - t.OnConnectionStateChange(func(state ICETransportState) { + t.internalOnConnectionStateChangeHandler.Store(func(state ICETransportState) { var cs ICEConnectionState switch state { case ICETransportStateNew: