Skip to content

Commit

Permalink
vpn config, template config
Browse files Browse the repository at this point in the history
  • Loading branch information
wardviaene committed Aug 20, 2024
1 parent 3d3b706 commit 574638e
Show file tree
Hide file tree
Showing 14 changed files with 337 additions and 36 deletions.
20 changes: 20 additions & 0 deletions pkg/configmanager/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,26 @@ func (c *ConfigManager) version(w http.ResponseWriter, r *http.Request) {
}
}

func (c *ConfigManager) restartVpn(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodPost:
err := stopVPN(c.Storage)
if err != nil { // don't exit, as the VPN might be down already.
fmt.Println("========= Warning =========")
fmt.Printf("Warning: vpn stop error: %s\n", err)
fmt.Println("=========================")
}
err = startVPN(c.Storage)
if err != nil {
returnError(w, fmt.Errorf("vpn start error: %s", err), http.StatusBadRequest)
return
}
w.WriteHeader(http.StatusAccepted)
default:
returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest)
}
}

func returnError(w http.ResponseWriter, err error, statusCode int) {
fmt.Println("========= ERROR =========")
fmt.Printf("Error: %s\n", err)
Expand Down
1 change: 1 addition & 0 deletions pkg/configmanager/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ func (c *ConfigManager) getRouter() *http.ServeMux {
mux.Handle("/pubkey", http.HandlerFunc(c.getPubKey))
mux.Handle("/refresh-clients", http.HandlerFunc(c.refreshClients))
mux.Handle("/upgrade", http.HandlerFunc(c.upgrade))
mux.Handle("/restart-vpn", http.HandlerFunc(c.restartVpn))
mux.Handle("/version", http.HandlerFunc(c.version))

return mux
Expand Down
5 changes: 5 additions & 0 deletions pkg/configmanager/start_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,8 @@ func startVPN(storage storage.Iface) error {
fmt.Printf("Warning: startVPN is not implemented in darwin\n")
return nil
}

func stopVPN(storage storage.Iface) error {
fmt.Printf("Warning: startVPN is not implemented in darwin\n")
return nil
}
4 changes: 4 additions & 0 deletions pkg/configmanager/start_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,7 @@ func startVPN(storage storage.Iface) error {
}
return wireguard.StartVPN()
}

func stopVPN(storage storage.Iface) error {
return wireguard.StopVPN()
}
1 change: 1 addition & 0 deletions pkg/rest/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ func (c *Context) getRouter(assets fs.FS, indexHtml []byte) *http.ServeMux {
mux.Handle("/api/oidc/{id}", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.oidcProviderElementHandler)))))
mux.Handle("/api/setup/general", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.setupHandler)))))
mux.Handle("/api/setup/vpn", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.vpnSetupHandler)))))
mux.Handle("/api/setup/templates", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.templateSetupHandler)))))
mux.Handle("/api/scim-setup", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.scimSetupHandler)))))
mux.Handle("/api/saml-setup", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.samlSetupHandler)))))
mux.Handle("/api/saml-setup/{id}", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.samlSetupElementHandler)))))
Expand Down
98 changes: 90 additions & 8 deletions pkg/rest/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/http"
"net/netip"
"reflect"
"strconv"
"strings"

