Skip to content

Commit

Permalink
支持更新 token 中的 claims 数据
Browse files Browse the repository at this point in the history
  • Loading branch information
flycash committed Apr 7, 2024
1 parent 1226b60 commit 3db38e6
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 23 deletions.
10 changes: 7 additions & 3 deletions internal/ratelimit/mocks/ratelimit.mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions session/global.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,7 @@ func CheckLoginMiddleware() gin.HandlerFunc {
func RenewAccessToken(ctx *gctx.Context) error {
return defaultProvider.RenewAccessToken(ctx)
}

func UpdateClaims(ctx *gctx.Context, claims Claims) error {
return defaultProvider.UpdateClaims(ctx, claims)

Check warning on line 54 in session/global.go

View check run for this annotation

Codecov / codecov/patch

session/global.go#L53-L54

Added lines #L53 - L54 were not covered by tests
}
6 changes: 6 additions & 0 deletions session/memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ package session
import (
"context"

"github.com/ecodeclub/ginx/gctx"

"github.com/ecodeclub/ekit"
"github.com/ecodeclub/ginx/internal/errs"
)
Expand All @@ -33,6 +35,10 @@ func (m *MemorySession) Destroy(ctx context.Context) error {
return nil
}

func (m *MemorySession) UpdateClaims(ctx *gctx.Context, claims Claims) error {
return nil

Check warning on line 39 in session/memory.go

View check run for this annotation

Codecov / codecov/patch

session/memory.go#L38-L39

Added lines #L38 - L39 were not covered by tests
}

func (m *MemorySession) Del(ctx context.Context, key string) error {
delete(m.data, key)
return nil
Expand Down
32 changes: 16 additions & 16 deletions session/provider.mock_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 17 additions & 2 deletions session/redis/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ var (

var _ session.Provider = &SessionProvider{}

// SessionProvider 默认是预加载机制,即 Get 的时候会顺便把所有的数据都拿过来
// 默认情况下,产生的 Session 对应了两个 token,access token 和 refresh token
// SessionProvider 默认情况下,产生的 Session 对应了两个 token,
// access token 和 refresh token
// 它们会被放进去 http.Response x-access-token 和 x-refresh-token 里面
// 后续前端发送请求的时候,需要把 token 放到 Authorization 中,以 Bearer 的形式传过来
// 很多字段并没有暴露,如果你需要自定义,可以发 issue
Expand All @@ -49,6 +49,21 @@ type SessionProvider struct {
expiration time.Duration
}

// UpdateClaims 在这个实现里面,claims 同时写进去了
func (rsp *SessionProvider) UpdateClaims(ctx *gctx.Context, claims session.Claims) error {
accessToken, err := rsp.m.GenerateAccessToken(claims)
if err != nil {
return err
}

Check warning on line 57 in session/redis/provider.go

View check run for this annotation

Codecov / codecov/patch

session/redis/provider.go#L56-L57

Added lines #L56 - L57 were not covered by tests
refreshToken, err := rsp.m.GenerateRefreshToken(claims)
if err != nil {
return err
}

Check warning on line 61 in session/redis/provider.go

View check run for this annotation

Codecov / codecov/patch

session/redis/provider.go#L60-L61

Added lines #L60 - L61 were not covered by tests
ctx.Header(rsp.atHeader, accessToken)
ctx.Header(rsp.rtHeader, refreshToken)
return nil
}

func (rsp *SessionProvider) RenewAccessToken(ctx *ginx.Context) error {
// 此时这里应该放着 RefreshToken
rt := rsp.extractTokenString(ctx)
Expand Down
64 changes: 62 additions & 2 deletions session/redis/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,67 @@ func (s *ProviderTestSuite) TestRenewSession() {
require.NoError(s.T(), err)
}

func TestSessionProvider_UpdateClaims(t *testing.T) {
testCases := []struct {
name string
mock func(ctrl *gomock.Controller) redis.Cmdable
wantErr error
}{
{
name: "更新成功",
mock: func(ctrl *gomock.Controller) redis.Cmdable {
cmd := mocks.NewMockCmdable(ctrl)
pip := mocks.NewMockPipeliner(ctrl)
pip.EXPECT().HMSet(gomock.Any(), gomock.Any(), gomock.Any()).
AnyTimes().Return(nil)
pip.EXPECT().Expire(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil)
pip.EXPECT().Exec(gomock.Any()).Return(nil, nil)
cmd.EXPECT().Pipeline().Return(pip)
return cmd
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
client := tc.mock(ctrl)
sp := NewSessionProvider(client, "123")
recorder := httptest.NewRecorder()

ctx, _ := gin.CreateTestContext(recorder)
// 先创建一个
_, err := sp.NewSession(&gctx.Context{
Context: ctx,
}, 123, map[string]string{"hello": "world"}, map[string]any{})

gtx := &gctx.Context{
Context: ctx,
}
newCl := session.Claims{
Uid: 234,
SSID: "ssid_123",
Data: map[string]string{"hello": "nihao"}}

err = sp.UpdateClaims(gtx, newCl)
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
}
token := ctx.Writer.Header().Get("X-Access-Token")
rc, err := sp.m.VerifyAccessToken(token)
require.NoError(t, err)
cl := rc.Data
assert.Equal(t, newCl, cl)
token = ctx.Writer.Header().Get("X-Refresh-Token")
rc, err = sp.m.VerifyAccessToken(token)
require.NoError(t, err)
cl = rc.Data
assert.Equal(t, newCl, cl)
})
}
}

func TestProvider(t *testing.T) {
suite.Run(t, new(ProviderTestSuite))
}
Expand Down Expand Up @@ -101,8 +162,7 @@ func TestSessionProvider_NewSession(t *testing.T) {
ctx, _ := gin.CreateTestContext(recorder)
sess, err := sp.NewSession(&gctx.Context{
Context: ctx,
}, 123,
map[string]string{"hello": "world"}, map[string]any{})
}, 123, map[string]string{"hello": "world"}, map[string]any{})
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
Expand Down
7 changes: 7 additions & 0 deletions session/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ type Session interface {

// Provider 定义了 Session 的整个管理机制。
// 所有的 Session 都必须支持 jwt
//
//go:generate mockgen -source=./types.go -destination=./provider.mock_test.go -package=session Provider
type Provider interface {
// NewSession 将会初始化 Session
// 其中 jwtData 将编码进去 jwt 中
Expand All @@ -49,6 +51,11 @@ type Provider interface {
// 也就是,用户可以预期拿到的 Session 永远是没有过期,直接可用的
Get(ctx *gctx.Context) (Session, error)

// UpdateClaims 修改 claims 的数据
// 但是因为 jwt 本身是不可变的,所以实际上这里是重新生成了一个 jwt 的 token
// 必须传入正确的 SSID
UpdateClaims(ctx *gctx.Context, claims Claims) error

// RenewAccessToken 刷新并且返回一个新的 access token
// 这个过程会校验长 token 的合法性
RenewAccessToken(ctx *gctx.Context) error
Expand Down

0 comments on commit 3db38e6

Please sign in to comment.