diff --git a/pkg/lib/oauth/grant_code.go b/pkg/lib/oauth/grant_code.go index 49a8c03b7a..5b3b5356b9 100644 --- a/pkg/lib/oauth/grant_code.go +++ b/pkg/lib/oauth/grant_code.go @@ -5,12 +5,14 @@ import ( "github.com/authgear/authgear-server/pkg/lib/authn/authenticationinfo" "github.com/authgear/authgear-server/pkg/lib/oauth/protocol" + "github.com/authgear/authgear-server/pkg/lib/session" ) type CodeGrant struct { AppID string `json:"app_id"` AuthorizationID string `json:"authz_id"` - IDPSessionID string `json:"session_id"` + SessionType session.Type `json:"session_type"` + SessionID string `json:"session_id"` AuthenticationInfo authenticationinfo.T `json:"authentication_info"` IDTokenHintSID string `json:"id_token_hint_sid"` diff --git a/pkg/lib/oauth/handler/handler_authz.go b/pkg/lib/oauth/handler/handler_authz.go index 16a54bb4f2..9016076c1f 100644 --- a/pkg/lib/oauth/handler/handler_authz.go +++ b/pkg/lib/oauth/handler/handler_authz.go @@ -482,12 +482,10 @@ func (h *AuthorizationHandler) doHandle( authenticationInfo := resolvedSession.GetAuthenticationInfo() autoGrantAuthz := client.IsFirstParty() - idpSessionID := "" - if resolvedSession.SessionType() == session.TypeIdentityProvider { - idpSessionID = resolvedSession.SessionID() - } + sessionType := resolvedSession.SessionType() + sessionID := resolvedSession.SessionID() - result, err := h.finish(redirectURI, r, idpSessionID, authenticationInfo, idTokenHintSID, nil, autoGrantAuthz) + result, err := h.finish(redirectURI, r, sessionType, sessionID, authenticationInfo, idTokenHintSID, nil, autoGrantAuthz) if err != nil { if errors.Is(err, oauth.ErrAuthorizationNotFound) { return nil, protocol.NewError("access_denied", "authorization required") @@ -562,7 +560,8 @@ func (h *AuthorizationHandler) doHandleAppInitiatedSSOToWeb( func (h *AuthorizationHandler) finish( redirectURI *url.URL, r protocol.AuthorizationRequest, - idpSessionID string, + sessionType session.Type, + sessionID string, authenticationInfo authenticationinfo.T, idTokenHintSID string, cookies []*http.Cookie, @@ -591,13 +590,17 @@ func (h *AuthorizationHandler) finish( responseType := r.ResponseType() switch { case responseType.Equal(SettingsActonResponseType): + idpSessionID := "" + if sessionType == session.TypeIdentityProvider { + idpSessionID = sessionID + } err = h.generateSettingsActionResponse(redirectURI.String(), idpSessionID, authenticationInfo, idTokenHintSID, r, authz, resp) if err != nil { return nil, err } case responseType.Equal(CodeResponseType): - err = h.generateCodeResponse(redirectURI.String(), idpSessionID, authenticationInfo, idTokenHintSID, r, authz, resp) + err = h.generateCodeResponse(redirectURI.String(), sessionType, sessionID, authenticationInfo, idTokenHintSID, r, authz, resp) if err != nil { return nil, err } @@ -645,12 +648,14 @@ func (h *AuthorizationHandler) doHandleConsentRequest( } idTokenHintSID := uiInfoByProduct.IDTokenHintSID - var idpSessionID string - if s := session.GetSession(h.Context); s != nil && s.SessionType() == session.TypeIdentityProvider { - idpSessionID = s.SessionID() + sessionID := "" + var sessionType session.Type = "" + if s := session.GetSession(h.Context); s != nil { + sessionID = s.SessionID() + sessionType = s.SessionType() } - return h.finish(redirectURI, r, idpSessionID, authenticationInfo, idTokenHintSID, []*http.Cookie{}, grantAuthz) + return h.finish(redirectURI, r, sessionType, sessionID, authenticationInfo, idTokenHintSID, []*http.Cookie{}, grantAuthz) } func (h *AuthorizationHandler) validateAppInitiatedSSOToWebTokenRequest( @@ -746,7 +751,8 @@ func (h *AuthorizationHandler) validateRequest( func (h *AuthorizationHandler) generateCodeResponse( redirectURI string, - idpSessionID string, + sessionType session.Type, + sessionID string, authenticationInfo authenticationinfo.T, idTokenHintSID string, r protocol.AuthorizationRequest, @@ -755,7 +761,8 @@ func (h *AuthorizationHandler) generateCodeResponse( ) error { code, _, err := h.CodeGrantService.CreateCodeGrant(&CreateCodeGrantOptions{ Authorization: authz, - IDPSessionID: idpSessionID, + SessionType: sessionType, + SessionID: sessionID, AuthenticationInfo: authenticationInfo, IDTokenHintSID: idTokenHintSID, RedirectURI: redirectURI, diff --git a/pkg/lib/oauth/handler/handler_authz_test.go b/pkg/lib/oauth/handler/handler_authz_test.go index c133670996..412d528dfb 100644 --- a/pkg/lib/oauth/handler/handler_authz_test.go +++ b/pkg/lib/oauth/handler/handler_authz_test.go @@ -21,6 +21,7 @@ import ( "github.com/authgear/authgear-server/pkg/lib/oauth/handler" "github.com/authgear/authgear-server/pkg/lib/oauth/oidc" "github.com/authgear/authgear-server/pkg/lib/oauth/protocol" + "github.com/authgear/authgear-server/pkg/lib/session" sessiontest "github.com/authgear/authgear-server/pkg/lib/session/test" "github.com/authgear/authgear-server/pkg/util/clock" ) @@ -293,7 +294,8 @@ func TestAuthorizationHandler(t *testing.T) { So(codeGrantStore.grants[0], ShouldResemble, oauth.CodeGrant{ AppID: "app-id", AuthorizationID: authorization.ID, - IDPSessionID: "session-id", + SessionType: session.TypeIdentityProvider, + SessionID: "session-id", AuthenticationInfo: authenticationinfo.T{ UserID: "user-id", }, @@ -345,7 +347,8 @@ func TestAuthorizationHandler(t *testing.T) { So(codeGrantStore.grants[0], ShouldResemble, oauth.CodeGrant{ AppID: "app-id", AuthorizationID: "authz-id", - IDPSessionID: "session-id", + SessionType: session.TypeIdentityProvider, + SessionID: "session-id", AuthenticationInfo: authenticationinfo.T{ UserID: "user-id", }, diff --git a/pkg/lib/oauth/handler/handler_token.go b/pkg/lib/oauth/handler/handler_token.go index d488ad1c84..769836cb76 100644 --- a/pkg/lib/oauth/handler/handler_token.go +++ b/pkg/lib/oauth/handler/handler_token.go @@ -166,6 +166,12 @@ type TokenHandlerTokenService interface { opts IssueOfflineGrantOptions, resp protocol.TokenResponse, ) (offlineGrant *oauth.OfflineGrant, tokenHash string, err error) + IssueRefreshTokenForOfflineGrant( + offlineGrantID string, + client *config.OAuthClientConfig, + opts IssueOfflineGrantRefreshTokenOptions, + resp protocol.TokenResponse, + ) (offlineGrant *oauth.OfflineGrant, tokenHash string, err error) IssueDeviceSecret(resp protocol.TokenResponse) (deviceSecretHash string) } @@ -1282,9 +1288,16 @@ func (h *TokenHandler) handleApp2AppRequest( artificialAuthorizationRequest["x_sso_enabled"] = "true" } + originalIDPSessionID := originalOfflineGrant.IDPSessionID + var sessionType session.Type = "" + if originalIDPSessionID != "" { + sessionType = session.TypeIdentityProvider + } + code, _, err := h.CodeGrantService.CreateCodeGrant(&CreateCodeGrantOptions{ Authorization: authz, - IDPSessionID: originalOfflineGrant.IDPSessionID, + SessionType: sessionType, + SessionID: originalIDPSessionID, AuthenticationInfo: info, IDTokenHintSID: "", RedirectURI: redirectURI.String(), @@ -1485,22 +1498,39 @@ func (h *TokenHandler) doIssueTokensForAuthorizationCode( Scopes: scopes, AuthorizationID: authz.ID, AuthenticationInfo: info, - IDPSessionID: code.IDPSessionID, + IDPSessionID: code.SessionID, DeviceInfo: deviceInfo, SSOEnabled: code.AuthorizationRequest.SSOEnabled(), App2AppDeviceKey: app2appDevicePublicKey, IssueDeviceSecret: issueDeviceToken, } if issueRefreshToken { - offlineGrant, tokenHash, err := h.issueOfflineGrant( - client, - code.AuthenticationInfo.UserID, - resp, - opts, - true) - if err != nil { - return nil, err + var offlineGrant *oauth.OfflineGrant + var tokenHash string + var err error + switch code.SessionType { + case session.TypeOfflineGrant: + offlineGrant, tokenHash, err = h.TokenService.IssueRefreshTokenForOfflineGrant(code.SessionID, client, IssueOfflineGrantRefreshTokenOptions{ + Scopes: scopes, + AuthorizationID: authz.ID, + }, resp) + if err != nil { + return nil, err + } + case session.TypeIdentityProvider: + fallthrough + default: + offlineGrant, tokenHash, err = h.issueOfflineGrant( + client, + code.AuthenticationInfo.UserID, + resp, + opts, + true) + if err != nil { + return nil, err + } } + sid = oidc.EncodeSID(offlineGrant) accessTokenSessionID = offlineGrant.ID accessTokenSessionKind = oauth.GrantSessionKindOffline @@ -1646,18 +1676,6 @@ func (h *TokenHandler) issueTokensForRefreshToken( return resp, nil } -type IssueOfflineGrantOptions struct { - AuthenticationInfo authenticationinfo.T - Scopes []string - AuthorizationID string - IDPSessionID string - DeviceInfo map[string]interface{} - IdentityID string - SSOEnabled bool - App2AppDeviceKey jwk.Key - IssueDeviceSecret bool -} - func (h *TokenHandler) IssueAppSessionToken(refreshToken string) (string, *oauth.AppSessionToken, error) { authz, grant, refreshTokenHash, err := h.TokenService.ParseRefreshToken(refreshToken) if err != nil { diff --git a/pkg/lib/oauth/handler/handler_token_mock_test.go b/pkg/lib/oauth/handler/handler_token_mock_test.go index 9d460c003b..b0d2d1c90f 100644 --- a/pkg/lib/oauth/handler/handler_token_mock_test.go +++ b/pkg/lib/oauth/handler/handler_token_mock_test.go @@ -682,6 +682,22 @@ func (mr *MockTokenHandlerTokenServiceMockRecorder) IssueOfflineGrant(client, op return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IssueOfflineGrant", reflect.TypeOf((*MockTokenHandlerTokenService)(nil).IssueOfflineGrant), client, opts, resp) } +// IssueRefreshTokenForOfflineGrant mocks base method. +func (m *MockTokenHandlerTokenService) IssueRefreshTokenForOfflineGrant(offlineGrantID string, client *config.OAuthClientConfig, opts handler.IssueOfflineGrantRefreshTokenOptions, resp protocol.TokenResponse) (*oauth.OfflineGrant, string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IssueRefreshTokenForOfflineGrant", offlineGrantID, client, opts, resp) + ret0, _ := ret[0].(*oauth.OfflineGrant) + ret1, _ := ret[1].(string) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// IssueRefreshTokenForOfflineGrant indicates an expected call of IssueRefreshTokenForOfflineGrant. +func (mr *MockTokenHandlerTokenServiceMockRecorder) IssueRefreshTokenForOfflineGrant(offlineGrantID, client, opts, resp interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IssueRefreshTokenForOfflineGrant", reflect.TypeOf((*MockTokenHandlerTokenService)(nil).IssueRefreshTokenForOfflineGrant), offlineGrantID, client, opts, resp) +} + // ParseRefreshToken mocks base method. func (m *MockTokenHandlerTokenService) ParseRefreshToken(token string) (*oauth.Authorization, *oauth.OfflineGrant, string, error) { m.ctrl.T.Helper() diff --git a/pkg/lib/oauth/handler/service_code_grant.go b/pkg/lib/oauth/handler/service_code_grant.go index 11f266e72e..f56663ee88 100644 --- a/pkg/lib/oauth/handler/service_code_grant.go +++ b/pkg/lib/oauth/handler/service_code_grant.go @@ -5,6 +5,7 @@ import ( "github.com/authgear/authgear-server/pkg/lib/config" "github.com/authgear/authgear-server/pkg/lib/oauth" "github.com/authgear/authgear-server/pkg/lib/oauth/protocol" + "github.com/authgear/authgear-server/pkg/lib/session" "github.com/authgear/authgear-server/pkg/util/clock" ) @@ -18,7 +19,8 @@ type CodeGrantService struct { type CreateCodeGrantOptions struct { Authorization *oauth.Authorization - IDPSessionID string + SessionType session.Type + SessionID string AuthenticationInfo authenticationinfo.T IDTokenHintSID string RedirectURI string @@ -32,7 +34,8 @@ func (s *CodeGrantService) CreateCodeGrant(opts *CreateCodeGrantOptions) (code s codeGrant := &oauth.CodeGrant{ AppID: string(s.AppID), AuthorizationID: opts.Authorization.ID, - IDPSessionID: opts.IDPSessionID, + SessionType: opts.SessionType, + SessionID: opts.SessionID, AuthenticationInfo: opts.AuthenticationInfo, IDTokenHintSID: opts.IDTokenHintSID, diff --git a/pkg/lib/oauth/handler/service_token.go b/pkg/lib/oauth/handler/service_token.go index a991889217..0d06ba9087 100644 --- a/pkg/lib/oauth/handler/service_token.go +++ b/pkg/lib/oauth/handler/service_token.go @@ -4,6 +4,9 @@ import ( "encoding/json" "errors" + "github.com/lestrrat-go/jwx/v2/jwk" + + "github.com/authgear/authgear-server/pkg/lib/authn/authenticationinfo" "github.com/authgear/authgear-server/pkg/lib/authn/user" "github.com/authgear/authgear-server/pkg/lib/config" "github.com/authgear/authgear-server/pkg/lib/oauth" @@ -17,6 +20,23 @@ import ( var ErrInvalidRefreshToken = protocol.NewError("invalid_grant", "invalid refresh token") +type IssueOfflineGrantOptions struct { + AuthenticationInfo authenticationinfo.T + Scopes []string + AuthorizationID string + IDPSessionID string + DeviceInfo map[string]interface{} + IdentityID string + SSOEnabled bool + App2AppDeviceKey jwk.Key + IssueDeviceSecret bool +} + +type IssueOfflineGrantRefreshTokenOptions struct { + Scopes []string + AuthorizationID string +} + type TokenService struct { RemoteIP httputil.RemoteIP UserAgentString httputil.UserAgentString @@ -112,6 +132,32 @@ func (s *TokenService) IssueOfflineGrant( return offlineGrant, tokenHash, nil } +func (s *TokenService) IssueRefreshTokenForOfflineGrant( + offlineGrantID string, + client *config.OAuthClientConfig, + opts IssueOfflineGrantRefreshTokenOptions, + resp protocol.TokenResponse, +) (offlineGrant *oauth.OfflineGrant, tokenHash string, err error) { + offlineGrant, err = s.OfflineGrants.GetOfflineGrant(offlineGrantID) + if err != nil { + return nil, "", err + } + + newRefreshTokenResult, newOfflineGrant, err := s.OfflineGrantService.CreateNewRefreshToken( + offlineGrant, client.ClientID, opts.Scopes, opts.AuthorizationID, + ) + if err != nil { + return nil, "", err + } + offlineGrant = newOfflineGrant + + if resp != nil { + resp.RefreshToken(oauth.EncodeRefreshToken(newRefreshTokenResult.Token, offlineGrant.ID)) + } + + return newOfflineGrant, newRefreshTokenResult.TokenHash, nil +} + func (s *TokenService) IssueAccessGrant( client *config.OAuthClientConfig, scopes []string, diff --git a/pkg/lib/oauth/resolver.go b/pkg/lib/oauth/resolver.go index f7ae2279b1..d010834afb 100644 --- a/pkg/lib/oauth/resolver.go +++ b/pkg/lib/oauth/resolver.go @@ -67,8 +67,27 @@ func (re *Resolver) Resolve(rw http.ResponseWriter, r *http.Request) (session.Re return nil, nil } -func (re *Resolver) resolveByAccessGrant(grant *AccessGrant) (session.ResolvedSession, error) { - _, err := re.Authorizations.GetByID(grant.AuthorizationID) +func (re *Resolver) resolveAccessToken(token string) (session.ResolvedSession, error) { + tok, isHash, err := re.AccessTokenDecoder.DecodeAccessToken(token) + if err != nil { + return nil, session.ErrInvalidSession + } + + var tokenHash string + if isHash { + tokenHash = tok + } else { + tokenHash = HashToken(token) + } + + grant, err := re.AccessGrants.GetAccessGrant(tokenHash) + if errors.Is(err, ErrGrantNotFound) { + return nil, session.ErrInvalidSession + } else if err != nil { + return nil, err + } + + _, err = re.Authorizations.GetByID(grant.AuthorizationID) if errors.Is(err, ErrAuthorizationNotFound) { // Authorization does not exists (e.g. revoked) return nil, session.ErrInvalidSession @@ -121,26 +140,7 @@ func (re *Resolver) resolveHeader(r *http.Request) (session.ResolvedSession, err return nil, nil } - tok, isHash, err := re.AccessTokenDecoder.DecodeAccessToken(token) - if err != nil { - return nil, session.ErrInvalidSession - } - - var tokenHash string - if isHash { - tokenHash = tok - } else { - tokenHash = HashToken(token) - } - - grant, err := re.AccessGrants.GetAccessGrant(tokenHash) - if errors.Is(err, ErrGrantNotFound) { - return nil, session.ErrInvalidSession - } else if err != nil { - return nil, err - } - - return re.resolveByAccessGrant(grant) + return re.resolveAccessToken(token) } func (re *Resolver) resolveAccessTokenCookie(r *http.Request) (session.ResolvedSession, error) { @@ -149,14 +149,8 @@ func (re *Resolver) resolveAccessTokenCookie(r *http.Request) (session.ResolvedS // No access token cookie. Simply proceed. return nil, nil } - grant, err := re.AccessGrants.GetAccessGrant(HashToken(cookie.Value)) - if errors.Is(err, ErrGrantNotFound) { - return nil, session.ErrInvalidSession - } else if err != nil { - return nil, err - } - return re.resolveByAccessGrant(grant) + return re.resolveAccessToken(cookie.Value) } func (re *Resolver) resolveAppSessionCookie(r *http.Request) (session.ResolvedSession, error) {