Skip to content

Commit

Permalink
context propagation: apic, unit tests (#3271)
Browse files Browse the repository at this point in the history
* context propagation: apic

* context propagation: unit tests
  • Loading branch information
mmetc authored Oct 3, 2024
1 parent af3116d commit 06adbe0
Show file tree
Hide file tree
Showing 10 changed files with 189 additions and 179 deletions.
6 changes: 6 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,12 @@ issues:
path: "pkg/(.+)_test.go"
text: "deep-exit: .*"

# we use t,ctx instead of ctx,t in tests
- linters:
- revive
path: "pkg/(.+)_test.go"
text: "context-as-argument: context.Context should be the first parameter of a function"

# tolerate deep exit in cobra's OnInitialize, for now
- linters:
- revive
Expand Down
144 changes: 72 additions & 72 deletions pkg/apiserver/alerts_test.go

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pkg/apiserver/api_key_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func TestAPIKey(t *testing.T) {

ctx := context.Background()

APIKey := CreateTestBouncer(t, config.API.Server.DbConfig)
APIKey := CreateTestBouncer(t, ctx, config.API.Server.DbConfig)

// Login with empty token
w := httptest.NewRecorder()
Expand Down
22 changes: 9 additions & 13 deletions pkg/apiserver/apic.go
Original file line number Diff line number Diff line change
Expand Up @@ -454,11 +454,9 @@ func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, delet
return nbDeleted, nil
}

func (a *apic) HandleDeletedDecisionsV3(deletedDecisions []*modelscapi.GetDecisionsStreamResponseDeletedItem, deleteCounters map[string]map[string]int) (int, error) {
func (a *apic) HandleDeletedDecisionsV3(ctx context.Context, deletedDecisions []*modelscapi.GetDecisionsStreamResponseDeletedItem, deleteCounters map[string]map[string]int) (int, error) {
var nbDeleted int

ctx := context.TODO()

for _, decisions := range deletedDecisions {
scope := decisions.Scope

Expand Down Expand Up @@ -676,7 +674,7 @@ func (a *apic) PullTop(ctx context.Context, forcePull bool) error {
addCounters, deleteCounters := makeAddAndDeleteCounters()

// process deleted decisions
nbDeleted, err := a.HandleDeletedDecisionsV3(data.Deleted, deleteCounters)
nbDeleted, err := a.HandleDeletedDecisionsV3(ctx, data.Deleted, deleteCounters)
if err != nil {
return err
}
Expand All @@ -697,7 +695,7 @@ func (a *apic) PullTop(ctx context.Context, forcePull bool) error {
alertsFromCapi := []*models.Alert{alert}
alertsFromCapi = fillAlertsWithDecisions(alertsFromCapi, decisions, addCounters)

err = a.SaveAlerts(alertsFromCapi, addCounters, deleteCounters)
err = a.SaveAlerts(ctx, alertsFromCapi, addCounters, deleteCounters)
if err != nil {
return fmt.Errorf("while saving alerts: %w", err)
}
Expand Down Expand Up @@ -766,9 +764,7 @@ func (a *apic) ApplyApicWhitelists(decisions []*models.Decision) []*models.Decis
return decisions[:outIdx]
}

func (a *apic) SaveAlerts(alertsFromCapi []*models.Alert, addCounters map[string]map[string]int, deleteCounters map[string]map[string]int) error {
ctx := context.TODO()

func (a *apic) SaveAlerts(ctx context.Context, alertsFromCapi []*models.Alert, addCounters map[string]map[string]int, deleteCounters map[string]map[string]int) error {
for _, alert := range alertsFromCapi {
setAlertScenario(alert, addCounters, deleteCounters)
log.Debugf("%s has %d decisions", *alert.Source.Scope, len(alert.Decisions))
Expand All @@ -788,13 +784,13 @@ func (a *apic) SaveAlerts(alertsFromCapi []*models.Alert, addCounters map[string
return nil
}

func (a *apic) ShouldForcePullBlocklist(blocklist *modelscapi.BlocklistLink) (bool, error) {
func (a *apic) ShouldForcePullBlocklist(ctx context.Context, blocklist *modelscapi.BlocklistLink) (bool, error) {
// we should force pull if the blocklist decisions are about to expire or there's no decision in the db
alertQuery := a.dbClient.Ent.Alert.Query()
alertQuery.Where(alert.SourceScopeEQ(fmt.Sprintf("%s:%s", types.ListOrigin, *blocklist.Name)))
alertQuery.Order(ent.Desc(alert.FieldCreatedAt))

alertInstance, err := alertQuery.First(context.Background())
alertInstance, err := alertQuery.First(ctx)
if err != nil {
if ent.IsNotFound(err) {
log.Debugf("no alert found for %s, force refresh", *blocklist.Name)
Expand All @@ -807,7 +803,7 @@ func (a *apic) ShouldForcePullBlocklist(blocklist *modelscapi.BlocklistLink) (bo
decisionQuery := a.dbClient.Ent.Decision.Query()
decisionQuery.Where(decision.HasOwnerWith(alert.IDEQ(alertInstance.ID)))

firstDecision, err := decisionQuery.First(context.Background())
firstDecision, err := decisionQuery.First(ctx)
if err != nil {
if ent.IsNotFound(err) {
log.Debugf("no decision found for %s, force refresh", *blocklist.Name)
Expand Down Expand Up @@ -837,7 +833,7 @@ func (a *apic) updateBlocklist(ctx context.Context, client *apiclient.ApiClient,
}

if !forcePull {
_forcePull, err := a.ShouldForcePullBlocklist(blocklist)
_forcePull, err := a.ShouldForcePullBlocklist(ctx, blocklist)
if err != nil {
return fmt.Errorf("while checking if we should force pull blocklist %s: %w", *blocklist.Name, err)
}
Expand Down Expand Up @@ -889,7 +885,7 @@ func (a *apic) updateBlocklist(ctx context.Context, client *apiclient.ApiClient,
alertsFromCapi := []*models.Alert{alert}
alertsFromCapi = fillAlertsWithDecisions(alertsFromCapi, decisions, addCounters)

err = a.SaveAlerts(alertsFromCapi, addCounters, nil)
err = a.SaveAlerts(ctx, alertsFromCapi, addCounters, nil)
if err != nil {
return fmt.Errorf("while saving alert from blocklist %s: %w", *blocklist.Name, err)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/apiserver/apic_metrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func TestAPICSendMetrics(t *testing.T) {
)
require.NoError(t, err)

api := getAPIC(t)
api := getAPIC(t, ctx)
api.pushInterval = time.Millisecond
api.pushIntervalFirst = time.Millisecond
api.apiClient = apiClient
Expand Down
67 changes: 33 additions & 34 deletions pkg/apiserver/apic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,9 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/types"
)

func getDBClient(t *testing.T) *database.Client {
func getDBClient(t *testing.T, ctx context.Context) *database.Client {
t.Helper()

ctx := context.Background()

dbPath, err := os.CreateTemp("", "*sqlite")
require.NoError(t, err)
dbClient, err := database.NewClient(ctx, &csconfig.DatabaseCfg{
Expand All @@ -51,9 +49,9 @@ func getDBClient(t *testing.T) *database.Client {
return dbClient
}

func getAPIC(t *testing.T) *apic {
func getAPIC(t *testing.T, ctx context.Context) *apic {
t.Helper()
dbClient := getDBClient(t)
dbClient := getDBClient(t, ctx)

return &apic{
AlertsAddChan: make(chan []*models.Alert),
Expand Down Expand Up @@ -84,8 +82,8 @@ func absDiff(a int, b int) int {
return c
}

func assertTotalDecisionCount(t *testing.T, dbClient *database.Client, count int) {
d := dbClient.Ent.Decision.Query().AllX(context.Background())
func assertTotalDecisionCount(t *testing.T, ctx context.Context, dbClient *database.Client, count int) {
d := dbClient.Ent.Decision.Query().AllX(ctx)
assert.Len(t, d, count)
}

Expand All @@ -111,9 +109,8 @@ func assertTotalAlertCount(t *testing.T, dbClient *database.Client, count int) {
}

func TestAPICCAPIPullIsOld(t *testing.T) {
api := getAPIC(t)

ctx := context.Background()
api := getAPIC(t, ctx)

isOld, err := api.CAPIPullIsOld(ctx)
require.NoError(t, err)
Expand Down Expand Up @@ -169,7 +166,7 @@ func TestAPICFetchScenariosListFromDB(t *testing.T) {

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
api := getAPIC(t)
api := getAPIC(t, ctx)
for machineID, scenarios := range tc.machineIDsWithScenarios {
api.dbClient.Ent.Machine.Create().
SetMachineId(machineID).
Expand All @@ -183,7 +180,7 @@ func TestAPICFetchScenariosListFromDB(t *testing.T) {
require.NoError(t, err)

for machineID := range tc.machineIDsWithScenarios {
api.dbClient.Ent.Machine.Delete().Where(machine.MachineIdEQ(machineID)).ExecX(context.Background())
api.dbClient.Ent.Machine.Delete().Where(machine.MachineIdEQ(machineID)).ExecX(ctx)
}

assert.ElementsMatch(t, tc.expectedScenarios, scenarios)
Expand All @@ -192,6 +189,8 @@ func TestAPICFetchScenariosListFromDB(t *testing.T) {
}

func TestNewAPIC(t *testing.T) {
ctx := context.Background()

var testConfig *csconfig.OnlineApiClientCfg

setConfig := func() {
Expand Down Expand Up @@ -219,23 +218,21 @@ func TestNewAPIC(t *testing.T) {
name: "simple",
action: func() {},
args: args{
dbClient: getDBClient(t),
dbClient: getDBClient(t, ctx),
consoleConfig: LoadTestConfig(t).API.Server.ConsoleConfig,
},
},
{
name: "error in parsing URL",
action: func() { testConfig.Credentials.URL = "foobar http://" },
args: args{
dbClient: getDBClient(t),
dbClient: getDBClient(t, ctx),
consoleConfig: LoadTestConfig(t).API.Server.ConsoleConfig,
},
expectedErr: "first path segment in URL cannot contain colon",
},
}

ctx := context.Background()

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
setConfig()
Expand All @@ -259,7 +256,8 @@ func TestNewAPIC(t *testing.T) {
}

func TestAPICHandleDeletedDecisions(t *testing.T) {
api := getAPIC(t)
ctx := context.Background()
api := getAPIC(t, ctx)
_, deleteCounters := makeAddAndDeleteCounters()

decision1 := api.dbClient.Ent.Decision.Create().
Expand All @@ -280,7 +278,7 @@ func TestAPICHandleDeletedDecisions(t *testing.T) {
SetOrigin(types.CAPIOrigin).
SaveX(context.Background())

assertTotalDecisionCount(t, api.dbClient, 2)
assertTotalDecisionCount(t, ctx, api.dbClient, 2)

nbDeleted, err := api.HandleDeletedDecisions([]*models.Decision{{
Value: ptr.Of("1.2.3.4"),
Expand Down Expand Up @@ -359,7 +357,7 @@ func TestAPICGetMetrics(t *testing.T) {

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
apiClient := getAPIC(t)
apiClient := getAPIC(t, ctx)
cleanUp(apiClient)

for i, machineID := range tc.machineIDs {
Expand All @@ -370,7 +368,7 @@ func TestAPICGetMetrics(t *testing.T) {
SetScenarios("crowdsecurity/test").
SetLastPush(time.Time{}).
SetUpdatedAt(time.Time{}).
ExecX(context.Background())
ExecX(ctx)
}

for i, bouncerName := range tc.bouncers {
Expand All @@ -380,7 +378,7 @@ func TestAPICGetMetrics(t *testing.T) {
SetAPIKey("foobar").
SetRevoked(false).
SetLastPull(time.Time{}).
ExecX(context.Background())
ExecX(ctx)
}

foundMetrics, err := apiClient.GetMetrics(ctx)
Expand Down Expand Up @@ -555,7 +553,7 @@ func TestFillAlertsWithDecisions(t *testing.T) {

func TestAPICWhitelists(t *testing.T) {
ctx := context.Background()
api := getAPIC(t)
api := getAPIC(t, ctx)
// one whitelist on IP, one on CIDR
api.whitelists = &csconfig.CapiWhitelist{}
api.whitelists.Ips = append(api.whitelists.Ips, net.ParseIP("9.2.3.4"), net.ParseIP("7.2.3.4"))
Expand All @@ -578,7 +576,7 @@ func TestAPICWhitelists(t *testing.T) {
SetScenario("crowdsecurity/ssh-bf").
SetUntil(time.Now().Add(time.Hour)).
ExecX(context.Background())
assertTotalDecisionCount(t, api.dbClient, 1)
assertTotalDecisionCount(t, ctx, api.dbClient, 1)
assertTotalValidDecisionCount(t, api.dbClient, 1)
httpmock.Activate()

Expand Down Expand Up @@ -693,7 +691,7 @@ func TestAPICWhitelists(t *testing.T) {
err = api.PullTop(ctx, false)
require.NoError(t, err)

assertTotalDecisionCount(t, api.dbClient, 5) // 2 from FIRE + 2 from bl + 1 existing
assertTotalDecisionCount(t, ctx, api.dbClient, 5) // 2 from FIRE + 2 from bl + 1 existing
assertTotalValidDecisionCount(t, api.dbClient, 4)
assertTotalAlertCount(t, api.dbClient, 3) // 2 for list sub , 1 for community list.
alerts := api.dbClient.Ent.Alert.Query().AllX(context.Background())
Expand Down Expand Up @@ -742,16 +740,16 @@ func TestAPICWhitelists(t *testing.T) {

func TestAPICPullTop(t *testing.T) {
ctx := context.Background()
api := getAPIC(t)
api := getAPIC(t, ctx)
api.dbClient.Ent.Decision.Create().
SetOrigin(types.CAPIOrigin).
SetType("ban").
SetValue("9.9.9.9").
SetScope("Ip").
SetScenario("crowdsecurity/ssh-bf").
SetUntil(time.Now().Add(time.Hour)).
ExecX(context.Background())
assertTotalDecisionCount(t, api.dbClient, 1)
ExecX(ctx)
assertTotalDecisionCount(t, ctx, api.dbClient, 1)
assertTotalValidDecisionCount(t, api.dbClient, 1)
httpmock.Activate()

Expand Down Expand Up @@ -835,7 +833,7 @@ func TestAPICPullTop(t *testing.T) {
err = api.PullTop(ctx, false)
require.NoError(t, err)

assertTotalDecisionCount(t, api.dbClient, 5)
assertTotalDecisionCount(t, ctx, api.dbClient, 5)
assertTotalValidDecisionCount(t, api.dbClient, 4)
assertTotalAlertCount(t, api.dbClient, 3) // 2 for list sub , 1 for community list.
alerts := api.dbClient.Ent.Alert.Query().AllX(context.Background())
Expand Down Expand Up @@ -868,7 +866,7 @@ func TestAPICPullTop(t *testing.T) {
func TestAPICPullTopBLCacheFirstCall(t *testing.T) {
ctx := context.Background()
// no decision in db, no last modified parameter.
api := getAPIC(t)
api := getAPIC(t, ctx)

httpmock.Activate()
defer httpmock.DeactivateAndReset()
Expand Down Expand Up @@ -943,7 +941,7 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) {

func TestAPICPullTopBLCacheForceCall(t *testing.T) {
ctx := context.Background()
api := getAPIC(t)
api := getAPIC(t, ctx)

httpmock.Activate()
defer httpmock.DeactivateAndReset()
Expand Down Expand Up @@ -1019,7 +1017,7 @@ func TestAPICPullTopBLCacheForceCall(t *testing.T) {

func TestAPICPullBlocklistCall(t *testing.T) {
ctx := context.Background()
api := getAPIC(t)
api := getAPIC(t, ctx)

httpmock.Activate()
defer httpmock.DeactivateAndReset()
Expand Down Expand Up @@ -1052,6 +1050,7 @@ func TestAPICPullBlocklistCall(t *testing.T) {
}

func TestAPICPush(t *testing.T) {
ctx := context.Background()
tests := []struct {
name string
alerts []*models.Alert
Expand Down Expand Up @@ -1105,7 +1104,7 @@ func TestAPICPush(t *testing.T) {

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
api := getAPIC(t)
api := getAPIC(t, ctx)
api.pushInterval = time.Millisecond
api.pushIntervalFirst = time.Millisecond
url, err := url.ParseRequestURI("http://api.crowdsec.net/")
Expand Down Expand Up @@ -1144,7 +1143,7 @@ func TestAPICPush(t *testing.T) {

func TestAPICPull(t *testing.T) {
ctx := context.Background()
api := getAPIC(t)
api := getAPIC(t, ctx)
tests := []struct {
name string
setUp func()
Expand Down Expand Up @@ -1172,7 +1171,7 @@ func TestAPICPull(t *testing.T) {

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
api = getAPIC(t)
api = getAPIC(t, ctx)
api.pullInterval = time.Millisecond
api.pullIntervalFirst = time.Millisecond
url, err := url.ParseRequestURI("http://api.crowdsec.net/")
Expand Down Expand Up @@ -1223,7 +1222,7 @@ func TestAPICPull(t *testing.T) {
time.Sleep(time.Millisecond * 500)
logrus.SetOutput(os.Stderr)
assert.Contains(t, buf.String(), tc.logContains)
assertTotalDecisionCount(t, api.dbClient, tc.expectedDecisionCount)
assertTotalDecisionCount(t, ctx, api.dbClient, tc.expectedDecisionCount)
})
}
}
Expand Down
4 changes: 1 addition & 3 deletions pkg/apiserver/apiserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,7 @@ func CreateTestMachine(t *testing.T, router *gin.Engine, token string) string {
return body
}

func CreateTestBouncer(t *testing.T, config *csconfig.DatabaseCfg) string {
ctx := context.Background()

func CreateTestBouncer(t *testing.T, ctx context.Context, config *csconfig.DatabaseCfg) string {
dbClient, err := database.NewClient(ctx, config)
require.NoError(t, err)

Expand Down
Loading

0 comments on commit 06adbe0

Please sign in to comment.