Skip to content

Commit

Permalink
Start removing config.Values()
Browse files Browse the repository at this point in the history
  • Loading branch information
NHAS committed Jan 22, 2024
1 parent 3fde540 commit 7f98ea4
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 27 deletions.
24 changes: 21 additions & 3 deletions internal/data/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type OIDC struct {
IssuerURL string
ClientSecret string
ClientID string
GroupsClaimName string `json:",omitempty"`
GroupsClaimName string
}

type PAM struct {
Expand Down Expand Up @@ -140,8 +140,26 @@ func SetWireguardConfigName(wgConfig string) error {
return err
}

func GetWireguardConfigName() (string, error) {
return getGeneric(defaultWGFileNameKey)
func GetWireguardConfigName() string {
k, err := getGeneric(defaultWGFileNameKey)
if err != nil {
return "wg0.conf"
}

if k == "" {
return "wg0.conf"
}

return k
}

func SetDefaultMfaMethod(method string) error {
_, err := etcd.Put(context.Background(), defaultMFAMethodKey, method)
return err
}

func GetDefaultMfaMethod() (string, error) {
return getGeneric(defaultMFAMethodKey)
}

func SetAuthenticationMethods(methods []string) error {
Expand Down
34 changes: 27 additions & 7 deletions internal/router/bpf.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,13 @@ func loadXDP() error {
return fmt.Errorf("loading objects: %s", err)
}

value := uint64(config.Values().SessionInactivityTimeoutMinutes) * 60000000000
if config.Values().SessionInactivityTimeoutMinutes < 0 {
sessionInactivityTimeoutMinutes, err := data.GetSessionInactivityTimeoutMinutes()
if err != nil {
return err
}

value := uint64(sessionInactivityTimeoutMinutes) * 60000000000
if sessionInactivityTimeoutMinutes < 0 {
value = math.MaxUint64
}

Expand Down Expand Up @@ -218,11 +223,16 @@ func isAuthed(address string) bool {
return false
}

inactivityTimeoutMinutes, err := data.GetSessionInactivityTimeoutMinutes()
if err != nil {
return false
}

currentTime := GetTimeStamp()

sessionValid := (deviceStruct.sessionExpiry > currentTime || deviceStruct.sessionExpiry == math.MaxUint64)

sessionActive := ((currentTime-deviceStruct.lastPacketTime) < uint64(config.Values().SessionInactivityTimeoutMinutes)*60000000000 || config.Values().SessionInactivityTimeoutMinutes < 0)
sessionActive := ((currentTime-deviceStruct.lastPacketTime) < uint64(inactivityTimeoutMinutes)*60000000000 || inactivityTimeoutMinutes < 0)

return isAccountLocked == 0 && sessionValid && sessionActive
}
Expand Down Expand Up @@ -536,8 +546,13 @@ func RefreshConfiguration() []error {
return []error{err}
}

value := uint64(config.Values().SessionInactivityTimeoutMinutes) * 60000000000
if config.Values().SessionInactivityTimeoutMinutes < 0 {
inactivityTimeoutMinutes, err := data.GetSessionInactivityTimeoutMinutes()
if err != nil {
return []error{err}
}

value := uint64(inactivityTimeoutMinutes) * 60000000000
if inactivityTimeoutMinutes < 0 {
value = math.MaxUint64
}

Expand Down Expand Up @@ -575,8 +590,13 @@ func SetAuthorized(internalAddress, username string) error {
var deviceStruct fwentry
deviceStruct.lastPacketTime = GetTimeStamp()

deviceStruct.sessionExpiry = GetTimeStamp() + uint64(config.Values().MaxSessionLifetimeMinutes)*60000000000
if config.Values().MaxSessionLifetimeMinutes < 0 {
maxSession, err := data.GetSessionLifetimeMinutes()
if err != nil {
return err
}

deviceStruct.sessionExpiry = GetTimeStamp() + uint64(maxSession)*60000000000
if maxSession < 0 {
deviceStruct.sessionExpiry = math.MaxUint64 // If the session timeout is disabled, (<0) then we set to max value
}

Expand Down
11 changes: 8 additions & 3 deletions internal/webserver/authenticators/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"net/http"
"strings"

"github.com/NHAS/wag/internal/config"
"github.com/NHAS/wag/internal/data"
)

// from: https://github.com/duo-labs/webauthn.io/blob/3f03b482d21476f6b9fb82b2bf1458ff61a61d41/server/response.go#L15
Expand All @@ -25,11 +25,16 @@ func resultMessage(err error) (string, int) {
return "Success", http.StatusOK
}

mail, err := data.GetHelpMail()
if err != nil {
mail = "Server Error"
}

msg := "Validation failed"
if strings.Contains(err.Error(), "account is locked") {
msg = "Account is locked contact: " + config.Values().HelpMail
msg = "Account is locked contact: " + mail
} else if strings.Contains(err.Error(), "device is locked") {
msg = "Device is locked contact: " + config.Values().HelpMail
msg = "Device is locked contact: " + mail
}
return msg, http.StatusBadRequest
}
17 changes: 12 additions & 5 deletions internal/webserver/authenticators/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ type issuer struct {

type Oidc struct {
provider rp.RelyingParty
details data.OIDC
}

func (o Oidc) state() string {
Expand Down Expand Up @@ -72,12 +73,12 @@ func (o *Oidc) Init() error {

log.Println("Connecting to OIDC provider")

oidc, err := data.GetOidc()
o.details, err = data.GetOidc()
if err != nil {
return err
}

o.provider, err = rp.NewRelyingPartyOIDC(oidc.IssuerURL, oidc.ClientID, oidc.ClientSecret, u.String(), []string{"openid"}, options...)
o.provider, err = rp.NewRelyingPartyOIDC(o.details.IssuerURL, o.details.ClientID, o.details.ClientSecret, u.String(), []string{"openid"}, options...)
if err != nil {
return err
}
Expand Down Expand Up @@ -156,7 +157,7 @@ func (o *Oidc) AuthorisationAPI(w http.ResponseWriter, r *http.Request) {

marshalUserinfo := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty, info oidc.UserInfo) {

groupsIntf, ok := tokens.IDTokenClaims.GetClaim(config.Values().Authenticators.OIDC.GroupsClaimName).([]interface{})
groupsIntf, ok := tokens.IDTokenClaims.GetClaim(o.details.GroupsClaimName).([]interface{})
if !ok {
log.Println("Error, could not convert group claim to []string, probably error in oidc idP configuration")

Expand Down Expand Up @@ -209,8 +210,14 @@ func (o *Oidc) AuthorisationAPI(w http.ResponseWriter, r *http.Request) {

w.WriteHeader(http.StatusUnauthorized)

err := resources.Render("oidc_error.html", w, &resources.Msg{
HelpMail: config.Values().HelpMail,
mail, err := data.GetHelpMail()
if err != nil {
log.Println("Error getting help mail: ", err)
http.Error(w, "Server Error", http.StatusInternalServerError)
return
}
err = resources.Render("oidc_error.html", w, &resources.Msg{
HelpMail: mail,
NumMethods: NumberOfMethods(),
Message: msg,
URL: rp.GetEndSessionEndpoint(),
Expand Down
14 changes: 9 additions & 5 deletions internal/webserver/authenticators/pam.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,16 +135,20 @@ func (t *Pam) AuthoriseFunc(w http.ResponseWriter, r *http.Request) types.Authen

passwd := r.FormValue("password")

serviceName := config.Values().Authenticators.PAM.ServiceName
pamDetails, err := data.GetPAM()
if err != nil {
http.Error(w, "Unable to get pam details: "+err.Error(), 500)
return err
}

pamRulesFile := "config /etc/pam.d/" + serviceName
if serviceName == "" {
serviceName = "login"
pamRulesFile := "config /etc/pam.d/" + pamDetails.ServiceName
if pamDetails.ServiceName == "" {
pamDetails.ServiceName = "login"
pamRulesFile = "default PAM /etc/pam.d/login"
}

log.Println(username, "attempting to authorise with PAM (using ", pamRulesFile, ")")
t, err := pam.StartFunc(serviceName, username, func(s pam.Style, msg string) (string, error) {
t, err := pam.StartFunc(pamDetails.ServiceName, username, func(s pam.Style, msg string) (string, error) {

switch s {
case pam.PromptEchoOff:
Expand Down
15 changes: 12 additions & 3 deletions internal/webserver/web.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,10 @@ func registerMFA(w http.ResponseWriter, r *http.Request) {

method := r.URL.Query().Get("method")
if method == "" {
method = config.Values().Authenticators.DefaultMethod
method, err = data.GetDefaultMfaMethod()
if err != nil {
method = ""
}
}

if method == "" || method == "select" {
Expand Down Expand Up @@ -465,7 +468,12 @@ func registerDevice(w http.ResponseWriter, r *http.Request) {
return
}

dnsWithOutSubnet := config.Values().Wireguard.DNS
dnsWithOutSubnet, err := data.GetDNS()
if err != nil {
log.Println(username, remoteAddr, "unable get dns: ", err)
http.Error(w, "Server Error", 500)
return
}

for i := 0; i < len(dnsWithOutSubnet); i++ {
dnsWithOutSubnet[i] = strings.TrimSuffix(dnsWithOutSubnet[i], "/32")
Expand Down Expand Up @@ -543,7 +551,8 @@ func registerDevice(w http.ResponseWriter, r *http.Request) {
}

} else {
w.Header().Set("Content-Disposition", "attachment; filename="+config.Values().DownloadConfigFileName)

w.Header().Set("Content-Disposition", "attachment; filename="+data.GetWireguardConfigName())

err = resources.RenderWithFuncs("interface.tmpl", w, &wireguardInterface, template.FuncMap{
"StringsJoin": strings.Join,
Expand Down
4 changes: 3 additions & 1 deletion ui/check_updates.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/NHAS/wag/internal/config"
"github.com/NHAS/wag/internal/data"
)

type githubResponse struct {
Expand All @@ -32,7 +33,8 @@ var (

func getUpdate() Update {

if !config.Values().CheckUpdates {
should, err := data.CheckUpdates()
if err != nil || !should {
return Update{}
}

Expand Down

0 comments on commit 7f98ea4

Please sign in to comment.