diff --git a/pkg/configmanager/upgrade.go b/pkg/configmanager/upgrade.go index 19254c1..7cdfc0e 100644 --- a/pkg/configmanager/upgrade.go +++ b/pkg/configmanager/upgrade.go @@ -42,6 +42,9 @@ func newVersionAvailable() (bool, string, error) { if i1 > i2 { return true, latestVersion, nil } + if i1 < i2 { + return false, latestVersion, nil + } } } return false, latestVersion, nil diff --git a/pkg/configmanager/upgrade_test.go b/pkg/configmanager/upgrade_test.go index 3724bb4..bcc5f18 100644 --- a/pkg/configmanager/upgrade_test.go +++ b/pkg/configmanager/upgrade_test.go @@ -145,3 +145,66 @@ func TestNewVersionAvailableBogus2(t *testing.T) { t.Fatalf("expected new version not to be available: %s", version) } } +func TestNewVersionAvailableHigherVersionMajor(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if req.URL.RequestURI() == "/latest" { + currentVersionSplit := strings.Split(getVersion(), ".") + if len(currentVersionSplit) != 3 { + t.Fatalf("unsupported current version: %s", getVersion()) + } + i, err := strconv.Atoi(currentVersionSplit[1]) + if err != nil { + t.Fatalf("unsupported current version: %s", getVersion()) + } + i++ + newVersion := strings.Join([]string{currentVersionSplit[0], strconv.Itoa(i), "0"}, ".") + w.Write([]byte(newVersion)) + return + } + w.WriteHeader(http.StatusNotFound) + })) + + defer server.Close() + + BINARIES_URL = server.URL + + available, version, err := newVersionAvailable() + if err != nil { + t.Fatalf("error: %s", err) + } + if !available { + t.Fatalf("expected new version expected to be available: %s", version) + } +} + +func TestNewVersionNotAvailableHigherVersionMajor(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if req.URL.RequestURI() == "/latest" { + currentVersionSplit := strings.Split(getVersion(), ".") + if len(currentVersionSplit) != 3 { + t.Fatalf("unsupported current version: %s", getVersion()) + } + i, err := strconv.Atoi(currentVersionSplit[1]) + if err != nil { + t.Fatalf("unsupported current version: %s", getVersion()) + } + i-- + newVersion := strings.Join([]string{currentVersionSplit[0], strconv.Itoa(i), "99"}, ".") + w.Write([]byte(newVersion)) + return + } + w.WriteHeader(http.StatusNotFound) + })) + + defer server.Close() + + BINARIES_URL = server.URL + + available, version, err := newVersionAvailable() + if err != nil { + t.Fatalf("error: %s", err) + } + if available { + t.Fatalf("expected new version expected not to be available: %s (current version: %s)", version, getVersion()) + } +} diff --git a/pkg/rest/setup.go b/pkg/rest/setup.go index a754123..844b3d8 100644 --- a/pkg/rest/setup.go +++ b/pkg/rest/setup.go @@ -187,7 +187,6 @@ func (c *Context) vpnSetupHandler(w http.ResponseWriter, r *http.Request) { var ( writeVPNConfig bool rewriteClientConfigs bool - rewriteServerConfig bool setupRequest VPNSetupRequest ) decoder := json.NewDecoder(r.Body) @@ -225,7 +224,6 @@ func (c *Context) vpnSetupHandler(w http.ResponseWriter, r *http.Request) { vpnConfig.AddressRange = addressRangeParsed writeVPNConfig = true rewriteClientConfigs = true - rewriteServerConfig = true } if setupRequest.ClientAddressPrefix != vpnConfig.ClientAddressPrefix { vpnConfig.ClientAddressPrefix = setupRequest.ClientAddressPrefix @@ -241,7 +239,6 @@ func (c *Context) vpnSetupHandler(w http.ResponseWriter, r *http.Request) { vpnConfig.Port = port writeVPNConfig = true rewriteClientConfigs = true - rewriteServerConfig = true } nameservers := strings.Split(setupRequest.Nameservers, ",") @@ -256,12 +253,10 @@ func (c *Context) vpnSetupHandler(w http.ResponseWriter, r *http.Request) { if setupRequest.ExternalInterface != vpnConfig.ExternalInterface { // don't rewrite client config vpnConfig.ExternalInterface = setupRequest.ExternalInterface writeVPNConfig = true - rewriteServerConfig = true } if setupRequest.DisableNAT != vpnConfig.DisableNAT { // don't rewrite client config vpnConfig.DisableNAT = setupRequest.DisableNAT writeVPNConfig = true - rewriteServerConfig = true } // write vpn config if config has changed @@ -280,14 +275,6 @@ func (c *Context) vpnSetupHandler(w http.ResponseWriter, r *http.Request) { return } } - if rewriteServerConfig { - // rewrite server config - err = wireguard.WriteWireGuardServerConfig(c.Storage.Client) - if err != nil { - c.returnError(w, fmt.Errorf("could not write wireguard server config: %s", err), http.StatusBadRequest) - return - } - } out, err := json.Marshal(setupRequest) if err != nil { c.returnError(w, fmt.Errorf("could not marshal SetupRequest: %s", err), http.StatusBadRequest) diff --git a/pkg/wireguard/ip.go b/pkg/wireguard/ip.go index 18f6839..585561b 100644 --- a/pkg/wireguard/ip.go +++ b/pkg/wireguard/ip.go @@ -4,13 +4,15 @@ import ( "encoding/json" "fmt" "net" + "net/netip" "path" "github.com/in4it/wireguard-server/pkg/storage" ) -func getNextFreeIP(storage storage.Iface, startIP net.IP) (net.IP, error) { +func getNextFreeIP(storage storage.Iface, addressRange netip.Prefix) (net.IP, error) { ipList := []string{} + startIP := net.IP(addressRange.Addr().AsSlice()) clients, err := storage.ReadDir(storage.ConfigPath(VPN_CLIENTS_DIR)) if err != nil { @@ -38,6 +40,15 @@ func getNextFreeIP(storage storage.Iface, startIP net.IP) (net.IP, error) { return nil, fmt.Errorf("getNextFreeIPFromList error: %s", err) } + newIPAddress, err := netip.ParseAddr(newIP.String()) + if err != nil { + return nil, fmt.Errorf("couldn't parse newIP: %s", newIP.String()) + } + + if !addressRange.Contains(newIPAddress) { + return nil, fmt.Errorf("newly allocated IP (%s) is not within address range (%s). Address Range might be too small", newIPAddress.String(), addressRange.String()) + } + return newIP, nil } func getNextFreeIPFromList(startIP net.IP, ipList []string) (net.IP, error) { diff --git a/pkg/wireguard/ip_test.go b/pkg/wireguard/ip_test.go index 029a35c..e0ee122 100644 --- a/pkg/wireguard/ip_test.go +++ b/pkg/wireguard/ip_test.go @@ -18,3 +18,17 @@ func TestGetNextFreeIPFromLisWithList(t *testing.T) { t.Fatalf("Wrong IP: %s", nextIP) } } + +func TestGetNextFreeIPFromLisWithList2(t *testing.T) { + startIP, _, err := net.ParseCIDR("10.189.184.1/21") + if err != nil { + t.Fatalf("error: %s", err) + } + nextIP, err := getNextFreeIPFromList(startIP, []string{"10.190.190.2", "10.189.184.2", "10.190.190.3"}) + if err != nil { + t.Fatalf("error: %s", err) + } + if nextIP.String() != "10.189.184.3" { + t.Fatalf("Wrong IP: %s", nextIP) + } +} diff --git a/pkg/wireguard/wireguardclientconfig.go b/pkg/wireguard/wireguardclientconfig.go index 952b917..bbf5c5b 100644 --- a/pkg/wireguard/wireguardclientconfig.go +++ b/pkg/wireguard/wireguardclientconfig.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "net/http" + "net/netip" "path" "slices" "strconv" @@ -49,12 +50,7 @@ func NewEmptyClientConfig(storage storage.Iface, userID string) (PeerConfig, err } // get next IP address, write in client file - addressRangeSplit := strings.Split(vpnConfig.AddressRange.String(), "/") - firstIP := net.ParseIP(addressRangeSplit[0]) - if firstIP == nil { - return PeerConfig{}, fmt.Errorf("couldn't determine address range from vpn setup") - } - nextFreeIP, err := getNextFreeIP(storage, firstIP) + nextFreeIP, err := getNextFreeIP(storage, vpnConfig.AddressRange) if err != nil { return PeerConfig{}, fmt.Errorf("getNextFreeIP error: %s", err) } @@ -130,6 +126,20 @@ func UpdateClientsConfig(storage storage.Iface) error { peerConfig.DNS = strings.Join(vpnConfig.Nameservers, ", ") } + addressParsed, err := netip.ParsePrefix(peerConfig.Address) + if err != nil { + return fmt.Errorf("couldn't parse existing address of vpn config %s", clientFilename) + } + if !vpnConfig.AddressRange.Contains(addressParsed.Addr()) { // client IP address is not in address range (address range might have changed) + nextFreeIP, err := getNextFreeIP(storage, vpnConfig.AddressRange) + if err != nil { + return fmt.Errorf("getNextFreeIP error: %s", err) + } + peerConfig.Address = nextFreeIP.String() + vpnConfig.ClientAddressPrefix + peerConfig.ServerAllowedIPs = []string{nextFreeIP.String() + "/32"} + rewriteFile = true + } + if rewriteFile { peerConfigOut, err := json.Marshal(peerConfig) if err != nil { diff --git a/pkg/wireguard/wireguardclientconfig_test.go b/pkg/wireguard/wireguardclientconfig_test.go index 3e2c159..2371e9f 100644 --- a/pkg/wireguard/wireguardclientconfig_test.go +++ b/pkg/wireguard/wireguardclientconfig_test.go @@ -581,3 +581,109 @@ func TestUpdateClientConfig(t *testing.T) { } } + +func TestUpdateClientConfigNewAddressRange(t *testing.T) { + var ( + l net.Listener + err error + ) + for { + l, err = net.Listen("tcp", CONFIGMANAGER_URI) + if err != nil { + if !strings.HasSuffix(err.Error(), "address already in use") { + t.Fatal(err) + } + time.Sleep(1 * time.Second) + } else { + break + } + } + + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodPost: + if r.RequestURI == "/refresh-clients" { + w.WriteHeader(http.StatusAccepted) + w.Write([]byte("OK")) + return + } + w.WriteHeader(http.StatusBadRequest) + default: + w.WriteHeader(http.StatusBadRequest) + } + })) + + ts.Listener.Close() + ts.Listener = l + ts.Start() + defer ts.Close() + defer l.Close() + + storage := &testingmocks.MockMemoryStorage{} + + // first create a new vpn config + vpnconfig, err := CreateNewVPNConfig(storage) + if err != nil { + t.Fatalf("CreateNewVPNConfig error: %s", err) + } + prefix, err := netip.ParsePrefix(DEFAULT_VPN_PREFIX) + if err != nil { + t.Errorf("ParsePrefix error: %s", err) + } + if vpnconfig.AddressRange.String() != prefix.String() { + t.Fatalf("wrong AddressRange: %s vs %s", vpnconfig.AddressRange.String(), prefix.String()) + } + // generate the peerconfig + peerConfig, err := NewEmptyClientConfig(storage, "2-2-2-2") + if err != nil { + t.Fatalf("NewEmptyClientConfig error: %s", err) + } + + if peerConfig.ClientAllowedIPs[0] != "0.0.0.0/0" { + t.Fatalf("wrong client allowed ips") + } + + newClientRoutes := []string{"1.2.3.4/32"} + vpnconfig.ClientRoutes = newClientRoutes + vpnconfig.AddressRange, err = netip.ParsePrefix("10.190.190.1/21") + if err != nil { + t.Fatalf("can't parse new ip range") + } + err = WriteVPNConfig(storage, vpnconfig) + if err != nil { + t.Fatalf("WriteVPNConfig error: %s", err) + } + err = UpdateClientsConfig(storage) + if err != nil { + t.Fatalf("UpdateClientsConfig error: %s", err) + } + + peerConfigCurrent, err := getPeerConfig(storage, "2-2-2-2-1") + if err != nil { + t.Fatalf("getPeerConfig error: %s", err) + } + + if peerConfigCurrent.ServerAllowedIPs[0] != "10.190.190.2/32" { + t.Fatalf("expected different server allowed IP. Got: %s", strings.Join(peerConfigCurrent.ServerAllowedIPs, ", ")) + } + if peerConfigCurrent.Address != "10.190.190.2/32" { + t.Fatalf("expected different client config address. Got: %s", peerConfigCurrent.Address) + } + + peerConfig, err = NewEmptyClientConfig(storage, "2-2-2-2") + if err != nil { + t.Fatalf("NewEmptyClientConfig error: %s", err) + } + + if peerConfig.ClientAllowedIPs[0] != "1.2.3.4/32" { + t.Fatalf("wrong client allowed ips") + } + + if peerConfig.ServerAllowedIPs[0] != "10.190.190.3/32" { + t.Fatalf("expected different server allowed IP. Got: %s", strings.Join(peerConfig.ServerAllowedIPs, ", ")) + } + if peerConfig.Address != "10.190.190.3/32" { + t.Fatalf("expected different client config address. Got: %s", peerConfig.Address) + } + +} diff --git a/webapp/src/Routes/Setup/GeneralSetup.tsx b/webapp/src/Routes/Setup/GeneralSetup.tsx index ee17276..95e84e0 100644 --- a/webapp/src/Routes/Setup/GeneralSetup.tsx +++ b/webapp/src/Routes/Setup/GeneralSetup.tsx @@ -59,8 +59,10 @@ export function GeneralSetup() { }, onSuccess: () => { queryClient.invalidateQueries({ queryKey: ['users'] }) + queryClient.invalidateQueries({ queryKey: ['general-setup'] }) setSaved(true) setSaveError("") + window.scrollTo(0, 0) }, onError: (error:AxiosError) => { const errorMessage = error.response?.data as GeneralSetupError diff --git a/webapp/src/Routes/Setup/Setup.tsx b/webapp/src/Routes/Setup/Setup.tsx index d21bd88..54c236a 100644 --- a/webapp/src/Routes/Setup/Setup.tsx +++ b/webapp/src/Routes/Setup/Setup.tsx @@ -1,6 +1,6 @@ import { Container, Tabs, Title, rem } from "@mantine/core"; import classes from './Setup.module.css'; -import { IconFile, IconNetwork, IconRestore, IconRewindBackward10, IconSettings } from "@tabler/icons-react"; +import { IconFile, IconNetwork, IconRestore, IconSettings } from "@tabler/icons-react"; import { GeneralSetup } from "./GeneralSetup"; import { VPNSetup } from "./VPNSetup"; import { TemplateSetup } from "./TemplateSetup"; diff --git a/webapp/src/Routes/Setup/TemplateSetup.tsx b/webapp/src/Routes/Setup/TemplateSetup.tsx index db0486d..243ef01 100644 --- a/webapp/src/Routes/Setup/TemplateSetup.tsx +++ b/webapp/src/Routes/Setup/TemplateSetup.tsx @@ -2,7 +2,7 @@ import { Container, Button, Alert, Textarea, Space } from "@mantine/core"; import { useEffect, useState } from "react"; import { IconInfoCircle } from "@tabler/icons-react"; import { AppSettings } from "../../Constants/Constants"; -import { useMutation, useQuery } from "@tanstack/react-query"; +import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; import { useAuthContext } from "../../Auth/Auth"; import { useForm } from '@mantine/form'; import axios, { AxiosError } from "axios"; @@ -19,6 +19,7 @@ export function TemplateSetup() { const [saved, setSaved] = useState(false) const [saveError, setSaveError] = useState("") const {authInfo} = useAuthContext(); + const queryClient = useQueryClient() const { isPending, error, data, isSuccess } = useQuery({ queryKey: ['templates-setup'], queryFn: () => @@ -52,6 +53,8 @@ export function TemplateSetup() { onSuccess: () => { setSaved(true) setSaveError("") + queryClient.invalidateQueries({ queryKey: ['templates-setup'] }) + window.scrollTo(0, 0) }, onError: (error:AxiosError) => { const errorMessage = error.response?.data as TemplateSetupError diff --git a/webapp/src/Routes/Setup/VPNSetup.tsx b/webapp/src/Routes/Setup/VPNSetup.tsx index c05090a..14b78d4 100644 --- a/webapp/src/Routes/Setup/VPNSetup.tsx +++ b/webapp/src/Routes/Setup/VPNSetup.tsx @@ -4,7 +4,7 @@ import { useEffect, useState } from "react"; import classes from './Setup.module.css'; import { IconInfoCircle } from "@tabler/icons-react"; import { AppSettings } from "../../Constants/Constants"; -import { useMutation, useQuery } from "@tanstack/react-query"; +import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; import { useAuthContext } from "../../Auth/Auth"; import { useForm } from '@mantine/form'; import axios, { AxiosError } from "axios"; @@ -28,6 +28,7 @@ export function VPNSetup() { const [saved, setSaved] = useState(false) const [saveError, setSaveError] = useState("") const {authInfo} = useAuthContext(); + const queryClient = useQueryClient() const { isPending, error, data, isSuccess } = useQuery({ queryKey: ['vpn-setup'], queryFn: () => @@ -66,6 +67,8 @@ export function VPNSetup() { onSuccess: () => { setSaved(true) setSaveError("") + queryClient.invalidateQueries({ queryKey: ['vpn-setup'] }) + window.scrollTo(0, 0) }, onError: (error:AxiosError) => { const errorMessage = error.response?.data as VPNSetupError @@ -91,7 +94,7 @@ export function VPNSetup() { return ( - Changes to Address Range, Port, External Interface, or NAT will need a wireguard reload. You can click the "Reload Wireguard" button at the bottom after submitting the changes. This will disconnect active VPN clients. + Changes to Address Range, Port, External Interface, or NAT will need a wireguard reload. You can click the "Reload Wireguard" button at the bottom after submitting the changes. This will disconnect active VPN clients, and if the Address Range or Port is changed, all clients will need to download a new VPN Config. {saved && saveError === "" ? Settings Saved! : null} {saveError !== "" ? {saveError} : null}