Skip to content

Commit

Permalink
Makes UDPMux IPv4/IPv6 aware
Browse files Browse the repository at this point in the history
UDPMux before only worked with UDP4 traffic.
UDP6 traffic would simply be ignored.

This commit implements 2 connections per ufrag. When requesting a
connection for a ufrag the user must specify if they want IPv4 or IPv6.

Relates to pion/webrtc#1915
  • Loading branch information
Antonito committed Mar 2, 2022
1 parent 427ac0f commit fafa43e
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 39 deletions.
10 changes: 5 additions & 5 deletions gather.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error {
return errUDPMuxDisabled
}

localIPs, err := localInterfaces(a.net, a.interfaceFilter, []NetworkType{NetworkTypeUDP4})
localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.networkTypes)
switch {
case err != nil:
return err
Expand All @@ -254,7 +254,7 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error {
}
}

conn, err := a.udpMux.GetConn(a.localUfrag)
conn, err := a.udpMux.GetConn(a.localUfrag, candidateIP.To4() == nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -351,7 +351,7 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*URL, ne

for i := range urls {
wg.Add(1)
go func(url URL, network string) {
go func(url URL, network string, isIPv6 bool) {
defer wg.Done()

hostPort := fmt.Sprintf("%s:%d", url.Host, url.Port)
Expand All @@ -367,7 +367,7 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*URL, ne
return
}

conn, err := a.udpMuxSrflx.GetConnForURL(a.localUfrag, url.String())
conn, err := a.udpMuxSrflx.GetConnForURL(a.localUfrag, url.String(), isIPv6)
if err != nil {
a.log.Warnf("could not find connection in UDPMuxSrflx %s %s: %v\n", network, url, err)
return
Expand Down Expand Up @@ -397,7 +397,7 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*URL, ne
}
a.log.Warnf("Failed to append to localCandidates and run onCandidateHdlr: %v\n", err)
}
}(*urls[i], networkType.String())
}(*urls[i], networkType.String(), networkType.IsIPv6())
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion gather_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ func (m *universalUDPMuxMock) GetRelayedAddr(turnAddr net.Addr, deadline time.Du
return nil, errNotImplemented
}

