From 7a827c66db35bc09e601e220bc26ee7e937ce9ce Mon Sep 17 00:00:00 2001 From: sebferrer Date: Wed, 22 Feb 2023 16:55:49 +0000 Subject: [PATCH] feat(saml): continuity manager relaystate refactoring + test new credentials saml + resolving conflicts with master Signed-off-by: sebferrer Co-authored-by: ThibaultHerard --- continuity/manager.go | 24 +-- continuity/manager_cookie.go | 29 +++- continuity/manager_relaystate.go | 149 ------------------ continuity/manager_relaystate_test.go | 19 ++- driver/registry_default.go | 18 +-- go.sum | 4 +- identity/credentials_saml.go | 9 +- identity/credentials_saml_test.go | 18 +++ selfservice/strategy/saml/config_test.go | 24 +-- selfservice/strategy/saml/provider_config.go | 2 +- selfservice/strategy/saml/provider_saml.go | 4 +- selfservice/strategy/saml/strategy.go | 109 ++++++------- selfservice/strategy/saml/strategy_login.go | 29 +--- .../strategy/saml/strategy_registration.go | 8 +- .../strategy/saml/testdata/sp2_cert.pem | 32 ---- .../strategy/saml/testdata/sp2_key.pem | 52 ------ .../saml/vulnerabilities_helper_test.go | 4 +- .../strategy/saml/vulnerabilities_test.go | 14 +- x/provider.go | 5 - 19 files changed, 167 insertions(+), 386 deletions(-) delete mode 100644 continuity/manager_relaystate.go create mode 100644 identity/credentials_saml_test.go delete mode 100644 selfservice/strategy/saml/testdata/sp2_cert.pem delete mode 100644 selfservice/strategy/saml/testdata/sp2_key.pem diff --git a/continuity/manager.go b/continuity/manager.go index 27541f32b381..e2323bfb7f60 100644 --- a/continuity/manager.go +++ b/continuity/manager.go @@ -20,22 +20,19 @@ type ManagementProvider interface { ContinuityManager() Manager } -type ManagementProviderRelayState interface { - RelayStateContinuityManager() Manager -} - type Manager interface { Pause(ctx context.Context, w http.ResponseWriter, r *http.Request, name string, opts ...ManagerOption) error Continue(ctx context.Context, w http.ResponseWriter, r *http.Request, name string, opts ...ManagerOption) (*Container, error) - Abort(ctx context.Context, w http.ResponseWriter, r *http.Request, name string) error + Abort(ctx context.Context, w http.ResponseWriter, r *http.Request, name string, opts ...ManagerOption) error } type managerOptions struct { - iid uuid.UUID - ttl time.Duration - payload json.RawMessage - payloadRaw interface{} - cleanUp bool + iid uuid.UUID + ttl time.Duration + payload json.RawMessage + payloadRaw interface{} + cleanUp bool + useRelayState bool } type ManagerOption func(*managerOptions) error @@ -87,3 +84,10 @@ func WithPayload(payload interface{}) ManagerOption { return nil } } + +func UseRelayState() ManagerOption { + return func(o *managerOptions) error { + o.useRelayState = true + return nil + } +} diff --git a/continuity/manager_cookie.go b/continuity/manager_cookie.go index 314c33da4a68..8e3d8f76b999 100644 --- a/continuity/manager_cookie.go +++ b/continuity/manager_cookie.go @@ -10,6 +10,7 @@ import ( "net/http" "github.com/gofrs/uuid" + "github.com/gorilla/sessions" "github.com/pkg/errors" "github.com/ory/herodot" @@ -64,12 +65,12 @@ func (m *ManagerCookie) Pause(ctx context.Context, w http.ResponseWriter, r *htt } func (m *ManagerCookie) Continue(ctx context.Context, w http.ResponseWriter, r *http.Request, name string, opts ...ManagerOption) (*Container, error) { - container, err := m.container(ctx, w, r, name) + o, err := newManagerOptions(opts) if err != nil { return nil, err } - o, err := newManagerOptions(opts) + container, err := m.container(ctx, w, r, name, o.useRelayState) if err != nil { return nil, err } @@ -95,9 +96,16 @@ func (m *ManagerCookie) Continue(ctx context.Context, w http.ResponseWriter, r * return container, nil } -func (m *ManagerCookie) sid(ctx context.Context, w http.ResponseWriter, r *http.Request, name string) (uuid.UUID, error) { +func (m *ManagerCookie) sid(ctx context.Context, w http.ResponseWriter, r *http.Request, name string, useRelayState bool) (uuid.UUID, error) { + var getStringFunction func(r *http.Request, s sessions.StoreExact, id string, key interface{}) (string, error) + if useRelayState { + getStringFunction = x.SessionGetStringRelayState + } else { + getStringFunction = x.SessionGetString + } + var sid uuid.UUID - if s, err := x.SessionGetString(r, m.d.ContinuityCookieManager(ctx), CookieName, name); err != nil { + if s, err := getStringFunction(r, m.d.ContinuityCookieManager(ctx), CookieName, name); err != nil { _ = x.SessionUnsetKey(w, r, m.d.ContinuityCookieManager(ctx), CookieName, name) return sid, errors.WithStack(ErrNotResumable.WithDebugf("%+v", err)) } else if sid = x.ParseUUID(s); sid == uuid.Nil { @@ -108,8 +116,8 @@ func (m *ManagerCookie) sid(ctx context.Context, w http.ResponseWriter, r *http. return sid, nil } -func (m *ManagerCookie) container(ctx context.Context, w http.ResponseWriter, r *http.Request, name string) (*Container, error) { - sid, err := m.sid(ctx, w, r, name) +func (m *ManagerCookie) container(ctx context.Context, w http.ResponseWriter, r *http.Request, name string, useRelayState bool) (*Container, error) { + sid, err := m.sid(ctx, w, r, name, useRelayState) if err != nil { return nil, err } @@ -129,8 +137,13 @@ func (m *ManagerCookie) container(ctx context.Context, w http.ResponseWriter, r return container, err } -func (m ManagerCookie) Abort(ctx context.Context, w http.ResponseWriter, r *http.Request, name string) error { - sid, err := m.sid(ctx, w, r, name) +func (m ManagerCookie) Abort(ctx context.Context, w http.ResponseWriter, r *http.Request, name string, opts ...ManagerOption) error { + o, err := newManagerOptions(opts) + if err != nil { + return err + } + + sid, err := m.sid(ctx, w, r, name, o.useRelayState) if errors.Is(err, &ErrNotResumable) { // We do not care about an error here return nil diff --git a/continuity/manager_relaystate.go b/continuity/manager_relaystate.go deleted file mode 100644 index 264a54e17064..000000000000 --- a/continuity/manager_relaystate.go +++ /dev/null @@ -1,149 +0,0 @@ -package continuity - -import ( - "bytes" - "context" - "encoding/json" - "net/http" - - "github.com/gofrs/uuid" - "github.com/pkg/errors" - - "github.com/ory/herodot" - "github.com/ory/x/sqlcon" - - "github.com/ory/kratos/session" - "github.com/ory/kratos/x" -) - -var _ Manager = new(ManagerRelayState) -var ErrNotResumableRelayState = *herodot.ErrBadRequest.WithError("no resumable session found").WithReason("The browser does not contain the necessary RelayState value to resume the session. This is a security violation and was blocked. Please try again!") - -type ( - managerRelayStateDependencies interface { - PersistenceProvider - x.RelayStateProvider - session.ManagementProvider - } - ManagerRelayState struct { - dr managerRelayStateDependencies - dc managerCookieDependencies - } -) - -// To ensure continuity even after redirection to the IDP, we cannot use cookies because the IDP and the SP are on two different domains. -// So we have to pass the continuity value through the relaystate. -// This value corresponds to the session ID -func NewManagerRelayState(dr managerRelayStateDependencies, dc managerCookieDependencies) *ManagerRelayState { - return &ManagerRelayState{dr: dr, dc: dc} -} - -func (m *ManagerRelayState) Pause(ctx context.Context, w http.ResponseWriter, r *http.Request, name string, opts ...ManagerOption) error { - if len(name) == 0 { - return errors.Errorf("continuity container name must be set") - } - - o, err := newManagerOptions(opts) - if err != nil { - return err - } - c := NewContainer(name, *o) - - if err := m.dr.ContinuityPersister().SaveContinuitySession(r.Context(), c); err != nil { - return errors.WithStack(err) - } - - if err = x.SessionPersistValues(w, r, m.dc.ContinuityCookieManager(ctx), CookieName, map[string]interface{}{ - name: c.ID.String(), - }); err != nil { - return err - } - - return nil -} - -func (m *ManagerRelayState) Continue(ctx context.Context, w http.ResponseWriter, r *http.Request, name string, opts ...ManagerOption) (*Container, error) { - container, err := m.container(ctx, w, r, name) - if err != nil { - return nil, err - } - - o, err := newManagerOptions(opts) - if err != nil { - return nil, err - } - - if err := container.Valid(o.iid); err != nil { - return nil, err - } - - if o.payloadRaw != nil && container.Payload != nil { - if err := json.NewDecoder(bytes.NewBuffer(container.Payload)).Decode(o.payloadRaw); err != nil { - return nil, errors.WithStack(err) - } - } - - if err := x.SessionUnsetKey(w, r, m.dc.ContinuityCookieManager(ctx), CookieName, name); err != nil { - return nil, err - } - - if err := m.dc.ContinuityPersister().DeleteContinuitySession(ctx, container.ID); err != nil && !errors.Is(err, sqlcon.ErrNoRows) { - return nil, err - } - - return container, nil -} - -func (m *ManagerRelayState) sid(ctx context.Context, w http.ResponseWriter, r *http.Request, name string) (uuid.UUID, error) { - var sid uuid.UUID - if s, err := x.SessionGetStringRelayState(r, m.dc.ContinuityCookieManager(ctx), CookieName, name); err != nil { - return sid, errors.WithStack(ErrNotResumable.WithDebugf("%+v", err)) - - } else if sid = x.ParseUUID(s); sid == uuid.Nil { - return sid, errors.WithStack(ErrNotResumable.WithDebug("session id is not a valid uuid")) - - } - - return sid, nil -} - -func (m *ManagerRelayState) container(ctx context.Context, w http.ResponseWriter, r *http.Request, name string) (*Container, error) { - sid, err := m.sid(ctx, w, r, name) - if err != nil { - return nil, err - } - - container, err := m.dr.ContinuityPersister().GetContinuitySession(ctx, sid) - - if err != nil { - _ = x.SessionUnsetKey(w, r, m.dc.ContinuityCookieManager(ctx), CookieName, name) - } - - if errors.Is(err, sqlcon.ErrNoRows) { - return nil, errors.WithStack(ErrNotResumable.WithDebug("Resumable ID from RelayState could not be found in the datastore")) - } else if err != nil { - return nil, err - } - - return container, err -} - -func (m ManagerRelayState) Abort(ctx context.Context, w http.ResponseWriter, r *http.Request, name string) error { - sid, err := m.sid(ctx, w, r, name) - if errors.Is(err, &ErrNotResumable) { - // We do not care about an error here - return nil - } else if err != nil { - return err - } - - if err := x.SessionUnsetKey(w, r, m.dc.ContinuityCookieManager(ctx), CookieName, name); err != nil { - return err - } - - if err := m.dr.ContinuityPersister().DeleteContinuitySession(ctx, sid); err != nil && !errors.Is(err, sqlcon.ErrNoRows) { - return errors.WithStack(err) - } - - return nil -} diff --git a/continuity/manager_relaystate_test.go b/continuity/manager_relaystate_test.go index c37b4d5b8d45..d8bbfdbe703e 100644 --- a/continuity/manager_relaystate_test.go +++ b/continuity/manager_relaystate_test.go @@ -5,6 +5,7 @@ package continuity_test import ( "context" + "fmt" "net/http" "net/http/httptest" "net/url" @@ -31,7 +32,7 @@ import ( "github.com/ory/kratos/x" ) -func TestManagerRelayState(t *testing.T) { +func TestManagerUseRelayState(t *testing.T) { ctx := context.Background() conf, reg := internal.NewFastRegistryWithMocks(t) @@ -57,6 +58,7 @@ func TestManagerRelayState(t *testing.T) { r.PostForm = make(url.Values) r.PostForm.Set("RelayState", relayState) + tc.wo = append(tc.wo, continuity.UseRelayState()) c, err := p.Continue(r.Context(), w, r, ps.ByName("name"), tc.wo...) if err != nil { writer.WriteError(w, r, err) @@ -71,7 +73,7 @@ func TestManagerRelayState(t *testing.T) { r.PostForm = make(url.Values) r.PostForm.Set("RelayState", relayState) - err := p.Abort(r.Context(), w, r, ps.ByName("name")) + err := p.Abort(r.Context(), w, r, ps.ByName("name"), continuity.UseRelayState()) if err != nil { writer.WriteError(w, r, err) return @@ -90,7 +92,7 @@ func TestManagerRelayState(t *testing.T) { return &http.Client{Jar: x.EasyCookieJar(t, nil)} } - p := reg.RelayStateContinuityManager() + p := reg.ContinuityManager() cl := newClient() t.Run("case=continue cookie persists with same http client", func(t *testing.T) { @@ -115,6 +117,7 @@ func TestManagerRelayState(t *testing.T) { t.Cleanup(func() { require.NoError(t, res.Body.Close()) }) + assert.Equal(t, res.StatusCode, 200) require.Len(t, res.Cookies(), 1) assert.EqualValues(t, res.Cookies()[0].Name, continuity.CookieName) }) @@ -147,7 +150,8 @@ func TestManagerRelayState(t *testing.T) { t.Cleanup(func() { require.NoError(t, res.Body.Close()) }) - require.Len(t, res.Cookies(), 1) + assert.Equal(t, res.StatusCode, 200) + assert.Len(t, res.Cookies(), 1) assert.EqualValues(t, res.Cookies()[0].Name, continuity.CookieName) }) @@ -165,6 +169,7 @@ func TestManagerRelayState(t *testing.T) { for _, c := range res.Cookies() { relayState = c.Value + fmt.Println(relayState) relayState = strings.Replace(relayState, "a", "b", 1) } require.Len(t, res.Cookies(), 1) @@ -180,7 +185,7 @@ func TestManagerRelayState(t *testing.T) { t.Cleanup(func() { require.NoError(t, res.Body.Close()) }) - require.Len(t, res.Cookies(), 0, "the cookie couldn't be reconstructed without a valid relaystate") + assert.True(t, res.StatusCode == 400 || len(res.Cookies()) == 0) }) t.Run("case=continue cookie not delivered without relaystate", func(t *testing.T) { @@ -205,7 +210,7 @@ func TestManagerRelayState(t *testing.T) { t.Cleanup(func() { require.NoError(t, res.Body.Close()) }) - require.Len(t, res.Cookies(), 0, "the cookie couldn't be reconstructed without a valid relaystate") + assert.True(t, res.StatusCode == 400 || len(res.Cookies()) == 0) }) t.Run("case=pause, abort, and continue session with failure", func(t *testing.T) { @@ -236,6 +241,6 @@ func TestManagerRelayState(t *testing.T) { t.Cleanup(func() { require.NoError(t, res.Body.Close()) }) - require.Len(t, res.Cookies(), 0, "the cookie couldn't be reconstructed without a valid relaystate") + assert.True(t, res.StatusCode == 400 || len(res.Cookies()) == 0) }) } diff --git a/driver/registry_default.go b/driver/registry_default.go index 78a5c5b69da0..3c33be26a111 100644 --- a/driver/registry_default.go +++ b/driver/registry_default.go @@ -110,9 +110,6 @@ type RegistryDefault struct { continuityManager continuity.Manager - x.RelayStateProvider - session.ManagementProvider - schemaHandler *schema.Handler sessionHandler *session.Handler @@ -676,25 +673,12 @@ func (m *RegistryDefault) Courier(ctx context.Context) (courier.Courier, error) } func (m *RegistryDefault) ContinuityManager() continuity.Manager { - // If m.continuityManager is nil or not a continuity.ManagerCookie - switch m.continuityManager.(type) { - case *continuity.ManagerCookie: - default: + if m.continuityManager == nil { m.continuityManager = continuity.NewManagerCookie(m) } return m.continuityManager } -func (m *RegistryDefault) RelayStateContinuityManager() continuity.Manager { - // If m.continuityManager is nil or not a continuity.ManagerRelayState - switch m.continuityManager.(type) { - case *continuity.ManagerRelayState: - default: - m.continuityManager = continuity.NewManagerRelayState(m, m) - } - return m.continuityManager -} - func (m *RegistryDefault) ContinuityPersister() continuity.Persister { return m.persister } diff --git a/go.sum b/go.sum index 0fabab367e28..76bd8bcd396c 100644 --- a/go.sum +++ b/go.sum @@ -1120,8 +1120,8 @@ github.com/ory/nosurf v1.2.7/go.mod h1:d4L3ZBa7Amv55bqxCBtCs63wSlyaiCkWVl4vKf3OU github.com/ory/sessions v1.2.2-0.20220110165800-b09c17334dc2 h1:zm6sDvHy/U9XrGpixwHiuAwpp0Ock6khSVHkrv6lQQU= github.com/ory/sessions v1.2.2-0.20220110165800-b09c17334dc2/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= github.com/ory/viper v1.7.5/go.mod h1:ypOuyJmEUb3oENywQZRgeAMwqgOyDqwboO1tj3DjTaM= -github.com/ory/x v0.0.534 h1:hc49pmcOuHdJ6rbHVGtJJ4/LU88dzDCtEQKfgeo/ecU= -github.com/ory/x v0.0.534/go.mod h1:CQopDsCC9t0tQsddE9UlyRFVEFd2xjKBVcw4nLMMMS0= +github.com/ory/x v0.0.537 h1:FB8Tioza6pihvy/RsVNzX08Qg3/VpIhI9vBnEQ4iFmQ= +github.com/ory/x v0.0.537/go.mod h1:CQopDsCC9t0tQsddE9UlyRFVEFd2xjKBVcw4nLMMMS0= github.com/otiai10/copy v1.2.0/go.mod h1:rrF5dJ5F0t/EWSYODDu4j9/vEeYHMkc8jt0zJChqQWw= github.com/otiai10/curr v0.0.0-20150429015615-9b4961190c95/go.mod h1:9qAhocn7zKJG+0mI8eUu6xqkFDYS2kb2saOteoSB3cE= github.com/otiai10/curr v1.0.0/go.mod h1:LskTG5wDwr8Rs+nNQ+1LlxRjAtTZZjtJW4rMXl6j4vs= diff --git a/identity/credentials_saml.go b/identity/credentials_saml.go index c420be179f1f..f79e6b5d5506 100644 --- a/identity/credentials_saml.go +++ b/identity/credentials_saml.go @@ -22,11 +22,18 @@ type CredentialsSAML struct { // swagger:model identityCredentialsSamlProvider type CredentialsSAMLProvider struct { Subject string `json:"subject"` - Provider string `json:"samlProvider"` + Provider string `json:"saml_provider"` } // Create an uniq identifier for user in database. Its look like "id + the id of the saml provider" func NewCredentialsSAML(subject string, provider string) (*Credentials, error) { + if provider == "" { + return nil, errors.New("received empty provider in saml credentials") + } + + if subject == "" { + return nil, errors.New("received empty provider in saml credentials") + } var b bytes.Buffer if err := json.NewEncoder(&b).Encode(CredentialsSAML{ Providers: []CredentialsSAMLProvider{ diff --git a/identity/credentials_saml_test.go b/identity/credentials_saml_test.go new file mode 100644 index 000000000000..e6c80a0be709 --- /dev/null +++ b/identity/credentials_saml_test.go @@ -0,0 +1,18 @@ +package identity + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewCredentialsSAML(t *testing.T) { + _, err := NewCredentialsSAML("not-empty", "") + require.Error(t, err) + + _, err = NewCredentialsSAML("", "not-empty") + require.Error(t, err) + + _, err = NewCredentialsSAML("not-empty", "not-empty") + require.NoError(t, err) +} diff --git a/selfservice/strategy/saml/config_test.go b/selfservice/strategy/saml/config_test.go index 0b509776c026..b32b882e6bc3 100644 --- a/selfservice/strategy/saml/config_test.go +++ b/selfservice/strategy/saml/config_test.go @@ -82,8 +82,8 @@ func TestInitSAMLWithoutPoviderID(t *testing.T) { saml.Configuration{ ID: "", Label: "samlProviderLabel", - PublicCertPath: "file://testdata/myservice.cert", - PrivateKeyPath: "file://testdata/myservice.key", + PublicCertPath: "file://testdata/sp_cert.pem", + PrivateKeyPath: "file://testdata/sp_key.pem", Mapper: "file://testdata/saml.jsonnet", AttributesMap: attributesMap, IDPInformation: idpInformation, @@ -131,8 +131,8 @@ func TestInitSAMLWithoutPoviderLabel(t *testing.T) { saml.Configuration{ ID: "samlProvider", Label: "", - PublicCertPath: "file://testdata/myservice.cert", - PrivateKeyPath: "file://testdata/myservice.key", + PublicCertPath: "file://testdata/sp_cert.pem", + PrivateKeyPath: "file://testdata/sp_key.pem", Mapper: "file://testdata/saml.jsonnet", AttributesMap: attributesMap, IDPInformation: idpInformation, @@ -179,8 +179,8 @@ func TestAttributesMapWithoutID(t *testing.T) { saml.Configuration{ ID: "samlProvider", Label: "samlProviderLabel", - PublicCertPath: "file://testdata/myservice.cert", - PrivateKeyPath: "file://testdata/myservice.key", + PublicCertPath: "file://testdata/sp_cert.pem", + PrivateKeyPath: "file://testdata/sp_key.pem", Mapper: "file://testdata/saml.jsonnet", AttributesMap: attributesMap, IDPInformation: idpInformation, @@ -275,8 +275,8 @@ func TestInitSAMLWithoutIDPInformation(t *testing.T) { saml.Configuration{ ID: "samlProvider", Label: "samlProviderLabel", - PublicCertPath: "file://testdata/myservice.cert", - PrivateKeyPath: "file://testdata/myservice.key", + PublicCertPath: "file://testdata/sp_cert.pem", + PrivateKeyPath: "file://testdata/sp_key.pem", Mapper: "file://testdata/saml.jsonnet", AttributesMap: attributesMap, }, @@ -322,8 +322,8 @@ func TestInitSAMLWithMissingIDPInformationField(t *testing.T) { saml.Configuration{ ID: "samlProvider", Label: "samlProviderLabel", - PublicCertPath: "file://testdata/myservice.cert", - PrivateKeyPath: "file://testdata/myservice.key", + PublicCertPath: "file://testdata/sp_cert.pem", + PrivateKeyPath: "file://testdata/sp_key.pem", Mapper: "file://testdata/saml.jsonnet", IDPInformation: idpInformation, AttributesMap: attributesMap, @@ -372,8 +372,8 @@ func TestInitSAMLWithExtraIDPInformationField(t *testing.T) { saml.Configuration{ ID: "samlProvider", Label: "samlProviderLabel", - PublicCertPath: "file://testdata/myservice.cert", - PrivateKeyPath: "file://testdata/myservice.key", + PublicCertPath: "file://testdata/sp_cert.pem", + PrivateKeyPath: "file://testdata/sp_key.pem", Mapper: "file://testdata/saml.jsonnet", IDPInformation: idpInformation, AttributesMap: attributesMap, diff --git a/selfservice/strategy/saml/provider_config.go b/selfservice/strategy/saml/provider_config.go index 357c9516f965..096513fa9f56 100644 --- a/selfservice/strategy/saml/provider_config.go +++ b/selfservice/strategy/saml/provider_config.go @@ -36,7 +36,7 @@ type ConfigurationCollection struct { SAMLProviders []Configuration `json:"providers"` } -func (c ConfigurationCollection) Provider(id string, reg registrationStrategyDependencies) (Provider, error) { +func (c ConfigurationCollection) Provider(id string, reg dependencies) (Provider, error) { for k := range c.SAMLProviders { p := c.SAMLProviders[k] if p.ID == id { diff --git a/selfservice/strategy/saml/provider_saml.go b/selfservice/strategy/saml/provider_saml.go index e0917dfd44ea..4b31d814ba55 100644 --- a/selfservice/strategy/saml/provider_saml.go +++ b/selfservice/strategy/saml/provider_saml.go @@ -13,12 +13,12 @@ import ( type ProviderSAML struct { config *Configuration - reg registrationStrategyDependencies + reg dependencies } func NewProviderSAML( config *Configuration, - reg registrationStrategyDependencies, + reg dependencies, ) *ProviderSAML { return &ProviderSAML{ config: config, diff --git a/selfservice/strategy/saml/strategy.go b/selfservice/strategy/saml/strategy.go index 4e2c69b91653..64b49322b7b9 100644 --- a/selfservice/strategy/saml/strategy.go +++ b/selfservice/strategy/saml/strategy.go @@ -17,19 +17,18 @@ import ( "github.com/tidwall/gjson" "github.com/ory/herodot" + "github.com/ory/kratos/cipher" + "github.com/ory/kratos/schema" "github.com/ory/kratos/text" "github.com/ory/kratos/ui/container" "github.com/ory/kratos/ui/node" - "github.com/go-playground/validator/v10" - "github.com/ory/x/decoderx" - "github.com/ory/x/fetcher" + "github.com/ory/x/jsonnetsecure" "github.com/ory/x/jsonx" "github.com/ory/kratos/continuity" "github.com/ory/kratos/driver/config" - "github.com/ory/kratos/hash" "github.com/ory/kratos/identity" "github.com/ory/kratos/selfservice/errorx" "github.com/ory/kratos/selfservice/flow" @@ -56,50 +55,57 @@ const ( var _ identity.ActiveCredentialsCounter = new(Strategy) -type registrationStrategyDependencies interface { - x.LoggingProvider - x.WriterProvider - x.CSRFTokenGeneratorProvider - x.CSRFProvider - x.HTTPClientProvider +type dependencies interface { + errorx.ManagementProvider config.Provider - continuity.ManagementProvider - continuity.ManagementProviderRelayState + x.LoggingProvider + x.CookieProvider + x.CSRFProvider + x.CSRFTokenGeneratorProvider + x.WriterProvider + x.HTTPClientProvider + x.TracingProvider - errorx.ManagementProvider - hash.HashProvider + identity.ValidationProvider + identity.PrivilegedPoolProvider + identity.ActiveCredentialsCounterStrategyProvider + identity.ManagementProvider - registration.HandlerProvider - registration.HooksProvider - registration.ErrorHandlerProvider - registration.HookExecutorProvider - registration.FlowPersistenceProvider + session.ManagementProvider + session.HandlerProvider - login.HooksProvider - login.ErrorHandlerProvider login.HookExecutorProvider login.FlowPersistenceProvider + login.HooksProvider + login.StrategyProvider login.HandlerProvider + login.ErrorHandlerProvider + registration.HookExecutorProvider + registration.FlowPersistenceProvider + registration.HooksProvider + registration.StrategyProvider + registration.HandlerProvider + registration.ErrorHandlerProvider + + settings.ErrorHandlerProvider settings.FlowPersistenceProvider settings.HookExecutorProvider - settings.HooksProvider - settings.ErrorHandlerProvider - identity.PrivilegedPoolProvider - identity.ValidationProvider + continuity.ManagementProvider - session.HandlerProvider - session.ManagementProvider + cipher.Provider + + jsonnetsecure.VMProvider } func (s *Strategy) ID() identity.CredentialsType { return identity.CredentialsTypeSAML } -func (s *Strategy) D() registrationStrategyDependencies { +func (s *Strategy) D() dependencies { return s.d } @@ -115,10 +121,9 @@ func isForced(req interface{}) bool { } type Strategy struct { - d registrationStrategyDependencies - f *fetcher.Fetcher - v *validator.Validate - hd *decoderx.HTTP + d dependencies + validator *schema.Validator + dec *decoderx.HTTP } type authCodeContainer struct { @@ -132,12 +137,10 @@ func generateState(flowID string) string { return base64.RawURLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", flowID, state))) } -func NewStrategy(d registrationStrategyDependencies) *Strategy { +func NewStrategy(d dependencies) *Strategy { return &Strategy{ - d: d, - f: fetcher.NewFetcher(), - v: validator.New(), - hd: decoderx.NewHTTP(), + d: d, + validator: schema.NewValidator(), } } @@ -247,7 +250,7 @@ func (s *Strategy) alreadyAuthenticated(w http.ResponseWriter, r *http.Request, func (s *Strategy) validateCallback(w http.ResponseWriter, r *http.Request) (flow.Flow, *authCodeContainer, error) { var cntnr authCodeContainer - if _, err := s.d.RelayStateContinuityManager().Continue(r.Context(), w, r, sessionName, continuity.WithPayload(&cntnr)); err != nil { + if _, err := s.d.ContinuityManager().Continue(r.Context(), w, r, sessionName, continuity.WithPayload(&cntnr), continuity.UseRelayState()); err != nil { return nil, nil, err } @@ -286,7 +289,7 @@ func (s *Strategy) HandleCallback(w http.ResponseWriter, r *http.Request, ps htt m, err := GetMiddleware(pid) if err != nil { - s.forwardError(w, r, err) + s.forwardError(w, r, s.handleError(w, r, req, pid, nil, err)) return } @@ -296,28 +299,28 @@ func (s *Strategy) HandleCallback(w http.ResponseWriter, r *http.Request, ps htt // We parse the SAML Response to get the SAML Assertion assertion, err := m.ServiceProvider.ParseResponse(r, possibleRequestIDs) if err != nil { - s.forwardError(w, r, err) + s.forwardError(w, r, s.handleError(w, r, req, pid, nil, err)) return } // We get the user's attributes from the SAML Response (assertion) attributes, err := s.GetAttributesFromAssertion(assertion) if err != nil { - s.forwardError(w, r, err) + s.forwardError(w, r, s.handleError(w, r, req, pid, nil, err)) return } // We get the provider information from the config file provider, err := s.Provider(r.Context(), pid) if err != nil { - s.forwardError(w, r, err) + s.forwardError(w, r, s.handleError(w, r, req, pid, nil, err)) return } // We translate SAML Attributes into claims (To create an identity we need these claims) claims, err := provider.Claims(r.Context(), s.d.Config(), attributes, pid) if err != nil { - s.forwardError(w, r, err) + s.forwardError(w, r, s.handleError(w, r, req, pid, nil, err)) return } @@ -326,10 +329,10 @@ func (s *Strategy) HandleCallback(w http.ResponseWriter, r *http.Request, ps htt // Now that we have the claims and the provider, we have to decide if we log or register the user if ff, err := s.processLoginOrRegister(w, r, a, provider, claims); err != nil { if ff != nil { - s.forwardError(w, r, err) + s.forwardError(w, r, s.handleError(w, r, req, pid, nil, err)) return } - s.forwardError(w, r, err) + s.forwardError(w, r, s.handleError(w, r, req, pid, nil, err)) } return } @@ -385,7 +388,7 @@ func (s *Strategy) populateMethod(r *http.Request, c *container.Container, messa func (s *Strategy) handleError(w http.ResponseWriter, r *http.Request, f flow.Flow, provider string, traits []byte, err error) error { switch rf := f.(type) { case *login.Flow: - return ErrAPIFlowNotSupported.WithTrace(err) + return err case *registration.Flow: // Reset all nodes to not confuse users. // This is kinda hacky and will probably need to be updated at some point. @@ -399,24 +402,22 @@ func (s *Strategy) handleError(w http.ResponseWriter, r *http.Request, f flow.Fl if traits != nil { ds, err := s.d.Config().DefaultIdentityTraitsSchemaURL(r.Context()) if err != nil { - return ErrInvalidSAMLConfiguration.WithTrace(err) + return err } - traitNodes, err := container.NodesFromJSONSchema(r.Context(), node.SAMLGroup, ds.String(), "", nil) + traitNodes, err := container.NodesFromJSONSchema(r.Context(), node.OpenIDConnectGroup, ds.String(), "", nil) if err != nil { - return herodot.ErrInternalServerError.WithTrace(err) + return err } rf.UI.Nodes = append(rf.UI.Nodes, traitNodes...) - rf.UI.UpdateNodeValuesFromJSON(traits, "traits", node.SAMLGroup) + rf.UI.UpdateNodeValuesFromJSON(traits, "traits", node.OpenIDConnectGroup) } - return herodot.ErrInternalServerError.WithTrace(err) - case *settings.Flow: - return ErrAPIFlowNotSupported.WithTrace(err) + return err } - return herodot.ErrInternalServerError.WithTrace(err) + return err } func (s *Strategy) CountActiveCredentials(cc map[identity.CredentialsType]identity.Credentials) (count int, err error) { diff --git a/selfservice/strategy/saml/strategy_login.go b/selfservice/strategy/saml/strategy_login.go index 67faa6682b25..f8670b16fe55 100644 --- a/selfservice/strategy/saml/strategy_login.go +++ b/selfservice/strategy/saml/strategy_login.go @@ -7,9 +7,7 @@ import ( "time" "github.com/gofrs/uuid" - "github.com/google/go-jsonnet" "github.com/pkg/errors" - "github.com/tidwall/gjson" "github.com/ory/herodot" "github.com/ory/kratos/continuity" @@ -58,7 +56,7 @@ type SubmitSelfServiceLoginFlowWithSAMLMethodBody struct { // Login and give a session to the user func (s *Strategy) processLogin(w http.ResponseWriter, r *http.Request, a *login.Flow, provider Provider, c *identity.Credentials, i *identity.Identity, claims *Claims) (*registration.Flow, error) { - s.updateIdentityTraits(i, provider, claims) + s.updateIdentityTraits(w, r, i, provider, claims) var o identity.CredentialsSAML if err := json.NewDecoder(bytes.NewBuffer(c.Config)).Decode(&o); err != nil { @@ -104,13 +102,13 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow, } state := generateState(f.ID.String()) - if err := s.d.RelayStateContinuityManager().Pause(r.Context(), w, r, sessionName, + if err := s.d.ContinuityManager().Pause(r.Context(), w, r, sessionName, continuity.WithPayload(&authCodeContainer{ State: state, FlowID: f.ID.String(), Traits: p.Traits, }), - continuity.WithLifespan(time.Minute*30)); err != nil { + continuity.WithLifespan(time.Minute*30), continuity.UseRelayState()); err != nil { return nil, s.handleError(w, r, f, pid, nil, err) } @@ -142,28 +140,17 @@ func (s *Strategy) PopulateLoginMethod(r *http.Request, requestedAAL identity.Au } // In order to do a JustInTimeProvisioning, it is important to update the identity traits at each new SAML connection -func (s *Strategy) updateIdentityTraits(i *identity.Identity, provider Provider, claims *Claims) error { - jn, err := s.f.Fetch(provider.Config().Mapper) - if err != nil { - return nil - } +func (s *Strategy) updateIdentityTraits(w http.ResponseWriter, r *http.Request, i *identity.Identity, provider Provider, claims *Claims) error { var jsonClaims bytes.Buffer if err := json.NewEncoder(&jsonClaims).Encode(claims); err != nil { - return nil + return err } - vm := jsonnet.MakeVM() - vm.ExtCode("claims", jsonClaims.String()) - evaluated, err := vm.EvaluateAnonymousSnippet(provider.Config().Mapper, jn.String()) - if err != nil { + if err := s.setTraits(w, r, claims, provider, jsonClaims, i); err != nil { return err - } else if traits := gjson.Get(evaluated, "identity.traits"); !traits.IsObject() { - i.Traits = []byte{'{', '}'} - return errors.New("SAML Jsonnet mapper did not return an object for key identity.traits. Please check your Jsonnet code!") - } else { - i.Traits = []byte(traits.Raw) - return nil } + return nil + } diff --git a/selfservice/strategy/saml/strategy_registration.go b/selfservice/strategy/saml/strategy_registration.go index d536be07d629..208fce11655b 100644 --- a/selfservice/strategy/saml/strategy_registration.go +++ b/selfservice/strategy/saml/strategy_registration.go @@ -34,7 +34,7 @@ func (s *Strategy) createIdentity(w http.ResponseWriter, r *http.Request, a *reg } i := identity.NewIdentity(s.d.Config().DefaultIdentityTraitsSchemaID(r.Context())) - if err := s.setTraits(w, r, a, claims, provider, jsonClaims, i); err != nil { + if err := s.setTraits(w, r, claims, provider, jsonClaims, i); err != nil { return nil, s.handleError(w, r, a, provider.Config().ID, i.Traits, err) } @@ -46,7 +46,7 @@ func (s *Strategy) createIdentity(w http.ResponseWriter, r *http.Request, a *reg return i, nil } -func (s *Strategy) setTraits(w http.ResponseWriter, r *http.Request, a *registration.Flow, claims *Claims, provider Provider, jsonClaims bytes.Buffer, i *identity.Identity) error { +func (s *Strategy) setTraits(w http.ResponseWriter, r *http.Request, claims *Claims, provider Provider, jsonClaims bytes.Buffer, i *identity.Identity) error { traitsMap := make(map[string]interface{}) json.Unmarshal(jsonClaims.Bytes(), &traitsMap) @@ -55,7 +55,7 @@ func (s *Strategy) setTraits(w http.ResponseWriter, r *http.Request, a *registra delete(traitsMap, "sub") traits, err := json.Marshal(traitsMap) if err != nil { - return s.handleError(w, r, a, provider.Config().ID, i.Traits, err) + return err } i.Traits = identity.Traits(traits) @@ -119,7 +119,7 @@ func (s *Strategy) newLinkDecoder(p interface{}, r *http.Request) error { return errors.WithStack(err) } - if err := s.hd.Decode(r, &p, compiler, + if err := s.dec.Decode(r, &p, compiler, decoderx.HTTPKeepRequestBody(true), decoderx.HTTPDecoderSetValidatePayloads(false), decoderx.HTTPDecoderUseQueryAndBody(), diff --git a/selfservice/strategy/saml/testdata/sp2_cert.pem b/selfservice/strategy/saml/testdata/sp2_cert.pem deleted file mode 100644 index e74d8fcb8c00..000000000000 --- a/selfservice/strategy/saml/testdata/sp2_cert.pem +++ /dev/null @@ -1,32 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIFbTCCA1WgAwIBAgIUDxKgJwZpe1Fmo4nFUVYTXwvIgOAwDQYJKoZIhvcNAQEL -BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM -GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAgFw0yMzAxMTExMzQ4MTlaGA8yMjk2 -MTAyNTEzNDgxOVowRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUx -ITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDCCAiIwDQYJKoZIhvcN -AQEBBQADggIPADCCAgoCggIBAM6XqltTvOTIqGZeUKnRfWqZdmjiJWU4XtSh4D8v -DUmBRIYXEBbcSqcNKsHjNoHoTXBqa9jkxl5ZIOwE3mCWr5heXTR9T4IhtsWRopWB -Aio8AIdAPuHiQoMxGm0jwZ4KCaZfLsHRIflJloSFEtebungmpDqo4hwsM6IUz5z3 -bKX1rlCCtupHpLihaLMVmgSSBvLDeAPvtHIBtNjNyhhi9TMS8SxFbCI7+PXRTRrR -G2ozUrR8qIBQxh/kA/rQWH4GGLt6phBdsxnmdn2idLhyvTjT8nA09gglvXZZ7gI9 -m8jbmUgpVGzmk38cAw+oENVMufAZWJBdAfrIfK0TR5YqXqhfj6e5BBf4OyM4rMlo -xlJnbQTI30uB2uzPfjry43NSu9jlFc/+r4ufW+ptEH9YIx9HbSz8y+hrpFdofSEQ -8J6w71x3NzN1L/hFKYqhQ8Gv2Flk1kp0iLkZ3tP1UCK2eDFhM9BBD3ZHYyByboWj -sKffst67LUAlufiLW4q0tZkpWwUR5fExAsFF//CciKPIpe8WBKlHxw6FYGPRGva5 -KK10jGZTJoM+KYwcZFpHCaWt8tFQB8ti9bQIP2etmrpsBZuR4CDuZEkMuvSLsJE5 -6SI+rxT3WKwK9zAginomFvEEvaqp43RT2a1zzYAfE3pRETis1tyoNN2I384ywL41 -lofhAgMBAAGjUzBRMB0GA1UdDgQWBBSmECM2phplGyTKV/dGef+JUPVpgjAfBgNV -HSMEGDAWgBSmECM2phplGyTKV/dGef+JUPVpgjAPBgNVHRMBAf8EBTADAQH/MA0G -CSqGSIb3DQEBCwUAA4ICAQCxbwfNQvpw68pTmyCIipb5pkuVDnjp65RV0wJbOfDR -qiQHVQsJexY1xmptOADzCvBQIkAAKCeLfJ8tKS6473Xc3BayREyJpN3oQsr1MDep -j6/ae8I1wt6uJ18M93wArWou/nuDHlkBeEKYlwCQYRZPW++9E38v6ZzKK7qHN+6M -vFKXx/Q98WpaNo5Oj0o8ngEYxS5/9Axn7EUBKLpikb8KNIO+icqc6DPs4GTqKb4/ -wC5FPcROoQAau3RsrZ8cAMU+zGt6OeYWU2Sabsnm1lo3bYMx2XuQOEQvLcrTjBLm -041LYk3SbPotBWc4ahVF4SUZWHKZst76+cZtR5RLZt3jjKjTguq76itPnuZxdM0g -JmdhophvFNwyKjxQ3jbJc9W1mpq5ILrtzO0pWTjOrBDWdZ4GF078GmjYrtJJ2e7T -LI0uuXwKB0K5SktluIM+7PVXYqt3ZnPJ6zjMCuYoQT5ua29hs9Qi/zjAf2Mf8JLo -t2MdAmvVZDr8bSVkyx0RrIKwYKLJ6b+KgdSACb618GV6dLpqMbe3mC+yPPa/FKIS -M64SBf/gBlMQpcUWdH4IWvQXu1Lmn+TPr4+BXi7loMGwAcGH7pcbYouOMsZlJ/CG -1cQ5cf3kevKolmJVaxJC+ZEBCqvM3/FySSmNNCmQidXi5QLHP87uYn/9aRaHh1kM -eA== ------END CERTIFICATE----- diff --git a/selfservice/strategy/saml/testdata/sp2_key.pem b/selfservice/strategy/saml/testdata/sp2_key.pem deleted file mode 100644 index 9bf6aeab4b89..000000000000 --- a/selfservice/strategy/saml/testdata/sp2_key.pem +++ /dev/null @@ -1,52 +0,0 @@ ------BEGIN PRIVATE KEY----- -MIIJQwIBADANBgkqhkiG9w0BAQEFAASCCS0wggkpAgEAAoICAQDOl6pbU7zkyKhm -XlCp0X1qmXZo4iVlOF7UoeA/Lw1JgUSGFxAW3EqnDSrB4zaB6E1wamvY5MZeWSDs -BN5glq+YXl00fU+CIbbFkaKVgQIqPACHQD7h4kKDMRptI8GeCgmmXy7B0SH5SZaE -hRLXm7p4JqQ6qOIcLDOiFM+c92yl9a5QgrbqR6S4oWizFZoEkgbyw3gD77RyAbTY -zcoYYvUzEvEsRWwiO/j10U0a0RtqM1K0fKiAUMYf5AP60Fh+Bhi7eqYQXbMZ5nZ9 -onS4cr040/JwNPYIJb12We4CPZvI25lIKVRs5pN/HAMPqBDVTLnwGViQXQH6yHyt -E0eWKl6oX4+nuQQX+DsjOKzJaMZSZ20EyN9Lgdrsz3468uNzUrvY5RXP/q+Ln1vq -bRB/WCMfR20s/Mvoa6RXaH0hEPCesO9cdzczdS/4RSmKoUPBr9hZZNZKdIi5Gd7T -9VAitngxYTPQQQ92R2Mgcm6Fo7Cn37Leuy1AJbn4i1uKtLWZKVsFEeXxMQLBRf/w -nIijyKXvFgSpR8cOhWBj0Rr2uSitdIxmUyaDPimMHGRaRwmlrfLRUAfLYvW0CD9n -rZq6bAWbkeAg7mRJDLr0i7CROekiPq8U91isCvcwIIp6JhbxBL2qqeN0U9mtc82A -HxN6URE4rNbcqDTdiN/OMsC+NZaH4QIDAQABAoICAQCZ1EbOUBDkDiGOcAYCHPIV -EQYhXNrZftrl208d3Qw4wl9itQOO8iNINj6zNltc6bvXy/ZX/ylSEW25MHrhUvKX -MxSVxAUS8cWlYSa9ydzx09HU49qu2YoLI+H4iFpgMjszPcaUHQP+GnRQYsI/9z4m -vyckYqJStfsQYgyhZX7qKIDOhDZtRkF6FP3f82LGqnEwDKpty+wBxBGEKd+kvvKz -QBSCkYLODvf3Gg0evbt7HZIkwHm7aenMzzzDYqWx2RpLZy0GHK8Cxx9Nt0zQFuec -y/zG3jigonFsEdRuqK86JYICQHwTxrDnQdVpsAwwtzvwcv8GJ6sUsHpdaXCxeQUX -ZC0k20JkUzAlNCUT+w+MaN6gfO27NFknmDYvtV46z7dm2gaCE2IGi/Z890xg/wTl -3OelqnNM9+qNRWCl0JLgVPTlcWexceg8DJVLTuE1R7qRND/FIG1IzMJUGQdhGwnn -ehJ3BvoMj+RxHkVnljN2wDYWtlbfJcjbBL/ruUbDHKsUB85j9RCkTWDvetUs6cPG -QB+CJ/SxI4/2Wlw58e/kG22uja+SvZJuwza7M3pOkv1Dxa+6ASj/O6sooEkVJ90y -aoM8A5b+FoS8XGoAJu3D3cV3GwYfE6Yue5nCcp10Wh/MQ+e/HqaWZHmgNjY0zd1I -1vyeQieJyT7oA7XsMDt1UQKCAQEA/mzRWKFf3C7DsgoxHDc7nGF6ojhK+z+ujfxj -qDyUqXZkx7UmY/2hp2/20ndXbla8TqFgkdh1zj4uRqbMlgNStlQnYzIwU6IBtfPo -+NrFPmlnZ4pHqMNkATO+oYegKASOaqF53aD09PV0KhzcfJpkSgxbOOE5Gt5sfj5q -z5AO6aLFQXUaykCZU/mf+WneX5YltXpHLelx6YyIL/H/RQjbR7A8V5ug7dkuy+U/ -RUb07sJ5AmeH+dbMiDpwLyFxfltEKdV21Ex/oGrdVHvPYzuhT/IVHujqdMmkODJ5 -PD+jmHcted0nZ0VH9gxprGkGF3PlsSX+KcuvrGU4uN6yasi9hQKCAQEAz98MXcbs -5akVyq9U4cIGeyn1miGwlNpKPxBOqWEc1lKqUHTCayD+SUrXfd0zSi8R5VSp9MuX -ciiibkb5/vBrLYXZRtZCtaVUPlyH/PvKIlsYP+J4QMqZJGHIFUpPsF9SQjmBmktF -J+N00WT7iGfJiFpHc05aq7yo2/AgiLmjWUsDYXxHwTkDT/q838K4gsIdOCh1sUQr -tT/LhciHEp9yF8hcQ38BPPnp7NHalwx8WndgjNqC6gaQ5T01K0YKH5mtuyPpJj4d -p9RqqcLju8FzNiEld57JQJPSb6MW5cKYu0uSArRLdAMeVYMThYCHHu15Z0zEy3ce -PAlY8G1pfYgxrQKCAQAbdGKizccqW2GCtNbX1J36Igq5tplgw15ys+mNHfxszPnT -ExkxcQ0gpFReIcKthW6MjZ1+H32W497agOVSyskCI9KcQa41WCYXHFrnf7QJKBag -dauF6o/AEXVguOHvb45usz4TTGsig9olMTgZug9YbjzpxmQDIj1S4ilkfIcfbxEa -Hyjk6lOhXC6HG4WDixBGpQtJSQehzChmBBcnu+ztr3bTfVfAUs9Z8UMClsWXfiTQ -vZtOun8XtDam31T/7ZlNaluITTj4do+rrjCS5LxjhBwDWd7y+09dQRUUC0n8CeA+ -Zj76Rd+eDXjZwfuGTFtc4lyq5e/vCn00ddOK8l6BAoIBAQDJPGBHZK2wA4myJxyg -VWpaz5sRdK3y3IRmGs5cEUSOg4aXzwDsHwutPoPxODRQC9NiVR0XfAUIIihlY9bf -JDZN4rceaYw5N22f1YpcshDUQ6XtKrxJ1Rh+bR765W7SCuWicPNzwIyZegx8LiuH -uRoUI3nqOZ9zhHdgPE3yruxhJEqIlH0OpLf9NHqmkGZ5R5xr4ldVne5GUBUiVafV -soAMYA5Z1VkIg9QfTGU2N4MnPUw978gu8N5S3ndbhjmEsAzND43FVPr2n6AG6kH3 -YOa9L0eLTy/7kV92bcdb9JBROW6Hqa0mCWLTW8qJQo0Mts8B3wLhClc9vbrZPsKS -IUgdAoIBADTgbwunp+95gp7eC2VPM42T/cZfCjHmThzF6dNJOWEVgUTsvjTqND7C -+XDyCgTVbyBWiEl8OA3ePl8oDVUX96Bd4TE/7wwfGe4wn85XcchdZmFbIN/2cS7T -eHEq85IY676T8WLU3LdEu/fPn4xYCT9fx32JB7IfZDuJ6liQUrjQFhZrC6eIFQON -He4niCxUTt1VxMb+0dtVGF2sBBW/rfg9BlOW+Jrllhm6RWTzlajCkmZW0BWwl5zi -KTt/KgAF7SkGoW57znDr9soJcLPaPAhjdYkxmuPPqwn94Qw6IX24DL6QLNALVlf5 -d7y+GE7k0jSTLxP851L5BvfqKi2Ay0w= ------END PRIVATE KEY----- diff --git a/selfservice/strategy/saml/vulnerabilities_helper_test.go b/selfservice/strategy/saml/vulnerabilities_helper_test.go index b49f433d81b8..9592902f0f61 100644 --- a/selfservice/strategy/saml/vulnerabilities_helper_test.go +++ b/selfservice/strategy/saml/vulnerabilities_helper_test.go @@ -217,12 +217,12 @@ func startContinuity(resp *httptest.ResponseRecorder, r *http.Request, strategy strategy.D().LoginFlowPersister().CreateLoginFlow(r.Context(), f) state := x.NewUUID().String() - strategy.D().RelayStateContinuityManager().Pause(r.Context(), resp, r, "ory_kratos_saml_auth_code_session", + strategy.D().ContinuityManager().Pause(r.Context(), resp, r, "ory_kratos_saml_auth_code_session", continuity.WithPayload(&authCodeContainer{ State: state, FlowID: f.ID.String(), }), - continuity.WithLifespan(time.Minute*30)) + continuity.WithLifespan(time.Minute*30), continuity.UseRelayState()) } func initRouterParams() httprouter.Params { diff --git a/selfservice/strategy/saml/vulnerabilities_test.go b/selfservice/strategy/saml/vulnerabilities_test.go index 7c0095aef5a8..8af41cc7ee5a 100644 --- a/selfservice/strategy/saml/vulnerabilities_test.go +++ b/selfservice/strategy/saml/vulnerabilities_test.go @@ -496,7 +496,7 @@ func TestAddXMLCommentsInSAMLAttributes(t *testing.T) { strategy.HandleCallback(resp, req, ps) // Get all identities - ids, _ := strategy.D().PrivilegedIdentityPool().ListIdentities(context.Background(), identity.ExpandEverything, 0, 1000) + ids, _ := strategy.D().PrivilegedIdentityPool().ListIdentities(context.Background(), identity.ListIdentityParameters{Expand: identity.ExpandEverything, Page: 0, PerPage: 1000}) traitsMap := make(map[string]interface{}) json.Unmarshal(ids[0].Traits, &traitsMap) @@ -680,7 +680,7 @@ func TestXSW3AssertionWrap1(t *testing.T) { strategy.HandleCallback(resp, req, ps) // Get all identities - ids, _ := strategy.D().PrivilegedIdentityPool().ListIdentities(context.Background(), identity.ExpandEverything, 0, 1000) + ids, _ := strategy.D().PrivilegedIdentityPool().ListIdentities(context.Background(), identity.ListIdentityParameters{Expand: identity.ExpandEverything, Page: 0, PerPage: 1000}) // We have to check that there is either an error or an identity created without the modified attribute assert.Check(t, strings.Contains(resp.HeaderMap["Location"][0], "error") || strings.Contains(string(ids[0].Traits), "alice@example.com")) @@ -742,7 +742,7 @@ func TestXSW4AssertionWrap2(t *testing.T) { strategy.HandleCallback(resp, req, ps) // Get all identities - ids, _ := strategy.D().PrivilegedIdentityPool().ListIdentities(context.Background(), identity.ExpandEverything, 0, 1000) + ids, _ := strategy.D().PrivilegedIdentityPool().ListIdentities(context.Background(), identity.ListIdentityParameters{Expand: identity.ExpandEverything, Page: 0, PerPage: 1000}) // We have to check that there is either an error or an identity created without the modified attribute assert.Check(t, strings.Contains(resp.HeaderMap["Location"][0], "error") || strings.Contains(string(ids[0].Traits), "alice@example.com")) @@ -804,7 +804,7 @@ func TestXSW5AssertionWrap3(t *testing.T) { strategy.HandleCallback(resp, req, ps) // Get all identities - ids, _ := strategy.D().PrivilegedIdentityPool().ListIdentities(context.Background(), identity.ExpandEverything, 0, 1000) + ids, _ := strategy.D().PrivilegedIdentityPool().ListIdentities(context.Background(), identity.ListIdentityParameters{Expand: identity.ExpandEverything, Page: 0, PerPage: 1000}) // We have to check that there is either an error or an identity created without the modified attribute assert.Check(t, strings.Contains(resp.HeaderMap["Location"][0], "error") || strings.Contains(string(ids[0].Traits), "alice@example.com")) @@ -866,7 +866,7 @@ func TestXSW6AssertionWrap4(t *testing.T) { strategy.HandleCallback(resp, req, ps) // Get all identities - ids, _ := strategy.D().PrivilegedIdentityPool().ListIdentities(context.Background(), identity.ExpandEverything, 0, 1000) + ids, _ := strategy.D().PrivilegedIdentityPool().ListIdentities(context.Background(), identity.ListIdentityParameters{Expand: identity.ExpandEverything, Page: 0, PerPage: 1000}) // We have to check that there is either an error or an identity created without the modified attribute assert.Check(t, strings.Contains(resp.HeaderMap["Location"][0], "error") || strings.Contains(string(ids[0].Traits), "alice@example.com")) @@ -928,7 +928,7 @@ func TestXSW7AssertionWrap5(t *testing.T) { strategy.HandleCallback(resp, req, ps) // Get all identities - ids, _ := strategy.D().PrivilegedIdentityPool().ListIdentities(context.Background(), identity.ExpandEverything, 0, 1000) + ids, _ := strategy.D().PrivilegedIdentityPool().ListIdentities(context.Background(), identity.ListIdentityParameters{Expand: identity.ExpandEverything, Page: 0, PerPage: 1000}) // We have to check that there is either an error or an identity created without the modified attribute assert.Check(t, strings.Contains(resp.HeaderMap["Location"][0], "error") || strings.Contains(string(ids[0].Traits), "alice@example.com")) @@ -988,7 +988,7 @@ func TestXSW8AssertionWrap6(t *testing.T) { strategy.HandleCallback(resp, req, ps) // Get all identities - ids, _ := strategy.D().PrivilegedIdentityPool().ListIdentities(context.Background(), identity.ExpandEverything, 0, 1000) + ids, _ := strategy.D().PrivilegedIdentityPool().ListIdentities(context.Background(), identity.ListIdentityParameters{Expand: identity.ExpandEverything, Page: 0, PerPage: 1000}) // We have to check that there is either an error or an identity created without the modified attribute assert.Check(t, strings.Contains(resp.HeaderMap["Location"][0], "error") || strings.Contains(string(ids[0].Traits), "alice@example.com")) diff --git a/x/provider.go b/x/provider.go index ec2ae5736ce5..f87f3f0fdc95 100644 --- a/x/provider.go +++ b/x/provider.go @@ -29,11 +29,6 @@ type CookieProvider interface { ContinuityCookieManager(ctx context.Context) sessions.StoreExact } -type RelayStateProvider interface { - RelayStateManager(ctx context.Context) sessions.StoreExact - ContinuityRelayStateManager(ctx context.Context) sessions.StoreExact -} - type TracingProvider interface { Tracer(ctx context.Context) *otelx.Tracer }