diff --git a/internal/data/devices.go b/internal/data/devices.go index 0dd9e12a..d077063e 100644 --- a/internal/data/devices.go +++ b/internal/data/devices.go @@ -206,7 +206,7 @@ func AddDevice(username, publickey string) (Device, error) { return Device{}, err } - address, err := allocateIPAddress(config.Values().Wireguard.Range.String()) + address, err := getNextIP(config.Values().Wireguard.Range.String()) if err != nil { return Device{}, err } @@ -289,7 +289,7 @@ func DeleteDevice(username, id string) error { otherReferenceKey = "deviceref-" + d.Address } - _, err = etcd.Txn(context.Background()).Then(clientv3.OpDelete(string(realKey.Kvs[0].Value)), clientv3.OpDelete(refKey), clientv3.OpDelete(otherReferenceKey)).Commit() + _, err = etcd.Txn(context.Background()).Then(clientv3.OpDelete(string(realKey.Kvs[0].Value)), clientv3.OpDelete(refKey), clientv3.OpDelete(otherReferenceKey), clientv3.OpDelete("allocated_ips/"+d.Address)).Commit() if err != nil { return err } @@ -312,7 +312,7 @@ func DeleteDevices(username string) error { return err } - ops = append(ops, clientv3.OpDelete("devicesref-"+d.Publickey), clientv3.OpDelete("deviceref-"+d.Address)) + ops = append(ops, clientv3.OpDelete("devicesref-"+d.Publickey), clientv3.OpDelete("deviceref-"+d.Address), clientv3.OpDelete("allocated_ips/"+d.Address)) } _, err = etcd.Txn(context.Background()).Then(ops...).Commit() diff --git a/internal/data/dhcp.go b/internal/data/dhcp.go index 1a6bb776..75d5515c 100644 --- a/internal/data/dhcp.go +++ b/internal/data/dhcp.go @@ -2,92 +2,64 @@ package data import ( "context" - "fmt" + "math" + "math/rand" "net" - "slices" - "time" clientv3 "go.etcd.io/etcd/client/v3" + "go.etcd.io/etcd/client/v3/clientv3util" ) -// This is almost certainly unsafe from splicing during multiple client registration -func allocateIPAddress(subnet string) (string, error) { +// https://gist.github.com/udhos/b468fbfd376aa0b655b6b0c539a88c03 +func incrementIP(ip net.IP, inc uint) net.IP { + i := ip.To4() + v := uint(i[0])<<24 + uint(i[1])<<16 + uint(i[2])<<8 + uint(i[3]) + v += inc + v3 := byte(v & 0xFF) + v2 := byte((v >> 8) & 0xFF) + v1 := byte((v >> 16) & 0xFF) + v0 := byte((v >> 24) & 0xFF) + return net.IPv4(v0, v1, v2, v3) +} - // Retrieve the list of allocated IPs - allocatedIPs, err := getAllocatedIPs() - if err != nil { - return "", err - } +func getNextIP(subnet string) (string, error) { - // Find an unallocated IP address within the given subnet - ip, err := findUnallocatedIP(subnet, allocatedIPs) + serverIP, cidr, err := net.ParseCIDR(subnet) if err != nil { return "", err } - // Mark the selected IP as allocated - if err := markIPAsAllocated(ip); err != nil { - return "", err - } - - return ip, nil -} - -func getAllocatedIPs() ([]string, error) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() + used, _ := cidr.Mask.Size() + addresses := int(math.Pow(2, float64(32-used))) - 2 // Do not allocate largest address or 0 - resp, err := etcd.Get(ctx, "allocated_ips", clientv3.WithPrefix()) - if err != nil { - return nil, err - } + // Choose a random number that cannot be 0 + addressAttempt := rand.Intn(addresses) + 1 + addr := incrementIP(cidr.IP, uint(addressAttempt)) - var allocatedIPs []string - for _, kv := range resp.Kvs { - allocatedIPs = append(allocatedIPs, string(kv.Value)) + if serverIP.Equal(addr) { + addr = incrementIP(addr, 1) } - return allocatedIPs, nil -} - -func findUnallocatedIP(subnet string, allocatedIPs []string) (string, error) { - _, ipNet, err := net.ParseCIDR(subnet) + lease, err := clientv3.NewLease(etcd).Grant(context.Background(), 3) if err != nil { return "", err } - for ip := ipNet.IP.Mask(ipNet.Mask); ipNet.Contains(ip); incrementIP(ip) { - // Check if the IP is unallocated + for { + txn := etcd.Txn(context.Background()) + txn.If(clientv3util.KeyMissing("deviceref-"+addr.String()), clientv3util.KeyMissing("ip-hold-"+addr.String())) + txn.Then(clientv3.OpPut("ip-hold-"+addr.String(), addr.String(), clientv3.WithLease(lease.ID))) - if !slices.Contains(allocatedIPs, ip.String()) { - return ip.String(), nil + resp, err := txn.Commit() + if err != nil { + return "", err } - } - - return "", fmt.Errorf("no available unallocated IP addresses in the subnet") -} - -func markIPAsAllocated(ip string) error { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - _, err := etcd.Put(ctx, fmt.Sprintf("allocated_ips/%s", ip), ip) - return err -} - -func markIPAsUnallocated(ip string) error { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - _, err := etcd.Delete(ctx, fmt.Sprintf("allocated_ips/%s", ip)) - return err -} -func incrementIP(ip net.IP) { - for j := len(ip) - 1; j >= 0; j-- { - ip[j]++ - if ip[j] > 0 { - break + if resp.Succeeded { + return addr.String(), nil } + + addr = incrementIP(addr, 1) } + } diff --git a/internal/data/registration.go b/internal/data/registration.go index 1d1f8291..0128b09e 100644 --- a/internal/data/registration.go +++ b/internal/data/registration.go @@ -86,15 +86,15 @@ func FinaliseRegistration(token string) error { return "", false, err } + result.NumUses-- + if result.NumUses <= 0 { - return "", false, errVal + err = errVal } - result.NumUses-- - b, _ := json.Marshal(result) - return string(b), false, nil + return string(b), false, err }) if err == errVal {