Skip to content

Commit

Permalink
Refactor API struct constructor (#4169) (#4174)
Browse files Browse the repository at this point in the history
(cherry picked from commit 3c16452)

Co-authored-by: Michel Laterman <[email protected]>
  • Loading branch information
mergify[bot] and michel-laterman authored Dec 4, 2024
1 parent 1ef97d1 commit 3680c42
Show file tree
Hide file tree
Showing 9 changed files with 188 additions and 82 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Kind can be one of:
# - breaking-change: a change to previously-documented behavior
# - deprecation: functionality that is being removed in a later release
# - bug-fix: fixes a problem in a previous version
# - enhancement: extends functionality but does not break or fix existing behavior
# - feature: new functionality
# - known-issue: problems that we are aware of in a given version
# - security: impacts on the security of a product or a user’s deployment.
# - upgrade: important information for someone upgrading from a prior version
# - other: does not fit into any of the other categories
kind: other

# Change summary; a 80ish characters long description of the change.
summary: Refactor API struct constructor

# Long description; in case the summary is not enough to describe the change
# this field accommodate a description without length limits.
# NOTE: This field will be rendered only for breaking-change and known-issue kinds at the moment.
#description:

# Affected component; usually one of "elastic-agent", "fleet-server", "filebeat", "metricbeat", "auditbeat", "all", etc.
component: fleet-server

# PR URL; optional; the PR number that added the changeset.
# If not present is automatically filled by the tooling finding the PR where this changelog fragment has been added.
# NOTE: the tooling supports backports, so it's able to fill the original PR number instead of the backport PR number.
# Please provide it if you are adding a fragment for a different PR.
#pr: https://github.com/owner/repo/1234

# Issue URL; optional; the GitHub issue related to this changeset (either closes or is part of).
# If not present is automatically filled by the tooling with the issue linked to the PR number.
issue: https://github.com/elastic/fleet-server/issues/3823
95 changes: 78 additions & 17 deletions internal/pkg/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,90 @@ package api
import (
"net/http"

"github.com/elastic/fleet-server/v7/internal/pkg/build"
"github.com/elastic/fleet-server/v7/internal/pkg/bulk"
"github.com/elastic/fleet-server/v7/internal/pkg/policy"
"go.elastic.co/apm/v2"

"github.com/elastic/fleet-server/v7/internal/pkg/rollback"

"github.com/rs/zerolog/hlog"
)

type APIOpt func(a *apiServer)

func WithCheckin(ct *CheckinT) APIOpt {
return func(a *apiServer) {
a.ct = ct
}
}

func WithEnroller(et *EnrollerT) APIOpt {
return func(a *apiServer) {
a.et = et
}
}

func WithArtifact(at *ArtifactT) APIOpt {
return func(a *apiServer) {
a.at = at
}
}

func WithAck(ack *AckT) APIOpt {
return func(a *apiServer) {
a.ack = ack
}
}

func WithStatus(st *StatusT) APIOpt {
return func(a *apiServer) {
a.st = st
}
}

func WithUpload(ut *UploadT) APIOpt {
return func(a *apiServer) {
a.ut = ut
}
}

func WithFileDelivery(ft *FileDeliveryT) APIOpt {
return func(a *apiServer) {
a.ft = ft
}
}

func WithPGP(pt *PGPRetrieverT) APIOpt {
return func(a *apiServer) {
a.pt = pt
}
}

func WithAudit(audit *AuditT) APIOpt {
return func(a *apiServer) {
a.audit = audit
}
}

func WithTracer(tracer *apm.Tracer) APIOpt {
return func(a *apiServer) {
a.tracer = tracer
}
}

// FIXME: Cleanup needed for: metrics endpoint (actually a separate listener?), endpoint auth
// FIXME: Should we use strict handler
type apiServer struct {
ct *CheckinT
et *EnrollerT
at *ArtifactT
ack *AckT
st *StatusT
sm policy.SelfMonitor
bi build.Info
ut *UploadT
ft *FileDeliveryT
pt *PGPRetrieverT
audit *AuditT
bulker bulk.Bulk
ct *CheckinT
et *EnrollerT
at *ArtifactT
ack *AckT
st *StatusT
ut *UploadT
ft *FileDeliveryT
pt *PGPRetrieverT
audit *AuditT

// tracer is used by the wrapping server to instrument the API server
tracer *apm.Tracer
}

// ensure api implements the ServerInterface
Expand Down Expand Up @@ -120,7 +181,7 @@ func (a *apiServer) UploadChunk(w http.ResponseWriter, r *http.Request, id strin
zlog := hlog.FromRequest(r).With().Str(LogAgentID, id).Logger()
w.Header().Set("Content-Type", "application/json")

if _, err := a.ut.authAPIKey(r, a.bulker, a.ut.cache); err != nil {
if _, err := a.ut.authAPIKey(r, a.ut.bulker, a.ut.cache); err != nil {
cntUploadChunk.IncError(err)
ErrorResp(w, r, err)
return
Expand Down Expand Up @@ -163,7 +224,7 @@ func (a *apiServer) Status(w http.ResponseWriter, r *http.Request, params Status
Str("mod", kStatusMod).
Logger()
w.Header().Set("Content-Type", "application/json")
err := a.st.handleStatus(zlog, a.sm, a.bi, r, w)
err := a.st.handleStatus(zlog, r, w)
if err != nil {
cntStatus.IncError(err)
ErrorResp(w, r, err)
Expand Down
9 changes: 6 additions & 3 deletions internal/pkg/api/handleCheckin.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,12 @@ func NewCheckinT(
pm policy.Monitor,
gcp monitor.GlobalCheckpointProvider,
ad *action.Dispatcher,
tr *action.TokenResolver,
bulker bulk.Bulk,
) *CheckinT {
) (*CheckinT, error) {
tr, err := action.NewTokenResolver(bulker)
if err != nil {
return nil, err
}
ct := &CheckinT{
verCon: verCon,
cfg: cfg,
Expand All @@ -115,7 +118,7 @@ func NewCheckinT(
bulker: bulker,
}

return ct
return ct, nil
}

func (ct *CheckinT) handleCheckin(zlog zerolog.Logger, w http.ResponseWriter, r *http.Request, id, userAgent string) error {
Expand Down
15 changes: 10 additions & 5 deletions internal/pkg/api/handleCheckin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,8 @@ func TestResolveSeqNo(t *testing.T) {
bulker := ftesting.NewMockBulk()
pim := mockmonitor.NewMockMonitor()
pm := policy.NewMonitor(bulker, pim, config.ServerLimits{PolicyLimit: config.Limit{Interval: 5 * time.Millisecond, Burst: 1}})
ct := NewCheckinT(verCon, cfg, c, bc, pm, nil, nil, nil, nil)
ct, err := NewCheckinT(verCon, cfg, c, bc, pm, nil, nil, nil)
assert.NoError(t, err)

resp, _ := ct.resolveSeqNo(ctx, logger, tc.req, tc.agent)
assert.Equal(t, tc.resp, resp)
Expand Down Expand Up @@ -673,7 +674,8 @@ func Test_CheckinT_writeResponse(t *testing.T) {
CompressionThresh: 1,
}

ct := NewCheckinT(verCon, cfg, nil, nil, nil, nil, nil, nil, ftesting.NewMockBulk())
ct, err := NewCheckinT(verCon, cfg, nil, nil, nil, nil, nil, ftesting.NewMockBulk())
require.NoError(t, err)

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
Expand All @@ -695,7 +697,8 @@ func Benchmark_CheckinT_writeResponse(b *testing.B) {
CompressionLevel: flate.BestSpeed,
CompressionThresh: 1,
}
ct := NewCheckinT(verCon, cfg, nil, nil, nil, nil, nil, nil, ftesting.NewMockBulk())
ct, err := NewCheckinT(verCon, cfg, nil, nil, nil, nil, nil, ftesting.NewMockBulk())
require.NoError(b, err)

logger := zerolog.Nop()
req := &http.Request{
Expand All @@ -721,7 +724,8 @@ func BenchmarkParallel_CheckinT_writeResponse(b *testing.B) {
CompressionLevel: flate.BestSpeed,
CompressionThresh: 1,
}
ct := NewCheckinT(verCon, cfg, nil, nil, nil, nil, nil, nil, ftesting.NewMockBulk())
ct, err := NewCheckinT(verCon, cfg, nil, nil, nil, nil, nil, ftesting.NewMockBulk())
require.NoError(b, err)

logger := zerolog.Nop()
req := &http.Request{
Expand Down Expand Up @@ -973,7 +977,8 @@ func TestValidateCheckinRequest(t *testing.T) {

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
checkin := NewCheckinT(verCon, tc.cfg, nil, nil, nil, nil, nil, nil, nil)
checkin, err := NewCheckinT(verCon, tc.cfg, nil, nil, nil, nil, nil, nil)
assert.NoError(t, err)
wr := httptest.NewRecorder()
logger := testlog.SetLogger(t)
valid, err := checkin.validateRequest(logger, wr, tc.req, time.Time{}, nil)
Expand Down
24 changes: 19 additions & 5 deletions internal/pkg/api/handleStatus.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,25 @@ type StatusT struct {
cfg *config.Server
bulk bulk.Bulk
cache cache.Cache
sm policy.SelfMonitor
bi build.Info
authfn AuthFunc
}

type OptFunc func(*StatusT)

func WithSelfMonitor(sm policy.SelfMonitor) OptFunc {
return func(st *StatusT) {
st.sm = sm
}
}

func WithBuildInfo(bi build.Info) OptFunc {
return func(st *StatusT) {
st.bi = bi
}
}

func NewStatusT(cfg *config.Server, bulker bulk.Bulk, cache cache.Cache, opts ...OptFunc) *StatusT {
st := &StatusT{
cfg: cfg,
Expand All @@ -63,26 +77,26 @@ func (st StatusT) authenticate(r *http.Request) (*apikey.APIKey, error) {
return authAPIKey(r, st.bulk, st.cache)
}

func (st StatusT) handleStatus(zlog zerolog.Logger, sm policy.SelfMonitor, bi build.Info, r *http.Request, w http.ResponseWriter) error {
func (st StatusT) handleStatus(zlog zerolog.Logger, r *http.Request, w http.ResponseWriter) error {
authed := true
if _, aerr := st.authfn(r); aerr != nil {
zlog.Debug().Err(aerr).Msg("unauthenticated status request, return short status response only")
authed = false
}

span, ctx := apm.StartSpan(r.Context(), "getState", "process")
state := sm.State()
state := st.sm.State()
resp := StatusAPIResponse{
Name: build.ServiceName,
Status: StatusResponseStatus(state.String()), // TODO try to make the oapi codegen less verbose here
}

if authed {
sSpan, _ := apm.StartSpan(ctx, "getVersion", "process")
bt := bi.BuildTime.Format(time.RFC3339)
bt := st.bi.BuildTime.Format(time.RFC3339)
resp.Version = &StatusResponseVersion{
Number: &bi.Version,
BuildHash: &bi.Commit,
Number: &st.bi.Version,
BuildHash: &st.bi.Commit,
BuildTime: &bt,
}
sSpan.End()
Expand Down
12 changes: 5 additions & 7 deletions internal/pkg/api/handleStatus_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,11 @@ func TestHandleStatus(t *testing.T) {
ctx = logger.WithContext(ctx)
state := client.UnitState(k)
r := apiServer{
st: NewStatusT(cfg, nil, c, withAuthFunc(tc.AuthFn)),
sm: &mockPolicyMonitor{state},
bi: fbuild.Info{
st: NewStatusT(cfg, nil, c, withAuthFunc(tc.AuthFn), WithSelfMonitor(&mockPolicyMonitor{state}), WithBuildInfo(fbuild.Info{
Version: "8.1.0",
Commit: "4eff928",
BuildTime: time.Now(),
},
})),
}

hr := Handler(&r)
Expand All @@ -117,9 +115,9 @@ func TestHandleStatus(t *testing.T) {
// Expect extended version information if authenticated
if tc.Authed {
require.NotNil(t, res.Version)
assert.Equal(t, r.bi.Version, *res.Version.Number)
assert.Equal(t, r.bi.Commit, *res.Version.BuildHash)
assert.Equal(t, r.bi.BuildTime.Format(time.RFC3339), *res.Version.BuildTime)
assert.Equal(t, r.st.bi.Version, *res.Version.Number)
assert.Equal(t, r.st.bi.Commit, *res.Version.BuildHash)
assert.Equal(t, r.st.bi.BuildTime.Format(time.RFC3339), *res.Version.BuildTime)
} else {
require.Nil(t, res.Version)
}
Expand Down
24 changes: 5 additions & 19 deletions internal/pkg/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,9 @@ import (
"net/http"

"github.com/elastic/elastic-agent-libs/transport/tlscommon"
"github.com/elastic/fleet-server/v7/internal/pkg/build"
"github.com/elastic/fleet-server/v7/internal/pkg/bulk"
"github.com/elastic/fleet-server/v7/internal/pkg/config"
"github.com/elastic/fleet-server/v7/internal/pkg/limit"
"github.com/elastic/fleet-server/v7/internal/pkg/logger"
"github.com/elastic/fleet-server/v7/internal/pkg/policy"
"go.elastic.co/apm/v2"

"github.com/rs/zerolog"
)
Expand All @@ -35,25 +31,15 @@ type server struct {
//
// The server has a listener specific conn limit and endpoint specific rate-limits.
// The underlying API structs (such as *CheckinT) may be shared between servers.
func NewServer(addr string, cfg *config.Server, ct *CheckinT, et *EnrollerT, at *ArtifactT, ack *AckT, st *StatusT, sm policy.SelfMonitor, bi build.Info, ut *UploadT, ft *FileDeliveryT, pt *PGPRetrieverT, audit *AuditT, bulker bulk.Bulk, tracer *apm.Tracer) *server {
a := &apiServer{
ct: ct,
et: et,
at: at,
ack: ack,
st: st,
sm: sm,
bi: bi,
ut: ut,
ft: ft,
pt: pt,
audit: audit,
bulker: bulker,
func NewServer(addr string, cfg *config.Server, opts ...APIOpt) *server {
a := &apiServer{}
for _, opt := range opts {
opt(a)
}
return &server{
addr: addr,
cfg: cfg,
handler: newRouter(&cfg.Limits, a, tracer),
handler: newRouter(&cfg.Limits, a, a.tracer),
}
}

Expand Down
Loading

0 comments on commit 3680c42

Please sign in to comment.