diff --git a/association.go b/association.go index ce3725ec..f0f1d7be 100644 --- a/association.go +++ b/association.go @@ -254,10 +254,17 @@ func Server(config Config) (*Association, error) { // Client opens a SCTP stream over a conn func Client(config Config) (*Association, error) { + return createClientWithContext(context.Background(), config) +} + +func createClientWithContext(ctx context.Context, config Config) (*Association, error) { a := createAssociation(config) a.init(true) select { + case <-ctx.Done(): + a.log.Errorf("[%s] client handshake canceled: state=%s", a.name, getAssociationStateString(a.getState())) + return nil, ctx.Err() case err := <-a.handshakeCompletedCh: if err != nil { return nil, err diff --git a/association_test.go b/association_test.go index e50f2f87..7bfcce74 100644 --- a/association_test.go +++ b/association_test.go @@ -2570,65 +2570,106 @@ func TestAssocMaxMessageSize(t *testing.T) { }) } -func createAssocs(t *testing.T) (a1, a2 *Association) { - addr1 := &net.UDPAddr{ - IP: net.IP{127, 0, 0, 1}, - Port: 1234, +// crateUDPConnPair creates a pair of net.UDPConn objects that are connected with each other +func createUDPConnPair(t *testing.T) (*net.UDPConn, *net.UDPConn, error) { + udp1, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1")}) + if err != nil { + return nil, nil, err } + addr1, ok := udp1.LocalAddr().(*net.UDPAddr) + require.True(t, ok) + err = udp1.Close() + require.NoError(t, err) - addr2 := &net.UDPAddr{ - IP: net.IP{127, 0, 0, 1}, - Port: 5678, + udp2, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1")}) + if err != nil { + return nil, nil, err } + addr2, ok := udp2.LocalAddr().(*net.UDPAddr) + require.True(t, ok) + err = udp2.Close() + require.NoError(t, err) - udp1, err := net.DialUDP("udp", addr1, addr2) + udp1, err = net.DialUDP("udp", addr1, addr2) if err != nil { - panic(err) + return nil, nil, err } - udp2, err := net.DialUDP("udp", addr2, addr1) + udp2, err = net.DialUDP("udp", addr2, addr1) if err != nil { - panic(err) + return nil, nil, err + } + + return udp1, udp2, nil +} + +func createAssocs(t *testing.T) (*Association, *Association, error) { + udp1, udp2, err := createUDPConnPair(t) + if err != nil { + return nil, nil, err } loggerFactory := logging.NewDefaultLoggerFactory() - a1Chan := make(chan *Association) - a2Chan := make(chan *Association) + a1Chan := make(chan interface{}) + a2Chan := make(chan interface{}) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() go func() { - a, err := Client(Config{ + a, err2 := createClientWithContext(ctx, Config{ NetConn: udp1, LoggerFactory: loggerFactory, }) - require.NoError(t, err) - - a1Chan <- a + if err2 != nil { + a1Chan <- err2 + } else { + a1Chan <- a + } }() go func() { - a, err := Client(Config{ + a, err2 := createClientWithContext(ctx, Config{ NetConn: udp2, LoggerFactory: loggerFactory, }) - require.NoError(t, err) - - a2Chan <- a + if err2 != nil { + a2Chan <- err2 + } else { + a2Chan <- a + } }() - select { - case a1 = <-a1Chan: - case <-time.After(time.Second): - assert.Fail(t, "timed out waiting for a1") - } + var a1 *Association + var a2 *Association - select { - case a2 = <-a2Chan: - case <-time.After(time.Second): - assert.Fail(t, "timed out waiting for a2") +loop: + for { + select { + case v1 := <-a1Chan: + switch v := v1.(type) { + case *Association: + a1 = v + if a2 != nil { + break loop + } + case error: + return nil, nil, v + } + case v2 := <-a2Chan: + switch v := v2.(type) { + case *Association: + a2 = v + if a1 != nil { + break loop + } + case error: + return nil, nil, v + } + } } - - return a1, a2 + return a1, a2, nil } func TestAssociation_Shutdown(t *testing.T) { @@ -2640,7 +2681,8 @@ func TestAssociation_Shutdown(t *testing.T) { assert.Equal(t, n0, runtime.NumGoroutine(), "goroutine is leaked") }() - a1, a2 := createAssocs(t) + a1, a2, err := createAssocs(t) + require.NoError(t, err) s11, err := a1.OpenStream(1, PayloadTypeWebRTCString) require.NoError(t, err) @@ -2683,7 +2725,8 @@ func TestAssociation_ShutdownDuringWrite(t *testing.T) { assert.Equal(t, n0, runtime.NumGoroutine(), "goroutine is leaked") }() - a1, a2 := createAssocs(t) + a1, a2, err := createAssocs(t) + require.NoError(t, err) s11, err := a1.OpenStream(1, PayloadTypeWebRTCString) require.NoError(t, err) @@ -2899,7 +2942,8 @@ func TestAssociation_Abort(t *testing.T) { assert.Equal(t, n0, runtime.NumGoroutine(), "goroutine is leaked") }() - a1, a2 := createAssocs(t) + a1, a2, err := createAssocs(t) + require.NoError(t, err) s11, err := a1.OpenStream(1, PayloadTypeWebRTCString) require.NoError(t, err)