From a54e3c44cd0d5e8d8e7e50ab98a0ab245f3db8ce Mon Sep 17 00:00:00 2001
From: Edward Viaene <ward@in4it.io>
Date: Wed, 21 Aug 2024 16:08:44 -0500
Subject: [PATCH] ensure next IP work with ranges, ensureclients are loaded
 after restart

---
 pkg/configmanager/handlers.go               |   5 +
 pkg/rest/setup.go                           |   6 --
 pkg/wireguard/ip.go                         |  49 +++++----
 pkg/wireguard/ip_test.go                    |  72 ++++++++++++-
 pkg/wireguard/wireguardclientconfig.go      |  11 +-
 pkg/wireguard/wireguardclientconfig_test.go | 106 +++++++++++++++++++-
 webapp/src/Routes/Setup/Restart.tsx         |   5 +-
 webapp/src/Routes/Setup/TemplateSetup.tsx   |   2 +-
 webapp/src/Routes/Setup/VPNSetup.tsx        |   6 +-
 webapp/src/Routes/Users/Users.tsx           |   2 +-
 10 files changed, 223 insertions(+), 41 deletions(-)

diff --git a/pkg/configmanager/handlers.go b/pkg/configmanager/handlers.go
index 351a024..2b8edba 100644
--- a/pkg/configmanager/handlers.go
+++ b/pkg/configmanager/handlers.go
@@ -131,6 +131,11 @@ func (c *ConfigManager) restartVpn(w http.ResponseWriter, r *http.Request) {
 			returnError(w, fmt.Errorf("vpn start error: %s", err), http.StatusBadRequest)
 			return
 		}
+		err = refreshAllClientsAndServer(c.Storage)
+		if err != nil {
+			returnError(w, fmt.Errorf("could not refresh all clients: %s", err), http.StatusBadRequest)
+			return
+		}
 		w.WriteHeader(http.StatusAccepted)
 	default:
 		returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest)
diff --git a/pkg/rest/setup.go b/pkg/rest/setup.go
index 844b3d8..236c1c9 100644
--- a/pkg/rest/setup.go
+++ b/pkg/rest/setup.go
@@ -342,12 +342,6 @@ func (c *Context) templateSetupHandler(w http.ResponseWriter, r *http.Request) {
 				c.returnError(w, fmt.Errorf("WriteServerTemplate error: %s", err), http.StatusBadRequest)
 				return
 			}
