Skip to content

Commit

Permalink
fix: include flow id in use recovery token query (#2679)
Browse files Browse the repository at this point in the history
This PR adds the `selfservice_recovery_flow_id` to the query used when "using" a token in the recovery flow.

This PR also adds a new enum field for `identity_recovery_tokens` to distinguish the two flows: admin versus self-service recovery.

BREAKING CHANGES: This patch invalidates recovery flows initiated using the Admin API. Please re-generate any admin-generated recovery flows and tokens.
  • Loading branch information
jonas-jonas authored Aug 25, 2022
1 parent 1cd2672 commit d56586b
Show file tree
Hide file tree
Showing 15 changed files with 143 additions and 50 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ALTER TABLE identity_recovery_tokens
DROP token_type;
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ALTER TABLE identity_recovery_tokens
ADD token_type int NOT NULL DEFAULT 0;
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
UPDATE identity_recovery_tokens
SET token_type = 1
WHERE selfservice_recovery_flow_id IS NULL;

UPDATE identity_recovery_tokens
SET token_type = 2
WHERE selfservice_recovery_flow_id IS NOT NULL;
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
-- SQLITE does not support Check constraints in all cases
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ALTER TABLE identity_recovery_tokens
ADD CONSTRAINT identity_recovery_tokens_token_type_ck CHECK (token_type = 1 OR token_type = 2);
4 changes: 2 additions & 2 deletions persistence/sql/persister_recovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (p *Persister) CreateRecoveryToken(ctx context.Context, token *link.Recover
return nil
}

func (p *Persister) UseRecoveryToken(ctx context.Context, token string) (*link.RecoveryToken, error) {
func (p *Persister) UseRecoveryToken(ctx context.Context, fID uuid.UUID, token string) (*link.RecoveryToken, error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UseRecoveryToken")
defer span.End()

Expand All @@ -74,7 +74,7 @@ func (p *Persister) UseRecoveryToken(ctx context.Context, token string) (*link.R
nid := p.NetworkID(ctx)
if err := sqlcon.HandleError(p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) (err error) {
for _, secret := range p.r.Config().SecretsSession(ctx) {
if err = tx.Where("token = ? AND nid = ? AND NOT used", p.hmacValueWithSecret(ctx, token, secret), nid).First(&rt); err != nil {
if err = tx.Where("token = ? AND nid = ? AND NOT used AND selfservice_recovery_flow_id = ?", p.hmacValueWithSecret(ctx, token, secret), nid, fID).First(&rt); err != nil {
if !errors.Is(sqlcon.HandleError(err), sqlcon.ErrNoRows) {
return err
}
Expand Down
4 changes: 3 additions & 1 deletion selfservice/strategy/link/persistence.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ package link

import (
"context"

"github.com/gofrs/uuid"
)

type (
RecoveryTokenPersister interface {
CreateRecoveryToken(ctx context.Context, token *RecoveryToken) error
UseRecoveryToken(ctx context.Context, token string) (*RecoveryToken, error)
UseRecoveryToken(ctx context.Context, fID uuid.UUID, token string) (*RecoveryToken, error)
DeleteRecoveryToken(ctx context.Context, token string) error
}

Expand Down
10 changes: 5 additions & 5 deletions selfservice/strategy/link/strategy_recovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ func (s *Strategy) createRecoveryLink(w http.ResponseWriter, r *http.Request, _
return
}

token := NewRecoveryToken(id.ID, expiresIn)
token := NewAdminRecoveryToken(id.ID, req.ID, expiresIn)
if err := s.d.RecoveryTokenPersister().CreateRecoveryToken(r.Context(), token); err != nil {
s.d.Writer().WriteError(w, r, err)
return
Expand Down Expand Up @@ -222,7 +222,7 @@ func (s *Strategy) Recover(w http.ResponseWriter, r *http.Request, f *recovery.F
return s.HandleRecoveryError(w, r, nil, body, err)
}

return s.recoveryUseToken(w, r, body)
return s.recoveryUseToken(w, r, f.ID, body)
}

if _, err := s.d.SessionManager().FetchFromRequest(r.Context(), r); err == nil {
Expand Down Expand Up @@ -313,8 +313,8 @@ func (s *Strategy) recoveryIssueSession(w http.ResponseWriter, r *http.Request,
return errors.WithStack(flow.ErrCompletedByStrategy)
}

func (s *Strategy) recoveryUseToken(w http.ResponseWriter, r *http.Request, body *recoverySubmitPayload) error {
token, err := s.d.RecoveryTokenPersister().UseRecoveryToken(r.Context(), body.Token)
func (s *Strategy) recoveryUseToken(w http.ResponseWriter, r *http.Request, fID uuid.UUID, body *recoverySubmitPayload) error {
token, err := s.d.RecoveryTokenPersister().UseRecoveryToken(r.Context(), fID, body.Token)
if err != nil {
if errors.Is(err, sqlcon.ErrNoRows) {
return s.retryRecoveryFlowWithMessage(w, r, flow.TypeBrowser, text.NewErrorValidationRecoveryTokenInvalidOrAlreadyUsed())
Expand Down Expand Up @@ -351,7 +351,7 @@ func (s *Strategy) recoveryUseToken(w http.ResponseWriter, r *http.Request, body
}

// mark address as verified only for a self-service flow
if token.FlowID.Valid {
if token.TokenType == RecoveryTokenTypeSelfService {
if err := s.markRecoveryAddressVerified(w, r, f, recovered, token.RecoveryAddress); err != nil {
return s.HandleRecoveryError(w, r, f, body, err)
}
Expand Down
110 changes: 82 additions & 28 deletions selfservice/strategy/link/strategy_recovery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"testing"
"time"

"github.com/ory/kratos/driver"
"github.com/ory/kratos/session"

"github.com/davecgh/go-spew/spew"
Expand Down Expand Up @@ -56,6 +57,23 @@ func init() {
corpx.RegisterFakes()
}

func createIdentityToRecover(t *testing.T, reg *driver.RegistryDefault, email string) *identity.Identity {
var id = &identity.Identity{
Credentials: map[identity.CredentialsType]identity.Credentials{
"password": {Type: "password", Identifiers: []string{email}, Config: sqlxx.JSONRawMessage(`{"hashed_password":"foo"}`)}},
Traits: identity.Traits(fmt.Sprintf(`{"email":"%s"}`, email)),
SchemaID: config.DefaultIdentityTraitsSchemaID,
}
require.NoError(t, reg.IdentityManager().Create(context.Background(), id, identity.ManagerAllowWriteProtectedTraits))

addr, err := reg.IdentityPool().FindVerifiableAddressByValue(context.Background(), identity.VerifiableAddressTypeEmail, email)
assert.NoError(t, err)
assert.False(t, addr.Verified)
assert.Nil(t, addr.VerifiedAt)
assert.Equal(t, identity.VerifiableAddressStatusPending, addr.Status)
return id
}

func TestAdminStrategy(t *testing.T) {
ctx := context.Background()
conf, reg := internal.NewFastRegistryWithMocks(t)
Expand Down Expand Up @@ -183,6 +201,59 @@ func TestAdminStrategy(t *testing.T) {
assert.Nil(t, addr.VerifiedAt)
assert.Equal(t, identity.VerifiableAddressStatusPending, addr.Status)
})

t.Run("case=should not be able to use code from different flow", func(t *testing.T) {
email := strings.ToLower(testhelpers.RandomEmail())
id := createIdentityToRecover(t, reg, email)

rl1, _, err := adminSDK.V0alpha2Api.
AdminCreateSelfServiceRecoveryLink(context.Background()).
AdminCreateSelfServiceRecoveryLinkBody(kratos.AdminCreateSelfServiceRecoveryLinkBody{
IdentityId: id.ID.String(),
}).
Execute()
require.NoError(t, err)

checkLink(t, rl1, time.Now().Add(conf.SelfServiceFlowRecoveryRequestLifespan(ctx)+time.Second))

rl2, _, err := adminSDK.V0alpha2Api.
AdminCreateSelfServiceRecoveryLink(context.Background()).
AdminCreateSelfServiceRecoveryLinkBody(kratos.AdminCreateSelfServiceRecoveryLinkBody{
IdentityId: id.ID.String(),
}).
Execute()
require.NoError(t, err)

checkLink(t, rl2, time.Now().Add(conf.SelfServiceFlowRecoveryRequestLifespan(ctx)+time.Second))

recoveryUrl1, err := url.Parse(rl1.RecoveryLink)
require.NoError(t, err)

recoveryUrl2, err := url.Parse(rl2.RecoveryLink)
require.NoError(t, err)

token1 := recoveryUrl1.Query().Get("token")
require.NotEmpty(t, token1)
token2 := recoveryUrl2.Query().Get("token")
require.NotEmpty(t, token2)
require.NotEqual(t, token1, token2)

values := recoveryUrl1.Query()

values.Set("token", token2)

recoveryUrl1.RawQuery = values.Encode()

action := recoveryUrl1.String()
// Submit the modified link with token from rl2 and flow from rl1
res, err := publicTS.Client().Get(action)
require.NoError(t, err)
body := ioutilx.MustReadAll(res.Body)

action = gjson.GetBytes(body, "ui.action").String()
require.NotEmpty(t, action)
assert.Equal(t, "The recovery token is invalid or has already been used. Please retry the flow.", gjson.GetBytes(body, "ui.messages.0.text").String())
})
}

func TestRecovery(t *testing.T) {
Expand All @@ -197,23 +268,6 @@ func TestRecovery(t *testing.T) {

public, _, publicRouter, _ := testhelpers.NewKratosServerWithCSRFAndRouters(t, reg)

var createIdentityToRecover = func(email string) *identity.Identity {
var id = &identity.Identity{
Credentials: map[identity.CredentialsType]identity.Credentials{
"password": {Type: "password", Identifiers: []string{email}, Config: sqlxx.JSONRawMessage(`{"hashed_password":"foo"}`)}},
Traits: identity.Traits(fmt.Sprintf(`{"email":"%s"}`, email)),
SchemaID: config.DefaultIdentityTraitsSchemaID,
}
require.NoError(t, reg.IdentityManager().Create(context.Background(), id, identity.ManagerAllowWriteProtectedTraits))

addr, err := reg.IdentityPool().FindVerifiableAddressByValue(context.Background(), identity.VerifiableAddressTypeEmail, email)
assert.NoError(t, err)
assert.False(t, addr.Verified)
assert.Nil(t, addr.VerifiedAt)
assert.Equal(t, identity.VerifiableAddressStatusPending, addr.Status)
return id
}

var expect = func(t *testing.T, hc *http.Client, isAPI, isSPA bool, values func(url.Values), c int) string {
if hc == nil {
hc = testhelpers.NewDebugClient(t)
Expand Down Expand Up @@ -414,23 +468,23 @@ func TestRecovery(t *testing.T) {

t.Run("type=browser", func(t *testing.T) {
email := "[email protected]"
createIdentityToRecover(email)
createIdentityToRecover(t, reg, email)
check(t, expectSuccess(t, nil, false, false, func(v url.Values) {
v.Set("email", email)
}), email, false)
})

t.Run("type=spa", func(t *testing.T) {
email := "[email protected]"
createIdentityToRecover(email)
createIdentityToRecover(t, reg, email)
check(t, expectSuccess(t, nil, true, true, func(v url.Values) {
v.Set("email", email)
}), email, true)
})

t.Run("type=api", func(t *testing.T) {
email := "[email protected]"
createIdentityToRecover(email)
createIdentityToRecover(t, reg, email)
check(t, expectSuccess(t, nil, true, false, func(v url.Values) {
v.Set("email", email)
}), email, true)
Expand Down Expand Up @@ -487,7 +541,7 @@ func TestRecovery(t *testing.T) {

t.Run("type=browser", func(t *testing.T) {
email := "[email protected]"
createIdentityToRecover(email)
createIdentityToRecover(t, reg, email)
check(t, expectSuccess(t, nil, false, false, func(v url.Values) {
v.Set("email", email)
}), email, "")
Expand All @@ -496,7 +550,7 @@ func TestRecovery(t *testing.T) {
t.Run("type=browser set return_to", func(t *testing.T) {
email := "[email protected]"
returnTo := "https://www.ory.sh"
createIdentityToRecover(email)
createIdentityToRecover(t, reg, email)

hc := testhelpers.NewClientWithCookies(t)
hc.Transport = testhelpers.NewTransportWithLogger(http.DefaultTransport, t).RoundTripper
Expand All @@ -518,15 +572,15 @@ func TestRecovery(t *testing.T) {

t.Run("type=spa", func(t *testing.T) {
email := "[email protected]"
createIdentityToRecover(email)
createIdentityToRecover(t, reg, email)
check(t, expectSuccess(t, nil, true, true, func(v url.Values) {
v.Set("email", email)
}), email, "")
})

t.Run("type=api", func(t *testing.T) {
email := "[email protected]"
createIdentityToRecover(email)
createIdentityToRecover(t, reg, email)
check(t, expectSuccess(t, nil, true, false, func(v url.Values) {
v.Set("email", email)
}), email, "")
Expand Down Expand Up @@ -563,7 +617,7 @@ func TestRecovery(t *testing.T) {
}

email := x.NewUUID().String() + "@ory.sh"
id := createIdentityToRecover(email)
id := createIdentityToRecover(t, reg, email)

t.Run("case=unauthenticated", func(t *testing.T) {
var values = func(v url.Values) {
Expand Down Expand Up @@ -604,7 +658,7 @@ func TestRecovery(t *testing.T) {

recoveryEmail := strings.ToLower(testhelpers.RandomEmail())
email := recoveryEmail
id := createIdentityToRecover(email)
id := createIdentityToRecover(t, reg, email)

sess, err := session.NewActiveSession(ctx, id, conf, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1)
require.NoError(t, err)
Expand Down Expand Up @@ -659,7 +713,7 @@ func TestRecovery(t *testing.T) {

t.Run("description=should not be able to use an outdated link", func(t *testing.T) {
recoveryEmail := "[email protected]"
createIdentityToRecover(recoveryEmail)
createIdentityToRecover(t, reg, recoveryEmail)
conf.MustSet(ctx, config.ViperKeySelfServiceRecoveryRequestLifespan, time.Millisecond*200)
t.Cleanup(func() {
conf.MustSet(ctx, config.ViperKeySelfServiceRecoveryRequestLifespan, time.Minute)
Expand All @@ -685,7 +739,7 @@ func TestRecovery(t *testing.T) {

t.Run("description=should not be able to use an outdated flow", func(t *testing.T) {
recoveryEmail := "[email protected]"
createIdentityToRecover(recoveryEmail)
createIdentityToRecover(t, reg, recoveryEmail)
conf.MustSet(ctx, config.ViperKeySelfServiceRecoveryRequestLifespan, time.Millisecond*200)
t.Cleanup(func() {
conf.MustSet(ctx, config.ViperKeySelfServiceRecoveryRequestLifespan, time.Minute)
Expand Down
32 changes: 19 additions & 13 deletions selfservice/strategy/link/test/persistence.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,8 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
conf.MustSet(ctx, config.ViperKeySecretsDefault, []string{"secret-a", "secret-b"})

t.Run("token=recovery", func(t *testing.T) {
t.Run("case=should error when the recovery token does not exist", func(t *testing.T) {
_, err := p.UseRecoveryToken(ctx, "i-do-not-exist")
require.Error(t, err)
})

newRecoveryToken := func(t *testing.T, email string) *link.RecoveryToken {
newRecoveryToken := func(t *testing.T, email string) (*link.RecoveryToken, *recovery.Flow) {
var req recovery.Flow
require.NoError(t, faker.FakeData(&req))
require.NoError(t, p.CreateRecoveryFlow(ctx, &req))
Expand All @@ -52,42 +48,52 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {

require.NoError(t, p.CreateIdentity(ctx, &i))

return &link.RecoveryToken{Token: x.NewUUID().String(), FlowID: uuid.NullUUID{UUID: req.ID, Valid: true},
return &link.RecoveryToken{
Token: x.NewUUID().String(),
FlowID: uuid.NullUUID{UUID: req.ID, Valid: true},
RecoveryAddress: &i.RecoveryAddresses[0],
ExpiresAt: time.Now(),
IssuedAt: time.Now(),
IdentityID: i.ID,
}
TokenType: link.RecoveryTokenTypeAdmin,
}, &req
}

t.Run("case=should error when the recovery token does not exist", func(t *testing.T) {
_, err := p.UseRecoveryToken(ctx, "i-do-not-exist")
_, err := p.UseRecoveryToken(ctx, x.NewUUID(), "i-do-not-exist")
require.Error(t, err)
})

t.Run("case=should create a new recovery token", func(t *testing.T) {
token := newRecoveryToken(t, "[email protected]")
token, _ := newRecoveryToken(t, "[email protected]")
require.NoError(t, p.CreateRecoveryToken(ctx, token))
})

t.Run("case=should error when token is used with different flow id", func(t *testing.T) {
token, _ := newRecoveryToken(t, "[email protected]")
require.NoError(t, p.CreateRecoveryToken(ctx, token))
_, err := p.UseRecoveryToken(ctx, x.NewUUID(), token.Token)
require.Error(t, err)
})

t.Run("case=should create a recovery token and use it", func(t *testing.T) {
expected := newRecoveryToken(t, "[email protected]")
expected, f := newRecoveryToken(t, "[email protected]")
require.NoError(t, p.CreateRecoveryToken(ctx, expected))

t.Run("not work on another network", func(t *testing.T) {
_, p := testhelpers.NewNetwork(t, ctx, p)
_, err := p.UseRecoveryToken(ctx, expected.Token)
_, err := p.UseRecoveryToken(ctx, f.ID, expected.Token)
require.ErrorIs(t, err, sqlcon.ErrNoRows)
})

actual, err := p.UseRecoveryToken(ctx, expected.Token)
actual, err := p.UseRecoveryToken(ctx, f.ID, expected.Token)
require.NoError(t, err)
assert.Equal(t, nid, actual.NID)
assert.Equal(t, expected.IdentityID, actual.IdentityID)
assert.NotEqual(t, expected.Token, actual.Token)
assert.EqualValues(t, expected.FlowID, actual.FlowID)

_, err = p.UseRecoveryToken(ctx, expected.Token)
_, err = p.UseRecoveryToken(ctx, f.ID, expected.Token)
require.Error(t, err)
})

Expand Down
Loading

0 comments on commit d56586b

Please sign in to comment.