Skip to content

Commit

Permalink
Support saml idp-initiated sso flow
Browse files Browse the repository at this point in the history
  • Loading branch information
tung2744 committed Aug 26, 2024
1 parent d0841d7 commit 5591759
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 56 deletions.
2 changes: 1 addition & 1 deletion pkg/auth/handler/saml/handle_login_result.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func (h *LoginResultHandler) handleLoginResult(

var response *samlprotocol.Response
err := h.Database.WithTx(func() error {
authnRequest := samlSessionEntry.AuthnRequest()
authnRequest, _ := samlSessionEntry.AuthnRequest()

resp, err := h.SAMLService.IssueSuccessResponse(
callbackURL,
Expand Down
135 changes: 92 additions & 43 deletions pkg/auth/handler/saml/login.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package saml

import (
"context"
"errors"
"fmt"
"net/http"
Expand Down Expand Up @@ -77,6 +78,17 @@ func (h *LoginHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
}

if err != nil {
if errors.Is(err, samlbinding.ErrNoRequest) {
// This is a IdP-initated flow
result := h.handleIdpInitiated(
r.Context(),
sp,
callbackURL,
)
h.writeResult(rw, r, result)
return
}

var parseRequestFailedErr *samlerror.ParseRequestFailedError
if errors.As(err, &parseRequestFailedErr) {
errResponse := samlprotocolhttp.NewExpectedSAMLErrorResult(err,
Expand Down Expand Up @@ -139,18 +151,90 @@ func (h *LoginHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
callbackURL = authnRequest.AssertionConsumerServiceURL
}

result := h.startSSOFlow(
r.Context(),
sp,
string(authnRequest.ToXMLBytes()),
callbackURL,
relayState,
authnRequest.GetIsPassive(),
)
h.writeResult(rw, r, result)
}

func (h *LoginHandler) handleUnknownError(
rw http.ResponseWriter, r *http.Request,
callbackURL string,
relayState string,
err any,
) {
now := h.Logger.Time.UTC()
e := panicutil.MakeError(err)

result := samlprotocolhttp.NewUnexpectedSAMLErrorResult(e,
samlprotocolhttp.SAMLResult{
CallbackURL: callbackURL,
Binding: samlprotocol.SAMLBindingHTTPPost,
Response: samlprotocol.NewUnexpectedServerErrorResponse(now, h.SAMLService.IdpEntityID()),
RelayState: relayState,
},
)
h.writeResult(rw, r, result)
}

func (h *LoginHandler) writeResult(
rw http.ResponseWriter, r *http.Request,
result httputil.Result,
) {
switch result := result.(type) {
case *samlprotocolhttp.SAMLErrorResult:
if result.IsUnexpected {
h.Logger.WithError(result.Cause).Error("unexpected error")
} else {
h.Logger.WithError(result).Warnln("saml login failed with expected error")
}
}
result.WriteResponse(rw, r)
}

func (h *LoginHandler) handleIdpInitiated(
ctx context.Context,
sp *config.SAMLServiceProviderConfig,
callbackURL string,
) httputil.Result {
return h.startSSOFlow(
ctx,
sp,
"",
callbackURL,
"",
false,
)
}

func (h *LoginHandler) startSSOFlow(
ctx context.Context,
sp *config.SAMLServiceProviderConfig,
authnRequestXML string,
callbackURL string,
relayState string,
isPassive bool,
) httputil.Result {
now := h.Clock.NowUTC()
issuer := h.SAMLService.IdpEntityID()

samlSessionEntry := &samlsession.SAMLSessionEntry{
ServiceProviderID: sp.ID,
AuthnRequestXML: string(authnRequest.ToXMLBytes()),
AuthnRequestXML: authnRequestXML,
CallbackURL: callbackURL,
RelayState: relayState,
}

if authnRequest.GetIsPassive() == true {
if isPassive == true {
// If IsPassive=true, no ui should be displayed.
// Authenticate by existing session or error.
var resolvedSession session.ResolvedSession
if s := session.GetSession(r.Context()); s != nil {
if s := session.GetSession(ctx); s != nil {
resolvedSession = s
}
// Ignore any session that is not allow to be used here
Expand All @@ -160,7 +244,8 @@ func (h *LoginHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {

if resolvedSession == nil {
// No session, return NoPassive error.
errResponse := samlprotocolhttp.NewExpectedSAMLErrorResult(err,
err := fmt.Errorf("no session but IsPassive=true")
result := samlprotocolhttp.NewExpectedSAMLErrorResult(err,
samlprotocolhttp.SAMLResult{
CallbackURL: callbackURL,
Binding: samlprotocol.SAMLBindingHTTPPost,
Expand All @@ -171,16 +256,14 @@ func (h *LoginHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
RelayState: relayState,
},
)
h.writeResult(rw, r, errResponse)
return
return result
} else {
// Else, authenticate with the existing session.
authInfo := resolvedSession.CreateNewAuthenticationInfoByThisSession()
// TODO(saml): If <Subject> is provided in the request,
// ensure the user of current session matches the subject.
result := h.LoginResultHandler.handleLoginResult(&authInfo, samlSessionEntry)
h.writeResult(rw, r, result)
return
return result
}

}
Expand All @@ -204,40 +287,6 @@ func (h *LoginHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
result := &httputil.ResultRedirect{
URL: endpoint.String(),
}
h.writeResult(rw, r, result)
}

func (h *LoginHandler) handleUnknownError(
rw http.ResponseWriter, r *http.Request,
callbackURL string,
relayState string,
err any,
) {
now := h.Logger.Time.UTC()
e := panicutil.MakeError(err)

result := samlprotocolhttp.NewUnexpectedSAMLErrorResult(e,
samlprotocolhttp.SAMLResult{
CallbackURL: callbackURL,
Binding: samlprotocol.SAMLBindingHTTPPost,
Response: samlprotocol.NewUnexpectedServerErrorResponse(now, h.SAMLService.IdpEntityID()),
RelayState: relayState,
},
)
h.writeResult(rw, r, result)
}

func (h *LoginHandler) writeResult(
rw http.ResponseWriter, r *http.Request,
result httputil.Result,
) {
switch result := result.(type) {
case *samlprotocolhttp.SAMLErrorResult:
if result.IsUnexpected {
h.Logger.WithError(result.Cause).Error("unexpected error")
} else {
h.Logger.WithError(result).Warnln("saml login failed with expected error")
}
}
result.WriteResponse(rw, r)
return result
}
7 changes: 7 additions & 0 deletions pkg/lib/saml/samlbinding/error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package samlbinding

import "github.com/authgear/authgear-server/pkg/lib/saml/samlerror"

var ErrNoRequest = &samlerror.ParseRequestFailedError{
Reason: "no SAMLRequest provided",
}
3 changes: 3 additions & 0 deletions pkg/lib/saml/samlbinding/http_redirect.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ func SAMLBindingHTTPRedirectParse(r *http.Request) (
signature := r.URL.Query().Get("Signature")
sigAlg := r.URL.Query().Get("SigAlg")
samlRequest := r.URL.Query().Get("SAMLRequest")
if samlRequest == "" {
return nil, ErrNoRequest
}
compressedRequest, err := base64.StdEncoding.DecodeString(samlRequest)
if err != nil {
return result, &samlerror.ParseRequestFailedError{
Expand Down
7 changes: 5 additions & 2 deletions pkg/lib/saml/samlsession/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,15 @@ func NewSAMLSession(entry *SAMLSessionEntry, uiInfo *SAMLUIInfo) *SAMLSession {
}
}

func (s *SAMLSessionEntry) AuthnRequest() *samlprotocol.AuthnRequest {
func (s *SAMLSessionEntry) AuthnRequest() (*samlprotocol.AuthnRequest, bool) {
if s.AuthnRequestXML == "" {
return nil, false
}
r, err := samlprotocol.ParseAuthnRequest([]byte(s.AuthnRequestXML))
if err != nil {
// We should ensure only valid request stored in the session
// So it is a panic if we got something invalid here
panic(err)
}
return r
return r, true
}
7 changes: 5 additions & 2 deletions pkg/lib/saml/samlsession/ui.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,13 @@ func (r *UIService) RemoveSAMLSessionID(w http.ResponseWriter, req *http.Request

func (r *UIService) ResolveUIInfo(entry *SAMLSessionEntry) (*SAMLUIInfo, error) {
var prompt []string
authnRequest := entry.AuthnRequest()
authnRequest, authnRequestExist := entry.AuthnRequest()
switch {
case !authnRequestExist:
// This is an Idp-Initiated flow, allow user to select_account or login
prompt = []string{}
case authnRequest.GetIsPassive() == false && authnRequest.GetForceAuthn() == false:
prompt = []string{"select_account"}
prompt = []string{}
case authnRequest.GetIsPassive() == false && authnRequest.GetForceAuthn() == true:
prompt = []string{"login"}
case authnRequest.GetIsPassive() == true && authnRequest.GetForceAuthn() == false:
Expand Down
24 changes: 16 additions & 8 deletions pkg/lib/saml/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,11 @@ func (s *Service) IssueSuccessResponse(

now := s.Clock.NowUTC()
issuerID := s.IdpEntityID()
response := samlprotocol.NewSuccessResponse(now, issuerID, inResponseToAuthnRequest.ID)
inResponseTo := ""
if inResponseToAuthnRequest != nil {
inResponseTo = inResponseToAuthnRequest.ID
}
response := samlprotocol.NewSuccessResponse(now, issuerID, inResponseTo)

// Referencing other SAML Idp implementations,
// use ACS url as default value of destination, recipient and audience
Expand All @@ -276,15 +280,17 @@ func (s *Service) IssueSuccessResponse(
}

nameIDFormat := sp.NameIDFormat
if nameIDFormatInRequest, ok := inResponseToAuthnRequest.GetNameIDFormat(); ok {
nameIDFormat = nameIDFormatInRequest
if inResponseToAuthnRequest != nil {
if nameIDFormatInRequest, ok := inResponseToAuthnRequest.GetNameIDFormat(); ok {
nameIDFormat = nameIDFormatInRequest
}
}

// allow for some clock skew
notBefore := now.Add(-1 * duration.ClockSkew)
assertionValidDuration := sp.AssertionValidDuration.Duration()
notOnOrAfter := now.Add(assertionValidDuration)
if notBefore.Before(inResponseToAuthnRequest.IssueInstant) {
if inResponseToAuthnRequest != nil && notBefore.Before(inResponseToAuthnRequest.IssueInstant) {
notBefore = inResponseToAuthnRequest.IssueInstant
notOnOrAfter = notBefore.Add(assertionValidDuration)
}
Expand All @@ -293,7 +299,7 @@ func (s *Service) IssueSuccessResponse(
NotBefore: notBefore,
NotOnOrAfter: notOnOrAfter,
}
if inResponseToAuthnRequest.Conditions != nil {
if inResponseToAuthnRequest != nil && inResponseToAuthnRequest.Conditions != nil {
// Only allow conditions which are stricter than what we set by default
if !inResponseToAuthnRequest.Conditions.NotBefore.IsZero() && inResponseToAuthnRequest.Conditions.NotBefore.After(notBefore) {
conditions.NotBefore = inResponseToAuthnRequest.Conditions.NotBefore
Expand All @@ -312,8 +318,10 @@ func (s *Service) IssueSuccessResponse(
}

// Include audiences requested
for _, aud := range inResponseToAuthnRequest.CollectAudiences() {
audiences.Add(aud)
if inResponseToAuthnRequest != nil {
for _, aud := range inResponseToAuthnRequest.CollectAudiences() {
audiences.Add(aud)
}
}

audienceRestriction := samlprotocol.AudienceRestriction{
Expand Down Expand Up @@ -354,7 +362,7 @@ func (s *Service) IssueSuccessResponse(
{
Method: "urn:oasis:names:tc:SAML:2.0:cm:bearer",
SubjectConfirmationData: &samlprotocol.SubjectConfirmationData{
InResponseTo: inResponseToAuthnRequest.ID,
InResponseTo: inResponseTo,
NotOnOrAfter: notOnOrAfter,
Recipient: recipient,
},
Expand Down

0 comments on commit 5591759

Please sign in to comment.