From fc408c6aac0b3342ca1bed3eb22476f5064b6a8a Mon Sep 17 00:00:00 2001 From: ThibaultHerard Date: Fri, 25 Nov 2022 17:52:34 +0000 Subject: [PATCH] feat(saml): improved error handling Signed-off-by: ThibaultHerard Co-authored-by: sebferrer --- selfservice/strategy/saml/config_test.go | 2 +- selfservice/strategy/saml/error.go | 28 ++++++++++- selfservice/strategy/saml/handler.go | 61 +++++++++++++----------- selfservice/strategy/saml/strategy.go | 14 +++--- 4 files changed, 67 insertions(+), 38 deletions(-) diff --git a/selfservice/strategy/saml/config_test.go b/selfservice/strategy/saml/config_test.go index d69274920aef..a6d90c8e5d19 100644 --- a/selfservice/strategy/saml/config_test.go +++ b/selfservice/strategy/saml/config_test.go @@ -102,7 +102,7 @@ func TestInitSAMLWithoutPoviderID(t *testing.T) { resp, _ := NewTestClient(t, nil).Get(ts.URL + "/self-service/methods/saml/metadata/samlProvider") body, _ := ioutil.ReadAll(resp.Body) - assert.Contains(t, string(body), "\"code\":404,\"status\":\"Not Found\"") + assert.Contains(t, string(body), "Invalid SAML configuration in the configuration file") } func TestInitSAMLWithoutPoviderLabel(t *testing.T) { diff --git a/selfservice/strategy/saml/error.go b/selfservice/strategy/saml/error.go index ab984670e1a0..68031dff7bda 100644 --- a/selfservice/strategy/saml/error.go +++ b/selfservice/strategy/saml/error.go @@ -1,6 +1,11 @@ package saml -import "github.com/ory/herodot" +import ( + "net/http" + + "github.com/ory/herodot" + "google.golang.org/grpc/codes" +) var ( ErrScopeMissing = herodot.ErrBadRequest. @@ -13,4 +18,25 @@ var ( ErrAPIFlowNotSupported = herodot.ErrBadRequest.WithError("API-based flows are not supported for this method"). WithReason("SAML SignIn and Registeration are only supported for flows initiated using the Browser endpoint.") + + ErrInvalidSAMLMetadataError = herodot.DefaultError{ + StatusField: http.StatusText(http.StatusOK), + ErrorField: "Not valid SAML metadata file", + CodeField: http.StatusOK, + GRPCCodeField: codes.InvalidArgument, + } + + ErrInvalidCertificateError = herodot.DefaultError{ + StatusField: http.StatusText(http.StatusOK), + ErrorField: "Not valid certificate", + CodeField: http.StatusOK, + GRPCCodeField: codes.InvalidArgument, + } + + ErrInvalidSAMLConfiguration = herodot.DefaultError{ + StatusField: http.StatusText(http.StatusOK), + ErrorField: "Invalid SAML configuration in the configuration file", + CodeField: http.StatusOK, + GRPCCodeField: codes.InvalidArgument, + } ) diff --git a/selfservice/strategy/saml/handler.go b/selfservice/strategy/saml/handler.go index bfcc82c0acf5..6ffac7ae70fc 100644 --- a/selfservice/strategy/saml/handler.go +++ b/selfservice/strategy/saml/handler.go @@ -19,6 +19,7 @@ import ( "github.com/crewjam/saml/samlsp" "github.com/julienschmidt/httprouter" + "github.com/ory/herodot" "github.com/ory/kratos/continuity" "github.com/ory/kratos/driver/config" "github.com/ory/kratos/selfservice/errorx" @@ -154,34 +155,33 @@ func (h *Handler) instantiateMiddleware(ctx context.Context, config config.Confi // Key pair to encrypt and sign SAML requests keyPair, err := tls.LoadX509KeyPair(strings.Replace(providerConfig.PublicCertPath, "file://", "", 1), strings.Replace(providerConfig.PrivateKeyPath, "file://", "", 1)) if err != nil { - return err + return herodot.ErrNotFound.WithTrace(err) } keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0]) if err != nil { - return err + return herodot.ErrNotFound.WithTrace(err) } var idpMetadata *samlidp.EntityDescriptor // We check if the metadata file is provided if providerConfig.IDPInformation["idp_metadata_url"] != "" { - // The metadata file is provided metadataURL := providerConfig.IDPInformation["idp_metadata_url"] metadataBuffer, err := fetcher.NewFetcher().Fetch(metadataURL) if err != nil { - return err + return herodot.ErrNotFound.WithTrace(err) } metadata, err := ioutil.ReadAll(metadataBuffer) if err != nil { - return err + return herodot.ErrInternalServerError.WithTrace(err) } idpMetadata, err = samlsp.ParseMetadata(metadata) if err != nil { - return err + return ErrInvalidSAMLMetadataError.WithTrace(err) } } else { @@ -189,36 +189,36 @@ func (h *Handler) instantiateMiddleware(ctx context.Context, config config.Confi // So were are creating a minimalist IDP metadata based on what is provided by the user on the config file entityIDURL, err := url.Parse(providerConfig.IDPInformation["idp_entity_id"]) if err != nil { - return err + return herodot.ErrNotFound.WithTrace(err) } // The IDP SSO URL IDPSSOURL, err := url.Parse(providerConfig.IDPInformation["idp_sso_url"]) if err != nil { - return err + return herodot.ErrNotFound.WithTrace(err) } // The IDP Logout URL IDPlogoutURL, err := url.Parse(providerConfig.IDPInformation["idp_logout_url"]) if err != nil { - return err + return herodot.ErrNotFound.WithTrace(err) } // The certificate of the IDP certificateBuffer, err := fetcher.NewFetcher().Fetch(providerConfig.IDPInformation["idp_certificate_path"]) if err != nil { - return err + return herodot.ErrNotFound.WithTrace(err) } certificate, err := ioutil.ReadAll(certificateBuffer) if err != nil { - return err + return herodot.ErrInternalServerError.WithTrace(err) } // We parse it into a x509.Certificate object IDPCertificate, err := MustParseCertificate(certificate) if err != nil { - return err + return ErrInvalidCertificateError.WithTrace(err) } // Because the metadata file is not provided, we need to simulate an IDP to create artificial metadata from the data entered in the conf file @@ -238,7 +238,7 @@ func (h *Handler) instantiateMiddleware(ctx context.Context, config config.Confi // The main URL rootURL, err := url.Parse(config.SelfServiceBrowserDefaultReturnTo(ctx).String()) if err != nil { - return err + return herodot.ErrNotFound.WithTrace(err) } // Here we create a MiddleWare to transform Kratos into a Service Provider @@ -263,7 +263,7 @@ func (h *Handler) instantiateMiddleware(ctx context.Context, config config.Confi }, }) if err != nil { - return err + return herodot.ErrInternalServerError.WithTrace(err) } // It's better to use SHA256 than SHA1 @@ -278,7 +278,7 @@ func (h *Handler) instantiateMiddleware(ctx context.Context, config config.Confi u, err := url.Parse(publicUrlString + RouteSamlAcsWithSlash) if err != nil { - return err + return herodot.ErrNotFound.WithTrace(err) } samlMiddleWare.ServiceProvider.AcsURL = *u @@ -287,7 +287,7 @@ func (h *Handler) instantiateMiddleware(ctx context.Context, config config.Confi publicUrlString = publicUrlString[:len(publicUrlString)-1] u, err := url.Parse(publicUrlString + RouteSamlAcsWithSlash) if err != nil { - return err + return herodot.ErrNotFound.WithTrace(err) } samlMiddleWare.ServiceProvider.AcsURL = *u } @@ -295,7 +295,7 @@ func (h *Handler) instantiateMiddleware(ctx context.Context, config config.Confi // Crewjam library use default route for ACS and metadata but we want to overwrite them metadata, err := url.Parse(publicUrlString + RouteMetadata) if err != nil { - return err + return herodot.ErrNotFound.WithTrace(err) } samlMiddleWare.ServiceProvider.MetadataURL = *metadata @@ -337,22 +337,25 @@ func CreateSAMLProviderConfig(config config.Config, ctx context.Context, pid str if err := jsonx. NewStrictDecoder(bytes.NewBuffer(conf)). Decode(&c); err != nil { - return nil, errors.Wrapf(err, "Unable to decode config %v", string(conf)) + return nil, ErrInvalidSAMLConfiguration.WithReasonf("Unable to decode config %v", string(conf)).WithTrace(err) } if len(c.SAMLProviders) == 0 { - return nil, errors.Errorf("Please indicate a SAML Identity Provider in your configuration file") + return nil, ErrInvalidSAMLConfiguration.WithReason("Please indicate a SAML Identity Provider in your configuration file") } providerConfig, err := c.ProviderConfig(pid) if err != nil { - return nil, err + return nil, ErrInvalidSAMLConfiguration.WithTrace(err) } if providerConfig.IDPInformation == nil { - return nil, errors.Errorf("Please include your Identity Provider information in the configuration file.") + return nil, ErrInvalidSAMLConfiguration.WithReasonf("Please include your Identity Provider information in the configuration file.").WithTrace(err) } + /** + * SAMLTODO errors + */ // _, sso_exists := providerConfig.IDPInformation["idp_sso_url"] _, sso_exists := providerConfig.IDPInformation["idp_sso_url"] _, entity_id_exists := providerConfig.IDPInformation["idp_entity_id"] @@ -361,35 +364,35 @@ func CreateSAMLProviderConfig(config config.Config, ctx context.Context, pid str _, metadata_exists := providerConfig.IDPInformation["idp_metadata_url"] if (!metadata_exists && (!sso_exists || !entity_id_exists || !certificate_exists || !logout_url_exists)) || len(providerConfig.IDPInformation) > 4 { - return nil, errors.Errorf("Please check your IDP information in the configuration file") + return nil, ErrInvalidSAMLConfiguration.WithReason("Please check your IDP information in the configuration file").WithTrace(err) } if providerConfig.ID == "" { - return nil, errors.Errorf("Provider must have an ID") + return nil, ErrInvalidSAMLConfiguration.WithReason("Provider must have an ID").WithTrace(err) } if providerConfig.Label == "" { - return nil, errors.Errorf("Provider must have a label") + return nil, ErrInvalidSAMLConfiguration.WithReason("Provider must have a label").WithTrace(err) } if providerConfig.PrivateKeyPath == "" { - return nil, errors.Errorf("Provider must have a private key") + return nil, ErrInvalidSAMLConfiguration.WithReason("Provider must have a private key").WithTrace(err) } if providerConfig.PublicCertPath == "" { - return nil, errors.Errorf("Provider must have a public certificate") + return nil, ErrInvalidSAMLConfiguration.WithReason("Provider must have a public certificate").WithTrace(err) } if providerConfig.AttributesMap == nil || len(providerConfig.AttributesMap) == 0 { - return nil, errors.Errorf("Provider must have an attributes map") + return nil, ErrInvalidSAMLConfiguration.WithReason("Provider must have an attributes map").WithTrace(err) } if providerConfig.AttributesMap["id"] == "" { - return nil, errors.Errorf("You must have an ID field in your attribute_map") + return nil, ErrInvalidSAMLConfiguration.WithReason("You must have an ID field in your attribute_map").WithTrace(err) } if providerConfig.Mapper == "" { - return nil, errors.Errorf("Provider must have a mapper url") + return nil, ErrInvalidSAMLConfiguration.WithReason("Provider must have a mapper url").WithTrace(err) } return providerConfig, nil diff --git a/selfservice/strategy/saml/strategy.go b/selfservice/strategy/saml/strategy.go index cb11ceb747b5..93ecf7533a34 100644 --- a/selfservice/strategy/saml/strategy.go +++ b/selfservice/strategy/saml/strategy.go @@ -357,7 +357,7 @@ func (s *Strategy) Config(ctx context.Context) (*ConfigurationCollection, error) func (s *Strategy) populateMethod(r *http.Request, c *container.Container, message func(provider string) *text.Message) error { conf, err := s.Config(r.Context()) if err != nil { - return err + return ErrInvalidSAMLConfiguration.WithTrace(err) } // does not need sorting because there is only one field @@ -370,7 +370,7 @@ func (s *Strategy) populateMethod(r *http.Request, c *container.Container, messa func (s *Strategy) handleError(w http.ResponseWriter, r *http.Request, f flow.Flow, provider string, traits []byte, err error) error { switch rf := f.(type) { case *login.Flow: - return err + return ErrAPIFlowNotSupported.WithTrace(err) case *registration.Flow: // Reset all nodes to not confuse users. // This is kinda hacky and will probably need to be updated at some point. @@ -384,24 +384,24 @@ func (s *Strategy) handleError(w http.ResponseWriter, r *http.Request, f flow.Fl if traits != nil { ds, err := s.d.Config().DefaultIdentityTraitsSchemaURL(r.Context()) if err != nil { - return err + return ErrInvalidSAMLConfiguration.WithTrace(err) } traitNodes, err := container.NodesFromJSONSchema(r.Context(), node.SAMLGroup, ds.String(), "", nil) if err != nil { - return err + return herodot.ErrInternalServerError.WithTrace(err) } rf.UI.Nodes = append(rf.UI.Nodes, traitNodes...) rf.UI.UpdateNodeValuesFromJSON(traits, "traits", node.SAMLGroup) } - return err + return herodot.ErrInternalServerError.WithTrace(err) case *settings.Flow: - return err + return ErrAPIFlowNotSupported.WithTrace(err) } - return err + return herodot.ErrInternalServerError.WithTrace(err) } func (s *Strategy) CountActiveCredentials(cc map[identity.CredentialsType]identity.Credentials) (count int, err error) {