Skip to content

Commit

Permalink
Add callback API to track allocation lifecycle
Browse files Browse the repository at this point in the history
  • Loading branch information
rg0now committed Dec 5, 2024
1 parent 3ff9392 commit bd0b543
Show file tree
Hide file tree
Showing 20 changed files with 788 additions and 47 deletions.
93 changes: 86 additions & 7 deletions internal/allocation/allocation.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ type Allocation struct {
channelBindings []*ChannelBind
lifetimeTimer *time.Timer
closed chan interface{}
username, realm string
callback EventHandler
log logging.LeveledLogger

// Some clients (Firefox or others using resiprocate's nICE lib) may retry allocation
Expand All @@ -45,12 +47,13 @@ type Allocation struct {
}

// NewAllocation creates a new instance of NewAllocation.
func NewAllocation(turnSocket net.PacketConn, fiveTuple *FiveTuple, log logging.LeveledLogger) *Allocation {
func NewAllocation(turnSocket net.PacketConn, fiveTuple *FiveTuple, callback EventHandler, log logging.LeveledLogger) *Allocation {
return &Allocation{
TurnSocket: turnSocket,
fiveTuple: fiveTuple,
permissions: make(map[string]*Permission, 64),
closed: make(chan interface{}),
callback: callback,
log: log,
}
}
Expand Down Expand Up @@ -81,6 +84,21 @@ func (a *Allocation) AddPermission(p *Permission) {
a.permissions[fingerprint] = p
a.permissionsLock.Unlock()

if a.callback != nil {
if u, ok := p.Addr.(*net.UDPAddr); ok {
a.callback(EventHandlerArgs{
Type: OnPermissionCreated,
SrcAddr: a.fiveTuple.SrcAddr,
DstAddr: a.fiveTuple.DstAddr,
Protocol: a.fiveTuple.Protocol,
Username: a.username,
Realm: a.realm,
RelayAddr: a.RelayAddr,
PeerIP: u.IP,
})
}
}

p.start(permissionTimeout)
}

Expand All @@ -89,6 +107,32 @@ func (a *Allocation) RemovePermission(addr net.Addr) {
a.permissionsLock.Lock()
defer a.permissionsLock.Unlock()
delete(a.permissions, ipnet.FingerprintAddr(addr))

if a.callback != nil {
if u, ok := addr.(*net.UDPAddr); ok {
a.callback(EventHandlerArgs{
Type: OnPermissionDeleted,
SrcAddr: a.fiveTuple.SrcAddr,
DstAddr: a.fiveTuple.DstAddr,
Protocol: a.fiveTuple.Protocol,
Username: a.username,
Realm: a.realm,
RelayAddr: a.RelayAddr,
PeerIP: u.IP,
})
}
}
}

// ListPermissions returns the permissions associated with an allocation.
func (a *Allocation) ListPermissions() []*Permission {
ps := []*Permission{}
a.permissionsLock.RLock()
defer a.permissionsLock.RUnlock()
for _, p := range a.permissions {
ps = append(ps, p)
}
return ps
}

// AddChannelBind adds a new ChannelBind to the allocation, it also updates the
Expand All @@ -113,6 +157,20 @@ func (a *Allocation) AddChannelBind(c *ChannelBind, lifetime time.Duration) erro

// Channel binds also refresh permissions.
a.AddPermission(NewPermission(c.Peer, a.log))

if a.callback != nil {
a.callback(EventHandlerArgs{
Type: OnChannelCreated,
SrcAddr: a.fiveTuple.SrcAddr,
DstAddr: a.fiveTuple.DstAddr,
Protocol: a.fiveTuple.Protocol,
Username: a.username,
Realm: a.realm,
RelayAddr: a.RelayAddr,
PeerAddr: c.Peer,
ChannelNumber: uint16(c.Number),
})
}
} else {
channelByNumber.refresh(lifetime)

Expand All @@ -130,6 +188,20 @@ func (a *Allocation) RemoveChannelBind(number proto.ChannelNumber) bool {

for i := len(a.channelBindings) - 1; i >= 0; i-- {
if a.channelBindings[i].Number == number {
if a.callback != nil {
a.callback(EventHandlerArgs{
Type: OnChannelDeleted,
SrcAddr: a.fiveTuple.SrcAddr,
DstAddr: a.fiveTuple.DstAddr,
Protocol: a.fiveTuple.Protocol,
Username: a.username,
Realm: a.realm,
RelayAddr: a.RelayAddr,
PeerAddr: a.channelBindings[i].Peer,
ChannelNumber: uint16(a.channelBindings[i].Number),
})
}

a.channelBindings = append(a.channelBindings[:i], a.channelBindings[i+1:]...)
return true
}
Expand Down Expand Up @@ -162,6 +234,15 @@ func (a *Allocation) GetChannelByAddr(addr net.Addr) *ChannelBind {
return nil
}

// ListChannelBindings returns the channel bindings associated with an allocation.
func (a *Allocation) ListChannelBindings() []*ChannelBind {
cs := []*ChannelBind{}
a.channelBindingsLock.RLock()
defer a.channelBindingsLock.RUnlock()
cs = append(cs, a.channelBindings...)
return cs
}

// Refresh updates the allocations lifetime
func (a *Allocation) Refresh(lifetime time.Duration) {
if !a.lifetimeTimer.Reset(lifetime) {
Expand Down Expand Up @@ -196,17 +277,15 @@ func (a *Allocation) Close() error {

a.lifetimeTimer.Stop()

a.permissionsLock.RLock()
for _, p := range a.permissions {
for _, p := range a.ListPermissions() {
a.RemovePermission(p.Addr)
p.lifetimeTimer.Stop()
}
a.permissionsLock.RUnlock()

a.channelBindingsLock.RLock()
for _, c := range a.channelBindings {
for _, c := range a.ListChannelBindings() {
a.RemoveChannelBind(c.Number)
c.lifetimeTimer.Stop()
}
a.channelBindingsLock.RUnlock()

return a.RelaySocket.Close()
}
Expand Down
33 changes: 31 additions & 2 deletions internal/allocation/allocation_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type ManagerConfig struct {
AllocatePacketConn func(network string, requestedPort int) (net.PacketConn, net.Addr, error)
AllocateConn func(network string, requestedPort int) (net.Conn, net.Addr, error)
PermissionHandler func(sourceAddr net.Addr, peerIP net.IP) bool
EventHandler EventHandler
}

type reservation struct {
Expand All @@ -36,6 +37,7 @@ type Manager struct {
allocatePacketConn func(network string, requestedPort int) (net.PacketConn, net.Addr, error)
allocateConn func(network string, requestedPort int) (net.Conn, net.Addr, error)
permissionHandler func(sourceAddr net.Addr, peerIP net.IP) bool
EventHandler EventHandler
}

// NewManager creates a new instance of Manager.
Expand All @@ -55,6 +57,7 @@ func NewManager(config ManagerConfig) (*Manager, error) {
allocatePacketConn: config.AllocatePacketConn,
allocateConn: config.AllocateConn,
permissionHandler: config.PermissionHandler,
EventHandler: config.EventHandler,
}, nil
}

Expand Down Expand Up @@ -86,7 +89,7 @@ func (m *Manager) Close() error {
}

// CreateAllocation creates a new allocation and starts relaying
func (m *Manager) CreateAllocation(fiveTuple *FiveTuple, turnSocket net.PacketConn, requestedPort int, lifetime time.Duration) (*Allocation, error) {
func (m *Manager) CreateAllocation(fiveTuple *FiveTuple, turnSocket net.PacketConn, requestedPort int, lifetime time.Duration, username, realm string) (*Allocation, error) {
switch {
case fiveTuple == nil:
return nil, errNilFiveTuple
Expand All @@ -103,7 +106,9 @@ func (m *Manager) CreateAllocation(fiveTuple *FiveTuple, turnSocket net.PacketCo
if a := m.GetAllocation(fiveTuple); a != nil {
return nil, fmt.Errorf("%w: %v", errDupeFiveTuple, fiveTuple)
}
a := NewAllocation(turnSocket, fiveTuple, m.log)
a := NewAllocation(turnSocket, fiveTuple, m.EventHandler, m.log)
a.username = username
a.realm = realm

conn, relayAddr, err := m.allocatePacketConn("udp4", requestedPort)
if err != nil {
Expand All @@ -123,6 +128,19 @@ func (m *Manager) CreateAllocation(fiveTuple *FiveTuple, turnSocket net.PacketCo
m.allocations[fiveTuple.Fingerprint()] = a
m.lock.Unlock()

if m.EventHandler != nil {
m.EventHandler(EventHandlerArgs{
Type: OnAllocationCreated,
SrcAddr: fiveTuple.SrcAddr,
DstAddr: fiveTuple.DstAddr,
Protocol: UDP,
Username: username,
Realm: realm,
RelayAddr: relayAddr,
RequestedPort: requestedPort,
})
}

go a.packetHandler(m)
return a, nil
}
Expand All @@ -143,6 +161,17 @@ func (m *Manager) DeleteAllocation(fiveTuple *FiveTuple) {
if err := allocation.Close(); err != nil {
m.log.Errorf("Failed to close allocation: %v", err)
}

if m.EventHandler != nil {
m.EventHandler(EventHandlerArgs{
Type: OnAllocationDeleted,
SrcAddr: fiveTuple.SrcAddr,
DstAddr: fiveTuple.DstAddr,
Protocol: UDP,
Username: allocation.username,
Realm: allocation.realm,
})
}
}

// CreateReservation stores the reservation for the token+port
Expand Down
20 changes: 10 additions & 10 deletions internal/allocation/allocation_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ func subTestCreateInvalidAllocation(t *testing.T, turnSocket net.PacketConn) {
m, err := newTestManager()
assert.NoError(t, err)

if a, err := m.CreateAllocation(nil, turnSocket, 0, proto.DefaultLifetime); a != nil || err == nil {
if a, err := m.CreateAllocation(nil, turnSocket, 0, proto.DefaultLifetime, "", ""); a != nil || err == nil {
t.Errorf("Illegally created allocation with nil FiveTuple")
}
if a, err := m.CreateAllocation(randomFiveTuple(), nil, 0, proto.DefaultLifetime); a != nil || err == nil {
if a, err := m.CreateAllocation(randomFiveTuple(), nil, 0, proto.DefaultLifetime, "", ""); a != nil || err == nil {
t.Errorf("Illegally created allocation with nil turnSocket")
}
if a, err := m.CreateAllocation(randomFiveTuple(), turnSocket, 0, 0); a != nil || err == nil {
if a, err := m.CreateAllocation(randomFiveTuple(), turnSocket, 0, 0, "", ""); a != nil || err == nil {
t.Errorf("Illegally created allocation with 0 lifetime")
}
}
Expand All @@ -69,7 +69,7 @@ func subTestCreateAllocation(t *testing.T, turnSocket net.PacketConn) {
assert.NoError(t, err)

fiveTuple := randomFiveTuple()
if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime); a == nil || err != nil {
if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime, "", ""); a == nil || err != nil {
t.Errorf("Failed to create allocation %v %v", a, err)
}

Expand All @@ -84,11 +84,11 @@ func subTestCreateAllocationDuplicateFiveTuple(t *testing.T, turnSocket net.Pack
assert.NoError(t, err)

fiveTuple := randomFiveTuple()
if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime); a == nil || err != nil {
if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime, "", ""); a == nil || err != nil {
t.Errorf("Failed to create allocation %v %v", a, err)
}

if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime); a != nil || err == nil {
if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime, "", ""); a != nil || err == nil {
t.Errorf("Was able to create allocation with same FiveTuple twice")
}
}
Expand All @@ -98,7 +98,7 @@ func subTestDeleteAllocation(t *testing.T, turnSocket net.PacketConn) {
assert.NoError(t, err)

fiveTuple := randomFiveTuple()
if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime); a == nil || err != nil {
if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime, "", ""); a == nil || err != nil {
t.Errorf("Failed to create allocation %v %v", a, err)
}

Expand All @@ -123,7 +123,7 @@ func subTestAllocationTimeout(t *testing.T, turnSocket net.PacketConn) {
for index := range allocations {
fiveTuple := randomFiveTuple()

a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, lifetime)
a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, lifetime, "", "")
if err != nil {
t.Errorf("Failed to create allocation with %v", fiveTuple)
}
Expand All @@ -147,9 +147,9 @@ func subTestManagerClose(t *testing.T, turnSocket net.PacketConn) {

allocations := make([]*Allocation, 2)

a1, _ := m.CreateAllocation(randomFiveTuple(), turnSocket, 0, time.Second)
a1, _ := m.CreateAllocation(randomFiveTuple(), turnSocket, 0, time.Second, "", "")
allocations[0] = a1
a2, _ := m.CreateAllocation(randomFiveTuple(), turnSocket, 0, time.Minute)
a2, _ := m.CreateAllocation(randomFiveTuple(), turnSocket, 0, time.Minute, "", "")
allocations[1] = a2

// Make a1 timeout
Expand Down
Loading

0 comments on commit bd0b543

Please sign in to comment.