-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
yangyile
committed
Oct 3, 2024
1 parent
7c00952
commit a64b671
Showing
4 changed files
with
176 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
package utils_kratos_ratelimit | ||
|
||
import ( | ||
"context" | ||
|
||
"github.com/go-kratos/kratos/v2/log" | ||
"github.com/go-kratos/kratos/v2/middleware" | ||
"github.com/go-kratos/kratos/v2/middleware/ratelimit" | ||
"github.com/go-kratos/kratos/v2/middleware/selector" | ||
"github.com/go-redis/redis_rate/v10" | ||
"github.com/orzkratos/authkratos/authkratospath" | ||
"github.com/yyle88/erero" | ||
) | ||
|
||
type Config struct { | ||
rateLimitBottle *redis_rate.Limiter | ||
rule *redis_rate.Limit | ||
ucGetUniqueCode func(ctx context.Context) string | ||
selectPath *authkratospath.SelectPath | ||
enable bool | ||
} | ||
|
||
func NewConfig( | ||
rateLimitBottle *redis_rate.Limiter, | ||
rule *redis_rate.Limit, | ||
ucGetUniqueCode func(ctx context.Context) string, | ||
selectPath *authkratospath.SelectPath, | ||
) *Config { | ||
return &Config{ | ||
rateLimitBottle: rateLimitBottle, | ||
rule: rule, | ||
ucGetUniqueCode: ucGetUniqueCode, | ||
selectPath: selectPath, | ||
enable: true, | ||
} | ||
} | ||
|
||
func (a *Config) SetEnable(v bool) { | ||
a.enable = v | ||
} | ||
|
||
func (a *Config) IsEnable() bool { | ||
if a != nil { | ||
return a.enable | ||
} | ||
return false | ||
} | ||
|
||
func NewMiddleware(cfg *Config, LOGGER log.Logger) middleware.Middleware { | ||
LOG := log.NewHelper(LOGGER) | ||
LOG.Infof( | ||
"new rate_limit middleware enable=%v rule=%v include=%v operations=%v", | ||
cfg.IsEnable(), | ||
cfg.rule.String(), | ||
cfg.selectPath.SelectSide, | ||
len(cfg.selectPath.Operations), | ||
) | ||
|
||
return selector.Server(middlewareFunc(cfg, LOGGER)).Match(matchFunc(cfg, LOGGER)).Build() | ||
} | ||
|
||
func matchFunc(cfg *Config, LOGGER log.Logger) selector.MatchFunc { | ||
LOG := log.NewHelper(LOGGER) | ||
|
||
return func(ctx context.Context, operation string) bool { | ||
if !cfg.IsEnable() { | ||
return false | ||
} | ||
match := cfg.selectPath.Match(operation) | ||
if match { | ||
LOG.Debugf("operation=%s include=%v match=%v must check rate", operation, cfg.selectPath.SelectSide, match) | ||
} else { | ||
LOG.Debugf("operation=%s include=%v match=%v skip check rate", operation, cfg.selectPath.SelectSide, match) | ||
} | ||
return match | ||
} | ||
} | ||
|
||
func middlewareFunc(cfg *Config, LOGGER log.Logger) middleware.Middleware { | ||
LOG := log.NewHelper(LOGGER) | ||
|
||
rateLimitRule := *cfg.rule | ||
|
||
return func(handleFunc middleware.Handler) middleware.Handler { | ||
return func(ctx context.Context, req interface{}) (resp interface{}, err error) { | ||
if !cfg.IsEnable() { | ||
LOG.Infof("rate_limit: cfg.enable=false anonymous pass") | ||
return handleFunc(ctx, req) | ||
} | ||
|
||
uck := cfg.ucGetUniqueCode(ctx) | ||
|
||
rls, err := cfg.rateLimitBottle.Allow(ctx, uck, rateLimitRule) | ||
if err != nil { | ||
return nil, erero.WithMessage(err, "rate_limit redis exception") | ||
} | ||
|
||
if rls.Allowed != 0 { | ||
LOG.Debugf("rate_limit allowed=%v remaining=%v so can pass", rls.Allowed, rls.Remaining) | ||
} else { | ||
LOG.Warnf("rate_limit exceeds so reject requests") | ||
|
||
return nil, ratelimit.ErrLimitExceed | ||
} | ||
return handleFunc(ctx, req) | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
package utils_kratos_ratelimit | ||
|
||
import ( | ||
"testing" | ||
) | ||
|
||
func TestMain(m *testing.M) { | ||
m.Run() | ||
} |