diff --git a/adminui/security.go b/adminui/security.go index 9a2d44d4..67aa7d31 100644 --- a/adminui/security.go +++ b/adminui/security.go @@ -1,38 +1 @@ package adminui - -import ( - "net/http" - "net/url" -) - -type security struct { - next http.Handler -} - -func (sh *security) ServeHTTP(w http.ResponseWriter, r *http.Request) { - w.Header().Set("X-Frame-Options", "DENY") - w.Header().Set("Strict-Transport-Security", "max-age=31536000") - w.Header().Set("X-Content-Type-Options", "nosniff") - - if r.Method != "GET" { - u, err := url.Parse(r.Header.Get("Origin")) - if err != nil { - http.Error(w, "Bad Request", 400) - return - } - - //If origin != host header - if r.Host != u.Host { - http.Error(w, "Bad Request", 400) - return - } - } - - sh.next.ServeHTTP(w, r) -} - -func setSecurityHeaders(f http.Handler) http.Handler { - return &security{ - next: f, - } -} diff --git a/adminui/src/css/sb-admin-2.css b/adminui/src/css/sb-admin-2.css index a15ab904..6bb79cc5 100755 --- a/adminui/src/css/sb-admin-2.css +++ b/adminui/src/css/sb-admin-2.css @@ -10289,67 +10289,67 @@ a:focus { @keyframes noise-anim { 0% { - clip: rect(62px, 9999px, 66px, 0); + clip: rect(79px, 9999px, 67px, 0); } 5% { - clip: rect(44px, 9999px, 54px, 0); + clip: rect(2px, 9999px, 86px, 0); } 10% { - clip: rect(47px, 9999px, 34px, 0); + clip: rect(53px, 9999px, 16px, 0); } 15% { - clip: rect(75px, 9999px, 98px, 0); + clip: rect(38px, 9999px, 49px, 0); } 20% { - clip: rect(45px, 9999px, 70px, 0); + clip: rect(38px, 9999px, 94px, 0); } 25% { - clip: rect(29px, 9999px, 35px, 0); + clip: rect(41px, 9999px, 28px, 0); } 30% { - clip: rect(88px, 9999px, 100px, 0); + clip: rect(91px, 9999px, 92px, 0); } 35% { - clip: rect(80px, 9999px, 39px, 0); + clip: rect(32px, 9999px, 91px, 0); } 40% { - clip: rect(84px, 9999px, 60px, 0); + clip: rect(53px, 9999px, 57px, 0); } 45% { - clip: rect(24px, 9999px, 38px, 0); + clip: rect(5px, 9999px, 85px, 0); } 50% { - clip: rect(73px, 9999px, 80px, 0); + clip: rect(82px, 9999px, 43px, 0); } 55% { - clip: rect(66px, 9999px, 20px, 0); + clip: rect(41px, 9999px, 44px, 0); } 60% { - clip: rect(19px, 9999px, 18px, 0); + clip: rect(47px, 9999px, 97px, 0); } 65% { - clip: rect(66px, 9999px, 32px, 0); + clip: rect(35px, 9999px, 62px, 0); } 70% { - clip: rect(90px, 9999px, 22px, 0); + clip: rect(91px, 9999px, 62px, 0); } 75% { - clip: rect(28px, 9999px, 66px, 0); + clip: rect(49px, 9999px, 24px, 0); } 80% { - clip: rect(10px, 9999px, 67px, 0); + clip: rect(51px, 9999px, 84px, 0); } 85% { - clip: rect(79px, 9999px, 78px, 0); + clip: rect(67px, 9999px, 66px, 0); } 90% { - clip: rect(1px, 9999px, 14px, 0); + clip: rect(60px, 9999px, 20px, 0); } 95% { - clip: rect(8px, 9999px, 81px, 0); + clip: rect(80px, 9999px, 30px, 0); } 100% { - clip: rect(48px, 9999px, 74px, 0); + clip: rect(26px, 9999px, 100px, 0); } } .error:after { @@ -10367,67 +10367,67 @@ a:focus { @keyframes noise-anim-2 { 0% { - clip: rect(82px, 9999px, 91px, 0); + clip: rect(52px, 9999px, 16px, 0); } 5% { - clip: rect(24px, 9999px, 30px, 0); + clip: rect(79px, 9999px, 20px, 0); } 10% { - clip: rect(31px, 9999px, 41px, 0); + clip: rect(39px, 9999px, 37px, 0); } 15% { - clip: rect(66px, 9999px, 42px, 0); + clip: rect(36px, 9999px, 20px, 0); } 20% { - clip: rect(60px, 9999px, 93px, 0); + clip: rect(64px, 9999px, 79px, 0); } 25% { - clip: rect(98px, 9999px, 23px, 0); + clip: rect(91px, 9999px, 13px, 0); } 30% { - clip: rect(68px, 9999px, 72px, 0); + clip: rect(43px, 9999px, 93px, 0); } 35% { - clip: rect(91px, 9999px, 41px, 0); + clip: rect(50px, 9999px, 23px, 0); } 40% { - clip: rect(95px, 9999px, 49px, 0); + clip: rect(39px, 9999px, 59px, 0); } 45% { - clip: rect(50px, 9999px, 37px, 0); + clip: rect(32px, 9999px, 88px, 0); } 50% { - clip: rect(19px, 9999px, 2px, 0); + clip: rect(23px, 9999px, 84px, 0); } 55% { - clip: rect(44px, 9999px, 84px, 0); + clip: rect(12px, 9999px, 89px, 0); } 60% { - clip: rect(69px, 9999px, 41px, 0); + clip: rect(87px, 9999px, 15px, 0); } 65% { - clip: rect(11px, 9999px, 12px, 0); + clip: rect(4px, 9999px, 27px, 0); } 70% { - clip: rect(36px, 9999px, 35px, 0); + clip: rect(67px, 9999px, 4px, 0); } 75% { - clip: rect(1px, 9999px, 4px, 0); + clip: rect(34px, 9999px, 70px, 0); } 80% { - clip: rect(75px, 9999px, 83px, 0); + clip: rect(80px, 9999px, 11px, 0); } 85% { - clip: rect(51px, 9999px, 74px, 0); + clip: rect(70px, 9999px, 16px, 0); } 90% { - clip: rect(86px, 9999px, 21px, 0); + clip: rect(82px, 9999px, 91px, 0); } 95% { - clip: rect(67px, 9999px, 5px, 0); + clip: rect(95px, 9999px, 58px, 0); } 100% { - clip: rect(13px, 9999px, 60px, 0); + clip: rect(64px, 9999px, 5px, 0); } } .error:before { diff --git a/adminui/ui_webserver.go b/adminui/ui_webserver.go index af263455..20e47e74 100644 --- a/adminui/ui_webserver.go +++ b/adminui/ui_webserver.go @@ -95,8 +95,8 @@ func New(firewall *router.Firewall, errs chan<- error) (ui *AdminUI, err error) } u.Path = path.Join(u.Path, "/login/oidc/callback") - log.Println("Admin OIDC callback: ", u.String()) - log.Println("Connecting to Admin UI OIDC provider: ", config.Values.ManagementUI.OIDC.IssuerURL) + log.Println("[ADMINUI] OIDC callback: ", u.String()) + log.Println("[ADMINUI] Connecting to OIDC provider: ", config.Values.ManagementUI.OIDC.IssuerURL) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -106,7 +106,7 @@ func New(firewall *router.Firewall, errs chan<- error) (ui *AdminUI, err error) return nil, fmt.Errorf("unable to connect to oidc provider for admin ui. err %s", err) } - log.Println("Connected to admin oidc provider!") + log.Println("[ADMINUI] Connected to admin oidc provider!") } if *config.Values.ManagementUI.Password.Enabled { @@ -117,7 +117,7 @@ func New(firewall *router.Firewall, errs chan<- error) (ui *AdminUI, err error) } if len(admins) == 0 { - log.Println("[INFO] *************** Web interface enabled but no administrator users exist, generating new ones CREDENTIALS FOLLOW ***************") + log.Println("[ADMINUI] *************** Web interface enabled but no administrator users exist, generating new ones CREDENTIALS FOLLOW ***************") username, err := utils.GenerateRandomHex(8) if err != nil { @@ -129,10 +129,10 @@ func New(firewall *router.Firewall, errs chan<- error) (ui *AdminUI, err error) return nil, fmt.Errorf("failed to generate random password: %s", err) } - log.Println("Username: ", username) - log.Println("Password: ", password) + log.Println("[ADMINUI] Username: ", username) + log.Println("[ADMINUI] Password: ", password) - log.Println("This information will not be shown again. ") + log.Println("[ADMINUI] This information will not be shown again. ") err = adminUI.ctrl.AddAdminUser(username, password, true) if err != nil { @@ -303,7 +303,7 @@ func New(firewall *router.Firewall, errs chan<- error) (ui *AdminUI, err error) WriteTimeout: 10 * time.Second, IdleTimeout: 120 * time.Second, TLSConfig: tlsConfig, - Handler: setSecurityHeaders(allRoutes), + Handler: utils.SetSecurityHeaders(allRoutes), } if err := adminUI.https.ListenAndServeTLS(config.Values.ManagementUI.CertPath, config.Values.ManagementUI.KeyPath); err != nil && !errors.Is(err, http.ErrServerClosed) { @@ -318,7 +318,7 @@ func New(firewall *router.Firewall, errs chan<- error) (ui *AdminUI, err error) ReadTimeout: 5 * time.Second, WriteTimeout: 10 * time.Second, IdleTimeout: 120 * time.Second, - Handler: setSecurityHeaders(allRoutes), + Handler: utils.SetSecurityHeaders(allRoutes), } if err := adminUI.http.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { errs <- fmt.Errorf("webserver management listener failed: %v", adminUI.http.ListenAndServe()) @@ -328,7 +328,7 @@ func New(firewall *router.Firewall, errs chan<- error) (ui *AdminUI, err error) } }() - log.Println("Started Managemnt UI:\n\t\t\tListening:", config.Values.ManagementUI.ListenAddress) + log.Println("[ADMINUI] Started Managemnt UI listening:", config.Values.ManagementUI.ListenAddress) return &adminUI, nil } diff --git a/commands/start.go b/commands/start.go index c32edfa0..1b8c15ca 100644 --- a/commands/start.go +++ b/commands/start.go @@ -14,6 +14,7 @@ import ( "github.com/NHAS/wag/adminui" "github.com/NHAS/wag/internal/config" "github.com/NHAS/wag/internal/data" + "github.com/NHAS/wag/internal/enrolment" "github.com/NHAS/wag/internal/mfaportal" "github.com/NHAS/wag/internal/router" @@ -90,9 +91,10 @@ func startWag(noIptables bool, cancel <-chan bool, errorChan chan<- error) func( routerFw *router.Firewall - controlServer *server.WagControlSocketServer - mfaPortal *mfaportal.MfaPortal - adminUI *adminui.AdminUI + controlServer *server.WagControlSocketServer + mfaPortal *mfaportal.MfaPortal + enrolmentServer *enrolment.EnrolmentServer + adminUI *adminui.AdminUI err error ) @@ -109,6 +111,10 @@ func startWag(noIptables bool, cancel <-chan bool, errorChan chan<- error) func( mfaPortal.Close() } + if enrolmentServer != nil { + enrolmentServer.Close() + } + if adminUI != nil { adminUI.Close() } @@ -176,6 +182,12 @@ func startWag(noIptables bool, cancel <-chan bool, errorChan chan<- error) func( return } + enrolmentServer, err = enrolment.New(routerFw, errorChan) + if err != nil { + errorChan <- fmt.Errorf("unable to start enrolment server: %v", err) + return + } + if config.Values.ManagementUI.Enabled { adminUI, err = adminui.New(routerFw, errorChan) if err != nil { diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml index 8e5a8786..b0721a07 100644 --- a/docker-compose.dev.yml +++ b/docker-compose.dev.yml @@ -17,6 +17,8 @@ services: command: sleep infinity ports: - "4433:4433/tcp" + - "53230:53230/udp" + - "8081:8081/tcp" volumes: - .:/usr/local/bin/wag:ro - ./docker-test-config.json:/opt/config.json diff --git a/internal/enrolment/resources/embed.go b/internal/enrolment/resources/embed.go new file mode 100644 index 00000000..dc971a5d --- /dev/null +++ b/internal/enrolment/resources/embed.go @@ -0,0 +1,44 @@ +package resources + +import ( + "embed" + "html/template" + "io" + "path" + + "github.com/NHAS/wag/internal/config" +) + +type Interface struct { + ClientPrivateKey string + ClientAddress string + ClientPresharedKey string + + ServerAddress string + ServerPublicKey string + CapturedAddresses []string + DNS []string +} + +type QrCodeRegistrationDisplay struct { + ImageData template.URL + Username string +} + +//go:embed templates/* +var templates embed.FS + +func Render(page string, out io.Writer, data interface{}) error { + return RenderWithFuncs(page, out, data, nil) +} + +func RenderWithFuncs(page string, out io.Writer, data interface{}, templateFuncs template.FuncMap) error { + var currentTemplate *template.Template + if len(config.Values.MFATemplatesDirectory) != 0 { + currentTemplate = template.Must(template.New(path.Base(page)).Funcs(templateFuncs).ParseFiles(path.Join(config.Values.MFATemplatesDirectory, page))) + } else { + currentTemplate = template.Must(template.New(path.Base(page)).Funcs(templateFuncs).ParseFS(templates, "templates/"+page)) + } + + return currentTemplate.Execute(out, data) +} diff --git a/internal/mfaportal/resources/templates/qrcode_registration.html b/internal/enrolment/resources/templates/qrcode_enrolment.html similarity index 100% rename from internal/mfaportal/resources/templates/qrcode_registration.html rename to internal/enrolment/resources/templates/qrcode_enrolment.html diff --git a/internal/mfaportal/resources/templates/interface.tmpl b/internal/enrolment/resources/templates/wgconf_enrolment.tmpl similarity index 100% rename from internal/mfaportal/resources/templates/interface.tmpl rename to internal/enrolment/resources/templates/wgconf_enrolment.tmpl diff --git a/internal/enrolment/web.go b/internal/enrolment/web.go new file mode 100644 index 00000000..56e404ec --- /dev/null +++ b/internal/enrolment/web.go @@ -0,0 +1,412 @@ +package enrolment + +import ( + "bytes" + "crypto/tls" + "encoding/base64" + "errors" + "fmt" + "html/template" + "image/png" + "log" + "net" + "net/http" + "net/url" + "strings" + "time" + + "github.com/NHAS/wag/internal/config" + "github.com/NHAS/wag/internal/data" + "github.com/NHAS/wag/internal/mfaportal/resources" + "github.com/NHAS/wag/internal/router" + "github.com/NHAS/wag/internal/routetypes" + "github.com/NHAS/wag/internal/users" + "github.com/NHAS/wag/internal/utils" + "github.com/NHAS/wag/pkg/httputils" + "github.com/boombuler/barcode" + "github.com/boombuler/barcode/qr" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +type EnrolmentServer struct { + publicHTTPServ *http.Server + publicTLSServ *http.Server + firewall *router.Firewall +} + +func (es *EnrolmentServer) Close() { + + if es.publicHTTPServ != nil { + es.publicHTTPServ.Close() + } + + if es.publicTLSServ != nil { + es.publicTLSServ.Close() + } +} + +func (es *EnrolmentServer) registerDevice(w http.ResponseWriter, r *http.Request) { + remoteAddr := utils.GetIPFromRequest(r) + + key, err := url.PathUnescape(r.URL.Query().Get("key")) + if err != nil { + http.NotFound(w, r) + return + } + + if len(key) == 0 { + log.Println("unknown", remoteAddr, "no registration key specified, ignoring") + http.NotFound(w, r) + return + } + + username, overwrites, groups, err := data.GetRegistrationToken(key) + if err != nil { + log.Println(username, remoteAddr, "failed to get registration key:", err) + http.NotFound(w, r) + return + } + + if len(groups) != 0 { + err := data.SetUserGroupMembership(username, groups) + if err != nil { + log.Println(username, remoteAddr, "could not set user membership from registration token:", err) + http.Error(w, "Server error", http.StatusInternalServerError) + return + } + } + + var publickey, privatekey wgtypes.Key + pubkeyParam, err := url.PathUnescape(r.URL.Query().Get("pubkey")) + if err != nil { + log.Println(username, remoteAddr, "failed to url decode public key paramter:", err) + http.NotFound(w, r) + return + } + + if len(pubkeyParam) != 0 { + publickey, err = wgtypes.ParseKey(pubkeyParam) + if err != nil { + log.Println(username, remoteAddr, "failed to unmarshal wireguard public key:", err) + http.Error(w, "Server error", http.StatusInternalServerError) + return + } + } else { + privatekey, err = wgtypes.GeneratePrivateKey() + if err != nil { + log.Println(username, remoteAddr, "failed to generate wireguard keys:", err) + http.Error(w, "Server error", http.StatusInternalServerError) + return + } + publickey = privatekey.PublicKey() + } + + user, err := users.GetUser(username) + if err != nil { + user, err = users.CreateUser(username) + if err != nil { + log.Println(username, remoteAddr, "unable create new user: "+err.Error()) + http.Error(w, "Server Error", http.StatusInternalServerError) + return + } + } + + var ( + address string + ) + if overwrites != "" { + + err = user.SetDevicePublicKey(publickey.String(), overwrites) + if err != nil { + log.Println(username, remoteAddr, "could update '", overwrites, "': ", err) + http.Error(w, "Server Error", http.StatusInternalServerError) + return + } + + address = overwrites + + } else { + + // Make sure not to accidentally shadow the global err here as we're using a defer to monitor failures to delete the device + var device data.Device + device, err = user.AddDevice(publickey) + if err != nil { + log.Println(username, remoteAddr, "unable to add device: ", err) + + http.Error(w, "Server Error", http.StatusInternalServerError) + return + } + address = device.Address + + defer func() { + + if err != nil { + log.Println(username, remoteAddr, "removing device (due to registration failure)") + err := user.DeleteDevice(device.Address) + if err != nil { + log.Println(username, remoteAddr, "unable to remove wg device: ", err) + } + } + }() + } + + acl := data.GetEffectiveAcl(username) + + wgPublicKey, wgPort, err := es.firewall.ServerDetails() + if err != nil { + log.Println(username, remoteAddr, "unable access wireguard device: ", err) + http.Error(w, "Server Error", http.StatusInternalServerError) + return + } + + keyStr := privatekey.String() + //Empty value of a private key in wgtype.Key + if keyStr == "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=" { + keyStr = "" + } + + presharedKey, err := user.GetDevicePresharedKey(address) + if err != nil { + log.Println(username, remoteAddr, "unable access device preshared key: ", err) + http.Error(w, "Server Error", http.StatusInternalServerError) + return + } + + dnsWithOutSubnet, err := data.GetDNS() + if err != nil { + log.Println(username, remoteAddr, "unable get dns: ", err) + http.Error(w, "Server Error", http.StatusInternalServerError) + return + } + + for i := 0; i < len(dnsWithOutSubnet); i++ { + dnsWithOutSubnet[i] = strings.TrimSuffix(dnsWithOutSubnet[i], "/32") + } + + routes, err := routetypes.AclsToRoutes(append(acl.Allow, acl.Mfa...)) + if err != nil { + log.Println(username, remoteAddr, "unable access parse acls to produce routes: ", err) + http.Error(w, "Server Error", http.StatusInternalServerError) + return + } + + wireguardInterface := resources.Interface{ + ClientPrivateKey: keyStr, + ClientAddress: address, + ServerPublicKey: wgPublicKey.String(), + CapturedAddresses: routes, + DNS: dnsWithOutSubnet, + ClientPresharedKey: presharedKey, + } + + externalAddress, err := data.GetExternalAddress() + if err != nil { + log.Println(username, remoteAddr, "unable to get server external address from datastore: ", err) + http.Error(w, "Server Error", http.StatusInternalServerError) + return + } + + // If the external address defined in the config has a port, use that, otherwise defaultly add the same port as the wireguard device + _, _, err = net.SplitHostPort(externalAddress) + if err != nil { + externalAddress = fmt.Sprintf("%s:%d", externalAddress, wgPort) + } + + wireguardInterface.ServerAddress = externalAddress + + if r.URL.Query().Get("type") == "mobile" { + w.Header().Set("Content-Type", "text/html; charset=UTF-8") + + var wireguardProfile bytes.Buffer + err = resources.RenderWithFuncs("interface.tmpl", &wireguardProfile, &wireguardInterface, template.FuncMap{ + "StringsJoin": strings.Join, + "Unescape": func(s string) template.HTML { return template.HTML(s) }, + }) + if err != nil { + log.Println(username, remoteAddr, "failed to execute template to generate wireguard config:", err) + http.Error(w, "Server Error", http.StatusInternalServerError) + return + } + + image, err := qr.Encode(wireguardProfile.String(), qr.M, qr.Auto) + if err != nil { + log.Println(username, remoteAddr, "failed to generate qr code:", err) + http.Error(w, "Server Error", http.StatusInternalServerError) + return + } + + image, err = barcode.Scale(image, 400, 400) + if err != nil { + log.Println(username, remoteAddr, "failed to output barcode bytes:", err) + http.Error(w, "Server Error", http.StatusInternalServerError) + return + } + + var buff bytes.Buffer + err = png.Encode(&buff, image) + if err != nil { + log.Println(user.Username, remoteAddr, "encoding mfa secret as png failed:", err) + http.Error(w, "Unknown error", http.StatusInternalServerError) + return + } + + qrCodeBytes := resources.QrCodeRegistrationDisplay{ + ImageData: template.URL("data:image/png;base64, " + base64.StdEncoding.EncodeToString(buff.Bytes())), + Username: username, + } + + err = resources.Render("qrcode_registration.html", w, &qrCodeBytes) + if err != nil { + log.Println(username, remoteAddr, "failed to execute template to show qr code wireguard config:", err) + http.Error(w, "Server Error", http.StatusInternalServerError) + return + } + + } else { + + w.Header().Set("Content-Disposition", "attachment; filename="+data.GetWireguardConfigName()) + + err = resources.RenderWithFuncs("interface.tmpl", w, &wireguardInterface, template.FuncMap{ + "StringsJoin": strings.Join, + "Unescape": func(s string) template.HTML { return template.HTML(s) }, + }) + if err != nil { + log.Println(username, remoteAddr, "failed to execute template to generate wireguard config:", err) + http.Error(w, "Server Error", http.StatusInternalServerError) + return + } + } + + //Finish registration process + err = data.FinaliseRegistration(key) + if err != nil { + log.Println(username, remoteAddr, "expiring registration token failed:", err) + http.Error(w, "Server Error", http.StatusInternalServerError) + return + } + + logMsg := "registered as" + if overwrites != "" { + logMsg = "overwrote" + } + log.Println(username, remoteAddr, "successfully", logMsg, address, ":", publickey.String()) +} + +func (es *EnrolmentServer) reachability(w http.ResponseWriter, _ *http.Request) { + w.Header().Add("Content-Type", "text/plain") + + isDrained, err := data.IsDrained(data.GetServerID().String()) + if err != nil { + http.Error(w, "Failed to fetch state", http.StatusInternalServerError) + return + } + + if !isDrained { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + return + } + + w.WriteHeader(http.StatusGone) + w.Write([]byte("Drained")) + +} + +func New(firewall *router.Firewall, errChan chan<- error) (*EnrolmentServer, error) { + if firewall == nil { + panic("firewall was nil") + } + + var es EnrolmentServer + es.firewall = firewall + + //https://blog.cloudflare.com/exposing-go-on-the-internet/ + tlsConfig := &tls.Config{ + // Only use curves which have assembly implementations + CurvePreferences: []tls.CurveID{ + tls.CurveP256, + tls.X25519, // Go 1.8 only + }, + MinVersion: tls.VersionTLS12, + CipherSuites: []uint16{ + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, // Go 1.8 only + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, // Go 1.8 only + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + }, + } + + public := httputils.NewMux() + public.Get("/static/", utils.EmbeddedStatic(resources.Static)) + public.Get("/reachability", es.reachability) + public.Get("/register_device", es.registerDevice) + + if config.Values.Webserver.Public.SupportsTLS() { + + go func() { + + es.publicTLSServ = &http.Server{ + Addr: config.Values.Webserver.Public.ListenAddress, + ReadTimeout: 5 * time.Second, + WriteTimeout: 10 * time.Second, + IdleTimeout: 120 * time.Second, + TLSConfig: tlsConfig, + Handler: utils.SetSecurityHeaders(public), + } + + if err := es.publicTLSServ.ListenAndServeTLS(config.Values.Webserver.Public.CertPath, config.Values.Webserver.Public.KeyPath); err != nil && !errors.Is(err, http.ErrServerClosed) { + errChan <- fmt.Errorf("TLS webserver enrolment listener failed: %v", err) + } + }() + + if config.Values.NumberProxies == 0 { + go func() { + + address, port, err := net.SplitHostPort(config.Values.Webserver.Public.ListenAddress) + + if err != nil { + errChan <- fmt.Errorf("malformed listen address for enrolment listener: %v", err) + return + } + + // If we're supporting tls, add a redirection handler from 80 -> tls + port += ":" + port + if port == "443" { + port = "" + } + + es.publicHTTPServ = &http.Server{ + Addr: address + ":80", + ReadTimeout: 5 * time.Second, + WriteTimeout: 10 * time.Second, + IdleTimeout: 120 * time.Second, + Handler: utils.SetSecurityHeaders(utils.SetRedirectHandler(port)), + } + + log.Printf("Creating redirection from 80/tcp to TLS webserver enrolment listener failed: %v", es.publicHTTPServ.ListenAndServe()) + }() + } + + } else { + go func() { + es.publicHTTPServ = &http.Server{ + Addr: config.Values.Webserver.Public.ListenAddress, + ReadTimeout: 5 * time.Second, + WriteTimeout: 10 * time.Second, + IdleTimeout: 120 * time.Second, + Handler: utils.SetSecurityHeaders(public), + } + + if err := es.publicHTTPServ.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + errChan <- fmt.Errorf("HTTP webserver enrolment listener failed: %v", err) + } + }() + } + + log.Println("[ENROLMENT] Public enrolment listening: ", config.Values.Webserver.Public.ListenAddress) + + return &es, nil + +} diff --git a/internal/mfaportal/security.go b/internal/mfaportal/security.go deleted file mode 100644 index 16f3251d..00000000 --- a/internal/mfaportal/security.go +++ /dev/null @@ -1,48 +0,0 @@ -package mfaportal - -import ( - "net" - "net/http" - "strings" -) - -type securityHeaders struct { - next http.Handler -} - -func (sh *securityHeaders) ServeHTTP(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Security-Policy", "default-src 'none'; script-src 'self'; connect-src 'self'; object-src 'none'; img-src 'self' data:; require-trusted-types-for 'script'; style-src 'self' fonts.googleapis.com; font-src fonts.gstatic.com fonts.googleapis.com; ") - w.Header().Set("X-Frame-Options", "DENY") - w.Header().Set("Strict-Transport-Security", "max-age=31536000") - - sh.next.ServeHTTP(w, r) -} - -func setSecurityHeaders(f http.Handler) http.Handler { - return &securityHeaders{ - next: f, - } -} - -type httpRedirectHandler struct { - TLSPort string -} - -func (sh *httpRedirectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - - host, _, err := net.SplitHostPort(r.Host) - if err != nil { - if strings.Contains(err.Error(), "missing port in address") { - host = r.Host - } else { - http.Error(w, "Server Error", http.StatusInternalServerError) - return - } - } - - http.Redirect(w, r, "https://"+host+r.RequestURI, http.StatusTemporaryRedirect) -} - -func setRedirectHandler(TLSPort string) http.Handler { - return &httpRedirectHandler{TLSPort: TLSPort} -} diff --git a/internal/mfaportal/static.go b/internal/mfaportal/static.go deleted file mode 100644 index 58fa38cb..00000000 --- a/internal/mfaportal/static.go +++ /dev/null @@ -1,51 +0,0 @@ -package mfaportal - -import ( - "log" - "net/http" - "path/filepath" - - "github.com/NHAS/wag/internal/mfaportal/resources" -) - -func setMimeType(w http.ResponseWriter, r *http.Request) { - headers := w.Header() - ext := filepath.Ext(r.URL.Path) - - switch ext { - case ".js": - headers.Set("Content-Type", "text/javascript") - case ".css": - headers.Set("Content-Type", "text/css") - case ".png": - headers.Set("Content-Type", "image/png") - case ".jpg": - headers.Set("Content-Type", "image/jpg") - case ".svg": - headers.Set("Content-Type", "image/svg") - } -} - -func embeddedStatic(w http.ResponseWriter, r *http.Request) { - - var err error - var fileContent []byte - - if len(r.URL.Path) > 0 { - r.URL.Path = r.URL.Path[1:] - } - - if fileContent, err = resources.Static.ReadFile(r.URL.Path); err != nil { - log.Println("Error getting static: ", err) - http.NotFound(w, r) - return - } - - setMimeType(w, r) - - _, err = w.Write(fileContent) - if err != nil { - log.Println("Unable to write static resource: ", err, " path: ", r.URL.Path) - http.Error(w, "Server Error", 500) - } -} diff --git a/internal/mfaportal/web.go b/internal/mfaportal/web.go index 4fdc4a6d..fc26eef6 100644 --- a/internal/mfaportal/web.go +++ b/internal/mfaportal/web.go @@ -1,18 +1,12 @@ package mfaportal import ( - "bytes" "crypto/tls" - "encoding/base64" "encoding/json" "errors" "fmt" - "html/template" - "image/png" "log" - "net" "net/http" - "net/url" "path" "strings" "time" @@ -22,23 +16,16 @@ import ( "github.com/NHAS/wag/internal/mfaportal/authenticators" "github.com/NHAS/wag/internal/mfaportal/resources" "github.com/NHAS/wag/internal/router" - "github.com/NHAS/wag/internal/routetypes" "github.com/NHAS/wag/internal/users" "github.com/NHAS/wag/internal/utils" "github.com/NHAS/wag/pkg/httputils" - "github.com/boombuler/barcode" - "github.com/boombuler/barcode/qr" - - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) type MfaPortal struct { tunnelHTTPServ *http.Server tunnelTLSServ *http.Server - publicHTTPServ *http.Server - publicTLSServ *http.Server - firewall *router.Firewall + firewall *router.Firewall listenerKeys struct { Oidc string @@ -58,14 +45,6 @@ func (mp *MfaPortal) Close() { mp.tunnelTLSServ.Close() } - if mp.publicHTTPServ != nil { - mp.publicHTTPServ.Close() - } - - if mp.publicTLSServ != nil { - mp.publicTLSServ.Close() - } - mp.deregisterListeners() log.Println("Stopped MFA portal") @@ -97,73 +76,6 @@ func New(firewall *router.Firewall, errChan chan<- error) (m *MfaPortal, err err }, } - public := httputils.NewMux() - public.Get("/static/", embeddedStatic) - public.Get("/register_device", mfaPortal.registerDevice) - public.Get("/reachability", mfaPortal.reachability) - - if config.Values.Webserver.Public.SupportsTLS() { - - go func() { - - mfaPortal.publicTLSServ = &http.Server{ - Addr: config.Values.Webserver.Public.ListenAddress, - ReadTimeout: 5 * time.Second, - WriteTimeout: 10 * time.Second, - IdleTimeout: 120 * time.Second, - TLSConfig: tlsConfig, - Handler: setSecurityHeaders(public), - } - - if err := mfaPortal.publicTLSServ.ListenAndServeTLS(config.Values.Webserver.Public.CertPath, config.Values.Webserver.Public.KeyPath); err != nil && !errors.Is(err, http.ErrServerClosed) { - errChan <- fmt.Errorf("TLS webserver public listener failed: %v", err) - } - }() - - if config.Values.NumberProxies == 0 { - go func() { - - address, port, err := net.SplitHostPort(config.Values.Webserver.Public.ListenAddress) - - if err != nil { - errChan <- fmt.Errorf("malformed listen address for public listener: %v", err) - return - } - - // If we're supporting tls, add a redirection handler from 80 -> tls - port += ":" + port - if port == "443" { - port = "" - } - - mfaPortal.publicHTTPServ = &http.Server{ - Addr: address + ":80", - ReadTimeout: 5 * time.Second, - WriteTimeout: 10 * time.Second, - IdleTimeout: 120 * time.Second, - Handler: setSecurityHeaders(setRedirectHandler(port)), - } - - log.Printf("Creating redirection from 80/tcp to TLS webserver public listener failed: %v", mfaPortal.publicHTTPServ.ListenAndServe()) - }() - } - - } else { - go func() { - mfaPortal.publicHTTPServ = &http.Server{ - Addr: config.Values.Webserver.Public.ListenAddress, - ReadTimeout: 5 * time.Second, - WriteTimeout: 10 * time.Second, - IdleTimeout: 120 * time.Second, - Handler: setSecurityHeaders(public), - } - - if err := mfaPortal.publicHTTPServ.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - errChan <- fmt.Errorf("HTTP webserver public listener failed: %v", err) - } - }() - } - tunnel := httputils.NewMux() tunnel.Get("/status/", mfaPortal.status) @@ -176,7 +88,7 @@ func New(firewall *router.Firewall, errChan chan<- error) (m *MfaPortal, err err tunnel.Handle("/custom/", http.StripPrefix("/custom/", fs)) } - tunnel.Get("/static/", embeddedStatic) + tunnel.Get("/static/", utils.EmbeddedStatic(resources.Static)) // Do inital state setup for our authentication methods err = authenticators.AddMFARoutes(tunnel, mfaPortal.firewall) @@ -210,7 +122,7 @@ func New(firewall *router.Firewall, errChan chan<- error) (m *MfaPortal, err err WriteTimeout: 10 * time.Second, IdleTimeout: 120 * time.Second, TLSConfig: tlsConfig, - Handler: setSecurityHeaders(tunnel), + Handler: utils.SetSecurityHeaders(tunnel), } if err := mfaPortal.tunnelTLSServ.ListenAndServeTLS(config.Values.Webserver.Tunnel.CertPath, config.Values.Webserver.Tunnel.KeyPath); err != nil && errors.Is(err, http.ErrServerClosed) { errChan <- fmt.Errorf("TLS webserver tunnel listener failed: %v", err) @@ -231,7 +143,7 @@ func New(firewall *router.Firewall, errChan chan<- error) (m *MfaPortal, err err ReadTimeout: 5 * time.Second, WriteTimeout: 10 * time.Second, IdleTimeout: 120 * time.Second, - Handler: setSecurityHeaders(setRedirectHandler(port)), + Handler: utils.SetSecurityHeaders(utils.SetRedirectHandler(port)), } log.Printf("HTTP redirect to TLS webserver tunnel listener failed: %v", mfaPortal.tunnelHTTPServ.ListenAndServe()) @@ -244,7 +156,7 @@ func New(firewall *router.Firewall, errChan chan<- error) (m *MfaPortal, err err ReadTimeout: 5 * time.Second, WriteTimeout: 10 * time.Second, IdleTimeout: 120 * time.Second, - Handler: setSecurityHeaders(tunnel), + Handler: utils.SetSecurityHeaders(tunnel), } if err := mfaPortal.tunnelHTTPServ.ListenAndServe(); err != nil && errors.Is(err, http.ErrServerClosed) { @@ -255,9 +167,7 @@ func New(firewall *router.Firewall, errChan chan<- error) (m *MfaPortal, err err } //Group the print statement so that multithreading won't disorder them - log.Println("Started listening:\n", - "\t\t\tTunnel Listener: ", tunnelListenAddress, "\n", - "\t\t\tPublic Listener: ", config.Values.Webserver.Public.ListenAddress) + log.Println("[PORTAL] Captive portal started listening: ", tunnelListenAddress) return m, nil } @@ -386,273 +296,6 @@ func (mp *MfaPortal) authorise(w http.ResponseWriter, r *http.Request) { mfaMethod.MFAPromptUI(w, r, user.Username, clientTunnelIp.String()) } -func (mp *MfaPortal) reachability(w http.ResponseWriter, _ *http.Request) { - w.Header().Add("Content-Type", "text/plain") - - isDrained, err := data.IsDrained(data.GetServerID().String()) - if err != nil { - http.Error(w, "Failed to fetch state", http.StatusInternalServerError) - return - } - - if !isDrained { - w.WriteHeader(http.StatusOK) - w.Write([]byte("OK")) - return - } - - w.WriteHeader(http.StatusGone) - w.Write([]byte("Drained")) - -} - -func (mp *MfaPortal) registerDevice(w http.ResponseWriter, r *http.Request) { - remoteAddr := utils.GetIPFromRequest(r) - - key, err := url.PathUnescape(r.URL.Query().Get("key")) - if err != nil { - http.NotFound(w, r) - return - } - - if len(key) == 0 { - log.Println("unknown", remoteAddr, "no registration key specified, ignoring") - http.NotFound(w, r) - return - } - - username, overwrites, groups, err := data.GetRegistrationToken(key) - if err != nil { - log.Println(username, remoteAddr, "failed to get registration key:", err) - http.NotFound(w, r) - return - } - - if len(groups) != 0 { - err := data.SetUserGroupMembership(username, groups) - if err != nil { - log.Println(username, remoteAddr, "could not set user membership from registration token:", err) - http.Error(w, "Server error", http.StatusInternalServerError) - return - } - } - - var publickey, privatekey wgtypes.Key - pubkeyParam, err := url.PathUnescape(r.URL.Query().Get("pubkey")) - if err != nil { - log.Println(username, remoteAddr, "failed to url decode public key paramter:", err) - http.NotFound(w, r) - return - } - - if len(pubkeyParam) != 0 { - publickey, err = wgtypes.ParseKey(pubkeyParam) - if err != nil { - log.Println(username, remoteAddr, "failed to unmarshal wireguard public key:", err) - http.Error(w, "Server error", http.StatusInternalServerError) - return - } - } else { - privatekey, err = wgtypes.GeneratePrivateKey() - if err != nil { - log.Println(username, remoteAddr, "failed to generate wireguard keys:", err) - http.Error(w, "Server error", http.StatusInternalServerError) - return - } - publickey = privatekey.PublicKey() - } - - user, err := users.GetUser(username) - if err != nil { - user, err = users.CreateUser(username) - if err != nil { - log.Println(username, remoteAddr, "unable create new user: "+err.Error()) - http.Error(w, "Server Error", http.StatusInternalServerError) - return - } - } - - var ( - address string - ) - if overwrites != "" { - - err = user.SetDevicePublicKey(publickey.String(), overwrites) - if err != nil { - log.Println(username, remoteAddr, "could update '", overwrites, "': ", err) - http.Error(w, "Server Error", http.StatusInternalServerError) - return - } - - address = overwrites - - } else { - - // Make sure not to accidentally shadow the global err here as we're using a defer to monitor failures to delete the device - var device data.Device - device, err = user.AddDevice(publickey) - if err != nil { - log.Println(username, remoteAddr, "unable to add device: ", err) - - http.Error(w, "Server Error", http.StatusInternalServerError) - return - } - address = device.Address - - defer func() { - - if err != nil { - log.Println(username, remoteAddr, "removing device (due to registration failure)") - err := user.DeleteDevice(device.Address) - if err != nil { - log.Println(username, remoteAddr, "unable to remove wg device: ", err) - } - } - }() - } - - acl := data.GetEffectiveAcl(username) - - wgPublicKey, wgPort, err := mp.firewall.ServerDetails() - if err != nil { - log.Println(username, remoteAddr, "unable access wireguard device: ", err) - http.Error(w, "Server Error", http.StatusInternalServerError) - return - } - - keyStr := privatekey.String() - //Empty value of a private key in wgtype.Key - if keyStr == "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=" { - keyStr = "" - } - - presharedKey, err := user.GetDevicePresharedKey(address) - if err != nil { - log.Println(username, remoteAddr, "unable access device preshared key: ", err) - http.Error(w, "Server Error", http.StatusInternalServerError) - return - } - - dnsWithOutSubnet, err := data.GetDNS() - if err != nil { - log.Println(username, remoteAddr, "unable get dns: ", err) - http.Error(w, "Server Error", http.StatusInternalServerError) - return - } - - for i := 0; i < len(dnsWithOutSubnet); i++ { - dnsWithOutSubnet[i] = strings.TrimSuffix(dnsWithOutSubnet[i], "/32") - } - - routes, err := routetypes.AclsToRoutes(append(acl.Allow, acl.Mfa...)) - if err != nil { - log.Println(username, remoteAddr, "unable access parse acls to produce routes: ", err) - http.Error(w, "Server Error", http.StatusInternalServerError) - return - } - - wireguardInterface := resources.Interface{ - ClientPrivateKey: keyStr, - ClientAddress: address, - ServerPublicKey: wgPublicKey.String(), - CapturedAddresses: routes, - DNS: dnsWithOutSubnet, - ClientPresharedKey: presharedKey, - } - - externalAddress, err := data.GetExternalAddress() - if err != nil { - log.Println(username, remoteAddr, "unable to get server external address from datastore: ", err) - http.Error(w, "Server Error", http.StatusInternalServerError) - return - } - - // If the external address defined in the config has a port, use that, otherwise defaultly add the same port as the wireguard device - _, _, err = net.SplitHostPort(externalAddress) - if err != nil { - externalAddress = fmt.Sprintf("%s:%d", externalAddress, wgPort) - } - - wireguardInterface.ServerAddress = externalAddress - - if r.URL.Query().Get("type") == "mobile" { - w.Header().Set("Content-Type", "text/html; charset=UTF-8") - - var wireguardProfile bytes.Buffer - err = resources.RenderWithFuncs("interface.tmpl", &wireguardProfile, &wireguardInterface, template.FuncMap{ - "StringsJoin": strings.Join, - "Unescape": func(s string) template.HTML { return template.HTML(s) }, - }) - if err != nil { - log.Println(username, remoteAddr, "failed to execute template to generate wireguard config:", err) - http.Error(w, "Server Error", http.StatusInternalServerError) - return - } - - image, err := qr.Encode(wireguardProfile.String(), qr.M, qr.Auto) - if err != nil { - log.Println(username, remoteAddr, "failed to generate qr code:", err) - http.Error(w, "Server Error", http.StatusInternalServerError) - return - } - - image, err = barcode.Scale(image, 400, 400) - if err != nil { - log.Println(username, remoteAddr, "failed to output barcode bytes:", err) - http.Error(w, "Server Error", http.StatusInternalServerError) - return - } - - var buff bytes.Buffer - err = png.Encode(&buff, image) - if err != nil { - log.Println(user.Username, remoteAddr, "encoding mfa secret as png failed:", err) - http.Error(w, "Unknown error", http.StatusInternalServerError) - return - } - - qrCodeBytes := resources.QrCodeRegistrationDisplay{ - ImageData: template.URL("data:image/png;base64, " + base64.StdEncoding.EncodeToString(buff.Bytes())), - Username: username, - } - - err = resources.Render("qrcode_registration.html", w, &qrCodeBytes) - if err != nil { - log.Println(username, remoteAddr, "failed to execute template to show qr code wireguard config:", err) - http.Error(w, "Server Error", http.StatusInternalServerError) - return - } - - } else { - - w.Header().Set("Content-Disposition", "attachment; filename="+data.GetWireguardConfigName()) - - err = resources.RenderWithFuncs("interface.tmpl", w, &wireguardInterface, template.FuncMap{ - "StringsJoin": strings.Join, - "Unescape": func(s string) template.HTML { return template.HTML(s) }, - }) - if err != nil { - log.Println(username, remoteAddr, "failed to execute template to generate wireguard config:", err) - http.Error(w, "Server Error", http.StatusInternalServerError) - return - } - } - - //Finish registration process - err = data.FinaliseRegistration(key) - if err != nil { - log.Println(username, remoteAddr, "expiring registration token failed:", err) - http.Error(w, "Server Error", http.StatusInternalServerError) - return - } - - logMsg := "registered as" - if overwrites != "" { - logMsg = "overwrote" - } - log.Println(username, remoteAddr, "successfully", logMsg, address, ":", publickey.String()) -} - func (mp *MfaPortal) logout(w http.ResponseWriter, r *http.Request) { clientTunnelIp := utils.GetIPFromRequest(r) diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 8954a0a1..8deae955 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -2,15 +2,113 @@ package utils import ( "crypto/rand" + "embed" "encoding/hex" "log" "net" "net/http" + "net/url" + "path/filepath" "strings" "github.com/NHAS/wag/internal/config" ) +func EmbeddedStatic(fs embed.FS) func(w http.ResponseWriter, r *http.Request) { + + return func(w http.ResponseWriter, r *http.Request) { + var err error + var fileContent []byte + + if len(r.URL.Path) > 0 { + r.URL.Path = r.URL.Path[1:] + } + + if fileContent, err = fs.ReadFile(r.URL.Path); err != nil { + log.Println("Error getting static: ", err) + http.NotFound(w, r) + return + } + + headers := w.Header() + ext := filepath.Ext(r.URL.Path) + + switch ext { + case ".js": + headers.Set("Content-Type", "text/javascript") + case ".css": + headers.Set("Content-Type", "text/css") + case ".png": + headers.Set("Content-Type", "image/png") + case ".jpg": + headers.Set("Content-Type", "image/jpg") + case ".svg": + headers.Set("Content-Type", "image/svg") + } + + _, err = w.Write(fileContent) + if err != nil { + log.Println("Unable to write static resource: ", err, " path: ", r.URL.Path) + http.Error(w, "Server Error", 500) + } + } +} + +type httpRedirectHandler struct { + TLSPort string +} + +func (sh *httpRedirectHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + + host, _, err := net.SplitHostPort(r.Host) + if err != nil { + if strings.Contains(err.Error(), "missing port in address") { + host = r.Host + } else { + http.Error(w, "Server Error", http.StatusInternalServerError) + return + } + } + + http.Redirect(w, r, "https://"+host+r.RequestURI, http.StatusTemporaryRedirect) +} + +func SetRedirectHandler(TLSPort string) http.Handler { + return &httpRedirectHandler{TLSPort: TLSPort} +} + +type security struct { + next http.Handler +} + +func (sh *security) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Frame-Options", "DENY") + w.Header().Set("Strict-Transport-Security", "max-age=31536000") + w.Header().Set("X-Content-Type-Options", "nosniff") + + if r.Method != "GET" { + u, err := url.Parse(r.Header.Get("Origin")) + if err != nil { + http.Error(w, "Bad Request", 400) + return + } + + //If origin != host header + if r.Host != u.Host { + http.Error(w, "Bad Request", 400) + return + } + } + + sh.next.ServeHTTP(w, r) +} + +func SetSecurityHeaders(f http.Handler) http.Handler { + return &security{ + next: f, + } +} + func GetIP(addr string) string { for i := len(addr) - 1; i > 0; i-- { if addr[i] == ':' || addr[i] == '/' { diff --git a/pkg/control/server/server.go b/pkg/control/server/server.go index 92daa144..c1c60d37 100644 --- a/pkg/control/server/server.go +++ b/pkg/control/server/server.go @@ -91,7 +91,7 @@ func NewControlServer(firewall *router.Firewall) (*WagControlSocketServer, error } } - log.Println("Started control socket: \n\t\t\t", config.Values.Socket) + log.Println("[CONTROL] Started socket: ", config.Values.Socket) controlMux := httputils.NewMux()