"github.com/google/uuid"
Expand Down Expand Up @@ -169,7 +170,7 @@ func (c *Context) vpnSetupHandler(w http.ResponseWriter, r *http.Request) {
VPNEndpoint: vpnConfig.Endpoint,
AddressRange: vpnConfig.AddressRange.String(),
ClientAddressPrefix: vpnConfig.ClientAddressPrefix,
Port: vpnConfig.Port,
Port: strconv.Itoa(vpnConfig.Port),
ExternalInterface: vpnConfig.ExternalInterface,
Nameservers: strings.Join(vpnConfig.Nameservers, ","),
DisableNAT: vpnConfig.DisableNAT,
Expand All @@ -196,10 +197,12 @@ func (c *Context) vpnSetupHandler(w http.ResponseWriter, r *http.Request) {
if strings.TrimSpace(network) == "::/0" {
validatedNetworks = append(validatedNetworks, "::/0")
} else {
_, ipnet, err := net.ParseCIDR(network)
if err == nil {
validatedNetworks = append(validatedNetworks, ipnet.String())
_, ipnet, err := net.ParseCIDR(strings.TrimSpace(network))
if err != nil {
c.returnError(w, fmt.Errorf("client route %s in wrong format: %s", strings.TrimSpace(network), err), http.StatusBadRequest)
return
}
validatedNetworks = append(validatedNetworks, ipnet.String())
}
}
vpnConfig.ClientRoutes = validatedNetworks
Expand Down Expand Up @@ -227,8 +230,13 @@ func (c *Context) vpnSetupHandler(w http.ResponseWriter, r *http.Request) {
writeVPNConfig = true
rewriteClientConfigs = true
}
if setupRequest.Port != vpnConfig.Port {
vpnConfig.Port = setupRequest.Port
port, err := strconv.Atoi(setupRequest.Port)
if err != nil {
c.returnError(w, fmt.Errorf("port in wrong format: %s", err), http.StatusBadRequest)
return
}
if port != vpnConfig.Port {
vpnConfig.Port = port
writeVPNConfig = true
rewriteClientConfigs = true
rewriteServerConfig = true
Expand All @@ -239,7 +247,7 @@ func (c *Context) vpnSetupHandler(w http.ResponseWriter, r *http.Request) {
nameservers[k] = strings.TrimSpace(nameservers[k])
}
if !reflect.DeepEqual(nameservers, vpnConfig.Nameservers) {
vpnConfig.ExternalInterface = setupRequest.ExternalInterface
vpnConfig.Nameservers = nameservers
writeVPNConfig = true
rewriteClientConfigs = true
}
Expand Down Expand Up @@ -274,7 +282,7 @@ func (c *Context) vpnSetupHandler(w http.ResponseWriter, r *http.Request) {
// rewrite server config
err = wireguard.WriteWireGuardServerConfig(c.Storage.Client)
if err != nil {
c.returnError(w, fmt.Errorf("could not update client vpn configs: %s", err), http.StatusBadRequest)
c.returnError(w, fmt.Errorf("could not write wireguard server config: %s", err), http.StatusBadRequest)
return
}
}
Expand All @@ -289,6 +297,80 @@ func (c *Context) vpnSetupHandler(w http.ResponseWriter, r *http.Request) {
}
}

func (c *Context) templateSetupHandler(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
clientTemplate, err := wireguard.GetClientTemplate(c.Storage.Client)
if err != nil {
c.returnError(w, fmt.Errorf("could not retrieve client template: %s", err), http.StatusBadRequest)
return
}
serverTemplate, err := wireguard.GetServerTemplate(c.Storage.Client)
if err != nil {
c.returnError(w, fmt.Errorf("could not retrieve server template: %s", err), http.StatusBadRequest)
return
}
setupRequest := TemplateSetupRequest{
ClientTemplate: string(clientTemplate),
ServerTemplate: string(serverTemplate),
}
out, err := json.Marshal(setupRequest)
if err != nil {
c.returnError(w, fmt.Errorf("could not marshal SetupRequest: %s", err), http.StatusBadRequest)
return
}
c.write(w, out)
case http.MethodPost:
var templateSetupRequest TemplateSetupRequest
decoder := json.NewDecoder(r.Body)
decoder.Decode(&templateSetupRequest)
clientTemplate, err := wireguard.GetClientTemplate(c.Storage.Client)
if err != nil {
c.returnError(w, fmt.Errorf("could not retrieve client template: %s", err), http.StatusBadRequest)
return
}
serverTemplate, err := wireguard.GetServerTemplate(c.Storage.Client)
if err != nil {
c.returnError(w, fmt.Errorf("could not retrieve server template: %s", err), http.StatusBadRequest)
return
}
if string(clientTemplate) != templateSetupRequest.ClientTemplate {
err = wireguard.WriteClientTemplate(c.Storage.Client, []byte(templateSetupRequest.ClientTemplate))
if err != nil {
c.returnError(w, fmt.Errorf("WriteClientTemplate error: %s", err), http.StatusBadRequest)
return
}
// rewrite client configs
err = wireguard.UpdateClientsConfig(c.Storage.Client)
if err != nil {
c.returnError(w, fmt.Errorf("could not update client vpn configs: %s", err), http.StatusBadRequest)
return
}
}
if string(serverTemplate) != templateSetupRequest.ServerTemplate {
err = wireguard.WriteServerTemplate(c.Storage.Client, []byte(templateSetupRequest.ServerTemplate))
if err != nil {
c.returnError(w, fmt.Errorf("WriteServerTemplate error: %s", err), http.StatusBadRequest)
return
}
// rewrite server config
err = wireguard.WriteWireGuardServerConfig(c.Storage.Client)
if err != nil {
c.returnError(w, fmt.Errorf("could not write wireguard server config: %s", err), http.StatusBadRequest)
return
}
}
out, err := json.Marshal(templateSetupRequest)
if err != nil {
c.returnError(w, fmt.Errorf("could not marshal SetupRequest: %s", err), http.StatusBadRequest)
return
}
c.write(w, out)
default:
c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest)
}
}

func (c *Context) scimSetupHandler(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
Expand Down
9 changes: 6 additions & 3 deletions pkg/rest/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,21 +99,24 @@ type GeneralSetupRequest struct {
RedirectToHttps bool `json:"redirectToHttps"`
DisableLocalAuth bool `json:"disableLocalAuth"`
EnableOIDCTokenRenewal bool `json:"enableOIDCTokenRenewal"`
Routes string `json:"routes"`
VPNEndpoint string `json:"vpnEndpoint"`
}

type VPNSetupRequest struct {
Routes string `json:"routes"`
VPNEndpoint string `json:"vpnEndpoint"`
AddressRange string `json:"addressRange"`
ClientAddressPrefix string `json:"clientAddressPrefix"`
Port int `json:"port"`
Port string `json:"port"`
ExternalInterface string `json:"externalInterface"`
Nameservers string `json:"nameservers"`
DisableNAT bool `json:"disableNAT"`
}

type TemplateSetupRequest struct {
ClientTemplate string `json:"clientTemplate"`
ServerTemplate string `json:"serverTemplate"`
}

type NewConnectionResponse struct {
Name string `json:"name"`
}
Expand Down
19 changes: 18 additions & 1 deletion pkg/wireguard/startup.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,27 @@ func StartVPN() error {

if err := cmd.Wait(); err != nil {
if exiterr, ok := err.(*exec.ExitError); ok {
return fmt.Errorf("exit Status: %d", exiterr.ExitCode())
return fmt.Errorf("start vpn exit Status: %d", exiterr.ExitCode())
} else {
return fmt.Errorf("error while waiting for the VPN to start: %v", err)
}
}
return nil
}

func StopVPN() error {
cmd := exec.Command("wg-quick", "down", "vpn")

if err := cmd.Start(); err != nil {
return fmt.Errorf("VPN stop error: %v", err)
}

if err := cmd.Wait(); err != nil {
if exiterr, ok := err.(*exec.ExitError); ok {
return fmt.Errorf("stop vpn exit Status: %d", exiterr.ExitCode())
} else {
return fmt.Errorf("error while waiting for the VPN to stop: %v", err)
}
}
return nil
}
29 changes: 23 additions & 6 deletions pkg/wireguard/wireguardclientconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,7 @@ func getPeerConfig(storage storage.Iface, connectionID string) (PeerConfig, erro
return peerConfig, nil
}

func GenerateNewClientConfig(storage storage.Iface, connectionID, userID string) ([]byte, error) {
clientConfigMutex.Lock()
defer clientConfigMutex.Unlock()
func GetClientTemplate(storage storage.Iface) ([]byte, error) {
filename := storage.ConfigPath("templates/client.tmpl")
err := storage.EnsurePath(storage.ConfigPath("templates"))
if err != nil {
Expand All @@ -173,6 +171,25 @@ func GenerateNewClientConfig(storage storage.Iface, connectionID, userID string)
return nil, fmt.Errorf("could not create initial client template: %s", err)
}
}
data, err := storage.ReadFile(filename)
if err != nil {
return nil, fmt.Errorf("could not read client template: %s", err)
}
return data, err
}

func WriteClientTemplate(storage storage.Iface, body []byte) error {
filename := storage.ConfigPath("templates/client.tmpl")
err := storage.WriteFile(filename, body)
if err != nil {
return fmt.Errorf("could not write client template: %s", err)
}
return nil
}

func GenerateNewClientConfig(storage storage.Iface, connectionID, userID string) ([]byte, error) {
clientConfigMutex.Lock()
defer clientConfigMutex.Unlock()

// parse template
privateKey, publicKey, err := GenerateKeys()
Expand Down Expand Up @@ -208,12 +225,12 @@ func GenerateNewClientConfig(storage storage.Iface, connectionID, userID string)
AllowedIPs: peerConfig.ClientAllowedIPs,
}

templatefileContents, err := storage.ReadFile(filename)
templatefileContents, err := GetClientTemplate(storage)
if err != nil {
return nil, fmt.Errorf("could not read client template: %s", err)
return nil, fmt.Errorf("could not get client template: %s", err)
}

tmpl, err := template.New(path.Base(filename)).Funcs(template.FuncMap{"StringsJoin": strings.Join}).Parse(string(templatefileContents))
tmpl, err := template.New("client.tmpl").Funcs(template.FuncMap{"StringsJoin": strings.Join}).Parse(string(templatefileContents))
if err != nil {
return nil, fmt.Errorf("could not parse client template: %s", err)
}
Expand Down
27 changes: 23 additions & 4 deletions pkg/wireguard/wireguardserverconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func WriteWireGuardServerConfig(storage storage.Iface) error {
return nil
}

func generateWireGuardServerConfig(storage storage.Iface) ([]byte, error) {
func GetServerTemplate(storage storage.Iface) ([]byte, error) {
templatefile := storage.ConfigPath(path.Join(WIREGUARD_TEMPLATE_DIR, WIREGUARD_TEMPLATE_SERVER))
err := storage.EnsurePath(storage.ConfigPath(WIREGUARD_TEMPLATE_DIR))
if err != nil {
Expand All @@ -45,6 +45,25 @@ func generateWireGuardServerConfig(storage storage.Iface) ([]byte, error) {
}
}

templateContents, err := storage.ReadFile(templatefile)
if err != nil {
return nil, fmt.Errorf("cannot read template file (%s): %s", templatefile, err)
}
return templateContents, nil
}

func WriteServerTemplate(storage storage.Iface, body []byte) error {
templatefile := storage.ConfigPath(path.Join(WIREGUARD_TEMPLATE_DIR, WIREGUARD_TEMPLATE_SERVER))

err := storage.WriteFile(templatefile, body)
if err != nil {
return fmt.Errorf("could not write template (%s): %s", templatefile, err)
}

return nil
}

func generateWireGuardServerConfig(storage storage.Iface) ([]byte, error) {
vpnConfig, err := GetVPNConfig(storage)
if err != nil {
return nil, fmt.Errorf("failed to get vpn config: %s", err)
Expand All @@ -61,11 +80,11 @@ func generateWireGuardServerConfig(storage storage.Iface) ([]byte, error) {
ExternalInterface: vpnConfig.ExternalInterface,
}

templateContents, err := storage.ReadFile(templatefile)
templateContents, err := GetServerTemplate(storage)
if err != nil {
return nil, fmt.Errorf("cannot read template file (%s): %s", templatefile, err)
return nil, fmt.Errorf("cannot get template file: %s", err)
}
tmpl, err := template.New(path.Base(templatefile)).Parse(string(templateContents))
tmpl, err := template.New(WIREGUARD_TEMPLATE_SERVER).Parse(string(templateContents))
if err != nil {
return nil, fmt.Errorf("could not parse client template: %s", err)
}
Expand Down
Loading

0 comments on commit 574638e

Please sign in to comment.