Skip to content

Commit

Permalink
ensure next IP work with ranges, ensureclients are loaded after restart
Browse files Browse the repository at this point in the history
  • Loading branch information
wardviaene committed Aug 21, 2024
1 parent 26732b1 commit a54e3c4
Show file tree
Hide file tree
Showing 10 changed files with 223 additions and 41 deletions.
5 changes: 5 additions & 0 deletions pkg/configmanager/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ func (c *ConfigManager) restartVpn(w http.ResponseWriter, r *http.Request) {
returnError(w, fmt.Errorf("vpn start error: %s", err), http.StatusBadRequest)
return
}
err = refreshAllClientsAndServer(c.Storage)
if err != nil {
returnError(w, fmt.Errorf("could not refresh all clients: %s", err), http.StatusBadRequest)
return
}
w.WriteHeader(http.StatusAccepted)
default:
returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest)
Expand Down
6 changes: 0 additions & 6 deletions pkg/rest/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -342,12 +342,6 @@ func (c *Context) templateSetupHandler(w http.ResponseWriter, r *http.Request) {
c.returnError(w, fmt.Errorf("WriteServerTemplate error: %s", err), http.StatusBadRequest)
return
}
// rewrite server config
err = wireguard.WriteWireGuardServerConfig(c.Storage.Client)
if err != nil {
c.returnError(w, fmt.Errorf("could not write wireguard server config: %s", err), http.StatusBadRequest)
return
}
}
out, err := json.Marshal(templateSetupRequest)
if err != nil {
Expand Down
49 changes: 29 additions & 20 deletions pkg/wireguard/ip.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@ import (
"net"
"net/netip"
"path"
"strings"

"github.com/in4it/wireguard-server/pkg/storage"
)

func getNextFreeIP(storage storage.Iface, addressRange netip.Prefix) (net.IP, error) {
func getNextFreeIP(storage storage.Iface, addressRange netip.Prefix, addressPrefix string) (net.IP, error) {
ipList := []string{}
startIP := net.IP(addressRange.Addr().AsSlice())
startIP, addressRangeParsed, err := net.ParseCIDR(addressRange.String())
if err != nil {
return nil, fmt.Errorf("cannot parse address range: %s: %s", addressRange, err)
}

clients, err := storage.ReadDir(storage.ConfigPath(VPN_CLIENTS_DIR))
if err != nil {
Expand All @@ -28,43 +32,48 @@ func getNextFreeIP(storage storage.Iface, addressRange netip.Prefix) (net.IP, er
if err != nil {
return nil, fmt.Errorf("cannot unmarshal %s: %s", clientFilename, err)
}
peerConfigAddress, _, err := net.ParseCIDR(peerConfig.Address)
if err != nil {
return nil, fmt.Errorf("could not parse peer config address %s: %s", peerConfig.Address, err)
}
ipList = append(ipList, peerConfigAddress.String())
ipList = append(ipList, peerConfig.Address)
}

newIP, err := getNextFreeIPFromList(startIP, ipList)
newIP, err := getNextFreeIPFromList(startIP, addressRangeParsed, ipList, addressPrefix)
if err != nil {
return nil, fmt.Errorf("getNextFreeIPFromList error: %s", err)
}

newIPAddress, err := netip.ParseAddr(newIP.String())
if err != nil {
return nil, fmt.Errorf("couldn't parse newIP: %s", newIP.String())
}

if !addressRange.Contains(newIPAddress) {
return nil, fmt.Errorf("newly allocated IP (%s) is not within address range (%s). Address Range might be too small", newIPAddress.String(), addressRange.String())
}

return newIP, nil
}
func getNextFreeIPFromList(startIP net.IP, ipList []string) (net.IP, error) {
func getNextFreeIPFromList(startIP net.IP, addressRange *net.IPNet, ipList []string, addressPrefix string) (net.IP, error) {
nextIPAddress := startIP
for i := 0; i < 100000; i++ {
nextIPAddress = nextIP(nextIPAddress, 1)
ipExists := false
for _, ip := range ipList {
if nextIPAddress.String() == ip {
ipRange := ip
if !strings.Contains(ip, "/") {
ipRange += addressPrefix
}
_, ipRangeParsed, err := net.ParseCIDR(ipRange)
if err != nil {
return nil, fmt.Errorf("cannot parse IP address: %s (ip range %s)", ip, ipRange)
}
if ipRangeParsed.Contains(nextIPAddress) {
ipExists = true
}
}
if !ipExists {
return nextIPAddress, nil
if !addressRange.Contains(nextIPAddress) {
return nil, fmt.Errorf("next IP (%s) is not within address range (%s). Address Range might be too small", nextIPAddress.String(), addressRange.String())
}
_, ipRangeParsed, err := net.ParseCIDR(nextIPAddress.String() + addressPrefix)
if err != nil {
return nil, fmt.Errorf("cannot parse new IP address range: %s: %s", nextIPAddress.String()+addressPrefix, err)
}
if !ipRangeParsed.Contains(startIP) { // don't pick a range where the start ip is in the range
return nextIPAddress, nil
}
}
}

return nil, fmt.Errorf("couldn't determine next ip address")
}

Expand Down
72 changes: 68 additions & 4 deletions pkg/wireguard/ip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@ package wireguard

import (
"net"
"strings"
"testing"
)

func TestGetNextFreeIPFromLisWithList(t *testing.T) {
startIP, _, err := net.ParseCIDR("10.189.184.1/21")
startIP, addressRange, err := net.ParseCIDR("10.189.184.1/21")
if err != nil {
t.Fatalf("error: %s", err)
}
nextIP, err := getNextFreeIPFromList(startIP, []string{"10.189.184.2"})
nextIP, err := getNextFreeIPFromList(startIP, addressRange, []string{"10.189.184.2"}, "/32")
if err != nil {
t.Fatalf("error: %s", err)
}
Expand All @@ -20,15 +21,78 @@ func TestGetNextFreeIPFromLisWithList(t *testing.T) {
}

func TestGetNextFreeIPFromLisWithList2(t *testing.T) {
startIP, _, err := net.ParseCIDR("10.189.184.1/21")
startIP, addressRange, err := net.ParseCIDR("10.189.184.1/21")
if err != nil {
t.Fatalf("error: %s", err)
}
nextIP, err := getNextFreeIPFromList(startIP, []string{"10.190.190.2", "10.189.184.2", "10.190.190.3"})
nextIP, err := getNextFreeIPFromList(startIP, addressRange, []string{"10.190.190.2", "10.189.184.2", "10.190.190.3"}, "/32")
if err != nil {
t.Fatalf("error: %s", err)
}
if nextIP.String() != "10.189.184.3" {
t.Fatalf("Wrong IP: %s", nextIP)
}
}

func TestGetNextFreeIPWithRange(t *testing.T) {
startIP, addressRange, err := net.ParseCIDR("10.189.184.1/21")
if err != nil {
t.Fatalf("error: %s", err)
}
networkPrefix := []string{
"/32",
"/32",
"/32",
"/32",
"/32",
"/32",
"/30",
"/30",
"/32",
}
testCases := [][]string{
{},
{"10.189.184.2"},
{"10.189.184.2/32"},
{"10.189.184.2", "10.189.184.3", "10.189.184.4/30"},
{"10.189.184.2", "10.189.184.3", "10.189.184.4/30", "10.189.184.8/32"},
{"10.189.184.1/30", "10.189.184.4/30", "10.189.184.8/30"},
{},
{"10.189.184.4/30", "10.189.184.8/30"},
{"10.189.189.2/32", "10.189.189.3/32", "10.189.189.4/32"},
}
expected := []string{
"10.189.184.2",
"10.189.184.3",
"10.189.184.3",
"10.189.184.8",
"10.189.184.9",
"10.189.184.12",
"10.189.184.4",
"10.189.184.12",
"10.189.184.2",
}

for k := range testCases {
nextIP, err := getNextFreeIPFromList(startIP, addressRange, testCases[k], networkPrefix[k])
if err != nil {
t.Fatalf("error: %s", err)
}
if nextIP.String() != expected[k] {
t.Fatalf("Wrong IP: %s", nextIP)
}
}

}

func TestIPNotInRange(t *testing.T) {
startIP, addressRange, err := net.ParseCIDR("10.189.184.1/21")
if err != nil {
t.Fatalf("error: %s", err)
}
_, err = getNextFreeIPFromList(startIP, addressRange, []string{"10.189.188.0/22"}, "/22")
if !strings.Contains(err.Error(), "not within address range") {
t.Fatalf("Expected error, got: %s", err)
}

}
11 changes: 8 additions & 3 deletions pkg/wireguard/wireguardclientconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func NewEmptyClientConfig(storage storage.Iface, userID string) (PeerConfig, err
}

// get next IP address, write in client file
nextFreeIP, err := getNextFreeIP(storage, vpnConfig.AddressRange)
nextFreeIP, err := getNextFreeIP(storage, vpnConfig.AddressRange, vpnConfig.ClientAddressPrefix)
if err != nil {
return PeerConfig{}, fmt.Errorf("getNextFreeIP error: %s", err)
}
Expand All @@ -75,7 +75,7 @@ func NewEmptyClientConfig(storage storage.Iface, userID string) (PeerConfig, err
DNS: strings.Join(vpnConfig.Nameservers, ", "),
Name: fmt.Sprintf("connection-%d", newConfigNumber),
Address: nextFreeIP.String() + vpnConfig.ClientAddressPrefix,
ServerAllowedIPs: []string{nextFreeIP.String() + "/32"},
ServerAllowedIPs: []string{nextFreeIP.String() + vpnConfig.ClientAddressPrefix},
ClientAllowedIPs: clientAllowedIPs,
}

Expand Down Expand Up @@ -131,7 +131,7 @@ func UpdateClientsConfig(storage storage.Iface) error {
return fmt.Errorf("couldn't parse existing address of vpn config %s", clientFilename)
}
if !vpnConfig.AddressRange.Contains(addressParsed.Addr()) { // client IP address is not in address range (address range might have changed)
nextFreeIP, err := getNextFreeIP(storage, vpnConfig.AddressRange)
nextFreeIP, err := getNextFreeIP(storage, vpnConfig.AddressRange, vpnConfig.ClientAddressPrefix)
if err != nil {
return fmt.Errorf("getNextFreeIP error: %s", err)
}
Expand All @@ -140,6 +140,11 @@ func UpdateClientsConfig(storage storage.Iface) error {
rewriteFile = true
}

if !strings.HasSuffix(peerConfig.Address, vpnConfig.ClientAddressPrefix) {
rewriteFile = true
peerConfig.Address = addressParsed.Addr().String() + vpnConfig.ClientAddressPrefix
}

if rewriteFile {
peerConfigOut, err := json.Marshal(peerConfig)
if err != nil {
Expand Down
106 changes: 104 additions & 2 deletions pkg/wireguard/wireguardclientconfig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@ import (
)

func TestGetNextFreeIPFromList(t *testing.T) {
ip := net.ParseIP("10.0.0.1")
startIP, addressRange, err := net.ParseCIDR("10.0.0.1/21")
if err != nil {
t.Fatalf("error: %s", err)
}
ipList := []string{"10.0.0.2", "10.0.0.3"}
nextIP, err := getNextFreeIPFromList(ip, ipList)
nextIP, err := getNextFreeIPFromList(startIP, addressRange, ipList, "/32")
if err != nil {
t.Errorf("next IP error: %s", err)
}
Expand Down Expand Up @@ -646,6 +649,7 @@ func TestUpdateClientConfigNewAddressRange(t *testing.T) {
newClientRoutes := []string{"1.2.3.4/32"}
vpnconfig.ClientRoutes = newClientRoutes
vpnconfig.AddressRange, err = netip.ParsePrefix("10.190.190.1/21")
vpnconfig.Nameservers = []string{"3.4.5.6", "8.8.8.8"}
if err != nil {
t.Fatalf("can't parse new ip range")
}
Expand All @@ -669,6 +673,9 @@ func TestUpdateClientConfigNewAddressRange(t *testing.T) {
if peerConfigCurrent.Address != "10.190.190.2/32" {
t.Fatalf("expected different client config address. Got: %s", peerConfigCurrent.Address)
}
if peerConfigCurrent.DNS != strings.Join(vpnconfig.Nameservers, ", ") {
t.Fatalf("Unexpected DNS Servers: %s (expected %s)", peerConfig.DNS, strings.Join(vpnconfig.Nameservers, ", "))
}

peerConfig, err = NewEmptyClientConfig(storage, "2-2-2-2")
if err != nil {
Expand All @@ -685,5 +692,100 @@ func TestUpdateClientConfigNewAddressRange(t *testing.T) {
if peerConfig.Address != "10.190.190.3/32" {
t.Fatalf("expected different client config address. Got: %s", peerConfig.Address)
}
if peerConfig.DNS != strings.Join(vpnconfig.Nameservers, ", ") {
t.Fatalf("Unexpected DNS Servers: %s (expected %s)", peerConfig.DNS, strings.Join(vpnconfig.Nameservers, ", "))
}
}

func TestUpdateClientConfigNewClientAddressPrefix(t *testing.T) {
var (
l net.Listener
err error
)
for {
l, err = net.Listen("tcp", CONFIGMANAGER_URI)
if err != nil {
if !strings.HasSuffix(err.Error(), "address already in use") {
t.Fatal(err)
}
time.Sleep(1 * time.Second)
} else {
break
}
}

ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodPost:
if r.RequestURI == "/refresh-clients" {
w.WriteHeader(http.StatusAccepted)
w.Write([]byte("OK"))
return
}
w.WriteHeader(http.StatusBadRequest)
default:
w.WriteHeader(http.StatusBadRequest)
}
}))

ts.Listener.Close()
ts.Listener = l
ts.Start()
defer ts.Close()
defer l.Close()

storage := &testingmocks.MockMemoryStorage{}

// first create a new vpn config
vpnconfig, err := CreateNewVPNConfig(storage)
if err != nil {
t.Fatalf("CreateNewVPNConfig error: %s", err)
}
prefix, err := netip.ParsePrefix(DEFAULT_VPN_PREFIX)
if err != nil {
t.Errorf("ParsePrefix error: %s", err)
}
if vpnconfig.AddressRange.String() != prefix.String() {
t.Fatalf("wrong AddressRange: %s vs %s", vpnconfig.AddressRange.String(), prefix.String())
}
if vpnconfig.ClientAddressPrefix != "/32" {
t.Fatalf("unexpected default for address prefix: %s", vpnconfig.ClientAddressPrefix)
}
// generate the peerconfig
peerConfig, err := NewEmptyClientConfig(storage, "2-2-2-2")
if err != nil {
t.Fatalf("NewEmptyClientConfig error: %s", err)
}

if peerConfig.ClientAllowedIPs[0] != "0.0.0.0/0" {
t.Fatalf("wrong client allowed ips")
}

vpnconfig.ClientAddressPrefix = "/30"

err = WriteVPNConfig(storage, vpnconfig)
if err != nil {
t.Fatalf("WriteVPNConfig error: %s", err)
}
err = UpdateClientsConfig(storage)
if err != nil {
t.Fatalf("UpdateClientsConfig error: %s", err)
}

peerConfigCurrent, err := getPeerConfig(storage, "2-2-2-2-1")
if err != nil {
t.Fatalf("getPeerConfig error: %s", err)
}

if peerConfigCurrent.Address != "10.189.184.2/30" {
t.Fatalf("expected different client address. Got: %s", peerConfigCurrent.Address)
}
peerConfig, err = NewEmptyClientConfig(storage, "2-2-2-2")
if err != nil {
t.Fatalf("NewEmptyClientConfig error: %s", err)
}
if peerConfig.Address != "10.189.184.4/30" {
t.Fatalf("expected different client config address. Got: %s", peerConfig.Address)
}

}
Loading

0 comments on commit a54e3c4

Please sign in to comment.