From 4770c602654ddfa5585c97bdd20524ab8ec6db45 Mon Sep 17 00:00:00 2001 From: Pavel Tatarskiy Date: Sat, 30 Nov 2024 13:52:51 +0300 Subject: [PATCH] add context to claims --- handlers/embed/post.go | 2 +- services/claims/claims.go | 10 ++++------ services/embed/domain_settings.go | 9 +++++---- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/handlers/embed/post.go b/handlers/embed/post.go index 159ba53..76334c6 100644 --- a/handlers/embed/post.go +++ b/handlers/embed/post.go @@ -56,7 +56,7 @@ func (s *Handler) post(c *gin.Context) { tpl.HTMLWithErr(err, http.StatusBadRequest, c, pd) return } - dsd, err := s.ds.Get(u.Hostname()) + dsd, err := s.ds.Get(c.Request.Context(), u.Hostname()) if err != nil { tpl.HTMLWithErr(err, http.StatusBadRequest, c, pd) return diff --git a/services/claims/claims.go b/services/claims/claims.go index bbefc0a..b1b847c 100644 --- a/services/claims/claims.go +++ b/services/claims/claims.go @@ -48,14 +48,12 @@ func New(c *cli.Context, cl *Client) *Claims { } } -func (s *Claims) get(email string) (resp *Data, err error) { +func (s *Claims) get(ctx context.Context, email string) (resp *Data, err error) { var cl proto.ClaimsProviderClient cl, err = s.cl.Get() if err != nil { return nil, err } - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() resp, err = cl.Get(ctx, &proto.GetRequest{Email: email}) if err != nil { return nil, errors.WithMessage(err, "failed to get claims") @@ -63,9 +61,9 @@ func (s *Claims) get(email string) (resp *Data, err error) { return } -func (s *Claims) Get(email string) (*Data, error) { +func (s *Claims) Get(ctx context.Context, email string) (*Data, error) { resp, err := s.LazyMap.Get(email, func() (interface{}, error) { - return s.get(email) + return s.get(ctx, email) }) if err != nil { return nil, err @@ -75,7 +73,7 @@ func (s *Claims) Get(email string) (*Data, error) { func (s *Claims) MakeUserClaimsFromContext(c *gin.Context) (*Data, error) { u := auth.GetUserFromContext(c) - r, err := s.Get(u.Email) + r, err := s.Get(c.Request.Context(), u.Email) if err != nil { return nil, err } diff --git a/services/embed/domain_settings.go b/services/embed/domain_settings.go index 14346de..27b3fab 100644 --- a/services/embed/domain_settings.go +++ b/services/embed/domain_settings.go @@ -1,6 +1,7 @@ package embed import ( + "context" "github.com/pkg/errors" "time" @@ -31,7 +32,7 @@ func NewDomainSettings(pg *cs.PG, claims *claims.Claims) *DomainSettings { } } -func (s *DomainSettings) get(domain string) (*DomainSettingsData, error) { +func (s *DomainSettings) get(ctx context.Context, domain string) (*DomainSettingsData, error) { if s.pg == nil || s.pg.Get() == nil || s.claims == nil { return &DomainSettingsData{}, nil } @@ -43,16 +44,16 @@ func (s *DomainSettings) get(domain string) (*DomainSettingsData, error) { } else if err != nil { return nil, err } - cl, err := s.claims.Get(em.Email) + cl, err := s.claims.Get(ctx, em.Email) if err != nil { return nil, err } return &DomainSettingsData{Ads: em.Ads || !cl.Claims.Embed.NoAds}, nil } -func (s *DomainSettings) Get(domain string) (*DomainSettingsData, error) { +func (s *DomainSettings) Get(ctx context.Context, domain string) (*DomainSettingsData, error) { resp, err := s.LazyMap.Get(domain, func() (interface{}, error) { - return s.get(domain) + return s.get(ctx, domain) }) if err != nil { return nil, err