Skip to content

Commit

Permalink
context propagation: NewServer()
Browse files Browse the repository at this point in the history
  • Loading branch information
mmetc committed Oct 3, 2024
1 parent 114b8ab commit 837d65c
Show file tree
Hide file tree
Showing 9 changed files with 38 additions and 44 deletions.
2 changes: 1 addition & 1 deletion .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ issues:
# `err` is often shadowed, we may continue to do it
- linters:
- govet
text: "shadow: declaration of \"err\" shadows declaration"
text: "shadow: declaration of \"(err|ctx)\" shadows declaration"

- linters:
- errcheck
Expand Down
5 changes: 3 additions & 2 deletions cmd/crowdsec/api.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"context"
"errors"
"fmt"
"runtime"
Expand All @@ -14,12 +15,12 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
)

func initAPIServer(cConfig *csconfig.Config) (*apiserver.APIServer, error) {
func initAPIServer(ctx context.Context, cConfig *csconfig.Config) (*apiserver.APIServer, error) {
if cConfig.API.Server.OnlineClient == nil || cConfig.API.Server.OnlineClient.Credentials == nil {
log.Info("push and pull to Central API disabled")
}

apiServer, err := apiserver.NewServer(cConfig.API.Server)
apiServer, err := apiserver.NewServer(ctx, cConfig.API.Server)
if err != nil {
return nil, fmt.Errorf("unable to run local API: %w", err)
}
Expand Down
6 changes: 4 additions & 2 deletions cmd/crowdsec/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ func debugHandler(sig os.Signal, cConfig *csconfig.Config) error {
func reloadHandler(sig os.Signal) (*csconfig.Config, error) {
var tmpFile string

ctx := context.TODO()

Check warning on line 56 in cmd/crowdsec/serve.go

View check run for this annotation

Codecov / codecov/patch

cmd/crowdsec/serve.go#L56

Added line #L56 was not covered by tests
// re-initialize tombs
acquisTomb = tomb.Tomb{}
parsersTomb = tomb.Tomb{}
Expand All @@ -74,7 +76,7 @@ func reloadHandler(sig os.Signal) (*csconfig.Config, error) {
cConfig.API.Server.OnlineClient = nil
}

apiServer, err := initAPIServer(cConfig)
apiServer, err := initAPIServer(ctx, cConfig)
if err != nil {
return nil, fmt.Errorf("unable to init api server: %w", err)
}
Expand Down Expand Up @@ -374,7 +376,7 @@ func Serve(cConfig *csconfig.Config, agentReady chan bool) error {
cConfig.API.Server.OnlineClient = nil
}

apiServer, err := initAPIServer(cConfig)
apiServer, err := initAPIServer(ctx, cConfig)
if err != nil {
return fmt.Errorf("api server init: %w", err)
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/apiserver/alerts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (l *LAPI) RecordResponse(t *testing.T, ctx context.Context, verb string, ur
}

func InitMachineTest(t *testing.T, ctx context.Context) (*gin.Engine, models.WatcherAuthResponse, csconfig.Config) {
router, config := NewAPITest(t)
router, config := NewAPITest(t, ctx)
loginResp := LoginToTestAPI(t, ctx, router, config)

return router, loginResp, config
Expand Down Expand Up @@ -137,7 +137,7 @@ func TestCreateAlert(t *testing.T) {

func TestCreateAlertChannels(t *testing.T) {
ctx := context.Background()
apiServer, config := NewAPIServer(t)
apiServer, config := NewAPIServer(t, ctx)
apiServer.controller.PluginChannel = make(chan csplugin.ProfileAlert)
apiServer.InitController()

Expand Down Expand Up @@ -437,7 +437,7 @@ func TestDeleteAlertTrustedIPS(t *testing.T) {
// cfg.API.Server.TrustedIPs = []string{"1.2.3.4", "1.2.4.0/24", "::"}
cfg.API.Server.TrustedIPs = []string{"1.2.3.4", "1.2.4.0/24"}
cfg.API.Server.ListenURI = "::8080"
server, err := NewServer(cfg.API.Server)
server, err := NewServer(ctx, cfg.API.Server)
require.NoError(t, err)

err = server.InitController()
Expand Down
3 changes: 1 addition & 2 deletions pkg/apiserver/api_key_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@ import (
)

func TestAPIKey(t *testing.T) {
router, config := NewAPITest(t)

ctx := context.Background()
router, config := NewAPITest(t, ctx)

APIKey := CreateTestBouncer(t, ctx, config.API.Server.DbConfig)

Expand Down
4 changes: 1 addition & 3 deletions pkg/apiserver/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,9 @@ func newGinLogger(config *csconfig.LocalApiServerCfg) (*log.Logger, string, erro

// NewServer creates a LAPI server.
// It sets up a gin router, a database client, and a controller.
func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) {
func NewServer(ctx context.Context, config *csconfig.LocalApiServerCfg) (*APIServer, error) {
var flushScheduler *gocron.Scheduler

ctx := context.TODO()

dbClient, err := database.NewClient(ctx, config.DbConfig)
if err != nil {
return nil, fmt.Errorf("unable to init database client: %w", err)
Expand Down
33 changes: 17 additions & 16 deletions pkg/apiserver/apiserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,12 @@ func LoadTestConfigForwardedFor(t *testing.T) csconfig.Config {
return config
}

func NewAPIServer(t *testing.T) (*APIServer, csconfig.Config) {
func NewAPIServer(t *testing.T, ctx context.Context) (*APIServer, csconfig.Config) {
config := LoadTestConfig(t)

os.Remove("./ent")

apiServer, err := NewServer(config.API.Server)
apiServer, err := NewServer(ctx, config.API.Server)
require.NoError(t, err)

log.Printf("Creating new API server")
Expand All @@ -149,8 +149,8 @@ func NewAPIServer(t *testing.T) (*APIServer, csconfig.Config) {
return apiServer, config
}

func NewAPITest(t *testing.T) (*gin.Engine, csconfig.Config) {
apiServer, config := NewAPIServer(t)
func NewAPITest(t *testing.T, ctx context.Context) (*gin.Engine, csconfig.Config) {
apiServer, config := NewAPIServer(t, ctx)

err := apiServer.InitController()
require.NoError(t, err)
Expand All @@ -161,12 +161,12 @@ func NewAPITest(t *testing.T) (*gin.Engine, csconfig.Config) {
return router, config
}

func NewAPITestForwardedFor(t *testing.T) (*gin.Engine, csconfig.Config) {
func NewAPITestForwardedFor(t *testing.T, ctx context.Context) (*gin.Engine, csconfig.Config) {
config := LoadTestConfigForwardedFor(t)

os.Remove("./ent")

apiServer, err := NewServer(config.API.Server)
apiServer, err := NewServer(ctx, config.API.Server)
require.NoError(t, err)

err = apiServer.InitController()
Expand Down Expand Up @@ -302,28 +302,29 @@ func CreateTestBouncer(t *testing.T, ctx context.Context, config *csconfig.Datab
}

func TestWithWrongDBConfig(t *testing.T) {
ctx := context.Background()
config := LoadTestConfig(t)
config.API.Server.DbConfig.Type = "test"
apiServer, err := NewServer(config.API.Server)
apiServer, err := NewServer(ctx, config.API.Server)

cstest.RequireErrorContains(t, err, "unable to init database client: unknown database type 'test'")
assert.Nil(t, apiServer)
}

func TestWithWrongFlushConfig(t *testing.T) {
ctx := context.Background()
config := LoadTestConfig(t)
maxItems := -1
config.API.Server.DbConfig.Flush.MaxItems = &maxItems
apiServer, err := NewServer(config.API.Server)
apiServer, err := NewServer(ctx, config.API.Server)

cstest.RequireErrorContains(t, err, "max_items can't be zero or negative")
assert.Nil(t, apiServer)
}

func TestUnknownPath(t *testing.T) {
router, _ := NewAPITest(t)

ctx := context.Background()
router, _ := NewAPITest(t, ctx)

w := httptest.NewRecorder()
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/test", nil)
Expand All @@ -349,6 +350,8 @@ ListenURI string `yaml:"listen_uri,omitempty"` //127.0
*/

func TestLoggingDebugToFileConfig(t *testing.T) {
ctx := context.Background()

/*declare settings*/
maxAge := "1h"
flushConfig := csconfig.FlushDBCfg{
Expand Down Expand Up @@ -378,12 +381,10 @@ func TestLoggingDebugToFileConfig(t *testing.T) {
err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false)
require.NoError(t, err)

api, err := NewServer(&cfg)
api, err := NewServer(ctx, &cfg)
require.NoError(t, err)
require.NotNil(t, api)

ctx := context.Background()

w := httptest.NewRecorder()
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/test42", nil)
req.Header.Set("User-Agent", UserAgent)
Expand All @@ -402,6 +403,8 @@ func TestLoggingDebugToFileConfig(t *testing.T) {
}

func TestLoggingErrorToFileConfig(t *testing.T) {
ctx := context.Background()

/*declare settings*/
maxAge := "1h"
flushConfig := csconfig.FlushDBCfg{
Expand Down Expand Up @@ -430,12 +433,10 @@ func TestLoggingErrorToFileConfig(t *testing.T) {
err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false)
require.NoError(t, err)

api, err := NewServer(&cfg)
api, err := NewServer(ctx, &cfg)
require.NoError(t, err)
require.NotNil(t, api)

ctx := context.Background()

w := httptest.NewRecorder()
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/test42", nil)
req.Header.Set("User-Agent", UserAgent)
Expand Down
3 changes: 1 addition & 2 deletions pkg/apiserver/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@ import (
)

func TestLogin(t *testing.T) {
router, config := NewAPITest(t)

ctx := context.Background()
router, config := NewAPITest(t, ctx)

body := CreateTestMachine(t, router, "")

Expand Down
20 changes: 7 additions & 13 deletions pkg/apiserver/machines_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@ import (
)

func TestCreateMachine(t *testing.T) {
router, _ := NewAPITest(t)

ctx := context.Background()
router, _ := NewAPITest(t, ctx)

// Create machine with invalid format
w := httptest.NewRecorder()
Expand Down Expand Up @@ -53,10 +52,9 @@ func TestCreateMachine(t *testing.T) {
}

func TestCreateMachineWithForwardedFor(t *testing.T) {
router, config := NewAPITestForwardedFor(t)
router.TrustedPlatform = "X-Real-IP"

ctx := context.Background()
router, config := NewAPITestForwardedFor(t, ctx)
router.TrustedPlatform = "X-Real-IP"

// Create machine
b, err := json.Marshal(MachineTest)
Expand All @@ -79,9 +77,8 @@ func TestCreateMachineWithForwardedFor(t *testing.T) {
}

func TestCreateMachineWithForwardedForNoConfig(t *testing.T) {
router, config := NewAPITest(t)

ctx := context.Background()
router, config := NewAPITest(t, ctx)

// Create machine
b, err := json.Marshal(MachineTest)
Expand All @@ -106,9 +103,8 @@ func TestCreateMachineWithForwardedForNoConfig(t *testing.T) {
}

func TestCreateMachineWithoutForwardedFor(t *testing.T) {
router, config := NewAPITestForwardedFor(t)

ctx := context.Background()
router, config := NewAPITestForwardedFor(t, ctx)

// Create machine
b, err := json.Marshal(MachineTest)
Expand All @@ -132,9 +128,8 @@ func TestCreateMachineWithoutForwardedFor(t *testing.T) {
}

func TestCreateMachineAlreadyExist(t *testing.T) {
router, _ := NewAPITest(t)

ctx := context.Background()
router, _ := NewAPITest(t, ctx)

body := CreateTestMachine(t, router, "")

Expand All @@ -153,9 +148,8 @@ func TestCreateMachineAlreadyExist(t *testing.T) {
}

func TestAutoRegistration(t *testing.T) {
router, _ := NewAPITest(t)

ctx := context.Background()
router, _ := NewAPITest(t, ctx)

// Invalid registration token / valid source IP
regReq := MachineTest
Expand Down

0 comments on commit 837d65c

Please sign in to comment.