Skip to content

Commit

Permalink
Remove overly complex behavior in safe update, fix admin login failin…
Browse files Browse the repository at this point in the history
…g open
  • Loading branch information
NHAS committed Jan 22, 2024
1 parent 72e9c45 commit 03fc035
Show file tree
Hide file tree
Showing 11 changed files with 238 additions and 141 deletions.
13 changes: 0 additions & 13 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,6 @@ type Acls struct {
Policies map[string]*acls.Acl
}

func (a Acls) GetUserGroups(username string) (result []string) {
if values.Acls.rGroupLookup == nil {
return []string{}
}

result = make([]string, 0, len(values.Acls.rGroupLookup[username]))
for group := range values.Acls.rGroupLookup[username] {
result = append(result, group)
}

return
}

type Config struct {
path string
Socket string `json:",omitempty"`
Expand Down
4 changes: 2 additions & 2 deletions internal/data/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,11 @@ func GetAllSettings() (s Settings, err error) {
}

if response.Responses[6].GetResponseRange().Count == 1 {
s.HelpMail = string(response.Responses[6].GetResponseRange().Kvs[0].Value)
s.Issuer = string(response.Responses[6].GetResponseRange().Kvs[0].Value)
}

if response.Responses[7].GetResponseRange().Count == 1 {
s.ExternalAddress = string(response.Responses[7].GetResponseRange().Kvs[0].Value)
s.Domain = string(response.Responses[7].GetResponseRange().Kvs[0].Value)
}

return
Expand Down
66 changes: 22 additions & 44 deletions internal/data/devices.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,9 @@ import (
"errors"
"fmt"
"net"
"strconv"
"strings"
"time"

"github.com/NHAS/wag/internal/config"
"github.com/NHAS/wag/internal/utils"
clientv3 "go.etcd.io/etcd/client/v3"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
Expand All @@ -28,25 +25,6 @@ type Device struct {
Authorised time.Time
}

func stringToUDPaddr(address string) (r *net.UDPAddr) {
parts := strings.Split(address, ":")
if len(parts) < 2 {
return nil
}

port, err := strconv.Atoi(parts[len(parts)-1])
if err != nil {
return nil
}

r = &net.UDPAddr{
IP: net.ParseIP(utils.GetIP(address)),
Port: port,
}

return
}

func UpdateDeviceEndpoint(address string, endpoint *net.UDPAddr) error {

realKey, err := etcd.Get(context.Background(), "deviceref-"+address)
Expand All @@ -58,22 +36,22 @@ func UpdateDeviceEndpoint(address string, endpoint *net.UDPAddr) error {
return errors.New("device was not found")
}

return doSafeUpdate(context.Background(), string(realKey.Kvs[0].Value), func(gr *clientv3.GetResponse) (string, bool, error) {
return doSafeUpdate(context.Background(), string(realKey.Kvs[0].Value), func(gr *clientv3.GetResponse) (string, error) {
if len(gr.Kvs) != 1 {
return "", false, errors.New("user device has multiple keys")
return "", errors.New("user device has multiple keys")
}

var device Device
err := json.Unmarshal(gr.Kvs[0].Value, &device)
if err != nil {
return "", false, err
return "", err
}

device.Endpoint = endpoint

b, _ := json.Marshal(device)

return string(b), false, err
return string(b), err
})
}

Expand All @@ -99,33 +77,33 @@ func GetDevice(username, id string) (device Device, err error) {

// Set device as authorized and clear authentication attempts
func AuthoriseDevice(username, address string) error {
return doSafeUpdate(context.Background(), deviceKey(username, address), func(gr *clientv3.GetResponse) (string, bool, error) {
return doSafeUpdate(context.Background(), deviceKey(username, address), func(gr *clientv3.GetResponse) (string, error) {
if len(gr.Kvs) != 1 {
return "", false, errors.New("user device has multiple keys")
return "", errors.New("user device has multiple keys")
}

var device Device
err := json.Unmarshal(gr.Kvs[0].Value, &device)
if err != nil {
return "", false, err
return "", err
}

u, err := GetUserData(device.Username)
if err != nil {
// We may want to make this lock the device if the user is not found. At the moment settle with doing nothing
return "", false, err
return "", err
}

if u.Locked {
return "", false, errors.New("account is locked")
return "", errors.New("account is locked")
}

device.Authorised = time.Now()
device.Attempts = 0

b, _ := json.Marshal(device)

return string(b), false, err
return string(b), err
})
}

Expand All @@ -140,42 +118,42 @@ func DeauthenticateDevice(address string) error {
return errors.New("device was not found")
}

return doSafeUpdate(context.Background(), string(realKey.Kvs[0].Value), func(gr *clientv3.GetResponse) (string, bool, error) {
return doSafeUpdate(context.Background(), string(realKey.Kvs[0].Value), func(gr *clientv3.GetResponse) (string, error) {
if len(gr.Kvs) != 1 {
return "", false, errors.New("user device has multiple keys")
return "", errors.New("user device has multiple keys")
}

var device Device
err := json.Unmarshal(gr.Kvs[0].Value, &device)
if err != nil {
return "", false, err
return "", err
}

device.Authorised = time.Time{}

b, _ := json.Marshal(device)

return string(b), false, err
return string(b), err
})
}

func SetDeviceAuthenticationAttempts(username, address string, attempts int) error {
return doSafeUpdate(context.Background(), deviceKey(username, address), func(gr *clientv3.GetResponse) (string, bool, error) {
return doSafeUpdate(context.Background(), deviceKey(username, address), func(gr *clientv3.GetResponse) (string, error) {
if len(gr.Kvs) != 1 {
return "", false, errors.New("user device has multiple keys")
return "", errors.New("user device has multiple keys")
}

var device Device
err := json.Unmarshal(gr.Kvs[0].Value, &device)
if err != nil {
return "", false, err
return "", err
}

device.Attempts = attempts

b, _ := json.Marshal(device)

return string(b), false, err
return string(b), err
})
}

Expand Down Expand Up @@ -326,22 +304,22 @@ func UpdateDevicePublicKey(username, address string, publicKey wgtypes.Key) erro
return err
}

err = doSafeUpdate(context.Background(), deviceKey(username, address), func(gr *clientv3.GetResponse) (string, bool, error) {
err = doSafeUpdate(context.Background(), deviceKey(username, address), func(gr *clientv3.GetResponse) (string, error) {
if len(gr.Kvs) != 1 {
return "", false, errors.New("user device has multiple keys")
return "", errors.New("user device has multiple keys")
}

var device Device
err := json.Unmarshal(gr.Kvs[0].Value, &device)
if err != nil {
return "", false, err
return "", err
}

device.Publickey = publicKey.String()

b, _ := json.Marshal(device)

return string(b), false, err
return string(b), err
})

if err != nil {
Expand Down
38 changes: 30 additions & 8 deletions internal/data/groups.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/NHAS/wag/pkg/control"
clientv3 "go.etcd.io/etcd/client/v3"
"golang.org/x/exp/maps"
)

func SetGroup(group string, members []string, overwrite bool) error {
Expand Down Expand Up @@ -36,16 +37,16 @@ func SetGroup(group string, members []string, overwrite bool) error {
}
}

err = doSafeUpdate(context.Background(), "wag-membership", func(gr *clientv3.GetResponse) (value string, onErrwrite bool, err error) {
err = doSafeUpdate(context.Background(), "wag-membership", func(gr *clientv3.GetResponse) (value string, err error) {

if len(gr.Kvs) != 1 {
return "", false, errors.New("bad number of membership keys")
return "", errors.New("bad number of membership keys")
}

var rGroupLookup map[string]map[string]bool
err = json.Unmarshal(gr.Kvs[0].Value, &rGroupLookup)
if err != nil {
return "", false, err
return "", err
}

for _, member := range oldMembers {
Expand All @@ -62,7 +63,7 @@ func SetGroup(group string, members []string, overwrite bool) error {

reverseMappingJson, _ := json.Marshal(rGroupLookup)

return string(reverseMappingJson), false, nil
return string(reverseMappingJson), nil
})

return err
Expand Down Expand Up @@ -111,16 +112,16 @@ func RemoveGroup(groupName string) error {
}
}

err = doSafeUpdate(context.Background(), "wag-membership", func(gr *clientv3.GetResponse) (value string, onErrwrite bool, err error) {
err = doSafeUpdate(context.Background(), "wag-membership", func(gr *clientv3.GetResponse) (value string, err error) {

if len(gr.Kvs) != 1 {
return "", false, errors.New("bad number of membership keys")
return "", errors.New("bad number of membership keys")
}

var rGroupLookup map[string]map[string]bool
err = json.Unmarshal(gr.Kvs[0].Value, &rGroupLookup)
if err != nil {
return "", false, err
return "", err
}

for _, member := range oldMembers {
Expand All @@ -129,8 +130,29 @@ func RemoveGroup(groupName string) error {

reverseMappingJson, _ := json.Marshal(rGroupLookup)

return string(reverseMappingJson), false, nil
return string(reverseMappingJson), nil
})

return err
}

func GetUserGroupMembership(username string) ([]string, error) {

response, err := etcd.Get(context.Background(), "wag-membership")
if err != nil {
return nil, err
}

var rGroupLookup map[string]map[string]bool

err = json.Unmarshal(response.Kvs[0].Value, &rGroupLookup)
if err != nil {
return nil, err
}

if rGroupLookup[username] == nil {
return []string{}, nil
}

return maps.Keys(rGroupLookup[username]), nil
}
6 changes: 3 additions & 3 deletions internal/data/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ func TearDown() {
}
}

func doSafeUpdate(ctx context.Context, key string, mutateFunc func(*clientv3.GetResponse) (value string, onErrwrite bool, err error)) error {
func doSafeUpdate(ctx context.Context, key string, mutateFunc func(*clientv3.GetResponse) (value string, err error)) error {
//https://github.com/kubernetes/kubernetes/blob/master/staging/src/k8s.io/apiserver/pkg/storage/etcd3/store.go#L382
opts := []clientv3.OpOption{}

Expand All @@ -387,8 +387,8 @@ func doSafeUpdate(ctx context.Context, key string, mutateFunc func(*clientv3.Get
return errors.New("no record found")
}

newValue, onErrwrite, err := mutateFunc(origState)
if err != nil && !onErrwrite {
newValue, err := mutateFunc(origState)
if err != nil {
return err
}

Expand Down
6 changes: 3 additions & 3 deletions internal/data/registration.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,12 @@ func DeleteRegistrationToken(identifier string) error {
func FinaliseRegistration(token string) error {

errVal := errors.New("registration token has expired")
err := doSafeUpdate(context.Background(), "tokens-"+token, func(gr *clientv3.GetResponse) (string, bool, error) {
err := doSafeUpdate(context.Background(), "tokens-"+token, func(gr *clientv3.GetResponse) (string, error) {

var result control.RegistrationResult
err := json.Unmarshal(gr.Kvs[0].Value, &result)
if err != nil {
return "", false, err
return "", err
}

result.NumUses--
Expand All @@ -94,7 +94,7 @@ func FinaliseRegistration(token string) error {

b, _ := json.Marshal(result)

return string(b), false, err
return string(b), err
})

if err == errVal {
Expand Down
23 changes: 23 additions & 0 deletions internal/data/sql_compat.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@ package data
import (
"database/sql"
"encoding/json"
"net"
"strconv"
"strings"

"github.com/NHAS/wag/internal/utils"
"github.com/NHAS/wag/pkg/control"
)

Expand Down Expand Up @@ -119,3 +123,22 @@ func sqlGetAllDevices() (devices []Device, err error) {

return devices, nil
}

func stringToUDPaddr(address string) (r *net.UDPAddr) {
parts := strings.Split(address, ":")
if len(parts) < 2 {
return nil
}

port, err := strconv.Atoi(parts[len(parts)-1])
if err != nil {
return nil
}

r = &net.UDPAddr{
IP: net.ParseIP(utils.GetIP(address)),
Port: port,
}

return
}
Loading

0 comments on commit 03fc035

Please sign in to comment.