Skip to content

Commit

Permalink
upgrade version fix, rewrite client config when ip range changes
Browse files Browse the repository at this point in the history
  • Loading branch information
wardviaene committed Aug 21, 2024
1 parent 81e3dda commit 26732b1
Show file tree
Hide file tree
Showing 11 changed files with 226 additions and 24 deletions.
3 changes: 3 additions & 0 deletions pkg/configmanager/upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ func newVersionAvailable() (bool, string, error) {
if i1 > i2 {
return true, latestVersion, nil
}
if i1 < i2 {
return false, latestVersion, nil
}
}
}
return false, latestVersion, nil
Expand Down
63 changes: 63 additions & 0 deletions pkg/configmanager/upgrade_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,66 @@ func TestNewVersionAvailableBogus2(t *testing.T) {
t.Fatalf("expected new version not to be available: %s", version)
}
}
func TestNewVersionAvailableHigherVersionMajor(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.URL.RequestURI() == "/latest" {
currentVersionSplit := strings.Split(getVersion(), ".")
if len(currentVersionSplit) != 3 {
t.Fatalf("unsupported current version: %s", getVersion())
}
i, err := strconv.Atoi(currentVersionSplit[1])
if err != nil {
t.Fatalf("unsupported current version: %s", getVersion())
}
i++
newVersion := strings.Join([]string{currentVersionSplit[0], strconv.Itoa(i), "0"}, ".")
w.Write([]byte(newVersion))
return
}
w.WriteHeader(http.StatusNotFound)
}))

defer server.Close()

BINARIES_URL = server.URL

available, version, err := newVersionAvailable()
if err != nil {
t.Fatalf("error: %s", err)
}
if !available {
t.Fatalf("expected new version expected to be available: %s", version)
}
}

func TestNewVersionNotAvailableHigherVersionMajor(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.URL.RequestURI() == "/latest" {
currentVersionSplit := strings.Split(getVersion(), ".")
if len(currentVersionSplit) != 3 {
t.Fatalf("unsupported current version: %s", getVersion())
}
i, err := strconv.Atoi(currentVersionSplit[1])
if err != nil {
t.Fatalf("unsupported current version: %s", getVersion())
}
i--
newVersion := strings.Join([]string{currentVersionSplit[0], strconv.Itoa(i), "99"}, ".")
w.Write([]byte(newVersion))
return
}
w.WriteHeader(http.StatusNotFound)
}))

defer server.Close()

BINARIES_URL = server.URL