-			// rewrite server config
-			err = wireguard.WriteWireGuardServerConfig(c.Storage.Client)
-			if err != nil {
-				c.returnError(w, fmt.Errorf("could not write wireguard server config: %s", err), http.StatusBadRequest)
-				return
-			}
 		}
 		out, err := json.Marshal(templateSetupRequest)
 		if err != nil {
diff --git a/pkg/wireguard/ip.go b/pkg/wireguard/ip.go
index 585561b..3def8ad 100644
--- a/pkg/wireguard/ip.go
+++ b/pkg/wireguard/ip.go
@@ -6,13 +6,17 @@ import (
 	"net"
 	"net/netip"
 	"path"
+	"strings"
 
 	"github.com/in4it/wireguard-server/pkg/storage"
 )
 
-func getNextFreeIP(storage storage.Iface, addressRange netip.Prefix) (net.IP, error) {
+func getNextFreeIP(storage storage.Iface, addressRange netip.Prefix, addressPrefix string) (net.IP, error) {
 	ipList := []string{}
-	startIP := net.IP(addressRange.Addr().AsSlice())
+	startIP, addressRangeParsed, err := net.ParseCIDR(addressRange.String())
+	if err != nil {
+		return nil, fmt.Errorf("cannot parse address range: %s: %s", addressRange, err)
+	}
 
 	clients, err := storage.ReadDir(storage.ConfigPath(VPN_CLIENTS_DIR))
 	if err != nil {
@@ -28,43 +32,48 @@ func getNextFreeIP(storage storage.Iface, addressRange netip.Prefix) (net.IP, er
 		if err != nil {
 			return nil, fmt.Errorf("cannot unmarshal %s: %s", clientFilename, err)
 		}
-		peerConfigAddress, _, err := net.ParseCIDR(peerConfig.Address)
-		if err != nil {
-			return nil, fmt.Errorf("could not parse peer config address %s: %s", peerConfig.Address, err)
-		}
-		ipList = append(ipList, peerConfigAddress.String())
+		ipList = append(ipList, peerConfig.Address)
 	}
 
-	newIP, err := getNextFreeIPFromList(startIP, ipList)
+	newIP, err := getNextFreeIPFromList(startIP, addressRangeParsed, ipList, addressPrefix)
 	if err != nil {
 		return nil, fmt.Errorf("getNextFreeIPFromList error: %s", err)
 	}
 
-	newIPAddress, err := netip.ParseAddr(newIP.String())
-	if err != nil {
-		return nil, fmt.Errorf("couldn't parse newIP: %s", newIP.String())
-	}
-
-	if !addressRange.Contains(newIPAddress) {
-		return nil, fmt.Errorf("newly allocated IP (%s) is not within address range (%s). Address Range might be too small", newIPAddress.String(), addressRange.String())
-	}
-
 	return newIP, nil
 }
-func getNextFreeIPFromList(startIP net.IP, ipList []string) (net.IP, error) {
+func getNextFreeIPFromList(startIP net.IP, addressRange *net.IPNet, ipList []string, addressPrefix string) (net.IP, error) {
 	nextIPAddress := startIP
 	for i := 0; i < 100000; i++ {
 		nextIPAddress = nextIP(nextIPAddress, 1)
 		ipExists := false
 		for _, ip := range ipList {
-			if nextIPAddress.String() == ip {
+			ipRange := ip
+			if !strings.Contains(ip, "/") {
+				ipRange += addressPrefix
+			}
+			_, ipRangeParsed, err := net.ParseCIDR(ipRange)
+			if err != nil {
+				return nil, fmt.Errorf("cannot parse IP address: %s (ip range %s)", ip, ipRange)
+			}
+			if ipRangeParsed.Contains(nextIPAddress) {
 				ipExists = true
 			}
 		}
 		if !ipExists {
-			return nextIPAddress, nil
+			if !addressRange.Contains(nextIPAddress) {
+				return nil, fmt.Errorf("next IP (%s) is not within address range (%s). Address Range might be too small", nextIPAddress.String(), addressRange.String())
+			}
+			_, ipRangeParsed, err := net.ParseCIDR(nextIPAddress.String() + addressPrefix)
+			if err != nil {
+				return nil, fmt.Errorf("cannot parse new IP address range: %s: %s", nextIPAddress.String()+addressPrefix, err)
+			}
+			if !ipRangeParsed.Contains(startIP) { // don't pick a range where the start ip is in the range
+				return nextIPAddress, nil
+			}
 		}
 	}
+
 	return nil, fmt.Errorf("couldn't determine next ip address")
 }
 
diff --git a/pkg/wireguard/ip_test.go b/pkg/wireguard/ip_test.go
index e0ee122..414fa8c 100644
--- a/pkg/wireguard/ip_test.go
+++ b/pkg/wireguard/ip_test.go
@@ -2,15 +2,16 @@ package wireguard
 
 import (
 	"net"
+	"strings"
 	"testing"
 )
 
 func TestGetNextFreeIPFromLisWithList(t *testing.T) {
-	startIP, _, err := net.ParseCIDR("10.189.184.1/21")
+	startIP, addressRange, err := net.ParseCIDR("10.189.184.1/21")
 	if err != nil {
 		t.Fatalf("error: %s", err)
 	}
-	nextIP, err := getNextFreeIPFromList(startIP, []string{"10.189.184.2"})
+	nextIP, err := getNextFreeIPFromList(startIP, addressRange, []string{"10.189.184.2"}, "/32")
 	if err != nil {
 		t.Fatalf("error: %s", err)
 	}
@@ -20,11 +21,11 @@ func TestGetNextFreeIPFromLisWithList(t *testing.T) {
 }
 
 func TestGetNextFreeIPFromLisWithList2(t *testing.T) {
-	startIP, _, err := net.ParseCIDR("10.189.184.1/21")
+	startIP, addressRange, err := net.ParseCIDR("10.189.184.1/21")
 	if err != nil {
 		t.Fatalf("error: %s", err)
 	}
-	nextIP, err := getNextFreeIPFromList(startIP, []string{"10.190.190.2", "10.189.184.2", "10.190.190.3"})
+	nextIP, err := getNextFreeIPFromList(startIP, addressRange, []string{"10.190.190.2", "10.189.184.2", "10.190.190.3"}, "/32")
 	if err != nil {
 		t.Fatalf("error: %s", err)
 	}
@@ -32,3 +33,66 @@ func TestGetNextFreeIPFromLisWithList2(t *testing.T) {
 		t.Fatalf("Wrong IP: %s", nextIP)
 	}
 }
+
+func TestGetNextFreeIPWithRange(t *testing.T) {
+	startIP, addressRange, err := net.ParseCIDR("10.189.184.1/21")
+	if err != nil {
+		t.Fatalf("error: %s", err)
+	}
+	networkPrefix := []string{
+		"/32",
+		"/32",
+		"/32",
+		"/32",
+		"/32",
+		"/32",
+		"/30",
+		"/30",
+		"/32",
+	}
+	testCases := [][]string{
+		{},
+		{"10.189.184.2"},
+		{"10.189.184.2/32"},
+		{"10.189.184.2", "10.189.184.3", "10.189.184.4/30"},
+		{"10.189.184.2", "10.189.184.3", "10.189.184.4/30", "10.189.184.8/32"},
+		{"10.189.184.1/30", "10.189.184.4/30", "10.189.184.8/30"},
+		{},
+		{"10.189.184.4/30", "10.189.184.8/30"},
+		{"10.189.189.2/32", "10.189.189.3/32", "10.189.189.4/32"},
+	}
+	expected := []string{
+		"10.189.184.2",
+		"10.189.184.3",
+		"10.189.184.3",
+		"10.189.184.8",
+		"10.189.184.9",
+		"10.189.184.12",
+		"10.189.184.4",
+		"10.189.184.12",
+		"10.189.184.2",
+	}
+
+	for k := range testCases {
+		nextIP, err := getNextFreeIPFromList(startIP, addressRange, testCases[k], networkPrefix[k])
+		if err != nil {
+			t.Fatalf("error: %s", err)
+		}
+		if nextIP.String() != expected[k] {
+			t.Fatalf("Wrong IP: %s", nextIP)
+		}
+	}
+
+}
+
+func TestIPNotInRange(t *testing.T) {
+	startIP, addressRange, err := net.ParseCIDR("10.189.184.1/21")
+	if err != nil {
+		t.Fatalf("error: %s", err)
+	}
+	_, err = getNextFreeIPFromList(startIP, addressRange, []string{"10.189.188.0/22"}, "/22")
+	if !strings.Contains(err.Error(), "not within address range") {
+		t.Fatalf("Expected error, got: %s", err)
+	}
+
+}
diff --git a/pkg/wireguard/wireguardclientconfig.go b/pkg/wireguard/wireguardclientconfig.go
index bbf5c5b..9d03f0d 100644
--- a/pkg/wireguard/wireguardclientconfig.go
+++ b/pkg/wireguard/wireguardclientconfig.go
@@ -50,7 +50,7 @@ func NewEmptyClientConfig(storage storage.Iface, userID string) (PeerConfig, err
 	}
 
 	// get next IP address, write in client file
-	nextFreeIP, err := getNextFreeIP(storage, vpnConfig.AddressRange)
+	nextFreeIP, err := getNextFreeIP(storage, vpnConfig.AddressRange, vpnConfig.ClientAddressPrefix)
 	if err != nil {
 		return PeerConfig{}, fmt.Errorf("getNextFreeIP error: %s", err)
 	}
@@ -75,7 +75,7 @@ func NewEmptyClientConfig(storage storage.Iface, userID string) (PeerConfig, err
 		DNS:              strings.Join(vpnConfig.Nameservers, ", "),
 		Name:             fmt.Sprintf("connection-%d", newConfigNumber),
 		Address:          nextFreeIP.String() + vpnConfig.ClientAddressPrefix,
-		ServerAllowedIPs: []string{nextFreeIP.String() + "/32"},
+		ServerAllowedIPs: []string{nextFreeIP.String() + vpnConfig.ClientAddressPrefix},
 		ClientAllowedIPs: clientAllowedIPs,
 	}
 
@@ -131,7 +131,7 @@ func UpdateClientsConfig(storage storage.Iface) error {
 			return fmt.Errorf("couldn't parse existing address of vpn config %s", clientFilename)
 		}
 		if !vpnConfig.AddressRange.Contains(addressParsed.Addr()) { // client IP address is not in address range (address range might have changed)
-			nextFreeIP, err := getNextFreeIP(storage, vpnConfig.AddressRange)
+			nextFreeIP, err := getNextFreeIP(storage, vpnConfig.AddressRange, vpnConfig.ClientAddressPrefix)
 			if err != nil {
 				return fmt.Errorf("getNextFreeIP error: %s", err)
 			}
@@ -140,6 +140,11 @@ func UpdateClientsConfig(storage storage.Iface) error {
 			rewriteFile = true
 		}
 
+		if !strings.HasSuffix(peerConfig.Address, vpnConfig.ClientAddressPrefix) {
+			rewriteFile = true
+			peerConfig.Address = addressParsed.Addr().String() + vpnConfig.ClientAddressPrefix
+		}
+
 		if rewriteFile {
 			peerConfigOut, err := json.Marshal(peerConfig)
 			if err != nil {
diff --git a/pkg/wireguard/wireguardclientconfig_test.go b/pkg/wireguard/wireguardclientconfig_test.go
index 2371e9f..5c1d346 100644
--- a/pkg/wireguard/wireguardclientconfig_test.go
+++ b/pkg/wireguard/wireguardclientconfig_test.go
@@ -16,9 +16,12 @@ import (
 )
 
 func TestGetNextFreeIPFromList(t *testing.T) {
-	ip := net.ParseIP("10.0.0.1")
+	startIP, addressRange, err := net.ParseCIDR("10.0.0.1/21")
+	if err != nil {
+		t.Fatalf("error: %s", err)
+	}
 	ipList := []string{"10.0.0.2", "10.0.0.3"}
-	nextIP, err := getNextFreeIPFromList(ip, ipList)
+	nextIP, err := getNextFreeIPFromList(startIP, addressRange, ipList, "/32")
 	if err != nil {
 		t.Errorf("next IP error: %s", err)
 	}
@@ -646,6 +649,7 @@ func TestUpdateClientConfigNewAddressRange(t *testing.T) {
 	newClientRoutes := []string{"1.2.3.4/32"}
 	vpnconfig.ClientRoutes = newClientRoutes
 	vpnconfig.AddressRange, err = netip.ParsePrefix("10.190.190.1/21")
+	vpnconfig.Nameservers = []string{"3.4.5.6", "8.8.8.8"}
 	if err != nil {
 		t.Fatalf("can't parse new ip range")
 	}
@@ -669,6 +673,9 @@ func TestUpdateClientConfigNewAddressRange(t *testing.T) {
 	if peerConfigCurrent.Address != "10.190.190.2/32" {
 		t.Fatalf("expected different client config address. Got: %s", peerConfigCurrent.Address)
 	}
+	if peerConfigCurrent.DNS != strings.Join(vpnconfig.Nameservers, ", ") {
+		t.Fatalf("Unexpected DNS Servers: %s (expected %s)", peerConfig.DNS, strings.Join(vpnconfig.Nameservers, ", "))
+	}
 
 	peerConfig, err = NewEmptyClientConfig(storage, "2-2-2-2")
 	if err != nil {
@@ -685,5 +692,100 @@ func TestUpdateClientConfigNewAddressRange(t *testing.T) {
 	if peerConfig.Address != "10.190.190.3/32" {
 		t.Fatalf("expected different client config address. Got: %s", peerConfig.Address)
 	}
+	if peerConfig.DNS != strings.Join(vpnconfig.Nameservers, ", ") {
+		t.Fatalf("Unexpected DNS Servers: %s (expected %s)", peerConfig.DNS, strings.Join(vpnconfig.Nameservers, ", "))
+	}
+}
+
+func TestUpdateClientConfigNewClientAddressPrefix(t *testing.T) {
+	var (
+		l   net.Listener
+		err error
+	)
+	for {
+		l, err = net.Listen("tcp", CONFIGMANAGER_URI)
+		if err != nil {
+			if !strings.HasSuffix(err.Error(), "address already in use") {
+				t.Fatal(err)
+			}
+			time.Sleep(1 * time.Second)
+		} else {
+			break
+		}
+	}
+
+	ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		switch r.Method {
+		case http.MethodPost:
+			if r.RequestURI == "/refresh-clients" {
+				w.WriteHeader(http.StatusAccepted)
+				w.Write([]byte("OK"))
+				return
+			}
+			w.WriteHeader(http.StatusBadRequest)
+		default:
+			w.WriteHeader(http.StatusBadRequest)
+		}
+	}))
+
+	ts.Listener.Close()
+	ts.Listener = l
+	ts.Start()
+	defer ts.Close()
+	defer l.Close()
+
+	storage := &testingmocks.MockMemoryStorage{}
+
+	// first create a new vpn config
+	vpnconfig, err := CreateNewVPNConfig(storage)
+	if err != nil {
+		t.Fatalf("CreateNewVPNConfig error: %s", err)
+	}
+	prefix, err := netip.ParsePrefix(DEFAULT_VPN_PREFIX)
+	if err != nil {
+		t.Errorf("ParsePrefix error: %s", err)
+	}
+	if vpnconfig.AddressRange.String() != prefix.String() {
+		t.Fatalf("wrong AddressRange: %s vs %s", vpnconfig.AddressRange.String(), prefix.String())
+	}
+	if vpnconfig.ClientAddressPrefix != "/32" {
+		t.Fatalf("unexpected default for address prefix: %s", vpnconfig.ClientAddressPrefix)
+	}
+	// generate the peerconfig
+	peerConfig, err := NewEmptyClientConfig(storage, "2-2-2-2")
+	if err != nil {
+		t.Fatalf("NewEmptyClientConfig error: %s", err)
+	}
+
+	if peerConfig.ClientAllowedIPs[0] != "0.0.0.0/0" {
+		t.Fatalf("wrong client allowed ips")
+	}
+
+	vpnconfig.ClientAddressPrefix = "/30"
+
+	err = WriteVPNConfig(storage, vpnconfig)
+	if err != nil {
+		t.Fatalf("WriteVPNConfig error: %s", err)
+	}
+	err = UpdateClientsConfig(storage)
+	if err != nil {
+		t.Fatalf("UpdateClientsConfig error: %s", err)
+	}
+
+	peerConfigCurrent, err := getPeerConfig(storage, "2-2-2-2-1")
+	if err != nil {
+		t.Fatalf("getPeerConfig error: %s", err)
+	}
+
+	if peerConfigCurrent.Address != "10.189.184.2/30" {
+		t.Fatalf("expected different client address. Got: %s", peerConfigCurrent.Address)
+	}
+	peerConfig, err = NewEmptyClientConfig(storage, "2-2-2-2")
+	if err != nil {
+		t.Fatalf("NewEmptyClientConfig error: %s", err)
+	}
+	if peerConfig.Address != "10.189.184.4/30" {
+		t.Fatalf("expected different client config address. Got: %s", peerConfig.Address)
+	}
 
 }
diff --git a/webapp/src/Routes/Setup/Restart.tsx b/webapp/src/Routes/Setup/Restart.tsx
index 01f69d9..8337bf3 100644
--- a/webapp/src/Routes/Setup/Restart.tsx
+++ b/webapp/src/Routes/Setup/Restart.tsx
@@ -12,6 +12,7 @@ type RestartError = {
 
 export function Restart() {
     const [saved, setSaved] = useState(false)
+    const [pending, setPending] = useState(false)
     const [saveError, setSaveError] = useState("")
     const {authInfo} = useAuthContext();
     const alertIcon = <IconInfoCircle />;
@@ -26,8 +27,10 @@ export function Restart() {
       onSuccess: () => {
           setSaved(true)
           setSaveError("")
+          setTimeout(function() { setPending(false); }, 1000);
       },
       onError: (error:AxiosError) => {
+        setTimeout(function() { setPending(false); }, 1000);
         const errorMessage = error.response?.data as RestartError
         if(errorMessage?.error === undefined) {
             setSaveError("Error: "+ error.message)
@@ -43,7 +46,7 @@ export function Restart() {
           <Space h="md" />
           {saved && saveError === "" ? <Alert variant="light" color="green" title="Restarted!" icon={alertIcon}>VPN Restarted!</Alert> : null}
           {saveError !== "" ? <Alert variant="light" color="red" title="Error!" icon={alertIcon} style={{marginTop: 10}}>{saveError}</Alert> : null}
-            <Button type="submit" mt="md" onClick={() =>  setupMutation.mutate()}>
+            <Button type="submit" mt="md" onClick={() => { setPending(true); setupMutation.mutate() } } disabled={pending}>
               Reload WireGuard® VPN
             </Button>
         </Container>
diff --git a/webapp/src/Routes/Setup/TemplateSetup.tsx b/webapp/src/Routes/Setup/TemplateSetup.tsx
index 243ef01..e96035a 100644
--- a/webapp/src/Routes/Setup/TemplateSetup.tsx
+++ b/webapp/src/Routes/Setup/TemplateSetup.tsx
@@ -96,7 +96,7 @@ export function TemplateSetup() {
             />
             <Space h="md" />
             <Textarea
-                label="VPN Server config template"
+                label="VPN Server config template (WireGuard® Configuration Reload in restart tab required to apply)"
                 key={form.key('serverTemplate')}
                 {...form.getInputProps('serverTemplate')}
                 autosize
diff --git a/webapp/src/Routes/Setup/VPNSetup.tsx b/webapp/src/Routes/Setup/VPNSetup.tsx
index 14b78d4..78b817c 100644
--- a/webapp/src/Routes/Setup/VPNSetup.tsx
+++ b/webapp/src/Routes/Setup/VPNSetup.tsx
@@ -94,7 +94,7 @@ export function VPNSetup() {
 
     return (
         <Container my={40} size="40rem">
-            <Alert variant="light" color="blue" title="Note!" icon={alertIcon}>Changes to Address Range, Port, External Interface, or NAT will need a wireguard reload. You can click the "Reload Wireguard" button at the bottom after submitting the changes. This will disconnect active VPN clients, and if the Address Range or Port is changed, all clients will need to download a new VPN Config.</Alert>
+            <Alert variant="light" color="blue" title="Note!" icon={alertIcon}>Changes to Address Range, Port, External Interface, or NAT will need a wireguard reload. You can click the "Reload Wireguard" button in the Restart tab after submitting the changes. This will disconnect active VPN clients, and if the Address Range or Port is changed, all clients will need to download a new VPN Config.</Alert>
             {saved && saveError === "" ? <Alert variant="light" color="green" title="Update!" icon={alertIcon} style={{marginTop: 10}}>Settings Saved!</Alert> : null}
             {saveError !== "" ? <Alert variant="light" color="red" title="Error!" icon={alertIcon} style={{marginTop: 10}}>{saveError}</Alert> : null}
 
@@ -144,8 +144,8 @@ export function VPNSetup() {
 
                 <InputWrapper
                 id="input-client-address-prefix-input"
-                label="Client Address Prefix"
-                description="Address prefix for the VPN Client to use. /32 means it'll not be able to communicate to other VPN clients."
+                label="Client Address Network Prefix"
+                description="Network prefix for the VPN Client to use. /32 means only one IP address for a client."
                 style={{marginTop: 10}}
                 >
                 <TextInput
diff --git a/webapp/src/Routes/Users/Users.tsx b/webapp/src/Routes/Users/Users.tsx
index 91f6f12..676d457 100644
--- a/webapp/src/Routes/Users/Users.tsx
+++ b/webapp/src/Routes/Users/Users.tsx
@@ -14,7 +14,7 @@ export function Users() {
     const { isPending, error, data } = useQuery({
         queryKey: ['setup'],
         queryFn: () =>
-          fetch(AppSettings.url + '/setup', {
+          fetch(AppSettings.url + '/setup/general', {
             headers: {
               "Content-Type": "application/json",
               "Authorization": "Bearer " + authInfo.token