From 574638e7f5d16d51ffded16d068af627f88dbb14 Mon Sep 17 00:00:00 2001 From: Edward Viaene Date: Tue, 20 Aug 2024 16:35:53 -0500 Subject: [PATCH] vpn config, template config --- pkg/configmanager/handlers.go | 20 ++++ pkg/configmanager/router.go | 1 + pkg/configmanager/start_darwin.go | 5 + pkg/configmanager/start_linux.go | 4 + pkg/rest/router.go | 1 + pkg/rest/setup.go | 98 +++++++++++++++++-- pkg/rest/types.go | 9 +- pkg/wireguard/startup.go | 19 +++- pkg/wireguard/wireguardclientconfig.go | 29 ++++-- pkg/wireguard/wireguardserverconfig.go | 27 +++++- webapp/src/Routes/Setup/GeneralSetup.tsx | 19 +++- webapp/src/Routes/Setup/Setup.tsx | 3 +- webapp/src/Routes/Setup/TemplateSetup.tsx | 111 ++++++++++++++++++++++ webapp/src/Routes/Setup/VPNSetup.tsx | 27 ++++-- 14 files changed, 337 insertions(+), 36 deletions(-) create mode 100644 webapp/src/Routes/Setup/TemplateSetup.tsx diff --git a/pkg/configmanager/handlers.go b/pkg/configmanager/handlers.go index fd36f51..351a024 100644 --- a/pkg/configmanager/handlers.go +++ b/pkg/configmanager/handlers.go @@ -117,6 +117,26 @@ func (c *ConfigManager) version(w http.ResponseWriter, r *http.Request) { } } +func (c *ConfigManager) restartVpn(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodPost: + err := stopVPN(c.Storage) + if err != nil { // don't exit, as the VPN might be down already. + fmt.Println("========= Warning =========") + fmt.Printf("Warning: vpn stop error: %s\n", err) + fmt.Println("=========================") + } + err = startVPN(c.Storage) + if err != nil { + returnError(w, fmt.Errorf("vpn start error: %s", err), http.StatusBadRequest) + return + } + w.WriteHeader(http.StatusAccepted) + default: + returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) + } +} + func returnError(w http.ResponseWriter, err error, statusCode int) { fmt.Println("========= ERROR =========") fmt.Printf("Error: %s\n", err) diff --git a/pkg/configmanager/router.go b/pkg/configmanager/router.go index 662bd3d..59a32d8 100644 --- a/pkg/configmanager/router.go +++ b/pkg/configmanager/router.go @@ -8,6 +8,7 @@ func (c *ConfigManager) getRouter() *http.ServeMux { mux.Handle("/pubkey", http.HandlerFunc(c.getPubKey)) mux.Handle("/refresh-clients", http.HandlerFunc(c.refreshClients)) mux.Handle("/upgrade", http.HandlerFunc(c.upgrade)) + mux.Handle("/restart-vpn", http.HandlerFunc(c.restartVpn)) mux.Handle("/version", http.HandlerFunc(c.version)) return mux diff --git a/pkg/configmanager/start_darwin.go b/pkg/configmanager/start_darwin.go index 963b023..47864ee 100644 --- a/pkg/configmanager/start_darwin.go +++ b/pkg/configmanager/start_darwin.go @@ -13,3 +13,8 @@ func startVPN(storage storage.Iface) error { fmt.Printf("Warning: startVPN is not implemented in darwin\n") return nil } + +func stopVPN(storage storage.Iface) error { + fmt.Printf("Warning: startVPN is not implemented in darwin\n") + return nil +} diff --git a/pkg/configmanager/start_linux.go b/pkg/configmanager/start_linux.go index c1bbe27..5516755 100644 --- a/pkg/configmanager/start_linux.go +++ b/pkg/configmanager/start_linux.go @@ -17,3 +17,7 @@ func startVPN(storage storage.Iface) error { } return wireguard.StartVPN() } + +func stopVPN(storage storage.Iface) error { + return wireguard.StopVPN() +} diff --git a/pkg/rest/router.go b/pkg/rest/router.go index 07bbc5c..2197538 100644 --- a/pkg/rest/router.go +++ b/pkg/rest/router.go @@ -55,6 +55,7 @@ func (c *Context) getRouter(assets fs.FS, indexHtml []byte) *http.ServeMux { mux.Handle("/api/oidc/{id}", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.oidcProviderElementHandler))))) mux.Handle("/api/setup/general", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.setupHandler))))) mux.Handle("/api/setup/vpn", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.vpnSetupHandler))))) + mux.Handle("/api/setup/templates", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.templateSetupHandler))))) mux.Handle("/api/scim-setup", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.scimSetupHandler))))) mux.Handle("/api/saml-setup", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.samlSetupHandler))))) mux.Handle("/api/saml-setup/{id}", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.samlSetupElementHandler))))) diff --git a/pkg/rest/setup.go b/pkg/rest/setup.go index 04b1715..3367e13 100644 --- a/pkg/rest/setup.go +++ b/pkg/rest/setup.go @@ -8,6 +8,7 @@ import ( "net/http" "net/netip" "reflect" + "strconv" "strings" "github.com/google/uuid" @@ -169,7 +170,7 @@ func (c *Context) vpnSetupHandler(w http.ResponseWriter, r *http.Request) { VPNEndpoint: vpnConfig.Endpoint, AddressRange: vpnConfig.AddressRange.String(), ClientAddressPrefix: vpnConfig.ClientAddressPrefix, - Port: vpnConfig.Port, + Port: strconv.Itoa(vpnConfig.Port), ExternalInterface: vpnConfig.ExternalInterface, Nameservers: strings.Join(vpnConfig.Nameservers, ","), DisableNAT: vpnConfig.DisableNAT, @@ -196,10 +197,12 @@ func (c *Context) vpnSetupHandler(w http.ResponseWriter, r *http.Request) { if strings.TrimSpace(network) == "::/0" { validatedNetworks = append(validatedNetworks, "::/0") } else { - _, ipnet, err := net.ParseCIDR(network) - if err == nil { - validatedNetworks = append(validatedNetworks, ipnet.String()) + _, ipnet, err := net.ParseCIDR(strings.TrimSpace(network)) + if err != nil { + c.returnError(w, fmt.Errorf("client route %s in wrong format: %s", strings.TrimSpace(network), err), http.StatusBadRequest) + return } + validatedNetworks = append(validatedNetworks, ipnet.String()) } } vpnConfig.ClientRoutes = validatedNetworks @@ -227,8 +230,13 @@ func (c *Context) vpnSetupHandler(w http.ResponseWriter, r *http.Request) { writeVPNConfig = true rewriteClientConfigs = true } - if setupRequest.Port != vpnConfig.Port { - vpnConfig.Port = setupRequest.Port + port, err := strconv.Atoi(setupRequest.Port) + if err != nil { + c.returnError(w, fmt.Errorf("port in wrong format: %s", err), http.StatusBadRequest) + return + } + if port != vpnConfig.Port { + vpnConfig.Port = port writeVPNConfig = true rewriteClientConfigs = true rewriteServerConfig = true @@ -239,7 +247,7 @@ func (c *Context) vpnSetupHandler(w http.ResponseWriter, r *http.Request) { nameservers[k] = strings.TrimSpace(nameservers[k]) } if !reflect.DeepEqual(nameservers, vpnConfig.Nameservers) { - vpnConfig.ExternalInterface = setupRequest.ExternalInterface + vpnConfig.Nameservers = nameservers writeVPNConfig = true rewriteClientConfigs = true } @@ -274,7 +282,7 @@ func (c *Context) vpnSetupHandler(w http.ResponseWriter, r *http.Request) { // rewrite server config err = wireguard.WriteWireGuardServerConfig(c.Storage.Client) if err != nil { - c.returnError(w, fmt.Errorf("could not update client vpn configs: %s", err), http.StatusBadRequest) + c.returnError(w, fmt.Errorf("could not write wireguard server config: %s", err), http.StatusBadRequest) return } } @@ -289,6 +297,80 @@ func (c *Context) vpnSetupHandler(w http.ResponseWriter, r *http.Request) { } } +func (c *Context) templateSetupHandler(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + clientTemplate, err := wireguard.GetClientTemplate(c.Storage.Client) + if err != nil { + c.returnError(w, fmt.Errorf("could not retrieve client template: %s", err), http.StatusBadRequest) + return + } + serverTemplate, err := wireguard.GetServerTemplate(c.Storage.Client) + if err != nil { + c.returnError(w, fmt.Errorf("could not retrieve server template: %s", err), http.StatusBadRequest) + return + } + setupRequest := TemplateSetupRequest{ + ClientTemplate: string(clientTemplate), + ServerTemplate: string(serverTemplate), + } + out, err := json.Marshal(setupRequest) + if err != nil { + c.returnError(w, fmt.Errorf("could not marshal SetupRequest: %s", err), http.StatusBadRequest) + return + } + c.write(w, out) + case http.MethodPost: + var templateSetupRequest TemplateSetupRequest + decoder := json.NewDecoder(r.Body) + decoder.Decode(&templateSetupRequest) + clientTemplate, err := wireguard.GetClientTemplate(c.Storage.Client) + if err != nil { + c.returnError(w, fmt.Errorf("could not retrieve client template: %s", err), http.StatusBadRequest) + return + } + serverTemplate, err := wireguard.GetServerTemplate(c.Storage.Client) + if err != nil { + c.returnError(w, fmt.Errorf("could not retrieve server template: %s", err), http.StatusBadRequest) + return + } + if string(clientTemplate) != templateSetupRequest.ClientTemplate { + err = wireguard.WriteClientTemplate(c.Storage.Client, []byte(templateSetupRequest.ClientTemplate)) + if err != nil { + c.returnError(w, fmt.Errorf("WriteClientTemplate error: %s", err), http.StatusBadRequest) + return + } + // rewrite client configs + err = wireguard.UpdateClientsConfig(c.Storage.Client) + if err != nil { + c.returnError(w, fmt.Errorf("could not update client vpn configs: %s", err), http.StatusBadRequest) + return + } + } + if string(serverTemplate) != templateSetupRequest.ServerTemplate { + err = wireguard.WriteServerTemplate(c.Storage.Client, []byte(templateSetupRequest.ServerTemplate)) + if err != nil { + c.returnError(w, fmt.Errorf("WriteServerTemplate error: %s", err), http.StatusBadRequest) + return + } + // rewrite server config + err = wireguard.WriteWireGuardServerConfig(c.Storage.Client) + if err != nil { + c.returnError(w, fmt.Errorf("could not write wireguard server config: %s", err), http.StatusBadRequest) + return + } + } + out, err := json.Marshal(templateSetupRequest) + if err != nil { + c.returnError(w, fmt.Errorf("could not marshal SetupRequest: %s", err), http.StatusBadRequest) + return + } + c.write(w, out) + default: + c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) + } +} + func (c *Context) scimSetupHandler(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: diff --git a/pkg/rest/types.go b/pkg/rest/types.go index 1131f79..d4287ac 100644 --- a/pkg/rest/types.go +++ b/pkg/rest/types.go @@ -99,8 +99,6 @@ type GeneralSetupRequest struct { RedirectToHttps bool `json:"redirectToHttps"` DisableLocalAuth bool `json:"disableLocalAuth"` EnableOIDCTokenRenewal bool `json:"enableOIDCTokenRenewal"` - Routes string `json:"routes"` - VPNEndpoint string `json:"vpnEndpoint"` } type VPNSetupRequest struct { @@ -108,12 +106,17 @@ type VPNSetupRequest struct { VPNEndpoint string `json:"vpnEndpoint"` AddressRange string `json:"addressRange"` ClientAddressPrefix string `json:"clientAddressPrefix"` - Port int `json:"port"` + Port string `json:"port"` ExternalInterface string `json:"externalInterface"` Nameservers string `json:"nameservers"` DisableNAT bool `json:"disableNAT"` } +type TemplateSetupRequest struct { + ClientTemplate string `json:"clientTemplate"` + ServerTemplate string `json:"serverTemplate"` +} + type NewConnectionResponse struct { Name string `json:"name"` } diff --git a/pkg/wireguard/startup.go b/pkg/wireguard/startup.go index 40f1770..708b7dc 100644 --- a/pkg/wireguard/startup.go +++ b/pkg/wireguard/startup.go @@ -14,10 +14,27 @@ func StartVPN() error { if err := cmd.Wait(); err != nil { if exiterr, ok := err.(*exec.ExitError); ok { - return fmt.Errorf("exit Status: %d", exiterr.ExitCode()) + return fmt.Errorf("start vpn exit Status: %d", exiterr.ExitCode()) } else { return fmt.Errorf("error while waiting for the VPN to start: %v", err) } } return nil } + +func StopVPN() error { + cmd := exec.Command("wg-quick", "down", "vpn") + + if err := cmd.Start(); err != nil { + return fmt.Errorf("VPN stop error: %v", err) + } + + if err := cmd.Wait(); err != nil { + if exiterr, ok := err.(*exec.ExitError); ok { + return fmt.Errorf("stop vpn exit Status: %d", exiterr.ExitCode()) + } else { + return fmt.Errorf("error while waiting for the VPN to stop: %v", err) + } + } + return nil +} diff --git a/pkg/wireguard/wireguardclientconfig.go b/pkg/wireguard/wireguardclientconfig.go index 4c7ee49..952b917 100644 --- a/pkg/wireguard/wireguardclientconfig.go +++ b/pkg/wireguard/wireguardclientconfig.go @@ -159,9 +159,7 @@ func getPeerConfig(storage storage.Iface, connectionID string) (PeerConfig, erro return peerConfig, nil } -func GenerateNewClientConfig(storage storage.Iface, connectionID, userID string) ([]byte, error) { - clientConfigMutex.Lock() - defer clientConfigMutex.Unlock() +func GetClientTemplate(storage storage.Iface) ([]byte, error) { filename := storage.ConfigPath("templates/client.tmpl") err := storage.EnsurePath(storage.ConfigPath("templates")) if err != nil { @@ -173,6 +171,25 @@ func GenerateNewClientConfig(storage storage.Iface, connectionID, userID string) return nil, fmt.Errorf("could not create initial client template: %s", err) } } + data, err := storage.ReadFile(filename) + if err != nil { + return nil, fmt.Errorf("could not read client template: %s", err) + } + return data, err +} + +func WriteClientTemplate(storage storage.Iface, body []byte) error { + filename := storage.ConfigPath("templates/client.tmpl") + err := storage.WriteFile(filename, body) + if err != nil { + return fmt.Errorf("could not write client template: %s", err) + } + return nil +} + +func GenerateNewClientConfig(storage storage.Iface, connectionID, userID string) ([]byte, error) { + clientConfigMutex.Lock() + defer clientConfigMutex.Unlock() // parse template privateKey, publicKey, err := GenerateKeys() @@ -208,12 +225,12 @@ func GenerateNewClientConfig(storage storage.Iface, connectionID, userID string) AllowedIPs: peerConfig.ClientAllowedIPs, } - templatefileContents, err := storage.ReadFile(filename) + templatefileContents, err := GetClientTemplate(storage) if err != nil { - return nil, fmt.Errorf("could not read client template: %s", err) + return nil, fmt.Errorf("could not get client template: %s", err) } - tmpl, err := template.New(path.Base(filename)).Funcs(template.FuncMap{"StringsJoin": strings.Join}).Parse(string(templatefileContents)) + tmpl, err := template.New("client.tmpl").Funcs(template.FuncMap{"StringsJoin": strings.Join}).Parse(string(templatefileContents)) if err != nil { return nil, fmt.Errorf("could not parse client template: %s", err) } diff --git a/pkg/wireguard/wireguardserverconfig.go b/pkg/wireguard/wireguardserverconfig.go index d06df6f..be0155a 100644 --- a/pkg/wireguard/wireguardserverconfig.go +++ b/pkg/wireguard/wireguardserverconfig.go @@ -23,7 +23,7 @@ func WriteWireGuardServerConfig(storage storage.Iface) error { return nil } -func generateWireGuardServerConfig(storage storage.Iface) ([]byte, error) { +func GetServerTemplate(storage storage.Iface) ([]byte, error) { templatefile := storage.ConfigPath(path.Join(WIREGUARD_TEMPLATE_DIR, WIREGUARD_TEMPLATE_SERVER)) err := storage.EnsurePath(storage.ConfigPath(WIREGUARD_TEMPLATE_DIR)) if err != nil { @@ -45,6 +45,25 @@ func generateWireGuardServerConfig(storage storage.Iface) ([]byte, error) { } } + templateContents, err := storage.ReadFile(templatefile) + if err != nil { + return nil, fmt.Errorf("cannot read template file (%s): %s", templatefile, err) + } + return templateContents, nil +} + +func WriteServerTemplate(storage storage.Iface, body []byte) error { + templatefile := storage.ConfigPath(path.Join(WIREGUARD_TEMPLATE_DIR, WIREGUARD_TEMPLATE_SERVER)) + + err := storage.WriteFile(templatefile, body) + if err != nil { + return fmt.Errorf("could not write template (%s): %s", templatefile, err) + } + + return nil +} + +func generateWireGuardServerConfig(storage storage.Iface) ([]byte, error) { vpnConfig, err := GetVPNConfig(storage) if err != nil { return nil, fmt.Errorf("failed to get vpn config: %s", err) @@ -61,11 +80,11 @@ func generateWireGuardServerConfig(storage storage.Iface) ([]byte, error) { ExternalInterface: vpnConfig.ExternalInterface, } - templateContents, err := storage.ReadFile(templatefile) + templateContents, err := GetServerTemplate(storage) if err != nil { - return nil, fmt.Errorf("cannot read template file (%s): %s", templatefile, err) + return nil, fmt.Errorf("cannot get template file: %s", err) } - tmpl, err := template.New(path.Base(templatefile)).Parse(string(templateContents)) + tmpl, err := template.New(WIREGUARD_TEMPLATE_SERVER).Parse(string(templateContents)) if err != nil { return nil, fmt.Errorf("could not parse client template: %s", err) } diff --git a/webapp/src/Routes/Setup/GeneralSetup.tsx b/webapp/src/Routes/Setup/GeneralSetup.tsx index a19963e..ee17276 100644 --- a/webapp/src/Routes/Setup/GeneralSetup.tsx +++ b/webapp/src/Routes/Setup/GeneralSetup.tsx @@ -8,6 +8,10 @@ import { useAuthContext } from "../../Auth/Auth"; import { useForm } from '@mantine/form'; import axios, { AxiosError } from "axios"; +type GeneralSetupError = { + error: string; +} + type GeneralSetupRequest = { hostname: string; enableTLS: boolean; @@ -47,7 +51,7 @@ export function GeneralSetup() { const alertIcon = ; const setupMutation = useMutation({ mutationFn: (setupRequest: GeneralSetupRequest) => { - return axios.post(AppSettings.url + '/setup', setupRequest, { + return axios.post(AppSettings.url + '/setup/general', setupRequest, { headers: { "Authorization": "Bearer " + authInfo.token }, @@ -56,9 +60,15 @@ export function GeneralSetup() { onSuccess: () => { queryClient.invalidateQueries({ queryKey: ['users'] }) setSaved(true) + setSaveError("") }, onError: (error:AxiosError) => { - setSaveError("Error: "+ error.message) + const errorMessage = error.response?.data as GeneralSetupError + if(errorMessage?.error === undefined) { + setSaveError("Error: "+ error.message) + } else { + setSaveError("Error: "+ errorMessage.error) + } } }) @@ -90,8 +100,9 @@ export function GeneralSetup() { return ( - {saved ? Settings Saved! : null} - {saveError !== "" ? saveError : null} + {saved && saveError === "" ? Settings Saved! : null} + {saveError !== "" ? {saveError} : null} +
setupMutation.mutate(values))}> - Templates will go here + diff --git a/webapp/src/Routes/Setup/TemplateSetup.tsx b/webapp/src/Routes/Setup/TemplateSetup.tsx new file mode 100644 index 0000000..db0486d --- /dev/null +++ b/webapp/src/Routes/Setup/TemplateSetup.tsx @@ -0,0 +1,111 @@ +import { Container, Button, Alert, Textarea, Space } from "@mantine/core"; +import { useEffect, useState } from "react"; +import { IconInfoCircle } from "@tabler/icons-react"; +import { AppSettings } from "../../Constants/Constants"; +import { useMutation, useQuery } from "@tanstack/react-query"; +import { useAuthContext } from "../../Auth/Auth"; +import { useForm } from '@mantine/form'; +import axios, { AxiosError } from "axios"; + +type TemplateSetupError = { + error: string; +} + +type TemplateSetupRequest = { + clientTemplate: string; + serverTemplate: string; +}; +export function TemplateSetup() { + const [saved, setSaved] = useState(false) + const [saveError, setSaveError] = useState("") + const {authInfo} = useAuthContext(); + const { isPending, error, data, isSuccess } = useQuery({ + queryKey: ['templates-setup'], + queryFn: () => + fetch(AppSettings.url + '/setup/templates', { + headers: { + "Content-Type": "application/json", + "Authorization": "Bearer " + authInfo.token + }, + }).then((res) => { + return res.json() + } + + ), + }) + const form = useForm({ + mode: 'uncontrolled', + initialValues: { + clientTemplate: "", + serverTemplate: "", + }, + }); + const alertIcon = ; + const setupMutation = useMutation({ + mutationFn: (setupRequest: TemplateSetupRequest) => { + return axios.post(AppSettings.url + '/setup/templates', setupRequest, { + headers: { + "Authorization": "Bearer " + authInfo.token + }, + }) + }, + onSuccess: () => { + setSaved(true) + setSaveError("") + }, + onError: (error:AxiosError) => { + const errorMessage = error.response?.data as TemplateSetupError + if(errorMessage?.error === undefined) { + setSaveError("Error: "+ error.message) + } else { + setSaveError("Error: "+ errorMessage.error) + } + } + }) + + + useEffect(() => { + if (isSuccess) { + form.setValues({ ...data }); + } + }, [isSuccess]); + + + if(isPending) return "Loading..." + if(error) return 'A backend error has occurred: ' + error.message + + return ( + + The template files use the Golang template package (see also https://pkg.go.dev/text/template). + + {saved && saveError === "" ? Settings Saved! : null} + {saveError !== "" ? {saveError} : null} + + setupMutation.mutate(values))}> +