Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add blocking write mode for association #356

Merged
merged 1 commit into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 44 additions & 2 deletions association.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

"github.com/pion/logging"
"github.com/pion/randutil"
"github.com/pion/transport/v3/deadline"
)

// Port 5000 shows up in examples for SDPs used by WebRTC. Since this implementation
Expand Down Expand Up @@ -251,6 +252,10 @@
delayedAckTriggered bool
immediateAckTriggered bool

blockWrite bool
writePending bool
writeNotify chan struct{}

name string
log logging.LeveledLogger
}
Expand All @@ -264,6 +269,7 @@
MaxMessageSize uint32
EnableZeroChecksum bool
LoggerFactory logging.LoggerFactory
BlockWrite bool

// congestion control configuration
// RTOMax is the maximum retransmission timeout in milliseconds
Expand Down Expand Up @@ -375,6 +381,8 @@
stats: &associationStats{},
log: config.LoggerFactory.NewLogger("sctp"),
name: config.Name,
blockWrite: config.BlockWrite,
writeNotify: make(chan struct{}, 1),
}

if a.name == "" {
Expand Down Expand Up @@ -675,6 +683,20 @@
}
}

func (a *Association) isBlockWrite() bool {
return a.blockWrite
}

// Mark the association is writable and unblock the waiting write,
// the caller should hold the association write lock.
func (a *Association) notifyBlockWritable() {
a.writePending = false
select {
case a.writeNotify <- struct{}{}:
default:

Check warning on line 696 in association.go

View check run for this annotation

Codecov / codecov/patch

association.go#L696

Added line #L696 was not covered by tests
}
}

// unregisterStream un-registers a stream from the association
// The caller should hold the association write lock.
func (a *Association) unregisterStream(s *Stream, err error) {
Expand Down Expand Up @@ -1555,6 +1577,7 @@
reassemblyQueue: newReassemblyQueue(streamIdentifier),
log: a.log,
name: fmt.Sprintf("%d:%s", streamIdentifier, a.name),
writeDeadline: deadline.New(),
}

s.readNotifier = sync.NewCond(&s.lock)
Expand Down Expand Up @@ -2338,6 +2361,11 @@
}
}

if a.blockWrite && len(chunks) > 0 && a.pendingQueue.size() == 0 {
a.log.Tracef("[%s] all pending data have been sent, notify writable", a.name)
a.notifyBlockWritable()
}

return chunks, sisToReset
}

Expand Down Expand Up @@ -2375,21 +2403,35 @@
}

