diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d9d9552 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/pkg/rest/static/ diff --git a/auth/provisioning/scim/middleware.go b/auth/provisioning/scim/middleware.go index 10cfacb..4f5c1cb 100644 --- a/auth/provisioning/scim/middleware.go +++ b/auth/provisioning/scim/middleware.go @@ -6,7 +6,7 @@ import ( "strings" ) -func (s *scim) authMiddleware(next http.Handler) http.Handler { +func (s *Scim) authMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !strings.HasPrefix(r.Header.Get("Authorization"), "Bearer ") { writeWithStatus(w, []byte(`{"error": "token not found"}`), http.StatusUnauthorized) diff --git a/auth/provisioning/scim/new.go b/auth/provisioning/scim/new.go index 55fd94d..ab95988 100644 --- a/auth/provisioning/scim/new.go +++ b/auth/provisioning/scim/new.go @@ -5,8 +5,8 @@ import ( "github.com/in4it/go-devops-platform/users" ) -func New(storage storage.Iface, userStore *users.UserStore, token string, disableFunc DisableFunc, reactivateFunc ReactivateFunc) *scim { - s := &scim{ +func New(storage storage.Iface, userStore *users.UserStore, token string, disableFunc DisableFunc, reactivateFunc ReactivateFunc) *Scim { + s := &Scim{ Token: token, UserStore: userStore, storage: storage, diff --git a/auth/provisioning/scim/router.go b/auth/provisioning/scim/router.go index f8f0248..01fcb3a 100644 --- a/auth/provisioning/scim/router.go +++ b/auth/provisioning/scim/router.go @@ -4,7 +4,7 @@ import ( "net/http" ) -func (s *scim) GetRouter() *http.ServeMux { +func (s *Scim) GetRouter() *http.ServeMux { mux := http.NewServeMux() mux.Handle("/api/scim/", s.authMiddleware(http.HandlerFunc(notFoundHandler))) diff --git a/auth/provisioning/scim/types.go b/auth/provisioning/scim/types.go index 9a118a6..863a970 100644 --- a/auth/provisioning/scim/types.go +++ b/auth/provisioning/scim/types.go @@ -10,7 +10,7 @@ import ( type DisableFunc func(storage.Iface, users.User) error type ReactivateFunc func(storage.Iface, users.User) error -type scim struct { +type Scim struct { Token string `json:"token"` UserStore *users.UserStore `json:"userStore"` storage storage.Iface diff --git a/auth/provisioning/scim/update.go b/auth/provisioning/scim/update.go index e1abc5e..7284852 100644 --- a/auth/provisioning/scim/update.go +++ b/auth/provisioning/scim/update.go @@ -1,5 +1,5 @@ package scim -func (s *scim) UpdateToken(token string) { +func (s *Scim) UpdateToken(token string) { s.Token = token } diff --git a/auth/provisioning/scim/users.go b/auth/provisioning/scim/users.go index f910f73..1be9372 100644 --- a/auth/provisioning/scim/users.go +++ b/auth/provisioning/scim/users.go @@ -11,7 +11,7 @@ import ( ) // handler for multiple users -func (s *scim) usersHandler(w http.ResponseWriter, r *http.Request) { +func (s *Scim) usersHandler(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: s.GetUsersHandler(w, r) @@ -25,7 +25,7 @@ func (s *scim) usersHandler(w http.ResponseWriter, r *http.Request) { } // handler for a single user -func (s *scim) userHandler(w http.ResponseWriter, r *http.Request) { +func (s *Scim) userHandler(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: s.GetUsersHandler(w, r) @@ -41,7 +41,7 @@ func (s *scim) userHandler(w http.ResponseWriter, r *http.Request) { } } -func (s *scim) GetUsersHandler(w http.ResponseWriter, r *http.Request) { +func (s *Scim) GetUsersHandler(w http.ResponseWriter, r *http.Request) { attributes := r.URL.Query().Get("attributes") filter := r.URL.Query().Get("filter") count, err := strconv.Atoi(r.URL.Query().Get("count")) @@ -70,7 +70,7 @@ func (s *scim) GetUsersHandler(w http.ResponseWriter, r *http.Request) { write(w, response) } -func (s *scim) getUserHandler(w http.ResponseWriter, r *http.Request) { +func (s *Scim) getUserHandler(w http.ResponseWriter, r *http.Request) { user, err := s.UserStore.GetUserByID(r.PathValue("id")) if err != nil { returnError(w, fmt.Errorf("get user by id error: %s", err), http.StatusBadRequest) @@ -85,7 +85,7 @@ func (s *scim) getUserHandler(w http.ResponseWriter, r *http.Request) { write(w, response) } -func (s *scim) PutUserHandler(w http.ResponseWriter, r *http.Request) { +func (s *Scim) PutUserHandler(w http.ResponseWriter, r *http.Request) { user, err := s.UserStore.GetUserByID(r.PathValue("id")) if err != nil { returnError(w, fmt.Errorf("get user by id error: %s", err), http.StatusBadRequest) @@ -137,7 +137,7 @@ func (s *scim) PutUserHandler(w http.ResponseWriter, r *http.Request) { write(w, response) } -func (s *scim) DeleteUserHandler(w http.ResponseWriter, r *http.Request) { +func (s *Scim) DeleteUserHandler(w http.ResponseWriter, r *http.Request) { user, err := s.UserStore.GetUserByID(r.PathValue("id")) if err != nil { returnError(w, fmt.Errorf("get user by id error: %s", err), http.StatusBadRequest) @@ -159,7 +159,7 @@ func (s *scim) DeleteUserHandler(w http.ResponseWriter, r *http.Request) { write(w, []byte("")) } -func (s *scim) PostUsersHandler(w http.ResponseWriter, r *http.Request) { +func (s *Scim) PostUsersHandler(w http.ResponseWriter, r *http.Request) { var postUserRequest PostUserRequest err := json.NewDecoder(r.Body).Decode(&postUserRequest) if err != nil { diff --git a/go.mod b/go.mod index b9a5783..09b815b 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,11 @@ require ( golang.org/x/crypto v0.28.0 ) +require ( + golang.org/x/net v0.21.0 // indirect + golang.org/x/text v0.19.0 // indirect +) + require ( github.com/aws/aws-sdk-go-v2 v1.32.2 github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.6 // indirect diff --git a/go.sum b/go.sum index 3e4416a..73f07dd 100644 --- a/go.sum +++ b/go.sum @@ -74,6 +74,10 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= +golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= +golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/rest/auditlog/logentry.go b/rest/auditlog/logentry.go new file mode 100644 index 0000000..b166bd5 --- /dev/null +++ b/rest/auditlog/logentry.go @@ -0,0 +1,39 @@ +package auditlog + +import ( + "encoding/json" + "fmt" + "path" + "time" + + "github.com/in4it/go-devops-platform/storage" +) + +const TIMESTAMP_FORMAT = "2006-01-02T15:04:05" +const AUDITLOG_STATS_DIR = "stats" + +type LogEntry struct { + Timestamp LogTimestamp `json:"timestamp"` + UserID string `json:"userID"` + Action string `json:"action"` +} +type LogTimestamp time.Time + +func (t LogTimestamp) MarshalJSON() ([]byte, error) { + timestamp := fmt.Sprintf("\"%s\"", time.Time(t).Format(TIMESTAMP_FORMAT)) + return []byte(timestamp), nil +} + +func Write(storage storage.Iface, logEntry LogEntry) error { + statsPath := path.Join(AUDITLOG_STATS_DIR, "logins-"+time.Now().Format("2006-01-02")) + ".log" + logEntryBytes, err := json.Marshal(logEntry) + if err != nil { + return fmt.Errorf("could not parse log entry: %s", err) + } + err = storage.AppendFile(statsPath, logEntryBytes) + if err != nil { + return fmt.Errorf("could not append stats to file (%s): %s", statsPath, err) + } + + return nil +} diff --git a/rest/auth.go b/rest/auth.go new file mode 100644 index 0000000..79ae215 --- /dev/null +++ b/rest/auth.go @@ -0,0 +1,434 @@ +package rest + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" + "time" + + "github.com/google/uuid" + "github.com/in4it/go-devops-platform/auth/oidc" + oidcstore "github.com/in4it/go-devops-platform/auth/oidc/store" + "github.com/in4it/go-devops-platform/auth/saml" + "github.com/in4it/go-devops-platform/logging" + "github.com/in4it/go-devops-platform/rest/login" +) + +func (c *Context) authHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + c.returnError(w, fmt.Errorf("not a post request"), http.StatusBadRequest) + return + } + + if c.LocalAuthDisabled { + c.returnError(w, fmt.Errorf("local auth is disabled in settings"), http.StatusForbidden) + return + } + + decoder := json.NewDecoder(r.Body) + var loginReq login.LoginRequest + err := decoder.Decode(&loginReq) + if err != nil { + c.returnError(w, fmt.Errorf("decode input error: %s", err), http.StatusBadRequest) + return + } + + // check login attempts + tooManyLogins := login.CheckTooManyLogins(c.LoginAttempts, loginReq.Login) + if tooManyLogins { + c.returnError(w, fmt.Errorf("too many login failures, try again later"), http.StatusTooManyRequests) + return + } + + loginResponse, user, err := login.Authenticate(loginReq, c.UserStore, c.JWTKeys.PrivateKey, c.JWTKeysKID) + if err != nil { + c.returnError(w, fmt.Errorf("authentication error: %s", err), http.StatusBadRequest) + return + } + out, err := json.Marshal(loginResponse) + if err != nil { + c.returnError(w, fmt.Errorf("unable to marshal response: %s", err), http.StatusBadRequest) + return + } + if loginResponse.MFARequired { + c.write(w, out) // status ok, but unauthorized, because we need a second call with MFA code + return + } else if loginResponse.Authenticated { + login.ClearAttemptsForLogin(c.LoginAttempts, loginReq.Login) + user.LastLogin = time.Now() + err = c.UserStore.UpdateUser(user) + if err != nil { + logging.ErrorLog(fmt.Errorf("last login update error: %s", err)) + } + c.write(w, out) + } else { + // log login attempts + login.RecordAttempt(c.LoginAttempts, loginReq.Login) + // return Unauthorized + c.writeWithStatus(w, out, http.StatusUnauthorized) + } +} + +func (c *Context) oidcProviderHandler(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + oidcProviders := make([]oidc.OIDCProvider, len(c.OIDCProviders)) + copy(oidcProviders, c.OIDCProviders) + for k := range oidcProviders { + oidcProviders[k].LoginURL = fmt.Sprintf("%s://%s%s", c.Protocol, c.Hostname, strings.Replace(oidcProviders[k].RedirectURI, "/callback/", "/login/", -1)) + oidcProviders[k].RedirectURI = fmt.Sprintf("%s://%s%s", c.Protocol, c.Hostname, oidcProviders[k].RedirectURI) + } + out, err := json.Marshal(oidcProviders) + if err != nil { + c.returnError(w, fmt.Errorf("oidcProviders marshal error"), http.StatusBadRequest) + return + } + c.write(w, out) + case http.MethodPost: + var oidcProvider oidc.OIDCProvider + decoder := json.NewDecoder(r.Body) + err := decoder.Decode(&oidcProvider) + if err != nil { + c.returnError(w, fmt.Errorf("decode input error: %s", err), http.StatusBadRequest) + return + } + oidcProvider.ID = uuid.New().String() + if oidcProvider.Name == "" { + c.returnError(w, fmt.Errorf("name not set"), http.StatusBadRequest) + return + } + if oidcProvider.ClientID == "" { + c.returnError(w, fmt.Errorf("clientID not set"), http.StatusBadRequest) + return + } + if oidcProvider.ClientSecret == "" { + c.returnError(w, fmt.Errorf("clientSecret not set"), http.StatusBadRequest) + return + } + if oidcProvider.Scope == "" { + c.returnError(w, fmt.Errorf("scope not set"), http.StatusBadRequest) + return + } + if oidcProvider.DiscoveryURI == "" { + c.returnError(w, fmt.Errorf("discovery URL not set"), http.StatusBadRequest) + return + } + oidcProvider.RedirectURI = "/callback/oidc/" + oidcProvider.ID + c.OIDCProviders = append(c.OIDCProviders, oidcProvider) + out, err := json.Marshal(oidcProvider) + if err != nil { + c.returnError(w, fmt.Errorf("oidcProvider marshal error: %s", err), http.StatusBadRequest) + return + } + err = SaveConfig(c) + if err != nil { + c.returnError(w, fmt.Errorf("saveConfig error: %s", err), http.StatusBadRequest) + return + } + c.write(w, out) + default: + c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) + return + } +} + +func (c *Context) oidcProviderElementHandler(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodDelete: + match := -1 + for k, oidcProvider := range c.OIDCProviders { + if oidcProvider.ID == r.PathValue("id") { + match = k + } + } + if match == -1 { + c.returnError(w, fmt.Errorf("oidc provider not found"), http.StatusBadRequest) + return + } + c.OIDCProviders = append(c.OIDCProviders[:match], c.OIDCProviders[match+1:]...) + // save config (changed providers) + err := SaveConfig(c) + if err != nil { + c.returnError(w, fmt.Errorf("saveConfig error: %s", err), http.StatusBadRequest) + return + } + c.write(w, []byte(`{ "deleted": "`+r.PathValue("id")+`" }`)) + } +} + +func (c *Context) authMethods(w http.ResponseWriter, r *http.Request) { + response := AuthMethodsResponse{ + LocalAuthDisabled: c.LocalAuthDisabled, + OIDCProviders: make([]AuthMethodsProvider, len(c.OIDCProviders)), + } + for k, oidcProvider := range c.OIDCProviders { + response.OIDCProviders[k] = AuthMethodsProvider{ + ID: oidcProvider.ID, + Name: oidcProvider.Name, + } + } + + out, err := json.Marshal(response) + if err != nil { + c.returnError(w, fmt.Errorf("response marshal error: %s", err), http.StatusBadRequest) + return + } + c.write(w, out) +} +func (c *Context) authMethodsByID(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodPost: + switch r.PathValue("method") { + case "saml": + loginResponse := login.LoginResponse{} + var samlCallback SAMLCallback + decoder := json.NewDecoder(r.Body) + err := decoder.Decode(&samlCallback) + if err != nil { + c.returnError(w, fmt.Errorf("decode input error: %s", err), http.StatusBadRequest) + return + } + if samlCallback.Code == "" { + c.returnError(w, fmt.Errorf("no code provided"), http.StatusBadRequest) + return + } + var samlProvider saml.Provider + for k := range *c.SAML.Providers { + if r.PathValue("id") == (*c.SAML.Providers)[k].ID { + samlProvider = (*c.SAML.Providers)[k] + } + } + if samlProvider.ID == "" { + c.returnError(w, fmt.Errorf("saml provider not found"), http.StatusBadRequest) + return + } + + samlSession, err := c.SAML.Client.GetAuthenticatedUser(samlProvider, samlCallback.Code) + if err != nil { + c.returnError(w, fmt.Errorf("saml session not found"), http.StatusBadRequest) + return + } + + // add user to the user database (or modify existing one) + user, err := addOrModifyExternalUser(c.Storage.Client, c.UserStore, samlSession.Login, "saml", samlSession.ID) + if err != nil { + c.returnError(w, fmt.Errorf("couldn't add/modify user in database: %s", err), http.StatusBadRequest) + return + } + + if user.Suspended { + loginResponse.Suspended = true + } + + token, err := login.GetJWTTokenWithExpiration(user.Login, user.Role, c.JWTKeys.PrivateKey, c.JWTKeysKID, samlSession.ExpiresAt) + if err != nil { + c.returnError(w, fmt.Errorf("token generation failed: %s", err), http.StatusBadRequest) + return + } + loginResponse.Authenticated = true + loginResponse.Token = token + + out, err := json.Marshal(loginResponse) + if err != nil { + c.returnError(w, fmt.Errorf("loginResponse Marshal error: %s", err), http.StatusBadRequest) + return + } + c.write(w, out) + return + default: // oidc is default + loginResponse := login.LoginResponse{} + var oidcCallback OIDCCallback + decoder := json.NewDecoder(r.Body) + err := decoder.Decode(&oidcCallback) + if err != nil { + c.returnError(w, fmt.Errorf("decode input error: %s", err), http.StatusBadRequest) + return + } + for _, oidcProvider := range c.OIDCProviders { + if r.PathValue("id") == oidcProvider.ID && oidcCallback.Code != "" { // we got the code back + oidcstore.RetrieveTokenLock.Lock() + defer oidcstore.RetrieveTokenLock.Unlock() + oauth2data, err := oidc.RetrieveOAUth2DataUsingState(c.OIDCStore.OAuth2Data, oidcCallback.State) // get the oauth2 struct based on the state (key) + if err != nil { + c.returnError(w, fmt.Errorf("cannot find oauth2 data using state provided: %s", err), http.StatusBadRequest) + return + } + if oauth2data.Token.AccessToken != "" { + if oauth2data.Suspended { + loginResponse.Suspended = true + } else if c.LicenseUserCount >= c.UserStore.UserCount() { + loginResponse.NoLicense = true + } else { + loginResponse.Authenticated = true + loginResponse.Token = oauth2data.Token.AccessToken + } + out, err := json.Marshal(loginResponse) + if err != nil { + c.returnError(w, fmt.Errorf("loginResponse Marshal error: %s", err), http.StatusBadRequest) + return + } + c.write(w, out) + return + } + // no token, let's generate a new one + discovery, err := c.OIDCStore.GetDiscoveryURI(oidcProvider.DiscoveryURI) + if err != nil { + c.returnError(w, fmt.Errorf("getDiscoveryURI error: %s", err), http.StatusBadRequest) + return + } + jwks, err := c.OIDCStore.GetJwks(discovery.JwksURI) + if err != nil { + c.returnError(w, fmt.Errorf("get jwks error: %s", err), http.StatusBadRequest) + return + } + updatedOauth2data, err := oidc.UpdateOAuth2DataWithToken(jwks, discovery, oidcProvider.ClientID, oidcProvider.ClientSecret, c.Protocol+"://"+c.Hostname+oidcCallback.RedirectURI, oidcCallback.Code, oidcCallback.State, oauth2data) + if err != nil { + c.returnError(w, fmt.Errorf("GetTokenFromCode error: %s", err), http.StatusBadRequest) + return + } + // add user to the user database (or modify existing one) + user, err := addOrModifyExternalUser(c.Storage.Client, c.UserStore, updatedOauth2data.UserInfo.Email, "oidc", updatedOauth2data.ID) + if err != nil { + c.returnError(w, fmt.Errorf("couldn't add/modify user in database: %s", err), http.StatusBadRequest) + return + } + if user.Suspended { + loginResponse.Suspended = true + updatedOauth2data.Suspended = true + } else { + updatedOauth2data.Suspended = false + } + // save oauth data (only when we're sure it's not a suspended user) + err = c.OIDCStore.SaveOAuth2Data(updatedOauth2data, oidcCallback.State) + if err != nil { + c.returnError(w, fmt.Errorf("oidc store save failed: %s", err), http.StatusBadRequest) + return + } + // cleanup oauth2 data + c.OIDCStore.CleanupOAuth2Data(updatedOauth2data) + + // save config (changed user info) + err = SaveConfig(c) + if err != nil { + c.returnError(w, fmt.Errorf("saveConfig error: %s", err), http.StatusBadRequest) + return + } + + // set loginResponse + if !loginResponse.Suspended { + loginResponse.Authenticated = true + loginResponse.Token = updatedOauth2data.Token.AccessToken + } + out, err := json.Marshal(loginResponse) + if err != nil { + c.returnError(w, fmt.Errorf("loginResponse Marshal error: %s", err), http.StatusBadRequest) + return + } + c.write(w, out) + return + } + } + } + c.returnError(w, fmt.Errorf("oidc provider not found"), http.StatusBadRequest) + case http.MethodGet: + switch r.PathValue("method") { + case "saml": + id := r.PathValue("id") + samlProviderId := -1 + for k := range *c.SAML.Providers { + if (*c.SAML.Providers)[k].ID == id { + samlProviderId = k + } + } + if samlProviderId == -1 { + c.returnError(w, fmt.Errorf("cannot find saml provider"), http.StatusBadRequest) + return + } + redirectURI, err := c.SAML.Client.GetAuthURL((*c.SAML.Providers)[samlProviderId]) + if err != nil { + c.returnError(w, fmt.Errorf("cannot get auth url"), http.StatusBadRequest) + return + } + response := AuthMethodsProvider{ + ID: (*c.SAML.Providers)[samlProviderId].ID, + Name: (*c.SAML.Providers)[samlProviderId].Name, + RedirectURI: redirectURI, + } + out, err := json.Marshal(response) + if err != nil { + c.returnError(w, fmt.Errorf("response marshal error: %s", err), http.StatusBadRequest) + return + } + c.write(w, out) + return + default: + id := r.PathValue("id") + for _, oidcProvider := range c.OIDCProviders { + if id == oidcProvider.ID { + callback := fmt.Sprintf("%s://%s%s", c.Protocol, c.Hostname, oidcProvider.RedirectURI) + discovery, err := c.OIDCStore.GetDiscoveryURI(oidcProvider.DiscoveryURI) + if err != nil { + c.returnError(w, fmt.Errorf("getDiscoveryURI error: %s", err), http.StatusBadRequest) + return + } + redirectURI, state, err := oidc.GetRedirectURI(discovery, oidcProvider.ClientID, oidcProvider.Scope, callback, c.EnableOIDCTokenRenewal) + if err != nil { + c.returnError(w, fmt.Errorf("GetRedirectURI error: %s", err), http.StatusBadRequest) + return + } + response := AuthMethodsProvider{ + ID: oidcProvider.ID, + Name: oidcProvider.Name, + RedirectURI: redirectURI, + } + out, err := json.Marshal(response) + if err != nil { + c.returnError(w, fmt.Errorf("response marshal error: %s", err), http.StatusBadRequest) + return + } + newOAuthEntry := oidc.OAuthData{ + ID: uuid.NewString(), + OIDCProviderID: response.ID, + CreatedAt: time.Now(), + } + err = c.OIDCStore.SaveOAuth2Data(newOAuthEntry, state) + if err != nil { + c.returnError(w, fmt.Errorf("unable to save state to oidc store: %s", err), http.StatusBadRequest) + return + } + c.write(w, out) + return + } + } + c.returnError(w, fmt.Errorf("element not found"), http.StatusBadRequest) + } + default: + c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) + } +} + +func (c *Context) oidcRenewTokensHandler(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodPost: + disabledUsers := c.OIDCRenewal.RenewAllOIDCConnections() + for _, user := range disabledUsers { + logging.DebugLog(fmt.Errorf("disable user with oidc id %s", user.ID)) + + err := c.UserStore.UserHooks.DisableFunc(c.Storage.Client, user) + if err != nil { + c.returnError(w, fmt.Errorf("DisableAllClientConfigs error for userID %s: %s", user.ID, err), http.StatusBadRequest) + return + } + user.ConnectionsDisabledOnAuthFailure = true + err = c.UserStore.UpdateUser(user) + if err != nil { + c.returnError(w, fmt.Errorf("could not update connectionsDisabledOnAuthFailure user with userID %s: %s", user.ID, err), http.StatusBadRequest) + return + } + } + c.write(w, []byte(`{"status": "done"}`)) + default: + c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) + } +} diff --git a/rest/auth_test.go b/rest/auth_test.go new file mode 100644 index 0000000..af31db9 --- /dev/null +++ b/rest/auth_test.go @@ -0,0 +1,946 @@ +package rest + +import ( + "bytes" + "compress/flate" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "encoding/xml" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/in4it/go-devops-platform/auth/oidc" + "github.com/in4it/go-devops-platform/auth/saml" + "github.com/in4it/go-devops-platform/logging" + "github.com/in4it/go-devops-platform/rest/login" + memorystorage "github.com/in4it/go-devops-platform/storage/memory" + "github.com/in4it/go-devops-platform/users" + "github.com/russellhaering/gosaml2/types" + dsigtypes "github.com/russellhaering/goxmldsig/types" +) + +func getSAMLCertWithCustomCert(singleSignOnURL string, cert string) *types.EntityDescriptor { + return &types.EntityDescriptor{ + EntityID: "https://www.idp.inv/metadata", + IDPSSODescriptor: &types.IDPSSODescriptor{ + SingleSignOnServices: []types.SingleSignOnService{ + { + Location: singleSignOnURL, + }, + }, + KeyDescriptors: []types.KeyDescriptor{ + { + KeyInfo: dsigtypes.KeyInfo{ + X509Data: dsigtypes.X509Data{ + X509Certificates: []dsigtypes.X509Certificate{ + { + Data: cert, + }, + }, + }, + }, + }, + }, + }, + } +} +func getSAMLCert(singleSignOnURL string) *types.EntityDescriptor { + cert := `MIID2jCCA0MCAg39MA0GCSqGSIb3DQEBBQUAMIGbMQswCQYDVQQGEwJKUDEOMAwG +A1UECBMFVG9reW8xEDAOBgNVBAcTB0NodW8ta3UxETAPBgNVBAoTCEZyYW5rNERE +MRgwFgYDVQQLEw9XZWJDZXJ0IFN1cHBvcnQxGDAWBgNVBAMTD0ZyYW5rNEREIFdl +YiBDQTEjMCEGCSqGSIb3DQEJARYUc3VwcG9ydEBmcmFuazRkZC5jb20wHhcNMTIw +ODIyMDUyODAwWhcNMTcwODIxMDUyODAwWjBKMQswCQYDVQQGEwJKUDEOMAwGA1UE +CAwFVG9reW8xETAPBgNVBAoMCEZyYW5rNEREMRgwFgYDVQQDDA93d3cuZXhhbXBs +ZS5jb20wggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQCwvWITOLeyTbS1 +Q/UacqeILIK16UHLvSymIlbbiT7mpD4SMwB343xpIlXN64fC0Y1ylT6LLeX4St7A +cJrGIV3AMmJcsDsNzgo577LqtNvnOkLH0GojisFEKQiREX6gOgq9tWSqwaENccTE +sAXuV6AQ1ST+G16s00iN92hjX9V/V66snRwTsJ/p4WRpLSdAj4272hiM19qIg9zr +h92e2rQy7E/UShW4gpOrhg2f6fcCBm+aXIga+qxaSLchcDUvPXrpIxTd/OWQ23Qh +vIEzkGbPlBA8J7Nw9KCyaxbYMBFb1i0lBjwKLjmcoihiI7PVthAOu/B71D2hKcFj +Kpfv4D1Uam/0VumKwhwuhZVNjLq1BR1FKRJ1CioLG4wCTr0LVgtvvUyhFrS+3PdU +R0T5HlAQWPMyQDHgCpbOHW0wc0hbuNeO/lS82LjieGNFxKmMBFF9lsN2zsA6Qw32 +Xkb2/EFltXCtpuOwVztdk4MDrnaDXy9zMZuqFHpv5lWTbDVwDdyEQNclYlbAEbDe +vEQo/rAOZFl94Mu63rAgLiPeZN4IdS/48or5KaQaCOe0DuAb4GWNIQ42cYQ5TsEH +Wt+FIOAMSpf9hNPjDeu1uff40DOtsiyGeX9NViqKtttaHpvd7rb2zsasbcAGUl+f +NQJj4qImPSB9ThqZqPTukEcM/NtbeQIDAQABMA0GCSqGSIb3DQEBBQUAA4GBAIAi +gU3My8kYYniDuKEXSJmbVB+K1upHxWDA8R6KMZGXfbe5BRd8s40cY6JBYL52Tgqd +l8z5Ek8dC4NNpfpcZc/teT1WqiO2wnpGHjgMDuDL1mxCZNL422jHpiPWkWp3AuDI +c7tL1QjbfAUHAQYwmHkWgPP+T2wAv0pOt36GgMCM` + return getSAMLCertWithCustomCert(singleSignOnURL, cert) +} + +func TestAuthHandler(t *testing.T) { + c, err := newContext(&memorystorage.MockMemoryStorage{}, SERVER_TYPE_VPN) + if err != nil { + t.Fatalf("Cannot create context: %s", err) + } + c.UserStore.Empty() + _, err = c.UserStore.AddUser(users.User{ + Login: "john", + Password: "mypass", + }) + if err != nil { + t.Fatalf("Cannot create user") + } + + loginReq := login.LoginRequest{ + Login: "john", + Password: "mypass", + } + + payload, err := json.Marshal(loginReq) + if err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest("POST", "http://example.com/api/auth", bytes.NewBuffer(payload)) + w := httptest.NewRecorder() + c.authHandler(w, req) + + resp := w.Result() + + if resp.StatusCode != 200 { + t.Fatalf("status code is not 200: %d", resp.StatusCode) + } + + defer resp.Body.Close() + + var loginResponse login.LoginResponse + + err = json.NewDecoder(resp.Body).Decode(&loginResponse) + if err != nil { + t.Fatalf("Cannot decode response from create user: %s", err) + } + + if !loginResponse.Authenticated { + t.Fatalf("expected to be authenticated") + } + +} + +func TestNewSAMLConnection(t *testing.T) { + // generate new keypair + kp := saml.NewKeyPair(&memorystorage.MockMemoryStorage{}, "www.idp.inv") + _, cert, err := kp.GetKeyPair() + if err != nil { + t.Fatalf("Can't generate new keypair: %s", err) + } + certBase64 := base64.StdEncoding.EncodeToString(cert) + + testUrl := "127.0.0.1:12347" + l, err := net.Listen("tcp", testUrl) + if err != nil { + t.Fatal(err) + } + + singleSignOnURL := "http://" + testUrl + "/auth" + audienceURL := "http://" + testUrl + "/aud" + login := "john@example.inv" + + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + requestURIParsed, _ := url.Parse(r.RequestURI) + if requestURIParsed.Path == "/auth" { + compressedSAMLReq, err := base64.StdEncoding.DecodeString(r.URL.Query().Get("SAMLRequest")) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(fmt.Sprintf("saml base64 decode error: %s", err))) + return + } + samlRequest := new(bytes.Buffer) + decompressor := flate.NewReader(bytes.NewReader(compressedSAMLReq)) + io.Copy(samlRequest, decompressor) + decompressor.Close() + + var authnReq saml.AuthnRequest + err = xml.Unmarshal(samlRequest.Bytes(), &authnReq) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(fmt.Sprintf("saml authn request decode error: %s", err))) + return + } + w.Write([]byte("OK")) + return + } + if r.RequestURI == "/metadata" { + out, _ := xml.Marshal(getSAMLCertWithCustomCert(singleSignOnURL, certBase64)) + w.Write(out) + return + } + w.WriteHeader(http.StatusBadRequest) + default: + w.WriteHeader(http.StatusBadRequest) + } + })) + + ts.Listener.Close() + ts.Listener = l + ts.Start() + defer ts.Close() + defer l.Close() + + // first create a new user + c, err := newContext(&memorystorage.MockMemoryStorage{}, SERVER_TYPE_VPN) + if err != nil { + t.Fatalf("Cannot create context") + } + + // create a new SAML connection + samlProvider := saml.Provider{ + Name: "testProvider", + MetadataURL: fmt.Sprintf("%s/metadata", ts.URL), + } + + payload, err := json.Marshal(samlProvider) + if err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest("POST", "http://example.inv/api/saml-setup", bytes.NewBuffer(payload)) + w := httptest.NewRecorder() + c.samlSetupHandler(w, req) + + resp := w.Result() + + if resp.StatusCode != 200 { + t.Fatalf("status code is not 200: %d", resp.StatusCode) + } + + defer resp.Body.Close() + + err = json.NewDecoder(resp.Body).Decode(&samlProvider) + if err != nil { + t.Fatalf("Cannot decode response from create user: %s", err) + } + + if samlProvider.ID == "" { + t.Fatalf("Was expecting saml provider to have an ID") + } + + authURL, err := c.SAML.Client.GetAuthURL(samlProvider) + if err != nil { + t.Fatalf("cannot get Auth URL from saml: %s", err) + } + + if authURL == "" { + t.Fatalf("authURL is empty") + } + + resp, err = http.Get(authURL) + if err != nil { + t.Fatalf("http get auth url error: %s", err) + } + if resp.StatusCode != 200 { + t.Errorf("auth url get not status 200: %d", resp.StatusCode) + } + _, err = io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("body read error: %s", err) + } + + // check SAML POST flow + tsSAML := httptest.NewServer(c.SAML.Client.GetRouter()) + defer tsSAML.Close() + + // build the SAML response + // example + /* + + + https://app.onelogin.com/saml/metadata/onelogin-id + + + + + + + + + + + 5eB3C+2/vwdigestvalue + + + sigvalue + + + MIIGETCCA/mgAcertt + + + + + + + + https://app.onelogin.com/saml/metadata/onelogin-id + + ward@in4it.io + + + + + + + https://vpn-server.in4it.io/saml/aud/provider-id + + + + + urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport + + + + + */ + samlResponse := saml.Response{ + Saml: "urn:oasis:names:tc:SAML:2.0:assertion", + Samlp: "urn:oasis:names:tc:SAML:2.0:protocol", + ID: "pfx181c1b93-cce6-a6c4-f1f7-99e539374d15", + Version: "2.0", + IssueInstant: time.Now().Format(time.RFC3339), + Destination: "http://example.inv/saml/acs/" + samlProvider.ID, + Issuer: ts.URL + "/metadata", + Signature: saml.ResponseSignature{ + Ds: "http://www.w3.org/2000/09/xmldsig#", + SignedInfo: saml.ResponseSignatureSignedInfo{ + CanonicalizationMethod: struct { + Text string "xml:\",chardata\"" + Algorithm string "xml:\"Algorithm,attr\"" + }{ + Algorithm: "http://www.w3.org/2001/10/xml-exc-c14n#", + }, + SignatureMethod: saml.ResponseSignatureSignedInfoSignatureMethod{ + Algorithm: "http://www.w3.org/2000/09/xmldsig#rsa-sha1", + }, + Reference: saml.ResponseSignatureSignedInfoReference{ + Transforms: struct { + Text string "xml:\",chardata\"" + Transform []struct { + Text string "xml:\",chardata\"" + Algorithm string "xml:\"Algorithm,attr\"" + } "xml:\"ds:Transform\"" + }{ + Transform: []struct { + Text string "xml:\",chardata\"" + Algorithm string "xml:\"Algorithm,attr\"" + }{ + { + Algorithm: "http://www.w3.org/2000/09/xmldsig#enveloped-signature", + }, + { + Algorithm: "http://www.w3.org/2001/10/xml-exc-c14n#", + }, + }, + }, + DigestMethod: struct { + Text string "xml:\",chardata\"" + Algorithm string "xml:\"Algorithm,attr\"" + }{ + Algorithm: "http://www.w3.org/2000/09/xmldsig#sha1", + }, + DigestValue: "thisisthesignature", + }, + }, + KeyInfo: saml.ResponseSignatureKeyInfo{ + X509Data: struct { + Text string "xml:\",chardata\"" + X509Certificate string "xml:\"ds:X509Certificate\"" + }{ + X509Certificate: certBase64, + }, + }, + }, + Assertion: saml.ResponseAssertion{ + Subject: saml.ResponseSubject{ + NameID: struct { + Text string "xml:\",chardata\"" + Format string "xml:\"Format,attr\"" + }{ + Text: login, + Format: "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress", + }, + }, + Conditions: saml.ResponseConditions{ + NotBefore: time.Now().Format(time.RFC3339), + NotOnOrAfter: time.Now().Add(10 * time.Minute).Format(time.RFC3339), + AudienceRestriction: saml.ResponseConditionsAdienceRestriction{ + Audience: audienceURL, + }, + }, + }, + } + samlResponseBytes, err := xml.Marshal(samlResponse) + if err != nil { + t.Fatalf("xml marshal error: %s", err) + } + //fmt.Printf("saml respons bytes: %s\n", samlResponseBytes) + samlResponseBytesDeflated := new(bytes.Buffer) + compressor, err := flate.NewWriter(samlResponseBytesDeflated, 1) + if err != nil { + t.Fatalf("deflate error: %s", err) + } + io.Copy(compressor, bytes.NewBuffer(samlResponseBytes)) + compressor.Close() + + samlResponseEncoded := base64.StdEncoding.EncodeToString(samlResponseBytesDeflated.Bytes()) + + form := url.Values{} + form.Add("SAMLResponse", samlResponseEncoded) + + resp, err = http.Post(tsSAML.URL+"/saml/acs/"+samlProvider.ID, "application/x-www-form-urlencoded", strings.NewReader(form.Encode())) + if err != nil { + t.Fatalf("http post acs url error: %s", err) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("body read error: %s", err) + } + + if strings.Contains(string(body), "provider not found") { + t.Fatalf("provider not found: body output: %s", body) + } + + // currently does't authenticate because of missing signatures, but we checked if the auth process kicked off + /*if resp.StatusCode != 200 { + t.Errorf("auth url get not status 200: %d", resp.StatusCode) + }*/ + +} +func TestAddModifyDeleteNewSAMLConnection(t *testing.T) { + c, err := newContext(&memorystorage.MockMemoryStorage{}, SERVER_TYPE_VPN) + if err != nil { + t.Fatalf("Cannot create context") + } + c.Hostname = "example.inv" + c.Protocol = "https" + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.RequestURI == "/metadata" { + out, err := xml.Marshal(getSAMLCert("http://localhost.inv")) + if err != nil { + t.Fatalf("marshal error: %s", err) + } + w.Write(out) + return + } + w.WriteHeader(http.StatusBadRequest) + })) + + samlProvider := saml.Provider{ + Name: "testProvider", + MetadataURL: fmt.Sprintf("%s/metadata", ts.URL), + } + + payload, err := json.Marshal(samlProvider) + if err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest("POST", "http://example.com/api/saml-setup", bytes.NewBuffer(payload)) + w := httptest.NewRecorder() + c.samlSetupHandler(w, req) + + resp := w.Result() + + if resp.StatusCode != 200 { + t.Fatalf("status code is not 200: %d", resp.StatusCode) + } + + defer resp.Body.Close() + + err = json.NewDecoder(resp.Body).Decode(&samlProvider) + if err != nil { + t.Fatalf("Cannot decode response from create user: %s", err) + } + + if samlProvider.ID == "" { + t.Fatalf("samlprovider id is empty") + } + + // GET authmethods and see if provider exists + req = httptest.NewRequest("GET", "http://example.com/authmethods/saml/"+samlProvider.ID, nil) + req.SetPathValue("method", "saml") + req.SetPathValue("id", samlProvider.ID) + w = httptest.NewRecorder() + c.authMethodsByID(w, req) + + resp = w.Result() + + if resp.StatusCode != 200 { + t.Fatalf("status code is not 200: %d", resp.StatusCode) + } + + defer resp.Body.Close() + + var authMethodsProvider AuthMethodsProvider + + err = json.NewDecoder(resp.Body).Decode(&authMethodsProvider) + if err != nil { + t.Fatalf("Cannot decode response from create user: %s", err) + } + if authMethodsProvider.ID != samlProvider.ID { + t.Fatalf("authmethods provider id is different than saml provider id: %s vs %s. authMethodsProvider: %+v", authMethodsProvider.ID, samlProvider.ID, authMethodsProvider) + } + + // PUT req + samlProvider.AllowMissingAttributes = true + payload, err = json.Marshal(samlProvider) + if err != nil { + t.Fatalf("marshal error: %s", err) + } + req = httptest.NewRequest("PUT", "http://example.com/saml-setup/"+samlProvider.ID, bytes.NewBuffer(payload)) + req.SetPathValue("id", samlProvider.ID) + w = httptest.NewRecorder() + c.samlSetupElementHandler(w, req) + + resp = w.Result() + + if resp.StatusCode != 200 { + t.Fatalf("status code is not 200: %d", resp.StatusCode) + } + + defer resp.Body.Close() + + err = json.NewDecoder(resp.Body).Decode(&samlProvider) + if err != nil { + t.Fatalf("marshal decode error: %s", err) + } + + if samlProvider.AllowMissingAttributes == false { + t.Fatalf("allow missing attributes is false") + } + + // GET on the saml endpoint to see if we can return it + req = httptest.NewRequest("GET", "http://example.com/saml-setup", nil) + w = httptest.NewRecorder() + c.samlSetupHandler(w, req) + + resp = w.Result() + + if resp.StatusCode != 200 { + t.Fatalf("status code is not 200: %d", resp.StatusCode) + } + + defer resp.Body.Close() + + var samlProviders []saml.Provider + err = json.NewDecoder(resp.Body).Decode(&samlProviders) + if err != nil { + t.Fatalf("Cannot decode response from create user: %s", err) + } + if len(samlProviders) == 0 { + t.Fatalf("samlProviders is zero length") + } + if samlProviders[len(samlProviders)-1].ID != samlProvider.ID { + t.Fatalf("ID doesn't match: %s vs %s ", samlProviders[len(samlProviders)-1].ID, samlProvider.ID) + } + if samlProviders[len(samlProviders)-1].Acs != fmt.Sprintf("%s://%s/%s/%s", c.Protocol, c.Hostname, saml.ACS_URL, samlProvider.ID) { + t.Fatalf("ACS doesn't match") + } + if samlProviders[len(samlProviders)-1].AllowMissingAttributes == false { + t.Fatalf("allow missing attributes is false when getting all samlproviders") + } + + // delete req + req = httptest.NewRequest("DELETE", "http://example.com/saml-setup/"+samlProvider.ID, bytes.NewBuffer(payload)) + req.SetPathValue("id", samlProvider.ID) + w = httptest.NewRecorder() + c.samlSetupElementHandler(w, req) + + resp = w.Result() + + if resp.StatusCode != 200 { + t.Fatalf("status code is not 200: %d", resp.StatusCode) + } + + defer resp.Body.Close() + + // list to see if really deleted + req = httptest.NewRequest("GET", "http://example.com/saml-setup", nil) + w = httptest.NewRecorder() + c.samlSetupHandler(w, req) + + resp = w.Result() + + if resp.StatusCode != 200 { + t.Fatalf("status code is not 200: %d", resp.StatusCode) + } + + defer resp.Body.Close() + + var samlProviders2 []saml.Provider + err = json.NewDecoder(resp.Body).Decode(&samlProviders2) + if err != nil { + t.Fatalf("Cannot decode response from create user: %s", err) + } + if len(samlProviders)-1 != len(samlProviders2) { + t.Fatalf("samlProviders has wrong length") + } + +} + +func TestSAMLCallback(t *testing.T) { + c, err := newContext(&memorystorage.MockMemoryStorage{}, SERVER_TYPE_VPN) + if err != nil { + t.Fatalf("Cannot create context") + } + c.Hostname = "example.inv" + c.Protocol = "https" + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.RequestURI == "/metadata" { + out, err := xml.Marshal(getSAMLCert("http://localhost.inv")) + if err != nil { + t.Fatalf("marshal error: %s", err) + } + w.Write(out) + return + } + w.WriteHeader(http.StatusBadRequest) + })) + + samlProvider := saml.Provider{ + Name: "testProvider", + MetadataURL: fmt.Sprintf("%s/metadata", ts.URL), + } + + payload, err := json.Marshal(samlProvider) + if err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest("POST", "http://example.com/api/saml-setup", bytes.NewBuffer(payload)) + w := httptest.NewRecorder() + c.samlSetupHandler(w, req) + + resp := w.Result() + + if resp.StatusCode != 200 { + t.Fatalf("status code is not 200: %d", resp.StatusCode) + } + + defer resp.Body.Close() + + err = json.NewDecoder(resp.Body).Decode(&samlProvider) + if err != nil { + t.Fatalf("Cannot decode response from create user: %s", err) + } + + if samlProvider.ID == "" { + t.Fatalf("samlprovider id is empty") + } + samlCallback := SAMLCallback{ + Code: "abc", + RedirectURI: "https://localhost.inv/something", + } + payload, err = json.Marshal(samlCallback) + if err != nil { + t.Fatal(err) + } + + c.SAML.Client.CreateSession(saml.SessionKey{ProviderID: samlProvider.ID, SessionID: "abc"}, saml.AuthenticatedUser{ID: "123", Login: "john@example.com", ExpiresAt: time.Now().AddDate(0, 0, 1)}) + + req = httptest.NewRequest("POST", "http://example.com/api/authmethods/saml/"+samlProvider.ID, bytes.NewBuffer(payload)) + req.SetPathValue("method", "saml") + req.SetPathValue("id", samlProvider.ID) + w = httptest.NewRecorder() + c.authMethodsByID(w, req) + + resp = w.Result() + + if resp.StatusCode != 200 { + t.Fatalf("status code is not 200: %d", resp.StatusCode) + } + + defer resp.Body.Close() + + var loginResponse login.LoginResponse + err = json.NewDecoder(resp.Body).Decode(&loginResponse) + if err != nil { + t.Fatalf("Cannot decode response from create user: %s", err) + } + + if !loginResponse.Authenticated { + t.Fatalf("Expected to be authenticated") + } + +} + +func TestOIDCFlow(t *testing.T) { + testUrl := "127.0.0.1:12346" + l, err := net.Listen("tcp", testUrl) + if err != nil { + t.Fatal(err) + } + + authURL := "http://" + testUrl + "/auth" + + // create a new OIDC connection + oidcProvider := oidc.OIDCProvider{ + Name: "test-oidc", + ClientID: "1-2-3-4", + ClientSecret: "9-9-9-9", + Scope: "openid", + DiscoveryURI: "http://" + testUrl + "/discovery.json", + } + jwtPrivateKey, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + t.Fatalf("can't generate jwt key: %s", err) + } + + // first create a new user + c, err := newContext(&memorystorage.MockMemoryStorage{}, SERVER_TYPE_VPN) + if err != nil { + t.Fatalf("Cannot create context") + } + c.Hostname = "example.inv" + c.Protocol = "http" + logging.Loglevel = 17 + + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + code := "thisisthecode" + + switch r.Method { + case http.MethodGet: + parsedURI, _ := url.Parse(r.RequestURI) + switch parsedURI.Path { + case "/discovery.json": + discovery := oidc.Discovery{ + Issuer: "test-issuer", + AuthorizationEndpoint: authURL, + TokenEndpoint: "http://" + testUrl + "/token", + JwksURI: "http://" + testUrl + "/jwks.json", + } + out, err := json.Marshal(discovery) + if err != nil { + t.Fatalf("json marshal error: %s", err) + } + w.Write(out) + return + case "/auth": + if oidcProvider.ClientID != r.URL.Query().Get("client_id") { + w.Write([]byte("client id mismatch")) + w.WriteHeader(http.StatusBadRequest) + return + } + if oidcProvider.Scope != r.URL.Query().Get("scope") { + w.Write([]byte("scope mismatch")) + w.WriteHeader(http.StatusBadRequest) + return + } + w.Write([]byte(code)) + case "/jwks.json": + publicKey := jwtPrivateKey.PublicKey + + jwks := oidc.Jwks{ + Keys: []oidc.JwksKey{ + { + Kid: "kid-id-1234", + Alg: "RS256", + Kty: "RSA", + Use: "sig", + N: base64.RawURLEncoding.EncodeToString(publicKey.N.Bytes()), + E: "AQAB", + }, + }, + } + out, err := json.Marshal(jwks) + if err != nil { + w.Write([]byte("jwks marshal error")) + w.WriteHeader(http.StatusBadRequest) + return + } + w.Write(out) + default: + w.WriteHeader(http.StatusNotFound) + } + case http.MethodPost: + parsedURI, _ := url.Parse(r.RequestURI) + switch parsedURI.Path { + case "/token": + if r.FormValue("grant_type") != "authorization_code" { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("wrong grant type")) + return + } + if r.FormValue("code") != code { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("wrong code")) + return + } + if oidcProvider.ClientID != r.FormValue("client_id") { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("client id mismatch")) + return + } + if oidcProvider.ClientSecret != r.FormValue("client_secret") { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("client secret mismatch")) + return + } + if c.Protocol+"://"+c.Hostname+oidcProvider.RedirectURI != r.FormValue("redirect_uri") { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(fmt.Sprintf("redirect uri mismatch: %s vs %s", oidcProvider.RedirectURI, r.FormValue("redirect_uri")))) + return + } + token := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), jwt.MapClaims{ + "iss": "test-server", + "sub": "john", + "email": "john@example.inv", + "role": "user", + "exp": time.Now().AddDate(0, 0, 1).Unix(), + "iat": time.Now().Unix(), + }) + token.Header["kid"] = "kid-id-1234" + + tokenString, err := token.SignedString(jwtPrivateKey) + if err != nil { + t.Fatalf("can't generate jwt token: %s", err) + w.WriteHeader(http.StatusBadRequest) + } + tokenRes := oidc.Token{ + AccessToken: tokenString, + IDToken: tokenString, + ExpiresIn: 180, + } + tokenBytes, _ := json.Marshal(tokenRes) + + w.Write([]byte(tokenBytes)) + default: + w.WriteHeader(http.StatusNotFound) + } + default: + w.WriteHeader(http.StatusBadRequest) + } + })) + + ts.Listener.Close() + ts.Listener = l + ts.Start() + defer ts.Close() + defer l.Close() + + payload, err := json.Marshal(oidcProvider) + if err != nil { + t.Fatal(err) + } + + // create new oidc provider + req := httptest.NewRequest("POST", "http://example.inv/api/oidc", bytes.NewBuffer(payload)) + w := httptest.NewRecorder() + c.oidcProviderHandler(w, req) + + resp := w.Result() + + if resp.StatusCode != 200 { + t.Fatalf("status code is not 200: %d", resp.StatusCode) + } + + defer resp.Body.Close() + + err = json.NewDecoder(resp.Body).Decode(&oidcProvider) + if err != nil { + t.Fatalf("Cannot decode response from create user: %s", err) + } + + if oidcProvider.ID == "" { + t.Fatalf("Was expecting oidc provider to have an ID") + } + + // get redirect URL + req = httptest.NewRequest("GET", "http://example.inv/api/authmethods/oidc/"+oidcProvider.ID, nil) + req.SetPathValue("id", oidcProvider.ID) + w = httptest.NewRecorder() + c.authMethodsByID(w, req) + + resp = w.Result() + defer resp.Body.Close() + + if resp.StatusCode != 200 { + t.Fatalf("status code is not 200: %d", resp.StatusCode) + } + + var authmethodsResponse AuthMethodsProvider + + err = json.NewDecoder(resp.Body).Decode(&authmethodsResponse) + + if err != nil { + t.Fatalf("cannot decode authmethodsresponse: %s", err) + } + if !strings.HasPrefix(authmethodsResponse.RedirectURI, authURL) { + t.Fatalf("expected authURL as prefix of redirect url. Redirect URL: %s", authmethodsResponse.RedirectURI) + } + + redirectURIParsed, err := url.Parse(authmethodsResponse.RedirectURI) + if err != nil { + t.Fatalf("could not parse redirect URI: %s", err) + } + state := redirectURIParsed.Query().Get("state") + if state == "" { + t.Fatalf("could not obtain state") + } + res, err := http.Get(authmethodsResponse.RedirectURI) + if err != nil { + t.Fatalf("http get redirect uri error: %s", err) + } + if res.StatusCode != 200 { + t.Fatalf("redirect uri statuscode not 200: %d", res.StatusCode) + } + code, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("body read error: %s", err) + } + + callback := OIDCCallback{ + Code: string(code), + State: state, + RedirectURI: oidcProvider.RedirectURI, + } + callbackPayload, err := json.Marshal(callback) + if err != nil { + t.Fatalf("callback marshal error: %s", err) + } + // execute callback + req = httptest.NewRequest("POST", "http://example.inv/api/authmethods/oidc/"+oidcProvider.ID, bytes.NewBuffer(callbackPayload)) + req.SetPathValue("id", oidcProvider.ID) + req.SetPathValue("method", "oidc") + w = httptest.NewRecorder() + c.authMethodsByID(w, req) + + resp = w.Result() + defer resp.Body.Close() + + if resp.StatusCode != 200 { + errorMessage, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("body read error after statuscode not 200 (%d): %s", resp.StatusCode, err) + } + t.Fatalf("status code is not 200: %d, errormessage: %s", resp.StatusCode, errorMessage) + } + + var loginResponse login.LoginResponse + + err = json.NewDecoder(resp.Body).Decode(&loginResponse) + if err != nil { + t.Fatalf("cannot decode login response: %s", err) + } + + if !loginResponse.Authenticated { + t.Fatalf("not authenticated: %+v", loginResponse) + } + if loginResponse.Token == "" { + t.Fatalf("no token received: %+v", loginResponse) + } +} diff --git a/rest/config.go b/rest/config.go new file mode 100644 index 0000000..51efdd6 --- /dev/null +++ b/rest/config.go @@ -0,0 +1,79 @@ +package rest + +import ( + "bytes" + "encoding/json" + "fmt" + "os/user" + "sync" + + "github.com/in4it/go-devops-platform/storage" +) + +var mu sync.Mutex + +func SaveConfig(c *Context) error { + mu.Lock() + defer mu.Unlock() + cCopy := *c + cCopy.SCIM = &SCIM{ // we don't save the client, but we want the token and enabled + EnableSCIM: c.SCIM.EnableSCIM, + Token: c.SCIM.Token, + } + cCopy.SAML = &SAML{ // we don't save the client, but we want the config + Providers: c.SAML.Providers, + } + cCopy.JWTKeys = nil // we retrieve JWTKeys from pem files at startup + cCopy.OIDCStore = nil // we save this separately + cCopy.UserStore = nil // we save this separately + cCopy.OIDCRenewal = nil // we don't save this + cCopy.LoginAttempts = nil // no need to save this + cCopy.Apps = nil // no need to save the app client + cCopy.Storage = nil // no need to save storage + out, err := json.Marshal(cCopy) + if err != nil { + return fmt.Errorf("context marshal error: %s", err) + } + err = c.Storage.Client.WriteFile(c.Storage.Client.ConfigPath("config.json"), out) + if err != nil { + return fmt.Errorf("config write error: %s", err) + } + // fix permissions + currentUser, err := user.Current() + if err != nil { + return fmt.Errorf("could not get current user: %s", err) + } + if currentUser.Username != "vpn" { + err = c.Storage.Client.EnsureOwnership(c.Storage.Client.ConfigPath("config.json"), "vpn") + if err != nil { + return fmt.Errorf("config write error: %s", err) + } + } + + return nil +} + +func GetConfig(storage storage.Iface) (*Context, error) { + var c *Context + + appDir := storage.GetPath() + + // check if config exists + if !storage.FileExists(storage.ConfigPath("config.json")) { + return getEmptyContext(appDir) + } + + data, err := storage.ReadFile(storage.ConfigPath("config.json")) + if err != nil { + return c, fmt.Errorf("config read error: %s", err) + } + decoder := json.NewDecoder(bytes.NewBuffer(data)) + err = decoder.Decode(&c) + if err != nil { + return c, fmt.Errorf("decode input error: %s", err) + } + + c.AppDir = appDir + + return c, nil +} diff --git a/rest/constants.go b/rest/constants.go new file mode 100644 index 0000000..48fe858 --- /dev/null +++ b/rest/constants.go @@ -0,0 +1,4 @@ +package rest + +const SERVER_TYPE_OBSERVABILITY = "observability" +const SERVER_TYPE_VPN = "vpn" diff --git a/rest/context.go b/rest/context.go new file mode 100644 index 0000000..da893cb --- /dev/null +++ b/rest/context.go @@ -0,0 +1,120 @@ +package rest + +import ( + "fmt" + "time" + + "github.com/in4it/go-devops-platform/auth/oidc" + oidcstore "github.com/in4it/go-devops-platform/auth/oidc/store" + oidcrenewal "github.com/in4it/go-devops-platform/auth/oidc/store/renewal" + "github.com/in4it/go-devops-platform/auth/provisioning/scim" + "github.com/in4it/go-devops-platform/auth/saml" + licensing "github.com/in4it/go-devops-platform/licensing" + "github.com/in4it/go-devops-platform/logging" + "github.com/in4it/go-devops-platform/rest/login" + "github.com/in4it/go-devops-platform/storage" + "github.com/in4it/go-devops-platform/users" +) + +func NewContext(storage storage.Iface, serverType string, userStore *users.UserStore, scimInstance scim.Iface, licenseUserCount int, cloudType string, apps map[string]AppClient) (*Context, error) { + return newContextWithParams(storage, serverType, userStore, scimInstance, licenseUserCount, cloudType, apps) +} + +func newContext(storage storage.Iface, serverType string) (*Context, error) { + userStore, err := users.NewUserStore(storage, 100) + if err != nil { + return &Context{}, fmt.Errorf("userstore initialization error: %s", err) + } + return newContextWithParams(storage, serverType, userStore, scim.New(storage, userStore, "", nil, nil), 100, "", map[string]AppClient{}) +} + +func newContextWithParams(storage storage.Iface, serverType string, userStore *users.UserStore, scimInstance scim.Iface, licenseUserCount int, cloudType string, apps map[string]AppClient) (*Context, error) { + c, err := GetConfig(storage) + if err != nil { + return c, fmt.Errorf("getConfig error: %s", err) + } + c.ServerType = serverType + + c.Storage = &Storage{ + Client: storage, + } + + c.JWTKeys, err = getJWTKeys(storage) + if err != nil { + return c, fmt.Errorf("getJWTKeys error: %s", err) + } + c.OIDCStore, err = oidcstore.NewStore(storage) + if err != nil { + return c, fmt.Errorf("getOIDCStore error: %s", err) + } + if c.OIDCProviders == nil { + c.OIDCProviders = []oidc.OIDCProvider{} + } + + c.LicenseUserCount = licenseUserCount + c.CloudType = cloudType + + go func() { // run license refresh + logging.DebugLog(fmt.Errorf("starting license refresh in background (current licenses: %d, cloud type: %s)", c.LicenseUserCount, c.CloudType)) + for { + time.Sleep(time.Hour * 24) + newLicenseCount := licensing.RefreshLicense(storage, c.CloudType, c.LicenseUserCount) + if newLicenseCount != c.LicenseUserCount { + logging.InfoLog(fmt.Sprintf("License changed from %d users to %d users", c.LicenseUserCount, newLicenseCount)) + c.LicenseUserCount = newLicenseCount + } + } + }() + + c.UserStore = userStore + + c.OIDCRenewal, err = oidcrenewal.NewRenewal(storage, c.TokenRenewalTimeMinutes, c.LogLevel, c.EnableOIDCTokenRenewal, c.OIDCStore, c.OIDCProviders, c.UserStore) + if err != nil { + return c, fmt.Errorf("oidcrenewal init error: %s", err) + } + + if c.LoginAttempts == nil { + c.LoginAttempts = make(login.Attempts) + } + + if c.SCIM == nil { + c.SCIM = &SCIM{ + Client: scimInstance, + Token: "", + EnableSCIM: false, + } + } else { + c.SCIM.Client = scimInstance + } + if c.SAML == nil { + providers := []saml.Provider{} + c.SAML = &SAML{ + Client: saml.New(&providers, storage, &c.Protocol, &c.Hostname), + Providers: &providers, + } + } else { + c.SAML.Client = saml.New(c.SAML.Providers, storage, &c.Protocol, &c.Hostname) + } + + c.Apps = &Apps{ + Clients: apps, + } + + return c, nil +} + +func getEmptyContext(appDir string) (*Context, error) { + randomString, err := oidc.GetRandomString(64) + if err != nil { + return nil, fmt.Errorf("couldn't generate random string for local kid") + } + c := &Context{ + AppDir: appDir, + JWTKeysKID: randomString, + TokenRenewalTimeMinutes: oidcrenewal.DEFAULT_RENEWAL_TIME_MINUTES, + LogLevel: logging.LOG_ERROR, + SCIM: &SCIM{EnableSCIM: false}, + SAML: &SAML{Providers: &[]saml.Provider{}}, + } + return c, nil +} diff --git a/rest/helpers.go b/rest/helpers.go new file mode 100644 index 0000000..8bfab34 --- /dev/null +++ b/rest/helpers.go @@ -0,0 +1,79 @@ +package rest + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "strings" +) + +func (c *Context) returnError(w http.ResponseWriter, err error, statusCode int) { + fmt.Println("========= ERROR =========") + fmt.Printf("Error: %s\n", err) + fmt.Println("=========================") + sendCorsHeaders(w, "", c.Hostname, c.Protocol) + w.WriteHeader(statusCode) + w.Write([]byte(`{"error": "` + strings.Replace(err.Error(), `"`, `\"`, -1) + `"}`)) +} + +func (c *Context) write(w http.ResponseWriter, res []byte) { + sendCorsHeaders(w, "", c.Hostname, c.Protocol) + w.WriteHeader(http.StatusOK) + w.Write(res) +} +func (c *Context) writeWithStatus(w http.ResponseWriter, res []byte, status int) { + sendCorsHeaders(w, "", c.Hostname, c.Protocol) + w.WriteHeader(status) + w.Write(res) +} + +func sendCorsHeaders(w http.ResponseWriter, headers string, hostname string, protocol string) { + if hostname == "" { + w.Header().Add("Access-Control-Allow-Origin", "*") + } else { + w.Header().Add("Access-Control-Allow-Origin", fmt.Sprintf("%s://%s", protocol, hostname)) + } + w.Header().Add("Access-Control-allow-methods", "GET,HEAD,POST,PUT,OPTIONS,DELETE,PATCH") + if headers != "" { + w.Header().Add("Access-Control-Allow-Headers", headers) + } +} + +func isAlphaNumeric(str string) bool { + for _, c := range str { + if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9')) { + return false + } + } + return true +} + +func getKidFromToken(token string) (string, error) { + jwtSplit := strings.Split(token, ".") + if len(jwtSplit) < 1 { + return "", fmt.Errorf("token split < 1") + } + data, err := base64.RawURLEncoding.DecodeString(jwtSplit[0]) + if err != nil { + return "", fmt.Errorf("could not base64 decode data part of jwt") + } + var header JwtHeader + err = json.Unmarshal(data, &header) + if err != nil { + return "", fmt.Errorf("could not unmarshal jwt data") + } + return header.Kid, nil +} + +func returnIndexOrNotFound(contents []byte) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.HasPrefix(r.URL.Path, "/api") { + w.WriteHeader(http.StatusOK) + w.Write(contents) + } else { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte("404 page not found\n")) + } + }) +} diff --git a/rest/helpers_test.go b/rest/helpers_test.go new file mode 100644 index 0000000..cfe500d --- /dev/null +++ b/rest/helpers_test.go @@ -0,0 +1,12 @@ +package rest + +import "testing" + +func TestIsAlphaNumeric(t *testing.T) { + if !isAlphaNumeric("abc123") { + t.Errorf("expected alphanumeric") + } + if isAlphaNumeric("abc@") { + t.Errorf("expected alphanumeric") + } +} diff --git a/rest/license.go b/rest/license.go new file mode 100644 index 0000000..9bc7fe7 --- /dev/null +++ b/rest/license.go @@ -0,0 +1,29 @@ +package rest + +import ( + "encoding/json" + "fmt" + "net/http" + + licensing "github.com/in4it/go-devops-platform/licensing" +) + +func (c *Context) licenseHandler(w http.ResponseWriter, r *http.Request) { + if r.PathValue("action") == "get-more" { + c.LicenseUserCount = licensing.RefreshLicense(c.Storage.Client, c.CloudType, c.LicenseUserCount) + } + + currentUserCount := c.UserStore.UserCount() + licenseResponse := LicenseResponse{LicenseUserCount: c.LicenseUserCount, CurrentUserCount: currentUserCount, CloudType: c.CloudType} + + if r.PathValue("action") == "get-more" { + licenseResponse.Key = licensing.GetLicenseKey(c.Storage.Client, c.CloudType) + } + + out, err := json.Marshal(licenseResponse) + if err != nil { + c.returnError(w, fmt.Errorf("oidcProviders marshal error"), http.StatusBadRequest) + return + } + c.write(w, out) +} diff --git a/rest/login/attempt.go b/rest/login/attempt.go new file mode 100644 index 0000000..2fe2f0c --- /dev/null +++ b/rest/login/attempt.go @@ -0,0 +1,51 @@ +package login + +import ( + "sync" + "time" +) + +var mu sync.Mutex + +type Attempts map[string][]Attempt + +type Attempt struct { + Timestamp time.Time +} + +func ClearAttemptsForLogin(attempts Attempts, login string) { + mu.Lock() + defer mu.Unlock() + attempts[login] = []Attempt{} +} + +func RecordAttempt(attempts Attempts, login string) { + mu.Lock() + defer mu.Unlock() + _, ok := attempts[login] + if !ok { + attempts[login] = []Attempt{} + } + attempts[login] = append(attempts[login], Attempt{Timestamp: time.Now()}) +} + +func CheckTooManyLogins(attempts Attempts, login string) bool { + threeMinutes := 3 * time.Minute + _, ok := attempts[login] + if ok { + loginAttempts := 0 + for _, loginAttempt := range attempts[login] { + if time.Since(loginAttempt.Timestamp) <= threeMinutes { + loginAttempts++ + } + } + if loginAttempts >= 3 { + if len(attempts[login]) > 3 { + index := len(attempts[login]) - 3 + attempts[login] = attempts[login][index:] + } + return true + } + } + return false +} diff --git a/rest/login/auth.go b/rest/login/auth.go new file mode 100644 index 0000000..5e64de2 --- /dev/null +++ b/rest/login/auth.go @@ -0,0 +1,50 @@ +package login + +import ( + "crypto/rsa" + "fmt" + + "github.com/in4it/go-devops-platform/mfa/totp" + "github.com/in4it/go-devops-platform/users" +) + +func Authenticate(loginReq LoginRequest, authIface AuthIface, jwtPrivateKey *rsa.PrivateKey, jwtKeyID string) (LoginResponse, users.User, error) { + loginResponse := LoginResponse{} + user, auth := authIface.AuthUser(loginReq.Login, loginReq.Password) + if auth && !user.Suspended { + if len(user.Factors) == 0 { // authentication without MFA + token, err := GetJWTToken(user.Login, user.Role, jwtPrivateKey, jwtKeyID) + if err != nil { + return loginResponse, user, fmt.Errorf("token generation failed: %s", err) + } + loginResponse.Authenticated = true + loginResponse.Token = token + } else { + if loginReq.FactorResponse.Name == "" { + loginResponse.Authenticated = false + loginResponse.MFARequired = true + for _, factor := range user.Factors { + loginResponse.Factors = append(loginResponse.Factors, factor.Name) + } + } else { + for _, factor := range user.Factors { + if factor.Name == loginReq.FactorResponse.Name { + ok, err := totp.Verify(factor.Secret, loginReq.FactorResponse.Code) + if err != nil { + return loginResponse, user, fmt.Errorf("MFA (totp) verify failed: %s", err) + } + if ok { // authentication with MFA + token, err := GetJWTToken(user.Login, user.Role, jwtPrivateKey, jwtKeyID) + if err != nil { + return loginResponse, user, fmt.Errorf("token generation failed: %s", err) + } + loginResponse.Authenticated = true + loginResponse.Token = token + } + } + } + } + } + } + return loginResponse, user, nil +} diff --git a/rest/login/auth_test.go b/rest/login/auth_test.go new file mode 100644 index 0000000..0137e5d --- /dev/null +++ b/rest/login/auth_test.go @@ -0,0 +1,123 @@ +package login + +import ( + "crypto/rand" + "crypto/rsa" + "encoding/base32" + "testing" + "time" + + "github.com/in4it/go-devops-platform/mfa/totp" + "github.com/in4it/go-devops-platform/users" +) + +type MockAuth struct { + AuthUserUser users.User + AuthUserResult bool +} + +func (m *MockAuth) AuthUser(login string, password string) (users.User, bool) { + return m.AuthUserUser, m.AuthUserResult +} + +func TestAuthenticate(t *testing.T) { + m := MockAuth{ + AuthUserUser: users.User{ + Login: "john", + }, + AuthUserResult: true, + } + loginReq := LoginRequest{ + Login: "john", + Password: "mypass", + } + privateKey, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + t.Fatalf("private key error: %s", err) + } + + loginResp, _, err := Authenticate(loginReq, &m, privateKey, "jwtKeyID") + if err != nil { + t.Fatalf("authentication error: %s", err) + } + if !loginResp.Authenticated { + t.Fatalf("expected to be authenticated") + } + if loginResp.Token == "" { + t.Fatalf("no token") + } +} +func TestAuthenticateMFANoToken(t *testing.T) { + m := MockAuth{ + AuthUserUser: users.User{ + Login: "john", + Factors: []users.Factor{ + { + Name: "test-factor", + Type: "test", + Secret: "secret", + }, + }, + }, + AuthUserResult: true, + } + loginReq := LoginRequest{ + Login: "john", + Password: "mypass", + } + privateKey, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + t.Fatalf("private key error: %s", err) + } + + loginResp, _, err := Authenticate(loginReq, &m, privateKey, "jwtKeyID") + if err != nil { + t.Fatalf("authentication error: %s", err) + } + if loginResp.Authenticated { + t.Fatalf("expected not to be authenticated") + } + if len(loginResp.Factors) == 0 { + t.Fatalf("expected to get factors") + } +} +func TestAuthenticateMFAWithToken(t *testing.T) { + secret := base32.StdEncoding.EncodeToString([]byte("secret")) + m := MockAuth{ + AuthUserUser: users.User{ + Login: "john", + Factors: []users.Factor{ + { + Name: "test-factor", + Type: "test", + Secret: secret, + }, + }, + }, + AuthUserResult: true, + } + token, err := totp.GetToken(secret, time.Now().Unix()/30) + if err != nil { + t.Fatalf("GetToken error: %s", err) + } + loginReq := LoginRequest{ + Login: "john", + Password: "mypass", + FactorResponse: FactorResponse{ + Name: "test-factor", + Code: token, + }, + } + privateKey, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + t.Fatalf("private key error: %s", err) + } + + loginResp, _, err := Authenticate(loginReq, &m, privateKey, "jwtKeyID") + if err != nil { + t.Fatalf("authentication error: %s", err) + } + if !loginResp.Authenticated { + t.Fatalf("expected to be authenticated") + } +} diff --git a/rest/login/jwt.go b/rest/login/jwt.go new file mode 100644 index 0000000..bba4e09 --- /dev/null +++ b/rest/login/jwt.go @@ -0,0 +1,27 @@ +package login + +import ( + "crypto/rsa" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +func GetJWTToken(login, role string, signKey *rsa.PrivateKey, kid string) (string, error) { + return GetJWTTokenWithExpiration(login, role, signKey, kid, time.Now().Add(time.Hour*72)) +} + +func GetJWTTokenWithExpiration(login, role string, signKey *rsa.PrivateKey, kid string, expiration time.Time) (string, error) { + token := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), jwt.MapClaims{ + "iss": "wireguard-server", + "sub": login, + "role": role, + "exp": expiration.Unix(), + "iat": time.Now().Unix(), + }) + token.Header["kid"] = kid + + tokenString, err := token.SignedString(signKey) + + return tokenString, err +} diff --git a/rest/login/types.go b/rest/login/types.go new file mode 100644 index 0000000..5b4d8ac --- /dev/null +++ b/rest/login/types.go @@ -0,0 +1,27 @@ +package login + +import "github.com/in4it/go-devops-platform/users" + +type AuthIface interface { + AuthUser(login string, password string) (users.User, bool) +} + +type LoginRequest struct { + Login string `json:"login"` + Password string `json:"password"` + FactorResponse FactorResponse `json:"factorResponse"` +} + +type FactorResponse struct { + Name string `json:"name"` + Code string `json:"code"` +} + +type LoginResponse struct { + Authenticated bool `json:"authenticated"` + Suspended bool `json:"suspended"` + NoLicense bool `json:"noLicense"` + Token string `json:"token,omitempty"` + MFARequired bool `json:"mfaRequired"` + Factors []string `json:"factors"` +} diff --git a/rest/middleware.go b/rest/middleware.go new file mode 100644 index 0000000..a7e2dd4 --- /dev/null +++ b/rest/middleware.go @@ -0,0 +1,187 @@ +package rest + +import ( + "context" + "fmt" + "log" + "net/http" + "strings" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/in4it/go-devops-platform/auth/oidc" + "github.com/in4it/go-devops-platform/users" +) + +type CustomValue string + +// auth middleware + +func (c *Context) authMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !c.SetupCompleted { + c.returnError(w, fmt.Errorf("setup not completed"), http.StatusUnauthorized) + return + } + if !strings.HasPrefix(r.Header.Get("Authorization"), "Bearer ") { + c.writeWithStatus(w, []byte(`{"error": "token not found"}`), http.StatusUnauthorized) + return + } + tokenString := strings.Replace(r.Header.Get("Authorization"), "Bearer ", "", -1) + if len(tokenString) == 0 { + c.returnError(w, fmt.Errorf("empty token"), http.StatusUnauthorized) + return + } + + // determine token to parse + var tokenToParse string + // is token an access token or a jwt from local auth? + kid, _ := getKidFromToken(tokenString) + if kid == c.JWTKeysKID { // local auth token + tokenToParse = tokenString + } else { + for _, oauth2Data := range c.OIDCStore.OAuth2Data { + if oauth2Data.Token.AccessToken == tokenString { + tokenToParse = oauth2Data.Token.IDToken + } + } + if tokenToParse == "" { + c.returnError(w, fmt.Errorf("token error: access token not found (wrong token or token expired)"), http.StatusUnauthorized) + return + } + } + token, err := jwt.Parse(tokenToParse, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Header["kid"]; !ok { + return nil, fmt.Errorf("no kid header found in token") + } + if token.Header["kid"] == c.JWTKeysKID { + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { + return nil, fmt.Errorf("local kid: unexpected signing method: %v", token.Header["alg"]) + } + return c.JWTKeys.PublicKey, nil + } + discoveryProviders := make([]oidc.Discovery, len(c.OIDCProviders)) + for k, oidcProvider := range c.OIDCProviders { + discovery, err := c.OIDCStore.GetDiscoveryURI(oidcProvider.DiscoveryURI) + if err != nil { + return nil, fmt.Errorf("couldn't retrieve discoveryURI from OIDC Provider (check discovery URI in OIDC settings). Error: %s", err) + } + discoveryProviders[k] = discovery + } + allJwks, err := c.OIDCStore.GetAllJwks(discoveryProviders) + if err != nil { + return nil, fmt.Errorf("couldn't retrieve JWKS URL from OIDC Provider (check discovery URI in OIDC settings). Error: %s", err) + } + publicKey, err := oidc.GetPublicKeyForToken(allJwks, discoveryProviders, token) + if err != nil { + return nil, fmt.Errorf("GetPublicKeyForToken error: %s", err) + } + return publicKey, nil + }) + if err != nil { + c.returnError(w, fmt.Errorf("token error: %s", err), http.StatusUnauthorized) + return + } + token.Claims.(jwt.MapClaims)["kid"] = token.Header["kid"] + ctx := context.WithValue(r.Context(), CustomValue("claims"), token.Claims.(jwt.MapClaims)) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// logging middleware + +// responseWriter is a minimal wrapper for http.ResponseWriter that allows the +// written HTTP status code to be captured for logging. +// MIT licensed +type responseWriter struct { + http.ResponseWriter + status int + wroteHeader bool +} + +func (rw *responseWriter) Status() int { + return rw.status +} + +func (rw *responseWriter) WriteHeader(code int) { + if rw.wroteHeader { + return + } + + rw.status = code + rw.ResponseWriter.WriteHeader(code) + rw.wroteHeader = true +} + +func (c *Context) loggingMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + wrappedResponse := &responseWriter{ResponseWriter: w} + next.ServeHTTP(wrappedResponse, r) + log.Printf("req=%s res=%d method=%s src=%s duration=%s", r.RequestURI, wrappedResponse.status, r.Method, r.RemoteAddr, time.Since(start)) + }) +} + +func (c *Context) corsMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodOptions { + sendCorsHeaders(w, r.Header.Get("Access-Control-Request-Headers"), c.Hostname, c.Protocol) + w.WriteHeader(http.StatusNoContent) + } else { + next.ServeHTTP(w, r) + } + }) +} +func (c *Context) injectUserMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user, err := c.GetUserFromRequest(r) + if err != nil { + c.returnError(w, fmt.Errorf("token error: %s", err), http.StatusUnauthorized) + return + } + + ctx := context.WithValue(r.Context(), CustomValue("user"), user) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +func (c *Context) isAdminMiddleware(next http.Handler) http.Handler { + return IsAdminMiddleware(next) +} + +func IsAdminMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user := r.Context().Value(CustomValue("user")).(users.User) + if user.Role != "admin" { + w.WriteHeader(http.StatusForbidden) + w.Write([]byte(`{ "error": "endpoint forbidden" }`)) + return + } + next.ServeHTTP(w, r) + }) +} + +func (c *Context) httpsRedirectMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if c.RedirectToHttps && r.TLS == nil { + if strings.HasPrefix(r.URL.Path, "/api") { + c.returnError(w, fmt.Errorf("non-tls requests disabled"), http.StatusForbidden) + return + } + http.Redirect(w, r, fmt.Sprintf("https://%s%s", r.Host, r.RequestURI), http.StatusMovedPermanently) + return + } + next.ServeHTTP(w, r) + }) +} + +func (c *Context) isSCIMEnabled(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !c.SCIM.EnableSCIM { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte(`{ "error": "SCIM Not Enabled" }`)) + return + } + next.ServeHTTP(w, r) + }) +} diff --git a/rest/profile.go b/rest/profile.go new file mode 100644 index 0000000..ff0fa79 --- /dev/null +++ b/rest/profile.go @@ -0,0 +1,139 @@ +package rest + +import ( + "encoding/json" + "fmt" + "net/http" + + "github.com/in4it/go-devops-platform/mfa/totp" + "github.com/in4it/go-devops-platform/users" +) + +func (c *Context) profilePasswordHandler(w http.ResponseWriter, r *http.Request) { + user := r.Context().Value(CustomValue("user")).(users.User) + switch r.Method { + case http.MethodPost: + var userInput users.User + decoder := json.NewDecoder(r.Body) + err := decoder.Decode(&userInput) + if err != nil { + c.returnError(w, fmt.Errorf("decode input error: %s", err), http.StatusBadRequest) + return + } + if userInput.Password == "" { + c.returnError(w, fmt.Errorf("no password supplied"), http.StatusBadRequest) + return + } + err = c.UserStore.UpdatePassword(user.ID, userInput.Password) + if err != nil { + c.returnError(w, fmt.Errorf("update password error: %s", err), http.StatusBadRequest) + return + } + + c.write(w, []byte(`{"result": "OK"}`)) + default: + c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) + + } +} + +func (c *Context) profileFactorsHandler(w http.ResponseWriter, r *http.Request) { + user := r.Context().Value(CustomValue("user")).(users.User) + switch r.Method { + case http.MethodGet: + factors := make([]users.Factor, len(user.Factors)) + copy(factors, user.Factors) + for k := range factors { + factors[k].Secret = "" // remove secret when outputting + } + out, err := json.Marshal(factors) + if err != nil { + c.returnError(w, fmt.Errorf("factors marshal error: %s", err), http.StatusBadRequest) + return + } + c.write(w, out) + case http.MethodPost: + var factor FactorRequest + decoder := json.NewDecoder(r.Body) + err := decoder.Decode(&factor) + if err != nil { + c.returnError(w, fmt.Errorf("decode factor error: %s", err), http.StatusBadRequest) + return + } + if factor.Secret == "" { + c.returnError(w, fmt.Errorf("no factor secret supplied"), http.StatusBadRequest) + return + } + if factor.Name == "" { + c.returnError(w, fmt.Errorf("no factor name supplied"), http.StatusBadRequest) + return + } + if len(factor.Name) > 16 { + c.returnError(w, fmt.Errorf("factor name too long"), http.StatusBadRequest) + return + } + if factor.Type == "" { + c.returnError(w, fmt.Errorf("no factor type supplied"), http.StatusBadRequest) + return + } + if factor.Code == "" { + c.returnError(w, fmt.Errorf("no factor code supplied"), http.StatusBadRequest) + return + } + + ok, err := totp.VerifyMultipleIntervals(factor.Secret, factor.Code, 20) + if err != nil { + c.returnError(w, fmt.Errorf("totp verify error: %s", err), http.StatusBadRequest) + return + } + + if !ok { + c.returnError(w, fmt.Errorf("code doesn't match. Try entering code again or try with a new QR code"), http.StatusBadRequest) + return + } + + user.Factors = append(user.Factors, users.Factor{Type: factor.Type, Secret: factor.Secret, Name: factor.Name}) + out, err := json.Marshal(user.Factors) + if err != nil { + c.returnError(w, fmt.Errorf("factors marshal error: %s", err), http.StatusBadRequest) + return + } + err = c.UserStore.UpdateUser(user) + if err != nil { + c.returnError(w, fmt.Errorf("coudn't update user: %s", err), http.StatusBadRequest) + return + } + c.write(w, out) + case http.MethodDelete: + factorName := r.PathValue("name") + if factorName == "" { + c.returnError(w, fmt.Errorf("no factor name supplied"), http.StatusBadRequest) + return + } + toDelete := -1 + for k := range user.Factors { + if user.Factors[k].Name == factorName { + toDelete = k + } + } + if toDelete == -1 { + c.returnError(w, fmt.Errorf("factor not found"), http.StatusBadRequest) + return + } + user.Factors = append(user.Factors[:toDelete], user.Factors[toDelete+1:]...) + err := c.UserStore.UpdateUser(user) + if err != nil { + c.returnError(w, fmt.Errorf("coudn't update user: %s", err), http.StatusBadRequest) + return + } + out, err := json.Marshal(user.Factors) + if err != nil { + c.returnError(w, fmt.Errorf("factors marshal error: %s", err), http.StatusBadRequest) + return + } + c.write(w, out) + default: + c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) + + } +} diff --git a/rest/resources/.gitignore b/rest/resources/.gitignore new file mode 100644 index 0000000..5e7d273 --- /dev/null +++ b/rest/resources/.gitignore @@ -0,0 +1,4 @@ +# Ignore everything in this directory +* +# Except this file +!.gitignore diff --git a/rest/router.go b/rest/router.go new file mode 100644 index 0000000..53e69f3 --- /dev/null +++ b/rest/router.go @@ -0,0 +1,62 @@ +package rest + +import ( + "io/fs" + "net/http" +) + +func (c *Context) getRouter(assets fs.FS, indexHtml []byte) *http.ServeMux { + mux := http.NewServeMux() + + // static files + mux.Handle("/assets/{filename}", http.FileServer(http.FS(assets))) + mux.Handle("/index.html", returnIndexOrNotFound(indexHtml)) + mux.Handle("/favicon.ico", http.FileServer(http.FS(assets))) + + // saml authentication + mux.Handle("/saml/", c.SAML.Client.GetRouter()) + + // endpoints with no authentication + mux.Handle("/api/context", http.HandlerFunc(c.contextHandler)) + mux.Handle("/api/auth", http.HandlerFunc(c.authHandler)) + mux.Handle("/api/authmethods", http.HandlerFunc(c.authMethods)) + mux.Handle("/api/authmethods/{method}/{id}", http.HandlerFunc(c.authMethodsByID)) + mux.Handle("/api/authmethods/{id}", http.HandlerFunc(c.authMethodsByID)) + mux.Handle("/api/version", http.HandlerFunc(c.version)) + mux.Handle("/api/upgrade", http.HandlerFunc(c.upgrade)) + mux.Handle("/", returnIndexOrNotFound(indexHtml)) + + // endpoints with no authentication (observability) + if c.ServerType == SERVER_TYPE_OBSERVABILITY { + mux.Handle("/api/observability/", c.Apps.Clients["observability"].GetRouter()) + } + + // scim + mux.Handle("/api/scim/", c.isSCIMEnabled(c.SCIM.Client.GetRouter())) + + // endpoints with authentication + mux.Handle("/api/userinfo", c.authMiddleware(c.injectUserMiddleware(http.HandlerFunc(c.userinfoHandler)))) + mux.Handle("/api/profile/password", c.authMiddleware(c.injectUserMiddleware(http.HandlerFunc(c.profilePasswordHandler)))) + mux.Handle("/api/profile/factors", c.authMiddleware(c.injectUserMiddleware(http.HandlerFunc(c.profileFactorsHandler)))) + mux.Handle("/api/profile/factors/{name}", c.authMiddleware(c.injectUserMiddleware(http.HandlerFunc(c.profileFactorsHandler)))) + + // endpoint with authentication (VPN) + if c.ServerType == SERVER_TYPE_VPN { + mux.Handle("/api/vpn/", c.authMiddleware(c.injectUserMiddleware(c.Apps.Clients["vpn"].GetRouter()))) + } + + // endpoints with authentication, with admin role + mux.Handle("/api/license", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.licenseHandler))))) + mux.Handle("/api/license/{action}", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.licenseHandler))))) + mux.Handle("/api/oidc", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.oidcProviderHandler))))) + mux.Handle("/api/oidc-renew-tokens", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.oidcRenewTokensHandler))))) + mux.Handle("/api/oidc/{id}", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.oidcProviderElementHandler))))) + mux.Handle("/api/setup/general", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.setupHandler))))) + mux.Handle("/api/scim-setup", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.scimSetupHandler))))) + mux.Handle("/api/saml-setup", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.samlSetupHandler))))) + mux.Handle("/api/saml-setup/{id}", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.samlSetupElementHandler))))) + mux.Handle("/api/users", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.usersHandler))))) + mux.Handle("/api/user/{id}", c.authMiddleware(c.injectUserMiddleware(c.isAdminMiddleware(http.HandlerFunc(c.userHandler))))) + + return mux +} diff --git a/rest/rsa.go b/rest/rsa.go new file mode 100644 index 0000000..31fe8ea --- /dev/null +++ b/rest/rsa.go @@ -0,0 +1,126 @@ +package rest + +/* + * Genarate rsa keys. (https://github.com/wardviaene/http-echo/blob/master/rsa.go) + */ + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "fmt" + "path" + + "github.com/golang-jwt/jwt/v5" + "github.com/in4it/go-devops-platform/storage" +) + +type JWTKeys struct { + PrivateKey *rsa.PrivateKey `json:"privateKey,omitempty"` + PublicKey *rsa.PublicKey `json:"publicKey,omitempty"` +} + +func getJWTKeys(storage storage.Iface) (*JWTKeys, error) { + + filename := storage.ConfigPath("pki/private.pem") + filenamePublicKey := storage.ConfigPath("pki/public.pem") + + if !storage.FileExists(filename) { + err := storage.EnsurePath(path.Dir(filename)) + if err != nil { + return nil, fmt.Errorf("ensure path error: %s", err) + } + err = createJWTKeys(storage, storage.ConfigPath("pki")) + if err != nil { + return nil, fmt.Errorf("createJWTKeys error: %s", err) + } + } + + signBytes, err := storage.ReadFile(filename) + if err != nil { + return nil, fmt.Errorf("private key read error: %s", err) + } + publicBytes, err := storage.ReadFile(filenamePublicKey) + if err != nil { + return nil, fmt.Errorf("private key read error: %s", err) + } + + signKey, err := jwt.ParseRSAPrivateKeyFromPEM(signBytes) + if err != nil { + return nil, fmt.Errorf("can't parse private key: %s", err) + } + publicKey, err := jwt.ParseRSAPublicKeyFromPEM(publicBytes) + if err != nil { + return nil, fmt.Errorf("can't parse public key: %s", err) + } + return &JWTKeys{PrivateKey: signKey, PublicKey: publicKey}, nil +} + +func createJWTKeys(storage storage.Iface, path string) error { + reader := rand.Reader + bitSize := 4096 + + key, err := rsa.GenerateKey(reader, bitSize) + if err != nil { + return err + } + + publicKey := key.PublicKey + + err = savePEMKey(storage, path+"/private.pem", key) + if err != nil { + return err + } + err = savePublicPEMKey(storage, path+"/public.pem", publicKey) + if err != nil { + return err + } + + return nil +} + +func savePEMKey(storage storage.Iface, fileName string, key *rsa.PrivateKey) error { + var buf bytes.Buffer + + var privateKey = &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + } + + err := pem.Encode(&buf, privateKey) + if err != nil { + return err + } + + err = storage.WriteFile(fileName, buf.Bytes()) + if err != nil { + return fmt.Errorf("WriteFile error: %s", err) + } + return nil +} + +func savePublicPEMKey(storage storage.Iface, fileName string, pubkey rsa.PublicKey) error { + var buf bytes.Buffer + + asn1Bytes, err := x509.MarshalPKIXPublicKey(&pubkey) + if err != nil { + return err + } + + var pemkey = &pem.Block{ + Type: "PUBLIC KEY", + Bytes: asn1Bytes, + } + + err = pem.Encode(&buf, pemkey) + if err != nil { + return err + } + err = storage.WriteFile(fileName, buf.Bytes()) + if err != nil { + return fmt.Errorf("WriteFile error: %s", err) + } + return nil +} diff --git a/rest/rsa_test.go b/rest/rsa_test.go new file mode 100644 index 0000000..eea219b --- /dev/null +++ b/rest/rsa_test.go @@ -0,0 +1,40 @@ +package rest + +import ( + "bytes" + "crypto/x509" + "encoding/pem" + "testing" + + memorystorage "github.com/in4it/go-devops-platform/storage/memory" +) + +func TestGetJWTKeys(t *testing.T) { + mockStorage := memorystorage.MockMemoryStorage{} + keys, err := getJWTKeys(&mockStorage) + if err != nil { + t.Fatalf("error: %s", err) + } + privateKeyFromFile, err := mockStorage.ReadFile(mockStorage.ConfigPath("pki/private.pem")) + if err != nil { + t.Fatalf("read error: %s", err) + } + _, err = mockStorage.ReadFile(mockStorage.ConfigPath("pki/public.pem")) + if err != nil { + t.Fatalf("read error: %s", err) + } + + var buf bytes.Buffer + var privateKey = &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(keys.PrivateKey), + } + err = pem.Encode(&buf, privateKey) + if err != nil { + t.Fatalf("pem encode error: %s", err) + } + + if !bytes.Equal(privateKeyFromFile, buf.Bytes()) { + t.Fatalf("private keys don't match") + } +} diff --git a/rest/server.go b/rest/server.go new file mode 100644 index 0000000..c320cbd --- /dev/null +++ b/rest/server.go @@ -0,0 +1,69 @@ +package rest + +import ( + "crypto/tls" + "embed" + "fmt" + "io" + "io/fs" + "log" + "net/http" + + "github.com/in4it/go-devops-platform/logging" + "github.com/in4it/go-devops-platform/storage" + "golang.org/x/crypto/acme/autocert" +) + +var ( + //go:embed static + assets embed.FS + enableTLSWaiter chan (bool) = make(chan bool) + TLSWaiterCompleted bool +) + +func StartServer(httpPort, httpsPort int, serverType string, storage storage.Iface, c *Context) { + go handleSignals(c) + + assetsFS, err := fs.Sub(assets, "static") + if err != nil { + log.Fatalf("could not load static web assets") + } + + indexHtml, err := assetsFS.Open("index.html") + if err != nil { + log.Fatalf("could not load static web assets (index.html)") + } + indexHtmlBody, err := io.ReadAll(indexHtml) + if err != nil { + log.Fatalf("could not read static web assets (index.html)") + } + + certManager := autocert.Manager{} + + // HTTP Configuration + go func() { // start http server + log.Printf("Start http server on port %d", httpPort) + log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", httpPort), certManager.HTTPHandler(c.loggingMiddleware(c.httpsRedirectMiddleware(c.corsMiddleware(c.getRouter(assetsFS, indexHtmlBody))))))) + }() + + // TLS Configuration + if !c.EnableTLS || !canEnableTLS(c.Hostname) { + <-enableTLSWaiter + } + // only enable when TLS is enabled + + logging.DebugLog(fmt.Errorf("enabling TLS endpoint with let's encrypt for hostname '%s'", c.Hostname)) + certManager.Prompt = autocert.AcceptTOS + certManager.HostPolicy = autocert.HostWhitelist(c.Hostname) + certManager.Cache = autocert.DirCache("tls-certs") + tlsServer := &http.Server{ + Addr: fmt.Sprintf(":%d", httpsPort), + TLSConfig: &tls.Config{ + GetCertificate: certManager.GetCertificate, + }, + Handler: c.loggingMiddleware(c.corsMiddleware(c.getRouter(assetsFS, indexHtmlBody))), + } + c.Protocol = "https" + TLSWaiterCompleted = true + log.Fatal(tlsServer.ListenAndServeTLS("", "")) +} diff --git a/rest/setup.go b/rest/setup.go new file mode 100644 index 0000000..9425f4c --- /dev/null +++ b/rest/setup.go @@ -0,0 +1,364 @@ +package rest + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "strings" + "time" + + "github.com/google/uuid" + "github.com/in4it/go-devops-platform/auth/oidc" + "github.com/in4it/go-devops-platform/auth/saml" + licensing "github.com/in4it/go-devops-platform/licensing" + "github.com/in4it/go-devops-platform/users" +) + +func (c *Context) contextHandler(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + decoder := json.NewDecoder(r.Body) + var contextReq ContextRequest + err := decoder.Decode(&contextReq) + if err != nil { + c.returnError(w, fmt.Errorf("decode input error: %s", err), http.StatusBadRequest) + return + } + if !c.Storage.Client.FileExists(SETUP_CODE_FILE) { + c.SetupCompleted = true + } + if !c.SetupCompleted { + // check if tag hash is chosen + accessGranted := false + switch c.CloudType { + case "digitalocean": // check if the hashtag is set + if contextReq.TagHash != "" { + if !strings.HasPrefix(contextReq.TagHash, "vpnsecret-") { + c.returnError(w, fmt.Errorf("tag doesn't have the correct prefix. The tag needs to start with 'vpnsecret-'"), http.StatusUnauthorized) + return + } + accessGranted, err = licensing.HasDigitalOceanTagSet(http.Client{Timeout: 5 * time.Second}, contextReq.TagHash) + if err != nil { + c.returnError(w, fmt.Errorf("could not retrieve tags at this time: %s", err), http.StatusUnauthorized) + return + } + if !accessGranted { + c.returnError(w, fmt.Errorf("tag not found. Make sure the correct tag is attached to the droplet"), http.StatusUnauthorized) + return + } + } + case "aws": // check if the instance id is set + if contextReq.InstanceID != "" { + instanceID, err := licensing.GetAWSInstanceID(http.Client{Timeout: 5 * time.Second}) + if err != nil { + c.returnError(w, fmt.Errorf("could not retrieve instance id at this time: %s", err), http.StatusUnauthorized) + return + } + if strings.TrimPrefix(instanceID, "i-") == strings.TrimPrefix(contextReq.InstanceID, "i-") { + accessGranted = true + } else { + c.returnError(w, fmt.Errorf("instance id doesn't match"), http.StatusUnauthorized) + return + } + } + } + // check secret + if !accessGranted { + localSecret, err := c.Storage.Client.ReadFile(SETUP_CODE_FILE) + if err != nil { + c.returnError(w, fmt.Errorf("secret file read error: %s", err), http.StatusBadRequest) + return + } + if strings.TrimSpace(string(localSecret)) != contextReq.Secret { + c.returnError(w, fmt.Errorf("wrong secret provided"), http.StatusUnauthorized) + return + } + } + if contextReq.AdminPassword != "" { + adminUser := users.User{ + Login: "admin", + Password: contextReq.AdminPassword, + Role: "admin", + } + if c.UserStore.LoginExists("admin") { + err = c.UserStore.UpdateUser(adminUser) + if err != nil { + c.returnError(w, fmt.Errorf("could not update user: %s", err), http.StatusBadRequest) + return + } + } else { + _, err = c.UserStore.AddUser(adminUser) + if err != nil { + c.returnError(w, fmt.Errorf("could not add user: %s", err), http.StatusBadRequest) + return + } + } + + c.SetupCompleted = true + c.Hostname = contextReq.Hostname + protocol := contextReq.Protocol + protocol = strings.Replace(protocol, "http:", "http", -1) + protocol = strings.Replace(protocol, "https:", "https", -1) + c.Protocol = protocol + + err = SaveConfig(c) + if err != nil { + c.SetupCompleted = false + c.returnError(w, fmt.Errorf("unable to save file: %s", err), http.StatusBadRequest) + return + } + } + } + } + + out, err := json.Marshal(ContextSetupResponse{SetupCompleted: c.SetupCompleted, CloudType: c.CloudType, ServerType: c.ServerType}) + if err != nil { + c.returnError(w, err, http.StatusBadRequest) + return + } + c.write(w, out) +} + +func (c *Context) setupHandler(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + setupRequest := GeneralSetupRequest{ + Hostname: c.Hostname, + EnableTLS: c.EnableTLS, + RedirectToHttps: c.RedirectToHttps, + DisableLocalAuth: c.LocalAuthDisabled, + EnableOIDCTokenRenewal: c.EnableOIDCTokenRenewal, + } + out, err := json.Marshal(setupRequest) + if err != nil { + c.returnError(w, fmt.Errorf("could not marshal SetupRequest: %s", err), http.StatusBadRequest) + return + } + c.write(w, out) + case http.MethodPost: + var setupRequest GeneralSetupRequest + decoder := json.NewDecoder(r.Body) + decoder.Decode(&setupRequest) + if c.Hostname != setupRequest.Hostname { + c.Hostname = setupRequest.Hostname + } + if c.RedirectToHttps != setupRequest.RedirectToHttps { + c.RedirectToHttps = setupRequest.RedirectToHttps + } + if c.EnableTLS != setupRequest.EnableTLS { + if !c.EnableTLS && setupRequest.EnableTLS && !TLSWaiterCompleted && canEnableTLS(c.Hostname) { + enableTLSWaiter <- true + } + c.EnableTLS = setupRequest.EnableTLS + } + if c.LocalAuthDisabled != setupRequest.DisableLocalAuth { + c.LocalAuthDisabled = setupRequest.DisableLocalAuth + } + if c.EnableOIDCTokenRenewal != setupRequest.EnableOIDCTokenRenewal { + c.EnableOIDCTokenRenewal = setupRequest.EnableOIDCTokenRenewal + c.OIDCRenewal.SetEnabled(c.EnableOIDCTokenRenewal) + } + err := SaveConfig(c) + if err != nil { + c.returnError(w, fmt.Errorf("could not save config to disk: %s", err), http.StatusBadRequest) + return + } + out, err := json.Marshal(setupRequest) + if err != nil { + c.returnError(w, fmt.Errorf("could not marshal SetupRequest: %s", err), http.StatusBadRequest) + return + } + c.write(w, out) + default: + c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) + } +} + +func (c *Context) scimSetupHandler(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + scimSetup := SCIMSetup{ + Enabled: c.SCIM.EnableSCIM, + } + if c.SCIM.EnableSCIM { + scimSetup.Token = c.SCIM.Token + scimSetup.BaseURL = fmt.Sprintf("%s://%s/%s", c.Protocol, c.Hostname, "api/scim/v2/") + } + out, err := json.Marshal(scimSetup) + if err != nil { + c.returnError(w, fmt.Errorf("could not marshal scim setup: %s", err), http.StatusBadRequest) + return + } + c.write(w, out) + case http.MethodPost: + saveConfig := false + var scimSetupRequest SCIMSetup + decoder := json.NewDecoder(r.Body) + decoder.Decode(&scimSetupRequest) + if scimSetupRequest.Enabled && !c.SCIM.EnableSCIM { + c.SCIM.EnableSCIM = true + saveConfig = true + } + if !scimSetupRequest.Enabled && c.SCIM.EnableSCIM { + c.SCIM.EnableSCIM = false + saveConfig = true + } + if scimSetupRequest.RegenerateToken || (scimSetupRequest.Enabled && c.SCIM.Token == "") { + // Generate new token + randomString, err := oidc.GetRandomString(64) + if err != nil { + c.returnError(w, fmt.Errorf("could not enable scim: %s", err), http.StatusBadRequest) + return + } + token := base64.StdEncoding.EncodeToString([]byte(randomString)) + scimSetupRequest.Token = token + c.SCIM.Token = token + c.SCIM.Client.UpdateToken(token) + saveConfig = true + } + if saveConfig { + // save config + err := SaveConfig(c) + if err != nil { + c.returnError(w, fmt.Errorf("could not save config to disk: %s", err), http.StatusBadRequest) + return + } + } + out, err := json.Marshal(scimSetupRequest) + if err != nil { + c.returnError(w, fmt.Errorf("could not marshal scim setup: %s", err), http.StatusBadRequest) + return + } + c.write(w, out) + default: + c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) + } +} + +func (c *Context) samlSetupHandler(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + samlProviders := make([]saml.Provider, len(*c.SAML.Providers)) + copy(samlProviders, *c.SAML.Providers) + for k := range samlProviders { + samlProviders[k].Issuer = fmt.Sprintf("%s://%s/%s/%s", c.Protocol, c.Hostname, saml.ISSUER_URL, samlProviders[k].ID) + samlProviders[k].Audience = fmt.Sprintf("%s://%s/%s/%s", c.Protocol, c.Hostname, saml.AUDIENCE_URL, samlProviders[k].ID) + samlProviders[k].Acs = fmt.Sprintf("%s://%s/%s/%s", c.Protocol, c.Hostname, saml.ACS_URL, samlProviders[k].ID) + } + out, err := json.Marshal(samlProviders) + if err != nil { + c.returnError(w, fmt.Errorf("oidcProviders marshal error"), http.StatusBadRequest) + return + } + c.write(w, out) + case http.MethodPost: + var samlProvider saml.Provider + decoder := json.NewDecoder(r.Body) + err := decoder.Decode(&samlProvider) + if err != nil { + c.returnError(w, fmt.Errorf("decode input error: %s", err), http.StatusBadRequest) + return + } + samlProvider.ID = uuid.New().String() + if samlProvider.Name == "" { + c.returnError(w, fmt.Errorf("name not set"), http.StatusBadRequest) + return + } + if samlProvider.MetadataURL == "" { + c.returnError(w, fmt.Errorf("metadata URL not set"), http.StatusBadRequest) + return + } + _, err = c.SAML.Client.HasValidMetadataURL(samlProvider.MetadataURL) + if err != nil { + c.returnError(w, fmt.Errorf("metadata error: %s", err), http.StatusBadRequest) + return + } + + *c.SAML.Providers = append(*c.SAML.Providers, samlProvider) + out, err := json.Marshal(samlProvider) + if err != nil { + c.returnError(w, fmt.Errorf("samlProvider marshal error: %s", err), http.StatusBadRequest) + return + } + err = SaveConfig(c) + if err != nil { + c.returnError(w, fmt.Errorf("saveConfig error: %s", err), http.StatusBadRequest) + return + } + c.write(w, out) + + default: + c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) + } +} + +func (c *Context) samlSetupElementHandler(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodDelete: + match := -1 + for k, samlProvider := range *c.SAML.Providers { + if samlProvider.ID == r.PathValue("id") { + match = k + } + } + if match == -1 { + c.returnError(w, fmt.Errorf("saml provider not found"), http.StatusBadRequest) + return + } + *c.SAML.Providers = append((*c.SAML.Providers)[:match], (*c.SAML.Providers)[match+1:]...) + // save config (changed providers) + err := SaveConfig(c) + if err != nil { + c.returnError(w, fmt.Errorf("saveConfig error: %s", err), http.StatusBadRequest) + return + } + c.write(w, []byte(`{ "deleted": "`+r.PathValue("id")+`" }`)) + case http.MethodPut: + var samlProvider saml.Provider + decoder := json.NewDecoder(r.Body) + err := decoder.Decode(&samlProvider) + if err != nil { + c.returnError(w, fmt.Errorf("decode input error: %s", err), http.StatusBadRequest) + return + } + samlProviderID := -1 + for k := range *c.SAML.Providers { + if (*c.SAML.Providers)[k].ID == r.PathValue("id") { + samlProviderID = k + } + } + if samlProviderID == -1 { + c.returnError(w, fmt.Errorf("cannot find saml provider: %s", err), http.StatusBadRequest) + return + } + saveConfig := false + if (*c.SAML.Providers)[samlProviderID].AllowMissingAttributes != samlProvider.AllowMissingAttributes { + (*c.SAML.Providers)[samlProviderID].AllowMissingAttributes = samlProvider.AllowMissingAttributes + saveConfig = true + } + if (*c.SAML.Providers)[samlProviderID].MetadataURL != samlProvider.MetadataURL { + _, err := c.SAML.Client.HasValidMetadataURL(samlProvider.MetadataURL) + if err != nil { + c.returnError(w, fmt.Errorf("metadata error: %s", err), http.StatusBadRequest) + return + } + (*c.SAML.Providers)[samlProviderID].MetadataURL = samlProvider.MetadataURL + saveConfig = true + } + out, err := json.Marshal(samlProvider) + if err != nil { + c.returnError(w, fmt.Errorf("samlProvider marshal error: %s", err), http.StatusBadRequest) + return + } + if saveConfig { + err = SaveConfig(c) + if err != nil { + c.returnError(w, fmt.Errorf("saveConfig error: %s", err), http.StatusBadRequest) + return + } + } + c.write(w, out) + default: + c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) + } +} diff --git a/rest/setup_test.go b/rest/setup_test.go new file mode 100644 index 0000000..58e5d43 --- /dev/null +++ b/rest/setup_test.go @@ -0,0 +1,293 @@ +package rest + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + licensing "github.com/in4it/go-devops-platform/licensing" + memorystorage "github.com/in4it/go-devops-platform/storage/memory" + "github.com/in4it/go-devops-platform/users" +) + +func TestContextHandlerSetupSecret(t *testing.T) { + storage := &memorystorage.MockMemoryStorage{} + + storage.WriteFile(SETUP_CODE_FILE, []byte(`secret setup code`)) + + userStore, err := users.NewUserStore(storage, -1) + if err != nil { + t.Fatalf("new user store error") + } + c, err := getEmptyContext("appdir") + if err != nil { + t.Fatalf("cannot create empty context") + } + c.Storage = &Storage{Client: storage} + c.UserStore = userStore + + payload := ContextRequest{ + Secret: "secret setup code", + AdminPassword: "adminPassword", + } + payloadBytes, err := json.Marshal(payload) + if err != nil { + t.Fatalf("marshal error: %s", err) + } + req := httptest.NewRequest("POST", "http://example.com/setup", bytes.NewBuffer(payloadBytes)) + w := httptest.NewRecorder() + c.contextHandler(w, req) + + resp := w.Result() + + if resp.StatusCode != 200 { + t.Fatalf("status code is not 200: %d", resp.StatusCode) + } + + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("got read error: %s", err) + } + + var contextSetupResponse ContextSetupResponse + err = json.Unmarshal(body, &contextSetupResponse) + if err != nil { + t.Fatalf("unmarshal error: %s", err) + } + if !contextSetupResponse.SetupCompleted { + t.Fatalf("expected setup to be completed") + } +} + +func TestContextHandlerSetupWrongSecret(t *testing.T) { + storage := &memorystorage.MockMemoryStorage{} + + storage.WriteFile(SETUP_CODE_FILE, []byte(`secret setup code`)) + + userStore, err := users.NewUserStore(storage, -1) + if err != nil { + t.Fatalf("new user store error") + } + c, err := getEmptyContext("appdir") + if err != nil { + t.Fatalf("cannot create empty context") + } + c.Storage = &Storage{Client: storage} + c.UserStore = userStore + + payload := ContextRequest{ + AdminPassword: "adminPassword", + } + payloadBytes, err := json.Marshal(payload) + if err != nil { + t.Fatalf("marshal error: %s", err) + } + req := httptest.NewRequest("POST", "http://example.com/setup", bytes.NewBuffer(payloadBytes)) + w := httptest.NewRecorder() + c.contextHandler(w, req) + + resp := w.Result() + + if resp.StatusCode != 401 { + t.Fatalf("status code is not 401: %d", resp.StatusCode) + } + + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("got read error: %s", err) + } + + var contextSetupResponse ContextSetupResponse + err = json.Unmarshal(body, &contextSetupResponse) + if err != nil { + t.Fatalf("unmarshal error: %s", err) + } + if contextSetupResponse.SetupCompleted { + t.Fatalf("expected setup to not be completed") + } +} +func TestContextHandlerSetupWrongSecretPartial(t *testing.T) { + storage := &memorystorage.MockMemoryStorage{} + + storage.WriteFile(SETUP_CODE_FILE, []byte(`secret setup code`)) + + userStore, err := users.NewUserStore(storage, -1) + if err != nil { + t.Fatalf("new user store error") + } + c, err := getEmptyContext("appdir") + if err != nil { + t.Fatalf("cannot create empty context") + } + c.Storage = &Storage{Client: storage} + c.UserStore = userStore + + payload := ContextRequest{ + Secret: "secret setup cod", + AdminPassword: "adminPassword", + } + payloadBytes, err := json.Marshal(payload) + if err != nil { + t.Fatalf("marshal error: %s", err) + } + req := httptest.NewRequest("POST", "http://example.com/setup", bytes.NewBuffer(payloadBytes)) + w := httptest.NewRecorder() + c.contextHandler(w, req) + + resp := w.Result() + + if resp.StatusCode != 401 { + t.Fatalf("status code is not 401: %d", resp.StatusCode) + } + + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("got read error: %s", err) + } + + var contextSetupResponse ContextSetupResponse + err = json.Unmarshal(body, &contextSetupResponse) + if err != nil { + t.Fatalf("unmarshal error: %s", err) + } + if contextSetupResponse.SetupCompleted { + t.Fatalf("expected setup to not be completed") + } +} + +func TestContextHandlerSetupAWSInstanceID(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.RequestURI == "/latest/api/token" { + w.Write([]byte("this is a test token")) + return + } + if r.RequestURI == "/latest/meta-data/instance-id" { + w.Write([]byte("i-012aaaaaaaaaaaaa1")) + return + } + w.WriteHeader(http.StatusBadRequest) + })) + defer ts.Close() + licensing.MetadataIP = strings.TrimPrefix(ts.URL, "http://") + + storage := &memorystorage.MockMemoryStorage{} + + storage.WriteFile(SETUP_CODE_FILE, []byte(`secret setup code`)) + + userStore, err := users.NewUserStore(storage, -1) + if err != nil { + t.Fatalf("new user store error") + } + c, err := getEmptyContext("appdir") + if err != nil { + t.Fatalf("cannot create empty context") + } + c.Storage = &Storage{Client: storage} + c.UserStore = userStore + c.CloudType = "aws" + + payload := ContextRequest{ + InstanceID: "i-012aaaaaaaaaaaaa1", + AdminPassword: "adminPassword", + } + payloadBytes, err := json.Marshal(payload) + if err != nil { + t.Fatalf("marshal error: %s", err) + } + req := httptest.NewRequest("POST", "http://example.com/setup", bytes.NewBuffer(payloadBytes)) + w := httptest.NewRecorder() + c.contextHandler(w, req) + + resp := w.Result() + + if resp.StatusCode != 200 { + t.Fatalf("status code is not 200: %d", resp.StatusCode) + } + + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("got read error: %s", err) + } + + var contextSetupResponse ContextSetupResponse + err = json.Unmarshal(body, &contextSetupResponse) + if err != nil { + t.Fatalf("unmarshal error: %s", err) + } + if !contextSetupResponse.SetupCompleted { + t.Fatalf("expected setup to be completed") + } +} +func TestContextHandlerSetupDigitalOceanTag(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.RequestURI == "/metadata/v1/tags" { + w.Write([]byte("vpnsecret-this-is-a-secret-tag")) + return + } + w.WriteHeader(http.StatusBadRequest) + })) + defer ts.Close() + licensing.MetadataIP = strings.TrimPrefix(ts.URL, "http://") + + storage := &memorystorage.MockMemoryStorage{} + + storage.WriteFile(SETUP_CODE_FILE, []byte(`secret setup code`)) + + userStore, err := users.NewUserStore(storage, -1) + if err != nil { + t.Fatalf("new user store error") + } + c, err := getEmptyContext("appdir") + if err != nil { + t.Fatalf("cannot create empty context") + } + c.Storage = &Storage{Client: storage} + c.UserStore = userStore + c.CloudType = "digitalocean" + + payload := ContextRequest{ + TagHash: "vpnsecret-this-is-a-secret-tag", + AdminPassword: "adminPassword", + } + payloadBytes, err := json.Marshal(payload) + if err != nil { + t.Fatalf("marshal error: %s", err) + } + req := httptest.NewRequest("POST", "http://example.com/setup", bytes.NewBuffer(payloadBytes)) + w := httptest.NewRecorder() + c.contextHandler(w, req) + + resp := w.Result() + + if resp.StatusCode != 200 { + t.Fatalf("status code is not 200: %d", resp.StatusCode) + } + + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("got read error: %s", err) + } + + var contextSetupResponse ContextSetupResponse + err = json.Unmarshal(body, &contextSetupResponse) + if err != nil { + t.Fatalf("unmarshal error: %s", err) + } + if !contextSetupResponse.SetupCompleted { + t.Fatalf("expected setup to be completed") + } +} diff --git a/rest/signals.go b/rest/signals.go new file mode 100644 index 0000000..625667b --- /dev/null +++ b/rest/signals.go @@ -0,0 +1,38 @@ +package rest + +import ( + "fmt" + "log" + "os" + "os/signal" + "path" + "syscall" +) + +func handleSignals(c *Context) { + // write pid file so other process can find it + err := os.WriteFile(path.Join(c.AppDir, "rest-server.pid"), []byte(fmt.Sprintf("%d", os.Getpid())), 0664) + if err != nil { + log.Printf("Could not write pid file\n") + } + signalChannel := make(chan os.Signal, 1) + signal.Notify(signalChannel, syscall.SIGHUP) + for sig := range signalChannel { + switch sig { + case syscall.SIGHUP: + c.ReloadConfig() + } + } +} + +func (c *Context) ReloadConfig() { + newC, err := NewContext(c.Storage.Client, c.ServerType, c.UserStore, c.SCIM.Client, c.LicenseUserCount, c.CloudType, c.Apps.Clients) + if err != nil { + log.Printf("ReloadConfig failed: %s\n", err) + } + c.AppDir = newC.AppDir + c.Hostname = newC.Hostname + c.SetupCompleted = newC.SetupCompleted + c.UserStore = newC.UserStore + log.Printf("Config Reloaded!\n") +} diff --git a/rest/static/.gitignore b/rest/static/.gitignore new file mode 100644 index 0000000..5e7d273 --- /dev/null +++ b/rest/static/.gitignore @@ -0,0 +1,4 @@ +# Ignore everything in this directory +* +# Except this file +!.gitignore diff --git a/rest/tls.go b/rest/tls.go new file mode 100644 index 0000000..5eb94b0 --- /dev/null +++ b/rest/tls.go @@ -0,0 +1,17 @@ +package rest + +import ( + "log" + "strings" +) + +func canEnableTLS(hostname string) bool { + hostnameSplit := strings.Split(hostname, ":") + if hostnameSplit[0] != "localhost" { + return true + } else { + log.Printf("Not enabling TLS with lets encrypt. Hostname is localhost") + } + + return false +} diff --git a/rest/types.go b/rest/types.go new file mode 100644 index 0000000..f961dba --- /dev/null +++ b/rest/types.go @@ -0,0 +1,170 @@ +package rest + +import ( + "net/http" + "time" + + "github.com/in4it/go-devops-platform/auth/oidc" + oidcstore "github.com/in4it/go-devops-platform/auth/oidc/store" + oidcrenewal "github.com/in4it/go-devops-platform/auth/oidc/store/renewal" + "github.com/in4it/go-devops-platform/auth/provisioning/scim" + "github.com/in4it/go-devops-platform/auth/saml" + "github.com/in4it/go-devops-platform/rest/login" + "github.com/in4it/go-devops-platform/storage" + "github.com/in4it/go-devops-platform/users" +) + +const SETUP_CODE_FILE = "setup-code.txt" +const ADMIN_USER = "admin" + +type AppClient interface { + GetRouter() *http.ServeMux +} + +type Context struct { + AppDir string `json:"appDir,omitempty"` + ServerType string `json:"serverType,omitempty"` + SetupCompleted bool `json:"setupCompleted"` + Hostname string `json:"hostname,omitempty"` + Protocol string `json:"protocol,omitempty"` + JWTKeys *JWTKeys `json:"jwtKeys,omitempty"` + JWTKeysKID string `json:"jwtKeysKid,omitempty"` + OIDCProviders []oidc.OIDCProvider `json:"oidcProviders,omitempty"` + LocalAuthDisabled bool `json:"disableLocalAuth,omitempty"` + EnableTLS bool `json:"enableTLS,omitempty"` + RedirectToHttps bool `json:"redirectToHttps,omitempty"` + EnableOIDCTokenRenewal bool `json:"enableOIDCTokenRenewal,omitempty"` + OIDCStore *oidcstore.Store `json:"oidcStore,omitempty"` + UserStore *users.UserStore `json:"users,omitempty"` + OIDCRenewal *oidcrenewal.Renewal `json:"oidcRenewal,omitempty"` + LoginAttempts login.Attempts `json:"loginAttempts,omitempty"` + LicenseUserCount int `json:"licenseUserCount,omitempty"` + CloudType string `json:"cloudType,omitempty"` + TokenRenewalTimeMinutes int `json:"tokenRenewalTimeMinutes,omitempty"` + LogLevel int `json:"loglevel,omitempty"` + SCIM *SCIM `json:"scim,omitempty"` + SAML *SAML `json:"saml,omitempty"` + Apps *Apps `json:"apps,omitempty"` + Storage *Storage `json:"storage,omitempty"` +} +type SCIM struct { + EnableSCIM bool `json:"enableSCIM,omitempty"` + Token string `json:"token"` + Client scim.Iface `json:"client,omitempty"` +} +type SAML struct { + Providers *[]saml.Provider `json:"providers"` + Client saml.Iface `json:"client,omitempty"` +} +type Apps struct { + Clients map[string]AppClient `json:"clients,omitempty"` +} +type Storage struct { + Client storage.Iface `json:"client,omitempty"` +} + +type ContextRequest struct { + Secret string `json:"secret"` + TagHash string `json:"tagHash"` + InstanceID string `json:"instanceID"` + AdminPassword string `json:"adminPassword"` + Hostname string `json:"hostname"` + Protocol string `json:"protocol"` +} +type ContextSetupResponse struct { + SetupCompleted bool `json:"setupCompleted"` + CloudType string `json:"cloudType"` + ServerType string `json:"serverType"` +} + +type AuthMethodsResponse struct { + LocalAuthDisabled bool `json:"localAuthDisabled"` + OIDCProviders []AuthMethodsProvider `json:"oidcProviders"` +} + +type AuthMethodsProvider struct { + ID string `json:"id"` + Name string `json:"name"` + RedirectURI string `json:"redirectURI,omitempty"` +} + +type OIDCCallback struct { + Code string `json:"code"` + State string `json:"state"` + RedirectURI string `json:"redirectURI"` +} +type SAMLCallback struct { + Code string `json:"code"` + RedirectURI string `json:"redirectURI"` +} + +type UserInfoResponse struct { + Login string `json:"login"` + Role string `json:"role"` + UserType string `json:"userType"` +} + +type GeneralSetupRequest struct { + Hostname string `json:"hostname"` + EnableTLS bool `json:"enableTLS"` + RedirectToHttps bool `json:"redirectToHttps"` + DisableLocalAuth bool `json:"disableLocalAuth"` + EnableOIDCTokenRenewal bool `json:"enableOIDCTokenRenewal"` +} + +type LicenseResponse struct { + LicenseUserCount int `json:"licenseUserCount"` + CurrentUserCount int `json:"currentUserCount,omitempty"` + CloudType string `json:"cloudType"` + Key string `json:"key,omitempty"` +} + +type ConnectionLicenseResponse struct { + LicenseUserCount int `json:"licenseUserCount"` + ConnectionCount int `json:"connectionCount"` +} + +type JwtHeader struct { + Alg string `json:"alg"` + Typ string `json:"typ"` + Kid string `json:"kid"` +} + +type UsersResponse struct { + ID string `json:"id"` + Login string `json:"login"` + Role string `json:"role"` + OIDCID string `json:"oidcID"` + SAMLID string `json:"samlID"` + Provisioned bool `json:"provisioned"` + Suspended bool `json:"suspended"` + ConnectionsDisabledOnAuthFailure bool `json:"connectionsDisabledOnAuthFailure"` + LastTokenRenewal time.Time `json:"lastTokenRenewal,omitempty"` + LastLogin string `json:"lastLogin"` +} + +type FactorRequest struct { + Name string `json:"name"` + Type string `json:"type"` + Secret string `json:"secret"` + Code string `json:"code"` +} + +type SCIMSetup struct { + Enabled bool `json:"enabled"` + Token string `json:"token,omitempty"` + RegenerateToken bool `json:"regenerateToken,omitempty"` + BaseURL string `json:"baseURL,omitempty"` +} + +type SAMLSetup struct { + Enabled bool `json:"enabled"` + MetadataURL string `json:"metadataURL,omitempty"` + RegenerateCert bool `json:"regenerateCert,omitempty"` +} + +type NewUserRequest struct { + Login string `json:"login"` + Role string `json:"role"` + Password string `json:"password,omitempty"` +} diff --git a/rest/users.go b/rest/users.go new file mode 100644 index 0000000..fc5bf9e --- /dev/null +++ b/rest/users.go @@ -0,0 +1,274 @@ +package rest + +import ( + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/in4it/go-devops-platform/storage" + "github.com/in4it/go-devops-platform/users" +) + +func (c *Context) GetUserFromRequest(r *http.Request) (users.User, error) { + claims := r.Context().Value(CustomValue("claims")).(jwt.MapClaims) + sub, ok := claims["sub"] + if !ok { + return users.User{}, fmt.Errorf("userinfoHandler: subject not found in token") + } + iss, ok := claims["iss"] + if !ok { + return users.User{}, fmt.Errorf("userinfoHandler: issuer not found in token") + } + + kid, ok := claims["kid"] + if !ok { + return users.User{}, fmt.Errorf("userinfoHandler: kid not found in token") + } + + if kid == c.JWTKeysKID { + user, err := c.UserStore.GetUserByLogin(sub.(string)) + if err != nil { + return users.User{}, fmt.Errorf("GetUserByLogin: user not found") + } + return user, nil + } else { // user comes from oidc + oauth2DataIDs := []string{} + for _, oauth2Data := range c.OIDCStore.OAuth2Data { + if oauth2Data.Issuer == iss && oauth2Data.Subject == sub { + oauth2DataIDs = append(oauth2DataIDs, oauth2Data.ID) + } + } + if len(oauth2DataIDs) == 0 { + return users.User{}, fmt.Errorf("userinfoHandler: couldn't find user in oidc database") + } + user, err := c.UserStore.GetUserByOIDCIDs(oauth2DataIDs) + if err != nil { + return user, fmt.Errorf("get user by oidc id failed: %s", err) + } + return user, nil + } +} + +func (c *Context) usersHandler(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + users := c.UserStore.ListUsers() + userResponse := make([]UsersResponse, len(users)) + for k, user := range users { + userResponse[k].ID = user.ID + userResponse[k].Login = user.Login + userResponse[k].Role = user.Role + userResponse[k].OIDCID = user.OIDCID + userResponse[k].SAMLID = user.SAMLID + userResponse[k].Suspended = user.Suspended + userResponse[k].Provisioned = user.Provisioned + userResponse[k].ConnectionsDisabledOnAuthFailure = user.ConnectionsDisabledOnAuthFailure + if !user.LastLogin.IsZero() { + userResponse[k].LastLogin = user.LastLogin.UTC().Format(time.RFC3339) + } + for _, oauth2Data := range c.OIDCStore.OAuth2Data { + if oauth2Data.ID == user.OIDCID { + userResponse[k].LastTokenRenewal = oauth2Data.LastTokenRenewal + } + } + } + out, err := json.Marshal(userResponse) + if err != nil { + c.returnError(w, fmt.Errorf("could not marshal SetupRequest: %s", err), http.StatusBadRequest) + return + } + c.write(w, out) + case http.MethodPost: + var user NewUserRequest + decoder := json.NewDecoder(r.Body) + err := decoder.Decode(&user) + if err != nil { + c.returnError(w, fmt.Errorf("decode input error: %s", err), http.StatusBadRequest) + return + } + if !isAlphaNumeric(user.Login) { + c.returnError(w, fmt.Errorf("login not valid"), http.StatusBadRequest) + return + } + if user.Login == "" { + c.returnError(w, fmt.Errorf("login is empty"), http.StatusBadRequest) + return + } + if user.Password == "" { + c.returnError(w, fmt.Errorf("password is empty"), http.StatusBadRequest) + return + } + if user.Role != "user" && user.Role != "admin" { + c.returnError(w, fmt.Errorf("invalid role"), http.StatusBadRequest) + return + } + if c.UserStore.UserCount() >= c.LicenseUserCount { + c.returnError(w, fmt.Errorf("no more licenses available"), http.StatusBadRequest) + return + } + + newUser, err := c.UserStore.AddUser(users.User{Login: user.Login, Password: user.Password, Role: user.Role}) + if err != nil { + c.returnError(w, fmt.Errorf("add user error: %s", err), http.StatusBadRequest) + return + } + out, err := json.Marshal(newUser) + if err != nil { + c.returnError(w, fmt.Errorf("new user marshal error: %s", err), http.StatusBadRequest) + return + } + c.write(w, out) + default: + c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) + } +} + +func (c *Context) userHandler(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodDelete: + userID := r.PathValue("id") + err := c.UserStore.DeleteUserByID(userID) + if err != nil { + c.returnError(w, fmt.Errorf("delete user error: %s", err), http.StatusBadRequest) + return + } + err = c.UserStore.UserHooks.DeleteFunc(c.Storage.Client, users.User{ID: userID}) + if err != nil { + c.returnError(w, fmt.Errorf("could not delete all clients for user %s: %s", userID, err), http.StatusBadRequest) + return + } + c.write(w, []byte(`{"deleted": "`+userID+`"}`)) + case http.MethodPatch: + dbUser, err := c.UserStore.GetUserByID(r.PathValue("id")) + if err != nil { + c.returnError(w, fmt.Errorf("user not found: %s", err), http.StatusBadRequest) + return + } + var user users.User + decoder := json.NewDecoder(r.Body) + err = decoder.Decode(&user) + if err != nil { + c.returnError(w, fmt.Errorf("decode input error: %s", err), http.StatusBadRequest) + return + } + updateUser := false + if user.Role != "" && dbUser.Role != user.Role { + dbUser.Role = user.Role + updateUser = true + } + if dbUser.Suspended != user.Suspended { + dbUser.Suspended = user.Suspended + updateUser = true + if user.Suspended { // user is now suspended + err := c.UserStore.UserHooks.DisableFunc(c.Storage.Client, user) + if err != nil { + c.returnError(w, fmt.Errorf("could not delete all clients for user %s: %s", user.ID, err), http.StatusBadRequest) + return + } + } else { // user is now unsuspended + err := c.UserStore.UserHooks.ReactivateFunc(c.Storage.Client, user) + if err != nil { + c.returnError(w, fmt.Errorf("could not reactivate all clients for user %s: %s", user.ID, err), http.StatusBadRequest) + return + } + } + } + if updateUser { + err = c.UserStore.UpdateUser(dbUser) + if err != nil { + c.returnError(w, fmt.Errorf("update user error: %s", err), http.StatusBadRequest) + return + } + } + if user.Password != "" { + err = c.UserStore.UpdatePassword(user.ID, user.Password) + if err != nil { + c.returnError(w, fmt.Errorf("update password error: %s", err), http.StatusBadRequest) + return + } + } + out, err := json.Marshal(dbUser) + if err != nil { + c.returnError(w, fmt.Errorf("marshal dbuser error: %s", err), http.StatusBadRequest) + return + } + c.write(w, out) + default: + c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) + } +} + +func addOrModifyExternalUser(storage storage.Iface, userStore *users.UserStore, login, authType, externalAuthID string) (users.User, error) { + if userStore.LoginExists(login) { + existingUser, err := userStore.GetUserByLogin(login) + if err != nil { + return existingUser, fmt.Errorf("couldn't find existing user in database: %s", login) + } + + if authType == "oidc" { + existingUser.OIDCID = externalAuthID + } + if authType == "saml" { + existingUser.SAMLID = externalAuthID + } + + if existingUser.ConnectionsDisabledOnAuthFailure { // we can enable connections again after auth + err := userStore.UserHooks.ReactivateFunc(storage, existingUser) + if err != nil { + return existingUser, fmt.Errorf("could not reactivate all clients for user %s: %s", existingUser.ID, err) + } + existingUser.ConnectionsDisabledOnAuthFailure = false + } + + existingUser.LastLogin = time.Now() + + err = userStore.UpdateUser(existingUser) + if err != nil { + return existingUser, fmt.Errorf("couldn't update user: %s", login) + } + return existingUser, nil + } else { + newUser := users.User{ + Login: login, + Role: "user", + } + if authType == "oidc" { + newUser.OIDCID = externalAuthID + } + if authType == "saml" { + newUser.SAMLID = externalAuthID + } + + newUser.LastLogin = time.Now() + + newUserAdded, err := userStore.AddUser(newUser) + if err != nil { + return newUserAdded, fmt.Errorf("could not add user: %s", err) + } + return newUserAdded, nil + } +} + +func (c *Context) userinfoHandler(w http.ResponseWriter, r *http.Request) { + var response UserInfoResponse + + user := r.Context().Value(CustomValue("user")).(users.User) + + response.Login = user.Login + response.Role = user.Role + if user.OIDCID == "" { + response.UserType = "local" + } else { + response.UserType = "oidc" + } + + out, err := json.Marshal(response) + if err != nil { + c.returnError(w, fmt.Errorf("cannot marshal userinfo response: %s", err), http.StatusBadRequest) + return + } + c.write(w, out) + +} diff --git a/rest/users_test.go b/rest/users_test.go new file mode 100644 index 0000000..5776aa3 --- /dev/null +++ b/rest/users_test.go @@ -0,0 +1,47 @@ +package rest + +import ( + "bytes" + "encoding/json" + "net/http/httptest" + "testing" + + memorystorage "github.com/in4it/go-devops-platform/storage/memory" + "github.com/in4it/go-devops-platform/users" +) + +func TestCreateUser(t *testing.T) { + // first create a new user + storage := &memorystorage.MockMemoryStorage{} + + c, err := newContext(storage, SERVER_TYPE_VPN) + if err != nil { + t.Fatalf("Cannot create context") + } + + err = c.UserStore.Empty() + if err != nil { + t.Fatalf("Cannot create context") + } + + // create a user + payload := []byte(`{"id": "", "login": "testuser", "password": "tttt213", "role": "user", "oidcID": "", "samlID": "", "lastLogin": "", "provisioned": false, "role":"user","samlID":"","suspended":false}`) + req := httptest.NewRequest("POST", "http://example.com/users", bytes.NewBuffer(payload)) + w := httptest.NewRecorder() + c.usersHandler(w, req) + + resp := w.Result() + + if resp.StatusCode != 200 { + t.Fatalf("status code is not 200: %d", resp.StatusCode) + } + + defer resp.Body.Close() + + var user users.User + err = json.NewDecoder(resp.Body).Decode(&user) + if err != nil { + t.Fatalf("Cannot decode response from create user: %s", err) + } + +} diff --git a/rest/version.go b/rest/version.go new file mode 100644 index 0000000..e56cbf9 --- /dev/null +++ b/rest/version.go @@ -0,0 +1,67 @@ +package rest + +import ( + _ "embed" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +//go:generate cp -r ../../latest ./resources/version +//go:embed resources/version + +var version string + +const UPGRADESERVER_URI = "127.0.0.1:8081" + +func (c *Context) version(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + out, err := json.Marshal(map[string]string{"version": strings.TrimSpace(version)}) + if err != nil { + c.returnError(w, fmt.Errorf("version marshal error: %s", err), http.StatusBadRequest) + return + } + c.write(w, out) + default: + c.returnError(w, fmt.Errorf("method not supported"), http.StatusBadRequest) + } +} + +func (c *Context) upgrade(w http.ResponseWriter, r *http.Request) { + client := http.Client{ + Timeout: 10 * time.Second, + } + req, err := http.NewRequest(r.Method, "http://"+UPGRADESERVER_URI+"/upgrade", nil) + if err != nil { + c.returnError(w, fmt.Errorf("upgrade request error: %s", err), http.StatusBadRequest) + return + } + resp, err := client.Do(req) + if err != nil { + c.returnError(w, fmt.Errorf("upgrade error: %s", err), http.StatusBadRequest) + return + } + if resp.StatusCode != http.StatusOK { + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + c.returnError(w, fmt.Errorf("upgrade error: got status code: %d. Respons: %s", resp.StatusCode, bodyBytes), http.StatusBadRequest) + return + } + c.returnError(w, fmt.Errorf("upgrade error: got status code: %d. Couldn't get response", resp.StatusCode), http.StatusBadRequest) + return + } + + defer resp.Body.Close() + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + c.returnError(w, fmt.Errorf("body read error: %s", err), http.StatusBadRequest) + return + } + + c.write(w, bodyBytes) + +} diff --git a/users/store.go b/users/store.go index 498e205..9f68123 100644 --- a/users/store.go +++ b/users/store.go @@ -10,6 +10,14 @@ import ( const USERSTORE_FILENAME = "users.json" +func NewUserStoreWithHooks(storage storage.Iface, maxUsers int, hooks UserHooks) (*UserStore, error) { + userStore, err := NewUserStore(storage, maxUsers) + if err != nil { + return userStore, err + } + userStore.UserHooks = hooks + return userStore, nil +} func NewUserStore(storage storage.Iface, maxUsers int) (*UserStore, error) { userStore := &UserStore{ autoSave: true, diff --git a/users/types.go b/users/types.go index 0cc9afa..6a0a1c9 100644 --- a/users/types.go +++ b/users/types.go @@ -7,10 +7,11 @@ import ( ) type UserStore struct { - Users []User `json:"users"` - autoSave bool - maxUsers int - storage storage.Iface + Users []User `json:"users"` + autoSave bool + maxUsers int + storage storage.Iface + UserHooks UserHooks } type User struct { @@ -33,3 +34,13 @@ type Factor struct { Type string `json:"type"` Secret string `json:"secret"` } + +type DisableFunc func(storage.Iface, User) error +type ReactivateFunc func(storage.Iface, User) error +type DeleteFunc func(storage.Iface, User) error + +type UserHooks struct { + DisableFunc DisableFunc + ReactivateFunc ReactivateFunc + DeleteFunc DeleteFunc +}