diff --git a/datachannel.go b/datachannel.go index 63bad0a843e..5fef7900174 100644 --- a/datachannel.go +++ b/datachannel.go @@ -450,6 +450,12 @@ func (d *DataChannel) ensureOpen() error { // pion/datachannel documentation for the correct way to handle the // resulting DataChannel object. func (d *DataChannel) Detach() (datachannel.ReadWriteCloser, error) { + return d.DetachWithDeadline() +} + +// DetachWithDeadline allows you to detach the underlying datachannel. +// It is the same as Detach but returns a ReadWriteCloserDeadliner. +func (d *DataChannel) DetachWithDeadline() (datachannel.ReadWriteCloserDeadliner, error) { d.mu.Lock() if !d.api.settingEngine.detach.DataChannels { diff --git a/sctptransport.go b/sctptransport.go index 68f40b9079d..9e227df31e5 100644 --- a/sctptransport.go +++ b/sctptransport.go @@ -98,7 +98,7 @@ func (r *SCTPTransport) GetCapabilities() SCTPCapabilities { // Start the SCTPTransport. Since both local and remote parties must mutually // create an SCTPTransport, SCTP SO (Simultaneous Open) is used to establish // a connection over SCTP. -func (r *SCTPTransport) Start(SCTPCapabilities) error { +func (r *SCTPTransport) Start(_ SCTPCapabilities) error { if r.isStarted { return nil } @@ -114,6 +114,7 @@ func (r *SCTPTransport) Start(SCTPCapabilities) error { EnableZeroChecksum: r.api.settingEngine.sctp.enableZeroChecksum, LoggerFactory: r.api.settingEngine.LoggerFactory, RTOMax: float64(r.api.settingEngine.sctp.rtoMax) / float64(time.Millisecond), + BlockWrite: r.api.settingEngine.detach.DataChannels && r.api.settingEngine.dataChannelBlockWrite, }) if err != nil { return err diff --git a/settingengine.go b/settingengine.go index fb2c40bc5fc..edfb1549bc8 100644 --- a/settingengine.go +++ b/settingengine.go @@ -103,6 +103,7 @@ type SettingEngine struct { iceMaxBindingRequests *uint16 fireOnTrackBeforeFirstRTP bool disableCloseByDTLS bool + dataChannelBlockWrite bool } // getReceiveMTU returns the configured MTU. If SettingEngine's MTU is configured to 0 it returns the default @@ -121,6 +122,12 @@ func (e *SettingEngine) DetachDataChannels() { e.detach.DataChannels = true } +// EnableDataChannelBlockWrite allows data channels to block on write, +// it only works if DetachDataChannels is enabled +func (e *SettingEngine) EnableDataChannelBlockWrite(nonblockWrite bool) { + e.dataChannelBlockWrite = nonblockWrite +} + // SetSRTPProtectionProfiles allows the user to override the default SRTP Protection Profiles // The default srtp protection profiles are provided by the function `defaultSrtpProtectionProfiles` func (e *SettingEngine) SetSRTPProtectionProfiles(profiles ...dtls.SRTPProtectionProfile) { diff --git a/settingengine_test.go b/settingengine_test.go index 8b786859fc7..3496fc3acd4 100644 --- a/settingengine_test.go +++ b/settingengine_test.go @@ -12,6 +12,7 @@ import ( "testing" "time" + "github.com/pion/datachannel" "github.com/pion/dtls/v3/pkg/crypto/elliptic" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/ice/v4" @@ -417,3 +418,45 @@ func TestDisableCloseByDTLS(t *testing.T) { assert.True(t, offer.ConnectionState() == PeerConnectionStateConnected) assert.NoError(t, offer.Close()) } + +func TestEnableDataChannelBlockWrite(t *testing.T) { + lim := test.TimeOut(time.Second * 30) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + s := SettingEngine{} + s.DetachDataChannels() + s.EnableDataChannelBlockWrite(true) + s.SetSCTPMaxReceiveBufferSize(1500) + + offer, answer, err := NewAPI(WithSettingEngine(s)).newPair(Configuration{}) + assert.NoError(t, err) + + dc, err := offer.CreateDataChannel("data", nil) + assert.NoError(t, err) + detachChan := make(chan datachannel.ReadWriteCloserDeadliner, 1) + dc.OnOpen(func() { + detached, err1 := dc.DetachWithDeadline() + assert.NoError(t, err1) + detachChan <- detached + }) + + assert.NoError(t, signalPair(offer, answer)) + untilConnectionState(PeerConnectionStateConnected, offer, answer).Wait() + + // write should block and return deadline exceeded since the receiver is not reading + // and the buffer size is 1500 bytes + rawDC := <-detachChan + assert.NoError(t, rawDC.SetWriteDeadline(time.Now().Add(time.Second))) + buf := make([]byte, 1000) + for i := 0; i < 10; i++ { + _, err = rawDC.Write(buf) + if err != nil { + break + } + } + assert.ErrorIs(t, err, context.DeadlineExceeded) + closePairNow(t, offer, answer) +}