Skip to content

Commit

Permalink
feat(saml): improved error handling
Browse files Browse the repository at this point in the history
Signed-off-by: ThibaultHerard <[email protected]>

Co-authored-by: sebferrer <[email protected]>
  • Loading branch information
ThibHrrd and sebferrer committed Nov 25, 2022
1 parent 111050d commit fc408c6
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 38 deletions.
2 changes: 1 addition & 1 deletion selfservice/strategy/saml/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func TestInitSAMLWithoutPoviderID(t *testing.T) {

resp, _ := NewTestClient(t, nil).Get(ts.URL + "/self-service/methods/saml/metadata/samlProvider")
body, _ := ioutil.ReadAll(resp.Body)
assert.Contains(t, string(body), "\"code\":404,\"status\":\"Not Found\"")
assert.Contains(t, string(body), "Invalid SAML configuration in the configuration file")
}

func TestInitSAMLWithoutPoviderLabel(t *testing.T) {
Expand Down
28 changes: 27 additions & 1 deletion selfservice/strategy/saml/error.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
package saml

import "github.com/ory/herodot"
import (
"net/http"

"github.com/ory/herodot"
"google.golang.org/grpc/codes"
)

var (
ErrScopeMissing = herodot.ErrBadRequest.
Expand All @@ -13,4 +18,25 @@ var (

ErrAPIFlowNotSupported = herodot.ErrBadRequest.WithError("API-based flows are not supported for this method").
WithReason("SAML SignIn and Registeration are only supported for flows initiated using the Browser endpoint.")

ErrInvalidSAMLMetadataError = herodot.DefaultError{
StatusField: http.StatusText(http.StatusOK),
ErrorField: "Not valid SAML metadata file",
CodeField: http.StatusOK,
GRPCCodeField: codes.InvalidArgument,
}

ErrInvalidCertificateError = herodot.DefaultError{
StatusField: http.StatusText(http.StatusOK),
ErrorField: "Not valid certificate",
CodeField: http.StatusOK,
GRPCCodeField: codes.InvalidArgument,
}

ErrInvalidSAMLConfiguration = herodot.DefaultError{
StatusField: http.StatusText(http.StatusOK),
ErrorField: "Invalid SAML configuration in the configuration file",
CodeField: http.StatusOK,
GRPCCodeField: codes.InvalidArgument,
}
)
61 changes: 32 additions & 29 deletions selfservice/strategy/saml/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/crewjam/saml/samlsp"
"github.com/julienschmidt/httprouter"

"github.com/ory/herodot"
"github.com/ory/kratos/continuity"
"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/selfservice/errorx"
Expand Down Expand Up @@ -154,71 +155,70 @@ func (h *Handler) instantiateMiddleware(ctx context.Context, config config.Confi
// Key pair to encrypt and sign SAML requests
keyPair, err := tls.LoadX509KeyPair(strings.Replace(providerConfig.PublicCertPath, "file://", "", 1), strings.Replace(providerConfig.PrivateKeyPath, "file://", "", 1))
if err != nil {
return err
return herodot.ErrNotFound.WithTrace(err)
}
keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0])
if err != nil {
return err
return herodot.ErrNotFound.WithTrace(err)
}

var idpMetadata *samlidp.EntityDescriptor

// We check if the metadata file is provided
if providerConfig.IDPInformation["idp_metadata_url"] != "" {

// The metadata file is provided
metadataURL := providerConfig.IDPInformation["idp_metadata_url"]

metadataBuffer, err := fetcher.NewFetcher().Fetch(metadataURL)
if err != nil {
return err
return herodot.ErrNotFound.WithTrace(err)
}

metadata, err := ioutil.ReadAll(metadataBuffer)
if err != nil {
return err
return herodot.ErrInternalServerError.WithTrace(err)
}

idpMetadata, err = samlsp.ParseMetadata(metadata)
if err != nil {
return err
return ErrInvalidSAMLMetadataError.WithTrace(err)
}

} else {
// The metadata file is not provided
// So were are creating a minimalist IDP metadata based on what is provided by the user on the config file
entityIDURL, err := url.Parse(providerConfig.IDPInformation["idp_entity_id"])
if err != nil {
return err
return herodot.ErrNotFound.WithTrace(err)
}

// The IDP SSO URL
IDPSSOURL, err := url.Parse(providerConfig.IDPInformation["idp_sso_url"])
if err != nil {
return err
return herodot.ErrNotFound.WithTrace(err)
}

// The IDP Logout URL
IDPlogoutURL, err := url.Parse(providerConfig.IDPInformation["idp_logout_url"])
if err != nil {
return err
return herodot.ErrNotFound.WithTrace(err)
}

// The certificate of the IDP
certificateBuffer, err := fetcher.NewFetcher().Fetch(providerConfig.IDPInformation["idp_certificate_path"])
if err != nil {
return err
return herodot.ErrNotFound.WithTrace(err)
}

certificate, err := ioutil.ReadAll(certificateBuffer)
if err != nil {
return err
return herodot.ErrInternalServerError.WithTrace(err)
}

// We parse it into a x509.Certificate object
IDPCertificate, err := MustParseCertificate(certificate)
if err != nil {
return err
return ErrInvalidCertificateError.WithTrace(err)
}

// Because the metadata file is not provided, we need to simulate an IDP to create artificial metadata from the data entered in the conf file
Expand All @@ -238,7 +238,7 @@ func (h *Handler) instantiateMiddleware(ctx context.Context, config config.Confi
// The main URL
rootURL, err := url.Parse(config.SelfServiceBrowserDefaultReturnTo(ctx).String())
if err != nil {
return err
return herodot.ErrNotFound.WithTrace(err)
}

// Here we create a MiddleWare to transform Kratos into a Service Provider
Expand All @@ -263,7 +263,7 @@ func (h *Handler) instantiateMiddleware(ctx context.Context, config config.Confi
},
})
if err != nil {
return err
return herodot.ErrInternalServerError.WithTrace(err)
}

// It's better to use SHA256 than SHA1
Expand All @@ -278,7 +278,7 @@ func (h *Handler) instantiateMiddleware(ctx context.Context, config config.Confi

u, err := url.Parse(publicUrlString + RouteSamlAcsWithSlash)
if err != nil {
return err
return herodot.ErrNotFound.WithTrace(err)
}
samlMiddleWare.ServiceProvider.AcsURL = *u

Expand All @@ -287,15 +287,15 @@ func (h *Handler) instantiateMiddleware(ctx context.Context, config config.Confi
publicUrlString = publicUrlString[:len(publicUrlString)-1]
u, err := url.Parse(publicUrlString + RouteSamlAcsWithSlash)
if err != nil {
return err
return herodot.ErrNotFound.WithTrace(err)
}
samlMiddleWare.ServiceProvider.AcsURL = *u
}

// Crewjam library use default route for ACS and metadata but we want to overwrite them
metadata, err := url.Parse(publicUrlString + RouteMetadata)
if err != nil {
return err
return herodot.ErrNotFound.WithTrace(err)
}
samlMiddleWare.ServiceProvider.MetadataURL = *metadata

Expand Down Expand Up @@ -337,22 +337,25 @@ func CreateSAMLProviderConfig(config config.Config, ctx context.Context, pid str
if err := jsonx.
NewStrictDecoder(bytes.NewBuffer(conf)).
Decode(&c); err != nil {
return nil, errors.Wrapf(err, "Unable to decode config %v", string(conf))
return nil, ErrInvalidSAMLConfiguration.WithReasonf("Unable to decode config %v", string(conf)).WithTrace(err)
}

if len(c.SAMLProviders) == 0 {
return nil, errors.Errorf("Please indicate a SAML Identity Provider in your configuration file")
return nil, ErrInvalidSAMLConfiguration.WithReason("Please indicate a SAML Identity Provider in your configuration file")
}

providerConfig, err := c.ProviderConfig(pid)
if err != nil {
return nil, err
return nil, ErrInvalidSAMLConfiguration.WithTrace(err)
}

if providerConfig.IDPInformation == nil {
return nil, errors.Errorf("Please include your Identity Provider information in the configuration file.")
return nil, ErrInvalidSAMLConfiguration.WithReasonf("Please include your Identity Provider information in the configuration file.").WithTrace(err)
}

/**
* SAMLTODO errors
*/
// _, sso_exists := providerConfig.IDPInformation["idp_sso_url"]
_, sso_exists := providerConfig.IDPInformation["idp_sso_url"]
_, entity_id_exists := providerConfig.IDPInformation["idp_entity_id"]
Expand All @@ -361,35 +364,35 @@ func CreateSAMLProviderConfig(config config.Config, ctx context.Context, pid str
_, metadata_exists := providerConfig.IDPInformation["idp_metadata_url"]

if (!metadata_exists && (!sso_exists || !entity_id_exists || !certificate_exists || !logout_url_exists)) || len(providerConfig.IDPInformation) > 4 {
return nil, errors.Errorf("Please check your IDP information in the configuration file")
return nil, ErrInvalidSAMLConfiguration.WithReason("Please check your IDP information in the configuration file").WithTrace(err)
}

if providerConfig.ID == "" {
return nil, errors.Errorf("Provider must have an ID")
return nil, ErrInvalidSAMLConfiguration.WithReason("Provider must have an ID").WithTrace(err)
}

if providerConfig.Label == "" {
return nil, errors.Errorf("Provider must have a label")
return nil, ErrInvalidSAMLConfiguration.WithReason("Provider must have a label").WithTrace(err)
}

if providerConfig.PrivateKeyPath == "" {
return nil, errors.Errorf("Provider must have a private key")
return nil, ErrInvalidSAMLConfiguration.WithReason("Provider must have a private key").WithTrace(err)
}

if providerConfig.PublicCertPath == "" {
return nil, errors.Errorf("Provider must have a public certificate")
return nil, ErrInvalidSAMLConfiguration.WithReason("Provider must have a public certificate").WithTrace(err)
}

if providerConfig.AttributesMap == nil || len(providerConfig.AttributesMap) == 0 {
return nil, errors.Errorf("Provider must have an attributes map")
return nil, ErrInvalidSAMLConfiguration.WithReason("Provider must have an attributes map").WithTrace(err)
}

if providerConfig.AttributesMap["id"] == "" {
return nil, errors.Errorf("You must have an ID field in your attribute_map")
return nil, ErrInvalidSAMLConfiguration.WithReason("You must have an ID field in your attribute_map").WithTrace(err)
}

if providerConfig.Mapper == "" {
return nil, errors.Errorf("Provider must have a mapper url")
return nil, ErrInvalidSAMLConfiguration.WithReason("Provider must have a mapper url").WithTrace(err)
}

return providerConfig, nil
Expand Down
14 changes: 7 additions & 7 deletions selfservice/strategy/saml/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ func (s *Strategy) Config(ctx context.Context) (*ConfigurationCollection, error)
func (s *Strategy) populateMethod(r *http.Request, c *container.Container, message func(provider string) *text.Message) error {
conf, err := s.Config(r.Context())
if err != nil {
return err
return ErrInvalidSAMLConfiguration.WithTrace(err)
}

// does not need sorting because there is only one field
Expand All @@ -370,7 +370,7 @@ func (s *Strategy) populateMethod(r *http.Request, c *container.Container, messa
func (s *Strategy) handleError(w http.ResponseWriter, r *http.Request, f flow.Flow, provider string, traits []byte, err error) error {
switch rf := f.(type) {
case *login.Flow:
return err
return ErrAPIFlowNotSupported.WithTrace(err)
case *registration.Flow:
// Reset all nodes to not confuse users.
// This is kinda hacky and will probably need to be updated at some point.
Expand All @@ -384,24 +384,24 @@ func (s *Strategy) handleError(w http.ResponseWriter, r *http.Request, f flow.Fl
if traits != nil {
ds, err := s.d.Config().DefaultIdentityTraitsSchemaURL(r.Context())
if err != nil {
return err
return ErrInvalidSAMLConfiguration.WithTrace(err)
}

traitNodes, err := container.NodesFromJSONSchema(r.Context(), node.SAMLGroup, ds.String(), "", nil)
if err != nil {
return err
return herodot.ErrInternalServerError.WithTrace(err)
}

rf.UI.Nodes = append(rf.UI.Nodes, traitNodes...)
rf.UI.UpdateNodeValuesFromJSON(traits, "traits", node.SAMLGroup)
}

return err
return herodot.ErrInternalServerError.WithTrace(err)
case *settings.Flow:
return err
return ErrAPIFlowNotSupported.WithTrace(err)
}

return err
return herodot.ErrInternalServerError.WithTrace(err)
}

func (s *Strategy) CountActiveCredentials(cc map[identity.CredentialsType]identity.Credentials) (count int, err error) {
Expand Down

0 comments on commit fc408c6

Please sign in to comment.