Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor and merge Callbacks #3151

Merged
merged 3 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 40 additions & 7 deletions auth/api/iam/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,19 +208,52 @@ func (r Wrapper) HandleTokenRequest(ctx context.Context, request HandleTokenRequ
}

func (r Wrapper) Callback(ctx context.Context, request CallbackRequestObject) (CallbackResponseObject, error) {
// check id in path
_, err := r.toOwnedDID(ctx, request.Did)
// validate request
// check did in path
ownDID, err := r.toOwnedDID(ctx, request.Did)
if err != nil {
// this is an OAuthError already, will be rendered as 400 but that's fine (for now) for an illegal id
return nil, err
}
// check if state is present and resolves to a client state
if request.Params.State == nil || *request.Params.State == "" {
return nil, oauthError(oauth.InvalidRequest, "missing state parameter")
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the Callback has an error but is missing the client state, the actual error will not be shown

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RFC6749 4.1.2.1. Error Response

state
REQUIRED if a "state" parameter was present in the client
authorization request. The exact value received from the
client.

I wasn't sure what to do with this, but I will wrap the error if there is one

oauthSession := new(OAuthSession)
if err = r.oauthClientStateStore().Get(*request.Params.State, oauthSession); err != nil {
return nil, oauthError(oauth.InvalidRequest, "invalid or expired state", err)
}
if !ownDID.Equals(*oauthSession.OwnDID) {
// TODO: this is a manipulated request, add error logging?
return nil, withCallbackURI(oauthError(oauth.InvalidRequest, "session DID does not match request"), oauthSession.redirectURI())
}

// if error is present, delegate call to error handler
if request.Params.Error != nil {
return r.handleCallbackError(request)
// if error is present, redirect error back to application initiating the flow
if request.Params.Error != nil && *request.Params.Error != "" {
requestErr := oauth.OAuth2Error{
Code: oauth.ErrorCode(*request.Params.Error),
RedirectURI: oauthSession.redirectURI(),
}
if request.Params.ErrorDescription != nil {
requestErr.Description = *request.Params.ErrorDescription
}
return nil, requestErr
}

return r.handleCallback(ctx, request)
// check if code is present
if request.Params.Code == nil || *request.Params.Code == "" {
return nil, withCallbackURI(oauthError(oauth.InvalidRequest, "missing code parameter"), oauthSession.redirectURI())
}

// continue flow
switch oauthSession.ClientFlow {
case "openid4vci_credential_request":
gerardsn marked this conversation as resolved.
Show resolved Hide resolved
return r.handleOpenID4VCICallback(ctx, *request.Params.Code, oauthSession)
case "access_token_request":
return r.handleCallback(ctx, *request.Params.Code, oauthSession)
default:
// programming error, should never happen
return nil, withCallbackURI(oauthError(oauth.ServerError, "unknown client flow for callback: '"+oauthSession.ClientFlow+"'"), oauthSession.redirectURI())
}
}

func (r Wrapper) RetrieveAccessToken(_ context.Context, request RetrieveAccessTokenRequestObject) (RetrieveAccessTokenResponseObject, error) {
Expand Down
129 changes: 107 additions & 22 deletions auth/api/iam/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,12 +389,11 @@ func TestWrapper_HandleAuthorizeRequest(t *testing.T) {
// handleAuthorizeRequestFromVerifier
_ = ctx.client.storageEngine.GetSessionDatabase().GetStore(oAuthFlowTimeout, oauthClientStateKey...).Put("state", OAuthSession{
// this is the state from the holder that was stored at the creation of the first authorization request to the verifier
ClientID: holderDID.String(),
Scope: "test",
OwnDID: &holderDID,
ClientState: "state",
RedirectURI: "https://example.com/iam/holder/cb",
ResponseType: "code",
ClientID: holderDID.String(),
Scope: "test",
OwnDID: &holderDID,
ClientState: "state",
RedirectURI: "https://example.com/iam/holder/cb",
})
_ = ctx.client.userSessionStore().Put("session-id", UserSession{
TenantDID: holderDID,
Expand Down Expand Up @@ -461,31 +460,40 @@ func TestWrapper_Callback(t *testing.T) {
errorDescription := "error description"
state := "state"
token := "token"
redirectURI, parseErr := url.Parse("https://example.com/iam/holder/cb")
require.NoError(t, parseErr)

session := OAuthSession{
ClientFlow: "access_token_request",
SessionID: "token",
OwnDID: &holderDID,
RedirectURI: "https://example.com/iam/holder/cb",
VerifierDID: &verifierDID,
RedirectURI: redirectURI.String(),
OtherDID: &verifierDID,
TokenEndpoint: "https://example.com/token",
}

t.Run("ok - error flow", func(t *testing.T) {
ctx := newTestClient(t)
ctx.vdr.EXPECT().IsOwner(gomock.Any(), webDID).Return(true, nil)
ctx.vdr.EXPECT().IsOwner(gomock.Any(), holderDID).Return(true, nil)
putState(ctx, "state", session)

res, err := ctx.client.Callback(nil, CallbackRequestObject{
Did: webDID.String(),
Did: holderDID.String(),
Params: CallbackParams{
State: &state,
Error: &errorCode,
ErrorDescription: &errorDescription,
},
})

require.NoError(t, err)
assert.Equal(t, "https://example.com/iam/holder/cb?error=error&error_description=error+description", res.(Callback302Response).Headers.Location)
var oauthErr oauth.OAuth2Error
require.ErrorAs(t, err, &oauthErr)
assert.Equal(t, oauth.OAuth2Error{
Code: oauth.ErrorCode(errorCode),
Description: errorDescription,
RedirectURI: redirectURI,
}, err)
assert.Nil(t, res)
})
t.Run("ok - success flow", func(t *testing.T) {
ctx := newTestClient(t)
Expand All @@ -494,11 +502,11 @@ func TestWrapper_Callback(t *testing.T) {
putState(ctx, "state", withDPoP)
putToken(ctx, token)
codeVerifier := getState(ctx, state).PKCEParams.Verifier
ctx.vdr.EXPECT().IsOwner(gomock.Any(), webDID).Return(true, nil).Times(2)
ctx.iamClient.EXPECT().AccessToken(gomock.Any(), code, session.TokenEndpoint, "https://example.com/oauth2/did:web:example.com:iam:123/callback", holderDID, codeVerifier, true).Return(&oauth.TokenResponse{AccessToken: "access"}, nil)
ctx.vdr.EXPECT().IsOwner(gomock.Any(), holderDID).Return(true, nil)
ctx.iamClient.EXPECT().AccessToken(gomock.Any(), code, session.TokenEndpoint, "https://example.com/oauth2/did:web:example.com:iam:holder/callback", holderDID, codeVerifier, true).Return(&oauth.TokenResponse{AccessToken: "access"}, nil)

res, err := ctx.client.Callback(nil, CallbackRequestObject{
Did: webDID.String(),
Did: holderDID.String(),
Params: CallbackParams{
Code: &code,
State: &state,
Expand All @@ -518,21 +526,22 @@ func TestWrapper_Callback(t *testing.T) {
t.Run("ok - no DPoP", func(t *testing.T) {
ctx := newTestClient(t)
_ = ctx.client.oauthClientStateStore().Put(state, OAuthSession{
ClientFlow: "access_token_request",
OwnDID: &holderDID,
PKCEParams: generatePKCEParams(),
RedirectURI: "https://example.com/iam/holder/cb",
SessionID: "token",
UseDPoP: false,
VerifierDID: &verifierDID,
OtherDID: &verifierDID,
TokenEndpoint: session.TokenEndpoint,
})
putToken(ctx, token)
codeVerifier := getState(ctx, state).PKCEParams.Verifier
ctx.vdr.EXPECT().IsOwner(gomock.Any(), webDID).Return(true, nil).Times(2)
ctx.iamClient.EXPECT().AccessToken(gomock.Any(), code, session.TokenEndpoint, "https://example.com/oauth2/did:web:example.com:iam:123/callback", holderDID, codeVerifier, false).Return(&oauth.TokenResponse{AccessToken: "access"}, nil)
ctx.vdr.EXPECT().IsOwner(gomock.Any(), holderDID).Return(true, nil)
ctx.iamClient.EXPECT().AccessToken(gomock.Any(), code, session.TokenEndpoint, "https://example.com/oauth2/did:web:example.com:iam:holder/callback", holderDID, codeVerifier, false).Return(&oauth.TokenResponse{AccessToken: "access"}, nil)

res, err := ctx.client.Callback(nil, CallbackRequestObject{
Did: webDID.String(),
Did: holderDID.String(),
Params: CallbackParams{
Code: &code,
State: &state,
Expand All @@ -542,17 +551,93 @@ func TestWrapper_Callback(t *testing.T) {
require.NoError(t, err)
assert.NotNil(t, res)
})
t.Run("unknown did", func(t *testing.T) {
t.Run("err - unknown did", func(t *testing.T) {
ctx := newTestClient(t)
ctx.vdr.EXPECT().IsOwner(gomock.Any(), webDID).Return(false, nil)
ctx.vdr.EXPECT().IsOwner(gomock.Any(), holderDID).Return(false, nil)

res, err := ctx.client.Callback(nil, CallbackRequestObject{
Did: webDID.String(),
Did: holderDID.String(),
})

assert.EqualError(t, err, "DID document not managed by this node")
assert.Nil(t, res)
})
t.Run("err - did mismatch", func(t *testing.T) {
ctx := newTestClient(t)
putState(ctx, "state", session)
ctx.vdr.EXPECT().IsOwner(gomock.Any(), webDID).Return(true, nil)

res, err := ctx.client.Callback(nil, CallbackRequestObject{
Did: webDID.String(),
Params: CallbackParams{
Code: &code,
State: &state,
},
})

assert.Nil(t, res)
requireOAuthError(t, err, oauth.InvalidRequest, "session DID does not match request")

})
t.Run("err - missing state", func(t *testing.T) {
ctx := newTestClient(t)
ctx.vdr.EXPECT().IsOwner(gomock.Any(), holderDID).Return(true, nil)

_, err := ctx.client.Callback(nil, CallbackRequestObject{
Did: holderDID.String(),
Params: CallbackParams{
Code: &code,
},
})

requireOAuthError(t, err, oauth.InvalidRequest, "missing state parameter")
})
t.Run("err - expired state/session", func(t *testing.T) {
ctx := newTestClient(t)
ctx.vdr.EXPECT().IsOwner(gomock.Any(), webDID).Return(true, nil)

_, err := ctx.client.Callback(nil, CallbackRequestObject{
Did: webDID.String(),
Params: CallbackParams{
Code: &code,
State: &state,
},
})

requireOAuthError(t, err, oauth.InvalidRequest, "invalid or expired state")
})
t.Run("err - missing code", func(t *testing.T) {
ctx := newTestClient(t)
ctx.vdr.EXPECT().IsOwner(gomock.Any(), holderDID).Return(true, nil)
putState(ctx, "state", session)

_, err := ctx.client.Callback(nil, CallbackRequestObject{
Did: holderDID.String(),
Params: CallbackParams{
State: &state,
},
})

requireOAuthError(t, err, oauth.InvalidRequest, "missing code parameter")
})
t.Run("err - unknown flow", func(t *testing.T) {
ctx := newTestClient(t)
_ = ctx.client.oauthClientStateStore().Put(state, OAuthSession{
ClientFlow: "",
OwnDID: &holderDID,
})
ctx.vdr.EXPECT().IsOwner(gomock.Any(), holderDID).Return(true, nil)

_, err := ctx.client.Callback(nil, CallbackRequestObject{
Did: holderDID.String(),
Params: CallbackParams{
Code: &code,
State: &state,
},
})

requireOAuthError(t, err, oauth.ServerError, "unknown client flow for callback: ''")
})
}

func TestWrapper_RetrieveAccessToken(t *testing.T) {
Expand Down
Loading
Loading