Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor API struct constructor #4169

Merged
merged 1 commit into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Kind can be one of:
# - breaking-change: a change to previously-documented behavior
# - deprecation: functionality that is being removed in a later release
# - bug-fix: fixes a problem in a previous version
# - enhancement: extends functionality but does not break or fix existing behavior
# - feature: new functionality
# - known-issue: problems that we are aware of in a given version
# - security: impacts on the security of a product or a user’s deployment.
# - upgrade: important information for someone upgrading from a prior version
# - other: does not fit into any of the other categories
kind: other

# Change summary; a 80ish characters long description of the change.
summary: Refactor API struct constructor

# Long description; in case the summary is not enough to describe the change
# this field accommodate a description without length limits.
# NOTE: This field will be rendered only for breaking-change and known-issue kinds at the moment.
#description:

# Affected component; usually one of "elastic-agent", "fleet-server", "filebeat", "metricbeat", "auditbeat", "all", etc.
component: fleet-server

# PR URL; optional; the PR number that added the changeset.
# If not present is automatically filled by the tooling finding the PR where this changelog fragment has been added.
# NOTE: the tooling supports backports, so it's able to fill the original PR number instead of the backport PR number.
# Please provide it if you are adding a fragment for a different PR.
#pr: https://github.com/owner/repo/1234

# Issue URL; optional; the GitHub issue related to this changeset (either closes or is part of).
# If not present is automatically filled by the tooling with the issue linked to the PR number.
issue: https://github.com/elastic/fleet-server/issues/3823
95 changes: 78 additions & 17 deletions internal/pkg/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,90 @@ package api
import (
"net/http"

"github.com/elastic/fleet-server/v7/internal/pkg/build"
"github.com/elastic/fleet-server/v7/internal/pkg/bulk"
"github.com/elastic/fleet-server/v7/internal/pkg/policy"
"go.elastic.co/apm/v2"

"github.com/elastic/fleet-server/v7/internal/pkg/rollback"

"github.com/rs/zerolog/hlog"
)

type APIOpt func(a *apiServer)

func WithCheckin(ct *CheckinT) APIOpt {
return func(a *apiServer) {
a.ct = ct
}
}

func WithEnroller(et *EnrollerT) APIOpt {
return func(a *apiServer) {
a.et = et
}
}

func WithArtifact(at *ArtifactT) APIOpt {
return func(a *apiServer) {
a.at = at
}
}

func WithAck(ack *AckT) APIOpt {
return func(a *apiServer) {
a.ack = ack
}
}

func WithStatus(st *StatusT) APIOpt {
return func(a *apiServer) {
a.st = st
}
}

func WithUpload(ut *UploadT) APIOpt {
return func(a *apiServer) {
a.ut = ut
}
}

func WithFileDelivery(ft *FileDeliveryT) APIOpt {
return func(a *apiServer) {
a.ft = ft
}
}

func WithPGP(pt *PGPRetrieverT) APIOpt {
return func(a *apiServer) {
a.pt = pt
}
}

func WithAudit(audit *AuditT) APIOpt {
return func(a *apiServer) {
a.audit = audit
}
}

func WithTracer(tracer *apm.Tracer) APIOpt {
return func(a *apiServer) {
a.tracer = tracer
}
}

// FIXME: Cleanup needed for: metrics endpoint (actually a separate listener?), endpoint auth
// FIXME: Should we use strict handler
type apiServer struct {
ct *CheckinT
et *EnrollerT
at *ArtifactT
ack *AckT
st *StatusT
sm policy.SelfMonitor
bi build.Info
ut *UploadT
ft *FileDeliveryT
pt *PGPRetrieverT
audit *AuditT
bulker bulk.Bulk
ct *CheckinT
et *EnrollerT
at *ArtifactT
ack *AckT
st *StatusT
ut *UploadT
ft *FileDeliveryT
pt *PGPRetrieverT
audit *AuditT

// tracer is used by the wrapping server to instrument the API server
tracer *apm.Tracer
}

