From 9691603d17d736db27d184b3eeb6e24e585eb4cb Mon Sep 17 00:00:00 2001 From: NHAS Date: Mon, 22 Jan 2024 12:02:15 +1300 Subject: [PATCH] Link up management to setting some settings --- internal/data/acls.go | 22 ++-- internal/data/config.go | 199 +++++++++++++++++++++++++++++- internal/data/init.go | 208 +++++++++++++++++++++----------- internal/data/user.go | 8 +- internal/router/statemachine.go | 11 +- internal/users/user.go | 8 +- pkg/control/server/devices.go | 10 +- ui/ui_webserver.go | 48 ++++++-- 8 files changed, 408 insertions(+), 106 deletions(-) diff --git a/internal/data/acls.go b/internal/data/acls.go index 8339558a..68ced3d4 100644 --- a/internal/data/acls.go +++ b/internal/data/acls.go @@ -75,15 +75,8 @@ func GetEffectiveAcl(username string) acls.Acl { //Add the server address by default resultingACLs.Allow = []string{config.Values().Wireguard.ServerAddress.String() + "/32"} - // Add dns servers if defined - // Make sure we resolve the dns servers in case someone added them as domains, so that clients dont get stuck trying to use the domain dns servers to look up the dns servers - // Restrict dns servers to only having 53/any by default as per #49 - for _, server := range config.Values().Wireguard.DNS { - resultingACLs.Allow = append(resultingACLs.Allow, fmt.Sprintf("%s 53/any", server)) - } - txn := etcd.Txn(context.Background()) - txn.Then(clientv3.OpGet("wag-acls-*"), clientv3.OpGet("wag-acls-"+username), clientv3.OpGet("wag-membership")) + txn.Then(clientv3.OpGet("wag-acls-*"), clientv3.OpGet("wag-acls-"+username), clientv3.OpGet("wag-membership"), clientv3.OpGet(dnsKey)) resp, err := txn.Commit() if err != nil { return acls.Acl{} @@ -150,5 +143,18 @@ func GetEffectiveAcl(username string) acls.Acl { } } + // Add dns servers if defined + // Restrict dns servers to only having 53/any by default as per #49 + if resp.Responses[3].GetResponseRange().GetCount() != 0 { + + var dns []string + err = json.Unmarshal(resp.Responses[3].GetResponseRange().Kvs[0].Value, &dns) + if err == nil { + for _, server := range dns { + resultingACLs.Allow = append(resultingACLs.Allow, fmt.Sprintf("%s 53/any", server)) + } + } + } + return resultingACLs } diff --git a/internal/data/config.go b/internal/data/config.go index 23b34f68..dc782f5c 100644 --- a/internal/data/config.go +++ b/internal/data/config.go @@ -1,29 +1,216 @@ package data +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strconv" + + clientv3 "go.etcd.io/etcd/client/v3" +) + +const ( + fullJsonConfigKey = "wag-config-full" + helpMailKey = "wag-config-general-help-mail" + inactivityTimeoutKey = "wag-config-authentication-inactivity-timeout" + sessionLifetimeKey = "wag-config-authentication-max-session-lifetime" + lockoutKey = "wag-config-authentication-lockout" + issuerKey = "wag-config-authentication-issuer" + domainKey = "wag-config-authentication-domain" + defaultMFAMethodKey = "wag-config-authentication-default-method" + externalAddressKey = "wag-config-network-external-address" + dnsKey = "wag-config-network-dns" +) + +func getGeneric(key string) (string, error) { + resp, err := etcd.Get(context.Background(), key) + if err != nil { + return "", err + } + + if len(resp.Kvs) != 1 { + return "", fmt.Errorf("incorrect number of %s keys", key) + } + + return string(resp.Kvs[0].Value), nil +} + +func SetDomain(domain string) error { + _, err := etcd.Put(context.Background(), domainKey, domain) + return err +} + +func GetDomain() (string, error) { + return getGeneric(domainKey) +} + +func SetIssuer(issuer string) error { + _, err := etcd.Put(context.Background(), issuerKey, issuer) + return err +} + +func GetIssuer() (string, error) { + return getGeneric(issuerKey) +} + func SetHelpMail(helpMail string) error { - return nil + _, err := etcd.Put(context.Background(), helpMailKey, helpMail) + return err +} + +func GetHelpMail() (string, error) { + return getGeneric(helpMailKey) } func SetExternalAddress(externalAddress string) error { - return nil + _, err := etcd.Put(context.Background(), externalAddressKey, externalAddress) + return err +} + +func GetExternalAddress() (string, error) { + return getGeneric(externalAddressKey) } func SetDNS(dns []string) error { + jsonData, _ := json.Marshal(dns) + _, err := etcd.Put(context.Background(), dnsKey, string(jsonData)) + return err +} + +func GetDNS() ([]string, error) { + + jsonData, err := getGeneric(dnsKey) + if err != nil { + return nil, err + } - return nil + var servers []string + err = json.Unmarshal([]byte(jsonData), &servers) + if err != nil { + return nil, err + } + + return servers, nil +} + +type Settings struct { + ExternalAddress string + Lockout int + Issuer string + Domain string + SessionInactivityTimeoutMinutes int + MaxSessionLifetimeMinutes int + HelpMail string + DNS []string +} + +func GetAllSettings() (s Settings, err error) { + + txn := etcd.Txn(context.Background()) + response, err := txn.Then(clientv3.OpGet(helpMailKey), + clientv3.OpGet(externalAddressKey), + clientv3.OpGet(inactivityTimeoutKey), + clientv3.OpGet(sessionLifetimeKey), + clientv3.OpGet(lockoutKey), + clientv3.OpGet(dnsKey), + clientv3.OpGet(issuerKey), + clientv3.OpGet(domainKey)).Commit() + if err != nil { + return s, err + } + + if response.Responses[0].GetResponseRange().Count == 1 { + s.HelpMail = string(response.Responses[0].GetResponseRange().Kvs[0].Value) + } + + if response.Responses[1].GetResponseRange().Count == 1 { + s.ExternalAddress = string(response.Responses[1].GetResponseRange().Kvs[0].Value) + } + + if response.Responses[2].GetResponseRange().Count == 1 { + s.SessionInactivityTimeoutMinutes, err = strconv.Atoi(string(response.Responses[2].GetResponseRange().Kvs[0].Value)) + if err != nil { + return + } + } + + if response.Responses[3].GetResponseRange().Count == 1 { + s.MaxSessionLifetimeMinutes, err = strconv.Atoi(string(response.Responses[3].GetResponseRange().Kvs[0].Value)) + if err != nil { + return + } + } + + if response.Responses[4].GetResponseRange().Count == 1 { + s.Lockout, err = strconv.Atoi(string(response.Responses[4].GetResponseRange().Kvs[0].Value)) + if err != nil { + return + } + } + + if response.Responses[5].GetResponseRange().Count == 1 { + err = json.Unmarshal(response.Responses[5].GetResponseRange().Kvs[0].Value, &s.DNS) + if err != nil { + return s, err + } + + } + + if response.Responses[6].GetResponseRange().Count == 1 { + s.HelpMail = string(response.Responses[6].GetResponseRange().Kvs[0].Value) + } + + if response.Responses[7].GetResponseRange().Count == 1 { + s.ExternalAddress = string(response.Responses[7].GetResponseRange().Kvs[0].Value) + } + + return } +// Due to how these functions are used there is quite a highlikelihood that splicing will occur +// We need to update these to make it that it checks the key revision against the pulled version func SetSessionLifetimeMinutes(lifetimeMinutes int) error { + _, err := etcd.Put(context.Background(), sessionLifetimeKey, strconv.Itoa(lifetimeMinutes)) + return err +} - return nil +func GetSessionLifetimeMinutes() (int, error) { + sessionLifeTime, err := getGeneric(sessionLifetimeKey) + if err != nil { + return 0, err + } + + return strconv.Atoi(sessionLifeTime) } func SetSessionInactivityTimeoutMinutes(InactivityTimeout int) error { + _, err := etcd.Put(context.Background(), inactivityTimeoutKey, strconv.Itoa(InactivityTimeout)) + return err +} + +func GetSessionInactivityTimeoutMinutes() (int, error) { + inactivityTimeout, err := getGeneric(inactivityTimeoutKey) + if err != nil { + return 0, err + } - return nil + return strconv.Atoi(inactivityTimeout) } func SetLockout(accountLockout int) error { + if accountLockout < 1 { + return errors.New("cannot set lockout to be below 1 as all accounts would be locked out") + } + _, err := etcd.Put(context.Background(), lockoutKey, strconv.Itoa(accountLockout)) + return err +} + +func GetLockout() (int, error) { + lockout, err := getGeneric(lockoutKey) + if err != nil { + return 0, err + } - return nil + return strconv.Atoi(lockout) } diff --git a/internal/data/init.go b/internal/data/init.go index 89f96a55..8ff49315 100644 --- a/internal/data/init.go +++ b/internal/data/init.go @@ -18,6 +18,7 @@ import ( "github.com/NHAS/wag/pkg/fsops" _ "github.com/mattn/go-sqlite3" clientv3 "go.etcd.io/etcd/client/v3" + "go.etcd.io/etcd/client/v3/clientv3util" "go.etcd.io/etcd/server/v3/embed" ) @@ -117,6 +118,143 @@ func Load(path string) error { log.Println("Successfully connected to etcd") + // This will be kept for 2 major releases with reduced support. + // It is a no-op if a migration has already taken place + err = migrateFromSql() + if err != nil { + return err + } + + // This will stay, so that the config can be used to easily spin up a new wag instance. + // After first run this will be a no-op + err = loadInitialSettings() + if err != nil { + return err + } + + go checkClusterHealth() + go watchEvents() + + return nil +} + +func loadInitialSettings() error { + response, err := etcd.Get(context.Background(), "wag-acls-", clientv3.WithPrefix()) + if err != nil { + return err + } + + if len(response.Kvs) == 0 { + log.Println("no acls found in database, importing from .json file (from this point the json file will be ignored)") + + for aclName, acl := range config.Values().Acls.Policies { + aclJson, _ := json.Marshal(acl) + _, err = etcd.Put(context.Background(), "wag-acls-"+aclName, string(aclJson)) + if err != nil { + return err + } + } + } + + response, err = etcd.Get(context.Background(), "wag-groups-", clientv3.WithPrefix()) + if err != nil { + return err + } + + if len(response.Kvs) == 0 { + log.Println("no groups found in database, importing from .json file (from this point the json file will be ignored)") + + // User to groups + rGroupLookup := map[string]map[string]bool{} + + for groupName, members := range config.Values().Acls.Groups { + groupJson, _ := json.Marshal(members) + _, err = etcd.Put(context.Background(), "wag-groups-"+groupName, string(groupJson)) + if err != nil { + return err + } + + for _, user := range members { + if rGroupLookup[user] == nil { + rGroupLookup[user] = make(map[string]bool) + } + + rGroupLookup[user][groupName] = true + } + } + + reverseMappingJson, _ := json.Marshal(rGroupLookup) + _, err = etcd.Put(context.Background(), "wag-membership", string(reverseMappingJson)) + if err != nil { + return err + } + } + + configData, _ := json.Marshal(config.Values()) + err = putIfNotFound(fullJsonConfigKey, string(configData), "full config") + if err != nil { + return err + } + + err = putIfNotFound(helpMailKey, config.Values().HelpMail, "help mail") + if err != nil { + return err + } + + err = putIfNotFound(externalAddressKey, config.Values().ExternalAddress, "external wag address") + if err != nil { + return err + } + + dnsData, _ := json.Marshal(config.Values().Wireguard.DNS) + err = putIfNotFound(dnsKey, string(dnsData), "dns") + if err != nil { + return err + } + + err = putIfNotFound(inactivityTimeoutKey, fmt.Sprintf("%d", config.Values().SessionInactivityTimeoutMinutes), "inactivity timeout") + if err != nil { + return err + } + + err = putIfNotFound(sessionLifetimeKey, fmt.Sprintf("%d", config.Values().MaxSessionLifetimeMinutes), "max session life") + if err != nil { + return err + } + + err = putIfNotFound(lockoutKey, fmt.Sprintf("%d", config.Values().Lockout), "lockout") + if err != nil { + return err + } + + err = putIfNotFound(issuerKey, config.Values().Authenticators.Issuer, "issuer name") + if err != nil { + return err + } + + err = putIfNotFound(domainKey, config.Values().Authenticators.DomainURL, "domain url") + if err != nil { + return err + } + + return nil +} + +func putIfNotFound(key, value, set string) error { + txn := etcd.Txn(context.Background()) + resp, err := txn.If(clientv3util.KeyMissing(key)).Then(clientv3.OpPut(key, value)).Commit() + if err != nil { + return err + } + + if resp.Succeeded { + log.Printf("setting %s from json, importing from .json file (from this point the json file will be ignored)", set) + } + + return nil +} + +func migrateFromSql() error { response, err := etcd.Get(context.Background(), "wag-migrated-sql") if err != nil { return err @@ -221,76 +359,6 @@ func Load(path string) error { } - response, err = etcd.Get(context.Background(), "wag-acls-", clientv3.WithPrefix()) - if err != nil { - return err - } - - if len(response.Kvs) == 0 { - log.Println("no acls found in database, importing from .json file (from this point the json file will be ignored)") - - for aclName, acl := range config.Values().Acls.Policies { - aclJson, _ := json.Marshal(acl) - _, err = etcd.Put(context.Background(), "wag-acls-"+aclName, string(aclJson)) - if err != nil { - return err - } - } - } - - response, err = etcd.Get(context.Background(), "wag-groups-", clientv3.WithPrefix()) - if err != nil { - return err - } - - if len(response.Kvs) == 0 { - log.Println("no groups found in database, importing from .json file (from this point the json file will be ignored)") - - // User to groups - rGroupLookup := map[string]map[string]bool{} - - for groupName, members := range config.Values().Acls.Groups { - groupJson, _ := json.Marshal(members) - _, err = etcd.Put(context.Background(), "wag-groups-"+groupName, string(groupJson)) - if err != nil { - return err - } - - for _, user := range members { - if rGroupLookup[user] == nil { - rGroupLookup[user] = make(map[string]bool) - } - - rGroupLookup[user][groupName] = true - } - } - - reverseMappingJson, _ := json.Marshal(rGroupLookup) - _, err = etcd.Put(context.Background(), "wag-membership", string(reverseMappingJson)) - if err != nil { - return err - } - - } - - response, err = etcd.Get(context.Background(), "wag-config") - if err != nil { - return err - } - - if len(response.Kvs) == 0 { - log.Println("no config found in database, importing from .json file (from this point the json file will be ignored)") - - configData, _ := json.Marshal(config.Values()) - _, err = etcd.Put(context.Background(), "wag-config", string(configData)) - if err != nil { - return err - } - } - - go checkClusterHealth() - go watchEvents() - return nil } diff --git a/internal/data/user.go b/internal/data/user.go index 5de9269e..d71d48c1 100644 --- a/internal/data/user.go +++ b/internal/data/user.go @@ -7,7 +7,6 @@ import ( "encoding/json" "errors" - "github.com/NHAS/wag/internal/config" clientv3 "go.etcd.io/etcd/client/v3" ) @@ -37,7 +36,12 @@ func IncrementAuthenticationAttempt(username, device string) error { return "", false, err } - if userDevice.Attempts < config.Values().Lockout { + l, err := GetLockout() + if err != nil { + return "", false, err + } + + if userDevice.Attempts < l { userDevice.Attempts++ } diff --git a/internal/router/statemachine.go b/internal/router/statemachine.go index a0304ff4..8520a18d 100644 --- a/internal/router/statemachine.go +++ b/internal/router/statemachine.go @@ -4,7 +4,6 @@ import ( "log" "github.com/NHAS/wag/internal/acls" - "github.com/NHAS/wag/internal/config" "github.com/NHAS/wag/internal/data" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) @@ -45,7 +44,13 @@ func deviceChanges(device data.BasicEvent[data.Device], state int) { } } - if (device.CurrentValue.Attempts != device.Previous.Attempts && device.CurrentValue.Attempts > config.Values().Lockout) || // If the number of authentication attempts on a device has exceeded the max + lockout, err := data.GetLockout() + if err != nil { + log.Println("cannot get lockout:", err) + return + } + + if (device.CurrentValue.Attempts != device.Previous.Attempts && device.CurrentValue.Attempts > lockout) || // If the number of authentication attempts on a device has exceeded the max device.CurrentValue.Endpoint.String() != device.Previous.Endpoint.String() || // If the client ip has changed device.CurrentValue.Authorised.IsZero() { // If we've explicitly deauthorised a device err := Deauthenticate(device.CurrentValue.Address) @@ -57,7 +62,7 @@ func deviceChanges(device data.BasicEvent[data.Device], state int) { if device.CurrentValue.Authorised != device.Previous.Authorised { log.Println("authorisation state changed on device") - if !device.CurrentValue.Authorised.IsZero() && device.CurrentValue.Attempts <= config.Values().Lockout { + if !device.CurrentValue.Authorised.IsZero() && device.CurrentValue.Attempts <= lockout { log.Println("authorising device") err := SetAuthorized(device.CurrentValue.Address, device.CurrentValue.Username) diff --git a/internal/users/user.go b/internal/users/user.go index 930ac83d..9628d4a8 100644 --- a/internal/users/user.go +++ b/internal/users/user.go @@ -6,7 +6,6 @@ import ( "log" "net" - "github.com/NHAS/wag/internal/config" "github.com/NHAS/wag/internal/data" "github.com/NHAS/wag/internal/webserver/authenticators" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -117,7 +116,12 @@ func (u *user) Authenticate(device, mfaType string, authenticator authenticators return err } - if attempts > config.Values().Lockout { + lockout, err := data.GetLockout() + if err != nil { + return err + } + + if attempts > lockout { return errors.New("device is locked") } diff --git a/pkg/control/server/devices.go b/pkg/control/server/devices.go index acbcd45d..af0ca078 100644 --- a/pkg/control/server/devices.go +++ b/pkg/control/server/devices.go @@ -7,7 +7,6 @@ import ( "net/http" "net/url" - "github.com/NHAS/wag/internal/config" "github.com/NHAS/wag/internal/data" "github.com/NHAS/wag/internal/router" "github.com/NHAS/wag/internal/users" @@ -86,7 +85,14 @@ func lockDevice(w http.ResponseWriter, r *http.Request) { return } - err = user.SetDeviceAuthAttempts(address, config.Values().Lockout+1) + lockout, err := data.GetLockout() + if err != nil { + http.Error(w, "could not get lockout number: "+err.Error(), 404) + return + } + + // This will need to be changed at some point to make it that lockout is a state, rather than a simple int + err = user.SetDeviceAuthAttempts(address, lockout+1) if err != nil { http.Error(w, "could not lock device in db: "+err.Error(), 404) return diff --git a/ui/ui_webserver.go b/ui/ui_webserver.go index d2180be9..b0a04849 100644 --- a/ui/ui_webserver.go +++ b/ui/ui_webserver.go @@ -160,7 +160,15 @@ func populateDashboard(w http.ResponseWriter, r *http.Request) { return } - lockout := config.Values().Lockout + lockout, err := data.GetLockout() + if err != nil { + log.Println("error getting lockout: ", err) + + w.WriteHeader(http.StatusInternalServerError) + renderDefaults(w, r, nil, "error.html") + return + } + lockedDevices := 0 activeSessions := 0 for _, d := range allDevices { @@ -625,7 +633,14 @@ func StartWebServer(errs chan<- error) error { return } - c := config.Values() + datastoreSettings, err := data.GetAllSettings() + if err != nil { + log.Println("could not get settings from datastore: ", err) + + w.WriteHeader(http.StatusInternalServerError) + renderDefaults(w, r, nil, "error.html") + return + } d := GeneralSettings{ Page: Page{ @@ -636,20 +651,20 @@ func StartWebServer(errs chan<- error) error { WagVersion: WagVersion, }, - ExternalAddress: c.ExternalAddress, - Lockout: c.Lockout, - Issuer: c.Authenticators.Issuer, - Domain: c.Authenticators.DomainURL, - InactivityTimeoutMinutes: c.SessionInactivityTimeoutMinutes, - SessionLifeTimeMinutes: c.MaxSessionLifetimeMinutes, - HelpMail: c.HelpMail, - DNS: strings.Join(c.Wireguard.DNS, "\n"), + ExternalAddress: datastoreSettings.ExternalAddress, + Lockout: datastoreSettings.Lockout, + Issuer: datastoreSettings.Issuer, + Domain: datastoreSettings.Domain, + InactivityTimeoutMinutes: datastoreSettings.SessionInactivityTimeoutMinutes, + SessionLifeTimeMinutes: datastoreSettings.MaxSessionLifetimeMinutes, + HelpMail: datastoreSettings.HelpMail, + DNS: strings.Join(datastoreSettings.DNS, "\n"), TotpEnabled: true, OidcEnabled: false, WebauthnEnabled: false, } - err := renderDefaults(w, r, d, "settings/general.html") + err = renderDefaults(w, r, d, "settings/general.html") if err != nil { log.Println("unable to render general: ", err) @@ -1284,9 +1299,16 @@ func devicesMgmt(w http.ResponseWriter, r *http.Request) { return } - data := []DevicesData{} + lockout, err := data.GetLockout() + if err != nil { + log.Println("error getting lockout: ", err) - lockout := config.Values().Lockout + w.WriteHeader(http.StatusInternalServerError) + renderDefaults(w, r, nil, "error.html") + return + } + + data := []DevicesData{} for _, dev := range allDevices { data = append(data, DevicesData{