Skip to content

Commit

Permalink
Improve ip address selection to be threadsafe, and just generally sma…
Browse files Browse the repository at this point in the history
…rter, fix registration tokens having 2 uses when it was single use
  • Loading branch information
NHAS committed Jan 21, 2024
1 parent e5c5d23 commit 6ff547b
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 71 deletions.
6 changes: 3 additions & 3 deletions internal/data/devices.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ func AddDevice(username, publickey string) (Device, error) {
return Device{}, err
}

address, err := allocateIPAddress(config.Values().Wireguard.Range.String())
address, err := getNextIP(config.Values().Wireguard.Range.String())
if err != nil {
return Device{}, err
}
Expand Down Expand Up @@ -289,7 +289,7 @@ func DeleteDevice(username, id string) error {
otherReferenceKey = "deviceref-" + d.Address
}

_, err = etcd.Txn(context.Background()).Then(clientv3.OpDelete(string(realKey.Kvs[0].Value)), clientv3.OpDelete(refKey), clientv3.OpDelete(otherReferenceKey)).Commit()
_, err = etcd.Txn(context.Background()).Then(clientv3.OpDelete(string(realKey.Kvs[0].Value)), clientv3.OpDelete(refKey), clientv3.OpDelete(otherReferenceKey), clientv3.OpDelete("allocated_ips/"+d.Address)).Commit()
if err != nil {
return err
}
Expand All @@ -312,7 +312,7 @@ func DeleteDevices(username string) error {
return err
}

ops = append(ops, clientv3.OpDelete("devicesref-"+d.Publickey), clientv3.OpDelete("deviceref-"+d.Address))
ops = append(ops, clientv3.OpDelete("devicesref-"+d.Publickey), clientv3.OpDelete("deviceref-"+d.Address), clientv3.OpDelete("allocated_ips/"+d.Address))
}

_, err = etcd.Txn(context.Background()).Then(ops...).Commit()
Expand Down
100 changes: 36 additions & 64 deletions internal/data/dhcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,92 +2,64 @@ package data

import (
"context"
"fmt"
"math"
"math/rand"
"net"
"slices"
"time"

clientv3 "go.etcd.io/etcd/client/v3"
"go.etcd.io/etcd/client/v3/clientv3util"
)

// This is almost certainly unsafe from splicing during multiple client registration
func allocateIPAddress(subnet string) (string, error) {
// https://gist.github.com/udhos/b468fbfd376aa0b655b6b0c539a88c03
func incrementIP(ip net.IP, inc uint) net.IP {
i := ip.To4()
v := uint(i[0])<<24 + uint(i[1])<<16 + uint(i[2])<<8 + uint(i[3])
v += inc
v3 := byte(v & 0xFF)
v2 := byte((v >> 8) & 0xFF)
v1 := byte((v >> 16) & 0xFF)
v0 := byte((v >> 24) & 0xFF)
return net.IPv4(v0, v1, v2, v3)
}

// Retrieve the list of allocated IPs
allocatedIPs, err := getAllocatedIPs()
if err != nil {
return "", err
}
func getNextIP(subnet string) (string, error) {

// Find an unallocated IP address within the given subnet
ip, err := findUnallocatedIP(subnet, allocatedIPs)
serverIP, cidr, err := net.ParseCIDR(subnet)
if err != nil {
return "", err
}

// Mark the selected IP as allocated
if err := markIPAsAllocated(ip); err != nil {
return "", err
}

return ip, nil
}

func getAllocatedIPs() ([]string, error) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
used, _ := cidr.Mask.Size()
addresses := int(math.Pow(2, float64(32-used))) - 2 // Do not allocate largest address or 0

resp, err := etcd.Get(ctx, "allocated_ips", clientv3.WithPrefix())
if err != nil {
return nil, err
}
// Choose a random number that cannot be 0
addressAttempt := rand.Intn(addresses) + 1
addr := incrementIP(cidr.IP, uint(addressAttempt))

var allocatedIPs []string
for _, kv := range resp.Kvs {
allocatedIPs = append(allocatedIPs, string(kv.Value))
if serverIP.Equal(addr) {
addr = incrementIP(addr, 1)
}

return allocatedIPs, nil
}

func findUnallocatedIP(subnet string, allocatedIPs []string) (string, error) {
_, ipNet, err := net.ParseCIDR(subnet)
lease, err := clientv3.NewLease(etcd).Grant(context.Background(), 3)
if err != nil {
return "", err
}

for ip := ipNet.IP.Mask(ipNet.Mask); ipNet.Contains(ip); incrementIP(ip) {
// Check if the IP is unallocated
for {
txn := etcd.Txn(context.Background())
txn.If(clientv3util.KeyMissing("deviceref-"+addr.String()), clientv3util.KeyMissing("ip-hold-"+addr.String()))
txn.Then(clientv3.OpPut("ip-hold-"+addr.String(), addr.String(), clientv3.WithLease(lease.ID)))

if !slices.Contains(allocatedIPs, ip.String()) {
return ip.String(), nil
resp, err := txn.Commit()
if err != nil {
return "", err
}
}

return "", fmt.Errorf("no available unallocated IP addresses in the subnet")
}

func markIPAsAllocated(ip string) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

_, err := etcd.Put(ctx, fmt.Sprintf("allocated_ips/%s", ip), ip)
return err
}

func markIPAsUnallocated(ip string) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

_, err := etcd.Delete(ctx, fmt.Sprintf("allocated_ips/%s", ip))
return err
}

func incrementIP(ip net.IP) {
for j := len(ip) - 1; j >= 0; j-- {
ip[j]++
if ip[j] > 0 {
break
if resp.Succeeded {
return addr.String(), nil
}

addr = incrementIP(addr, 1)
}

}
8 changes: 4 additions & 4 deletions internal/data/registration.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,15 @@ func FinaliseRegistration(token string) error {
return "", false, err
}

result.NumUses--

if result.NumUses <= 0 {
return "", false, errVal
err = errVal
}

result.NumUses--

b, _ := json.Marshal(result)

return string(b), false, nil
return string(b), false, err
})

if err == errVal {
Expand Down

0 comments on commit 6ff547b

Please sign in to comment.