// ensure api implements the ServerInterface
Expand Down Expand Up @@ -120,7 +181,7 @@ func (a *apiServer) UploadChunk(w http.ResponseWriter, r *http.Request, id strin
zlog := hlog.FromRequest(r).With().Str(LogAgentID, id).Logger()
w.Header().Set("Content-Type", "application/json")

if _, err := a.ut.authAPIKey(r, a.bulker, a.ut.cache); err != nil {
if _, err := a.ut.authAPIKey(r, a.ut.bulker, a.ut.cache); err != nil {
cntUploadChunk.IncError(err)
ErrorResp(w, r, err)
return
Expand Down Expand Up @@ -163,7 +224,7 @@ func (a *apiServer) Status(w http.ResponseWriter, r *http.Request, params Status
Str("mod", kStatusMod).
Logger()
w.Header().Set("Content-Type", "application/json")
err := a.st.handleStatus(zlog, a.sm, a.bi, r, w)
err := a.st.handleStatus(zlog, r, w)
if err != nil {
cntStatus.IncError(err)
ErrorResp(w, r, err)
Expand Down
9 changes: 6 additions & 3 deletions internal/pkg/api/handleCheckin.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,12 @@ func NewCheckinT(
pm policy.Monitor,
gcp monitor.GlobalCheckpointProvider,
ad *action.Dispatcher,
tr *action.TokenResolver,
bulker bulk.Bulk,
) *CheckinT {
) (*CheckinT, error) {
tr, err := action.NewTokenResolver(bulker)
if err != nil {
return nil, err
}
ct := &CheckinT{
verCon: verCon,
cfg: cfg,
Expand All @@ -115,7 +118,7 @@ func NewCheckinT(
bulker: bulker,
}

return ct
return ct, nil
}

func (ct *CheckinT) handleCheckin(zlog zerolog.Logger, w http.ResponseWriter, r *http.Request, id, userAgent string) error {
Expand Down
15 changes: 10 additions & 5 deletions internal/pkg/api/handleCheckin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,8 @@ func TestResolveSeqNo(t *testing.T) {
bulker := ftesting.NewMockBulk()
pim := mockmonitor.NewMockMonitor()
pm := policy.NewMonitor(bulker, pim, config.ServerLimits{PolicyLimit: config.Limit{Interval: 5 * time.Millisecond, Burst: 1}})
ct := NewCheckinT(verCon, cfg, c, bc, pm, nil, nil, nil, nil)
ct, err := NewCheckinT(verCon, cfg, c, bc, pm, nil, nil, nil)
assert.NoError(t, err)

resp, _ := ct.resolveSeqNo(ctx, logger, tc.req, tc.agent)
assert.Equal(t, tc.resp, resp)
Expand Down Expand Up @@ -673,7 +674,8 @@ func Test_CheckinT_writeResponse(t *testing.T) {
CompressionThresh: 1,
}

ct := NewCheckinT(verCon, cfg, nil, nil, nil, nil, nil, nil, ftesting.NewMockBulk())
ct, err := NewCheckinT(verCon, cfg, nil, nil, nil, nil, nil, ftesting.NewMockBulk())
require.NoError(t, err)

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
Expand All @@ -695,7 +697,8 @@ func Benchmark_CheckinT_writeResponse(b *testing.B) {
CompressionLevel: flate.BestSpeed,
CompressionThresh: 1,
}
ct := NewCheckinT(verCon, cfg, nil, nil, nil, nil, nil, nil, ftesting.NewMockBulk())
ct, err := NewCheckinT(verCon, cfg, nil, nil, nil, nil, nil, ftesting.NewMockBulk())
require.NoError(b, err)

logger := zerolog.Nop()
req := &http.Request{
Expand All @@ -721,7 +724,8 @@ func BenchmarkParallel_CheckinT_writeResponse(b *testing.B) {
CompressionLevel: flate.BestSpeed,
CompressionThresh: 1,
}
ct := NewCheckinT(verCon, cfg, nil, nil, nil, nil, nil, nil, ftesting.NewMockBulk())
ct, err := NewCheckinT(verCon, cfg, nil, nil, nil, nil, nil, ftesting.NewMockBulk())
require.NoError(b, err)

logger := zerolog.Nop()
req := &http.Request{
Expand Down Expand Up @@ -973,7 +977,8 @@ func TestValidateCheckinRequest(t *testing.T) {

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
checkin := NewCheckinT(verCon, tc.cfg, nil, nil, nil, nil, nil, nil, nil)
checkin, err := NewCheckinT(verCon, tc.cfg, nil, nil, nil, nil, nil, nil)
assert.NoError(t, err)
wr := httptest.NewRecorder()
logger := testlog.SetLogger(t)
valid, err := checkin.validateRequest(logger, wr, tc.req, time.Time{}, nil)
Expand Down
24 changes: 19 additions & 5 deletions internal/pkg/api/handleStatus.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,25 @@ type StatusT struct {
cfg *config.Server
bulk bulk.Bulk
cache cache.Cache
sm policy.SelfMonitor
bi build.Info
authfn AuthFunc
}

type OptFunc func(*StatusT)

func WithSelfMonitor(sm policy.SelfMonitor) OptFunc {
return func(st *StatusT) {
st.sm = sm
}
}

func WithBuildInfo(bi build.Info) OptFunc {
return func(st *StatusT) {
st.bi = bi
}
}

func NewStatusT(cfg *config.Server, bulker bulk.Bulk, cache cache.Cache, opts ...OptFunc) *StatusT {
st := &StatusT{
cfg: cfg,
Expand All @@ -63,26 +77,26 @@ func (st StatusT) authenticate(r *http.Request) (*apikey.APIKey, error) {
return authAPIKey(r, st.bulk, st.cache)
}

func (st StatusT) handleStatus(zlog zerolog.Logger, sm policy.SelfMonitor, bi build.Info, r *http.Request, w http.ResponseWriter) error {
func (st StatusT) handleStatus(zlog zerolog.Logger, r *http.Request, w http.ResponseWriter) error {
authed := true
if _, aerr := st.authfn(r); aerr != nil {
zlog.Debug().Err(aerr).Msg("unauthenticated status request, return short status response only")
authed = false
}

span, ctx := apm.StartSpan(r.Context(), "getState", "process")
state := sm.State()
state := st.sm.State()
resp := StatusAPIResponse{
Name: build.ServiceName,
Status: StatusResponseStatus(state.String()), // TODO try to make the oapi codegen less verbose here
}

if authed {
sSpan, _ := apm.StartSpan(ctx, "getVersion", "process")
bt := bi.BuildTime.Format(time.RFC3339)
bt := st.bi.BuildTime.Format(time.RFC3339)
resp.Version = &StatusResponseVersion{
Number: &bi.Version,
BuildHash: &bi.Commit,
Number: &st.bi.Version,
BuildHash: &st.bi.Commit,
BuildTime: &bt,
}
sSpan.End()
Expand Down
12 changes: 5 additions & 7 deletions internal/pkg/api/handleStatus_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,11 @@ func TestHandleStatus(t *testing.T) {
ctx = logger.WithContext(ctx)
state := client.UnitState(k)
r := apiServer{
st: NewStatusT(cfg, nil, c, withAuthFunc(tc.AuthFn)),
sm: &mockPolicyMonitor{state},
bi: fbuild.Info{
st: NewStatusT(cfg, nil, c, withAuthFunc(tc.AuthFn), WithSelfMonitor(&mockPolicyMonitor{state}), WithBuildInfo(fbuild.Info{
Version: "8.1.0",
Commit: "4eff928",
BuildTime: time.Now(),
},
})),
}

hr := Handler(&r)
Expand All @@ -117,9 +115,9 @@ func TestHandleStatus(t *testing.T) {
// Expect extended version information if authenticated
if tc.Authed {
require.NotNil(t, res.Version)
assert.Equal(t, r.bi.Version, *res.Version.Number)
assert.Equal(t, r.bi.Commit, *res.Version.BuildHash)
assert.Equal(t, r.bi.BuildTime.Format(time.RFC3339), *res.Version.BuildTime)
assert.Equal(t, r.st.bi.Version, *res.Version.Number)
assert.Equal(t, r.st.bi.Commit, *res.Version.BuildHash)
assert.Equal(t, r.st.bi.BuildTime.Format(time.RFC3339), *res.Version.BuildTime)
} else {
require.Nil(t, res.Version)
}
Expand Down
24 changes: 5 additions & 19 deletions internal/pkg/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,9 @@ import (
"net/http"

"github.com/elastic/elastic-agent-libs/transport/tlscommon"
"github.com/elastic/fleet-server/v7/internal/pkg/build"
"github.com/elastic/fleet-server/v7/internal/pkg/bulk"
"github.com/elastic/fleet-server/v7/internal/pkg/config"
"github.com/elastic/fleet-server/v7/internal/pkg/limit"
"github.com/elastic/fleet-server/v7/internal/pkg/logger"
"github.com/elastic/fleet-server/v7/internal/pkg/policy"
"go.elastic.co/apm/v2"

"github.com/rs/zerolog"
)
Expand All @@ -35,25 +31,15 @@ type server struct {
//
// The server has a listener specific conn limit and endpoint specific rate-limits.
// The underlying API structs (such as *CheckinT) may be shared between servers.
func NewServer(addr string, cfg *config.Server, ct *CheckinT, et *EnrollerT, at *ArtifactT, ack *AckT, st *StatusT, sm policy.SelfMonitor, bi build.Info, ut *UploadT, ft *FileDeliveryT, pt *PGPRetrieverT, audit *AuditT, bulker bulk.Bulk, tracer *apm.Tracer) *server {
a := &apiServer{
ct: ct,
et: et,
at: at,
ack: ack,
st: st,
sm: sm,
bi: bi,
ut: ut,
ft: ft,
pt: pt,
audit: audit,
bulker: bulker,
func NewServer(addr string, cfg *config.Server, opts ...APIOpt) *server {
a := &apiServer{}
for _, opt := range opts {
opt(a)
}
return &server{
addr: addr,
cfg: cfg,
handler: newRouter(&cfg.Limits, a, tracer),
handler: newRouter(&cfg.Limits, a, a.tracer),
}
}

Expand Down
Loading
Loading