// sendPayloadData sends the data chunks.
func (a *Association) sendPayloadData(chunks []*chunkPayloadData) error {
func (a *Association) sendPayloadData(ctx context.Context, chunks []*chunkPayloadData) error {
a.lock.Lock()
defer a.lock.Unlock()

state := a.getState()
if state != established {
a.lock.Unlock()
return fmt.Errorf("%w: state=%s", ErrPayloadDataStateNotExist,
getAssociationStateString(state))
}

if a.blockWrite {
for a.writePending {
a.lock.Unlock()
select {
case <-ctx.Done():
return ctx.Err()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how should this error be handled? does this bubble up to the app level or it is all handled internally? Am curious about how this affects reliable transmission mode and if app need to retry send?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will return to upper application, when WriteSCTP returns error, it means the data has not been sent (incorrect state or deadline timeout) so no retransmission here, the retransmission only works for sent messages (write operation succeeded).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it, wondering if the application is supposed to try Write again?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should do that if blocking mode is enabled, like other connections. The default behavior is still unblock mode

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good, thank you 🙏

case <-a.writeNotify:
a.lock.Lock()
}
}
a.writePending = true
}

// Push the chunks into the pending queue first.
for _, c := range chunks {
a.pendingQueue.push(c)
}

a.lock.Unlock()
a.awakeWriteLoop()
return nil
}
Expand Down
95 changes: 93 additions & 2 deletions association_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2845,7 +2845,7 @@ func TestAssociationReceiveWindow(t *testing.T) {

done := make(chan bool)
go func() {
chunks := s1.packetize(make([]byte, 1000), PayloadTypeWebRTCBinary)
chunks, _ := s1.packetize(make([]byte, 1000), PayloadTypeWebRTCBinary)
chunks = chunks[:1]
chunk := chunks[0]
// Fake the TSN and enqueue 1 chunk with a very high tsn in the payload queue
Expand Down Expand Up @@ -3016,7 +3016,7 @@ func TestAssociationMaxTSNOffset(t *testing.T) {
require.NoError(t, err)
require.Equal(t, uint16(1), s2.streamIdentifier)

chunks := s1.packetize(make([]byte, 1000), PayloadTypeWebRTCBinary)
chunks, _ := s1.packetize(make([]byte, 1000), PayloadTypeWebRTCBinary)
chunks = chunks[:1]
sendChunk := func(tsn uint32) {
chunk := chunks[0]
Expand Down Expand Up @@ -3616,3 +3616,94 @@ func TestAssociation_OpenStreamAfterInternalClose(t *testing.T) {
require.Equal(t, 0, len(a1.streams))
require.Equal(t, 0, len(a2.streams))
}

func TestAssociation_BlockWrite(t *testing.T) {
checkGoroutineLeaks(t)

conn1, conn2 := createUDPConnPair()
a1, a2, err := createAssociationPairWithConfig(conn1, conn2, Config{BlockWrite: true, MaxReceiveBufferSize: 4000})
require.NoError(t, err)

defer noErrorClose(t, a2.Close)
defer noErrorClose(t, a1.Close)
s1, err := a1.OpenStream(1, PayloadTypeWebRTCBinary)
require.NoError(t, err)
defer noErrorClose(t, s1.Close)
_, err = s1.WriteSCTP([]byte("hello"), PayloadTypeWebRTCBinary)
require.NoError(t, err)
s2, err := a2.AcceptStream()
require.NoError(t, err)

data := make([]byte, 4000)
n, err := s2.Read(data)
require.NoError(t, err)
require.Equal(t, "hello", string(data[:n]))

// Write should block until data is sent
dbConn1, ok := conn1.(*dumbConn2)
require.True(t, ok)
dbConn2, ok := conn2.(*dumbConn2)
require.True(t, ok)

dbConn1.remoteInboundHandler = dbConn2.inboundHandler

_, err = s1.WriteSCTP(data, PayloadTypeWebRTCBinary)
require.NoError(t, err)
_, err = s1.WriteSCTP(data, PayloadTypeWebRTCBinary)
require.NoError(t, err)

// test write deadline
// a2's awnd is 0, so write should be blocked
require.NoError(t, s1.SetWriteDeadline(time.Now().Add(100*time.Millisecond)))
_, err = s1.WriteSCTP(data, PayloadTypeWebRTCBinary)
require.ErrorIs(t, err, context.DeadlineExceeded, err)

// test write deadline cancel
require.NoError(t, s1.SetWriteDeadline(time.Time{}))
var deadLineCanceled atomic.Bool
writeCanceled := make(chan struct{}, 2)
// both write should be blocked and canceled by deadline
go func() {
_, err1 := s1.WriteSCTP(data, PayloadTypeWebRTCBinary)
require.ErrorIs(t, err, context.DeadlineExceeded, err1)
require.True(t, deadLineCanceled.Load())
writeCanceled <- struct{}{}
}()
go func() {
_, err1 := s1.WriteSCTP(data, PayloadTypeWebRTCBinary)
require.ErrorIs(t, err, context.DeadlineExceeded, err1)
require.True(t, deadLineCanceled.Load())
writeCanceled <- struct{}{}
}()
time.Sleep(100 * time.Millisecond)
deadLineCanceled.Store(true)
require.NoError(t, s1.SetWriteDeadline(time.Now().Add(-1*time.Second)))
<-writeCanceled
<-writeCanceled
require.NoError(t, s1.SetWriteDeadline(time.Time{}))

rn, rerr := s2.Read(data)
require.NoError(t, rerr)
require.Equal(t, 4000, rn)

// slow reader and fast writer, make sure all write is blocked
go func() {
for {
bytes := make([]byte, 4000)
rn, rerr = s2.Read(bytes)
if errors.Is(rerr, io.EOF) {
return
}
require.NoError(t, rerr)
require.Equal(t, 4000, rn)
time.Sleep(5 * time.Millisecond)
}
}()

for i := 0; i < 10; i++ {
_, err = s1.Write(data)
require.NoError(t, err)
// bufferedAmount should not exceed RWND+message size (inflight + pending)
require.LessOrEqual(t, s1.BufferedAmount(), uint64(4000*2))
}
}
43 changes: 37 additions & 6 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"time"

"github.com/pion/logging"
"github.com/pion/transport/v3/deadline"
)

const (
Expand Down Expand Up @@ -65,6 +66,8 @@
readNotifier *sync.Cond
readErr error
readTimeoutCancel chan struct{}
writeDeadline *deadline.Deadline
writeLock sync.Mutex
unordered bool
reliabilityType byte
reliabilityValue uint32
Expand Down Expand Up @@ -272,16 +275,44 @@
return 0, ErrStreamClosed
}

chunks := s.packetize(p, ppi)
// the send could fail if the association is blocked for writing (timeout), it will left a hole
// in the stream sequence number space, so we need to lock the write to avoid concurrent send and decrement
// the sequence number in case of failure
if s.association.isBlockWrite() {
s.writeLock.Lock()
}
chunks, unordered := s.packetize(p, ppi)
n := len(p)
err := s.association.sendPayloadData(chunks)
err := s.association.sendPayloadData(s.writeDeadline, chunks)
if err != nil {
return n, ErrStreamClosed
s.lock.Lock()
s.bufferedAmount -= uint64(n)
if !unordered {
s.sequenceNumber--
}
s.lock.Unlock()
}
if s.association.isBlockWrite() {
s.writeLock.Unlock()
}
return n, err
}

// SetWriteDeadline sets the write deadline in an identical way to net.Conn, it will only work for blocking writes
func (s *Stream) SetWriteDeadline(deadline time.Time) error {
s.writeDeadline.Set(deadline)
return nil
}

// SetDeadline sets the read and write deadlines in an identical way to net.Conn
func (s *Stream) SetDeadline(t time.Time) error {
if err := s.SetReadDeadline(t); err != nil {
return err

Check warning on line 310 in stream.go

View check run for this annotation

Codecov / codecov/patch

stream.go#L308-L310

Added lines #L308 - L310 were not covered by tests
}
return n, nil
return s.SetWriteDeadline(t)

Check warning on line 312 in stream.go

View check run for this annotation

Codecov / codecov/patch

stream.go#L312

Added line #L312 was not covered by tests
}

func (s *Stream) packetize(raw []byte, ppi PayloadProtocolIdentifier) []*chunkPayloadData {
func (s *Stream) packetize(raw []byte, ppi PayloadProtocolIdentifier) ([]*chunkPayloadData, bool) {
s.lock.Lock()
defer s.lock.Unlock()

Expand Down Expand Up @@ -336,7 +367,7 @@
s.bufferedAmount += uint64(len(raw))
s.log.Tracef("[%s] bufferedAmount = %d", s.name, s.bufferedAmount)

return chunks
return chunks, unordered
}

// Close closes the write-direction of the stream.
Expand Down
Loading