diff --git a/dial.go b/dial.go index a16b5b2..944bcc6 100644 --- a/dial.go +++ b/dial.go @@ -1,4 +1,4 @@ -package tcpreuse +package reusetransport import ( "context" @@ -34,11 +34,15 @@ func (t *Transport) DialContext(ctx context.Context, raddr ma.Multiaddr) (manet. var d dialer switch network { case "tcp4": - d = t.v4.getDialer(network) + d = t.v4.getTcpDialer(network) + case "udp4": + d = t.v4.getUdpDialer(network) case "tcp6": - d = t.v6.getDialer(network) + d = t.v6.getTcpDialer(network) + case "udp6": + d = t.v6.getUdpDialer(network) default: - return nil, ErrWrongProto + return nil, ErrWrongDialProto } conn, err := d.DialContext(ctx, network, addr) if err != nil { @@ -52,20 +56,62 @@ func (t *Transport) DialContext(ctx context.Context, raddr ma.Multiaddr) (manet. return maconn, nil } -func (n *network) getDialer(network string) dialer { - n.mu.RLock() - d := n.dialer - n.mu.RUnlock() - if d == nil { - n.mu.Lock() - defer n.mu.Unlock() +func (n *network) getTcpDialer(network string) dialer { + n.mu.Lock() + defer n.mu.Unlock() - if n.dialer == nil { - n.dialer = n.makeDialer(network) - } - d = n.dialer + if n.tcpDialer != nil { + return n.tcpDialer + } + n.tcpDialer = n.makeDialer(network) + return n.tcpDialer +} + +func (n *network) getUdpDialer(network string) dialer { + n.mu.Lock() + defer n.mu.Unlock() + + if n.udpDialer != nil { + return n.udpDialer + } + n.udpDialer = n.makeDialer(network) + return n.udpDialer +} + +func tcpAddresses(listeners map[*listener]struct{}) []net.Addr { + result := make([]net.Addr, 0, len(listeners)) + for l := range listeners { + result = append(result, l.Addr()) + } + return result +} + +func udpAddresses(listeners map[*udpListener]struct{}) []net.Addr { + result := make([]net.Addr, 0, len(listeners)) + for l := range listeners { + result = append(result, l.Connection().LocalAddr()) // TODO make udpListener's interface comparable to listener + } + return result +} + +func ipOf(addr net.Addr) net.IP { + if a, ok := addr.(*net.TCPAddr); ok { + return a.IP + } + if a, ok := addr.(*net.UDPAddr); ok { + return a.IP } - return d + panic("only support tcp and udp address") +} + +func portOf(addr net.Addr) int { + if a, ok := addr.(*net.TCPAddr); ok { + return a.Port + } + if a, ok := addr.(*net.UDPAddr); ok { + return a.Port + } + panic("only support tcp and udp address") } func (n *network) makeDialer(network string) dialer { @@ -75,26 +121,35 @@ func (n *network) makeDialer(network string) dialer { } var unspec net.IP + var listenAddrs []net.Addr switch network { case "tcp4": unspec = net.IPv4zero + listenAddrs = tcpAddresses(n.tcpListeners) + case "udp4": + unspec = net.IPv4zero + listenAddrs = udpAddresses(n.udpListeners) case "tcp6": unspec = net.IPv6unspecified + listenAddrs = tcpAddresses(n.tcpListeners) + case "udp6": + unspec = net.IPv6unspecified + listenAddrs = udpAddresses(n.udpListeners) default: - panic("invalid network: must be either tcp4 or tcp6") + panic("invalid network: must be either tcp4, tcp6, udp4 or udp6") } // How many ports are we listening on. var port = 0 - for l := range n.listeners { - newPort := l.Addr().(*net.TCPAddr).Port + for _, l := range listenAddrs { + newPort := portOf(l) switch { case newPort == 0: // Any port, ignore (really, we shouldn't get this case...). case port == 0: // Haven't selected a port yet, choose this one. port = newPort case newPort == port: // Same as the selected port, continue... default: // Multiple ports, use the multi dialer - return newMultiDialer(unspec, n.listeners) + return newMultiDialer(unspec, listenAddrs, network) } } @@ -104,10 +159,20 @@ func (n *network) makeDialer(network string) dialer { } // One. Always dial from the single port we're listening on. - laddr := &net.TCPAddr{ - IP: unspec, - Port: port, + switch network { + case "tcp4", "tcp6": + laddr := &net.TCPAddr{ + IP: unspec, + Port: port, + } + return singleDialer{laddr} + case "udp4", "udp6": + laddr := &net.UDPAddr{ + IP: unspec, + Port: port, + } + return singleDialer{laddr} + default: + panic("invalid network: must be either tcp4, tcp6, udp4 or udp6") } - - return (*singleDialer)(laddr) } diff --git a/listen.go b/listen.go index 7b2a4c3..ca4fc40 100644 --- a/listen.go +++ b/listen.go @@ -1,4 +1,4 @@ -package tcpreuse +package reusetransport import ( "net" @@ -13,14 +13,27 @@ type listener struct { network *network } +type udpListener struct { + manet.PacketConn + network *network +} + func (l *listener) Close() error { l.network.mu.Lock() - delete(l.network.listeners, l) - l.network.dialer = nil + delete(l.network.tcpListeners, l) + l.network.tcpDialer = nil l.network.mu.Unlock() return l.Listener.Close() } +func (l *udpListener) Close() error { + l.network.mu.Lock() + delete(l.network.udpListeners, l) + l.network.udpDialer = nil + l.network.mu.Unlock() + return l.PacketConn.Close() +} + // Listen listens on the given multiaddr. // // If reuseport is supported, it will be enabled for this listener and future @@ -40,7 +53,7 @@ func (t *Transport) Listen(laddr ma.Multiaddr) (manet.Listener, error) { case "tcp6": n = &t.v6 default: - return nil, ErrWrongProto + return nil, ErrWrongListenProto } if !reuseport.Available() { @@ -53,7 +66,7 @@ func (t *Transport) Listen(laddr ma.Multiaddr) (manet.Listener, error) { if _, ok := nl.Addr().(*net.TCPAddr); !ok { nl.Close() - return nil, ErrWrongProto + return nil, ErrWrongListenProto } malist, err := manet.WrapNetListener(nl) @@ -70,11 +83,63 @@ func (t *Transport) Listen(laddr ma.Multiaddr) (manet.Listener, error) { n.mu.Lock() defer n.mu.Unlock() - if n.listeners == nil { - n.listeners = make(map[*listener]struct{}) + if n.tcpListeners == nil { + n.tcpListeners = make(map[*listener]struct{}) + } + n.tcpListeners[list] = struct{}{} + n.tcpDialer = nil + + return list, nil +} + +// ListenPacket is the UDP equivalent of `Listen` +func (t *Transport) ListenPacket(laddr ma.Multiaddr) (manet.PacketConn, error) { + nw, naddr, err := manet.DialArgs(laddr) + if err != nil { + return nil, err + } + var n *network + switch nw { + case "udp4": + n = &t.v4 + case "udp6": + n = &t.v6 + default: + return nil, ErrWrongListenPacketProto + } + + if !reuseport.Available() { + return manet.ListenPacket(laddr) + } + nl, err := reuseport.ListenPacket(nw, naddr) + if err != nil { + return manet.ListenPacket(laddr) + } + + if _, ok := nl.LocalAddr().(*net.UDPAddr); !ok { + nl.Close() + return nil, ErrWrongListenPacketProto + } + + malist, err := manet.WrapPacketConn(nl) + if err != nil { + nl.Close() + return nil, err + } + + list := &udpListener{ + PacketConn: malist, + network: n, + } + + n.mu.Lock() + defer n.mu.Unlock() + + if n.udpListeners == nil { + n.udpListeners = make(map[*udpListener]struct{}) } - n.listeners[list] = struct{}{} - n.dialer = nil + n.udpListeners[list] = struct{}{} + n.udpDialer = nil return list, nil } diff --git a/multidialer.go b/multidialer.go index c7d388c..2d02184 100644 --- a/multidialer.go +++ b/multidialer.go @@ -1,4 +1,4 @@ -package tcpreuse +package reusetransport import ( "context" @@ -8,16 +8,16 @@ import ( ) type multiDialer struct { - loopback []*net.TCPAddr - unspecified []*net.TCPAddr - global *net.TCPAddr + loopback []net.Addr + unspecified []net.Addr + global net.Addr } func (d *multiDialer) Dial(network, addr string) (net.Conn, error) { return d.DialContext(context.Background(), network, addr) } -func randAddr(addrs []*net.TCPAddr) *net.TCPAddr { +func randAddr(addrs []net.Addr) net.Addr { if len(addrs) > 0 { return addrs[rand.Intn(len(addrs))] } @@ -25,9 +25,22 @@ func randAddr(addrs []*net.TCPAddr) *net.TCPAddr { } func (d *multiDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { - tcpAddr, err := net.ResolveTCPAddr(network, addr) - if err != nil { - return nil, err + var ip net.IP + switch network { + case "tcp", "tcp4", "tcp6": + resolved, err := net.ResolveTCPAddr(network, addr) + if err != nil { + return nil, err + } + ip = resolved.IP + case "udp", "udp4", "udp6": + resolved, err := net.ResolveUDPAddr(network, addr) + if err != nil { + return nil, err + } + ip = resolved.IP + default: + return nil, net.UnknownNetworkError(network) } // We pick the source *port* based on the following algorithm. @@ -52,7 +65,6 @@ func (d *multiDialer) DialContext(ctx context.Context, network, addr string) (ne // the port we pick). In the future, we could use netlink (on Linux) to // figure out the right source address but we're going to punt on that. - ip := tcpAddr.IP source := d.global switch { case ip.IsLoopback(): @@ -68,19 +80,20 @@ func (d *multiDialer) DialContext(ctx context.Context, network, addr string) (ne source = randAddr(d.unspecified) } default: - return nil, fmt.Errorf("undialable IP: %s", tcpAddr.IP) + return nil, fmt.Errorf("undialable IP: %s", ip) } return reuseDial(ctx, source, network, addr) } -func newMultiDialer(unspec net.IP, listeners map[*listener]struct{}) dialer { +func newMultiDialer(unspec net.IP, listenAddrs []net.Addr, network string) dialer { m := new(multiDialer) - for l := range listeners { - laddr := l.Addr().(*net.TCPAddr) + for _, l := range listenAddrs { + ip := ipOf(l) + port := portOf(l) switch { - case laddr.IP.IsLoopback(): - m.loopback = append(m.loopback, laddr) - case laddr.IP.IsGlobalUnicast(): + case ip.IsLoopback(): + m.loopback = append(m.loopback, l) + case ip.IsGlobalUnicast(): // Different global ports? Crap. // // The *proper* way to deal with this is to, e.g., use @@ -95,15 +108,26 @@ func newMultiDialer(unspec net.IP, listeners map[*listener]struct{}) dialer { // // TODO: Port priority? Addr priority? if m.global == nil { - m.global = &net.TCPAddr{ - IP: unspec, - Port: laddr.Port, + switch network { + case "tcp4", "tcp6": + m.global = &net.TCPAddr{ + IP: unspec, + Port: port, + } + case "udp4", "udp6": + m.global = &net.UDPAddr{ + IP: unspec, + Port: port, + } + default: + panic("invalid network: must be either tcp4, tcp6, udp4 or udp6") } } else { - log.Warning("listening on external interfaces on multiple ports, will dial from %d, not %s", m.global, laddr) + log.Warning("listening on external interfaces on multiple ports, will dial from %d, not %s:%d", + m.global, ip, port) } - case laddr.IP.IsUnspecified(): - m.unspecified = append(m.unspecified, laddr) + case ip.IsUnspecified(): + m.unspecified = append(m.unspecified, l) } } return m diff --git a/reuseport.go b/reuseport.go index 47ceac2..58b3801 100644 --- a/reuseport.go +++ b/reuseport.go @@ -1,4 +1,4 @@ -package tcpreuse +package reusetransport import ( "context" @@ -39,7 +39,7 @@ func reuseErrShouldRetry(err error) bool { } // Dials using reusport and then redials normally if that fails. -func reuseDial(ctx context.Context, laddr *net.TCPAddr, network, raddr string) (net.Conn, error) { +func reuseDial(ctx context.Context, laddr net.Addr, network, raddr string) (net.Conn, error) { if laddr == nil { return fallbackDialer.DialContext(ctx, network, raddr) } diff --git a/reuseport_test.go b/reuseport_test.go index b0ee4de..c9b4e56 100644 --- a/reuseport_test.go +++ b/reuseport_test.go @@ -1,4 +1,4 @@ -package tcpreuse +package reusetransport import ( "net" diff --git a/singledialer.go b/singledialer.go index efb96eb..a0db012 100644 --- a/singledialer.go +++ b/singledialer.go @@ -1,16 +1,18 @@ -package tcpreuse +package reusetransport import ( "context" "net" ) -type singleDialer net.TCPAddr +type singleDialer struct { + net.Addr +} -func (d *singleDialer) Dial(network, address string) (net.Conn, error) { +func (d singleDialer) Dial(network, address string) (net.Conn, error) { return d.DialContext(context.Background(), network, address) } -func (d *singleDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - return reuseDial(ctx, (*net.TCPAddr)(d), network, address) +func (d singleDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + return reuseDial(ctx, d.Addr, network, address) } diff --git a/transport.go b/transport.go index 5f094d1..9a41208 100644 --- a/transport.go +++ b/transport.go @@ -1,4 +1,4 @@ -package tcpreuse +package reusetransport import ( "errors" @@ -9,17 +9,25 @@ import ( var log = logging.Logger("reuseport-transport") -// ErrWrongProto is returned when dialing a protocol other than tcp. -var ErrWrongProto = errors.New("can only dial TCP over IPv4 or IPv6") +// ErrWrongListenProto is returned when listen a protocol other than tcp. +var ErrWrongListenProto = errors.New("can only listen TCP over IPv4 or IPv6") -// Transport is a TCP reuse transport that reuses listener ports. +// ErrWrongListenPacketProto is returned when listen a protocol other than udp. +var ErrWrongListenPacketProto = errors.New("can only listen UDP packet over IPv4 or IPv6") + +// ErrWrongDialProto is returned when dialing a protocol we cannot handle +var ErrWrongDialProto = errors.New("can only dial tcp4, tcp6, udp4, udp6") + +// Transport is a reuse transport that reuses listener ports. type Transport struct { v4 network v6 network } type network struct { - mu sync.RWMutex - listeners map[*listener]struct{} - dialer dialer + mu sync.RWMutex + tcpListeners map[*listener]struct{} + udpListeners map[*udpListener]struct{} + tcpDialer dialer + udpDialer dialer } diff --git a/transport_test.go b/transport_test.go index 942019b..9bf2580 100644 --- a/transport_test.go +++ b/transport_test.go @@ -1,6 +1,7 @@ -package tcpreuse +package reusetransport import ( + "errors" "net" "testing" @@ -13,8 +14,19 @@ var loopbackV6, _ = ma.NewMultiaddr("/ip6/::1/tcp/0") var unspecV6, _ = ma.NewMultiaddr("/ip6/::/tcp/0") var unspecV4, _ = ma.NewMultiaddr("/ip4/0.0.0.0/tcp/0") +var udpLoopbackV4, _ = ma.NewMultiaddr("/ip4/127.0.0.1/udp/0") +var udpLoopbackV6, _ = ma.NewMultiaddr("/ip6/::1/udp/0") +var udpUnspecV6, _ = ma.NewMultiaddr("/ip6/::/udp/0") +var udpUnspecV4, _ = ma.NewMultiaddr("/ip4/0.0.0.0/udp/0") + +var udpMsg = []byte("udp-reuse-port-transport-test") +var udpMsgSize = len(udpMsg) +var errBadMsgSize = errors.New("bad message size") + var globalV4 ma.Multiaddr var globalV6 ma.Multiaddr +var udpGlobalV4 ma.Multiaddr +var udpGlobalV6 ma.Multiaddr func init() { addrs, err := manet.InterfaceMultiaddrs() @@ -24,15 +36,22 @@ func init() { for _, addr := range addrs { if !manet.IsIP6LinkLocal(addr) && !manet.IsIPLoopback(addr) { tcp, _ := ma.NewMultiaddr("/tcp/0") + udp, _ := ma.NewMultiaddr("/udp/0") switch addr.Protocols()[0].Code { case ma.P_IP4: if globalV4 == nil { globalV4 = addr.Encapsulate(tcp) } + if udpGlobalV4 == nil { + udpGlobalV4 = addr.Encapsulate(udp) + } case ma.P_IP6: if globalV6 == nil { globalV6 = addr.Encapsulate(tcp) } + if udpGlobalV6 == nil { + udpGlobalV6 = addr.Encapsulate(udp) + } } } } @@ -53,6 +72,25 @@ func acceptOne(t *testing.T, listener manet.Listener) <-chan struct{} { return done } +func udpAcceptOne(t *testing.T, listener manet.PacketConn) <-chan struct{} { + t.Helper() + done := make(chan struct{}) + go func() { + defer close(done) + buffer := make([]byte, udpMsgSize+1) // +1 in case we receive larger message than expected + n, _, err := listener.ReadFrom(buffer) + if err != nil { + t.Error(err) + return + } + if n != udpMsgSize { + t.Error(errBadMsgSize) + return + } + }() + return done +} + func dialOne(t *testing.T, tr *Transport, listener manet.Listener, expected ...int) int { t.Helper() @@ -76,6 +114,36 @@ func dialOne(t *testing.T, tr *Transport, listener manet.Listener, expected ...i return 0 } +func udpDialOne(t *testing.T, tr *Transport, listener manet.PacketConn, expected ...int) int { + t.Helper() + + done := udpAcceptOne(t, listener) + c, err := tr.Dial(listener.Multiaddr()) + if err != nil { + t.Fatal(err) + } + n, err := c.Write(udpMsg) + if err != nil { + t.Fatal(err) + } + if n != udpMsgSize { + t.Fatal(errBadMsgSize) + } + port := c.LocalAddr().(*net.UDPAddr).Port + <-done + c.Close() + if len(expected) == 0 { + return port + } + for _, p := range expected { + if p == port { + return port + } + } + t.Errorf("dialed from %d, expected to dial from one of %v", port, expected) + return 0 +} + func TestNoneAndSingle(t *testing.T) { var trA Transport var trB Transport @@ -96,6 +164,26 @@ func TestNoneAndSingle(t *testing.T) { dialOne(t, &trB, listenerA, listenerB.Addr().(*net.TCPAddr).Port) } +func TestUdpNoneAndSingle(t *testing.T) { + var trA Transport + var trB Transport + listenerA, err := trA.ListenPacket(udpLoopbackV4) + if err != nil { + t.Fatal(err) + } + defer listenerA.Close() + + udpDialOne(t, &trB, listenerA) + + listenerB, err := trB.ListenPacket(udpLoopbackV4) + if err != nil { + t.Fatal(err) + } + defer listenerB.Close() + + udpDialOne(t, &trB, listenerA, listenerB.Connection().LocalAddr().(*net.UDPAddr).Port) +} + func TestTwoLocal(t *testing.T) { var trA Transport var trB Transport @@ -122,6 +210,32 @@ func TestTwoLocal(t *testing.T) { listenerB2.Addr().(*net.TCPAddr).Port) } +func TestUdpTwoLocal(t *testing.T) { + var trA Transport + var trB Transport + listenerA, err := trA.ListenPacket(udpLoopbackV4) + if err != nil { + t.Fatal(err) + } + defer listenerA.Close() + + listenerB1, err := trB.ListenPacket(udpLoopbackV4) + if err != nil { + t.Fatal(err) + } + defer listenerB1.Close() + + listenerB2, err := trB.ListenPacket(udpLoopbackV4) + if err != nil { + t.Fatal(err) + } + defer listenerB2.Close() + + udpDialOne(t, &trB, listenerA, + listenerB1.Connection().LocalAddr().(*net.UDPAddr).Port, + listenerB2.Connection().LocalAddr().(*net.UDPAddr).Port) +} + func TestGlobalPreferenceV4(t *testing.T) { if globalV4 == nil { t.Skip("no global IPv4 addresses configured") @@ -132,6 +246,16 @@ func TestGlobalPreferenceV4(t *testing.T) { testPrefer(t, globalV4, unspecV4, globalV4) testPrefer(t, globalV4, unspecV4, loopbackV4) + + if udpGlobalV4 == nil { + t.Skip("no global IPv4 addresses configured") + return + } + testUdpPrefer(t, udpLoopbackV4, udpLoopbackV4, udpGlobalV4) + testUdpPrefer(t, udpLoopbackV4, udpUnspecV4, udpGlobalV4) + + testUdpPrefer(t, udpGlobalV4, udpUnspecV4, udpGlobalV4) + testUdpPrefer(t, udpGlobalV4, udpUnspecV4, udpLoopbackV4) } func TestGlobalPreferenceV6(t *testing.T) { @@ -144,11 +268,24 @@ func TestGlobalPreferenceV6(t *testing.T) { testPrefer(t, globalV6, unspecV6, globalV6) testPrefer(t, globalV6, unspecV6, loopbackV6) + + if udpGlobalV6 == nil { + t.Skip("no global IPv6 addresses configured") + return + } + testUdpPrefer(t, udpLoopbackV6, udpLoopbackV6, udpGlobalV6) + testUdpPrefer(t, udpLoopbackV6, udpUnspecV6, udpGlobalV6) + + testUdpPrefer(t, udpGlobalV6, udpUnspecV6, udpGlobalV6) + testUdpPrefer(t, udpGlobalV6, udpUnspecV6, udpLoopbackV6) } func TestLoopbackPreference(t *testing.T) { testPrefer(t, loopbackV4, loopbackV4, unspecV4) testPrefer(t, loopbackV6, loopbackV6, unspecV6) + + testUdpPrefer(t, udpLoopbackV4, udpLoopbackV4, udpUnspecV4) + testUdpPrefer(t, udpLoopbackV6, udpLoopbackV6, udpUnspecV6) } func testPrefer(t *testing.T, listen, prefer, avoid ma.Multiaddr) { @@ -182,9 +319,43 @@ func testPrefer(t *testing.T, listen, prefer, avoid ma.Multiaddr) { dialOne(t, &trB, listenerA, listenerB1.Addr().(*net.TCPAddr).Port) } +func testUdpPrefer(t *testing.T, listen, prefer, avoid ma.Multiaddr) { + var trA Transport + var trB Transport + listenerA, err := trA.ListenPacket(listen) + if err != nil { + t.Fatal(err) + } + defer listenerA.Close() + + listenerB1, err := trB.ListenPacket(avoid) + if err != nil { + t.Fatal(err) + } + defer listenerB1.Close() + + udpDialOne(t, &trB, listenerA, listenerB1.Connection().LocalAddr().(*net.UDPAddr).Port) + + listenerB2, err := trB.ListenPacket(prefer) + if err != nil { + t.Fatal(err) + } + defer listenerB2.Close() + + udpDialOne(t, &trB, listenerA, listenerB2.Connection().LocalAddr().(*net.UDPAddr).Port) + + // Closing the listener should reset the dialer. + listenerB2.Close() + + udpDialOne(t, &trB, listenerA, listenerB1.Connection().LocalAddr().(*net.UDPAddr).Port) +} + func TestV6V4(t *testing.T) { testUseFirst(t, loopbackV4, loopbackV4, loopbackV6) testUseFirst(t, loopbackV6, loopbackV6, loopbackV4) + + testUdpUseFirst(t, udpLoopbackV4, udpLoopbackV4, udpLoopbackV6) + testUdpUseFirst(t, udpLoopbackV6, udpLoopbackV6, udpLoopbackV4) } func TestGlobalToGlobal(t *testing.T) { @@ -194,6 +365,13 @@ func TestGlobalToGlobal(t *testing.T) { } testUseFirst(t, globalV4, globalV4, loopbackV4) testUseFirst(t, globalV6, globalV6, loopbackV6) + + if udpGlobalV4 == nil { + t.Skip("no globalV4 addresses configured") + return + } + testUdpUseFirst(t, udpGlobalV4, udpGlobalV4, udpLoopbackV4) + testUdpUseFirst(t, udpGlobalV6, udpGlobalV6, udpLoopbackV6) } func testUseFirst(t *testing.T, listen, use, never ma.Multiaddr) { @@ -230,6 +408,40 @@ func testUseFirst(t *testing.T, listen, use, never ma.Multiaddr) { dialOne(t, &trB, listenerA) } +func testUdpUseFirst(t *testing.T, listen, use, never ma.Multiaddr) { + var trA Transport + var trB Transport + listenerA, err := trA.ListenPacket(udpGlobalV4) + if err != nil { + t.Fatal(err) + } + defer listenerA.Close() + + listenerB1, err := trB.ListenPacket(udpLoopbackV4) + if err != nil { + t.Fatal(err) + } + defer listenerB1.Close() + + // It works (random port) + udpDialOne(t, &trB, listenerA) + + listenerB2, err := trB.ListenPacket(udpGlobalV4) + if err != nil { + t.Fatal(err) + } + defer listenerB2.Close() + + // Uses globalV4 port. + udpDialOne(t, &trB, listenerA, listenerB2.Connection().LocalAddr().(*net.UDPAddr).Port) + + // Closing the listener should reset the dialer. + listenerB2.Close() + + // It still works. + udpDialOne(t, &trB, listenerA) +} + func TestDuplicateGlobal(t *testing.T) { if globalV4 == nil { t.Skip("no globalV4 addresses configured") @@ -264,3 +476,38 @@ func TestDuplicateGlobal(t *testing.T) { dialOne(t, &trB, listenerA, port) } } + +func TestUdpDuplicateGlobal(t *testing.T) { + if udpGlobalV4 == nil { + t.Skip("no globalV4 addresses configured") + return + } + + var trA Transport + var trB Transport + listenerA, err := trA.ListenPacket(udpGlobalV4) + if err != nil { + t.Fatal(err) + } + defer listenerA.Close() + + listenerB1, err := trB.ListenPacket(udpGlobalV4) + if err != nil { + t.Fatal(err) + } + defer listenerB1.Close() + + listenerB2, err := trB.ListenPacket(udpGlobalV4) + if err != nil { + t.Fatal(err) + } + defer listenerB2.Close() + + // Check which port we're using + port := udpDialOne(t, &trB, listenerA) + + // Check consistency + for i := 0; i < 10; i++ { + udpDialOne(t, &trB, listenerA, port) + } +}