From b3ca0b38f1d126a7b7960c7f6d023f9479126f92 Mon Sep 17 00:00:00 2001 From: Pavel Tatarskiy Date: Tue, 17 Dec 2024 23:27:47 +0300 Subject: [PATCH] update lazymap add rate limitting for NotFound --- go.mod | 2 +- go.sum | 2 + services/abuse.go | 44 ++++++--------- services/providers/badger.go | 15 ++++-- services/providers/redis.go | 16 +++--- services/providers/s3.go | 39 +++++++------- services/server.go | 46 ++++++++-------- services/store.go | 101 ++++++++++++++++++++++++----------- 8 files changed, 154 insertions(+), 111 deletions(-) diff --git a/go.mod b/go.mod index 36c2124..3a2c089 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,7 @@ require ( github.com/urfave/cli v1.22.16 github.com/webtor-io/abuse-store v0.0.0-20221220184115-12e6c0615a10 github.com/webtor-io/common-services v0.0.0-20241022160325-d391acd827ab - github.com/webtor-io/lazymap v0.0.0-20221030185154-1799721becef + github.com/webtor-io/lazymap v0.0.0-20241211155941-e81d935cfa1d golang.org/x/sys v0.26.0 // indirect google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 // indirect google.golang.org/grpc v1.56.3 diff --git a/go.sum b/go.sum index 5a67f9b..f95cc8c 100644 --- a/go.sum +++ b/go.sum @@ -359,6 +359,8 @@ github.com/webtor-io/common-services v0.0.0-20241022160325-d391acd827ab h1:71AxB github.com/webtor-io/common-services v0.0.0-20241022160325-d391acd827ab/go.mod h1:6jUeO6R+ytZnEJj7PlcLEQZfWaxw8ovav73BP83MTlI= github.com/webtor-io/lazymap v0.0.0-20221030185154-1799721becef h1:tSAcIGxgmsxFJnLRyTkhzFhjGAtaOZ5g2osBHG3JYBs= github.com/webtor-io/lazymap v0.0.0-20221030185154-1799721becef/go.mod h1:za/bioTGK3VjG3+mK7/kpx0TV8++ytZkdOQ1MJ2HTjM= +github.com/webtor-io/lazymap v0.0.0-20241211155941-e81d935cfa1d h1:Xi9E0LCDgK++QliA7ZNFdSI11Bpg5qe7efN3AMWJ3dY= +github.com/webtor-io/lazymap v0.0.0-20241211155941-e81d935cfa1d/go.mod h1:kioEFK4hk8YfHrhg47tGvMG40xawOJM4gcfRQ4EeX4k= github.com/webtor-io/stoplist v0.0.0-20230128160543-ea87bdc34deb h1:RCjga119RT7hTqYeELSGPTVCYfMUIyaSyc2QCaDk/ik= github.com/webtor-io/stoplist v0.0.0-20230128160543-ea87bdc34deb/go.mod h1:nlKK64Domln2CfQUQiP2+RcbD0IQPjoBfHDYmiGLuqY= github.com/willf/bitset v1.1.9/go.mod h1:RjeCKbqT1RxIR/KWY6phxZiaY1IyutSBfGjNPySAYV4= diff --git a/services/abuse.go b/services/abuse.go index 2415dde..ce4df43 100644 --- a/services/abuse.go +++ b/services/abuse.go @@ -29,7 +29,7 @@ var ( ) type Abuse struct { - lazymap.LazyMap + lazymap.LazyMap[bool] cl *AbuseClient } @@ -39,36 +39,26 @@ func NewAbuse(c *cli.Context, cl *AbuseClient) *Abuse { } return &Abuse{ cl: cl, - LazyMap: lazymap.New(&lazymap.Config{ + LazyMap: lazymap.New[bool](&lazymap.Config{ Expire: time.Minute, - ErrorExpire: 10 * time.Second, + StoreErrors: false, }), } } -func (s *Abuse) get(h string) error { - cl, err := s.cl.Get() - if err != nil { - return err - } - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - r, err := cl.Check(ctx, &as.CheckRequest{Infohash: h}) - if err != nil { - return err - } - if r.GetExists() { - return ErrAbuse - } - return nil -} - -func (s *Abuse) Get(h string) error { - _, err := s.LazyMap.Get(h, func() (interface{}, error) { - return nil, s.get(h) +func (s *Abuse) Get(ctx context.Context, h string) (bool, error) { + return s.LazyMap.Get(h, func() (bool, error) { + cl, err := s.cl.Get() + if err != nil { + return false, err + } + r, err := cl.Check(ctx, &as.CheckRequest{Infohash: h}) + if err != nil { + return false, err + } + if r.GetExists() { + return true, nil + } + return false, nil }) - if err != nil { - return err - } - return nil } diff --git a/services/providers/badger.go b/services/providers/badger.go index 8b97337..3046b0a 100644 --- a/services/providers/badger.go +++ b/services/providers/badger.go @@ -52,12 +52,11 @@ func (s *Badger) Name() string { return "badger" } -func (s *Badger) Touch(_ context.Context, h string) (err error) { +func (s *Badger) Touch(_ context.Context, h string) (ok bool, err error) { err = s.db.Update(func(txn *badger.Txn) error { i, err := txn.Get([]byte(h)) if errors.Is(err, badger.ErrKeyNotFound) { return ss.ErrNotFound - } else { err = i.Value(func(val []byte) error { e := badger.NewEntry([]byte(h), val).WithTTL(s.exp) @@ -66,15 +65,21 @@ func (s *Badger) Touch(_ context.Context, h string) (err error) { return err } }) - return + if err != nil { + return false, err + } + return true, nil } -func (s *Badger) Push(_ context.Context, h string, torrent []byte) (err error) { +func (s *Badger) Push(_ context.Context, h string, torrent []byte) (ok bool, err error) { err = s.db.Update(func(txn *badger.Txn) error { e := badger.NewEntry([]byte(h), torrent).WithTTL(s.exp) return txn.SetEntry(e) }) - return + if err != nil { + return false, err + } + return true, nil } func (s *Badger) Pull(_ context.Context, h string) (torrent []byte, err error) { diff --git a/services/providers/redis.go b/services/providers/redis.go index 2f4ec8f..d177b23 100644 --- a/services/providers/redis.go +++ b/services/providers/redis.go @@ -51,23 +51,27 @@ func (s *Redis) Name() string { return "redis" } -func (s *Redis) Touch(ctx context.Context, h string) (err error) { +func (s *Redis) Touch(ctx context.Context, h string) (ok bool, err error) { cl := s.cl.Get() res, err := cl.Expire(ctx, h, s.exp).Result() if err != nil { - return err + return false, err } if !res { - return ss.ErrNotFound + return false, ss.ErrNotFound } - return nil + return true, nil } -func (s *Redis) Push(ctx context.Context, h string, torrent []byte) (err error) { +func (s *Redis) Push(ctx context.Context, h string, torrent []byte) (ok bool, err error) { cl := s.cl.Get() - return cl.Set(ctx, h, torrent, s.exp).Err() + err = cl.Set(ctx, h, torrent, s.exp).Err() + if err != nil { + return false, err + } + return true, nil } func (s *Redis) Pull(ctx context.Context, h string) (torrent []byte, err error) { diff --git a/services/providers/s3.go b/services/providers/s3.go index 6c3af64..d1c1f89 100644 --- a/services/providers/s3.go +++ b/services/providers/s3.go @@ -54,22 +54,22 @@ func (s *S3) Name() string { return "s3" } -func (s *S3) Touch(ctx context.Context, h string) (err error) { - //cl := s.cl.Get() - //r, err := cl.GetObjectWithContext(ctx, &s3.GetObjectInput{ - // Bucket: aws.String(s.bucket), - // Key: aws.String(h), - //}) - //if err != nil { - // if awsErr, ok := err.(awserr.Error); ok && awsErr.Code() == s3.ErrCodeNoSuchKey { - // return ss.ErrNotFound - // } - // return err - //} - //defer func(Body io.ReadCloser) { - // _ = Body.Close() - //}(r.Body) - return nil +func (s *S3) Touch(ctx context.Context, h string) (ok bool, err error) { + cl := s.cl.Get() + r, err := cl.GetObjectWithContext(ctx, &s3.GetObjectInput{ + Bucket: aws.String(s.bucket), + Key: aws.String(h), + }) + if err != nil { + if awsErr, ok := err.(awserr.Error); ok && awsErr.Code() == s3.ErrCodeNoSuchKey { + return false, ss.ErrNotFound + } + return false, err + } + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(r.Body) + return true, nil } func (s *S3) makeAWSMD5(b []byte) *string { @@ -78,7 +78,7 @@ func (s *S3) makeAWSMD5(b []byte) *string { return aws.String(m) } -func (s *S3) Push(ctx context.Context, h string, torrent []byte) (err error) { +func (s *S3) Push(ctx context.Context, h string, torrent []byte) (ok bool, err error) { cl := s.cl.Get() _, err = cl.PutObjectWithContext(ctx, &s3.PutObjectInput{ @@ -87,7 +87,10 @@ func (s *S3) Push(ctx context.Context, h string, torrent []byte) (err error) { Body: bytes.NewReader(torrent), ContentMD5: s.makeAWSMD5(torrent), }) - return + if err != nil { + return false, err + } + return true, nil } func (s *S3) Pull(ctx context.Context, h string) (torrent []byte, err error) { diff --git a/services/server.go b/services/server.go index a551af7..59006a6 100644 --- a/services/server.go +++ b/services/server.go @@ -30,22 +30,23 @@ func NewServer(s *Store, a *Abuse, sl *Stoplist) *Server { func (s *Server) Pull(ctx context.Context, in *pb.PullRequest) (*pb.PullReply, error) { t := time.Now() - hLog := log.WithField("infoHash", in.GetInfoHash()) + + hLog := log.WithField("infoHash", in.GetInfoHash()).WithField("method", "pull") hLog.Info("pull torrent request") - err := s.isAbused(in.GetInfoHash()) - if errors.Is(err, ErrAbuse) { - hLog.WithField("duration", time.Since(t)).Warn("abused") - return nil, status.Errorf(codes.PermissionDenied, "Restricted by the rightholder infoHash=%v", in.GetInfoHash()) - } else if err != nil { + abused, err := s.isAbused(ctx, in.GetInfoHash()) + if err != nil { hLog.WithField("duration", time.Since(t)).WithError(err).Error("failed to check abuse") return nil, errors.Wrapf(err, "failed to check abuse infoHash=%v", in.GetInfoHash()) } - + if abused { + hLog.WithField("duration", time.Since(t)).Warn("abused") + return nil, status.Errorf(codes.PermissionDenied, "restricted by the rightholder infoHash=%v", in.GetInfoHash()) + } torrent, err := s.s.Pull(ctx, in.GetInfoHash()) if errors.Is(err, ErrNotFound) { hLog.WithField("duration", time.Since(t)).Info("torrent not found") - return nil, status.Errorf(codes.NotFound, "Unable to find torrent for infoHash=%v", in.GetInfoHash()) + return nil, status.Errorf(codes.NotFound, "unable to find torrent for infoHash=%v", in.GetInfoHash()) } else if err != nil { hLog.WithField("duration", time.Since(t)).WithError(err).Error("failed to pull") return nil, errors.Wrapf(err, "failed to pull torrent infoHash=%v", in.GetInfoHash()) @@ -69,7 +70,7 @@ func (s *Server) checkStoplist(torrent []byte, log *log.Entry, t time.Time, hash } if cr.Found { log.WithField("duration", time.Since(t)).Warnf("found in stoplist %v", cr.String()) - return status.Errorf(codes.PermissionDenied, "Found in stoplist infoHash=%v", hash) + return status.Errorf(codes.PermissionDenied, "found in stoplist infoHash=%v", hash) } return nil } @@ -83,7 +84,7 @@ func (s *Server) Push(ctx context.Context, in *pb.PushRequest) (*pb.PushReply, e return nil, err } infoHash := mi.HashInfoBytes().HexString() - hLog := log.WithField("infoHash", infoHash) + hLog := log.WithField("infoHash", infoHash).WithField("method", "push") hLog.Info("push torrent request") err = s.checkStoplist(in.GetTorrent(), hLog, t, infoHash) @@ -91,15 +92,17 @@ func (s *Server) Push(ctx context.Context, in *pb.PushRequest) (*pb.PushReply, e return nil, err } - err = s.isAbused(infoHash) - if errors.Is(err, ErrAbuse) { - hLog.WithField("duration", time.Since(t)).Warn("abused") - return nil, status.Errorf(codes.PermissionDenied, "Restricted by the rightholder infoHash=%v", infoHash) - } else if err != nil { + abused, err := s.isAbused(ctx, infoHash) + if err != nil { hLog.WithField("duration", time.Since(t)).WithError(err).Error("failed to check abuse") return nil, errors.Wrapf(err, "failed to check abuse infoHash=%v", infoHash) } - err = s.s.Push(ctx, infoHash, in.GetTorrent()) + if abused { + hLog.WithField("duration", time.Since(t)).Warn("abused") + return nil, status.Errorf(codes.PermissionDenied, "restricted by the rightholder infoHash=%v", infoHash) + } + + _, err = s.s.Push(ctx, infoHash, in.GetTorrent()) if err != nil { hLog.WithField("duration", time.Since(t)).WithError(err).Error("failed to push") return nil, errors.Wrapf(err, "failed to push torrent infoHash=%v", infoHash) @@ -109,20 +112,17 @@ func (s *Server) Push(ctx context.Context, in *pb.PushRequest) (*pb.PushReply, e return &pb.PushReply{InfoHash: infoHash}, nil } -func (s *Server) isAbused(h string) error { - if s.a == nil { - return nil - } - return s.a.Get(h) +func (s *Server) isAbused(ctx context.Context, h string) (bool, error) { + return s.a.Get(ctx, h) } func (s *Server) Touch(ctx context.Context, in *pb.TouchRequest) (*pb.TouchReply, error) { t := time.Now() infoHash := in.GetInfoHash() - hLog := log.WithField("infoHash", infoHash) + hLog := log.WithField("infoHash", infoHash).WithField("method", "touch") hLog.Info("touch torrent request") - err := s.s.Touch(ctx, infoHash) + _, err := s.s.Touch(ctx, infoHash) if errors.Is(err, ErrNotFound) { hLog.WithField("duration", time.Since(t)).Info("torrent not found") return nil, status.Errorf(codes.NotFound, "torrent not found infoHash=%v", infoHash) diff --git a/services/store.go b/services/store.go index 4575748..a1f5f69 100644 --- a/services/store.go +++ b/services/store.go @@ -2,6 +2,7 @@ package services import ( "context" + "sync/atomic" "time" "github.com/pkg/errors" @@ -11,18 +12,19 @@ import ( ) type StoreProvider interface { - Push(ctx context.Context, h string, torrent []byte) (err error) + Push(ctx context.Context, h string, torrent []byte) (ok bool, err error) Pull(ctx context.Context, h string) (torrent []byte, err error) - Touch(ctx context.Context, h string) (err error) + Touch(ctx context.Context, h string) (ok bool, err error) Name() string } type Store struct { - pullm *lazymap.LazyMap - pushm *lazymap.LazyMap - touchm *lazymap.LazyMap + pullm *lazymap.LazyMap[[]byte] + pushm *lazymap.LazyMap[bool] + touchm *lazymap.LazyMap[bool] providers []StoreProvider revProviders []StoreProvider + ratem *lazymap.LazyMap[*atomic.Int64] } var ( @@ -31,12 +33,18 @@ var ( func NewStore(providers []StoreProvider) *Store { cfg := &lazymap.Config{ - ErrorExpire: 10 * time.Second, - Expire: time.Minute, + Expire: 5 * time.Minute, + StoreErrors: false, } - pullm := lazymap.New(cfg) - pushm := lazymap.New(cfg) - touchm := lazymap.New(cfg) + + rateCfg := &lazymap.Config{ + Expire: 1 * time.Minute, + StoreErrors: false, + } + pullm := lazymap.New[[]byte](cfg) + pushm := lazymap.New[bool](cfg) + touchm := lazymap.New[bool](cfg) + ratem := lazymap.New[*atomic.Int64](rateCfg) var revProviders []StoreProvider for _, p := range providers { log.WithField("provider", p.Name()).Info("use provider") @@ -49,29 +57,54 @@ func NewStore(providers []StoreProvider) *Store { pullm: &pullm, pushm: &pushm, touchm: &touchm, + ratem: &ratem, providers: providers, revProviders: revProviders, } } -func (s *Store) push(ctx context.Context, h string, torrent []byte) (val interface{}, err error) { +func (s *Store) push(ctx context.Context, h string, torrent []byte) (ok bool, err error) { for _, v := range s.revProviders { t := time.Now() - err = v.Push(ctx, h, torrent) + ok, err = v.Push(ctx, h, torrent) if err != nil { - return nil, err + return false, err } log.WithField("infohash", h).WithField("duration", time.Since(t)).WithField("provider", v.Name()).Info("provider push") } return } -func (s *Store) touch(ctx context.Context, h string) (val interface{}, err error) { - s.touchm.Touch(h) +func (s *Store) checkRate(h string) bool { + a := s.getRate(h) + return a.Load() < 10 +} + +func (s *Store) incRate(h string) { + a := s.getRate(h) + go func() { + <-time.After(time.Minute) + a.Add(-1) + }() + a.Add(1) +} + +func (s *Store) getRate(h string) *atomic.Int64 { + a, _ := s.ratem.Get(h, func() (*atomic.Int64, error) { + return &atomic.Int64{}, nil + }) + return a +} +func (s *Store) touch(ctx context.Context, h string) (ok bool, err error) { + if !s.checkRate(h) { + log.WithField("infohash", h).Warn("get rate limit") + return false, ErrNotFound + } + s.touchm.Touch(h) for i, v := range s.providers { t := time.Now() - err = v.Touch(ctx, h) + ok, err = v.Touch(ctx, h) if errors.Is(err, ErrNotFound) { log.WithField("infohash", h).WithField("duration", time.Since(t)).WithField("provider", v.Name()).Info("provider not touched") continue @@ -81,14 +114,23 @@ func (s *Store) touch(ctx context.Context, h string) (val interface{}, err error } log.WithField("infohash", h).WithField("duration", time.Since(t)).WithField("provider", v.Name()).Info("provider touch") if i > 0 { - go s.pull(ctx, h, i) + go func() { + _, _ = s.pull(ctx, h, i) + }() } break } + if err != nil && errors.Is(err, ErrNotFound) { + s.incRate(h) + } return } func (s *Store) pull(ctx context.Context, h string, start int) (torrent []byte, err error) { + if !s.checkRate(h) { + log.WithField("infohash", h).Warn("get rate limit") + return nil, ErrNotFound + } for i := start; i < len(s.providers); i++ { t := time.Now() torrent, err = s.providers[i].Pull(ctx, h) @@ -101,7 +143,7 @@ func (s *Store) pull(ctx context.Context, h string, start int) (torrent []byte, if torrent != nil { for j := 0; j < i; j++ { log.WithField("infohash", h).WithField("provider", s.providers[j].Name()).Info("provider push") - err = s.providers[j].Push(ctx, h, torrent) + _, err = s.providers[j].Push(ctx, h, torrent) if err != nil { log.WithField("infohash", h).WithField("duration", time.Since(t)).WithField("provider", s.providers[j].Name()).WithError(err).Warn("provider not pushed") continue @@ -110,30 +152,27 @@ func (s *Store) pull(ctx context.Context, h string, start int) (torrent []byte, } break } + if err != nil && errors.Is(err, ErrNotFound) { + s.incRate(h) + } return } -func (s *Store) Pull(ctx context.Context, h string) (torrent []byte, err error) { - v, err := s.pullm.Get(h, func() (interface{}, error) { +func (s *Store) Pull(ctx context.Context, h string) ([]byte, error) { + return s.pullm.Get(h, func() ([]byte, error) { return s.pull(ctx, h, 0) }) - if err != nil { - return nil, err - } - torrent = v.([]byte) - return + } -func (s *Store) Push(ctx context.Context, h string, torrent []byte) (err error) { - _, err = s.pushm.Get(h, func() (interface{}, error) { +func (s *Store) Push(ctx context.Context, h string, torrent []byte) (bool, error) { + return s.pushm.Get(h, func() (bool, error) { return s.push(ctx, h, torrent) }) - return err } -func (s *Store) Touch(ctx context.Context, h string) (err error) { - _, err = s.touchm.Get(h, func() (interface{}, error) { +func (s *Store) Touch(ctx context.Context, h string) (bool, error) { + return s.touchm.Get(h, func() (bool, error) { return s.touch(ctx, h) }) - return err }