func (m *universalUDPMuxMock) GetConnForURL(ufrag string, url string) (net.PacketConn, error) {
func (m *universalUDPMuxMock) GetConnForURL(ufrag string, url string, isIPv6 bool) (net.PacketConn, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.getConnForURLTimes++
Expand Down
87 changes: 60 additions & 27 deletions udp_mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
// UDPMux allows multiple connections to go over a single UDP port
type UDPMux interface {
io.Closer
GetConn(ufrag string) (net.PacketConn, error)
GetConn(ufrag string, isIPv6 bool) (net.PacketConn, error)
RemoveConnByUfrag(ufrag string)
}

Expand All @@ -25,8 +25,8 @@ type UDPMuxDefault struct {
closedChan chan struct{}
closeOnce sync.Once

// conns is a map of all udpMuxedConn indexed by ufrag|network|candidateType
conns map[string]*udpMuxedConn
// connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType
connsIPv4, connsIPv6 map[string]*udpMuxedConn

addressMapMu sync.RWMutex
addressMap map[string]*udpMuxedConn
Expand Down Expand Up @@ -54,7 +54,8 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
m := &UDPMuxDefault{
addressMap: map[string]*udpMuxedConn{},
params: params,
conns: make(map[string]*udpMuxedConn),
connsIPv4: make(map[string]*udpMuxedConn),
connsIPv6: make(map[string]*udpMuxedConn),
closedChan: make(chan struct{}, 1),
pool: &sync.Pool{
New: func() interface{} {
Expand All @@ -76,43 +77,47 @@ func (m *UDPMuxDefault) LocalAddr() net.Addr {

// GetConn returns a PacketConn given the connection's ufrag and network
// creates the connection if an existing one can't be found
func (m *UDPMuxDefault) GetConn(ufrag string) (net.PacketConn, error) {
func (m *UDPMuxDefault) GetConn(ufrag string, isIPv6 bool) (net.PacketConn, error) {
m.mu.Lock()
defer m.mu.Unlock()

if m.IsClosed() {
return nil, io.ErrClosedPipe
}

if c, ok := m.conns[ufrag]; ok {
return c, nil
if conn, ok := m.getConn(ufrag, isIPv6); ok {
return conn, nil
}

c := m.createMuxedConn(ufrag)
go func() {
<-c.CloseChannel()
m.removeConn(ufrag)
}()
m.conns[ufrag] = c

if isIPv6 {
m.connsIPv6[ufrag] = c
} else {
m.connsIPv4[ufrag] = c
}

return c, nil
}

// RemoveConnByUfrag stops and removes the muxed packet connection
func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
m.mu.Lock()
removedConns := make([]*udpMuxedConn, 0)
for key := range m.conns {
if key != ufrag {
continue
}
removedConns := make([]*udpMuxedConn, 0, 2)

c := m.conns[key]
delete(m.conns, key)
if c != nil {
removedConns = append(removedConns, c)
}
// Keep lock section small to avoid deadlock with conn lock
m.mu.Lock()
if c, ok := m.connsIPv4[ufrag]; ok {
delete(m.connsIPv4, ufrag)
removedConns = append(removedConns, c)
}
if c, ok := m.connsIPv6[ufrag]; ok {
delete(m.connsIPv6, ufrag)
removedConns = append(removedConns, c)
}
// keep lock section small to avoid deadlock with conn lock
m.mu.Unlock()

m.addressMapMu.Lock()
Expand Down Expand Up @@ -143,21 +148,39 @@ func (m *UDPMuxDefault) Close() error {
m.mu.Lock()
defer m.mu.Unlock()

for _, c := range m.conns {
for _, c := range m.connsIPv4 {
_ = c.Close()
}
m.conns = make(map[string]*udpMuxedConn)
for _, c := range m.connsIPv6 {
_ = c.Close()
}

m.connsIPv4 = make(map[string]*udpMuxedConn)
m.connsIPv6 = make(map[string]*udpMuxedConn)

close(m.closedChan)
})
return err
}

func (m *UDPMuxDefault) removeConn(key string) {
m.mu.Lock()
c := m.conns[key]
delete(m.conns, key)
// keep lock section small to avoid deadlock with conn lock
m.mu.Unlock()
c := func() *udpMuxedConn {
m.mu.Lock()
defer m.mu.Unlock()

if c, ok := m.connsIPv4[key]; ok {
delete(m.connsIPv4, key)
return c
}

if c, ok := m.connsIPv6[key]; ok {
delete(m.connsIPv6, key)
return c
}

return nil
}()

if c == nil {
return
Expand Down Expand Up @@ -255,9 +278,10 @@ func (m *UDPMuxDefault) connWorker() {
}

ufrag := strings.Split(string(attr), ":")[0]
isIPv6 := udpAddr.IP.To4() == nil

m.mu.Lock()
destinationConn = m.conns[ufrag]
destinationConn, _ = m.getConn(ufrag, isIPv6)
m.mu.Unlock()
}

Expand All @@ -272,6 +296,15 @@ func (m *UDPMuxDefault) connWorker() {
}
}

func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
if isIPv6 {
val, ok = m.connsIPv6[ufrag]
} else {
val, ok = m.connsIPv4[ufrag]
}
return
}

type bufferHolder struct {
buffer []byte
}
Expand Down
4 changes: 2 additions & 2 deletions udp_mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func TestUDPMux(t *testing.T) {
require.NoError(t, udpMux.Close())

// can't create more connections
_, err = udpMux.GetConn("failufrag")
_, err = udpMux.GetConn("failufrag", false)
require.Error(t, err)
}

Expand Down Expand Up @@ -110,7 +110,7 @@ func TestAddressEncoding(t *testing.T) {
}

func testMuxConnection(t *testing.T, udpMux *UDPMuxDefault, ufrag string, network string) {
pktConn, err := udpMux.GetConn(ufrag)
pktConn, err := udpMux.GetConn(ufrag, false)
require.NoError(t, err, "error retrieving muxed connection for ufrag")
defer func() {
_ = pktConn.Close()
Expand Down
6 changes: 3 additions & 3 deletions udp_mux_universal.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type UniversalUDPMux interface {
UDPMux
GetXORMappedAddr(stunAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error)
GetRelayedAddr(turnAddr net.Addr, deadline time.Duration) (*net.Addr, error)
GetConnForURL(ufrag string, url string) (net.PacketConn, error)
GetConnForURL(ufrag string, url string, isIPv6 bool) (net.PacketConn, error)
}

// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn overriding ReadFrom.
Expand Down Expand Up @@ -84,8 +84,8 @@ func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time

// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers
// and return a unique connection per server.
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string) (net.PacketConn, error) {
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url))
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, isIPv6 bool) (net.PacketConn, error) {
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), isIPv6)
}

// ReadFrom is called by UDPMux connWorker and handles packets coming from the STUN server discovering a mapped address.
Expand Down
2 changes: 1 addition & 1 deletion udp_mux_universal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func TestUniversalUDPMux(t *testing.T) {
}

func testMuxSrflxConnection(t *testing.T, udpMux *UniversalUDPMuxDefault, ufrag string, network string) {
pktConn, err := udpMux.GetConn(ufrag)
pktConn, err := udpMux.GetConn(ufrag, false)
require.NoError(t, err, "error retrieving muxed connection for ufrag")
defer func() {
_ = pktConn.Close()
Expand Down

0 comments on commit fafa43e

Please sign in to comment.