Skip to content

Commit

Permalink
Rename param / fix error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
motoki317 committed Jun 1, 2024
1 parent 69e36a3 commit ac2678e
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 14 deletions.
4 changes: 2 additions & 2 deletions internal/provider/generic_oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ func (o *GenericOAuth) Setup() error {
}

// GetLoginURL provides the login url for the given redirect uri and state
func (o *GenericOAuth) GetLoginURL(redirectURI, state string, forcePrompt bool) string {
return o.OAuthGetLoginURL(redirectURI, state, forcePrompt)
func (o *GenericOAuth) GetLoginURL(redirectURI, state string, allowPrompt bool) string {
return o.OAuthGetLoginURL(redirectURI, state, allowPrompt)
}

// ExchangeCode exchanges the given redirect uri and code for a token
Expand Down
2 changes: 1 addition & 1 deletion internal/provider/generic_oauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func TestGenericOAuthGetLoginURL(t *testing.T) {
}

// Check url
uri, err := url.Parse(p.GetLoginURL("http://example.com/_oauth", "state", false))
uri, err := url.Parse(p.GetLoginURL("http://example.com/_oauth", "state", true))
assert.Nil(err)
assert.Equal("https", uri.Scheme)
assert.Equal("provider.com", uri.Host)
Expand Down
4 changes: 2 additions & 2 deletions internal/provider/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ func (o *OIDC) Setup() error {
}

// GetLoginURL provides the login url for the given redirect uri and state
func (o *OIDC) GetLoginURL(redirectURI, state string, forcePrompt bool) string {
return o.OAuthGetLoginURL(redirectURI, state, forcePrompt)
func (o *OIDC) GetLoginURL(redirectURI, state string, allowPrompt bool) string {
return o.OAuthGetLoginURL(redirectURI, state, allowPrompt)
}

// ExchangeCode exchanges the given redirect uri and code for a token
Expand Down
6 changes: 3 additions & 3 deletions internal/provider/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type Providers struct {
// Provider is used to authenticate users
type Provider interface {
Name() string
GetLoginURL(redirectURI, state string, forcePrompt bool) string
GetLoginURL(redirectURI, state string, allowPrompt bool) string
ExchangeCode(redirectURI, code string) (string, error)
GetUser(token, UserPath string) (string, error)
Setup() error
Expand Down Expand Up @@ -62,11 +62,11 @@ func (p *OAuthProvider) ConfigCopy(redirectURI string) oauth2.Config {
}

// OAuthGetLoginURL provides a base "GetLoginURL" for proiders using OAauth2
func (p *OAuthProvider) OAuthGetLoginURL(redirectURI, state string, forcePrompt bool) string {
func (p *OAuthProvider) OAuthGetLoginURL(redirectURI, state string, allowPrompt bool) string {
config := p.ConfigCopy(redirectURI)

var options []oauth2.AuthCodeOption
if p.Prompt != "" && !forcePrompt {
if p.Prompt != "" && allowPrompt {
options = append(options, oauth2.SetAuthURLParam("prompt", p.Prompt))
}
if p.Resource != "" {
Expand Down
18 changes: 12 additions & 6 deletions internal/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ func (s *Server) authHandler(providerName, rule string, soft bool) http.HandlerF
unauthorized(w)
return
} else {
s.authRedirect(logger, w, r, p, currentUrl(r), false)
s.authRedirect(logger, w, r, p, currentUrl(r), true)
return
}
}
Expand Down Expand Up @@ -292,8 +292,14 @@ func (s *Server) AuthCallbackHandler() http.HandlerFunc {
// Check error
authError := r.URL.Query().Get("error")
if authError == "login_required" || authError == "consent_required" {
// Retry with without prompt (none) parameter
s.authRedirect(logger, w, r, p, redirect, true)
// Retry without the 'prompt' parameter (which was possibly 'none' or some other value)
s.authRedirect(logger, w, r, p, redirect, false)
return
}
if authError != "" {
// Other errors such as provider server error
logger.WithField("provider_error", authError).Error("Authorize failed with provider")
http.Error(w, "Provider error", 500)
return
}

Expand Down Expand Up @@ -363,7 +369,7 @@ func (s *Server) LoginHandler(providerName string) http.HandlerFunc {
}

// Login
s.authRedirect(logger, w, r, p, redirectURL.String(), false)
s.authRedirect(logger, w, r, p, redirectURL.String(), true)
}
}

Expand Down Expand Up @@ -394,7 +400,7 @@ func (s *Server) LogoutHandler() http.HandlerFunc {
}
}

func (s *Server) authRedirect(logger *logrus.Entry, w http.ResponseWriter, r *http.Request, p provider.Provider, returnUrl string, forcePrompt bool) {
func (s *Server) authRedirect(logger *logrus.Entry, w http.ResponseWriter, r *http.Request, p provider.Provider, returnUrl string, allowPrompt bool) {
// Error indicates no cookie, generate nonce
err, nonce := Nonce()
if err != nil {
Expand All @@ -420,7 +426,7 @@ func (s *Server) authRedirect(logger *logrus.Entry, w http.ResponseWriter, r *ht
}

// Forward them on
loginURL := p.GetLoginURL(redirectUri(r), MakeState(returnUrl, p, nonce), forcePrompt)
loginURL := p.GetLoginURL(redirectUri(r), MakeState(returnUrl, p, nonce), allowPrompt)
http.Redirect(w, r, loginURL, http.StatusTemporaryRedirect)

logger.WithFields(logrus.Fields{
Expand Down

0 comments on commit ac2678e

Please sign in to comment.