Skip to content
This repository has been archived by the owner on May 26, 2022. It is now read-only.

UDP support #3

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 90 additions & 25 deletions dial.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package tcpreuse
package reusetransport

import (
"context"
Expand Down Expand Up @@ -34,11 +34,15 @@ func (t *Transport) DialContext(ctx context.Context, raddr ma.Multiaddr) (manet.
var d dialer
switch network {
case "tcp4":
d = t.v4.getDialer(network)
d = t.v4.getTcpDialer(network)
case "udp4":
d = t.v4.getUdpDialer(network)
case "tcp6":
d = t.v6.getDialer(network)
d = t.v6.getTcpDialer(network)
case "udp6":
d = t.v6.getUdpDialer(network)
default:
return nil, ErrWrongProto
return nil, ErrWrongDialProto
}
conn, err := d.DialContext(ctx, network, addr)
if err != nil {
Expand All @@ -52,20 +56,62 @@ func (t *Transport) DialContext(ctx context.Context, raddr ma.Multiaddr) (manet.
return maconn, nil
}

func (n *network) getDialer(network string) dialer {
n.mu.RLock()
d := n.dialer
n.mu.RUnlock()
if d == nil {
n.mu.Lock()
defer n.mu.Unlock()
func (n *network) getTcpDialer(network string) dialer {
n.mu.Lock()
defer n.mu.Unlock()

if n.dialer == nil {
n.dialer = n.makeDialer(network)
}
d = n.dialer
if n.tcpDialer != nil {
return n.tcpDialer
}
n.tcpDialer = n.makeDialer(network)
return n.tcpDialer
}

func (n *network) getUdpDialer(network string) dialer {
n.mu.Lock()
defer n.mu.Unlock()

if n.udpDialer != nil {
return n.udpDialer
}
n.udpDialer = n.makeDialer(network)
return n.udpDialer
}

func tcpAddresses(listeners map[*listener]struct{}) []net.Addr {
result := make([]net.Addr, 0, len(listeners))
for l := range listeners {
result = append(result, l.Addr())
}
return result
}

func udpAddresses(listeners map[*udpListener]struct{}) []net.Addr {
result := make([]net.Addr, 0, len(listeners))
for l := range listeners {
result = append(result, l.Connection().LocalAddr()) // TODO make udpListener's interface comparable to listener
}
return result
}

func ipOf(addr net.Addr) net.IP {
if a, ok := addr.(*net.TCPAddr); ok {
return a.IP
}
if a, ok := addr.(*net.UDPAddr); ok {
return a.IP
}
return d
panic("only support tcp and udp address")
}

func portOf(addr net.Addr) int {
if a, ok := addr.(*net.TCPAddr); ok {
return a.Port
}
if a, ok := addr.(*net.UDPAddr); ok {
return a.Port
}
panic("only support tcp and udp address")
}

func (n *network) makeDialer(network string) dialer {
Expand All @@ -75,26 +121,35 @@ func (n *network) makeDialer(network string) dialer {
}

var unspec net.IP
var listenAddrs []net.Addr
switch network {
case "tcp4":
unspec = net.IPv4zero
listenAddrs = tcpAddresses(n.tcpListeners)
case "udp4":
unspec = net.IPv4zero
listenAddrs = udpAddresses(n.udpListeners)
case "tcp6":
unspec = net.IPv6unspecified
listenAddrs = tcpAddresses(n.tcpListeners)
case "udp6":
unspec = net.IPv6unspecified
listenAddrs = udpAddresses(n.udpListeners)
default:
panic("invalid network: must be either tcp4 or tcp6")
panic("invalid network: must be either tcp4, tcp6, udp4 or udp6")
}

// How many ports are we listening on.
var port = 0
for l := range n.listeners {
newPort := l.Addr().(*net.TCPAddr).Port
for _, l := range listenAddrs {
newPort := portOf(l)
switch {
case newPort == 0: // Any port, ignore (really, we shouldn't get this case...).
case port == 0: // Haven't selected a port yet, choose this one.
port = newPort
case newPort == port: // Same as the selected port, continue...
default: // Multiple ports, use the multi dialer
return newMultiDialer(unspec, n.listeners)
return newMultiDialer(unspec, listenAddrs, network)
}
}

Expand All @@ -104,10 +159,20 @@ func (n *network) makeDialer(network string) dialer {
}

// One. Always dial from the single port we're listening on.
laddr := &net.TCPAddr{
IP: unspec,
Port: port,
switch network {
case "tcp4", "tcp6":
laddr := &net.TCPAddr{
IP: unspec,
Port: port,
}
return singleDialer{laddr}
case "udp4", "udp6":
laddr := &net.UDPAddr{
IP: unspec,
Port: port,
}
return singleDialer{laddr}
default:
panic("invalid network: must be either tcp4, tcp6, udp4 or udp6")
}

return (*singleDialer)(laddr)
}
83 changes: 74 additions & 9 deletions listen.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package tcpreuse
package reusetransport

import (
"net"
Expand All @@ -13,14 +13,27 @@ type listener struct {
network *network
}

type udpListener struct {
manet.PacketConn
network *network
}

func (l *listener) Close() error {
l.network.mu.Lock()
delete(l.network.listeners, l)
l.network.dialer = nil
delete(l.network.tcpListeners, l)
l.network.tcpDialer = nil
l.network.mu.Unlock()
return l.Listener.Close()
}

func (l *udpListener) Close() error {
l.network.mu.Lock()
delete(l.network.udpListeners, l)
l.network.udpDialer = nil
l.network.mu.Unlock()
return l.PacketConn.Close()
}

// Listen listens on the given multiaddr.
//
// If reuseport is supported, it will be enabled for this listener and future
Expand All @@ -40,7 +53,7 @@ func (t *Transport) Listen(laddr ma.Multiaddr) (manet.Listener, error) {
case "tcp6":
n = &t.v6
default:
return nil, ErrWrongProto
return nil, ErrWrongListenProto
}

if !reuseport.Available() {
Expand All @@ -53,7 +66,7 @@ func (t *Transport) Listen(laddr ma.Multiaddr) (manet.Listener, error) {

if _, ok := nl.Addr().(*net.TCPAddr); !ok {
nl.Close()
return nil, ErrWrongProto
return nil, ErrWrongListenProto
}

malist, err := manet.WrapNetListener(nl)
Expand All @@ -70,11 +83,63 @@ func (t *Transport) Listen(laddr ma.Multiaddr) (manet.Listener, error) {
n.mu.Lock()
defer n.mu.Unlock()

if n.listeners == nil {
n.listeners = make(map[*listener]struct{})
if n.tcpListeners == nil {
n.tcpListeners = make(map[*listener]struct{})
}
n.tcpListeners[list] = struct{}{}
n.tcpDialer = nil

return list, nil
}

// ListenPacket is the UDP equivalent of `Listen`
func (t *Transport) ListenPacket(laddr ma.Multiaddr) (manet.PacketConn, error) {
nw, naddr, err := manet.DialArgs(laddr)
if err != nil {
return nil, err
}
var n *network
switch nw {
case "udp4":
n = &t.v4
case "udp6":
n = &t.v6
default:
return nil, ErrWrongListenPacketProto
}

if !reuseport.Available() {
return manet.ListenPacket(laddr)
}
nl, err := reuseport.ListenPacket(nw, naddr)
if err != nil {
return manet.ListenPacket(laddr)
}

if _, ok := nl.LocalAddr().(*net.UDPAddr); !ok {
nl.Close()
return nil, ErrWrongListenPacketProto
}

malist, err := manet.WrapPacketConn(nl)
if err != nil {
nl.Close()
return nil, err
}

list := &udpListener{
PacketConn: malist,
network: n,
}

n.mu.Lock()
defer n.mu.Unlock()

if n.udpListeners == nil {
n.udpListeners = make(map[*udpListener]struct{})
}
n.listeners[list] = struct{}{}
n.dialer = nil
n.udpListeners[list] = struct{}{}
n.udpDialer = nil

return list, nil
}
Loading