diff --git a/association.go b/association.go index 9383625..559f988 100644 --- a/association.go +++ b/association.go @@ -17,6 +17,7 @@ import ( "github.com/pion/logging" "github.com/pion/randutil" + "github.com/pion/transport/v3/deadline" ) // Port 5000 shows up in examples for SDPs used by WebRTC. Since this implementation @@ -251,6 +252,10 @@ type Association struct { delayedAckTriggered bool immediateAckTriggered bool + blockWrite bool + writePending bool + writeNotify chan struct{} + name string log logging.LeveledLogger } @@ -264,6 +269,7 @@ type Config struct { MaxMessageSize uint32 EnableZeroChecksum bool LoggerFactory logging.LoggerFactory + BlockWrite bool // congestion control configuration // RTOMax is the maximum retransmission timeout in milliseconds @@ -375,6 +381,8 @@ func createAssociation(config Config) *Association { stats: &associationStats{}, log: config.LoggerFactory.NewLogger("sctp"), name: config.Name, + blockWrite: config.BlockWrite, + writeNotify: make(chan struct{}, 1), } if a.name == "" { @@ -675,6 +683,20 @@ func (a *Association) awakeWriteLoop() { } } +func (a *Association) isBlockWrite() bool { + return a.blockWrite +} + +// Mark the association is writable and unblock the waiting write, +// the caller should hold the association write lock. +func (a *Association) notifyBlockWritable() { + a.writePending = false + select { + case a.writeNotify <- struct{}{}: + default: + } +} + // unregisterStream un-registers a stream from the association // The caller should hold the association write lock. func (a *Association) unregisterStream(s *Stream, err error) { @@ -1555,6 +1577,7 @@ func (a *Association) createStream(streamIdentifier uint16, accept bool) *Stream reassemblyQueue: newReassemblyQueue(streamIdentifier), log: a.log, name: fmt.Sprintf("%d:%s", streamIdentifier, a.name), + writeDeadline: deadline.New(), } s.readNotifier = sync.NewCond(&s.lock) @@ -2338,6 +2361,11 @@ func (a *Association) popPendingDataChunksToSend() ([]*chunkPayloadData, []uint1 } } + if a.blockWrite && len(chunks) > 0 && a.pendingQueue.size() == 0 { + a.log.Tracef("[%s] all pending data have been sent, notify writable", a.name) + a.notifyBlockWritable() + } + return chunks, sisToReset } @@ -2375,21 +2403,35 @@ func (a *Association) bundleDataChunksIntoPackets(chunks []*chunkPayloadData) [] } // sendPayloadData sends the data chunks. -func (a *Association) sendPayloadData(chunks []*chunkPayloadData) error { +func (a *Association) sendPayloadData(ctx context.Context, chunks []*chunkPayloadData) error { a.lock.Lock() - defer a.lock.Unlock() state := a.getState() if state != established { + a.lock.Unlock() return fmt.Errorf("%w: state=%s", ErrPayloadDataStateNotExist, getAssociationStateString(state)) } + if a.blockWrite { + for a.writePending { + a.lock.Unlock() + select { + case <-ctx.Done(): + return ctx.Err() + case <-a.writeNotify: + a.lock.Lock() + } + } + a.writePending = true + } + // Push the chunks into the pending queue first. for _, c := range chunks { a.pendingQueue.push(c) } + a.lock.Unlock() a.awakeWriteLoop() return nil } diff --git a/association_test.go b/association_test.go index def1c01..44a71ac 100644 --- a/association_test.go +++ b/association_test.go @@ -2845,7 +2845,7 @@ func TestAssociationReceiveWindow(t *testing.T) { done := make(chan bool) go func() { - chunks := s1.packetize(make([]byte, 1000), PayloadTypeWebRTCBinary) + chunks, _ := s1.packetize(make([]byte, 1000), PayloadTypeWebRTCBinary) chunks = chunks[:1] chunk := chunks[0] // Fake the TSN and enqueue 1 chunk with a very high tsn in the payload queue @@ -3016,7 +3016,7 @@ func TestAssociationMaxTSNOffset(t *testing.T) { require.NoError(t, err) require.Equal(t, uint16(1), s2.streamIdentifier) - chunks := s1.packetize(make([]byte, 1000), PayloadTypeWebRTCBinary) + chunks, _ := s1.packetize(make([]byte, 1000), PayloadTypeWebRTCBinary) chunks = chunks[:1] sendChunk := func(tsn uint32) { chunk := chunks[0] @@ -3616,3 +3616,94 @@ func TestAssociation_OpenStreamAfterInternalClose(t *testing.T) { require.Equal(t, 0, len(a1.streams)) require.Equal(t, 0, len(a2.streams)) } + +func TestAssociation_BlockWrite(t *testing.T) { + checkGoroutineLeaks(t) + + conn1, conn2 := createUDPConnPair() + a1, a2, err := createAssociationPairWithConfig(conn1, conn2, Config{BlockWrite: true, MaxReceiveBufferSize: 4000}) + require.NoError(t, err) + + defer noErrorClose(t, a2.Close) + defer noErrorClose(t, a1.Close) + s1, err := a1.OpenStream(1, PayloadTypeWebRTCBinary) + require.NoError(t, err) + defer noErrorClose(t, s1.Close) + _, err = s1.WriteSCTP([]byte("hello"), PayloadTypeWebRTCBinary) + require.NoError(t, err) + s2, err := a2.AcceptStream() + require.NoError(t, err) + + data := make([]byte, 4000) + n, err := s2.Read(data) + require.NoError(t, err) + require.Equal(t, "hello", string(data[:n])) + + // Write should block until data is sent + dbConn1, ok := conn1.(*dumbConn2) + require.True(t, ok) + dbConn2, ok := conn2.(*dumbConn2) + require.True(t, ok) + + dbConn1.remoteInboundHandler = dbConn2.inboundHandler + + _, err = s1.WriteSCTP(data, PayloadTypeWebRTCBinary) + require.NoError(t, err) + _, err = s1.WriteSCTP(data, PayloadTypeWebRTCBinary) + require.NoError(t, err) + + // test write deadline + // a2's awnd is 0, so write should be blocked + require.NoError(t, s1.SetWriteDeadline(time.Now().Add(100*time.Millisecond))) + _, err = s1.WriteSCTP(data, PayloadTypeWebRTCBinary) + require.ErrorIs(t, err, context.DeadlineExceeded, err) + + // test write deadline cancel + require.NoError(t, s1.SetWriteDeadline(time.Time{})) + var deadLineCanceled atomic.Bool + writeCanceled := make(chan struct{}, 2) + // both write should be blocked and canceled by deadline + go func() { + _, err1 := s1.WriteSCTP(data, PayloadTypeWebRTCBinary) + require.ErrorIs(t, err, context.DeadlineExceeded, err1) + require.True(t, deadLineCanceled.Load()) + writeCanceled <- struct{}{} + }() + go func() { + _, err1 := s1.WriteSCTP(data, PayloadTypeWebRTCBinary) + require.ErrorIs(t, err, context.DeadlineExceeded, err1) + require.True(t, deadLineCanceled.Load()) + writeCanceled <- struct{}{} + }() + time.Sleep(100 * time.Millisecond) + deadLineCanceled.Store(true) + require.NoError(t, s1.SetWriteDeadline(time.Now().Add(-1*time.Second))) + <-writeCanceled + <-writeCanceled + require.NoError(t, s1.SetWriteDeadline(time.Time{})) + + rn, rerr := s2.Read(data) + require.NoError(t, rerr) + require.Equal(t, 4000, rn) + + // slow reader and fast writer, make sure all write is blocked + go func() { + for { + bytes := make([]byte, 4000) + rn, rerr = s2.Read(bytes) + if errors.Is(rerr, io.EOF) { + return + } + require.NoError(t, rerr) + require.Equal(t, 4000, rn) + time.Sleep(5 * time.Millisecond) + } + }() + + for i := 0; i < 10; i++ { + _, err = s1.Write(data) + require.NoError(t, err) + // bufferedAmount should not exceed RWND+message size (inflight + pending) + require.LessOrEqual(t, s1.BufferedAmount(), uint64(4000*2)) + } +} diff --git a/stream.go b/stream.go index 47e06f3..ed06d69 100644 --- a/stream.go +++ b/stream.go @@ -13,6 +13,7 @@ import ( "time" "github.com/pion/logging" + "github.com/pion/transport/v3/deadline" ) const ( @@ -65,6 +66,8 @@ type Stream struct { readNotifier *sync.Cond readErr error readTimeoutCancel chan struct{} + writeDeadline *deadline.Deadline + writeLock sync.Mutex unordered bool reliabilityType byte reliabilityValue uint32 @@ -272,16 +275,44 @@ func (s *Stream) WriteSCTP(p []byte, ppi PayloadProtocolIdentifier) (int, error) return 0, ErrStreamClosed } - chunks := s.packetize(p, ppi) + // the send could fail if the association is blocked for writing (timeout), it will left a hole + // in the stream sequence number space, so we need to lock the write to avoid concurrent send and decrement + // the sequence number in case of failure + if s.association.isBlockWrite() { + s.writeLock.Lock() + } + chunks, unordered := s.packetize(p, ppi) n := len(p) - err := s.association.sendPayloadData(chunks) + err := s.association.sendPayloadData(s.writeDeadline, chunks) if err != nil { - return n, ErrStreamClosed + s.lock.Lock() + s.bufferedAmount -= uint64(n) + if !unordered { + s.sequenceNumber-- + } + s.lock.Unlock() + } + if s.association.isBlockWrite() { + s.writeLock.Unlock() + } + return n, err +} + +// SetWriteDeadline sets the write deadline in an identical way to net.Conn, it will only work for blocking writes +func (s *Stream) SetWriteDeadline(deadline time.Time) error { + s.writeDeadline.Set(deadline) + return nil +} + +// SetDeadline sets the read and write deadlines in an identical way to net.Conn +func (s *Stream) SetDeadline(t time.Time) error { + if err := s.SetReadDeadline(t); err != nil { + return err } - return n, nil + return s.SetWriteDeadline(t) } -func (s *Stream) packetize(raw []byte, ppi PayloadProtocolIdentifier) []*chunkPayloadData { +func (s *Stream) packetize(raw []byte, ppi PayloadProtocolIdentifier) ([]*chunkPayloadData, bool) { s.lock.Lock() defer s.lock.Unlock() @@ -336,7 +367,7 @@ func (s *Stream) packetize(raw []byte, ppi PayloadProtocolIdentifier) []*chunkPa s.bufferedAmount += uint64(len(raw)) s.log.Tracef("[%s] bufferedAmount = %d", s.name, s.bufferedAmount) - return chunks + return chunks, unordered } // Close closes the write-direction of the stream.