From 724cbc3c5679780e5bd685e673c4b9ca2b80684c Mon Sep 17 00:00:00 2001 From: NHAS Date: Mon, 13 May 2024 21:11:09 +1200 Subject: [PATCH] Fix registration issues, add validation in settings, improve dhcp to avoid deadlock when subnet is exhausted --- go.mod | 5 +++ go.sum | 10 +++++ internal/data/config.go | 72 +++++++++++++++++++++++------- internal/data/devices.go | 2 +- internal/data/dhcp.go | 41 ++++++++++++----- internal/data/user.go | 3 ++ internal/webserver/web.go | 56 ++++++++++++----------- ui/registration.go | 3 ++ ui/templates/settings/general.html | 3 +- 9 files changed, 141 insertions(+), 54 deletions(-) diff --git a/go.mod b/go.mod index ba64d188..d345b7d1 100644 --- a/go.mod +++ b/go.mod @@ -34,8 +34,12 @@ require ( github.com/coreos/go-systemd/v22 v22.3.2 // indirect github.com/dustin/go-humanize v1.0.0 // indirect github.com/fxamacker/cbor/v2 v2.5.0 // indirect + github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/go-logr/logr v1.3.0 // indirect github.com/go-logr/stdr v1.2.2 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-playground/validator/v10 v10.20.0 // indirect github.com/go-webauthn/x v0.1.4 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang-jwt/jwt/v4 v4.4.2 // indirect @@ -54,6 +58,7 @@ require ( github.com/jonboulle/clockwork v0.2.2 // indirect github.com/josharian/native v1.1.0 // indirect github.com/json-iterator/go v1.1.11 // indirect + github.com/leodido/go-urn v1.4.0 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect github.com/mdlayher/genetlink v1.3.2 // indirect github.com/mdlayher/socket v0.4.1 // indirect diff --git a/go.sum b/go.sum index e27a6938..a1a1564c 100644 --- a/go.sum +++ b/go.sum @@ -59,6 +59,8 @@ github.com/envoyproxy/protoc-gen-validate v1.0.2 h1:QkIBuU5k+x7/QXPvPPnWXWlCdaBF github.com/envoyproxy/protoc-gen-validate v1.0.2/go.mod h1:GpiZQP3dDbg4JouG/NNS7QWXpgx6x8QiMKdmN72jogE= github.com/fxamacker/cbor/v2 v2.5.0 h1:oHsG0V/Q6E/wqTS2O1Cozzsy69nqCiguo5Q1a1ADivE= github.com/fxamacker/cbor/v2 v2.5.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= +github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= +github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= @@ -71,6 +73,12 @@ github.com/go-logr/logr v1.3.0 h1:2y3SDp0ZXuc6/cjLSZ+Q3ir+QB9T/iG5yYRXqsagWSY= github.com/go-logr/logr v1.3.0/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8= +github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7eI= github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= @@ -160,6 +168,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= diff --git a/internal/data/config.go b/internal/data/config.go index d7123e47..c8bd6276 100644 --- a/internal/data/config.go +++ b/internal/data/config.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/NHAS/wag/internal/data/validators" + "github.com/go-playground/validator/v10" clientv3 "go.etcd.io/etcd/client/v3" ) @@ -342,21 +343,34 @@ type AllSettings struct { } type LoginSettings struct { - SessionInactivityTimeoutMinutes int - MaxSessionLifetimeMinutes int - Lockout int + SessionInactivityTimeoutMinutes int `validate:"required,number"` + MaxSessionLifetimeMinutes int `validate:"required,number"` + Lockout int `validate:"required,number"` - DefaultMFAMethod string - EnabledMFAMethods []string + DefaultMFAMethod string `validate:"required"` + EnabledMFAMethods []string `validate:"required,lt=10,dive,required"` - Domain string - Issuer string + Domain string `validate:"required"` + Issuer string `validate:"required"` OidcDetails OIDC PamDetails PAM } -func (lg *LoginSettings) ToWriteOps() (ret []clientv3.Op) { +func (lg *LoginSettings) Validate() error { + lg.Domain = strings.TrimSpace(lg.Domain) + lg.Issuer = strings.TrimSpace(lg.Issuer) + + validate := validator.New(validator.WithRequiredStructEnabled()) + + return validate.Struct(lg) +} + +func (lg *LoginSettings) ToWriteOps() (ret []clientv3.Op, err error) { + + if err := lg.Validate(); err != nil { + return nil, err + } b, _ := json.Marshal(lg.SessionInactivityTimeoutMinutes) ret = append(ret, clientv3.OpPut(InactivityTimeoutKey, string(b))) @@ -389,15 +403,33 @@ func (lg *LoginSettings) ToWriteOps() (ret []clientv3.Op) { } type GeneralSettings struct { - HelpMail string - ExternalAddress string - DNS []string + HelpMail string `validate:"required,email"` + ExternalAddress string `validate:"required,hostname|hostname_port|ip"` + DNS []string `validate:"omitempty,dive,ip"` - WireguardConfigFilename string + WireguardConfigFilename string `validate:"required"` CheckUpdates bool } -func (gs *GeneralSettings) ToWriteOps() (ret []clientv3.Op) { +func (gs *GeneralSettings) Validate() error { + + gs.HelpMail = strings.TrimSpace(gs.HelpMail) + gs.ExternalAddress = strings.TrimSpace(gs.ExternalAddress) + gs.WireguardConfigFilename = strings.TrimSpace(gs.WireguardConfigFilename) + for i := range gs.DNS { + gs.DNS[i] = strings.TrimSpace(gs.DNS[i]) + } + + validate := validator.New(validator.WithRequiredStructEnabled()) + + return validate.Struct(gs) +} + +func (gs *GeneralSettings) ToWriteOps() (ret []clientv3.Op, err error) { + + if err := gs.Validate(); err != nil { + return nil, err + } b, _ := json.Marshal(gs.HelpMail) ret = append(ret, clientv3.OpPut(helpMailKey, string(b))) @@ -541,14 +573,24 @@ func GetAllSettings() (s AllSettings, err error) { } func SetLoginSettings(loginSettings LoginSettings) error { + + writeOps, err := loginSettings.ToWriteOps() + if err != nil { + return err + } + txn := etcd.Txn(context.Background()) - _, err := txn.Then(loginSettings.ToWriteOps()...).Commit() + _, err = txn.Then(writeOps...).Commit() return err } func SetGeneralSettings(generalSettings GeneralSettings) error { txn := etcd.Txn(context.Background()) - _, err := txn.Then(generalSettings.ToWriteOps()...).Commit() + writeOPs, err := generalSettings.ToWriteOps() + if err != nil { + return err + } + _, err = txn.Then(writeOPs...).Commit() return err } diff --git a/internal/data/devices.go b/internal/data/devices.go index 1dbfeb09..d86cec91 100644 --- a/internal/data/devices.go +++ b/internal/data/devices.go @@ -194,7 +194,7 @@ func AddDevice(username, publickey string) (Device, error) { return Device{}, err } - address, err := getNextIP(config.Values.Wireguard.Range.String()) + address, err := getNextIP(config.Values.Wireguard.Address) if err != nil { return Device{}, err } diff --git a/internal/data/dhcp.go b/internal/data/dhcp.go index e078e5c1..226c5bb4 100644 --- a/internal/data/dhcp.go +++ b/internal/data/dhcp.go @@ -31,28 +31,39 @@ func getNextIP(subnet string) (string, error) { } used, _ := cidr.Mask.Size() - addresses := int(math.Pow(2, float64(32-used))) - 2 // Do not allocate largest address or 0 - if addresses < 1 { - return "", errors.New("no addresses available") + maxNumberOfAddresses := int(math.Pow(2, float64(32-used))) - 2 // Do not allocate largest address or 0 + if maxNumberOfAddresses < 1 { + return "", errors.New("subnet is too small to contain a new device") } // Choose a random number that cannot be 0 - addressAttempt := rand.Intn(addresses) + 1 + addressAttempt := rand.Intn(maxNumberOfAddresses) + 1 addr := incrementIP(cidr.IP, uint(addressAttempt)) - if serverIP.Equal(addr) { - addr = incrementIP(addr, 1) - } - lease, err := clientv3.NewLease(etcd).Grant(context.Background(), 3) if err != nil { return "", err } + if serverIP.Equal(addr) { + addr = incrementIP(addr, 1) + } + + startIP := addr for { + + if serverIP.Equal(addr) { + addr = incrementIP(addr, 1) + } + txn := etcd.Txn(context.Background()) - txn.If(clientv3util.KeyMissing("deviceref-"+addr.String()), clientv3util.KeyMissing("ip-hold-"+addr.String())) - txn.Then(clientv3.OpPut("ip-hold-"+addr.String(), addr.String(), clientv3.WithLease(lease.ID))) + txn.If( + clientv3util.KeyMissing("deviceref-"+addr.String()), + clientv3util.KeyMissing("ip-hold-"+addr.String()), + ) + txn.Then( + clientv3.OpPut("ip-hold-"+addr.String(), addr.String(), clientv3.WithLease(lease.ID)), + ) resp, err := txn.Commit() if err != nil { @@ -64,6 +75,16 @@ func getNextIP(subnet string) (string, error) { } addr = incrementIP(addr, 1) + if cidr.Contains(addr) { + continue + } else { + addr = incrementIP(cidr.IP, 1) + } + + if addr.Equal(startIP) { + return "", errors.New("unable to obtain ip lease, subnet is full") + } + } } diff --git a/internal/data/user.go b/internal/data/user.go index e339667f..5c182137 100644 --- a/internal/data/user.go +++ b/internal/data/user.go @@ -7,6 +7,7 @@ import ( "errors" "strings" + "github.com/NHAS/wag/internal/webserver/authenticators/types" clientv3 "go.etcd.io/etcd/client/v3" "go.etcd.io/etcd/client/v3/clientv3util" ) @@ -323,6 +324,8 @@ func CreateUserDataAccount(username string) (UserModel, error) { newUser := UserModel{ Username: username, + Mfa: string(types.Unset), + MfaType: string(types.Unset), } b, _ := json.Marshal(&newUser) diff --git a/internal/webserver/web.go b/internal/webserver/web.go index 7672dc26..100d0899 100644 --- a/internal/webserver/web.go +++ b/internal/webserver/web.go @@ -330,7 +330,7 @@ func registerMFA(w http.ResponseWriter, r *http.Request) { err = resources.Render("register_mfa.html", w, &menu) if err != nil { log.Println(user.Username, clientTunnelIp, "unable to build template:", err) - http.Error(w, "Server error", 500) + http.Error(w, "Server error", http.StatusInternalServerError) } return @@ -435,7 +435,7 @@ func registerDevice(w http.ResponseWriter, r *http.Request) { err := data.SetUserGroupMembership(username, groups) if err != nil { log.Println(username, remoteAddr, "could not set user membership from registration token:", err) - http.Error(w, "Server error", 500) + http.Error(w, "Server error", http.StatusInternalServerError) return } } @@ -452,14 +452,14 @@ func registerDevice(w http.ResponseWriter, r *http.Request) { publickey, err = wgtypes.ParseKey(pubkeyParam) if err != nil { log.Println(username, remoteAddr, "failed to unmarshal wireguard public key:", err) - http.Error(w, "Server error", 500) + http.Error(w, "Server error", http.StatusInternalServerError) return } } else { privatekey, err = wgtypes.GeneratePrivateKey() if err != nil { log.Println(username, remoteAddr, "failed to generate wireguard keys:", err) - http.Error(w, "Server error", 500) + http.Error(w, "Server error", http.StatusInternalServerError) return } publickey = privatekey.PublicKey() @@ -470,18 +470,20 @@ func registerDevice(w http.ResponseWriter, r *http.Request) { user, err = users.CreateUser(username) if err != nil { log.Println(username, remoteAddr, "unable create new user: "+err.Error()) - http.Error(w, "Server Error", 500) + http.Error(w, "Server Error", http.StatusInternalServerError) return } } - var address string + var ( + address string + ) if overwrites != "" { err = user.SetDevicePublicKey(publickey.String(), overwrites) if err != nil { log.Println(username, remoteAddr, "could update '", overwrites, "': ", err) - http.Error(w, "Server Error", 500) + http.Error(w, "Server Error", http.StatusInternalServerError) return } @@ -489,16 +491,18 @@ func registerDevice(w http.ResponseWriter, r *http.Request) { } else { - device, err := user.AddDevice(publickey) + var device data.Device + device, err = user.AddDevice(publickey) if err != nil { log.Println(username, remoteAddr, "unable to add device: ", err) - http.Error(w, "Server Error", 500) + http.Error(w, "Server Error", http.StatusInternalServerError) return } address = device.Address defer func() { + if err != nil { log.Println(username, remoteAddr, "removing device (due to registration failure)") err := user.DeleteDevice(device.Address) @@ -514,7 +518,7 @@ func registerDevice(w http.ResponseWriter, r *http.Request) { wgPublicKey, wgPort, err := router.ServerDetails() if err != nil { log.Println(username, remoteAddr, "unable access wireguard device: ", err) - http.Error(w, "Server Error", 500) + http.Error(w, "Server Error", http.StatusInternalServerError) return } @@ -527,14 +531,14 @@ func registerDevice(w http.ResponseWriter, r *http.Request) { presharedKey, err := user.GetDevicePresharedKey(address) if err != nil { log.Println(username, remoteAddr, "unable access device preshared key: ", err) - http.Error(w, "Server Error", 500) + http.Error(w, "Server Error", http.StatusInternalServerError) return } dnsWithOutSubnet, err := data.GetDNS() if err != nil { log.Println(username, remoteAddr, "unable get dns: ", err) - http.Error(w, "Server Error", 500) + http.Error(w, "Server Error", http.StatusInternalServerError) return } @@ -545,7 +549,7 @@ func registerDevice(w http.ResponseWriter, r *http.Request) { routes, err := routetypes.AclsToRoutes(append(acl.Allow, acl.Mfa...)) if err != nil { log.Println(username, remoteAddr, "unable access parse acls to produce routes: ", err) - http.Error(w, "Server Error", 500) + http.Error(w, "Server Error", http.StatusInternalServerError) return } @@ -561,7 +565,7 @@ func registerDevice(w http.ResponseWriter, r *http.Request) { externalAddress, err := data.GetExternalAddress() if err != nil { log.Println(username, remoteAddr, "unable to get server external address from datastore: ", err) - http.Error(w, "Server Error", 500) + http.Error(w, "Server Error", http.StatusInternalServerError) return } @@ -583,21 +587,21 @@ func registerDevice(w http.ResponseWriter, r *http.Request) { }) if err != nil { log.Println(username, remoteAddr, "failed to execute template to generate wireguard config:", err) - http.Error(w, "Server Error", 500) + http.Error(w, "Server Error", http.StatusInternalServerError) return } image, err := qr.Encode(config.String(), qr.M, qr.Auto) if err != nil { log.Println(username, remoteAddr, "failed to generate qr code:", err) - http.Error(w, "Server Error", 500) + http.Error(w, "Server Error", http.StatusInternalServerError) return } image, err = barcode.Scale(image, 400, 400) if err != nil { log.Println(username, remoteAddr, "failed to output barcode bytes:", err) - http.Error(w, "Server Error", 500) + http.Error(w, "Server Error", http.StatusInternalServerError) return } @@ -605,7 +609,7 @@ func registerDevice(w http.ResponseWriter, r *http.Request) { err = png.Encode(&buff, image) if err != nil { log.Println(user.Username, remoteAddr, "encoding mfa secret as png failed:", err) - http.Error(w, "Unknown error", 500) + http.Error(w, "Unknown error", http.StatusInternalServerError) return } @@ -617,7 +621,7 @@ func registerDevice(w http.ResponseWriter, r *http.Request) { err = resources.Render("qrcode_registration.html", w, &qr) if err != nil { log.Println(username, remoteAddr, "failed to execute template to show qr code wireguard config:", err) - http.Error(w, "Server Error", 500) + http.Error(w, "Server Error", http.StatusInternalServerError) return } @@ -631,7 +635,7 @@ func registerDevice(w http.ResponseWriter, r *http.Request) { }) if err != nil { log.Println(username, remoteAddr, "failed to execute template to generate wireguard config:", err) - http.Error(w, "Server Error", 500) + http.Error(w, "Server Error", http.StatusInternalServerError) return } } @@ -640,7 +644,7 @@ func registerDevice(w http.ResponseWriter, r *http.Request) { err = data.FinaliseRegistration(key) if err != nil { log.Println(username, remoteAddr, "expiring registration token failed:", err) - http.Error(w, "Server Error", 500) + http.Error(w, "Server Error", http.StatusInternalServerError) return } @@ -674,7 +678,7 @@ func logout(w http.ResponseWriter, r *http.Request) { err = user.Deauthenticate(clientTunnelIp.String()) if err != nil { log.Println(user.Username, clientTunnelIp, "could not deauthenticate:", err) - http.Error(w, "Server Error", 500) + http.Error(w, "Server Error", http.StatusInternalServerError) return } @@ -697,14 +701,14 @@ func routes(w http.ResponseWriter, r *http.Request) { user, err := users.GetUserFromAddress(remoteAddress) if err != nil { log.Println(user.Username, remoteAddress, "Could not find user: ", err) - http.Error(w, "Server Error", 500) + http.Error(w, "Server Error", http.StatusInternalServerError) return } routes, err := router.GetRoutes(user.Username) if err != nil { log.Println(user.Username, remoteAddress, "Getting routes from xdp failed: ", err) - http.Error(w, "Server Error", 500) + http.Error(w, "Server Error", http.StatusInternalServerError) return } @@ -724,7 +728,7 @@ func status(w http.ResponseWriter, r *http.Request) { user, err := users.GetUserFromAddress(remoteAddress) if err != nil { log.Println(user.Username, remoteAddress, "Could not find user: ", err) - http.Error(w, "Server Error", 500) + http.Error(w, "Server Error", http.StatusInternalServerError) return } @@ -764,7 +768,7 @@ func publicKey(w http.ResponseWriter, r *http.Request) { wgPublicKey, _, err := router.ServerDetails() if err != nil { log.Println("unable access wireguard device: ", err) - http.Error(w, "Server Error", 500) + http.Error(w, "Server Error", http.StatusInternalServerError) return } diff --git a/ui/registration.go b/ui/registration.go index 93d9aaa3..d460db54 100644 --- a/ui/registration.go +++ b/ui/registration.go @@ -124,6 +124,9 @@ func registrationTokens(w http.ResponseWriter, r *http.Request) { return } + b.Username = strings.TrimSpace(b.Username) + b.Overwrites = strings.TrimSpace(b.Overwrites) + uses, err := strconv.Atoi(b.Uses) if err != nil { log.Println("client sent invalid number for token number of usees") diff --git a/ui/templates/settings/general.html b/ui/templates/settings/general.html index 87a5842d..eef6556c 100755 --- a/ui/templates/settings/general.html +++ b/ui/templates/settings/general.html @@ -36,8 +36,7 @@
General
- +