diff --git a/pkg/auth/handler/saml/handle_login_result.go b/pkg/auth/handler/saml/handle_login_result.go index a1c97b8d73..9adb643ef0 100644 --- a/pkg/auth/handler/saml/handle_login_result.go +++ b/pkg/auth/handler/saml/handle_login_result.go @@ -50,7 +50,7 @@ func (h *LoginResultHandler) handleLoginResult( var response *samlprotocol.Response err := h.Database.WithTx(func() error { - authnRequest := samlSessionEntry.AuthnRequest() + authnRequest, _ := samlSessionEntry.AuthnRequest() resp, err := h.SAMLService.IssueSuccessResponse( callbackURL, diff --git a/pkg/auth/handler/saml/login.go b/pkg/auth/handler/saml/login.go index d53288d0b7..3e3a67de9e 100644 --- a/pkg/auth/handler/saml/login.go +++ b/pkg/auth/handler/saml/login.go @@ -1,6 +1,7 @@ package saml import ( + "context" "errors" "fmt" "net/http" @@ -77,6 +78,17 @@ func (h *LoginHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { } if err != nil { + if errors.Is(err, samlbinding.ErrNoRequest) { + // This is a IdP-initated flow + result := h.handleIdpInitiated( + r.Context(), + sp, + callbackURL, + ) + h.writeResult(rw, r, result) + return + } + var parseRequestFailedErr *samlerror.ParseRequestFailedError if errors.As(err, &parseRequestFailedErr) { errResponse := samlprotocolhttp.NewExpectedSAMLErrorResult(err, @@ -139,18 +151,90 @@ func (h *LoginHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { callbackURL = authnRequest.AssertionConsumerServiceURL } + result := h.startSSOFlow( + r.Context(), + sp, + string(authnRequest.ToXMLBytes()), + callbackURL, + relayState, + authnRequest.GetIsPassive(), + ) + h.writeResult(rw, r, result) +} + +func (h *LoginHandler) handleUnknownError( + rw http.ResponseWriter, r *http.Request, + callbackURL string, + relayState string, + err any, +) { + now := h.Logger.Time.UTC() + e := panicutil.MakeError(err) + + result := samlprotocolhttp.NewUnexpectedSAMLErrorResult(e, + samlprotocolhttp.SAMLResult{ + CallbackURL: callbackURL, + Binding: samlprotocol.SAMLBindingHTTPPost, + Response: samlprotocol.NewUnexpectedServerErrorResponse(now, h.SAMLService.IdpEntityID()), + RelayState: relayState, + }, + ) + h.writeResult(rw, r, result) +} + +func (h *LoginHandler) writeResult( + rw http.ResponseWriter, r *http.Request, + result httputil.Result, +) { + switch result := result.(type) { + case *samlprotocolhttp.SAMLErrorResult: + if result.IsUnexpected { + h.Logger.WithError(result.Cause).Error("unexpected error") + } else { + h.Logger.WithError(result).Warnln("saml login failed with expected error") + } + } + result.WriteResponse(rw, r) +} + +func (h *LoginHandler) handleIdpInitiated( + ctx context.Context, + sp *config.SAMLServiceProviderConfig, + callbackURL string, +) httputil.Result { + return h.startSSOFlow( + ctx, + sp, + "", + callbackURL, + "", + false, + ) +} + +func (h *LoginHandler) startSSOFlow( + ctx context.Context, + sp *config.SAMLServiceProviderConfig, + authnRequestXML string, + callbackURL string, + relayState string, + isPassive bool, +) httputil.Result { + now := h.Clock.NowUTC() + issuer := h.SAMLService.IdpEntityID() + samlSessionEntry := &samlsession.SAMLSessionEntry{ ServiceProviderID: sp.ID, - AuthnRequestXML: string(authnRequest.ToXMLBytes()), + AuthnRequestXML: authnRequestXML, CallbackURL: callbackURL, RelayState: relayState, } - if authnRequest.GetIsPassive() == true { + if isPassive == true { // If IsPassive=true, no ui should be displayed. // Authenticate by existing session or error. var resolvedSession session.ResolvedSession - if s := session.GetSession(r.Context()); s != nil { + if s := session.GetSession(ctx); s != nil { resolvedSession = s } // Ignore any session that is not allow to be used here @@ -160,7 +244,8 @@ func (h *LoginHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { if resolvedSession == nil { // No session, return NoPassive error. - errResponse := samlprotocolhttp.NewExpectedSAMLErrorResult(err, + err := fmt.Errorf("no session but IsPassive=true") + result := samlprotocolhttp.NewExpectedSAMLErrorResult(err, samlprotocolhttp.SAMLResult{ CallbackURL: callbackURL, Binding: samlprotocol.SAMLBindingHTTPPost, @@ -171,16 +256,14 @@ func (h *LoginHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { RelayState: relayState, }, ) - h.writeResult(rw, r, errResponse) - return + return result } else { // Else, authenticate with the existing session. authInfo := resolvedSession.CreateNewAuthenticationInfoByThisSession() // TODO(saml): If is provided in the request, // ensure the user of current session matches the subject. result := h.LoginResultHandler.handleLoginResult(&authInfo, samlSessionEntry) - h.writeResult(rw, r, result) - return + return result } } @@ -204,40 +287,6 @@ func (h *LoginHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { result := &httputil.ResultRedirect{ URL: endpoint.String(), } - h.writeResult(rw, r, result) -} - -func (h *LoginHandler) handleUnknownError( - rw http.ResponseWriter, r *http.Request, - callbackURL string, - relayState string, - err any, -) { - now := h.Logger.Time.UTC() - e := panicutil.MakeError(err) - - result := samlprotocolhttp.NewUnexpectedSAMLErrorResult(e, - samlprotocolhttp.SAMLResult{ - CallbackURL: callbackURL, - Binding: samlprotocol.SAMLBindingHTTPPost, - Response: samlprotocol.NewUnexpectedServerErrorResponse(now, h.SAMLService.IdpEntityID()), - RelayState: relayState, - }, - ) - h.writeResult(rw, r, result) -} -func (h *LoginHandler) writeResult( - rw http.ResponseWriter, r *http.Request, - result httputil.Result, -) { - switch result := result.(type) { - case *samlprotocolhttp.SAMLErrorResult: - if result.IsUnexpected { - h.Logger.WithError(result.Cause).Error("unexpected error") - } else { - h.Logger.WithError(result).Warnln("saml login failed with expected error") - } - } - result.WriteResponse(rw, r) + return result } diff --git a/pkg/lib/saml/samlbinding/error.go b/pkg/lib/saml/samlbinding/error.go new file mode 100644 index 0000000000..0370191cf1 --- /dev/null +++ b/pkg/lib/saml/samlbinding/error.go @@ -0,0 +1,7 @@ +package samlbinding + +import "github.com/authgear/authgear-server/pkg/lib/saml/samlerror" + +var ErrNoRequest = &samlerror.ParseRequestFailedError{ + Reason: "no SAMLRequest provided", +} diff --git a/pkg/lib/saml/samlbinding/http_redirect.go b/pkg/lib/saml/samlbinding/http_redirect.go index b166b8f140..4fbd2fa9df 100644 --- a/pkg/lib/saml/samlbinding/http_redirect.go +++ b/pkg/lib/saml/samlbinding/http_redirect.go @@ -35,6 +35,9 @@ func SAMLBindingHTTPRedirectParse(r *http.Request) ( signature := r.URL.Query().Get("Signature") sigAlg := r.URL.Query().Get("SigAlg") samlRequest := r.URL.Query().Get("SAMLRequest") + if samlRequest == "" { + return nil, ErrNoRequest + } compressedRequest, err := base64.StdEncoding.DecodeString(samlRequest) if err != nil { return result, &samlerror.ParseRequestFailedError{ diff --git a/pkg/lib/saml/samlsession/model.go b/pkg/lib/saml/samlsession/model.go index 61ec1946d5..06bf215cfb 100644 --- a/pkg/lib/saml/samlsession/model.go +++ b/pkg/lib/saml/samlsession/model.go @@ -32,12 +32,15 @@ func NewSAMLSession(entry *SAMLSessionEntry, uiInfo *SAMLUIInfo) *SAMLSession { } } -func (s *SAMLSessionEntry) AuthnRequest() *samlprotocol.AuthnRequest { +func (s *SAMLSessionEntry) AuthnRequest() (*samlprotocol.AuthnRequest, bool) { + if s.AuthnRequestXML == "" { + return nil, false + } r, err := samlprotocol.ParseAuthnRequest([]byte(s.AuthnRequestXML)) if err != nil { // We should ensure only valid request stored in the session // So it is a panic if we got something invalid here panic(err) } - return r + return r, true } diff --git a/pkg/lib/saml/samlsession/ui.go b/pkg/lib/saml/samlsession/ui.go index a796e0fafb..a065104d52 100644 --- a/pkg/lib/saml/samlsession/ui.go +++ b/pkg/lib/saml/samlsession/ui.go @@ -66,10 +66,13 @@ func (r *UIService) RemoveSAMLSessionID(w http.ResponseWriter, req *http.Request func (r *UIService) ResolveUIInfo(entry *SAMLSessionEntry) (*SAMLUIInfo, error) { var prompt []string - authnRequest := entry.AuthnRequest() + authnRequest, authnRequestExist := entry.AuthnRequest() switch { + case !authnRequestExist: + // This is an Idp-Initiated flow, allow user to select_account or login + prompt = []string{} case authnRequest.GetIsPassive() == false && authnRequest.GetForceAuthn() == false: - prompt = []string{"select_account"} + prompt = []string{} case authnRequest.GetIsPassive() == false && authnRequest.GetForceAuthn() == true: prompt = []string{"login"} case authnRequest.GetIsPassive() == true && authnRequest.GetForceAuthn() == false: diff --git a/pkg/lib/saml/service.go b/pkg/lib/saml/service.go index 7bcf90088c..20ce2bdd1b 100644 --- a/pkg/lib/saml/service.go +++ b/pkg/lib/saml/service.go @@ -260,7 +260,11 @@ func (s *Service) IssueSuccessResponse( now := s.Clock.NowUTC() issuerID := s.IdpEntityID() - response := samlprotocol.NewSuccessResponse(now, issuerID, inResponseToAuthnRequest.ID) + inResponseTo := "" + if inResponseToAuthnRequest != nil { + inResponseTo = inResponseToAuthnRequest.ID + } + response := samlprotocol.NewSuccessResponse(now, issuerID, inResponseTo) // Referencing other SAML Idp implementations, // use ACS url as default value of destination, recipient and audience @@ -276,15 +280,17 @@ func (s *Service) IssueSuccessResponse( } nameIDFormat := sp.NameIDFormat - if nameIDFormatInRequest, ok := inResponseToAuthnRequest.GetNameIDFormat(); ok { - nameIDFormat = nameIDFormatInRequest + if inResponseToAuthnRequest != nil { + if nameIDFormatInRequest, ok := inResponseToAuthnRequest.GetNameIDFormat(); ok { + nameIDFormat = nameIDFormatInRequest + } } // allow for some clock skew notBefore := now.Add(-1 * duration.ClockSkew) assertionValidDuration := sp.AssertionValidDuration.Duration() notOnOrAfter := now.Add(assertionValidDuration) - if notBefore.Before(inResponseToAuthnRequest.IssueInstant) { + if inResponseToAuthnRequest != nil && notBefore.Before(inResponseToAuthnRequest.IssueInstant) { notBefore = inResponseToAuthnRequest.IssueInstant notOnOrAfter = notBefore.Add(assertionValidDuration) } @@ -293,7 +299,7 @@ func (s *Service) IssueSuccessResponse( NotBefore: notBefore, NotOnOrAfter: notOnOrAfter, } - if inResponseToAuthnRequest.Conditions != nil { + if inResponseToAuthnRequest != nil && inResponseToAuthnRequest.Conditions != nil { // Only allow conditions which are stricter than what we set by default if !inResponseToAuthnRequest.Conditions.NotBefore.IsZero() && inResponseToAuthnRequest.Conditions.NotBefore.After(notBefore) { conditions.NotBefore = inResponseToAuthnRequest.Conditions.NotBefore @@ -312,8 +318,10 @@ func (s *Service) IssueSuccessResponse( } // Include audiences requested - for _, aud := range inResponseToAuthnRequest.CollectAudiences() { - audiences.Add(aud) + if inResponseToAuthnRequest != nil { + for _, aud := range inResponseToAuthnRequest.CollectAudiences() { + audiences.Add(aud) + } } audienceRestriction := samlprotocol.AudienceRestriction{ @@ -354,7 +362,7 @@ func (s *Service) IssueSuccessResponse( { Method: "urn:oasis:names:tc:SAML:2.0:cm:bearer", SubjectConfirmationData: &samlprotocol.SubjectConfirmationData{ - InResponseTo: inResponseToAuthnRequest.ID, + InResponseTo: inResponseTo, NotOnOrAfter: notOnOrAfter, Recipient: recipient, },