diff --git a/internal/data/dhcp.go b/internal/data/dhcp.go index 0fdbd635..531f8cc3 100644 --- a/internal/data/dhcp.go +++ b/internal/data/dhcp.go @@ -38,27 +38,35 @@ func incrementIP(ip net.IP, inc uint) net.IP { } -func getNextIP(subnet string) (string, error) { +func chooseInitalIP(cidr *net.IPNet) (net.IP, error) { - serverIP, cidr, err := net.ParseCIDR(subnet) - if err != nil { - return "", err - } - - max := 32 - if serverIP.To16() != nil { - max = 128 + max := 128 + if cidr.IP.To4() != nil { + max = 32 } used, _ := cidr.Mask.Size() maxNumberOfAddresses := int(math.Pow(2, float64(max-used))) - 2 // Do not allocate largest address or 0 if maxNumberOfAddresses < 1 { - return "", errors.New("subnet is too small to contain a new device") + return nil, errors.New("subnet is too small to contain a new device") } // Choose a random number that cannot be 0 addressAttempt := rand.Intn(maxNumberOfAddresses) + 1 - addr := incrementIP(cidr.IP, uint(addressAttempt)) + return incrementIP(cidr.IP, uint(addressAttempt)), nil +} + +func getNextIP(subnet string) (string, error) { + + serverIP, cidr, err := net.ParseCIDR(subnet) + if err != nil { + return "", err + } + + addr, err := chooseInitalIP(cidr) + if err != nil { + return "", err + } lease, err := clientv3.NewLease(etcd).Grant(context.Background(), 3) if err != nil { diff --git a/internal/data/dhcp_test.go b/internal/data/dhcp_test.go index 4cb3f58e..38965cab 100644 --- a/internal/data/dhcp_test.go +++ b/internal/data/dhcp_test.go @@ -131,3 +131,34 @@ func TestIncrementIPOverflow(t *testing.T) { }) } } + +func TestChooseInitial(t *testing.T) { + + _, cidr, err := net.ParseCIDR("192.168.3.4/24") + if err != nil { + t.Fatal(err) + } + + addr, err := chooseInitalIP(cidr) + if err != nil { + t.Fatal(err) + } + + if !cidr.Contains(addr) { + t.Fatalf("does not contain address, %s", addr) + } + + _, cidr, err = net.ParseCIDR("2001:db8:abcd:1234:c000::/66") + if err != nil { + t.Fatal(err) + } + + addr, err = chooseInitalIP(cidr) + if err != nil { + t.Fatal(err) + } + + if !cidr.Contains(addr) { + t.Fatalf("does not contain address, %s", addr) + } +} diff --git a/internal/mfaportal/web.go b/internal/mfaportal/web.go index 9abd240c..5412c1d1 100644 --- a/internal/mfaportal/web.go +++ b/internal/mfaportal/web.go @@ -112,7 +112,7 @@ func New(firewall *router.Firewall, errChan chan<- error) (m *MfaPortal, err err tunnel.GetOrPost("/", mfaPortal.index) address := config.Values.Wireguard.ServerAddress.String() - if config.Values.Wireguard.ServerAddress.To16() != nil { + if config.Values.Wireguard.ServerAddress.To4() == nil && config.Values.Wireguard.ServerAddress.To16() != nil { address = "[" + address + "]" } diff --git a/internal/routetypes/parser.go b/internal/routetypes/parser.go index 846c627e..ca14222e 100644 --- a/internal/routetypes/parser.go +++ b/internal/routetypes/parser.go @@ -394,6 +394,7 @@ func parseAddress(address string) (resultAddresses []net.IPNet, err error) { if addr.To16() != nil { addedSomething = true resultAddresses = append(resultAddresses, net.IPNet{IP: addr.To16(), Mask: net.CIDRMask(128, 128)}) + continue } }