available, version, err := newVersionAvailable()
if err != nil {
t.Fatalf("error: %s", err)
}
if available {
t.Fatalf("expected new version expected not to be available: %s (current version: %s)", version, getVersion())
}
}
13 changes: 0 additions & 13 deletions pkg/rest/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ func (c *Context) vpnSetupHandler(w http.ResponseWriter, r *http.Request) {
var (
writeVPNConfig bool
rewriteClientConfigs bool
rewriteServerConfig bool
setupRequest VPNSetupRequest
)
decoder := json.NewDecoder(r.Body)
Expand Down Expand Up @@ -225,7 +224,6 @@ func (c *Context) vpnSetupHandler(w http.ResponseWriter, r *http.Request) {
vpnConfig.AddressRange = addressRangeParsed
writeVPNConfig = true
rewriteClientConfigs = true
rewriteServerConfig = true
}
if setupRequest.ClientAddressPrefix != vpnConfig.ClientAddressPrefix {
vpnConfig.ClientAddressPrefix = setupRequest.ClientAddressPrefix
Expand All @@ -241,7 +239,6 @@ func (c *Context) vpnSetupHandler(w http.ResponseWriter, r *http.Request) {
vpnConfig.Port = port
writeVPNConfig = true
rewriteClientConfigs = true
rewriteServerConfig = true
}

nameservers := strings.Split(setupRequest.Nameservers, ",")
Expand All @@ -256,12 +253,10 @@ func (c *Context) vpnSetupHandler(w http.ResponseWriter, r *http.Request) {
if setupRequest.ExternalInterface != vpnConfig.ExternalInterface { // don't rewrite client config
vpnConfig.ExternalInterface = setupRequest.ExternalInterface
writeVPNConfig = true
rewriteServerConfig = true
}
if setupRequest.DisableNAT != vpnConfig.DisableNAT { // don't rewrite client config
vpnConfig.DisableNAT = setupRequest.DisableNAT
writeVPNConfig = true
rewriteServerConfig = true
}

// write vpn config if config has changed
Expand All @@ -280,14 +275,6 @@ func (c *Context) vpnSetupHandler(w http.ResponseWriter, r *http.Request) {
return
}
}
if rewriteServerConfig {
// rewrite server config
err = wireguard.WriteWireGuardServerConfig(c.Storage.Client)
if err != nil {
c.returnError(w, fmt.Errorf("could not write wireguard server config: %s", err), http.StatusBadRequest)
return
}
}
out, err := json.Marshal(setupRequest)
if err != nil {
c.returnError(w, fmt.Errorf("could not marshal SetupRequest: %s", err), http.StatusBadRequest)
Expand Down
13 changes: 12 additions & 1 deletion pkg/wireguard/ip.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ import (
"encoding/json"
"fmt"
"net"
"net/netip"
"path"

"github.com/in4it/wireguard-server/pkg/storage"
)

func getNextFreeIP(storage storage.Iface, startIP net.IP) (net.IP, error) {
func getNextFreeIP(storage storage.Iface, addressRange netip.Prefix) (net.IP, error) {
ipList := []string{}
startIP := net.IP(addressRange.Addr().AsSlice())

clients, err := storage.ReadDir(storage.ConfigPath(VPN_CLIENTS_DIR))
if err != nil {
Expand Down Expand Up @@ -38,6 +40,15 @@ func getNextFreeIP(storage storage.Iface, startIP net.IP) (net.IP, error) {
return nil, fmt.Errorf("getNextFreeIPFromList error: %s", err)
}

newIPAddress, err := netip.ParseAddr(newIP.String())
if err != nil {
return nil, fmt.Errorf("couldn't parse newIP: %s", newIP.String())
}

if !addressRange.Contains(newIPAddress) {
return nil, fmt.Errorf("newly allocated IP (%s) is not within address range (%s). Address Range might be too small", newIPAddress.String(), addressRange.String())
}

return newIP, nil
}
func getNextFreeIPFromList(startIP net.IP, ipList []string) (net.IP, error) {
Expand Down
14 changes: 14 additions & 0 deletions pkg/wireguard/ip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,17 @@ func TestGetNextFreeIPFromLisWithList(t *testing.T) {
t.Fatalf("Wrong IP: %s", nextIP)
}
}

func TestGetNextFreeIPFromLisWithList2(t *testing.T) {
startIP, _, err := net.ParseCIDR("10.189.184.1/21")
if err != nil {
t.Fatalf("error: %s", err)
}
nextIP, err := getNextFreeIPFromList(startIP, []string{"10.190.190.2", "10.189.184.2", "10.190.190.3"})
if err != nil {
t.Fatalf("error: %s", err)
}
if nextIP.String() != "10.189.184.3" {
t.Fatalf("Wrong IP: %s", nextIP)
}
}
22 changes: 16 additions & 6 deletions pkg/wireguard/wireguardclientconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"net"
"net/http"
"net/netip"
"path"
"slices"
"strconv"
Expand Down Expand Up @@ -49,12 +50,7 @@ func NewEmptyClientConfig(storage storage.Iface, userID string) (PeerConfig, err
}

// get next IP address, write in client file
addressRangeSplit := strings.Split(vpnConfig.AddressRange.String(), "/")
firstIP := net.ParseIP(addressRangeSplit[0])
if firstIP == nil {
return PeerConfig{}, fmt.Errorf("couldn't determine address range from vpn setup")
}
nextFreeIP, err := getNextFreeIP(storage, firstIP)
nextFreeIP, err := getNextFreeIP(storage, vpnConfig.AddressRange)
if err != nil {
return PeerConfig{}, fmt.Errorf("getNextFreeIP error: %s", err)
}
Expand Down Expand Up @@ -130,6 +126,20 @@ func UpdateClientsConfig(storage storage.Iface) error {
peerConfig.DNS = strings.Join(vpnConfig.Nameservers, ", ")
}

addressParsed, err := netip.ParsePrefix(peerConfig.Address)
if err != nil {
return fmt.Errorf("couldn't parse existing address of vpn config %s", clientFilename)
}
if !vpnConfig.AddressRange.Contains(addressParsed.Addr()) { // client IP address is not in address range (address range might have changed)
nextFreeIP, err := getNextFreeIP(storage, vpnConfig.AddressRange)
if err != nil {
return fmt.Errorf("getNextFreeIP error: %s", err)
}
peerConfig.Address = nextFreeIP.String() + vpnConfig.ClientAddressPrefix
peerConfig.ServerAllowedIPs = []string{nextFreeIP.String() + "/32"}
rewriteFile = true
}

if rewriteFile {
peerConfigOut, err := json.Marshal(peerConfig)
if err != nil {
Expand Down
106 changes: 106 additions & 0 deletions pkg/wireguard/wireguardclientconfig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -581,3 +581,109 @@ func TestUpdateClientConfig(t *testing.T) {
}

}

func TestUpdateClientConfigNewAddressRange(t *testing.T) {
var (
l net.Listener
err error
)
for {
l, err = net.Listen("tcp", CONFIGMANAGER_URI)
if err != nil {
if !strings.HasSuffix(err.Error(), "address already in use") {
t.Fatal(err)
}
time.Sleep(1 * time.Second)
} else {
break
}
}

ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodPost:
if r.RequestURI == "/refresh-clients" {
w.WriteHeader(http.StatusAccepted)
w.Write([]byte("OK"))
return
}
w.WriteHeader(http.StatusBadRequest)
default:
w.WriteHeader(http.StatusBadRequest)
}
}))

ts.Listener.Close()
ts.Listener = l
ts.Start()
defer ts.Close()
defer l.Close()

storage := &testingmocks.MockMemoryStorage{}

// first create a new vpn config
vpnconfig, err := CreateNewVPNConfig(storage)
if err != nil {
t.Fatalf("CreateNewVPNConfig error: %s", err)
}
prefix, err := netip.ParsePrefix(DEFAULT_VPN_PREFIX)
if err != nil {
t.Errorf("ParsePrefix error: %s", err)
}
if vpnconfig.AddressRange.String() != prefix.String() {
t.Fatalf("wrong AddressRange: %s vs %s", vpnconfig.AddressRange.String(), prefix.String())
}
// generate the peerconfig
peerConfig, err := NewEmptyClientConfig(storage, "2-2-2-2")
if err != nil {
t.Fatalf("NewEmptyClientConfig error: %s", err)
}

if peerConfig.ClientAllowedIPs[0] != "0.0.0.0/0" {
t.Fatalf("wrong client allowed ips")
}

newClientRoutes := []string{"1.2.3.4/32"}
vpnconfig.ClientRoutes = newClientRoutes
vpnconfig.AddressRange, err = netip.ParsePrefix("10.190.190.1/21")
if err != nil {
t.Fatalf("can't parse new ip range")
}
err = WriteVPNConfig(storage, vpnconfig)
if err != nil {
t.Fatalf("WriteVPNConfig error: %s", err)
}
err = UpdateClientsConfig(storage)
if err != nil {
t.Fatalf("UpdateClientsConfig error: %s", err)
}

peerConfigCurrent, err := getPeerConfig(storage, "2-2-2-2-1")
if err != nil {
t.Fatalf("getPeerConfig error: %s", err)
}

if peerConfigCurrent.ServerAllowedIPs[0] != "10.190.190.2/32" {
t.Fatalf("expected different server allowed IP. Got: %s", strings.Join(peerConfigCurrent.ServerAllowedIPs, ", "))
}
if peerConfigCurrent.Address != "10.190.190.2/32" {
t.Fatalf("expected different client config address. Got: %s", peerConfigCurrent.Address)
}

peerConfig, err = NewEmptyClientConfig(storage, "2-2-2-2")
if err != nil {
t.Fatalf("NewEmptyClientConfig error: %s", err)
}

if peerConfig.ClientAllowedIPs[0] != "1.2.3.4/32" {
t.Fatalf("wrong client allowed ips")
}

if peerConfig.ServerAllowedIPs[0] != "10.190.190.3/32" {
t.Fatalf("expected different server allowed IP. Got: %s", strings.Join(peerConfig.ServerAllowedIPs, ", "))
}
if peerConfig.Address != "10.190.190.3/32" {
t.Fatalf("expected different client config address. Got: %s", peerConfig.Address)
}

}
2 changes: 2 additions & 0 deletions webapp/src/Routes/Setup/GeneralSetup.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,10 @@ export function GeneralSetup() {
},
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ['users'] })
queryClient.invalidateQueries({ queryKey: ['general-setup'] })
setSaved(true)
setSaveError("")
window.scrollTo(0, 0)
},
onError: (error:AxiosError) => {
const errorMessage = error.response?.data as GeneralSetupError
Expand Down
2 changes: 1 addition & 1 deletion webapp/src/Routes/Setup/Setup.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { Container, Tabs, Title, rem } from "@mantine/core";
import classes from './Setup.module.css';
import { IconFile, IconNetwork, IconRestore, IconRewindBackward10, IconSettings } from "@tabler/icons-react";
import { IconFile, IconNetwork, IconRestore, IconSettings } from "@tabler/icons-react";
import { GeneralSetup } from "./GeneralSetup";
import { VPNSetup } from "./VPNSetup";
import { TemplateSetup } from "./TemplateSetup";
Expand Down
5 changes: 4 additions & 1 deletion webapp/src/Routes/Setup/TemplateSetup.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { Container, Button, Alert, Textarea, Space } from "@mantine/core";
import { useEffect, useState } from "react";
import { IconInfoCircle } from "@tabler/icons-react";
import { AppSettings } from "../../Constants/Constants";
import { useMutation, useQuery } from "@tanstack/react-query";
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";
import { useAuthContext } from "../../Auth/Auth";
import { useForm } from '@mantine/form';
import axios, { AxiosError } from "axios";
Expand All @@ -19,6 +19,7 @@ export function TemplateSetup() {
const [saved, setSaved] = useState(false)
const [saveError, setSaveError] = useState("")
const {authInfo} = useAuthContext();
const queryClient = useQueryClient()
const { isPending, error, data, isSuccess } = useQuery({
queryKey: ['templates-setup'],
queryFn: () =>
Expand Down Expand Up @@ -52,6 +53,8 @@ export function TemplateSetup() {
onSuccess: () => {
setSaved(true)
setSaveError("")
queryClient.invalidateQueries({ queryKey: ['templates-setup'] })
window.scrollTo(0, 0)
},
onError: (error:AxiosError) => {
const errorMessage = error.response?.data as TemplateSetupError
Expand Down
Loading

0 comments on commit 26732b1

Please sign in to comment.