Skip to content

Commit

Permalink
feat(saml): continuity manager relaystate refactoring
Browse files Browse the repository at this point in the history
+ test new credentials saml
+ resolving conflicts with master
Signed-off-by: sebferrer <[email protected]>

Co-authored-by: ThibaultHerard <[email protected]>
  • Loading branch information
sebferrer and ThibHrrd committed Feb 22, 2023
1 parent 576df75 commit 7a827c6
Show file tree
Hide file tree
Showing 19 changed files with 167 additions and 386 deletions.
24 changes: 14 additions & 10 deletions continuity/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,19 @@ type ManagementProvider interface {
ContinuityManager() Manager
}

type ManagementProviderRelayState interface {
RelayStateContinuityManager() Manager
}

type Manager interface {
Pause(ctx context.Context, w http.ResponseWriter, r *http.Request, name string, opts ...ManagerOption) error
Continue(ctx context.Context, w http.ResponseWriter, r *http.Request, name string, opts ...ManagerOption) (*Container, error)
Abort(ctx context.Context, w http.ResponseWriter, r *http.Request, name string) error
Abort(ctx context.Context, w http.ResponseWriter, r *http.Request, name string, opts ...ManagerOption) error
}

type managerOptions struct {
iid uuid.UUID
ttl time.Duration
payload json.RawMessage
payloadRaw interface{}
cleanUp bool
iid uuid.UUID
ttl time.Duration
payload json.RawMessage
payloadRaw interface{}
cleanUp bool
useRelayState bool
}

type ManagerOption func(*managerOptions) error
Expand Down Expand Up @@ -87,3 +84,10 @@ func WithPayload(payload interface{}) ManagerOption {
return nil
}
}

func UseRelayState() ManagerOption {
return func(o *managerOptions) error {
o.useRelayState = true
return nil
}
}
29 changes: 21 additions & 8 deletions continuity/manager_cookie.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"net/http"

"github.com/gofrs/uuid"
"github.com/gorilla/sessions"
"github.com/pkg/errors"

"github.com/ory/herodot"
Expand Down Expand Up @@ -64,12 +65,12 @@ func (m *ManagerCookie) Pause(ctx context.Context, w http.ResponseWriter, r *htt
}

