From 3db38e641044f06e4b7fb1acb703b41ce2f2982b Mon Sep 17 00:00:00 2001 From: Deng Ming Date: Sun, 7 Apr 2024 11:11:09 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=AF=E6=8C=81=E6=9B=B4=E6=96=B0=20token=20?= =?UTF-8?q?=E4=B8=AD=E7=9A=84=20claims=20=E6=95=B0=E6=8D=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/ratelimit/mocks/ratelimit.mock.go | 10 +++- session/global.go | 4 ++ session/memory.go | 6 ++ session/provider.mock_test.go | 32 +++++------ session/redis/provider.go | 19 ++++++- session/redis/provider_test.go | 64 +++++++++++++++++++++- session/types.go | 7 +++ 7 files changed, 119 insertions(+), 23 deletions(-) diff --git a/internal/ratelimit/mocks/ratelimit.mock.go b/internal/ratelimit/mocks/ratelimit.mock.go index b7b3b43..866b50c 100644 --- a/internal/ratelimit/mocks/ratelimit.mock.go +++ b/internal/ratelimit/mocks/ratelimit.mock.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: crawler_detector.go - +// Source: types.go +// +// Generated by this command: +// +// mockgen -source=types.go -package=limitmocks -destination=./mocks/ratelimit.mock.go +// // Package limitmocks is a generated GoMock package. package limitmocks @@ -44,7 +48,7 @@ func (m *MockLimiter) Limit(ctx context.Context, key string) (bool, error) { } // Limit indicates an expected call of Limit. -func (mr *MockLimiterMockRecorder) Limit(ctx, key interface{}) *gomock.Call { +func (mr *MockLimiterMockRecorder) Limit(ctx, key any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Limit", reflect.TypeOf((*MockLimiter)(nil).Limit), ctx, key) } diff --git a/session/global.go b/session/global.go index 68f719b..321e961 100644 --- a/session/global.go +++ b/session/global.go @@ -49,3 +49,7 @@ func CheckLoginMiddleware() gin.HandlerFunc { func RenewAccessToken(ctx *gctx.Context) error { return defaultProvider.RenewAccessToken(ctx) } + +func UpdateClaims(ctx *gctx.Context, claims Claims) error { + return defaultProvider.UpdateClaims(ctx, claims) +} diff --git a/session/memory.go b/session/memory.go index 4a925c4..c12541f 100644 --- a/session/memory.go +++ b/session/memory.go @@ -17,6 +17,8 @@ package session import ( "context" + "github.com/ecodeclub/ginx/gctx" + "github.com/ecodeclub/ekit" "github.com/ecodeclub/ginx/internal/errs" ) @@ -33,6 +35,10 @@ func (m *MemorySession) Destroy(ctx context.Context) error { return nil } +func (m *MemorySession) UpdateClaims(ctx *gctx.Context, claims Claims) error { + return nil +} + func (m *MemorySession) Del(ctx context.Context, key string) error { delete(m.data, key) return nil diff --git a/session/provider.mock_test.go b/session/provider.mock_test.go index 6dc4734..f8c44e9 100644 --- a/session/provider.mock_test.go +++ b/session/provider.mock_test.go @@ -1,23 +1,9 @@ -// Copyright 2023 ecodeclub -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - // Code generated by MockGen. DO NOT EDIT. -// Source: session/crawler_detector.go +// Source: ./types.go // // Generated by this command: // -// mockgen -copyright_file=.license_header -source=session/crawler_detector.go -package=session -destination=session/provider.mock_test.go Provider +// mockgen -source=./types.go -destination=./provider.mock_test.go -package=session Provider // // Package session is a generated GoMock package. package session @@ -190,3 +176,17 @@ func (mr *MockProviderMockRecorder) RenewAccessToken(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewAccessToken", reflect.TypeOf((*MockProvider)(nil).RenewAccessToken), ctx) } + +// UpdateClaims mocks base method. +func (m *MockProvider) UpdateClaims(ctx *gctx.Context, claims Claims) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateClaims", ctx, claims) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateClaims indicates an expected call of UpdateClaims. +func (mr *MockProviderMockRecorder) UpdateClaims(ctx, claims any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateClaims", reflect.TypeOf((*MockProvider)(nil).UpdateClaims), ctx, claims) +} diff --git a/session/redis/provider.go b/session/redis/provider.go index 4b2e987..d7fef27 100644 --- a/session/redis/provider.go +++ b/session/redis/provider.go @@ -34,8 +34,8 @@ var ( var _ session.Provider = &SessionProvider{} -// SessionProvider 默认是预加载机制,即 Get 的时候会顺便把所有的数据都拿过来 -// 默认情况下,产生的 Session 对应了两个 token,access token 和 refresh token +// SessionProvider 默认情况下,产生的 Session 对应了两个 token, +// access token 和 refresh token // 它们会被放进去 http.Response x-access-token 和 x-refresh-token 里面 // 后续前端发送请求的时候,需要把 token 放到 Authorization 中,以 Bearer 的形式传过来 // 很多字段并没有暴露,如果你需要自定义,可以发 issue @@ -49,6 +49,21 @@ type SessionProvider struct { expiration time.Duration } +// UpdateClaims 在这个实现里面,claims 同时写进去了 +func (rsp *SessionProvider) UpdateClaims(ctx *gctx.Context, claims session.Claims) error { + accessToken, err := rsp.m.GenerateAccessToken(claims) + if err != nil { + return err + } + refreshToken, err := rsp.m.GenerateRefreshToken(claims) + if err != nil { + return err + } + ctx.Header(rsp.atHeader, accessToken) + ctx.Header(rsp.rtHeader, refreshToken) + return nil +} + func (rsp *SessionProvider) RenewAccessToken(ctx *ginx.Context) error { // 此时这里应该放着 RefreshToken rt := rsp.extractTokenString(ctx) diff --git a/session/redis/provider_test.go b/session/redis/provider_test.go index 35897c7..5afb89e 100644 --- a/session/redis/provider_test.go +++ b/session/redis/provider_test.go @@ -64,6 +64,67 @@ func (s *ProviderTestSuite) TestRenewSession() { require.NoError(s.T(), err) } +func TestSessionProvider_UpdateClaims(t *testing.T) { + testCases := []struct { + name string + mock func(ctrl *gomock.Controller) redis.Cmdable + wantErr error + }{ + { + name: "更新成功", + mock: func(ctrl *gomock.Controller) redis.Cmdable { + cmd := mocks.NewMockCmdable(ctrl) + pip := mocks.NewMockPipeliner(ctrl) + pip.EXPECT().HMSet(gomock.Any(), gomock.Any(), gomock.Any()). + AnyTimes().Return(nil) + pip.EXPECT().Expire(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + pip.EXPECT().Exec(gomock.Any()).Return(nil, nil) + cmd.EXPECT().Pipeline().Return(pip) + return cmd + }, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + client := tc.mock(ctrl) + sp := NewSessionProvider(client, "123") + recorder := httptest.NewRecorder() + + ctx, _ := gin.CreateTestContext(recorder) + // 先创建一个 + _, err := sp.NewSession(&gctx.Context{ + Context: ctx, + }, 123, map[string]string{"hello": "world"}, map[string]any{}) + + gtx := &gctx.Context{ + Context: ctx, + } + newCl := session.Claims{ + Uid: 234, + SSID: "ssid_123", + Data: map[string]string{"hello": "nihao"}} + + err = sp.UpdateClaims(gtx, newCl) + assert.Equal(t, tc.wantErr, err) + if err != nil { + return + } + token := ctx.Writer.Header().Get("X-Access-Token") + rc, err := sp.m.VerifyAccessToken(token) + require.NoError(t, err) + cl := rc.Data + assert.Equal(t, newCl, cl) + token = ctx.Writer.Header().Get("X-Refresh-Token") + rc, err = sp.m.VerifyAccessToken(token) + require.NoError(t, err) + cl = rc.Data + assert.Equal(t, newCl, cl) + }) + } +} + func TestProvider(t *testing.T) { suite.Run(t, new(ProviderTestSuite)) } @@ -101,8 +162,7 @@ func TestSessionProvider_NewSession(t *testing.T) { ctx, _ := gin.CreateTestContext(recorder) sess, err := sp.NewSession(&gctx.Context{ Context: ctx, - }, 123, - map[string]string{"hello": "world"}, map[string]any{}) + }, 123, map[string]string{"hello": "world"}, map[string]any{}) assert.Equal(t, tc.wantErr, err) if err != nil { return diff --git a/session/types.go b/session/types.go index 4bffea5..ebf9ec4 100644 --- a/session/types.go +++ b/session/types.go @@ -38,6 +38,8 @@ type Session interface { // Provider 定义了 Session 的整个管理机制。 // 所有的 Session 都必须支持 jwt +// +//go:generate mockgen -source=./types.go -destination=./provider.mock_test.go -package=session Provider type Provider interface { // NewSession 将会初始化 Session // 其中 jwtData 将编码进去 jwt 中 @@ -49,6 +51,11 @@ type Provider interface { // 也就是,用户可以预期拿到的 Session 永远是没有过期,直接可用的 Get(ctx *gctx.Context) (Session, error) + // UpdateClaims 修改 claims 的数据 + // 但是因为 jwt 本身是不可变的,所以实际上这里是重新生成了一个 jwt 的 token + // 必须传入正确的 SSID + UpdateClaims(ctx *gctx.Context, claims Claims) error + // RenewAccessToken 刷新并且返回一个新的 access token // 这个过程会校验长 token 的合法性 RenewAccessToken(ctx *gctx.Context) error