func (m *ManagerCookie) Continue(ctx context.Context, w http.ResponseWriter, r *http.Request, name string, opts ...ManagerOption) (*Container, error) {
container, err := m.container(ctx, w, r, name)
o, err := newManagerOptions(opts)
if err != nil {
return nil, err
}

o, err := newManagerOptions(opts)
container, err := m.container(ctx, w, r, name, o.useRelayState)
if err != nil {
return nil, err
}
Expand All @@ -95,9 +96,16 @@ func (m *ManagerCookie) Continue(ctx context.Context, w http.ResponseWriter, r *
return container, nil
}

func (m *ManagerCookie) sid(ctx context.Context, w http.ResponseWriter, r *http.Request, name string) (uuid.UUID, error) {
func (m *ManagerCookie) sid(ctx context.Context, w http.ResponseWriter, r *http.Request, name string, useRelayState bool) (uuid.UUID, error) {
var getStringFunction func(r *http.Request, s sessions.StoreExact, id string, key interface{}) (string, error)
if useRelayState {
getStringFunction = x.SessionGetStringRelayState
} else {
getStringFunction = x.SessionGetString
}

var sid uuid.UUID
if s, err := x.SessionGetString(r, m.d.ContinuityCookieManager(ctx), CookieName, name); err != nil {
if s, err := getStringFunction(r, m.d.ContinuityCookieManager(ctx), CookieName, name); err != nil {
_ = x.SessionUnsetKey(w, r, m.d.ContinuityCookieManager(ctx), CookieName, name)
return sid, errors.WithStack(ErrNotResumable.WithDebugf("%+v", err))
} else if sid = x.ParseUUID(s); sid == uuid.Nil {
Expand All @@ -108,8 +116,8 @@ func (m *ManagerCookie) sid(ctx context.Context, w http.ResponseWriter, r *http.
return sid, nil
}

func (m *ManagerCookie) container(ctx context.Context, w http.ResponseWriter, r *http.Request, name string) (*Container, error) {
sid, err := m.sid(ctx, w, r, name)
func (m *ManagerCookie) container(ctx context.Context, w http.ResponseWriter, r *http.Request, name string, useRelayState bool) (*Container, error) {
sid, err := m.sid(ctx, w, r, name, useRelayState)
if err != nil {
return nil, err
}
Expand All @@ -129,8 +137,13 @@ func (m *ManagerCookie) container(ctx context.Context, w http.ResponseWriter, r
return container, err
}

func (m ManagerCookie) Abort(ctx context.Context, w http.ResponseWriter, r *http.Request, name string) error {
sid, err := m.sid(ctx, w, r, name)
func (m ManagerCookie) Abort(ctx context.Context, w http.ResponseWriter, r *http.Request, name string, opts ...ManagerOption) error {
o, err := newManagerOptions(opts)
if err != nil {
return err
}

sid, err := m.sid(ctx, w, r, name, o.useRelayState)
if errors.Is(err, &ErrNotResumable) {
// We do not care about an error here
return nil
Expand Down
149 changes: 0 additions & 149 deletions continuity/manager_relaystate.go

This file was deleted.

19 changes: 12 additions & 7 deletions continuity/manager_relaystate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package continuity_test

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
Expand All @@ -31,7 +32,7 @@ import (
"github.com/ory/kratos/x"
)

func TestManagerRelayState(t *testing.T) {
func TestManagerUseRelayState(t *testing.T) {
ctx := context.Background()
conf, reg := internal.NewFastRegistryWithMocks(t)

Expand All @@ -57,6 +58,7 @@ func TestManagerRelayState(t *testing.T) {
r.PostForm = make(url.Values)
r.PostForm.Set("RelayState", relayState)

tc.wo = append(tc.wo, continuity.UseRelayState())
c, err := p.Continue(r.Context(), w, r, ps.ByName("name"), tc.wo...)
if err != nil {
writer.WriteError(w, r, err)
Expand All @@ -71,7 +73,7 @@ func TestManagerRelayState(t *testing.T) {
r.PostForm = make(url.Values)
r.PostForm.Set("RelayState", relayState)

err := p.Abort(r.Context(), w, r, ps.ByName("name"))
err := p.Abort(r.Context(), w, r, ps.ByName("name"), continuity.UseRelayState())
if err != nil {
writer.WriteError(w, r, err)
return
Expand All @@ -90,7 +92,7 @@ func TestManagerRelayState(t *testing.T) {
return &http.Client{Jar: x.EasyCookieJar(t, nil)}
}

p := reg.RelayStateContinuityManager()
p := reg.ContinuityManager()
cl := newClient()

t.Run("case=continue cookie persists with same http client", func(t *testing.T) {
Expand All @@ -115,6 +117,7 @@ func TestManagerRelayState(t *testing.T) {

t.Cleanup(func() { require.NoError(t, res.Body.Close()) })

assert.Equal(t, res.StatusCode, 200)
require.Len(t, res.Cookies(), 1)
assert.EqualValues(t, res.Cookies()[0].Name, continuity.CookieName)
})
Expand Down Expand Up @@ -147,7 +150,8 @@ func TestManagerRelayState(t *testing.T) {

t.Cleanup(func() { require.NoError(t, res.Body.Close()) })

require.Len(t, res.Cookies(), 1)
assert.Equal(t, res.StatusCode, 200)
assert.Len(t, res.Cookies(), 1)
assert.EqualValues(t, res.Cookies()[0].Name, continuity.CookieName)
})

Expand All @@ -165,6 +169,7 @@ func TestManagerRelayState(t *testing.T) {

for _, c := range res.Cookies() {
relayState = c.Value
fmt.Println(relayState)
relayState = strings.Replace(relayState, "a", "b", 1)
}
require.Len(t, res.Cookies(), 1)
Expand All @@ -180,7 +185,7 @@ func TestManagerRelayState(t *testing.T) {

t.Cleanup(func() { require.NoError(t, res.Body.Close()) })

require.Len(t, res.Cookies(), 0, "the cookie couldn't be reconstructed without a valid relaystate")
assert.True(t, res.StatusCode == 400 || len(res.Cookies()) == 0)
})

t.Run("case=continue cookie not delivered without relaystate", func(t *testing.T) {
Expand All @@ -205,7 +210,7 @@ func TestManagerRelayState(t *testing.T) {

t.Cleanup(func() { require.NoError(t, res.Body.Close()) })

require.Len(t, res.Cookies(), 0, "the cookie couldn't be reconstructed without a valid relaystate")
assert.True(t, res.StatusCode == 400 || len(res.Cookies()) == 0)
})

t.Run("case=pause, abort, and continue session with failure", func(t *testing.T) {
Expand Down Expand Up @@ -236,6 +241,6 @@ func TestManagerRelayState(t *testing.T) {

t.Cleanup(func() { require.NoError(t, res.Body.Close()) })

require.Len(t, res.Cookies(), 0, "the cookie couldn't be reconstructed without a valid relaystate")
assert.True(t, res.StatusCode == 400 || len(res.Cookies()) == 0)
})
}
18 changes: 1 addition & 17 deletions driver/registry_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,6 @@ type RegistryDefault struct {

continuityManager continuity.Manager

x.RelayStateProvider
session.ManagementProvider

schemaHandler *schema.Handler

sessionHandler *session.Handler
Expand Down Expand Up @@ -676,25 +673,12 @@ func (m *RegistryDefault) Courier(ctx context.Context) (courier.Courier, error)
}

func (m *RegistryDefault) ContinuityManager() continuity.Manager {
// If m.continuityManager is nil or not a continuity.ManagerCookie
switch m.continuityManager.(type) {
case *continuity.ManagerCookie:
default:
if m.continuityManager == nil {
m.continuityManager = continuity.NewManagerCookie(m)
}
return m.continuityManager
}

func (m *RegistryDefault) RelayStateContinuityManager() continuity.Manager {
// If m.continuityManager is nil or not a continuity.ManagerRelayState
switch m.continuityManager.(type) {
case *continuity.ManagerRelayState:
default:
m.continuityManager = continuity.NewManagerRelayState(m, m)
}
return m.continuityManager
}

func (m *RegistryDefault) ContinuityPersister() continuity.Persister {
return m.persister
}
Expand Down
Loading

0 comments on commit 7a827c6

Please sign in to comment.